前回のエントリでND4J
を利用して、ベクトルの演算を試しました
今回は行列の演算を試してみます。
各種演算
行列の和
行列と行列の足し算です。
■サンプルコード
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; INDArray matrixA = Nd4j.create(new double[]{1.0, 2.0, 3.0, 4.0}, new int[]{2, 2}); INDArray matrixB = Nd4j.create(new double[]{11.0, 12.0, 13.0, 14.0}, new int[]{2, 2}); INDArray matrixRet1 = matrixA.add(matrixB); System.out.println("Ret1: \n" + matrixRet1);
■実行結果
Ret1: [[12.00, 14.00], [16.00, 18.00]]
行列の差
行列と行列の引き算です。
■サンプルコード
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; INDArray matrixA = Nd4j.create(new double[]{1.0, 2.0, 3.0, 4.0}, new int[]{2, 2}); INDArray matrixB = Nd4j.create(new double[]{11.0, 12.0, 13.0, 14.0}, new int[]{2, 2}); INDArray matrixRet2 = matrixA.sub(matrixB); System.out.println("Ret2: \n" + matrixRet2);
■実行結果
Ret2: [[-10.00, -10.00], [-10.00, -10.00]]
行列×スカラー
行列をスカラー倍します。
■サンプルコード
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; INDArray matrixA = Nd4j.create(new double[]{1.0, 2.0, 3.0, 4.0}, new int[]{2, 2}); INDArray matrixRet3 = matrixA.mul(3); System.out.println("Ret3: \n" + matrixRet3);
■実行結果
Ret3: [[3.00, 6.00], [9.00, 12.00]]
行列と行列の積
行列と行列のかけ算です
■サンプルコード
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; INDArray matrixA = Nd4j.create(new double[]{1.0, 2.0, 3.0, 4.0}, new int[]{2, 2}); INDArray matrixB = Nd4j.create(new double[]{11.0, 12.0, 13.0, 14.0}, new int[]{2, 2}); INDArray matrixRet4 = matrixA.mmul(matrixB); System.out.println("Ret4: \n" + matrixRet4);
■実行結果
Ret4: [[37.00, 40.00], [85.00, 92.00]]
零行列との積
行列と零行列のかけ算です
■サンプルコード
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; INDArray matrixA = Nd4j.create(new double[]{1.0, 2.0, 3.0, 4.0}, new int[]{2, 2}); INDArray zeroMatrix = Nd4j.zeros(2, 2); INDArray matrixRet5 = matrixA.mmul(zeroMatrix); System.out.println("Ret5: \n" + matrixRet5);
■実行結果
[[0.00, 0.00], [0.00, 0.00]]
単位行列との積
行列と単位行列のかけ算です
■サンプルコード
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; INDArray matrixA = Nd4j.create(new double[]{1.0, 2.0, 3.0, 4.0}, new int[]{2, 2}); INDArray identityMatrix = Nd4j.eye(2); INDArray matrixRet6 = matrixA.mmul(identityMatrix); System.out.println("Ret6: \n" + matrixRet6);
■実行結果
Ret6: [[1.00, 2.00], [3.00, 4.00]]
転置行列
対象行列Cの転置行列を求めます
■サンプルコード
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; INDArray matrixC = Nd4j.create(new double[]{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}, new int[]{3, 3}); INDArray matrixRet7 = matrixC.transpose(); System.out.println("Raw Matrix: \n" + matrixC); System.out.println("Transposed Matrix: \n" + matrixRet7);
■実行結果
Raw Matrix: [[1.00, 2.00, 3.00], [4.00, 5.00, 6.00], [7.00, 8.00, 9.00]] Transposed Matrix: [[1.00, 4.00, 7.00], [2.00, 5.00, 8.00], [3.00, 6.00, 9.00]]
逆行列
対象行列Aの逆行列を求めます
■サンプルコード
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.inverse.InvertMatrix; INDArray matrixC = Nd4j.create(new double[]{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}, new int[]{3, 3}); INDArray matrixRet8 = InvertMatrix.invert(matrixA, false); System.out.println("Raw Matrix: \n" + matrixA); System.out.println("Invert Matrix: \n" + matrixRet8); // 元の行列×逆行列=単位行列 System.out.println("identity Matrix: \n" + matrixA.mmul(matrixRet8));
■実行結果
Raw Matrix: [[1.00, 2.00], [3.00, 4.00]] Invert Matrix: [[-2.00, 1.00], [1.50, -0.50]] identity Matrix: [[1.00, 0.00], [0.00, 1.00]]