001/*-
002 *******************************************************************************
003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd.
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 *
009 * Contributors:
010 *    Peter Chang - initial API and implementation and/or initial documentation
011 *******************************************************************************/
012
013package org.eclipse.january.dataset;
014
015import java.util.Arrays;
016import java.util.List;
017
018import org.apache.commons.math3.complex.Complex;
019import org.apache.commons.math3.linear.Array2DRowRealMatrix;
020import org.apache.commons.math3.linear.ArrayRealVector;
021import org.apache.commons.math3.linear.CholeskyDecomposition;
022import org.apache.commons.math3.linear.ConjugateGradient;
023import org.apache.commons.math3.linear.EigenDecomposition;
024import org.apache.commons.math3.linear.LUDecomposition;
025import org.apache.commons.math3.linear.MatrixUtils;
026import org.apache.commons.math3.linear.QRDecomposition;
027import org.apache.commons.math3.linear.RealLinearOperator;
028import org.apache.commons.math3.linear.RealMatrix;
029import org.apache.commons.math3.linear.RealVector;
030import org.apache.commons.math3.linear.SingularValueDecomposition;
031
032
033public class LinearAlgebra {
034
035        private static final int CROSSOVERPOINT = 16; // point at which using slice iterators for inner loop is faster 
036
037        /**
038         * Calculate the tensor dot product over given axes. This is the sum of products of elements selected
039         * from the given axes in each dataset
040         * @param a
041         * @param b
042         * @param axisa axis dimension in a to sum over (can be -ve)
043         * @param axisb axis dimension in b to sum over (can be -ve)
044         * @return tensor dot product
045         */
046        public static Dataset tensorDotProduct(final Dataset a, final Dataset b, final int axisa, final int axisb) {
047                // this is slower for summing lengths < ~15
048                final int[] ashape = a.getShapeRef();
049                final int[] bshape = b.getShapeRef();
050                final int arank = ashape.length;
051                final int brank = bshape.length;
052                int aaxis = ShapeUtils.checkAxis(arank, axisa);
053
054                if (ashape[aaxis] < CROSSOVERPOINT) { // faster to use position iteration
055                        return tensorDotProduct(a, b, new int[] {axisa}, new int[] {axisb});
056                }
057                int baxis = ShapeUtils.checkAxis(brank, axisb);
058
059                final boolean[] achoice = new boolean[arank];
060                final boolean[] bchoice = new boolean[brank];
061                Arrays.fill(achoice, true);
062                Arrays.fill(bchoice, true);
063                achoice[aaxis] = false; // flag which axes not to iterate over
064                bchoice[baxis] = false;
065
066                final boolean[] notachoice = new boolean[arank];
067                final boolean[] notbchoice = new boolean[brank];
068                notachoice[aaxis] = true; // flag which axes to iterate over
069                notbchoice[baxis] = true;
070
071                int drank = arank + brank - 2;
072                int[] dshape = new int[drank];
073                int d = 0;
074                for (int i = 0; i < arank; i++) {
075                        if (achoice[i])
076                                dshape[d++] = ashape[i];
077                }
078                for (int i = 0; i < brank; i++) {
079                        if (bchoice[i])
080                                dshape[d++] = bshape[i];
081                }
082                Dataset data = DatasetFactory.zeros(InterfaceUtils.getBestInterface(a.getClass(), b.getClass()), dshape);
083
084                SliceIterator ita = a.getSliceIteratorFromAxes(null, achoice);
085                int l = 0;
086                final int[] apos = ita.getPos();
087                while (ita.hasNext()) {
088                        SliceIterator itb = b.getSliceIteratorFromAxes(null, bchoice);
089                        final int[] bpos = itb.getPos();
090                        while (itb.hasNext()) {
091                                SliceIterator itaa = a.getSliceIteratorFromAxes(apos, notachoice);
092                                SliceIterator itba = b.getSliceIteratorFromAxes(bpos, notbchoice);
093                                double sum = 0.0;
094                                double com = 0.0;
095                                while (itaa.hasNext() && itba.hasNext()) {
096                                        final double y = a.getElementDoubleAbs(itaa.index) * b.getElementDoubleAbs(itba.index) - com;
097                                        final double t = sum + y;
098                                        com = (t - sum) - y;
099                                        sum = t;
100                                }
101                                data.setObjectAbs(l++, sum);
102                        }
103                }
104
105                return data;
106        }
107
108        /**
109         * Calculate the tensor dot product over given axes. This is the sum of products of elements selected
110         * from the given axes in each dataset
111         * @param a
112         * @param b
113         * @param axisa axis dimensions in a to sum over (can be -ve)
114         * @param axisb axis dimensions in b to sum over (can be -ve)
115         * @return tensor dot product
116         */
117        public static Dataset tensorDotProduct(final Dataset a, final Dataset b, final int[] axisa, final int[] axisb) {
118                if (axisa.length != axisb.length) {
119                        throw new IllegalArgumentException("Numbers of summing axes must be same");
120                }
121                final int[] ashape = a.getShapeRef();
122                final int[] bshape = b.getShapeRef();
123                final int arank = ashape.length;
124                final int brank = bshape.length;
125                final int[] aaxes = new int[axisa.length];
126                final int[] baxes = new int[axisa.length];
127                for (int i = 0; i < axisa.length; i++) {
128                        aaxes[i] = ShapeUtils.checkAxis(arank, axisa[i]);
129                        int n = ShapeUtils.checkAxis(brank, axisb[i]);
130                        baxes[i] = n;
131
132                        if (ashape[aaxes[i]] != bshape[n]) {
133                                throw new IllegalArgumentException("Summing axes do not have matching lengths");
134                        }
135                }
136
137                final boolean[] achoice = new boolean[arank];
138                final boolean[] bchoice = new boolean[brank];
139                Arrays.fill(achoice, true);
140                Arrays.fill(bchoice, true);
141                for (int i = 0; i < aaxes.length; i++) { // flag which axes to iterate over
142                        achoice[aaxes[i]] = false;
143                        bchoice[baxes[i]] = false;
144                }
145
146                int drank = arank + brank - 2*aaxes.length;
147                int[] dshape = new int[drank];
148                int d = 0;
149                for (int i = 0; i < arank; i++) {
150                        if (achoice[i])
151                                dshape[d++] = ashape[i];
152                }
153                for (int i = 0; i < brank; i++) {
154                        if (bchoice[i])
155                                dshape[d++] = bshape[i];
156                }
157                Dataset data = DatasetFactory.zeros(InterfaceUtils.getBestInterface(a.getClass(), b.getClass()), dshape);
158
159                SliceIterator ita = a.getSliceIteratorFromAxes(null, achoice);
160                int l = 0;
161                final int[] apos = ita.getPos();
162                while (ita.hasNext()) {
163                        SliceIterator itb = b.getSliceIteratorFromAxes(null, bchoice);
164                        final int[] bpos = itb.getPos();
165                        while (itb.hasNext()) {
166                                double sum = 0.0;
167                                double com = 0.0;
168                                apos[aaxes[aaxes.length - 1]] = -1;
169                                bpos[baxes[aaxes.length - 1]] = -1;
170                                while (true) { // step through summing axes
171                                        int e = aaxes.length - 1;
172                                        for (; e >= 0; e--) {
173                                                int ai = aaxes[e];
174                                                int bi = baxes[e];
175
176                                                apos[ai]++;
177                                                bpos[bi]++;
178                                                if (apos[ai] == ashape[ai]) {
179                                                        apos[ai] = 0;
180                                                        bpos[bi] = 0;
181                                                } else
182                                                        break;
183                                        }
184                                        if (e == -1) break;
185                                        final double y = a.getDouble(apos) * b.getDouble(bpos) - com;
186                                        final double t = sum + y;
187                                        com = (t - sum) - y;
188                                        sum = t;
189                                }
190                                data.setObjectAbs(l++, sum);
191                        }
192                }
193
194                return data;
195        }
196
197        /**
198         * Calculate the dot product of two datasets. When <b>b</b> is a 1D dataset, the sum product over
199         * the last axis of <b>a</b> and <b>b</b> is returned. Where <b>a</b> is also a 1D dataset, a zero-rank dataset
200         * is returned. If <b>b</b> is 2D or higher, its second-to-last axis is used
201         * @param a
202         * @param b
203         * @return dot product
204         */
205        public static Dataset dotProduct(Dataset a, Dataset b) {
206                if (b.getRank() < 2)
207                        return tensorDotProduct(a, b, -1, 0);
208                return tensorDotProduct(a, b, -1, -2);
209        }
210
211        /**
212         * Calculate the outer product of two datasets
213         * @param a
214         * @param b
215         * @return outer product
216         */
217        public static Dataset outerProduct(Dataset a, Dataset b) {
218                int[] as = a.getShapeRef();
219                int[] bs = b.getShapeRef();
220                int rank = as.length + bs.length;
221                int[] shape = new int[rank];
222                for (int i = 0; i < as.length; i++) {
223                        shape[i] = as[i];
224                }
225                for (int i = 0; i < bs.length; i++) {
226                        shape[as.length + i] = bs[i];
227                }
228                int isa = a.getElementsPerItem();
229                int isb = b.getElementsPerItem();
230                if (isa != 1 || isb != 1) {
231                        throw new UnsupportedOperationException("Compound datasets not supported");
232                }
233                Dataset o = DatasetFactory.zeros(InterfaceUtils.getBestInterface(a.getClass(), b.getClass()), shape);
234
235                IndexIterator ita = a.getIterator();
236                IndexIterator itb = b.getIterator();
237                int j = 0;
238                while (ita.hasNext()) {
239                        double va = a.getElementDoubleAbs(ita.index);
240                        while (itb.hasNext()) {
241                                o.setObjectAbs(j++, va * b.getElementDoubleAbs(itb.index));
242                        }
243                        itb.reset();
244                }
245                return o;
246        }
247
248        /**
249         * Calculate the cross product of two datasets. Datasets must be broadcastable and
250         * possess last dimensions of length 2 or 3
251         * @param a
252         * @param b
253         * @return cross product
254         */
255        public static Dataset crossProduct(Dataset a, Dataset b) {
256                return crossProduct(a, b, -1, -1, -1);
257        }
258
259        /**
260         * Calculate the cross product of two datasets. Datasets must be broadcastable and
261         * possess dimensions of length 2 or 3. The axis parameters can be negative to indicate
262         * dimensions from the end of their shapes
263         * @param a
264         * @param b
265         * @param axisA dimension to be used a vector (must have length of 2 or 3)
266         * @param axisB dimension to be used a vector (must have length of 2 or 3)
267         * @param axisC dimension to assign as cross-product
268         * @return cross product
269         */
270        public static Dataset crossProduct(Dataset a, Dataset b, int axisA, int axisB, int axisC) {
271                final int rankA = a.getRank();
272                final int rankB = b.getRank();
273                if (rankA == 0 || rankB == 0) {
274                        throw new IllegalArgumentException("Datasets must have one or more dimensions");
275                }
276                axisA = a.checkAxis(axisA);
277                axisB = b.checkAxis(axisB);
278
279                int la = a.getShapeRef()[axisA];
280                int lb = b.getShapeRef()[axisB];
281                if (Math.min(la,  lb) < 2 || Math.max(la, lb) > 3) {
282                        throw new IllegalArgumentException("Chosen dimension of A & B must be 2 or 3");
283                }
284
285                if (Math.max(la,  lb) == 2) {
286                        return crossProduct2D(a, b, axisA, axisB);
287                }
288
289                return crossProduct3D(a, b, axisA, axisB, axisC);
290        }
291
292        private static int[] removeAxisFromShape(int[] shape, int axis) {
293                int[] s = new int[shape.length - 1];
294                int i = 0;
295                int j = 0;
296                while (i < axis) {
297                        s[j++] = shape[i++];
298                }
299                i++;
300                while (i < shape.length) {
301                        s[j++] = shape[i++];
302                }
303                return s;
304        }
305
306        // assume axes is in increasing order
307        private static int[] removeAxesFromShape(int[] shape, int... axes) {
308                int n = axes.length;
309                int[] s = new int[shape.length - n];
310                int i = 0;
311                int j = 0;
312                for (int k = 0; k < n; k++) {
313                        int a = axes[k];
314                        while (i < a) {
315                                s[j++] = shape[i++];
316                        }
317                        i++;
318                }
319                while (i < shape.length) {
320                        s[j++] = shape[i++];
321                }
322                return s;
323        }
324
325        private static int[] addAxisToShape(int[] shape, int axis, int length) {
326                int[] s = new int[shape.length + 1];
327                int i = 0;
328                int j = 0;
329                while (i < axis) {
330                        s[j++] = shape[i++];
331                }
332                s[j++] = length;
333                while (i < shape.length) {
334                        s[j++] = shape[i++];
335                }
336                return s;
337        }
338
339        private static Dataset crossProduct2D(Dataset a, Dataset b, int axisA, int axisB) {
340                // need to broadcast and omit given axes
341                int[] shapeA = removeAxisFromShape(a.getShapeRef(), axisA);
342                int[] shapeB = removeAxisFromShape(b.getShapeRef(), axisB);
343
344                List<int[]> fullShapes = BroadcastUtils.broadcastShapes(shapeA, shapeB);
345
346                int[] maxShape = fullShapes.get(0);
347                Dataset c = DatasetFactory.zeros(InterfaceUtils.getBestInterface(a.getClass(), b.getClass()), maxShape);
348
349                PositionIterator ita = a.getPositionIterator(axisA);
350                PositionIterator itb = b.getPositionIterator(axisB);
351                IndexIterator itc = c.getIterator();
352
353                final int[] pa = ita.getPos();
354                final int[] pb = itb.getPos();
355                while (itc.hasNext()) {
356                        if (!ita.hasNext()) // TODO use broadcasting...
357                                ita.reset();
358                        if (!itb.hasNext())
359                                itb.reset();
360                        pa[axisA] = 0;
361                        pb[axisB] = 1;
362                        double cv = a.getDouble(pa) * b.getDouble(pb);
363                        pa[axisA] = 1;
364                        pb[axisB] = 0;
365                        cv -= a.getDouble(pa) * b.getDouble(pb);
366
367                        c.setObjectAbs(itc.index, cv);
368                }
369                return c;
370        }
371
372        private static Dataset crossProduct3D(Dataset a, Dataset b, int axisA, int axisB, int axisC) {
373                int[] shapeA = removeAxisFromShape(a.getShapeRef(), axisA);
374                int[] shapeB = removeAxisFromShape(b.getShapeRef(), axisB);
375
376                List<int[]> fullShapes = BroadcastUtils.broadcastShapes(shapeA, shapeB);
377
378                int[] maxShape = fullShapes.get(0);
379                int rankC = maxShape.length + 1;
380                axisC = ShapeUtils.checkAxis(rankC, axisC);
381                maxShape = addAxisToShape(maxShape, axisC, 3);
382                Dataset c = DatasetFactory.zeros(InterfaceUtils.getBestInterface(a.getClass(), b.getClass()), maxShape);
383
384                PositionIterator ita = a.getPositionIterator(axisA);
385                PositionIterator itb = b.getPositionIterator(axisB);
386                PositionIterator itc = c.getPositionIterator(axisC);
387
388                final int[] pa = ita.getPos();
389                final int[] pb = itb.getPos();
390                final int[] pc = itc.getPos();
391                final int la = a.getShapeRef()[axisA];
392                final int lb = b.getShapeRef()[axisB];
393
394                if (la == 2) {
395                        while (itc.hasNext()) {
396                                if (!ita.hasNext()) // TODO use broadcasting...
397                                        ita.reset();
398                                if (!itb.hasNext())
399                                        itb.reset();
400                                double cv;
401                                pa[axisA] = 1;
402                                pb[axisB] = 2;
403                                cv = a.getDouble(pa) * b.getDouble(pb);
404                                pc[axisC] = 0;
405                                c.set(cv, pc);
406
407                                pa[axisA] = 0;
408                                pb[axisB] = 2;
409                                cv = -a.getDouble(pa) * b.getDouble(pb);
410                                pc[axisC] = 1;
411                                c.set(cv, pc);
412
413                                pa[axisA] = 0;
414                                pb[axisB] = 1;
415                                cv = a.getDouble(pa) * b.getDouble(pb);
416                                pa[axisA] = 1;
417                                pb[axisB] = 0;
418                                cv -= a.getDouble(pa) * b.getDouble(pb);
419                                pc[axisC] = 2;
420                                c.set(cv, pc);
421                        }
422                } else if (lb == 2) {
423                        while (itc.hasNext()) {
424                                if (!ita.hasNext()) // TODO use broadcasting...
425                                        ita.reset();
426                                if (!itb.hasNext())
427                                        itb.reset();
428                                double cv;
429                                pa[axisA] = 2;
430                                pb[axisB] = 1;
431                                cv = -a.getDouble(pa) * b.getDouble(pb);
432                                pc[axisC] = 0;
433                                c.set(cv, pc);
434
435                                pa[axisA] = 2;
436                                pb[axisB] = 0;
437                                cv = a.getDouble(pa) * b.getDouble(pb);
438                                pc[axisC] = 1;
439                                c.set(cv, pc);
440
441                                pa[axisA] = 0;
442                                pb[axisB] = 1;
443                                cv = a.getDouble(pa) * b.getDouble(pb);
444                                pa[axisA] = 1;
445                                pb[axisB] = 0;
446                                cv -= a.getDouble(pa) * b.getDouble(pb);
447                                pc[axisC] = 2;
448                                c.set(cv, pc);
449                        }
450                        
451                } else {
452                        while (itc.hasNext()) {
453                                if (!ita.hasNext()) // TODO use broadcasting...
454                                        ita.reset();
455                                if (!itb.hasNext())
456                                        itb.reset();
457                                double cv;
458                                pa[axisA] = 1;
459                                pb[axisB] = 2;
460                                cv = a.getDouble(pa) * b.getDouble(pb);
461                                pa[axisA] = 2;
462                                pb[axisB] = 1;
463                                cv -= a.getDouble(pa) * b.getDouble(pb);
464                                pc[axisC] = 0;
465                                c.set(cv, pc);
466
467                                pa[axisA] = 2;
468                                pb[axisB] = 0;
469                                cv = a.getDouble(pa) * b.getDouble(pb);
470                                pa[axisA] = 0;
471                                pb[axisB] = 2;
472                                cv -= a.getDouble(pa) * b.getDouble(pb);
473                                pc[axisC] = 1;
474                                c.set(cv, pc);
475
476                                pa[axisA] = 0;
477                                pb[axisB] = 1;
478                                cv = a.getDouble(pa) * b.getDouble(pb);
479                                pa[axisA] = 1;
480                                pb[axisB] = 0;
481                                cv -= a.getDouble(pa) * b.getDouble(pb);
482                                pc[axisC] = 2;
483                                c.set(cv, pc);
484                        }
485                }
486                return c;
487        }
488
489        /**
490         * Raise dataset to given power by matrix multiplication
491         * @param a
492         * @param n power
493         * @return {@code a ** n}
494         */
495        public static Dataset power(Dataset a, int n) {
496                if (n < 0) {
497                        LUDecomposition lud = new LUDecomposition(createRealMatrix(a));
498                        return createDataset(lud.getSolver().getInverse().power(-n));
499                }
500                Dataset p = createDataset(createRealMatrix(a).power(n));
501                if (!a.hasFloatingPointElements()) {
502                        return p.cast(a.getClass());
503                }
504                return p;
505        }
506
507        /**
508         * Create the Kronecker product as defined by 
509         * {@code kron[k0,...,kN] = a[i0,...,iN] * b[j0,...,jN]}
510         * where {@code kn = sn * in + jn} for {@code n = 0...N} and {@code s} is shape of {@code b}
511         * @param a
512         * @param b
513         * @return Kronecker product of a and b
514         */
515        public static Dataset kroneckerProduct(Dataset a, Dataset b) {
516                if (a.getElementsPerItem() != 1 || b.getElementsPerItem() != 1) {
517                        throw new UnsupportedOperationException("Compound datasets (including complex ones) are not currently supported");
518                }
519                int ar = a.getRank();
520                int br = b.getRank();
521                int[] aShape;
522                int[] bShape;
523                aShape = a.getShapeRef();
524                bShape = b.getShapeRef();
525                int r = ar;
526                // pre-pad if ranks are not same
527                if (ar < br) {
528                        r = br;
529                        int[] shape = new int[br];
530                        int j = 0;
531                        for (int i = ar; i < br; i++) {
532                                shape[j++] = 1;
533                        }
534                        int i = 0;
535                        while (j < br) {
536                                shape[j++] = aShape[i++];
537                        }
538                        a = a.reshape(shape);
539                        aShape = shape;
540                } else if (ar > br) {
541                        int[] shape = new int[ar];
542                        int j = 0;
543                        for (int i = br; i < ar; i++) {
544                                shape[j++] = 1;
545                        }
546                        int i = 0;
547                        while (j < ar) {
548                                shape[j++] = bShape[i++];
549                        }
550                        b = b.reshape(shape);
551                        bShape = shape;
552                }
553
554                int[] nShape = new int[r];
555                for (int i = 0; i < r; i++) {
556                        nShape[i] = aShape[i] * bShape[i];
557                }
558                Dataset kron = DatasetFactory.zeros(InterfaceUtils.getBestInterface(a.getClass(), b.getClass()), nShape);
559                IndexIterator ita = a.getIterator(true);
560                IndexIterator itb = b.getIterator(true);
561                int[] pa = ita.getPos();
562                int[] pb = itb.getPos();
563                int[] off = new int[1];
564                int[] stride = AbstractDataset.createStrides(1, nShape, null, 0, off);
565                if (kron instanceof LongDataset) {
566                        while (ita.hasNext()) {
567                                long av = a.getElementLongAbs(ita.index);
568
569                                int ka = 0; 
570                                for (int i = 0; i < r; i++) {
571                                        ka += stride[i] * bShape[i] * pa[i];
572                                }
573                                itb.reset();
574                                while (itb.hasNext()) {
575                                        long bv = b.getElementLongAbs(itb.index);
576                                        int kb = ka;
577                                        for (int i = 0; i < r; i++) {
578                                                kb += stride[i] * pb[i];
579                                        }
580                                        kron.setObjectAbs(kb, av * bv);
581                                }
582                        }
583                } else {
584                        while (ita.hasNext()) {
585                                double av = a.getElementDoubleAbs(ita.index);
586
587                                int ka = 0; 
588                                for (int i = 0; i < r; i++) {
589                                        ka += stride[i] * bShape[i] * pa[i];
590                                }
591                                itb.reset();
592                                while (itb.hasNext()) {
593                                        double bv = b.getElementLongAbs(itb.index);
594                                        int kb = ka;
595                                        for (int i = 0; i < r; i++) {
596                                                kb += stride[i] * pb[i];
597                                        }
598                                        kron.setObjectAbs(kb, av * bv);
599                                }
600                        }
601                }
602
603                return kron;
604        }
605
606        /**
607         * Calculate trace of dataset - sum of values over 1st axis and 2nd axis
608         * @param a
609         * @return trace of dataset
610         */
611        public static Dataset trace(Dataset a) {
612                return trace(a, 0, 0, 1);
613        }
614
615        /**
616         * Calculate trace of dataset - sum of values over axis1 and axis2 where axis2 is offset
617         * @param a
618         * @param offset
619         * @param axis1
620         * @param axis2
621         * @return trace of dataset
622         */
623        public static Dataset trace(Dataset a, int offset, int axis1, int axis2) {
624                int[] shape = a.getShapeRef();
625                int[] axes = new int[] { a.checkAxis(axis1), a.checkAxis(axis2) };
626                Arrays.sort(axes);
627                int is = a.getElementsPerItem();
628                Dataset trace = DatasetFactory.zeros(is, a.getClass(), removeAxesFromShape(shape, axes));
629
630                int am = axes[0];
631                int mmax = shape[am];
632                int an = axes[1];
633                int nmax = shape[an];
634                PositionIterator it = new PositionIterator(shape, axes);
635                int[] pos = it.getPos();
636                int i = 0;
637                int mmin;
638                int nmin;
639                if (offset >= 0) {
640                        mmin = 0;
641                        nmin = offset;
642                } else {
643                        mmin = -offset;
644                        nmin = 0;
645                }
646                if (is == 1) {
647                        if (a instanceof LongDataset) {
648                                while (it.hasNext()) {
649                                        int m = mmin;
650                                        int n = nmin;
651                                        long s = 0;
652                                        while (m < mmax && n < nmax) {
653                                                pos[am] = m++;
654                                                pos[an] = n++;
655                                                s += a.getLong(pos);
656                                        }
657                                        trace.setObjectAbs(i++, s);
658                                }
659                        } else {
660                                while (it.hasNext()) {
661                                        int m = mmin;
662                                        int n = nmin;
663                                        double s = 0;
664                                        while (m < mmax && n < nmax) {
665                                                pos[am] = m++;
666                                                pos[an] = n++;
667                                                s += a.getDouble(pos);
668                                        }
669                                        trace.setObjectAbs(i++, s);
670                                }
671                        }
672                } else {
673                        AbstractCompoundDataset ca = (AbstractCompoundDataset) a;
674                        if (ca instanceof CompoundLongDataset) {
675                                long[] t = new long[is];
676                                long[] s = new long[is];
677                                while (it.hasNext()) {
678                                        int m = mmin;
679                                        int n = nmin;
680                                        Arrays.fill(s, 0);
681                                        while (m < mmax && n < nmax) {
682                                                pos[am] = m++;
683                                                pos[an] = n++;
684                                                ((CompoundLongDataset)ca).getAbs(ca.get1DIndex(pos), t);
685                                                for (int k = 0; k < is; k++) {
686                                                        s[k] += t[k];
687                                                }
688                                        }
689                                        trace.setObjectAbs(i++, s);
690                                }
691                        } else {
692                                double[] t = new double[is];
693                                double[] s = new double[is];
694                                while (it.hasNext()) {
695                                        int m = mmin;
696                                        int n = nmin;
697                                        Arrays.fill(s, 0);
698                                        while (m < mmax && n < nmax) {
699                                                pos[am] = m++;
700                                                pos[an] = n++;
701                                                ca.getDoubleArray(t, pos);
702                                                for (int k = 0; k < is; k++) {
703                                                        s[k] += t[k];
704                                                }
705                                        }
706                                        trace.setObjectAbs(i++, s);
707                                }
708                        }
709                }
710
711                return trace;
712        }
713
714        /**
715         * Order value for norm
716         */
717        public enum NormOrder {
718                /**
719                 * 2-norm for vectors and Frobenius for matrices
720                 */
721                DEFAULT,
722                /**
723                 * Frobenius (not allowed for vectors)
724                 */
725                FROBENIUS,
726                /**
727                 * Zero-order (not allowed for matrices)
728                 */
729                ZERO,
730                /**
731                 * Positive infinity
732                 */
733                POS_INFINITY,
734                /**
735                 * Negative infinity
736                 */
737                NEG_INFINITY;
738        }
739
740        /**
741         * @param a
742         * @return norm of dataset
743         */
744        public static double norm(Dataset a) {
745                return norm(a, NormOrder.DEFAULT);
746        }
747
748        /**
749         * @param a
750         * @param order
751         * @return norm of dataset
752         */
753        public static double norm(Dataset a, NormOrder order) {
754                int r = a.getRank();
755                if (r == 1) {
756                        return vectorNorm(a, order);
757                } else if (r == 2) {
758                        return matrixNorm(a, order);
759                }
760                throw new IllegalArgumentException("Rank of dataset must be one or two");
761        }
762
763        private static double vectorNorm(Dataset a, NormOrder order) {
764                double n;
765                IndexIterator it;
766                switch (order) {
767                case FROBENIUS:
768                        throw new IllegalArgumentException("Not allowed for vectors");
769                case NEG_INFINITY:
770                case POS_INFINITY:
771                        it = a.getIterator();
772                        if (order == NormOrder.POS_INFINITY) {
773                                n = Double.NEGATIVE_INFINITY;
774                                if (a.isComplex()) {
775                                        while (it.hasNext()) {
776                                                double v = ((Complex) a.getObjectAbs(it.index)).abs();
777                                                n = Math.max(n, v);
778                                        }
779                                } else {
780                                        while (it.hasNext()) {
781                                                double v = Math.abs(a.getElementDoubleAbs(it.index));
782                                                n = Math.max(n, v);
783                                        }
784                                }
785                        } else {
786                                n = Double.POSITIVE_INFINITY;
787                                if (a.isComplex()) {
788                                        while (it.hasNext()) {
789                                                double v = ((Complex) a.getObjectAbs(it.index)).abs();
790                                                n = Math.min(n, v);
791                                        }
792                                } else {
793                                        while (it.hasNext()) {
794                                                double v = Math.abs(a.getElementDoubleAbs(it.index));
795                                                n = Math.min(n, v);
796                                        }
797                                }
798                        }
799                        break;
800                case ZERO:
801                        it = a.getIterator();
802                        n = 0;
803                        if (a.isComplex()) {
804                                while (it.hasNext()) {
805                                        if (!((Complex) a.getObjectAbs(it.index)).equals(Complex.ZERO))
806                                                n++;
807                                }
808                        } else {
809                                while (it.hasNext()) {
810                                        if (a.getElementBooleanAbs(it.index))
811                                                n++;
812                                }
813                        }
814                        
815                        break;
816                default:
817                        n = vectorNorm(a, 2);
818                        break;
819                }
820                return n;
821        }
822
823        private static double matrixNorm(Dataset a, NormOrder order) {
824                double n;
825                IndexIterator it;
826                switch (order) {
827                case NEG_INFINITY:
828                case POS_INFINITY:
829                        n = maxMinMatrixNorm(a, 1, order == NormOrder.POS_INFINITY);
830                        break;
831                case ZERO:
832                        throw new IllegalArgumentException("Not allowed for matrices");
833                default:
834                case FROBENIUS:
835                        it = a.getIterator();
836                        n = 0;
837                        if (a.isComplex()) {
838                                while (it.hasNext()) {
839                                        double v = ((Complex) a.getObjectAbs(it.index)).abs();
840                                        n += v*v;
841                                }
842                        } else {
843                                while (it.hasNext()) {
844                                        double v = a.getElementDoubleAbs(it.index);
845                                        n += v*v;
846                                }
847                        }
848                        n = Math.sqrt(n);
849                        break;
850                }
851                return n;
852        }
853
854        /**
855         * @param a
856         * @param p
857         * @return p-norm of dataset
858         */
859        public static double norm(Dataset a, final double p) {
860                if (p == 0) {
861                        return norm(a, NormOrder.ZERO);
862                }
863                int r = a.getRank();
864                if (r == 1) {
865                        return vectorNorm(a, p);
866                } else if (r == 2) {
867                        return matrixNorm(a, p);
868                }
869                throw new IllegalArgumentException("Rank of dataset must be one or two");
870        }
871
872        private static double vectorNorm(Dataset a, final double p) {
873                IndexIterator it = a.getIterator();
874                double n = 0;
875                if (a.isComplex()) {
876                        while (it.hasNext()) {
877                                double v = ((Complex) a.getObjectAbs(it.index)).abs();
878                                if (p == 2) {
879                                        v *= v;
880                                } else if (p != 1) {
881                                        v = Math.pow(v, p);
882                                }
883                                n += v;
884                        }
885                } else {
886                        while (it.hasNext()) {
887                                double v = a.getElementDoubleAbs(it.index);
888                                if (p == 1) {
889                                        v = Math.abs(v);
890                                } else if (p == 2) {
891                                        v *= v;
892                                } else {
893                                        v = Math.pow(Math.abs(v), p);
894                                }
895                                n += v;
896                        }
897                }
898                return Math.pow(n, 1./p);
899        }
900
901        private static double matrixNorm(Dataset a, final double p) {
902                double n;
903                if (Math.abs(p) == 1) {
904                        n = maxMinMatrixNorm(a, 0, p > 0);
905                } else if (Math.abs(p) == 2) {
906                        double[] s = calcSingularValues(a);
907                        n = p > 0 ? s[0] : s[s.length - 1];
908                } else {
909                        throw new IllegalArgumentException("Order not allowed");
910                }
911
912                return n;
913        }
914
915        private static double maxMinMatrixNorm(Dataset a, int d, boolean max) {
916                double n;
917                IndexIterator it;
918                int[] pos;
919                int l;
920                it = a.getPositionIterator(d);
921                pos = it.getPos();
922                l = a.getShapeRef()[d];
923                if (max) {
924                        n = Double.NEGATIVE_INFINITY;
925                        if (a.isComplex()) {
926                                while (it.hasNext()) {
927                                        double v = ((Complex) a.getObject(pos)).abs();
928                                        for (int i = 1; i < l; i++) {
929                                                pos[d] = i;
930                                                v += ((Complex) a.getObject(pos)).abs();
931                                        }
932                                        pos[d] = 0;
933                                        n = Math.max(n, v);
934                                }
935                        } else {
936                                while (it.hasNext()) {
937                                        double v = Math.abs(a.getDouble(pos));
938                                        for (int i = 1; i < l; i++) {
939                                                pos[d] = i;
940                                                v += Math.abs(a.getDouble(pos));
941                                        }
942                                        pos[d] = 0;
943                                        n = Math.max(n, v);
944                                }
945                        }
946                } else {
947                        n = Double.POSITIVE_INFINITY;
948                        if (a.isComplex()) {
949                                while (it.hasNext()) {
950                                        double v = ((Complex) a.getObject(pos)).abs();
951                                        for (int i = 1; i < l; i++) {
952                                                pos[d] = i;
953                                                v += ((Complex) a.getObject(pos)).abs();
954                                        }
955                                        pos[d] = 0;
956                                        n = Math.min(n, v);
957                                }
958                        } else {
959                                while (it.hasNext()) {
960                                        double v = Math.abs(a.getDouble(pos));
961                                        for (int i = 1; i < l; i++) {
962                                                pos[d] = i;
963                                                v += Math.abs(a.getDouble(pos));
964                                        }
965                                        pos[d] = 0;
966                                        n = Math.min(n, v);
967                                }
968                        }
969                }
970                return n;
971        }
972
973        /**
974         * @param a
975         * @return array of singular values
976         */
977        public static double[] calcSingularValues(Dataset a) {
978                SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a));
979                return svd.getSingularValues();
980        }
981
982
983        /**
984         * Calculate singular value decomposition {@code A = U S V^T}
985         * @param a
986         * @return array of U - orthogonal matrix, s - singular values vector, V - orthogonal matrix
987         */
988        public static Dataset[] calcSingularValueDecomposition(Dataset a) {
989                SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a));
990                return new Dataset[] {createDataset(svd.getU()), DatasetFactory.createFromObject(svd.getSingularValues()),
991                                createDataset(svd.getV())};
992        }
993
994        /**
995         * Calculate (Moore-Penrose) pseudo-inverse
996         * @param a
997         * @return pseudo-inverse
998         */
999        public static Dataset calcPseudoInverse(Dataset a) {
1000                SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a));
1001                return createDataset(svd.getSolver().getInverse());
1002        }
1003
1004        /**
1005         * Calculate matrix rank by singular value decomposition method
1006         * @param a
1007         * @return effective numerical rank of matrix
1008         */
1009        public static int calcMatrixRank(Dataset a) {
1010                SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a));
1011                return svd.getRank();
1012        }
1013
1014        /**
1015         * Calculate condition number of matrix by singular value decomposition method
1016         * @param a
1017         * @return condition number
1018         */
1019        public static double calcConditionNumber(Dataset a) {
1020                SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a));
1021                return svd.getConditionNumber();
1022        }
1023
1024        /**
1025         * @param a
1026         * @return determinant of dataset
1027         */
1028        public static double calcDeterminant(Dataset a) {
1029                EigenDecomposition evd = new EigenDecomposition(createRealMatrix(a));
1030                return evd.getDeterminant();
1031        }
1032
1033        /**
1034         * @param a
1035         * @return dataset of eigenvalues (can be double or complex double)
1036         */
1037        public static Dataset calcEigenvalues(Dataset a) {
1038                EigenDecomposition evd = new EigenDecomposition(createRealMatrix(a));
1039                double[] rev = evd.getRealEigenvalues();
1040
1041                if (evd.hasComplexEigenvalues()) {
1042                        double[] iev = evd.getImagEigenvalues();
1043                        return DatasetFactory.createComplexDataset(ComplexDoubleDataset.class, rev, iev);
1044                }
1045                return DatasetFactory.createFromObject(rev);
1046        }
1047
1048        /**
1049         * Calculate eigen-decomposition {@code A = V D V^T}
1050         * @param a
1051         * @return array of D eigenvalues (can be double or complex double) and V eigenvectors
1052         */
1053        public static Dataset[] calcEigenDecomposition(Dataset a) {
1054                EigenDecomposition evd = new EigenDecomposition(createRealMatrix(a));
1055                Dataset[] results = new Dataset[2];
1056
1057                double[] rev = evd.getRealEigenvalues();
1058                if (evd.hasComplexEigenvalues()) {
1059                        double[] iev = evd.getImagEigenvalues();
1060                        results[0] = DatasetFactory.createComplexDataset(ComplexDoubleDataset.class, rev, iev);
1061                } else {
1062                        results[0] = DatasetFactory.createFromObject(rev);
1063                }
1064                results[1] = createDataset(evd.getV());
1065                return results;
1066        }
1067
1068        /**
1069         * Calculate QR decomposition {@code A = Q R}
1070         * @param a
1071         * @return array of Q and R
1072         */
1073        public static Dataset[] calcQRDecomposition(Dataset a) {
1074                QRDecomposition qrd = new QRDecomposition(createRealMatrix(a));
1075                return new Dataset[] {createDataset(qrd.getQT()).getTransposedView(), createDataset(qrd.getR())};
1076        }
1077
1078        /**
1079         * Calculate LU decomposition {@code A = P^-1 L U}
1080         * @param a
1081         * @return array of L, U and P
1082         */
1083        public static Dataset[] calcLUDecomposition(Dataset a) {
1084                LUDecomposition lud = new LUDecomposition(createRealMatrix(a));
1085                return new Dataset[] {createDataset(lud.getL()), createDataset(lud.getU()),
1086                                createDataset(lud.getP())};
1087        }
1088
1089        /**
1090         * Calculate inverse of square dataset
1091         * @param a
1092         * @return inverse
1093         */
1094        public static Dataset calcInverse(Dataset a) {
1095                LUDecomposition lud = new LUDecomposition(createRealMatrix(a));
1096                return createDataset(lud.getSolver().getInverse());
1097        }
1098
1099        /**
1100         * Solve linear matrix equation {@code A x = v}
1101         * @param a
1102         * @param v
1103         * @return x
1104         */
1105        public static Dataset solve(Dataset a, Dataset v) {
1106                LUDecomposition lud = new LUDecomposition(createRealMatrix(a));
1107                if (v.getRank() == 1) {
1108                        RealVector x = createRealVector(v);
1109                        return createDataset(lud.getSolver().solve(x));
1110                }
1111                RealMatrix x = createRealMatrix(v);
1112                return createDataset(lud.getSolver().solve(x));
1113        }
1114
1115        
1116        /**
1117         * Solve least squares matrix equation {@code A x = v} by SVD
1118         * @param a
1119         * @param v
1120         * @return x
1121         */
1122        public static Dataset solveSVD(Dataset a, Dataset v) {
1123                SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a));
1124                if (v.getRank() == 1) {
1125                        RealVector x = createRealVector(v);
1126                        return createDataset(svd.getSolver().solve(x));
1127                }
1128                RealMatrix x = createRealMatrix(v);
1129                return createDataset(svd.getSolver().solve(x));
1130        }
1131        
1132        /**
1133         * Calculate Cholesky decomposition {@code A = L L^T}
1134         * @param a
1135         * @return L
1136         */
1137        public static Dataset calcCholeskyDecomposition(Dataset a) {
1138                CholeskyDecomposition cd = new CholeskyDecomposition(createRealMatrix(a));
1139                return createDataset(cd.getL());
1140        }
1141
1142        /**
1143         * Calculation {@code A x = v} by conjugate gradient method with the stopping criterion being
1144         * that the estimated residual {@code r = v - A x} satisfies {@code ||r|| < ||v||} with maximum of 100 iterations
1145         * @param a
1146         * @param v
1147         * @return value of {@code A^-1 v} by conjugate gradient method
1148         */
1149        public static Dataset calcConjugateGradient(Dataset a, Dataset v) {
1150                return calcConjugateGradient(a, v, 100, 1);
1151        }
1152
1153        /**
1154         * Calculation {@code A x = v} by conjugate gradient method with the stopping criterion being
1155         * that the estimated residual {@code r = v - A x} satisfies {@code ||r|| < delta ||v||}
1156         * @param a
1157         * @param v
1158         * @param maxIterations
1159         * @param delta parameter used by stopping criterion
1160         * @return value of {@code A^-1 v} by conjugate gradient method
1161         */
1162        public static Dataset calcConjugateGradient(Dataset a, Dataset v, int maxIterations, double delta) {
1163                ConjugateGradient cg = new ConjugateGradient(maxIterations, delta, false);
1164                return createDataset(cg.solve((RealLinearOperator) createRealMatrix(a), createRealVector(v)));
1165        }
1166
1167        private static RealMatrix createRealMatrix(Dataset a) {
1168                if (a.getRank() != 2) {
1169                        throw new IllegalArgumentException("Dataset must be rank 2");
1170                }
1171                int[] shape = a.getShapeRef();
1172                IndexIterator it = a.getIterator(true);
1173                int[] pos = it.getPos();
1174                RealMatrix m = MatrixUtils.createRealMatrix(shape[0], shape[1]);
1175                while (it.hasNext()) {
1176                        m.setEntry(pos[0], pos[1], a.getElementDoubleAbs(it.index));
1177                }
1178                return m;
1179        }
1180
1181        private static RealVector createRealVector(Dataset a) {
1182                if (a.getRank() != 1) {
1183                        throw new IllegalArgumentException("Dataset must be rank 1");
1184                }
1185                int size = a.getSize();
1186                IndexIterator it = a.getIterator(true);
1187                int[] pos = it.getPos();
1188                RealVector m = new ArrayRealVector(size);
1189                while (it.hasNext()) {
1190                        m.setEntry(pos[0], a.getElementDoubleAbs(it.index));
1191                }
1192                return m;
1193        }
1194
1195        private static Dataset createDataset(RealVector v) {
1196                DoubleDataset r = DatasetFactory.zeros(DoubleDataset.class, v.getDimension());
1197                int size = r.getSize();
1198                if (v instanceof ArrayRealVector) {
1199                        double[] data = ((ArrayRealVector) v).getDataRef();
1200                        for (int i = 0; i < size; i++) {
1201                                r.setAbs(i, data[i]);
1202                        }
1203                } else {
1204                        for (int i = 0; i < size; i++) {
1205                                r.setAbs(i, v.getEntry(i));
1206                        }
1207                }
1208                return r;
1209        }
1210
1211        private static Dataset createDataset(RealMatrix m) {
1212                DoubleDataset r = DatasetFactory.zeros(DoubleDataset.class, m.getRowDimension(), m.getColumnDimension());
1213                if (m instanceof Array2DRowRealMatrix) {
1214                        double[][] data = ((Array2DRowRealMatrix) m).getDataRef();
1215                        IndexIterator it = r.getIterator(true);
1216                        int[] pos = it.getPos();
1217                        while (it.hasNext()) {
1218                                r.setAbs(it.index, data[pos[0]][pos[1]]);
1219                        }
1220                } else {
1221                        IndexIterator it = r.getIterator(true);
1222                        int[] pos = it.getPos();
1223                        while (it.hasNext()) {
1224                                r.setAbs(it.index, m.getEntry(pos[0], pos[1]));
1225                        }
1226                }
1227                return r;
1228        }
1229}