onnxの使い方

・pthファイルからonnxファイルへの変換

import torch
from torch import onnx
from torchvision import transforms, datasets


def main():

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

model = torch.load('checkpoints/cifar10_resnet18/00004.pth')

data, _ = train_dataset[0]
data = data.unsqueeze(0).cuda()
onnx.export(model, data, "mnist.onnx",
input_names=["input"],
dynamic_axes={
"input": {0: "batch_size", 2: "height", 3: "width"}
})


if __name__ == '__main__':
main()


・onnxによる推論

import onnxruntime
import torch
from torchvision import transforms, datasets
import numpy as np
from tqdm import tqdm


def test_onnx():

session = onnxruntime.InferenceSession('test.onnx', providers=["CUDAExecutionProvider"])

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=False, num_workers=2)

correct = 0
for data, target in tqdm(train_loader):

data = np.array(data)
target = np.array(target)
out = session.run(None, {"input": data})

pred = out[0].argmax(axis=1)
correct += np.equal(pred, np.reshape(target, pred.size)).sum()
print('accuracy = ', correct / len(train_dataset))


if __name__ == '__main__':
test_onnx()