kotlinのonnxでmmsegmentationの推論

・ネットワーク選定
onnxでAdaptive Poolingが使えないので、Adaptive Poolingを使ってないネットワークを選定する。
mmsegmentationのデモで使われていpspnetはダメ。
fcnは大丈夫。

・onnxモデルへの変換
直接onnxモデルへ変換しようとするとエラーが出る。
以下のように、単純なネットワークを定義して、それをonnxモデルに変換する。

class SimpleModel(nn.Module):
def __init__(self, orig_model):
super().__init__()
self.backbone = orig_model.backbone
self.decode_head = orig_model.decode_head

def forward(self, x):
x = self.backbone(x)
out = self.decode_head(x)
return out


・test_poipelineのtransformsによる置き換え
リサイズと正規化をtransformsで置き換える。

long_side = 480
scale = long_side / max(img.width, img.height)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((round(img.height * scale), round(img.width * scale))),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
data = transform(img).unsqueeze(0).numpy()

 

・kotlinでの推論

import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.OrtEnvironment
import ai.onnxruntime.OrtSession
import java.awt.Color
import java.awt.Image
import java.awt.image.BufferedImage
import java.io.File
import java.lang.Integer.max
import java.util.*
import javax.imageio.ImageIO
import kotlin.math.roundToInt


fun makeData(img: BufferedImage) : Array<Array<Array<FloatArray>>> {
val data = Array(1) { Array(3) { Array(img.height) { FloatArray(img.width) { 0f } } } }
for (y in 0 until img.height) {
for (x in 0 until img.width) {
val color = Color(img.getRGB(x, y))
data[0][0][y][x] = (color.red.toFloat() / 255f - 0.485f) / 0.229f
data[0][1][y][x] = (color.green.toFloat() / 255f - 0.456f) / 0.224f
data[0][2][y][x] = (color.blue.toFloat() / 255f - 0.406f) / 0.225f
}
}
return data
}

fun makeSegmentImage(probs: Array<Array<Array<FloatArray>>>, classNum: Int) : BufferedImage {
val height = probs[0][0].size
val width = probs[0][0][0].size
val segImg = BufferedImage(width, height, BufferedImage.TYPE_3BYTE_BGR)
for (y in 0 until height) {
for (x in 0 until width) {
var maxValue = -Float.MAX_VALUE
var maxIdx = 0
for (c in 1 until classNum) {
if (probs[0][c][y][x] > maxValue) {
maxValue = probs[0][c][y][x]
maxIdx = c
}
}
segImg.setRGB(x, y, Color(maxIdx, maxIdx, maxIdx).rgb)
}
}
return segImg
}


fun main(args: Array<String>) {

val onnxName = "mmseg.onnx"
val longSide = 480f
val classNum = 5

val origImg = ImageIO.read(File("chitose000.png"))
val scale = longSide / max(origImg.width, origImg.height).toFloat()
val width = (origImg.width * scale).roundToInt()
val height = (origImg.height * scale).roundToInt()
val img = BufferedImage(width, height, origImg.type)
img.createGraphics().drawImage(
origImg.getScaledInstance(width, height, Image.SCALE_AREA_AVERAGING), 0, 0, width, height, null
)

val data = makeData(img)

val env = OrtEnvironment.getEnvironment()
val opts = OrtSession.SessionOptions()
val session = env.createSession(onnxName, opts)
val inputName = session.inputNames.iterator().next()

val tensor = OnnxTensor.createTensor(env, data)
val output = session.run(Collections.singletonMap(inputName, tensor))
val probs = output[0].value as? Array<Array<Array<FloatArray>>> ?: return
val segImg = makeSegmentImage(probs, classNum)
ImageIO.write(segImg, "png", File("seg.png"))
}

 

kotlinでonnx推論

以下のリポジトリを分かりやすく整理しました。
onnxruntime/ScoreMNIST.java at main · microsoft/onnxruntime · GitHub


データはこちらからダウンロード。

LIBSVM Data: Classification (Multi Class)


onnxモデルは以下の手順で生成。

onnxの使い方 - LeMU_Researchの日記

import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.OrtEnvironment
import ai.onnxruntime.OrtSession
import java.io.BufferedReader
import java.io.FileReader
import java.util.*
import java.util.regex.Pattern
import kotlin.collections.ArrayList


class Dataset (val labels: Array<Int>, val indices: Array<Array<Int>>, val values: Array<Array<Float>>)


