kotlinのonnxでmmsegmentationの推論

・ネットワーク選定
onnxでAdaptive Poolingが使えないので、Adaptive Poolingを使ってないネットワークを選定する。
mmsegmentationのデモで使われていpspnetはダメ。
fcnは大丈夫。

・onnxモデルへの変換
直接onnxモデルへ変換しようとするとエラーが出る。
以下のように、単純なネットワークを定義して、それをonnxモデルに変換する。

class SimpleModel(nn.Module):
def __init__(self, orig_model):
super().__init__()
self.backbone = orig_model.backbone
self.decode_head = orig_model.decode_head

def forward(self, x):
x = self.backbone(x)
out = self.decode_head(x)
return out


・test_poipelineのtransformsによる置き換え
リサイズと正規化をtransformsで置き換える。

long_side = 480
scale = long_side / max(img.width, img.height)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((round(img.height * scale), round(img.width * scale))),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
data = transform(img).unsqueeze(0).numpy()

 

・kotlinでの推論

import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.OrtEnvironment
import ai.onnxruntime.OrtSession
import java.awt.Color
import java.awt.Image
import java.awt.image.BufferedImage
import java.io.File
import java.lang.Integer.max
import java.util.*
import javax.imageio.ImageIO
import kotlin.math.roundToInt


fun makeData(img: BufferedImage) : Array<Array<Array<FloatArray>>> {
val data = Array(1) { Array(3) { Array(img.height) { FloatArray(img.width) { 0f } } } }
for (y in 0 until img.height) {
for (x in 0 until img.width) {
val color = Color(img.getRGB(x, y))
data[0][0][y][x] = (color.red.toFloat() / 255f - 0.485f) / 0.229f
data[0][1][y][x] = (color.green.toFloat() / 255f - 0.456f) / 0.224f
data[0][2][y][x] = (color.blue.toFloat() / 255f - 0.406f) / 0.225f
}
}
return data
}

fun makeSegmentImage(probs: Array<Array<Array<FloatArray>>>, classNum: Int) : BufferedImage {
val height = probs[0][0].size
val width = probs[0][0][0].size
val segImg = BufferedImage(width, height, BufferedImage.TYPE_3BYTE_BGR)
for (y in 0 until height) {
for (x in 0 until width) {
var maxValue = -Float.MAX_VALUE
var maxIdx = 0
for (c in 1 until classNum) {
if (probs[0][c][y][x] > maxValue) {
maxValue = probs[0][c][y][x]
maxIdx = c
}
}
segImg.setRGB(x, y, Color(maxIdx, maxIdx, maxIdx).rgb)
}
}
return segImg
}


fun main(args: Array<String>) {

val onnxName = "mmseg.onnx"
val longSide = 480f
val classNum = 5

val origImg = ImageIO.read(File("chitose000.png"))
val scale = longSide / max(origImg.width, origImg.height).toFloat()
val width = (origImg.width * scale).roundToInt()
val height = (origImg.height * scale).roundToInt()
val img = BufferedImage(width, height, origImg.type)
img.createGraphics().drawImage(
origImg.getScaledInstance(width, height, Image.SCALE_AREA_AVERAGING), 0, 0, width, height, null
)

val data = makeData(img)

val env = OrtEnvironment.getEnvironment()
val opts = OrtSession.SessionOptions()
val session = env.createSession(onnxName, opts)
val inputName = session.inputNames.iterator().next()

val tensor = OnnxTensor.createTensor(env, data)
val output = session.run(Collections.singletonMap(inputName, tensor))
val probs = output[0].value as? Array<Array<Array<FloatArray>>> ?: return
val segImg = makeSegmentImage(probs, classNum)
ImageIO.write(segImg, "png", File("seg.png"))
}