kotlinでonnx推論

以下のリポジトリを分かりやすく整理しました。
onnxruntime/ScoreMNIST.java at main · microsoft/onnxruntime · GitHub


データはこちらからダウンロード。

LIBSVM Data: Classification (Multi Class)


onnxモデルは以下の手順で生成。

onnxの使い方 - LeMU_Researchの日記

import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.OrtEnvironment
import ai.onnxruntime.OrtSession
import java.io.BufferedReader
import java.io.FileReader
import java.util.*
import java.util.regex.Pattern
import kotlin.collections.ArrayList


class Dataset (val labels: Array<Int>, val indices: Array<Array<Int>>, val values: Array<Array<Float>>)


// 0以外の画素値を持つ画素のインデックスと画素値を全データについて抽出
fun load(path: String) : Dataset {
val reader = BufferedReader(FileReader(path))
val labels : ArrayList<Int> = arrayListOf()
val indices : ArrayList<Array<Int>> = arrayListOf()
val values : ArrayList<Array<Float>> = arrayListOf()

while(true) {
val line = reader.readLine() ?: break
val fields = Pattern.compile("\\s+").split(line)
val curIndices : ArrayList<Int> = arrayListOf()
val curValues : ArrayList<Float> = arrayListOf()
for (i in 1 until fields.size) {
val ind = fields[i].indexOf(':')
curIndices.add(fields[i].substring(0, ind).toInt())
curValues.add(fields[i].substring(ind + 1).toFloat())
}
labels.add(fields[0].toInt())
indices.add(curIndices.toTypedArray())
values.add(curValues.toTypedArray())
}
return Dataset(labels.toTypedArray(), indices.toTypedArray(), values.toTypedArray())
}


// SparseDataArrayデータに変換
fun writeData(indices: Array<Int>, values: Array<Float>, width: Int, channels: Int) : Array<Array<Array<FloatArray>>> {
val data = Array(1) { Array(channels) { Array(width) { FloatArray(width) { -1f } } } }
val maxValue = width * width * channels - 1
for (i in indices.indices) {
val index = Integer.min(indices[i], maxValue)
val c = index / (width * width)
val y = index % (width * width) / width
val x = index % width
data[0][c][y][x] = (values[i] / 255 - 0.5f) / 0.5f
}
return data
}


fun main(args: Array<String>) {

val datasetName = args[0]
val onnxName = "$datasetName.onnx"

val width = if (datasetName == "cifar10") 32 else 28
val channels = if (datasetName == "cifar10") 3 else 1

val env = OrtEnvironment.getEnvironment()
val opts = OrtSession.SessionOptions()
val session = env.createSession(onnxName, opts)
val inputName = session.inputNames.iterator().next()
val dataset = load(datasetName)

var correct = 0
for (i in 0 until dataset.labels.size) {
val data = writeData(dataset.indices[i], dataset.values[i], width, channels)
val tensor = OnnxTensor.createTensor(env, data)
val output = session.run(Collections.singletonMap(inputName, tensor))
val probs = output[0].value as? Array<FloatArray> ?: return
val pred = probs[0].withIndex().maxByOrNull { it.value }?.index
if (pred == dataset.labels[i]) correct += 1
if (i % 2000 == 0) println(i)
}
println("accuracy = " + (correct.toFloat() / dataset.labels.size))
}