// 0以外の画素値を持つ画素のインデックスと画素値を全データについて抽出
fun load(path: String) : Dataset {
val reader = BufferedReader(FileReader(path))
val labels : ArrayList<Int> = arrayListOf()
val indices : ArrayList<Array<Int>> = arrayListOf()
val values : ArrayList<Array<Float>> = arrayListOf()

while(true) {
val line = reader.readLine() ?: break
val fields = Pattern.compile("\\s+").split(line)
val curIndices : ArrayList<Int> = arrayListOf()
val curValues : ArrayList<Float> = arrayListOf()
for (i in 1 until fields.size) {
val ind = fields[i].indexOf(':')
curIndices.add(fields[i].substring(0, ind).toInt())
curValues.add(fields[i].substring(ind + 1).toFloat())
}
labels.add(fields[0].toInt())
indices.add(curIndices.toTypedArray())
values.add(curValues.toTypedArray())
}
return Dataset(labels.toTypedArray(), indices.toTypedArray(), values.toTypedArray())
}


// SparseDataArrayデータに変換
fun writeData(indices: Array<Int>, values: Array<Float>, width: Int, channels: Int) : Array<Array<Array<FloatArray>>> {
val data = Array(1) { Array(channels) { Array(width) { FloatArray(width) { -1f } } } }
val maxValue = width * width * channels - 1
for (i in indices.indices) {
val index = Integer.min(indices[i], maxValue)
val c = index / (width * width)
val y = index % (width * width) / width
val x = index % width
data[0][c][y][x] = (values[i] / 255 - 0.5f) / 0.5f
}
return data
}


fun main(args: Array<String>) {

val datasetName = args[0]
val onnxName = "$datasetName.onnx"

val width = if (datasetName == "cifar10") 32 else 28
val channels = if (datasetName == "cifar10") 3 else 1

val env = OrtEnvironment.getEnvironment()
val opts = OrtSession.SessionOptions()
val session = env.createSession(onnxName, opts)
val inputName = session.inputNames.iterator().next()
val dataset = load(datasetName)

var correct = 0
for (i in 0 until dataset.labels.size) {
val data = writeData(dataset.indices[i], dataset.values[i], width, channels)
val tensor = OnnxTensor.createTensor(env, data)
val output = session.run(Collections.singletonMap(inputName, tensor))
val probs = output[0].value as? Array<FloatArray> ?: return
val pred = probs[0].withIndex().maxByOrNull { it.value }?.index
if (pred == dataset.labels[i]) correct += 1
if (i % 2000 == 0) println(i)
}
println("accuracy = " + (correct.toFloat() / dataset.labels.size))
}

loss, accuracyの計算方法まとめ

criterionのreductionはデフォルトで'mean'。
len(data_loader)はバッチ数。

なので、epochの平均lossの算出は、バッチごとの(平均)lossを加算していって、len(data_loader)で割ればよい。

正解数は

correct += pred.eq(target.view_as(pred)).sum().item()

で足していく。
これはバッチ数len(data_loader)ではなく、データ数len(dataset)で割ることで、accuracyを算出する。

numpy(onnx)の場合は、

correct += np.equal(pred, np.reshape(target, pred.size)).sum()

で、正解数をカウントしていく。

onnxの使い方

・pthファイルからonnxファイルへの変換

import torch
from torch import onnx
from torchvision import transforms, datasets


def main():

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

model = torch.load('checkpoints/cifar10_resnet18/00004.pth')

data, _ = train_dataset[0]
data = data.unsqueeze(0).cuda()
onnx.export(model, data, "mnist.onnx",
input_names=["input"],
dynamic_axes={
"input": {0: "batch_size", 2: "height", 3: "width"}
})


if __name__ == '__main__':
main()


・onnxによる推論

import onnxruntime
import torch
from torchvision import transforms, datasets
import numpy as np
from tqdm import tqdm


def test_onnx():

session = onnxruntime.InferenceSession('test.onnx', providers=["CUDAExecutionProvider"])

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=False, num_workers=2)

correct = 0
for data, target in tqdm(train_loader):

data = np.array(data)
target = np.array(target)
out = session.run(None, {"input": data})

pred = out[0].argmax(axis=1)
correct += np.equal(pred, np.reshape(target, pred.size)).sum()
print('accuracy = ', correct / len(train_dataset))


if __name__ == '__main__':
test_onnx()

 

mmsegmentation使い方

・git clone
GitHub - open-mmlab/mmsegmentation: OpenMMLab Semantic Segmentation Toolbox and Benchmark.

・get_started.md

step0にmim install mmengineを追加。

mmsegmentation/get_started.md at master · open-mmlab/mmsegmentation · GitHub

・自前データの学習

MMSegmentationによる多数クラス画像(Multi Class)のセマンティックセグメンテーション(Semantic Segmentation). - Qiita
4節を参照。

