・pthファイルからonnxファイルへの変換
import torch
from torch import onnx
from torchvision import transforms, datasets
def main():
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
model = torch.load('checkpoints/cifar10_resnet18/00004.pth')
data, _ = train_dataset[0]
data = data.unsqueeze(0).cuda()
onnx.export(model, data, "mnist.onnx",
input_names=["input"],
dynamic_axes={
"input": {0: "batch_size", 2: "height", 3: "width"}
})
if __name__ == '__main__':
main()
・onnxによる推論
import onnxruntime
import torch
from torchvision import transforms, datasets
import numpy as np
from tqdm import tqdm
def test_onnx():
session = onnxruntime.InferenceSession('test.onnx', providers=["CUDAExecutionProvider"])
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=False, num_workers=2)
correct = 0
for data, target in tqdm(train_loader):
data = np.array(data)
target = np.array(target)
out = session.run(None, {"input": data})
pred = out[0].argmax(axis=1)
correct += np.equal(pred, np.reshape(target, pred.size)).sum()
print('accuracy = ', correct / len(train_dataset))
if __name__ == '__main__':
test_onnx()