TensorFlow Javaの使い方

Maven Repository
Maven Repository: org.tensorflow » tensorflow-core-platform
※旧版と間違えないよう注意

・Example1
Python学習

import tensorflow as tf
from tensorflow.keras import Model

class MyModel(Model):
def __init__(self):
super().__init__()

def call(self, x):
return x * 2

def main():
model = MyModel()
model(3)
tf.saved_model.save(model, 'minimum_model')

if __name__ == '__main__':
main()


Java推論

public class Minimum {
public static void main(String[] args) {
SavedModelBundle model = SavedModelBundle.load("minimum_model");

IntNdArray input_matrix = NdArrays.ofInts(Shape.of(1));
input_matrix.set(NdArrays.vectorOf(3));
Tensor input_tensor = TInt32.tensorOf(input_matrix);

List<Tensor> outputsList = null;
outputsList = model.session().runner().feed("serving_default_input_1:0", input_tensor).fetch("PartitionedCall:0").run();
Tensor result = outputsList.get(0);
int scores = result.asRawTensor().data().asInts().getInt(0);
}
}

※feedとfetchの引数はsaved_model_cliで確認


・Example2
Python学習

import tensorflow as tf
from tensorflow.keras import Model, layers
import numpy as np

class MyModel2(Model):
def __init__(self):
super().__init__()
self.fc1 = layers.Dense(128, activation='relu')
self.fc2 = layers.Dense(10, activation='softmax')

def call(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x

def main():

data = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])

model = MyModel2()
out = model(data)
print(out)
tf.saved_model.save(model, 'minimum_model2')

if __name__ == '__main__':
main()


Java推論

public class Minimum2 {
public static void main(String[] args) {

SavedModelBundle model = SavedModelBundle.load("minimum_model2");

FloatNdArray input_matrix = NdArrays.ofFloats(Shape.of(2, 3));
// input_matrix.set(NdArrays.vectorOf(1.0f, 2.0f, 3.0f), 0);
// input_matrix.set(NdArrays.vectorOf(4.0f, 5.0f, 6.0f), 1);
input_matrix.setFloat(1.0f, 0, 0);
input_matrix.setFloat(2.0f, 0, 1);
input_matrix.setFloat(3.0f, 0, 2);
input_matrix.setFloat(4.0f, 1, 0);
input_matrix.setFloat(5.0f, 1, 1);
input_matrix.setFloat(6.0f, 1, 2);

Tensor input_tensor = TFloat32.tensorOf(input_matrix);

List<Tensor> outputsList = null;
outputsList = model.session().runner().feed("serving_default_input_1:0", input_tensor).fetch("StatefulPartitionedCall:0").run();
Tensor result = outputsList.get(0);
float scores = result.asRawTensor().data().asFloats().getFloat(12);
}
}


・Example3(複数出力)
Python学習

import tensorflow as tf
from tensorflow.keras import Model, layers
import numpy as np

class MyModel3(Model):
def __init__(self):
super().__init__()
self.fc1 = layers.Dense(128, activation='relu')
self.fc2 = layers.Dense(10, activation='softmax')

def call(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x, x + 0.1

def main():

data = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])

model = MyModel3()
out = model(data)
print(out)
tf.saved_model.save(model, 'minimum_model3')

if __name__ == '__main__':
main()


Java推論

public class Minimum2 {
public static void main(String[] args) {

SavedModelBundle model = SavedModelBundle.load("minimum_model2");

FloatNdArray input_matrix = NdArrays.ofFloats(Shape.of(2, 3));
input_matrix.set(NdArrays.vectorOf(1.0f, 2.0f, 3.0f), 0);
input_matrix.set(NdArrays.vectorOf(4.0f, 5.0f, 6.0f), 1);

Tensor input_tensor = TFloat32.tensorOf(input_matrix);

List<Tensor> outputsList = null;
outputsList = model.session().runner().feed("serving_default_input_1:0", input_tensor).fetch("StatefulPartitionedCall:0").run();
Tensor result = outputsList.get(0);
float scores = result.asRawTensor().data().asFloats().getFloat(12);
}
}


・Example4(画像分類)

public class Buffered {
public static void main(String[] args) {

try {
SavedModelBundle model = SavedModelBundle.load("checkpoints");

BufferedImage img0 = ImageIO.read(new File("data/00000_03_050_025/000000.png"));
BufferedImage img1 = ImageIO.read(new File("data/00000_04_060_050/000000.png"));
FloatNdArray input_matrix = NdArrays.ofFloats(Shape.of(2, img0.getHeight(), img0.getWidth(), 3));
for (int h = 0; h < img0.getHeight(); h++) {
for (int w = 0; w < img0.getWidth(); w++) {
for (int c = 0; c < 3; c++) {
input_matrix.setFloat((img0.getRGB(w, h) & 0xFF) / 255.0f, 0, h, w, c);
input_matrix.setFloat((img1.getRGB(w, h) & 0xFF) / 255.0f, 1, h, w, c);
}
}
}
Tensor input_tensor = TFloat32.tensorOf(input_matrix);

List<Tensor> outputsList = null;
outputsList = model.session().runner().feed("serving_default_input_1:0", input_tensor)
.fetch("StatefulPartitionedCall:0").fetch("StatefulPartitionedCall:1").fetch("StatefulPartitionedCall:2").run();

// size, flow, hardness
Tensor result0 = outputsList.get(0);
Tensor result1 = outputsList.get(1);
Tensor result2 = outputsList.get(2);

for (int i = 0; i < result0.asRawTensor().size(); i++) {
System.out.println(result0.asRawTensor().data().asFloats().getFloat(i));
}
System.out.println();
for (int i = 0; i < result1.asRawTensor().size(); i++) {
System.out.println(result1.asRawTensor().data().asFloats().getFloat(i));
}
System.out.println();
for (int i = 0; i < result2.asRawTensor().size(); i++) {
System.out.println(result2.asRawTensor().data().asFloats().getFloat(i));
}
} catch (Exception e) {
e.printStackTrace();
}
}
}