https://github.com/alexgkendall/SegNet-Tutorial をmmsegmentation直下に置く。
以下の学習データ生成コードを実行。

import os.path as osp
import mmcv


def main():
data_root = 'SegNet-Tutorial/CamVid'
ann_dir = 'trainannot'

split_dir = 'splits_resnet50A'
mmcv.mkdir_or_exist(osp.join(data_root, split_dir))
filename_list = [osp.splitext(filename)[0] for filename in mmcv.scandir(osp.join(data_root, ann_dir), suffix='.png')]
with open(osp.join(data_root, split_dir, 'train.txt'), 'w') as f:
train_length = int(len(filename_list)*4/5)
f.writelines(line + '\n' for line in filename_list[:train_length])
with open(osp.join(data_root, split_dir, 'val.txt'), 'w') as f:
f.writelines(line + '\n' for line in filename_list[train_length:])


if __name__ == '__main__':
main()

CamVidフォルダには元々train.txtなどが含まれるが、mmdetectionのフォーマットに合っていないのでそのままは使えない。
新たに生成したtrain.txtなどはフォルダ名や拡張子を除くbasenameのみの羅列。

その後、以下の学習コードを実行。

import os.path as osp
import mmcv
from mmcv import Config
from mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset
from mmseg.datasets import build_dataset
from mmseg.apis import set_random_seed
from mmseg.models import build_segmentor
from mmseg.apis import train_segmentor


data_root = 'SegNet-Tutorial/CamVid'
img_dir = 'train'
ann_dir = 'trainannot'

classes = (
'sky', 'Bulding', 'Pole', 'Road_marking', 'Road', 'Pavement', 'Tree',
'SingSymbole', 'Fence', 'Car', 'Pedestrian', 'Bicyclist'
)

palette = [
[128, 128, 128], [128, 0, 0], [192, 192, 128], [255, 69, 0], [128, 64, 128], [60, 40, 222],
[128, 128, 0], [192, 128, 128], [64, 64, 128], [64, 0, 128], [64, 64, 0], [0, 128, 192]
]


@DATASETS.register_module()
class splits_resnet50A(CustomDataset):

CLASSES = classes
PALETTE = palette

def __init__(self, split, **kwargs):
super().__init__(img_suffix='.png', seg_map_suffix='.png', split=split, **kwargs)
assert osp.exists(self.img_dir) and self.split is not None


def main():

w = 480
h = 360

cfg = Config.fromfile('configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py')

cfg.checkpoint_config.meta = dict(CLASSES=classes, PALETTE=palette)
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.num_classes = len(classes)
cfg.model.auxiliary_head.num_classes = len(classes)

cfg.dataset_type = 'splits_resnet50A'
cfg.data_root = data_root

cfg.data.samples_per_gpu = 8
cfg.data.workers_per_gpu = 8

cfg.img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
cfg.crop_size = (256, 256)
cfg.train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
# dict(type='Resize', img_scale=(w, h), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=cfg.crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', flip_ratio=0.5),
# dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **cfg.img_norm_cfg),
# dict(type='Pad', size=cfg.crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]

cfg.test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(w, h),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
# dict(type='RandomFlip'),
dict(type='Normalize', **cfg.img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]

cfg.data.train.type = cfg.dataset_type
cfg.data.train.data_root = cfg.data_root
cfg.data.train.img_dir = img_dir
cfg.data.train.ann_dir = ann_dir
cfg.data.train.pipeline = cfg.train_pipeline
cfg.data.train.split = 'splits_resnet50A/train.txt' # 3)

cfg.data.val.type = cfg.dataset_type
cfg.data.val.data_root = cfg.data_root
cfg.data.val.img_dir = img_dir
cfg.data.val.ann_dir = ann_dir
cfg.data.val.pipeline = cfg.test_pipeline
cfg.data.val.split = 'splits_resnet50A/val.txt' # 3)

cfg.work_dir = './work_dirs/tutorial_pspnet_r50A'

cfg.runner.max_iters = 40000
cfg.log_config.interval = 10
cfg.evaluation.interval = 200
cfg.checkpoint_config.interval = 1000

cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)
cfg.device = 'cuda'

datasets = [build_dataset(cfg.data.train)]
model = build_segmentor(
cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg')
)
model.CLASSES = datasets[0].CLASSES

mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
train_segmentor(model, datasets, cfg, distributed=False, validate=True, meta=dict())


if __name__ == '__main__':
main()


trainフォルダの画像は480x360のRGB画像。
trainannotフォルダの画像は480x360のラベル画像(画素値=ラベル)。
例えば、画素値0の場所はsky。