001/*- 002 ******************************************************************************* 003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd. 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 * 009 * Contributors: 010 * Peter Chang - initial API and implementation and/or initial documentation 011 *******************************************************************************/ 012 013package org.eclipse.january.dataset; 014 015import java.util.Arrays; 016import java.util.List; 017 018import org.apache.commons.math3.complex.Complex; 019import org.apache.commons.math3.linear.Array2DRowRealMatrix; 020import org.apache.commons.math3.linear.ArrayRealVector; 021import org.apache.commons.math3.linear.CholeskyDecomposition; 022import org.apache.commons.math3.linear.ConjugateGradient; 023import org.apache.commons.math3.linear.EigenDecomposition; 024import org.apache.commons.math3.linear.LUDecomposition; 025import org.apache.commons.math3.linear.MatrixUtils; 026import org.apache.commons.math3.linear.QRDecomposition; 027import org.apache.commons.math3.linear.RealLinearOperator; 028import org.apache.commons.math3.linear.RealMatrix; 029import org.apache.commons.math3.linear.RealVector; 030import org.apache.commons.math3.linear.SingularValueDecomposition; 031 032 033public class LinearAlgebra { 034 035 private static final int CROSSOVERPOINT = 16; // point at which using slice iterators for inner loop is faster 036 037 /** 038 * Calculate the tensor dot product over given axes. This is the sum of products of elements selected 039 * from the given axes in each dataset 040 * @param a 041 * @param b 042 * @param axisa axis dimension in a to sum over (can be -ve) 043 * @param axisb axis dimension in b to sum over (can be -ve) 044 * @return tensor dot product 045 */ 046 public static Dataset tensorDotProduct(final Dataset a, final Dataset b, final int axisa, final int axisb) { 047 // this is slower for summing lengths < ~15 048 final int[] ashape = a.getShapeRef(); 049 final int[] bshape = b.getShapeRef(); 050 final int arank = ashape.length; 051 final int brank = bshape.length; 052 int aaxis = ShapeUtils.checkAxis(arank, axisa); 053 054 if (ashape[aaxis] < CROSSOVERPOINT) { // faster to use position iteration 055 return tensorDotProduct(a, b, new int[] {axisa}, new int[] {axisb}); 056 } 057 int baxis = ShapeUtils.checkAxis(brank, axisb); 058 059 final boolean[] achoice = new boolean[arank]; 060 final boolean[] bchoice = new boolean[brank]; 061 Arrays.fill(achoice, true); 062 Arrays.fill(bchoice, true); 063 achoice[aaxis] = false; // flag which axes not to iterate over 064 bchoice[baxis] = false; 065 066 final boolean[] notachoice = new boolean[arank]; 067 final boolean[] notbchoice = new boolean[brank]; 068 notachoice[aaxis] = true; // flag which axes to iterate over 069 notbchoice[baxis] = true; 070 071 int drank = arank + brank - 2; 072 int[] dshape = new int[drank]; 073 int d = 0; 074 for (int i = 0; i < arank; i++) { 075 if (achoice[i]) 076 dshape[d++] = ashape[i]; 077 } 078 for (int i = 0; i < brank; i++) { 079 if (bchoice[i]) 080 dshape[d++] = bshape[i]; 081 } 082 Dataset data = DatasetFactory.zeros(InterfaceUtils.getBestInterface(a.getClass(), b.getClass()), dshape); 083 084 SliceIterator ita = a.getSliceIteratorFromAxes(null, achoice); 085 int l = 0; 086 final int[] apos = ita.getPos(); 087 while (ita.hasNext()) { 088 SliceIterator itb = b.getSliceIteratorFromAxes(null, bchoice); 089 final int[] bpos = itb.getPos(); 090 while (itb.hasNext()) { 091 SliceIterator itaa = a.getSliceIteratorFromAxes(apos, notachoice); 092 SliceIterator itba = b.getSliceIteratorFromAxes(bpos, notbchoice); 093 double sum = 0.0; 094 double com = 0.0; 095 while (itaa.hasNext() && itba.hasNext()) { 096 final double y = a.getElementDoubleAbs(itaa.index) * b.getElementDoubleAbs(itba.index) - com; 097 final double t = sum + y; 098 com = (t - sum) - y; 099 sum = t; 100 } 101 data.setObjectAbs(l++, sum); 102 } 103 } 104 105 return data; 106 } 107 108 /** 109 * Calculate the tensor dot product over given axes. This is the sum of products of elements selected 110 * from the given axes in each dataset 111 * @param a 112 * @param b 113 * @param axisa axis dimensions in a to sum over (can be -ve) 114 * @param axisb axis dimensions in b to sum over (can be -ve) 115 * @return tensor dot product 116 */ 117 public static Dataset tensorDotProduct(final Dataset a, final Dataset b, final int[] axisa, final int[] axisb) { 118 if (axisa.length != axisb.length) { 119 throw new IllegalArgumentException("Numbers of summing axes must be same"); 120 } 121 final int[] ashape = a.getShapeRef(); 122 final int[] bshape = b.getShapeRef(); 123 final int arank = ashape.length; 124 final int brank = bshape.length; 125 final int[] aaxes = new int[axisa.length]; 126 final int[] baxes = new int[axisa.length]; 127 for (int i = 0; i < axisa.length; i++) { 128 aaxes[i] = ShapeUtils.checkAxis(arank, axisa[i]); 129 int n = ShapeUtils.checkAxis(brank, axisb[i]); 130 baxes[i] = n; 131 132 if (ashape[aaxes[i]] != bshape[n]) { 133 throw new IllegalArgumentException("Summing axes do not have matching lengths"); 134 } 135 } 136 137 final boolean[] achoice = new boolean[arank]; 138 final boolean[] bchoice = new boolean[brank]; 139 Arrays.fill(achoice, true); 140 Arrays.fill(bchoice, true); 141 for (int i = 0; i < aaxes.length; i++) { // flag which axes to iterate over 142 achoice[aaxes[i]] = false; 143 bchoice[baxes[i]] = false; 144 } 145 146 int drank = arank + brank - 2*aaxes.length; 147 int[] dshape = new int[drank]; 148 int d = 0; 149 for (int i = 0; i < arank; i++) { 150 if (achoice[i]) 151 dshape[d++] = ashape[i]; 152 } 153 for (int i = 0; i < brank; i++) { 154 if (bchoice[i]) 155 dshape[d++] = bshape[i]; 156 } 157 Dataset data = DatasetFactory.zeros(InterfaceUtils.getBestInterface(a.getClass(), b.getClass()), dshape); 158 159 SliceIterator ita = a.getSliceIteratorFromAxes(null, achoice); 160 int l = 0; 161 final int[] apos = ita.getPos(); 162 while (ita.hasNext()) { 163 SliceIterator itb = b.getSliceIteratorFromAxes(null, bchoice); 164 final int[] bpos = itb.getPos(); 165 while (itb.hasNext()) { 166 double sum = 0.0; 167 double com = 0.0; 168 apos[aaxes[aaxes.length - 1]] = -1; 169 bpos[baxes[aaxes.length - 1]] = -1; 170 while (true) { // step through summing axes 171 int e = aaxes.length - 1; 172 for (; e >= 0; e--) { 173 int ai = aaxes[e]; 174 int bi = baxes[e]; 175 176 apos[ai]++; 177 bpos[bi]++; 178 if (apos[ai] == ashape[ai]) { 179 apos[ai] = 0; 180 bpos[bi] = 0; 181 } else 182 break; 183 } 184 if (e == -1) break; 185 final double y = a.getDouble(apos) * b.getDouble(bpos) - com; 186 final double t = sum + y; 187 com = (t - sum) - y; 188 sum = t; 189 } 190 data.setObjectAbs(l++, sum); 191 } 192 } 193 194 return data; 195 } 196 197 /** 198 * Calculate the dot product of two datasets. When <b>b</b> is a 1D dataset, the sum product over 199 * the last axis of <b>a</b> and <b>b</b> is returned. Where <b>a</b> is also a 1D dataset, a zero-rank dataset 200 * is returned. If <b>b</b> is 2D or higher, its second-to-last axis is used 201 * @param a 202 * @param b 203 * @return dot product 204 */ 205 public static Dataset dotProduct(Dataset a, Dataset b) { 206 if (b.getRank() < 2) 207 return tensorDotProduct(a, b, -1, 0); 208 return tensorDotProduct(a, b, -1, -2); 209 } 210 211 /** 212 * Calculate the outer product of two datasets 213 * @param a 214 * @param b 215 * @return outer product 216 */ 217 public static Dataset outerProduct(Dataset a, Dataset b) { 218 int[] as = a.getShapeRef(); 219 int[] bs = b.getShapeRef(); 220 int rank = as.length + bs.length; 221 int[] shape = new int[rank]; 222 for (int i = 0; i < as.length; i++) { 223 shape[i] = as[i]; 224 } 225 for (int i = 0; i < bs.length; i++) { 226 shape[as.length + i] = bs[i]; 227 } 228 int isa = a.getElementsPerItem(); 229 int isb = b.getElementsPerItem(); 230 if (isa != 1 || isb != 1) { 231 throw new UnsupportedOperationException("Compound datasets not supported"); 232 } 233 Dataset o = DatasetFactory.zeros(InterfaceUtils.getBestInterface(a.getClass(), b.getClass()), shape); 234 235 IndexIterator ita = a.getIterator(); 236 IndexIterator itb = b.getIterator(); 237 int j = 0; 238 while (ita.hasNext()) { 239 double va = a.getElementDoubleAbs(ita.index); 240 while (itb.hasNext()) { 241 o.setObjectAbs(j++, va * b.getElementDoubleAbs(itb.index)); 242 } 243 itb.reset(); 244 } 245 return o; 246 } 247 248 /** 249 * Calculate the cross product of two datasets. Datasets must be broadcastable and 250 * possess last dimensions of length 2 or 3 251 * @param a 252 * @param b 253 * @return cross product 254 */ 255 public static Dataset crossProduct(Dataset a, Dataset b) { 256 return crossProduct(a, b, -1, -1, -1); 257 } 258 259 /** 260 * Calculate the cross product of two datasets. Datasets must be broadcastable and 261 * possess dimensions of length 2 or 3. The axis parameters can be negative to indicate 262 * dimensions from the end of their shapes 263 * @param a 264 * @param b 265 * @param axisA dimension to be used a vector (must have length of 2 or 3) 266 * @param axisB dimension to be used a vector (must have length of 2 or 3) 267 * @param axisC dimension to assign as cross-product 268 * @return cross product 269 */ 270 public static Dataset crossProduct(Dataset a, Dataset b, int axisA, int axisB, int axisC) { 271 final int rankA = a.getRank(); 272 final int rankB = b.getRank(); 273 if (rankA == 0 || rankB == 0) { 274 throw new IllegalArgumentException("Datasets must have one or more dimensions"); 275 } 276 axisA = a.checkAxis(axisA); 277 axisB = b.checkAxis(axisB); 278 279 int la = a.getShapeRef()[axisA]; 280 int lb = b.getShapeRef()[axisB]; 281 if (Math.min(la, lb) < 2 || Math.max(la, lb) > 3) { 282 throw new IllegalArgumentException("Chosen dimension of A & B must be 2 or 3"); 283 } 284 285 if (Math.max(la, lb) == 2) { 286 return crossProduct2D(a, b, axisA, axisB); 287 } 288 289 return crossProduct3D(a, b, axisA, axisB, axisC); 290 } 291 292 private static int[] removeAxisFromShape(int[] shape, int axis) { 293 int[] s = new int[shape.length - 1]; 294 int i = 0; 295 int j = 0; 296 while (i < axis) { 297 s[j++] = shape[i++]; 298 } 299 i++; 300 while (i < shape.length) { 301 s[j++] = shape[i++]; 302 } 303 return s; 304 } 305 306 // assume axes is in increasing order 307 private static int[] removeAxesFromShape(int[] shape, int... axes) { 308 int n = axes.length; 309 int[] s = new int[shape.length - n]; 310 int i = 0; 311 int j = 0; 312 for (int k = 0; k < n; k++) { 313 int a = axes[k]; 314 while (i < a) { 315 s[j++] = shape[i++]; 316 } 317 i++; 318 } 319 while (i < shape.length) { 320 s[j++] = shape[i++]; 321 } 322 return s; 323 } 324 325 private static int[] addAxisToShape(int[] shape, int axis, int length) { 326 int[] s = new int[shape.length + 1]; 327 int i = 0; 328 int j = 0; 329 while (i < axis) { 330 s[j++] = shape[i++]; 331 } 332 s[j++] = length; 333 while (i < shape.length) { 334 s[j++] = shape[i++]; 335 } 336 return s; 337 } 338 339 private static Dataset crossProduct2D(Dataset a, Dataset b, int axisA, int axisB) { 340 // need to broadcast and omit given axes 341 int[] shapeA = removeAxisFromShape(a.getShapeRef(), axisA); 342 int[] shapeB = removeAxisFromShape(b.getShapeRef(), axisB); 343 344 List<int[]> fullShapes = BroadcastUtils.broadcastShapes(shapeA, shapeB); 345 346 int[] maxShape = fullShapes.get(0); 347 Dataset c = DatasetFactory.zeros(InterfaceUtils.getBestInterface(a.getClass(), b.getClass()), maxShape); 348 349 PositionIterator ita = a.getPositionIterator(axisA); 350 PositionIterator itb = b.getPositionIterator(axisB); 351 IndexIterator itc = c.getIterator(); 352 353 final int[] pa = ita.getPos(); 354 final int[] pb = itb.getPos(); 355 while (itc.hasNext()) { 356 if (!ita.hasNext()) // TODO use broadcasting... 357 ita.reset(); 358 if (!itb.hasNext()) 359 itb.reset(); 360 pa[axisA] = 0; 361 pb[axisB] = 1; 362 double cv = a.getDouble(pa) * b.getDouble(pb); 363 pa[axisA] = 1; 364 pb[axisB] = 0; 365 cv -= a.getDouble(pa) * b.getDouble(pb); 366 367 c.setObjectAbs(itc.index, cv); 368 } 369 return c; 370 } 371 372 private static Dataset crossProduct3D(Dataset a, Dataset b, int axisA, int axisB, int axisC) { 373 int[] shapeA = removeAxisFromShape(a.getShapeRef(), axisA); 374 int[] shapeB = removeAxisFromShape(b.getShapeRef(), axisB); 375 376 List<int[]> fullShapes = BroadcastUtils.broadcastShapes(shapeA, shapeB); 377 378 int[] maxShape = fullShapes.get(0); 379 int rankC = maxShape.length + 1; 380 axisC = ShapeUtils.checkAxis(rankC, axisC); 381 maxShape = addAxisToShape(maxShape, axisC, 3); 382 Dataset c = DatasetFactory.zeros(InterfaceUtils.getBestInterface(a.getClass(), b.getClass()), maxShape); 383 384 PositionIterator ita = a.getPositionIterator(axisA); 385 PositionIterator itb = b.getPositionIterator(axisB); 386 PositionIterator itc = c.getPositionIterator(axisC); 387 388 final int[] pa = ita.getPos(); 389 final int[] pb = itb.getPos(); 390 final int[] pc = itc.getPos(); 391 final int la = a.getShapeRef()[axisA]; 392 final int lb = b.getShapeRef()[axisB]; 393 394 if (la == 2) { 395 while (itc.hasNext()) { 396 if (!ita.hasNext()) // TODO use broadcasting... 397 ita.reset(); 398 if (!itb.hasNext()) 399 itb.reset(); 400 double cv; 401 pa[axisA] = 1; 402 pb[axisB] = 2; 403 cv = a.getDouble(pa) * b.getDouble(pb); 404 pc[axisC] = 0; 405 c.set(cv, pc); 406 407 pa[axisA] = 0; 408 pb[axisB] = 2; 409 cv = -a.getDouble(pa) * b.getDouble(pb); 410 pc[axisC] = 1; 411 c.set(cv, pc); 412 413 pa[axisA] = 0; 414 pb[axisB] = 1; 415 cv = a.getDouble(pa) * b.getDouble(pb); 416 pa[axisA] = 1; 417 pb[axisB] = 0; 418 cv -= a.getDouble(pa) * b.getDouble(pb); 419 pc[axisC] = 2; 420 c.set(cv, pc); 421 } 422 } else if (lb == 2) { 423 while (itc.hasNext()) { 424 if (!ita.hasNext()) // TODO use broadcasting... 425 ita.reset(); 426 if (!itb.hasNext()) 427 itb.reset(); 428 double cv; 429 pa[axisA] = 2; 430 pb[axisB] = 1; 431 cv = -a.getDouble(pa) * b.getDouble(pb); 432 pc[axisC] = 0; 433 c.set(cv, pc); 434 435 pa[axisA] = 2; 436 pb[axisB] = 0; 437 cv = a.getDouble(pa) * b.getDouble(pb); 438 pc[axisC] = 1; 439 c.set(cv, pc); 440 441 pa[axisA] = 0; 442 pb[axisB] = 1; 443 cv = a.getDouble(pa) * b.getDouble(pb); 444 pa[axisA] = 1; 445 pb[axisB] = 0; 446 cv -= a.getDouble(pa) * b.getDouble(pb); 447 pc[axisC] = 2; 448 c.set(cv, pc); 449 } 450 451 } else { 452 while (itc.hasNext()) { 453 if (!ita.hasNext()) // TODO use broadcasting... 454 ita.reset(); 455 if (!itb.hasNext()) 456 itb.reset(); 457 double cv; 458 pa[axisA] = 1; 459 pb[axisB] = 2; 460 cv = a.getDouble(pa) * b.getDouble(pb); 461 pa[axisA] = 2; 462 pb[axisB] = 1; 463 cv -= a.getDouble(pa) * b.getDouble(pb); 464 pc[axisC] = 0; 465 c.set(cv, pc); 466 467 pa[axisA] = 2; 468 pb[axisB] = 0; 469 cv = a.getDouble(pa) * b.getDouble(pb); 470 pa[axisA] = 0; 471 pb[axisB] = 2; 472 cv -= a.getDouble(pa) * b.getDouble(pb); 473 pc[axisC] = 1; 474 c.set(cv, pc); 475 476 pa[axisA] = 0; 477 pb[axisB] = 1; 478 cv = a.getDouble(pa) * b.getDouble(pb); 479 pa[axisA] = 1; 480 pb[axisB] = 0; 481 cv -= a.getDouble(pa) * b.getDouble(pb); 482 pc[axisC] = 2; 483 c.set(cv, pc); 484 } 485 } 486 return c; 487 } 488 489 /** 490 * Raise dataset to given power by matrix multiplication 491 * @param a 492 * @param n power 493 * @return {@code a ** n} 494 */ 495 public static Dataset power(Dataset a, int n) { 496 if (n < 0) { 497 LUDecomposition lud = new LUDecomposition(createRealMatrix(a)); 498 return createDataset(lud.getSolver().getInverse().power(-n)); 499 } 500 Dataset p = createDataset(createRealMatrix(a).power(n)); 501 if (!a.hasFloatingPointElements()) { 502 return p.cast(a.getClass()); 503 } 504 return p; 505 } 506 507 /** 508 * Create the Kronecker product as defined by 509 * {@code kron[k0,...,kN] = a[i0,...,iN] * b[j0,...,jN]} 510 * where {@code kn = sn * in + jn} for {@code n = 0...N} and {@code s} is shape of {@code b} 511 * @param a 512 * @param b 513 * @return Kronecker product of a and b 514 */ 515 public static Dataset kroneckerProduct(Dataset a, Dataset b) { 516 if (a.getElementsPerItem() != 1 || b.getElementsPerItem() != 1) { 517 throw new UnsupportedOperationException("Compound datasets (including complex ones) are not currently supported"); 518 } 519 int ar = a.getRank(); 520 int br = b.getRank(); 521 int[] aShape; 522 int[] bShape; 523 aShape = a.getShapeRef(); 524 bShape = b.getShapeRef(); 525 int r = ar; 526 // pre-pad if ranks are not same 527 if (ar < br) { 528 r = br; 529 int[] shape = new int[br]; 530 int j = 0; 531 for (int i = ar; i < br; i++) { 532 shape[j++] = 1; 533 } 534 int i = 0; 535 while (j < br) { 536 shape[j++] = aShape[i++]; 537 } 538 a = a.reshape(shape); 539 aShape = shape; 540 } else if (ar > br) { 541 int[] shape = new int[ar]; 542 int j = 0; 543 for (int i = br; i < ar; i++) { 544 shape[j++] = 1; 545 } 546 int i = 0; 547 while (j < ar) { 548 shape[j++] = bShape[i++]; 549 } 550 b = b.reshape(shape); 551 bShape = shape; 552 } 553 554 int[] nShape = new int[r]; 555 for (int i = 0; i < r; i++) { 556 nShape[i] = aShape[i] * bShape[i]; 557 } 558 Dataset kron = DatasetFactory.zeros(InterfaceUtils.getBestInterface(a.getClass(), b.getClass()), nShape); 559 IndexIterator ita = a.getIterator(true); 560 IndexIterator itb = b.getIterator(true); 561 int[] pa = ita.getPos(); 562 int[] pb = itb.getPos(); 563 int[] off = new int[1]; 564 int[] stride = AbstractDataset.createStrides(1, nShape, null, 0, off); 565 if (kron instanceof LongDataset) { 566 while (ita.hasNext()) { 567 long av = a.getElementLongAbs(ita.index); 568 569 int ka = 0; 570 for (int i = 0; i < r; i++) { 571 ka += stride[i] * bShape[i] * pa[i]; 572 } 573 itb.reset(); 574 while (itb.hasNext()) { 575 long bv = b.getElementLongAbs(itb.index); 576 int kb = ka; 577 for (int i = 0; i < r; i++) { 578 kb += stride[i] * pb[i]; 579 } 580 kron.setObjectAbs(kb, av * bv); 581 } 582 } 583 } else { 584 while (ita.hasNext()) { 585 double av = a.getElementDoubleAbs(ita.index); 586 587 int ka = 0; 588 for (int i = 0; i < r; i++) { 589 ka += stride[i] * bShape[i] * pa[i]; 590 } 591 itb.reset(); 592 while (itb.hasNext()) { 593 double bv = b.getElementLongAbs(itb.index); 594 int kb = ka; 595 for (int i = 0; i < r; i++) { 596 kb += stride[i] * pb[i]; 597 } 598 kron.setObjectAbs(kb, av * bv); 599 } 600 } 601 } 602 603 return kron; 604 } 605 606 /** 607 * Calculate trace of dataset - sum of values over 1st axis and 2nd axis 608 * @param a 609 * @return trace of dataset 610 */ 611 public static Dataset trace(Dataset a) { 612 return trace(a, 0, 0, 1); 613 } 614 615 /** 616 * Calculate trace of dataset - sum of values over axis1 and axis2 where axis2 is offset 617 * @param a 618 * @param offset 619 * @param axis1 620 * @param axis2 621 * @return trace of dataset 622 */ 623 public static Dataset trace(Dataset a, int offset, int axis1, int axis2) { 624 int[] shape = a.getShapeRef(); 625 int[] axes = new int[] { a.checkAxis(axis1), a.checkAxis(axis2) }; 626 Arrays.sort(axes); 627 int is = a.getElementsPerItem(); 628 Dataset trace = DatasetFactory.zeros(is, a.getClass(), removeAxesFromShape(shape, axes)); 629 630 int am = axes[0]; 631 int mmax = shape[am]; 632 int an = axes[1]; 633 int nmax = shape[an]; 634 PositionIterator it = new PositionIterator(shape, axes); 635 int[] pos = it.getPos(); 636 int i = 0; 637 int mmin; 638 int nmin; 639 if (offset >= 0) { 640 mmin = 0; 641 nmin = offset; 642 } else { 643 mmin = -offset; 644 nmin = 0; 645 } 646 if (is == 1) { 647 if (a instanceof LongDataset) { 648 while (it.hasNext()) { 649 int m = mmin; 650 int n = nmin; 651 long s = 0; 652 while (m < mmax && n < nmax) { 653 pos[am] = m++; 654 pos[an] = n++; 655 s += a.getLong(pos); 656 } 657 trace.setObjectAbs(i++, s); 658 } 659 } else { 660 while (it.hasNext()) { 661 int m = mmin; 662 int n = nmin; 663 double s = 0; 664 while (m < mmax && n < nmax) { 665 pos[am] = m++; 666 pos[an] = n++; 667 s += a.getDouble(pos); 668 } 669 trace.setObjectAbs(i++, s); 670 } 671 } 672 } else { 673 AbstractCompoundDataset ca = (AbstractCompoundDataset) a; 674 if (ca instanceof CompoundLongDataset) { 675 long[] t = new long[is]; 676 long[] s = new long[is]; 677 while (it.hasNext()) { 678 int m = mmin; 679 int n = nmin; 680 Arrays.fill(s, 0); 681 while (m < mmax && n < nmax) { 682 pos[am] = m++; 683 pos[an] = n++; 684 ((CompoundLongDataset)ca).getAbs(ca.get1DIndex(pos), t); 685 for (int k = 0; k < is; k++) { 686 s[k] += t[k]; 687 } 688 } 689 trace.setObjectAbs(i++, s); 690 } 691 } else { 692 double[] t = new double[is]; 693 double[] s = new double[is]; 694 while (it.hasNext()) { 695 int m = mmin; 696 int n = nmin; 697 Arrays.fill(s, 0); 698 while (m < mmax && n < nmax) { 699 pos[am] = m++; 700 pos[an] = n++; 701 ca.getDoubleArray(t, pos); 702 for (int k = 0; k < is; k++) { 703 s[k] += t[k]; 704 } 705 } 706 trace.setObjectAbs(i++, s); 707 } 708 } 709 } 710 711 return trace; 712 } 713 714 /** 715 * Order value for norm 716 */ 717 public enum NormOrder { 718 /** 719 * 2-norm for vectors and Frobenius for matrices 720 */ 721 DEFAULT, 722 /** 723 * Frobenius (not allowed for vectors) 724 */ 725 FROBENIUS, 726 /** 727 * Zero-order (not allowed for matrices) 728 */ 729 ZERO, 730 /** 731 * Positive infinity 732 */ 733 POS_INFINITY, 734 /** 735 * Negative infinity 736 */ 737 NEG_INFINITY; 738 } 739 740 /** 741 * @param a 742 * @return norm of dataset 743 */ 744 public static double norm(Dataset a) { 745 return norm(a, NormOrder.DEFAULT); 746 } 747 748 /** 749 * @param a 750 * @param order 751 * @return norm of dataset 752 */ 753 public static double norm(Dataset a, NormOrder order) { 754 int r = a.getRank(); 755 if (r == 1) { 756 return vectorNorm(a, order); 757 } else if (r == 2) { 758 return matrixNorm(a, order); 759 } 760 throw new IllegalArgumentException("Rank of dataset must be one or two"); 761 } 762 763 private static double vectorNorm(Dataset a, NormOrder order) { 764 double n; 765 IndexIterator it; 766 switch (order) { 767 case FROBENIUS: 768 throw new IllegalArgumentException("Not allowed for vectors"); 769 case NEG_INFINITY: 770 case POS_INFINITY: 771 it = a.getIterator(); 772 if (order == NormOrder.POS_INFINITY) { 773 n = Double.NEGATIVE_INFINITY; 774 if (a.isComplex()) { 775 while (it.hasNext()) { 776 double v = ((Complex) a.getObjectAbs(it.index)).abs(); 777 n = Math.max(n, v); 778 } 779 } else { 780 while (it.hasNext()) { 781 double v = Math.abs(a.getElementDoubleAbs(it.index)); 782 n = Math.max(n, v); 783 } 784 } 785 } else { 786 n = Double.POSITIVE_INFINITY; 787 if (a.isComplex()) { 788 while (it.hasNext()) { 789 double v = ((Complex) a.getObjectAbs(it.index)).abs(); 790 n = Math.min(n, v); 791 } 792 } else { 793 while (it.hasNext()) { 794 double v = Math.abs(a.getElementDoubleAbs(it.index)); 795 n = Math.min(n, v); 796 } 797 } 798 } 799 break; 800 case ZERO: 801 it = a.getIterator(); 802 n = 0; 803 if (a.isComplex()) { 804 while (it.hasNext()) { 805 if (!((Complex) a.getObjectAbs(it.index)).equals(Complex.ZERO)) 806 n++; 807 } 808 } else { 809 while (it.hasNext()) { 810 if (a.getElementBooleanAbs(it.index)) 811 n++; 812 } 813 } 814 815 break; 816 default: 817 n = vectorNorm(a, 2); 818 break; 819 } 820 return n; 821 } 822 823 private static double matrixNorm(Dataset a, NormOrder order) { 824 double n; 825 IndexIterator it; 826 switch (order) { 827 case NEG_INFINITY: 828 case POS_INFINITY: 829 n = maxMinMatrixNorm(a, 1, order == NormOrder.POS_INFINITY); 830 break; 831 case ZERO: 832 throw new IllegalArgumentException("Not allowed for matrices"); 833 default: 834 case FROBENIUS: 835 it = a.getIterator(); 836 n = 0; 837 if (a.isComplex()) { 838 while (it.hasNext()) { 839 double v = ((Complex) a.getObjectAbs(it.index)).abs(); 840 n += v*v; 841 } 842 } else { 843 while (it.hasNext()) { 844 double v = a.getElementDoubleAbs(it.index); 845 n += v*v; 846 } 847 } 848 n = Math.sqrt(n); 849 break; 850 } 851 return n; 852 } 853 854 /** 855 * @param a 856 * @param p 857 * @return p-norm of dataset 858 */ 859 public static double norm(Dataset a, final double p) { 860 if (p == 0) { 861 return norm(a, NormOrder.ZERO); 862 } 863 int r = a.getRank(); 864 if (r == 1) { 865 return vectorNorm(a, p); 866 } else if (r == 2) { 867 return matrixNorm(a, p); 868 } 869 throw new IllegalArgumentException("Rank of dataset must be one or two"); 870 } 871 872 private static double vectorNorm(Dataset a, final double p) { 873 IndexIterator it = a.getIterator(); 874 double n = 0; 875 if (a.isComplex()) { 876 while (it.hasNext()) { 877 double v = ((Complex) a.getObjectAbs(it.index)).abs(); 878 if (p == 2) { 879 v *= v; 880 } else if (p != 1) { 881 v = Math.pow(v, p); 882 } 883 n += v; 884 } 885 } else { 886 while (it.hasNext()) { 887 double v = a.getElementDoubleAbs(it.index); 888 if (p == 1) { 889 v = Math.abs(v); 890 } else if (p == 2) { 891 v *= v; 892 } else { 893 v = Math.pow(Math.abs(v), p); 894 } 895 n += v; 896 } 897 } 898 return Math.pow(n, 1./p); 899 } 900 901 private static double matrixNorm(Dataset a, final double p) { 902 double n; 903 if (Math.abs(p) == 1) { 904 n = maxMinMatrixNorm(a, 0, p > 0); 905 } else if (Math.abs(p) == 2) { 906 double[] s = calcSingularValues(a); 907 n = p > 0 ? s[0] : s[s.length - 1]; 908 } else { 909 throw new IllegalArgumentException("Order not allowed"); 910 } 911 912 return n; 913 } 914 915 private static double maxMinMatrixNorm(Dataset a, int d, boolean max) { 916 double n; 917 IndexIterator it; 918 int[] pos; 919 int l; 920 it = a.getPositionIterator(d); 921 pos = it.getPos(); 922 l = a.getShapeRef()[d]; 923 if (max) { 924 n = Double.NEGATIVE_INFINITY; 925 if (a.isComplex()) { 926 while (it.hasNext()) { 927 double v = ((Complex) a.getObject(pos)).abs(); 928 for (int i = 1; i < l; i++) { 929 pos[d] = i; 930 v += ((Complex) a.getObject(pos)).abs(); 931 } 932 pos[d] = 0; 933 n = Math.max(n, v); 934 } 935 } else { 936 while (it.hasNext()) { 937 double v = Math.abs(a.getDouble(pos)); 938 for (int i = 1; i < l; i++) { 939 pos[d] = i; 940 v += Math.abs(a.getDouble(pos)); 941 } 942 pos[d] = 0; 943 n = Math.max(n, v); 944 } 945 } 946 } else { 947 n = Double.POSITIVE_INFINITY; 948 if (a.isComplex()) { 949 while (it.hasNext()) { 950 double v = ((Complex) a.getObject(pos)).abs(); 951 for (int i = 1; i < l; i++) { 952 pos[d] = i; 953 v += ((Complex) a.getObject(pos)).abs(); 954 } 955 pos[d] = 0; 956 n = Math.min(n, v); 957 } 958 } else { 959 while (it.hasNext()) { 960 double v = Math.abs(a.getDouble(pos)); 961 for (int i = 1; i < l; i++) { 962 pos[d] = i; 963 v += Math.abs(a.getDouble(pos)); 964 } 965 pos[d] = 0; 966 n = Math.min(n, v); 967 } 968 } 969 } 970 return n; 971 } 972 973 /** 974 * @param a 975 * @return array of singular values 976 */ 977 public static double[] calcSingularValues(Dataset a) { 978 SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a)); 979 return svd.getSingularValues(); 980 } 981 982 983 /** 984 * Calculate singular value decomposition {@code A = U S V^T} 985 * @param a 986 * @return array of U - orthogonal matrix, s - singular values vector, V - orthogonal matrix 987 */ 988 public static Dataset[] calcSingularValueDecomposition(Dataset a) { 989 SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a)); 990 return new Dataset[] {createDataset(svd.getU()), DatasetFactory.createFromObject(svd.getSingularValues()), 991 createDataset(svd.getV())}; 992 } 993 994 /** 995 * Calculate (Moore-Penrose) pseudo-inverse 996 * @param a 997 * @return pseudo-inverse 998 */ 999 public static Dataset calcPseudoInverse(Dataset a) { 1000 SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a)); 1001 return createDataset(svd.getSolver().getInverse()); 1002 } 1003 1004 /** 1005 * Calculate matrix rank by singular value decomposition method 1006 * @param a 1007 * @return effective numerical rank of matrix 1008 */ 1009 public static int calcMatrixRank(Dataset a) { 1010 SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a)); 1011 return svd.getRank(); 1012 } 1013 1014 /** 1015 * Calculate condition number of matrix by singular value decomposition method 1016 * @param a 1017 * @return condition number 1018 */ 1019 public static double calcConditionNumber(Dataset a) { 1020 SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a)); 1021 return svd.getConditionNumber(); 1022 } 1023 1024 /** 1025 * @param a 1026 * @return determinant of dataset 1027 */ 1028 public static double calcDeterminant(Dataset a) { 1029 EigenDecomposition evd = new EigenDecomposition(createRealMatrix(a)); 1030 return evd.getDeterminant(); 1031 } 1032 1033 /** 1034 * @param a 1035 * @return dataset of eigenvalues (can be double or complex double) 1036 */ 1037 public static Dataset calcEigenvalues(Dataset a) { 1038 EigenDecomposition evd = new EigenDecomposition(createRealMatrix(a)); 1039 double[] rev = evd.getRealEigenvalues(); 1040 1041 if (evd.hasComplexEigenvalues()) { 1042 double[] iev = evd.getImagEigenvalues(); 1043 return DatasetFactory.createComplexDataset(ComplexDoubleDataset.class, rev, iev); 1044 } 1045 return DatasetFactory.createFromObject(rev); 1046 } 1047 1048 /** 1049 * Calculate eigen-decomposition {@code A = V D V^T} 1050 * @param a 1051 * @return array of D eigenvalues (can be double or complex double) and V eigenvectors 1052 */ 1053 public static Dataset[] calcEigenDecomposition(Dataset a) { 1054 EigenDecomposition evd = new EigenDecomposition(createRealMatrix(a)); 1055 Dataset[] results = new Dataset[2]; 1056 1057 double[] rev = evd.getRealEigenvalues(); 1058 if (evd.hasComplexEigenvalues()) { 1059 double[] iev = evd.getImagEigenvalues(); 1060 results[0] = DatasetFactory.createComplexDataset(ComplexDoubleDataset.class, rev, iev); 1061 } else { 1062 results[0] = DatasetFactory.createFromObject(rev); 1063 } 1064 results[1] = createDataset(evd.getV()); 1065 return results; 1066 } 1067 1068 /** 1069 * Calculate QR decomposition {@code A = Q R} 1070 * @param a 1071 * @return array of Q and R 1072 */ 1073 public static Dataset[] calcQRDecomposition(Dataset a) { 1074 QRDecomposition qrd = new QRDecomposition(createRealMatrix(a)); 1075 return new Dataset[] {createDataset(qrd.getQT()).getTransposedView(), createDataset(qrd.getR())}; 1076 } 1077 1078 /** 1079 * Calculate LU decomposition {@code A = P^-1 L U} 1080 * @param a 1081 * @return array of L, U and P 1082 */ 1083 public static Dataset[] calcLUDecomposition(Dataset a) { 1084 LUDecomposition lud = new LUDecomposition(createRealMatrix(a)); 1085 return new Dataset[] {createDataset(lud.getL()), createDataset(lud.getU()), 1086 createDataset(lud.getP())}; 1087 } 1088 1089 /** 1090 * Calculate inverse of square dataset 1091 * @param a 1092 * @return inverse 1093 */ 1094 public static Dataset calcInverse(Dataset a) { 1095 LUDecomposition lud = new LUDecomposition(createRealMatrix(a)); 1096 return createDataset(lud.getSolver().getInverse()); 1097 } 1098 1099 /** 1100 * Solve linear matrix equation {@code A x = v} 1101 * @param a 1102 * @param v 1103 * @return x 1104 */ 1105 public static Dataset solve(Dataset a, Dataset v) { 1106 LUDecomposition lud = new LUDecomposition(createRealMatrix(a)); 1107 if (v.getRank() == 1) { 1108 RealVector x = createRealVector(v); 1109 return createDataset(lud.getSolver().solve(x)); 1110 } 1111 RealMatrix x = createRealMatrix(v); 1112 return createDataset(lud.getSolver().solve(x)); 1113 } 1114 1115 1116 /** 1117 * Solve least squares matrix equation {@code A x = v} by SVD 1118 * @param a 1119 * @param v 1120 * @return x 1121 */ 1122 public static Dataset solveSVD(Dataset a, Dataset v) { 1123 SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a)); 1124 if (v.getRank() == 1) { 1125 RealVector x = createRealVector(v); 1126 return createDataset(svd.getSolver().solve(x)); 1127 } 1128 RealMatrix x = createRealMatrix(v); 1129 return createDataset(svd.getSolver().solve(x)); 1130 } 1131 1132 /** 1133 * Calculate Cholesky decomposition {@code A = L L^T} 1134 * @param a 1135 * @return L 1136 */ 1137 public static Dataset calcCholeskyDecomposition(Dataset a) { 1138 CholeskyDecomposition cd = new CholeskyDecomposition(createRealMatrix(a)); 1139 return createDataset(cd.getL()); 1140 } 1141 1142 /** 1143 * Calculation {@code A x = v} by conjugate gradient method with the stopping criterion being 1144 * that the estimated residual {@code r = v - A x} satisfies {@code ||r|| < ||v||} with maximum of 100 iterations 1145 * @param a 1146 * @param v 1147 * @return value of {@code A^-1 v} by conjugate gradient method 1148 */ 1149 public static Dataset calcConjugateGradient(Dataset a, Dataset v) { 1150 return calcConjugateGradient(a, v, 100, 1); 1151 } 1152 1153 /** 1154 * Calculation {@code A x = v} by conjugate gradient method with the stopping criterion being 1155 * that the estimated residual {@code r = v - A x} satisfies {@code ||r|| < delta ||v||} 1156 * @param a 1157 * @param v 1158 * @param maxIterations 1159 * @param delta parameter used by stopping criterion 1160 * @return value of {@code A^-1 v} by conjugate gradient method 1161 */ 1162 public static Dataset calcConjugateGradient(Dataset a, Dataset v, int maxIterations, double delta) { 1163 ConjugateGradient cg = new ConjugateGradient(maxIterations, delta, false); 1164 return createDataset(cg.solve((RealLinearOperator) createRealMatrix(a), createRealVector(v))); 1165 } 1166 1167 private static RealMatrix createRealMatrix(Dataset a) { 1168 if (a.getRank() != 2) { 1169 throw new IllegalArgumentException("Dataset must be rank 2"); 1170 } 1171 int[] shape = a.getShapeRef(); 1172 IndexIterator it = a.getIterator(true); 1173 int[] pos = it.getPos(); 1174 RealMatrix m = MatrixUtils.createRealMatrix(shape[0], shape[1]); 1175 while (it.hasNext()) { 1176 m.setEntry(pos[0], pos[1], a.getElementDoubleAbs(it.index)); 1177 } 1178 return m; 1179 } 1180 1181 private static RealVector createRealVector(Dataset a) { 1182 if (a.getRank() != 1) { 1183 throw new IllegalArgumentException("Dataset must be rank 1"); 1184 } 1185 int size = a.getSize(); 1186 IndexIterator it = a.getIterator(true); 1187 int[] pos = it.getPos(); 1188 RealVector m = new ArrayRealVector(size); 1189 while (it.hasNext()) { 1190 m.setEntry(pos[0], a.getElementDoubleAbs(it.index)); 1191 } 1192 return m; 1193 } 1194 1195 private static Dataset createDataset(RealVector v) { 1196 DoubleDataset r = DatasetFactory.zeros(DoubleDataset.class, v.getDimension()); 1197 int size = r.getSize(); 1198 if (v instanceof ArrayRealVector) { 1199 double[] data = ((ArrayRealVector) v).getDataRef(); 1200 for (int i = 0; i < size; i++) { 1201 r.setAbs(i, data[i]); 1202 } 1203 } else { 1204 for (int i = 0; i < size; i++) { 1205 r.setAbs(i, v.getEntry(i)); 1206 } 1207 } 1208 return r; 1209 } 1210 1211 private static Dataset createDataset(RealMatrix m) { 1212 DoubleDataset r = DatasetFactory.zeros(DoubleDataset.class, m.getRowDimension(), m.getColumnDimension()); 1213 if (m instanceof Array2DRowRealMatrix) { 1214 double[][] data = ((Array2DRowRealMatrix) m).getDataRef(); 1215 IndexIterator it = r.getIterator(true); 1216 int[] pos = it.getPos(); 1217 while (it.hasNext()) { 1218 r.setAbs(it.index, data[pos[0]][pos[1]]); 1219 } 1220 } else { 1221 IndexIterator it = r.getIterator(true); 1222 int[] pos = it.getPos(); 1223 while (it.hasNext()) { 1224 r.setAbs(it.index, m.getEntry(pos[0], pos[1])); 1225 } 1226 } 1227 return r; 1228 } 1229}