爲了正常運行ONNX,咱們須要安裝最新的Pytorchpython
git clone --recursive https://github.com/pytorch/pytorch cd pytorch mkdir build && cd build sudo cmake .. -DPYTHON_INCLUDE_DIR=/usr/include/python3.5 -DUSE_MPI=OFF make install export PYTHONPATH=$PYTHONPATH:/opt/pytorch/build
上面的"/opt/pytorch/build"是你前面build pytorch的目錄,寫對路徑便可。git
經過整個PyTorch的源碼安裝,PyTorch支持的相關ONNX庫也會隨之安裝好。安裝路徑在:/usr/local/lib/python3.5/dist-packages/torchgithub
運行以下命令安裝ONNX的庫:bash
conda install -c conda-forge onnx
此外,還須要安裝onnx-caffe2,一個純Python庫,它爲ONNX提供了一個caffe2的編譯器。你能夠用pip安裝onnx-caffe2:app
pip3 install onnx-caffe2
在 https://github.com/lindylin1817/pytorch2caffe2 上面的pytorch2caffe2.py就是一段參考代碼,把DeblurGAN訓練好的模型轉換成ONNX 。代碼解釋以下:dom
import os import sys import torch import torch.onnx import torch.utils.model_zoo from torch.autograd import Variable sys.path.append("../DeblurGAN") from models.models import create_model import models.networks as networks from options.test_options import TestOptions import shutil import onnx from onnx_caffe2.backend import Caffe2Backend batch_size = 1 # just a random number # Load the pretrained model weights model_path = './model/char_deblur/latest_net_G.pth' onnx_model_path = "./deblurring.onnx.pb" state_dict = torch.utils.model_zoo.load_url(model_path, model_dir="./model/char_deblur") # Load the DeblurnGAN neural network gan_opt = TestOptions().parse() gan_opt.name = "char_deblur" gan_opt.checkpoints_dir = "./model/" gan_opt.model = "test" gan_opt.dataset_mode = "single" gan_opt.dataroot = "/tmp/gan/" try: shutil.rmtree(gan_opt.dataroot) except: pass os.mkdir(gan_opt.dataroot) gan_opt.loadSizeX = 64 gan_opt.loadSizeY = 64 gan_opt.fineSize = 64 gan_opt.learn_residual = True gan_opt.nThreads = 1 # test code only supports nThreads = 1 gan_opt.batchSize = 1 # test code only supports batchSize = 1 gan_opt.serial_batches = True # no shuffle gan_opt.no_flip = True # no flip #torch_model = create_model(gan_opt) gpus = [] torch_model = networks.define_G(gan_opt.input_nc, gan_opt.output_nc, gan_opt.ngf, gan_opt.which_model_netG, gan_opt.norm, not gan_opt.no_dropout, gpus, False, gan_opt.learn_residual) torch_model.load_state_dict(state_dict) #torch_model.load_state_dict(state_dict) # set the train mode to false since we will only run the forward pass. torch_model.train(False) # Input to the model x = Variable(torch.randn(batch_size, 3, 60, 60), requires_grad=True) x = x.float() # Export the model torch_out = torch.onnx._export(torch_model, # model being run x, # model input (or a tuple for multiple inputs) onnx_model_path, # where to save the model (can be a file or file-like object) verbose=True, export_params=True, training=False) # store the trained parameter weights inside the model file onnx_model = onnx.load(onnx_model_path) onnx.checker.check_model(onnx_model) model_name = onnx_model_path.replace('.onnx.pb','') init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model.graph, device="CUDA") with open(model_name + "_init.pb", "wb") as f: f.write(init_net.SerializeToString()) with open(model_name + "_predict.pb", "wb") as f: f.write(predict_net.SerializeToString())
基於這個例子中,用戶須要本身修改的部分有以下幾個:ide
經過上面的代碼,將生成兩個Caffe2的pb文件,deblurring_init.pb 和 deblurring_predict.pb。ui