覚えたら書く

IT関係のデベロッパとして日々覚えたことを書き残したいです。twitter: @yyoshikaw

ND4J - Javaでベクトルとか行列を定義する

Javaでディープラーニングを試そうと思った場合に、名前があがっているライブラリはDeeplearning4J一択な気がします、現状は。
そのDeeplearning4Jの内部で使用されているのが、ND4J(N-Dimensional Arrays for Java)です。

ND4Jは、行列等のN次元配列を簡単に扱えるようにするための数値計算ライブラリです。
Python用に存在している数値計算ライブラリのNumPyを参考して作られたようです。

ディープラーニング云々の前にND4J自体も結構有用そうなので、とりあえず試してみます。
(時間の都合で、今回のエントリでは行列やベクトルを定義するとこまでしかやってません。演算は次の機会に・・・


準備

ND4Jを利用するためにはpom.xmlに以下を追記します。
(本質的なところではないのですが、ND4Jが内部でslf4jでロギングしているようなので、slf4jとLogbackについても依存関係に追加しています。)

<dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-native</artifactId>
    <version>0.7.2</version>
</dependency>

<dependency>
    <groupId>org.slf4j</groupId>
    <artifactId>slf4j-api</artifactId>
    <version>1.7.22</version>
</dependency>
<dependency>
    <groupId>ch.qos.logback</groupId>
    <artifactId>logback-classic</artifactId>
    <version>1.1.8</version>
</dependency>


とりあえず試してみる

基本的にNd4j(org.nd4j.linalg.factory.Nd4j)というファクトリクラスを経由して行列等を生成します。
生成した行列等はINDArray(org.nd4j.linalg.api.ndarray.INDArray)型で返されます。
このINDArrayに対して各種操作(演算)を行う感じになるようです。


列ベクトルを定義

3行1列の列ベクトルを定義して、その中身を確認しています。

■サンプルコード

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays;

// 列ベクトルを生成(値は(10, 20, 30))
INDArray columnVectorA = Nd4j.create(new double[]{10, 20, 30}, new int[]{3, 1});
System.out.println("Vector A: \n" + columnVectorA);

// この値はベクトルか?行列か?
System.out.println("Vector A is Vector?: " + columnVectorA.isVector());
System.out.println("Vector A is Matrix?: " + columnVectorA.isMatrix());

// 行数と列数はいくつか?
System.out.println("Vector A rows: " + columnVectorA.rows());
System.out.println("Vector A columns: " + columnVectorA.columns());
System.out.println("Vector A shape: " + Arrays.toString(columnVectorA.shape()));

■実行結果

Vector A: 
[10.00, 20.00, 30.00]
Vector A is Vector?: true
Vector A is Matrix?: false
Vector A rows: 3
Vector A columns: 1
Vector A shape: [3, 1]


行ベクトルを定義

1行3列の列ベクトルを定義して、その中身を確認しています。

■サンプルコード

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays;

// 行ベクトルを生成(値は(10, 20, 30))
INDArray rowVectorB = Nd4j.create(new double[]{10, 20, 30});
System.out.println("Vector B: \n" + rowVectorB);

// この値はベクトルか?行列か?
System.out.println("Vector B is Vector?: " + rowVectorB.isVector());
System.out.println("Vector B is Matrix?: " + rowVectorB.isMatrix());

// 行数と列数はいくつか?
System.out.println("Vector B rows: " + rowVectorB.rows());
System.out.println("Vector B columns: " + rowVectorB.columns());
System.out.println("Vector B shape: " + Arrays.toString(rowVectorB.shape()));

■実行結果

Vector B: 
[10.00, 20.00, 30.00]
Vector B is Vector?: true
Vector B is Matrix?: false
Vector B rows: 1
Vector B columns: 3
Vector B shape: [1, 3]


行列の定義

2行4列の行列を定義して、その中身を確認しています。

■サンプルコード

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays;

// 行列を生成
INDArray matrix1 = Nd4j.create(new double[][]{
        {1.0, 2.0, 3.0, 4.0},
        {5.0, 6.0, 7.0, 8.0}});

System.out.println("Matrix1: \n" + matrix1);

// この値はベクトルか?行列か?
System.out.println("Matrix1 is Vector?: " + matrix1.isVector());
System.out.println("Matrix1 is Matrix?: " + matrix1.isMatrix());

// 行数と列数はいくつか?
System.out.println("Matrix1 rows: " + matrix1.rows());
System.out.println("Matrix1 columns: " + matrix1.columns());
System.out.println("Matrix1 shape: " + Arrays.toString(matrix1.shape()));

// 行列の要素数はいくつか?
System.out.println("Matrix1 elements: " + matrix1.length());

■実行結果

Matrix1: 
[[1.00, 2.00, 3.00, 4.00],
 [5.00, 6.00, 7.00, 8.00]]
Matrix1 is Vector?: false
Matrix1 is Matrix?: true
Matrix1 rows: 2
Matrix1 columns: 4
Matrix1 shape: [2, 4]
Matrix1 elements: 8


単位行列とゼロ行列の定義

Nd4j#eyeを使うことで単位行列を生成することができます。また、Nd4j#zerosによりゼロ行列を生成することができます

■サンプルコード

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

// 単位行列を生成(3行3列)
INDArray idMatrix = Nd4j.eye(3);
System.out.println("idMatrix: \n" + idMatrix);

System.out.println("-----------------------");

// ゼロ行列を生成(4行3列)
INDArray zeroMatrix = Nd4j.zeros(4, 3);
System.out.println("zeroMatrix: \n" + zeroMatrix);

■実行結果

idMatrix: 
[[1.00, 0.00, 0.00],
 [0.00, 1.00, 0.00],
 [0.00, 0.00, 1.00]]
-----------------------
zeroMatrix: 
[[0.00, 0.00, 0.00],
 [0.00, 0.00, 0.00],
 [0.00, 0.00, 0.00],
 [0.00, 0.00, 0.00]]



関連エントリ