001/*- 002 ******************************************************************************* 003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd. 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 * 009 * Contributors: 010 * Peter Chang - initial API and implementation and/or initial documentation 011 *******************************************************************************/ 012 013package org.eclipse.january.dataset; 014 015import java.util.Arrays; 016import java.util.List; 017 018import org.apache.commons.math3.complex.Complex; 019import org.apache.commons.math3.linear.Array2DRowRealMatrix; 020import org.apache.commons.math3.linear.ArrayRealVector; 021import org.apache.commons.math3.linear.CholeskyDecomposition; 022import org.apache.commons.math3.linear.ConjugateGradient; 023import org.apache.commons.math3.linear.EigenDecomposition; 024import org.apache.commons.math3.linear.LUDecomposition; 025import org.apache.commons.math3.linear.MatrixUtils; 026import org.apache.commons.math3.linear.QRDecomposition; 027import org.apache.commons.math3.linear.RealLinearOperator; 028import org.apache.commons.math3.linear.RealMatrix; 029import org.apache.commons.math3.linear.RealVector; 030import org.apache.commons.math3.linear.SingularValueDecomposition; 031 032 033public class LinearAlgebra { 034 035 private static final int CROSSOVERPOINT = 16; // point at which using slice iterators for inner loop is faster 036 037 /** 038 * Calculate the tensor dot product over given axes. This is the sum of products of elements selected 039 * from the given axes in each dataset 040 * @param a 041 * @param b 042 * @param axisa axis dimension in a to sum over (can be -ve) 043 * @param axisb axis dimension in b to sum over (can be -ve) 044 * @return tensor dot product 045 */ 046 public static Dataset tensorDotProduct(final Dataset a, final Dataset b, final int axisa, final int axisb) { 047 // this is slower for summing lengths < ~15 048 final int[] ashape = a.getShapeRef(); 049 final int[] bshape = b.getShapeRef(); 050 final int arank = ashape.length; 051 final int brank = bshape.length; 052 int aaxis = axisa; 053 if (aaxis < 0) 054 aaxis += arank; 055 if (aaxis < 0 || aaxis >= arank) 056 throw new IllegalArgumentException("Summing axis outside valid rank of 1st dataset"); 057 058 if (ashape[aaxis] < CROSSOVERPOINT) { // faster to use position iteration 059 return tensorDotProduct(a, b, new int[] {axisa}, new int[] {axisb}); 060 } 061 int baxis = axisb; 062 if (baxis < 0) 063 baxis += arank; 064 if (baxis < 0 || baxis >= arank) 065 throw new IllegalArgumentException("Summing axis outside valid rank of 2nd dataset"); 066 067 final boolean[] achoice = new boolean[arank]; 068 final boolean[] bchoice = new boolean[brank]; 069 Arrays.fill(achoice, true); 070 Arrays.fill(bchoice, true); 071 achoice[aaxis] = false; // flag which axes not to iterate over 072 bchoice[baxis] = false; 073 074 final boolean[] notachoice = new boolean[arank]; 075 final boolean[] notbchoice = new boolean[brank]; 076 notachoice[aaxis] = true; // flag which axes to iterate over 077 notbchoice[baxis] = true; 078 079 int drank = arank + brank - 2; 080 int[] dshape = new int[drank]; 081 int d = 0; 082 for (int i = 0; i < arank; i++) { 083 if (achoice[i]) 084 dshape[d++] = ashape[i]; 085 } 086 for (int i = 0; i < brank; i++) { 087 if (bchoice[i]) 088 dshape[d++] = bshape[i]; 089 } 090 int dtype = DTypeUtils.getBestDType(a.getDType(), b.getDType()); 091 @SuppressWarnings("deprecation") 092 Dataset data = DatasetFactory.zeros(dshape, dtype); 093 094 SliceIterator ita = a.getSliceIteratorFromAxes(null, achoice); 095 int l = 0; 096 final int[] apos = ita.getPos(); 097 while (ita.hasNext()) { 098 SliceIterator itb = b.getSliceIteratorFromAxes(null, bchoice); 099 final int[] bpos = itb.getPos(); 100 while (itb.hasNext()) { 101 SliceIterator itaa = a.getSliceIteratorFromAxes(apos, notachoice); 102 SliceIterator itba = b.getSliceIteratorFromAxes(bpos, notbchoice); 103 double sum = 0.0; 104 double com = 0.0; 105 while (itaa.hasNext() && itba.hasNext()) { 106 final double y = a.getElementDoubleAbs(itaa.index) * b.getElementDoubleAbs(itba.index) - com; 107 final double t = sum + y; 108 com = (t - sum) - y; 109 sum = t; 110 } 111 data.setObjectAbs(l++, sum); 112 } 113 } 114 115 return data; 116 } 117 118 /** 119 * Calculate the tensor dot product over given axes. This is the sum of products of elements selected 120 * from the given axes in each dataset 121 * @param a 122 * @param b 123 * @param axisa axis dimensions in a to sum over (can be -ve) 124 * @param axisb axis dimensions in b to sum over (can be -ve) 125 * @return tensor dot product 126 */ 127 public static Dataset tensorDotProduct(final Dataset a, final Dataset b, final int[] axisa, final int[] axisb) { 128 if (axisa.length != axisb.length) { 129 throw new IllegalArgumentException("Numbers of summing axes must be same"); 130 } 131 final int[] ashape = a.getShapeRef(); 132 final int[] bshape = b.getShapeRef(); 133 final int arank = ashape.length; 134 final int brank = bshape.length; 135 final int[] aaxes = new int[axisa.length]; 136 final int[] baxes = new int[axisa.length]; 137 for (int i = 0; i < axisa.length; i++) { 138 int n; 139 140 n = axisa[i]; 141 if (n < 0) n += arank; 142 if (n < 0 || n >= arank) 143 throw new IllegalArgumentException("Summing axis outside valid rank of 1st dataset"); 144 aaxes[i] = n; 145 146 n = axisb[i]; 147 if (n < 0) n += brank; 148 if (n < 0 || n >= brank) 149 throw new IllegalArgumentException("Summing axis outside valid rank of 2nd dataset"); 150 baxes[i] = n; 151 152 if (ashape[aaxes[i]] != bshape[n]) 153 throw new IllegalArgumentException("Summing axes do not have matching lengths"); 154 } 155 156 final boolean[] achoice = new boolean[arank]; 157 final boolean[] bchoice = new boolean[brank]; 158 Arrays.fill(achoice, true); 159 Arrays.fill(bchoice, true); 160 for (int i = 0; i < aaxes.length; i++) { // flag which axes to iterate over 161 achoice[aaxes[i]] = false; 162 bchoice[baxes[i]] = false; 163 } 164 165 int drank = arank + brank - 2*aaxes.length; 166 int[] dshape = new int[drank]; 167 int d = 0; 168 for (int i = 0; i < arank; i++) { 169 if (achoice[i]) 170 dshape[d++] = ashape[i]; 171 } 172 for (int i = 0; i < brank; i++) { 173 if (bchoice[i]) 174 dshape[d++] = bshape[i]; 175 } 176 int dtype = DTypeUtils.getBestDType(a.getDType(), b.getDType()); 177 @SuppressWarnings("deprecation") 178 Dataset data = DatasetFactory.zeros(dshape, dtype); 179 180 SliceIterator ita = a.getSliceIteratorFromAxes(null, achoice); 181 int l = 0; 182 final int[] apos = ita.getPos(); 183 while (ita.hasNext()) { 184 SliceIterator itb = b.getSliceIteratorFromAxes(null, bchoice); 185 final int[] bpos = itb.getPos(); 186 while (itb.hasNext()) { 187 double sum = 0.0; 188 double com = 0.0; 189 apos[aaxes[aaxes.length - 1]] = -1; 190 bpos[baxes[aaxes.length - 1]] = -1; 191 while (true) { // step through summing axes 192 int e = aaxes.length - 1; 193 for (; e >= 0; e--) { 194 int ai = aaxes[e]; 195 int bi = baxes[e]; 196 197 apos[ai]++; 198 bpos[bi]++; 199 if (apos[ai] == ashape[ai]) { 200 apos[ai] = 0; 201 bpos[bi] = 0; 202 } else 203 break; 204 } 205 if (e == -1) break; 206 final double y = a.getDouble(apos) * b.getDouble(bpos) - com; 207 final double t = sum + y; 208 com = (t - sum) - y; 209 sum = t; 210 } 211 data.setObjectAbs(l++, sum); 212 } 213 } 214 215 return data; 216 } 217 218 /** 219 * Calculate the dot product of two datasets. When <b>b</b> is a 1D dataset, the sum product over 220 * the last axis of <b>a</b> and <b>b</b> is returned. Where <b>a</b> is also a 1D dataset, a zero-rank dataset 221 * is returned. If <b>b</b> is 2D or higher, its second-to-last axis is used 222 * @param a 223 * @param b 224 * @return dot product 225 */ 226 public static Dataset dotProduct(Dataset a, Dataset b) { 227 if (b.getRank() < 2) 228 return tensorDotProduct(a, b, -1, 0); 229 return tensorDotProduct(a, b, -1, -2); 230 } 231 232 /** 233 * Calculate the outer product of two datasets 234 * @param a 235 * @param b 236 * @return outer product 237 */ 238 public static Dataset outerProduct(Dataset a, Dataset b) { 239 int[] as = a.getShapeRef(); 240 int[] bs = b.getShapeRef(); 241 int rank = as.length + bs.length; 242 int[] shape = new int[rank]; 243 for (int i = 0; i < as.length; i++) { 244 shape[i] = as[i]; 245 } 246 for (int i = 0; i < bs.length; i++) { 247 shape[as.length + i] = bs[i]; 248 } 249 int isa = a.getElementsPerItem(); 250 int isb = b.getElementsPerItem(); 251 if (isa != 1 || isb != 1) { 252 throw new UnsupportedOperationException("Compound datasets not supported"); 253 } 254 @SuppressWarnings("deprecation") 255 Dataset o = DatasetFactory.zeros(shape, DTypeUtils.getBestDType(a.getDType(), b.getDType())); 256 257 IndexIterator ita = a.getIterator(); 258 IndexIterator itb = b.getIterator(); 259 int j = 0; 260 while (ita.hasNext()) { 261 double va = a.getElementDoubleAbs(ita.index); 262 while (itb.hasNext()) { 263 o.setObjectAbs(j++, va * b.getElementDoubleAbs(itb.index)); 264 } 265 itb.reset(); 266 } 267 return o; 268 } 269 270 /** 271 * Calculate the cross product of two datasets. Datasets must be broadcastable and 272 * possess last dimensions of length 2 or 3 273 * @param a 274 * @param b 275 * @return cross product 276 */ 277 public static Dataset crossProduct(Dataset a, Dataset b) { 278 return crossProduct(a, b, -1, -1, -1); 279 } 280 281 /** 282 * Calculate the cross product of two datasets. Datasets must be broadcastable and 283 * possess dimensions of length 2 or 3. The axis parameters can be negative to indicate 284 * dimensions from the end of their shapes 285 * @param a 286 * @param b 287 * @param axisA dimension to be used a vector (must have length of 2 or 3) 288 * @param axisB dimension to be used a vector (must have length of 2 or 3) 289 * @param axisC dimension to assign as cross-product 290 * @return cross product 291 */ 292 public static Dataset crossProduct(Dataset a, Dataset b, int axisA, int axisB, int axisC) { 293 final int rankA = a.getRank(); 294 final int rankB = b.getRank(); 295 if (rankA == 0 || rankB == 0) { 296 throw new IllegalArgumentException("Datasets must have one or more dimensions"); 297 } 298 if (axisA < 0) { 299 axisA += rankA; 300 } 301 if (axisA < 0 || axisA >= rankA) { 302 throw new IllegalArgumentException("Axis A argument exceeds rank"); 303 } 304 if (axisB < 0) { 305 axisB += rankB; 306 } 307 if (axisB < 0 || axisB >= rankB) { 308 throw new IllegalArgumentException("Axis B argument exceeds rank"); 309 } 310 311 final int[] shapeA = a.getShape(); 312 final int[] shapeB = b.getShape(); 313 int la = shapeA[axisA]; 314 int lb = shapeB[axisB]; 315 if (Math.min(la, lb) < 2 || Math.max(la, lb) > 3) { 316 throw new IllegalArgumentException("Chosen dimension of A & B must be 2 or 3"); 317 } 318 319 if (Math.max(la, lb) == 2) { 320 return crossProduct2D(a, b, axisA, axisB); 321 } 322 323 return crossProduct3D(a, b, axisA, axisB, axisC); 324 } 325 326 private static int[] removeAxisFromShape(int[] shape, int axis) { 327 int[] s = new int[shape.length - 1]; 328 int i = 0; 329 int j = 0; 330 while (i < axis) { 331 s[j++] = shape[i++]; 332 } 333 i++; 334 while (i < shape.length) { 335 s[j++] = shape[i++]; 336 } 337 return s; 338 } 339 340 // assume axes is in increasing order 341 private static int[] removeAxesFromShape(int[] shape, int... axes) { 342 int n = axes.length; 343 int[] s = new int[shape.length - n]; 344 int i = 0; 345 int j = 0; 346 for (int k = 0; k < n; k++) { 347 int a = axes[k]; 348 while (i < a) { 349 s[j++] = shape[i++]; 350 } 351 i++; 352 } 353 while (i < shape.length) { 354 s[j++] = shape[i++]; 355 } 356 return s; 357 } 358 359 private static int[] addAxisToShape(int[] shape, int axis, int length) { 360 int[] s = new int[shape.length + 1]; 361 int i = 0; 362 int j = 0; 363 while (i < axis) { 364 s[j++] = shape[i++]; 365 } 366 s[j++] = length; 367 while (i < shape.length) { 368 s[j++] = shape[i++]; 369 } 370 return s; 371 } 372 373 // assume axes is in increasing order 374 private static int[] addAxesToShape(int[] shape, int[] axes, int[] lengths) { 375 int n = axes.length; 376 if (lengths.length != n) { 377 throw new IllegalArgumentException("Axes and lengths arrays must be same size"); 378 } 379 int[] s = new int[shape.length + n]; 380 int i = 0; 381 int j = 0; 382 for (int k = 0; k < n; k++) { 383 int a = axes[k]; 384 while (i < a) { 385 s[j++] = shape[i++]; 386 } 387 s[j++] = lengths[k]; 388 } 389 while (i < shape.length) { 390 s[j++] = shape[i++]; 391 } 392 return s; 393 } 394 395 private static Dataset crossProduct2D(Dataset a, Dataset b, int axisA, int axisB) { 396 // need to broadcast and omit given axes 397 int[] shapeA = removeAxisFromShape(a.getShapeRef(), axisA); 398 int[] shapeB = removeAxisFromShape(b.getShapeRef(), axisB); 399 400 List<int[]> fullShapes = BroadcastUtils.broadcastShapes(shapeA, shapeB); 401 402 int[] maxShape = fullShapes.get(0); 403 @SuppressWarnings("deprecation") 404 Dataset c = DatasetFactory.zeros(maxShape, DTypeUtils.getBestDType(a.getDType(), b.getDType())); 405 406 PositionIterator ita = a.getPositionIterator(axisA); 407 PositionIterator itb = b.getPositionIterator(axisB); 408 IndexIterator itc = c.getIterator(); 409 410 final int[] pa = ita.getPos(); 411 final int[] pb = itb.getPos(); 412 while (itc.hasNext()) { 413 if (!ita.hasNext()) // TODO use broadcasting... 414 ita.reset(); 415 if (!itb.hasNext()) 416 itb.reset(); 417 pa[axisA] = 0; 418 pb[axisB] = 1; 419 double cv = a.getDouble(pa) * b.getDouble(pb); 420 pa[axisA] = 1; 421 pb[axisB] = 0; 422 cv -= a.getDouble(pa) * b.getDouble(pb); 423 424 c.setObjectAbs(itc.index, cv); 425 } 426 return c; 427 } 428 429 private static Dataset crossProduct3D(Dataset a, Dataset b, int axisA, int axisB, int axisC) { 430 int[] shapeA = removeAxisFromShape(a.getShapeRef(), axisA); 431 int[] shapeB = removeAxisFromShape(b.getShapeRef(), axisB); 432 433 List<int[]> fullShapes = BroadcastUtils.broadcastShapes(shapeA, shapeB); 434 435 int[] maxShape = fullShapes.get(0); 436 int rankC = maxShape.length + 1; 437 if (axisC < 0) { 438 axisC += rankC; 439 } 440 if (axisC < 0 || axisC >= rankC) { 441 throw new IllegalArgumentException("Axis C argument exceeds rank"); 442 } 443 maxShape = addAxisToShape(maxShape, axisC, 3); 444 @SuppressWarnings("deprecation") 445 Dataset c = DatasetFactory.zeros(maxShape, DTypeUtils.getBestDType(a.getDType(), b.getDType())); 446 447 PositionIterator ita = a.getPositionIterator(axisA); 448 PositionIterator itb = b.getPositionIterator(axisB); 449 PositionIterator itc = c.getPositionIterator(axisC); 450 451 final int[] pa = ita.getPos(); 452 final int[] pb = itb.getPos(); 453 final int[] pc = itc.getPos(); 454 final int la = a.getShapeRef()[axisA]; 455 final int lb = b.getShapeRef()[axisB]; 456 457 if (la == 2) { 458 while (itc.hasNext()) { 459 if (!ita.hasNext()) // TODO use broadcasting... 460 ita.reset(); 461 if (!itb.hasNext()) 462 itb.reset(); 463 double cv; 464 pa[axisA] = 1; 465 pb[axisB] = 2; 466 cv = a.getDouble(pa) * b.getDouble(pb); 467 pc[axisC] = 0; 468 c.set(cv, pc); 469 470 pa[axisA] = 0; 471 pb[axisB] = 2; 472 cv = -a.getDouble(pa) * b.getDouble(pb); 473 pc[axisC] = 1; 474 c.set(cv, pc); 475 476 pa[axisA] = 0; 477 pb[axisB] = 1; 478 cv = a.getDouble(pa) * b.getDouble(pb); 479 pa[axisA] = 1; 480 pb[axisB] = 0; 481 cv -= a.getDouble(pa) * b.getDouble(pb); 482 pc[axisC] = 2; 483 c.set(cv, pc); 484 } 485 } else if (lb == 2) { 486 while (itc.hasNext()) { 487 if (!ita.hasNext()) // TODO use broadcasting... 488 ita.reset(); 489 if (!itb.hasNext()) 490 itb.reset(); 491 double cv; 492 pa[axisA] = 2; 493 pb[axisB] = 1; 494 cv = -a.getDouble(pa) * b.getDouble(pb); 495 pc[axisC] = 0; 496 c.set(cv, pc); 497 498 pa[axisA] = 2; 499 pb[axisB] = 0; 500 cv = a.getDouble(pa) * b.getDouble(pb); 501 pc[axisC] = 1; 502 c.set(cv, pc); 503 504 pa[axisA] = 0; 505 pb[axisB] = 1; 506 cv = a.getDouble(pa) * b.getDouble(pb); 507 pa[axisA] = 1; 508 pb[axisB] = 0; 509 cv -= a.getDouble(pa) * b.getDouble(pb); 510 pc[axisC] = 2; 511 c.set(cv, pc); 512 } 513 514 } else { 515 while (itc.hasNext()) { 516 if (!ita.hasNext()) // TODO use broadcasting... 517 ita.reset(); 518 if (!itb.hasNext()) 519 itb.reset(); 520 double cv; 521 pa[axisA] = 1; 522 pb[axisB] = 2; 523 cv = a.getDouble(pa) * b.getDouble(pb); 524 pa[axisA] = 2; 525 pb[axisB] = 1; 526 cv -= a.getDouble(pa) * b.getDouble(pb); 527 pc[axisC] = 0; 528 c.set(cv, pc); 529 530 pa[axisA] = 2; 531 pb[axisB] = 0; 532 cv = a.getDouble(pa) * b.getDouble(pb); 533 pa[axisA] = 0; 534 pb[axisB] = 2; 535 cv -= a.getDouble(pa) * b.getDouble(pb); 536 pc[axisC] = 1; 537 c.set(cv, pc); 538 539 pa[axisA] = 0; 540 pb[axisB] = 1; 541 cv = a.getDouble(pa) * b.getDouble(pb); 542 pa[axisA] = 1; 543 pb[axisB] = 0; 544 cv -= a.getDouble(pa) * b.getDouble(pb); 545 pc[axisC] = 2; 546 c.set(cv, pc); 547 } 548 } 549 return c; 550 } 551 552 /** 553 * Raise dataset to given power by matrix multiplication 554 * @param a 555 * @param n power 556 * @return a ** n 557 */ 558 public static Dataset power(Dataset a, int n) { 559 if (n < 0) { 560 LUDecomposition lud = new LUDecomposition(createRealMatrix(a)); 561 return createDataset(lud.getSolver().getInverse().power(-n)); 562 } 563 Dataset p = createDataset(createRealMatrix(a).power(n)); 564 if (!a.hasFloatingPointElements()) 565 return p.cast(a.getDType()); 566 return p; 567 } 568 569 /** 570 * Create the Kronecker product as defined by 571 * kron[k0,...,kN] = a[i0,...,iN] * b[j0,...,jN] 572 * where kn = sn * in + jn for n = 0...N and s is shape of b 573 * @param a 574 * @param b 575 * @return Kronecker product of a and b 576 */ 577 public static Dataset kroneckerProduct(Dataset a, Dataset b) { 578 if (a.getElementsPerItem() != 1 || b.getElementsPerItem() != 1) { 579 throw new UnsupportedOperationException("Compound datasets (including complex ones) are not currently supported"); 580 } 581 int ar = a.getRank(); 582 int br = b.getRank(); 583 int[] aShape; 584 int[] bShape; 585 aShape = a.getShapeRef(); 586 bShape = b.getShapeRef(); 587 int r = ar; 588 // pre-pad if ranks are not same 589 if (ar < br) { 590 r = br; 591 int[] shape = new int[br]; 592 int j = 0; 593 for (int i = ar; i < br; i++) { 594 shape[j++] = 1; 595 } 596 int i = 0; 597 while (j < br) { 598 shape[j++] = aShape[i++]; 599 } 600 a = a.reshape(shape); 601 aShape = shape; 602 } else if (ar > br) { 603 int[] shape = new int[ar]; 604 int j = 0; 605 for (int i = br; i < ar; i++) { 606 shape[j++] = 1; 607 } 608 int i = 0; 609 while (j < ar) { 610 shape[j++] = bShape[i++]; 611 } 612 b = b.reshape(shape); 613 bShape = shape; 614 } 615 616 int[] nShape = new int[r]; 617 for (int i = 0; i < r; i++) { 618 nShape[i] = aShape[i] * bShape[i]; 619 } 620 @SuppressWarnings("deprecation") 621 Dataset kron = DatasetFactory.zeros(nShape, DTypeUtils.getBestDType(a.getDType(), b.getDType())); 622 IndexIterator ita = a.getIterator(true); 623 IndexIterator itb = b.getIterator(true); 624 int[] pa = ita.getPos(); 625 int[] pb = itb.getPos(); 626 int[] off = new int[1]; 627 int[] stride = AbstractDataset.createStrides(1, nShape, null, 0, off); 628 if (kron.getDType() == Dataset.INT64) { 629 while (ita.hasNext()) { 630 long av = a.getElementLongAbs(ita.index); 631 632 int ka = 0; 633 for (int i = 0; i < r; i++) { 634 ka += stride[i] * bShape[i] * pa[i]; 635 } 636 itb.reset(); 637 while (itb.hasNext()) { 638 long bv = b.getElementLongAbs(itb.index); 639 int kb = ka; 640 for (int i = 0; i < r; i++) { 641 kb += stride[i] * pb[i]; 642 } 643 kron.setObjectAbs(kb, av * bv); 644 } 645 } 646 } else { 647 while (ita.hasNext()) { 648 double av = a.getElementDoubleAbs(ita.index); 649 650 int ka = 0; 651 for (int i = 0; i < r; i++) { 652 ka += stride[i] * bShape[i] * pa[i]; 653 } 654 itb.reset(); 655 while (itb.hasNext()) { 656 double bv = b.getElementLongAbs(itb.index); 657 int kb = ka; 658 for (int i = 0; i < r; i++) { 659 kb += stride[i] * pb[i]; 660 } 661 kron.setObjectAbs(kb, av * bv); 662 } 663 } 664 } 665 666 return kron; 667 } 668 669 /** 670 * Calculate trace of dataset - sum of values over 1st axis and 2nd axis 671 * @param a 672 * @return trace of dataset 673 */ 674 public static Dataset trace(Dataset a) { 675 return trace(a, 0, 0, 1); 676 } 677 678 /** 679 * Calculate trace of dataset - sum of values over axis1 and axis2 where axis2 is offset 680 * @param a 681 * @param offset 682 * @param axis1 683 * @param axis2 684 * @return trace of dataset 685 */ 686 public static Dataset trace(Dataset a, int offset, int axis1, int axis2) { 687 int[] shape = a.getShapeRef(); 688 int[] axes = new int[] { a.checkAxis(axis1), a.checkAxis(axis2) }; 689 Arrays.sort(axes); 690 int is = a.getElementsPerItem(); 691 @SuppressWarnings("deprecation") 692 Dataset trace = DatasetFactory.zeros(is, removeAxesFromShape(shape, axes), a.getDType()); 693 694 int am = axes[0]; 695 int mmax = shape[am]; 696 int an = axes[1]; 697 int nmax = shape[an]; 698 PositionIterator it = new PositionIterator(shape, axes); 699 int[] pos = it.getPos(); 700 int i = 0; 701 int mmin; 702 int nmin; 703 if (offset >= 0) { 704 mmin = 0; 705 nmin = offset; 706 } else { 707 mmin = -offset; 708 nmin = 0; 709 } 710 if (is == 1) { 711 if (a.getDType() == Dataset.INT64) { 712 while (it.hasNext()) { 713 int m = mmin; 714 int n = nmin; 715 long s = 0; 716 while (m < mmax && n < nmax) { 717 pos[am] = m++; 718 pos[an] = n++; 719 s += a.getLong(pos); 720 } 721 trace.setObjectAbs(i++, s); 722 } 723 } else { 724 while (it.hasNext()) { 725 int m = mmin; 726 int n = nmin; 727 double s = 0; 728 while (m < mmax && n < nmax) { 729 pos[am] = m++; 730 pos[an] = n++; 731 s += a.getDouble(pos); 732 } 733 trace.setObjectAbs(i++, s); 734 } 735 } 736 } else { 737 AbstractCompoundDataset ca = (AbstractCompoundDataset) a; 738 if (ca instanceof CompoundLongDataset) { 739 long[] t = new long[is]; 740 long[] s = new long[is]; 741 while (it.hasNext()) { 742 int m = mmin; 743 int n = nmin; 744 Arrays.fill(s, 0); 745 while (m < mmax && n < nmax) { 746 pos[am] = m++; 747 pos[an] = n++; 748 ((CompoundLongDataset)ca).getAbs(ca.get1DIndex(pos), t); 749 for (int k = 0; k < is; k++) { 750 s[k] += t[k]; 751 } 752 } 753 trace.setObjectAbs(i++, s); 754 } 755 } else { 756 double[] t = new double[is]; 757 double[] s = new double[is]; 758 while (it.hasNext()) { 759 int m = mmin; 760 int n = nmin; 761 Arrays.fill(s, 0); 762 while (m < mmax && n < nmax) { 763 pos[am] = m++; 764 pos[an] = n++; 765 ca.getDoubleArray(t, pos); 766 for (int k = 0; k < is; k++) { 767 s[k] += t[k]; 768 } 769 } 770 trace.setObjectAbs(i++, s); 771 } 772 } 773 } 774 775 return trace; 776 } 777 778 /** 779 * Order value for norm 780 */ 781 public enum NormOrder { 782 /** 783 * 2-norm for vectors and Frobenius for matrices 784 */ 785 DEFAULT, 786 /** 787 * Frobenius (not allowed for vectors) 788 */ 789 FROBENIUS, 790 /** 791 * Zero-order (not allowed for matrices) 792 */ 793 ZERO, 794 /** 795 * Positive infinity 796 */ 797 POS_INFINITY, 798 /** 799 * Negative infinity 800 */ 801 NEG_INFINITY; 802 } 803 804 /** 805 * @param a 806 * @return norm of dataset 807 */ 808 public static double norm(Dataset a) { 809 return norm(a, NormOrder.DEFAULT); 810 } 811 812 /** 813 * @param a 814 * @param order 815 * @return norm of dataset 816 */ 817 public static double norm(Dataset a, NormOrder order) { 818 int r = a.getRank(); 819 if (r == 1) { 820 return vectorNorm(a, order); 821 } else if (r == 2) { 822 return matrixNorm(a, order); 823 } 824 throw new IllegalArgumentException("Rank of dataset must be one or two"); 825 } 826 827 private static double vectorNorm(Dataset a, NormOrder order) { 828 double n; 829 IndexIterator it; 830 switch (order) { 831 case FROBENIUS: 832 throw new IllegalArgumentException("Not allowed for vectors"); 833 case NEG_INFINITY: 834 case POS_INFINITY: 835 it = a.getIterator(); 836 if (order == NormOrder.POS_INFINITY) { 837 n = Double.NEGATIVE_INFINITY; 838 if (a.isComplex()) { 839 while (it.hasNext()) { 840 double v = ((Complex) a.getObjectAbs(it.index)).abs(); 841 n = Math.max(n, v); 842 } 843 } else { 844 while (it.hasNext()) { 845 double v = Math.abs(a.getElementDoubleAbs(it.index)); 846 n = Math.max(n, v); 847 } 848 } 849 } else { 850 n = Double.POSITIVE_INFINITY; 851 if (a.isComplex()) { 852 while (it.hasNext()) { 853 double v = ((Complex) a.getObjectAbs(it.index)).abs(); 854 n = Math.min(n, v); 855 } 856 } else { 857 while (it.hasNext()) { 858 double v = Math.abs(a.getElementDoubleAbs(it.index)); 859 n = Math.min(n, v); 860 } 861 } 862 } 863 break; 864 case ZERO: 865 it = a.getIterator(); 866 n = 0; 867 if (a.isComplex()) { 868 while (it.hasNext()) { 869 if (!((Complex) a.getObjectAbs(it.index)).equals(Complex.ZERO)) 870 n++; 871 } 872 } else { 873 while (it.hasNext()) { 874 if (a.getElementBooleanAbs(it.index)) 875 n++; 876 } 877 } 878 879 break; 880 default: 881 n = vectorNorm(a, 2); 882 break; 883 } 884 return n; 885 } 886 887 private static double matrixNorm(Dataset a, NormOrder order) { 888 double n; 889 IndexIterator it; 890 switch (order) { 891 case NEG_INFINITY: 892 case POS_INFINITY: 893 n = maxMinMatrixNorm(a, 1, order == NormOrder.POS_INFINITY); 894 break; 895 case ZERO: 896 throw new IllegalArgumentException("Not allowed for matrices"); 897 default: 898 case FROBENIUS: 899 it = a.getIterator(); 900 n = 0; 901 if (a.isComplex()) { 902 while (it.hasNext()) { 903 double v = ((Complex) a.getObjectAbs(it.index)).abs(); 904 n += v*v; 905 } 906 } else { 907 while (it.hasNext()) { 908 double v = a.getElementDoubleAbs(it.index); 909 n += v*v; 910 } 911 } 912 n = Math.sqrt(n); 913 break; 914 } 915 return n; 916 } 917 918 /** 919 * @param a 920 * @param p 921 * @return p-norm of dataset 922 */ 923 public static double norm(Dataset a, final double p) { 924 if (p == 0) { 925 return norm(a, NormOrder.ZERO); 926 } 927 int r = a.getRank(); 928 if (r == 1) { 929 return vectorNorm(a, p); 930 } else if (r == 2) { 931 return matrixNorm(a, p); 932 } 933 throw new IllegalArgumentException("Rank of dataset must be one or two"); 934 } 935 936 private static double vectorNorm(Dataset a, final double p) { 937 IndexIterator it = a.getIterator(); 938 double n = 0; 939 if (a.isComplex()) { 940 while (it.hasNext()) { 941 double v = ((Complex) a.getObjectAbs(it.index)).abs(); 942 if (p == 2) { 943 v *= v; 944 } else if (p != 1) { 945 v = Math.pow(v, p); 946 } 947 n += v; 948 } 949 } else { 950 while (it.hasNext()) { 951 double v = a.getElementDoubleAbs(it.index); 952 if (p == 1) { 953 v = Math.abs(v); 954 } else if (p == 2) { 955 v *= v; 956 } else { 957 v = Math.pow(Math.abs(v), p); 958 } 959 n += v; 960 } 961 } 962 return Math.pow(n, 1./p); 963 } 964 965 private static double matrixNorm(Dataset a, final double p) { 966 double n; 967 if (Math.abs(p) == 1) { 968 n = maxMinMatrixNorm(a, 0, p > 0); 969 } else if (Math.abs(p) == 2) { 970 double[] s = calcSingularValues(a); 971 n = p > 0 ? s[0] : s[s.length - 1]; 972 } else { 973 throw new IllegalArgumentException("Order not allowed"); 974 } 975 976 return n; 977 } 978 979 private static double maxMinMatrixNorm(Dataset a, int d, boolean max) { 980 double n; 981 IndexIterator it; 982 int[] pos; 983 int l; 984 it = a.getPositionIterator(d); 985 pos = it.getPos(); 986 l = a.getShapeRef()[d]; 987 if (max) { 988 n = Double.NEGATIVE_INFINITY; 989 if (a.isComplex()) { 990 while (it.hasNext()) { 991 double v = ((Complex) a.getObject(pos)).abs(); 992 for (int i = 1; i < l; i++) { 993 pos[d] = i; 994 v += ((Complex) a.getObject(pos)).abs(); 995 } 996 pos[d] = 0; 997 n = Math.max(n, v); 998 } 999 } else { 1000 while (it.hasNext()) { 1001 double v = Math.abs(a.getDouble(pos)); 1002 for (int i = 1; i < l; i++) { 1003 pos[d] = i; 1004 v += Math.abs(a.getDouble(pos)); 1005 } 1006 pos[d] = 0; 1007 n = Math.max(n, v); 1008 } 1009 } 1010 } else { 1011 n = Double.POSITIVE_INFINITY; 1012 if (a.isComplex()) { 1013 while (it.hasNext()) { 1014 double v = ((Complex) a.getObject(pos)).abs(); 1015 for (int i = 1; i < l; i++) { 1016 pos[d] = i; 1017 v += ((Complex) a.getObject(pos)).abs(); 1018 } 1019 pos[d] = 0; 1020 n = Math.min(n, v); 1021 } 1022 } else { 1023 while (it.hasNext()) { 1024 double v = Math.abs(a.getDouble(pos)); 1025 for (int i = 1; i < l; i++) { 1026 pos[d] = i; 1027 v += Math.abs(a.getDouble(pos)); 1028 } 1029 pos[d] = 0; 1030 n = Math.min(n, v); 1031 } 1032 } 1033 } 1034 return n; 1035 } 1036 1037 /** 1038 * @param a 1039 * @return array of singular values 1040 */ 1041 public static double[] calcSingularValues(Dataset a) { 1042 SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a)); 1043 return svd.getSingularValues(); 1044 } 1045 1046 1047 /** 1048 * Calculate singular value decomposition A = U S V^T 1049 * @param a 1050 * @return array of U - orthogonal matrix, s - singular values vector, V - orthogonal matrix 1051 */ 1052 public static Dataset[] calcSingularValueDecomposition(Dataset a) { 1053 SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a)); 1054 return new Dataset[] {createDataset(svd.getU()), DatasetFactory.createFromObject(svd.getSingularValues()), 1055 createDataset(svd.getV())}; 1056 } 1057 1058 /** 1059 * Calculate (Moore-Penrose) pseudo-inverse 1060 * @param a 1061 * @return pseudo-inverse 1062 */ 1063 public static Dataset calcPseudoInverse(Dataset a) { 1064 SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a)); 1065 return createDataset(svd.getSolver().getInverse()); 1066 } 1067 1068 /** 1069 * Calculate matrix rank by singular value decomposition method 1070 * @param a 1071 * @return effective numerical rank of matrix 1072 */ 1073 public static int calcMatrixRank(Dataset a) { 1074 SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a)); 1075 return svd.getRank(); 1076 } 1077 1078 /** 1079 * Calculate condition number of matrix by singular value decomposition method 1080 * @param a 1081 * @return condition number 1082 */ 1083 public static double calcConditionNumber(Dataset a) { 1084 SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a)); 1085 return svd.getConditionNumber(); 1086 } 1087 1088 /** 1089 * @param a 1090 * @return determinant of dataset 1091 */ 1092 public static double calcDeterminant(Dataset a) { 1093 EigenDecomposition evd = new EigenDecomposition(createRealMatrix(a)); 1094 return evd.getDeterminant(); 1095 } 1096 1097 /** 1098 * @param a 1099 * @return dataset of eigenvalues (can be double or complex double) 1100 */ 1101 public static Dataset calcEigenvalues(Dataset a) { 1102 EigenDecomposition evd = new EigenDecomposition(createRealMatrix(a)); 1103 double[] rev = evd.getRealEigenvalues(); 1104 1105 if (evd.hasComplexEigenvalues()) { 1106 double[] iev = evd.getImagEigenvalues(); 1107 return DatasetFactory.createComplexDataset(ComplexDoubleDataset.class, rev, iev); 1108 } 1109 return DatasetFactory.createFromObject(rev); 1110 } 1111 1112 /** 1113 * Calculate eigen-decomposition A = V D V^T 1114 * @param a 1115 * @return array of D eigenvalues (can be double or complex double) and V eigenvectors 1116 */ 1117 public static Dataset[] calcEigenDecomposition(Dataset a) { 1118 EigenDecomposition evd = new EigenDecomposition(createRealMatrix(a)); 1119 Dataset[] results = new Dataset[2]; 1120 1121 double[] rev = evd.getRealEigenvalues(); 1122 if (evd.hasComplexEigenvalues()) { 1123 double[] iev = evd.getImagEigenvalues(); 1124 results[0] = DatasetFactory.createComplexDataset(ComplexDoubleDataset.class, rev, iev); 1125 } else { 1126 results[0] = DatasetFactory.createFromObject(rev); 1127 } 1128 results[1] = createDataset(evd.getV()); 1129 return results; 1130 } 1131 1132 /** 1133 * Calculate QR decomposition A = Q R 1134 * @param a 1135 * @return array of Q and R 1136 */ 1137 public static Dataset[] calcQRDecomposition(Dataset a) { 1138 QRDecomposition qrd = new QRDecomposition(createRealMatrix(a)); 1139 return new Dataset[] {createDataset(qrd.getQT()).getTransposedView(), createDataset(qrd.getR())}; 1140 } 1141 1142 /** 1143 * Calculate LU decomposition A = P^-1 L U 1144 * @param a 1145 * @return array of L, U and P 1146 */ 1147 public static Dataset[] calcLUDecomposition(Dataset a) { 1148 LUDecomposition lud = new LUDecomposition(createRealMatrix(a)); 1149 return new Dataset[] {createDataset(lud.getL()), createDataset(lud.getU()), 1150 createDataset(lud.getP())}; 1151 } 1152 1153 /** 1154 * Calculate inverse of square dataset 1155 * @param a 1156 * @return inverse 1157 */ 1158 public static Dataset calcInverse(Dataset a) { 1159 LUDecomposition lud = new LUDecomposition(createRealMatrix(a)); 1160 return createDataset(lud.getSolver().getInverse()); 1161 } 1162 1163 /** 1164 * Solve linear matrix equation A x = v 1165 * @param a 1166 * @param v 1167 * @return x 1168 */ 1169 public static Dataset solve(Dataset a, Dataset v) { 1170 LUDecomposition lud = new LUDecomposition(createRealMatrix(a)); 1171 if (v.getRank() == 1) { 1172 RealVector x = createRealVector(v); 1173 return createDataset(lud.getSolver().solve(x)); 1174 } 1175 RealMatrix x = createRealMatrix(v); 1176 return createDataset(lud.getSolver().solve(x)); 1177 } 1178 1179 1180 /** 1181 * Solve least squares matrix equation A x = v by SVD 1182 * @param a 1183 * @param v 1184 * @return x 1185 */ 1186 public static Dataset solveSVD(Dataset a, Dataset v) { 1187 SingularValueDecomposition svd = new SingularValueDecomposition(createRealMatrix(a)); 1188 if (v.getRank() == 1) { 1189 RealVector x = createRealVector(v); 1190 return createDataset(svd.getSolver().solve(x)); 1191 } 1192 RealMatrix x = createRealMatrix(v); 1193 return createDataset(svd.getSolver().solve(x)); 1194 } 1195 1196 /** 1197 * Calculate Cholesky decomposition A = L L^T 1198 * @param a 1199 * @return L 1200 */ 1201 public static Dataset calcCholeskyDecomposition(Dataset a) { 1202 CholeskyDecomposition cd = new CholeskyDecomposition(createRealMatrix(a)); 1203 return createDataset(cd.getL()); 1204 } 1205 1206 /** 1207 * Calculation A x = v by conjugate gradient method with the stopping criterion being 1208 * that the estimated residual r = v - A x satisfies ||r|| < ||v|| with maximum of 100 iterations 1209 * @param a 1210 * @param v 1211 * @return solution of A^-1 v by conjugate gradient method 1212 */ 1213 public static Dataset calcConjugateGradient(Dataset a, Dataset v) { 1214 return calcConjugateGradient(a, v, 100, 1); 1215 } 1216 1217 /** 1218 * Calculation A x = v by conjugate gradient method with the stopping criterion being 1219 * that the estimated residual r = v - A x satisfies ||r|| < delta ||v|| 1220 * @param a 1221 * @param v 1222 * @param maxIterations 1223 * @param delta parameter used by stopping criterion 1224 * @return solution of A^-1 v by conjugate gradient method 1225 */ 1226 public static Dataset calcConjugateGradient(Dataset a, Dataset v, int maxIterations, double delta) { 1227 ConjugateGradient cg = new ConjugateGradient(maxIterations, delta, false); 1228 return createDataset(cg.solve((RealLinearOperator) createRealMatrix(a), createRealVector(v))); 1229 } 1230 1231 private static RealMatrix createRealMatrix(Dataset a) { 1232 if (a.getRank() != 2) { 1233 throw new IllegalArgumentException("Dataset must be rank 2"); 1234 } 1235 int[] shape = a.getShapeRef(); 1236 IndexIterator it = a.getIterator(true); 1237 int[] pos = it.getPos(); 1238 RealMatrix m = MatrixUtils.createRealMatrix(shape[0], shape[1]); 1239 while (it.hasNext()) { 1240 m.setEntry(pos[0], pos[1], a.getElementDoubleAbs(it.index)); 1241 } 1242 return m; 1243 } 1244 1245 private static RealVector createRealVector(Dataset a) { 1246 if (a.getRank() != 1) { 1247 throw new IllegalArgumentException("Dataset must be rank 1"); 1248 } 1249 int size = a.getSize(); 1250 IndexIterator it = a.getIterator(true); 1251 int[] pos = it.getPos(); 1252 RealVector m = new ArrayRealVector(size); 1253 while (it.hasNext()) { 1254 m.setEntry(pos[0], a.getElementDoubleAbs(it.index)); 1255 } 1256 return m; 1257 } 1258 1259 private static Dataset createDataset(RealVector v) { 1260 DoubleDataset r = DatasetFactory.zeros(DoubleDataset.class, v.getDimension()); 1261 int size = r.getSize(); 1262 if (v instanceof ArrayRealVector) { 1263 double[] data = ((ArrayRealVector) v).getDataRef(); 1264 for (int i = 0; i < size; i++) { 1265 r.setAbs(i, data[i]); 1266 } 1267 } else { 1268 for (int i = 0; i < size; i++) { 1269 r.setAbs(i, v.getEntry(i)); 1270 } 1271 } 1272 return r; 1273 } 1274 1275 private static Dataset createDataset(RealMatrix m) { 1276 DoubleDataset r = DatasetFactory.zeros(DoubleDataset.class, m.getRowDimension(), m.getColumnDimension()); 1277 if (m instanceof Array2DRowRealMatrix) { 1278 double[][] data = ((Array2DRowRealMatrix) m).getDataRef(); 1279 IndexIterator it = r.getIterator(true); 1280 int[] pos = it.getPos(); 1281 while (it.hasNext()) { 1282 r.setAbs(it.index, data[pos[0]][pos[1]]); 1283 } 1284 } else { 1285 IndexIterator it = r.getIterator(true); 1286 int[] pos = it.getPos(); 1287 while (it.hasNext()) { 1288 r.setAbs(it.index, m.getEntry(pos[0], pos[1])); 1289 } 1290 } 1291 return r; 1292 } 1293}