



關於模型訓練有兩種方法,一種是直接使用C++編寫訓練代碼,能夠作到搭建完整的網絡模型,可是沒法使用遷移學習,而遷移學習是目前訓練樣本幾乎都會用到的方法,另外一種是使用python代碼訓練好模型,並使用JIT技術,將python模型導出爲C++可調用的模型,這裏具體介紹第二種。(我的以爲還能夠採用一種方式,即將pytorch模型做爲一種Web Service以供各類客戶端調用)linux


TorchScript是一種從PyTorch代碼建立可序列化和可優化模型的方法。用TorchScript編寫的任何代碼均可以從Python進程中保存並加載到沒有Python依賴關係的進程中。 咱們提供了一些工具來增量地將模型從純Python程序轉換爲可以獨立於Python運行的TorchScript程序,例如,在一個獨立的c++程序中。這使得使用熟悉的工具在PyTorch中培訓模型,而後經過TorchScript將模型導出到生產環境中成爲可能。在生產環境中,出於性能和多線程的緣由,將模型做爲Python程序運行不是一個好主意。




1.python環境下跑模型的推斷代碼 github

以ESRGAN的inference code(https://github.com/xinntao/ESRGAN)爲例:segmentfault


Python packages: pip install numpy opencv-python

直接run test,結果以下(個人版本有作一些改動,如增長FPS的計算等):

 2.將PyTorch模型轉換爲Torch Script


第二個方法就是向模型添加顯式註釋,通知Torch Script編譯器它能夠直接解析和編譯模型代碼,受Torch Script語言強加的約束。

  • 利用Tracing將模型轉換爲Torch Script



import torch
import architecture as arch

# An instance of your model.
model = arch.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
                        mode='CNA', res_scale=1, upsample_mode='upconv')

model.load_state_dict(torch.load('./models/RRDB_ESRGAN_x4.pth'), strict=True)

# An example input you would normally provide to your model's forward() method.
example = torch.rand(64, 3, 3, 3)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
output = traced_script_module(torch.ones(64, 3, 3, 3))

# The traced ScriptModule can now be evaluated identically to a regular PyTorch module


(surper-resolution-pytorch) anpi-cn@anpi-cn:~/workspace_min/Super-Resolution/ESRGAN$ python model_jit_converter.py 
tensor([[[[0.9618, 1.0375, 1.0242,  ..., 1.0049, 1.0399, 1.0255],
          [1.0199, 0.9996, 1.0096,  ..., 1.0269, 1.0140, 1.0267],
          [1.0290, 1.0154, 1.0161,  ..., 1.0201, 1.0077, 1.0298],
          [1.0316, 1.0139, 1.0184,  ..., 1.0184, 1.0179, 1.0197],
          [1.0391, 1.0174, 1.0162,  ..., 1.0185, 1.0443, 1.0168],
          [1.0066, 1.0186, 0.9976,  ..., 1.0143, 1.0066, 1.0249]],

         [[1.0155, 1.0491, 1.0004,  ..., 0.9993, 0.9828, 0.9706],
          [0.9992, 1.0149, 1.0032,  ..., 0.9851, 0.9937, 0.9887],
          [0.9974, 1.0106, 1.0089,  ..., 1.0072, 1.0074, 1.0041],
          [1.0130, 1.0036, 1.0059,  ..., 0.9979, 1.0065, 1.0133],
          [1.0066, 0.9955, 1.0034,  ..., 1.0030, 0.9875, 1.0011],
          [0.9788, 0.9983, 1.0113,  ..., 1.0106, 1.0381, 1.0248]],

         [[0.9570, 0.9789, 0.9720,  ..., 0.9920, 0.9740, 0.9940],
          [0.9522, 1.0182, 1.0109,  ..., 1.0181, 1.0060, 0.9842],
          [0.9872, 1.0062, 1.0112,  ..., 1.0172, 1.0072, 0.9803],
          [1.0211, 1.0119, 1.0091,  ..., 1.0082, 1.0339, 1.0348],
          [0.9894, 1.0227, 1.0226,  ..., 0.9930, 1.0258, 1.0234],
          [0.9997, 0.9755, 0.9969,  ..., 1.0227, 1.0308, 1.0109]]],

        [[[0.9618, 1.0375, 1.0242,  ..., 1.0049, 1.0399, 1.0255],
          [1.0199, 0.9996, 1.0096,  ..., 1.0269, 1.0140, 1.0267],
          [1.0290, 1.0154, 1.0161,  ..., 1.0201, 1.0077, 1.0298],
          [1.0316, 1.0139, 1.0184,  ..., 1.0184, 1.0179, 1.0197],
          [1.0391, 1.0174, 1.0162,  ..., 1.0185, 1.0443, 1.0168],
          [1.0066, 1.0186, 0.9976,  ..., 1.0143, 1.0066, 1.0249]],

         [[1.0155, 1.0491, 1.0004,  ..., 0.9993, 0.9828, 0.9706],
          [0.9992, 1.0149, 1.0032,  ..., 0.9851, 0.9937, 0.9887],
          [0.9974, 1.0106, 1.0089,  ..., 1.0072, 1.0074, 1.0041],
          [1.0130, 1.0036, 1.0059,  ..., 0.9979, 1.0065, 1.0133],
          [1.0066, 0.9955, 1.0034,  ..., 1.0030, 0.9875, 1.0011],
          [0.9788, 0.9983, 1.0113,  ..., 1.0106, 1.0381, 1.0248]],

         [[0.9570, 0.9789, 0.9720,  ..., 0.9920, 0.9740, 0.9940],
          [0.9522, 1.0182, 1.0109,  ..., 1.0181, 1.0060, 0.9842],
          [0.9872, 1.0062, 1.0112,  ..., 1.0172, 1.0072, 0.9803],
          [1.0211, 1.0119, 1.0091,  ..., 1.0082, 1.0339, 1.0348],
          [0.9894, 1.0227, 1.0226,  ..., 0.9930, 1.0258, 1.0234],
          [0.9997, 0.9755, 0.9969,  ..., 1.0227, 1.0308, 1.0109]]],

        [[[0.9618, 1.0375, 1.0242,  ..., 1.0049, 1.0399, 1.0255],
          [1.0199, 0.9996, 1.0096,  ..., 1.0269, 1.0140, 1.0267],
          [1.0290, 1.0154, 1.0161,  ..., 1.0201, 1.0077, 1.0298],
          [1.0316, 1.0139, 1.0184,  ..., 1.0184, 1.0179, 1.0197],
          [1.0391, 1.0174, 1.0162,  ..., 1.0185, 1.0443, 1.0168],
          [1.0066, 1.0186, 0.9976,  ..., 1.0143, 1.0066, 1.0249]],

         [[1.0155, 1.0491, 1.0004,  ..., 0.9993, 0.9828, 0.9706],
          [0.9992, 1.0149, 1.0032,  ..., 0.9851, 0.9937, 0.9887],
          [0.9974, 1.0106, 1.0089,  ..., 1.0072, 1.0074, 1.0041],
          [1.0130, 1.0036, 1.0059,  ..., 0.9979, 1.0065, 1.0133],
          [1.0066, 0.9955, 1.0034,  ..., 1.0030, 0.9875, 1.0011],
          [0.9788, 0.9983, 1.0113,  ..., 1.0106, 1.0381, 1.0248]],

         [[0.9570, 0.9789, 0.9720,  ..., 0.9920, 0.9740, 0.9940],
          [0.9522, 1.0182, 1.0109,  ..., 1.0181, 1.0060, 0.9842],
          [0.9872, 1.0062, 1.0112,  ..., 1.0172, 1.0072, 0.9803],
          [1.0211, 1.0119, 1.0091,  ..., 1.0082, 1.0339, 1.0348],
          [0.9894, 1.0227, 1.0226,  ..., 0.9930, 1.0258, 1.0234],
          [0.9997, 0.9755, 0.9969,  ..., 1.0227, 1.0308, 1.0109]]],


        [[[0.9618, 1.0375, 1.0242,  ..., 1.0049, 1.0399, 1.0255],
          [1.0199, 0.9996, 1.0096,  ..., 1.0269, 1.0140, 1.0267],
          [1.0290, 1.0154, 1.0161,  ..., 1.0201, 1.0077, 1.0298],
          [1.0316, 1.0139, 1.0184,  ..., 1.0184, 1.0179, 1.0197],
          [1.0391, 1.0174, 1.0162,  ..., 1.0185, 1.0443, 1.0168],
          [1.0066, 1.0186, 0.9976,  ..., 1.0143, 1.0066, 1.0249]],

         [[1.0155, 1.0491, 1.0004,  ..., 0.9993, 0.9828, 0.9706],
          [0.9992, 1.0149, 1.0032,  ..., 0.9851, 0.9937, 0.9887],
          [0.9974, 1.0106, 1.0089,  ..., 1.0072, 1.0074, 1.0041],
          [1.0130, 1.0036, 1.0059,  ..., 0.9979, 1.0065, 1.0133],
          [1.0066, 0.9955, 1.0034,  ..., 1.0030, 0.9875, 1.0011],
          [0.9788, 0.9983, 1.0113,  ..., 1.0106, 1.0381, 1.0248]],

         [[0.9570, 0.9789, 0.9720,  ..., 0.9920, 0.9740, 0.9940],
          [0.9522, 1.0182, 1.0109,  ..., 1.0181, 1.0060, 0.9842],
          [0.9872, 1.0062, 1.0112,  ..., 1.0172, 1.0072, 0.9803],
          [1.0211, 1.0119, 1.0091,  ..., 1.0082, 1.0339, 1.0348],
          [0.9894, 1.0227, 1.0226,  ..., 0.9930, 1.0258, 1.0234],
          [0.9997, 0.9755, 0.9969,  ..., 1.0227, 1.0308, 1.0109]]],

        [[[0.9618, 1.0375, 1.0242,  ..., 1.0049, 1.0399, 1.0255],
          [1.0199, 0.9996, 1.0096,  ..., 1.0269, 1.0140, 1.0267],
          [1.0290, 1.0154, 1.0161,  ..., 1.0201, 1.0077, 1.0298],
          [1.0316, 1.0139, 1.0184,  ..., 1.0184, 1.0179, 1.0197],
          [1.0391, 1.0174, 1.0162,  ..., 1.0185, 1.0443, 1.0168],
          [1.0066, 1.0186, 0.9976,  ..., 1.0143, 1.0066, 1.0249]],

         [[1.0155, 1.0491, 1.0004,  ..., 0.9993, 0.9828, 0.9706],
          [0.9992, 1.0149, 1.0032,  ..., 0.9851, 0.9937, 0.9887],
          [0.9974, 1.0106, 1.0089,  ..., 1.0072, 1.0074, 1.0041],
          [1.0130, 1.0036, 1.0059,  ..., 0.9979, 1.0065, 1.0133],
          [1.0066, 0.9955, 1.0034,  ..., 1.0030, 0.9875, 1.0011],
          [0.9788, 0.9983, 1.0113,  ..., 1.0106, 1.0381, 1.0248]],

         [[0.9570, 0.9789, 0.9720,  ..., 0.9920, 0.9740, 0.9940],
          [0.9522, 1.0182, 1.0109,  ..., 1.0181, 1.0060, 0.9842],
          [0.9872, 1.0062, 1.0112,  ..., 1.0172, 1.0072, 0.9803],
          [1.0211, 1.0119, 1.0091,  ..., 1.0082, 1.0339, 1.0348],
          [0.9894, 1.0227, 1.0226,  ..., 0.9930, 1.0258, 1.0234],
          [0.9997, 0.9755, 0.9969,  ..., 1.0227, 1.0308, 1.0109]]],

        [[[0.9618, 1.0375, 1.0242,  ..., 1.0049, 1.0399, 1.0255],
          [1.0199, 0.9996, 1.0096,  ..., 1.0269, 1.0140, 1.0267],
          [1.0290, 1.0154, 1.0161,  ..., 1.0201, 1.0077, 1.0298],
          [1.0316, 1.0139, 1.0184,  ..., 1.0184, 1.0179, 1.0197],
          [1.0391, 1.0174, 1.0162,  ..., 1.0185, 1.0443, 1.0168],
          [1.0066, 1.0186, 0.9976,  ..., 1.0143, 1.0066, 1.0249]],

         [[1.0155, 1.0491, 1.0004,  ..., 0.9993, 0.9828, 0.9706],
          [0.9992, 1.0149, 1.0032,  ..., 0.9851, 0.9937, 0.9887],
          [0.9974, 1.0106, 1.0089,  ..., 1.0072, 1.0074, 1.0041],
          [1.0130, 1.0036, 1.0059,  ..., 0.9979, 1.0065, 1.0133],
          [1.0066, 0.9955, 1.0034,  ..., 1.0030, 0.9875, 1.0011],
          [0.9788, 0.9983, 1.0113,  ..., 1.0106, 1.0381, 1.0248]],

         [[0.9570, 0.9789, 0.9720,  ..., 0.9920, 0.9740, 0.9940],
          [0.9522, 1.0182, 1.0109,  ..., 1.0181, 1.0060, 0.9842],
          [0.9872, 1.0062, 1.0112,  ..., 1.0172, 1.0072, 0.9803],
          [1.0211, 1.0119, 1.0091,  ..., 1.0082, 1.0339, 1.0348],
          [0.9894, 1.0227, 1.0226,  ..., 0.9930, 1.0258, 1.0234],
          [0.9997, 0.9755, 0.9969,  ..., 1.0227, 1.0308, 1.0109]]]],

 3.在C++中加載你的Script Module

要在C ++中加載序列化的PyTorch模型,您的應用程序必須依賴於PyTorch C ++ API - 也稱爲LibTorch。LibTorch發行版包含一組共享庫,頭文件和CMake構建配置文件。雖然CMake不是依賴LibTorch的要求,但它是推薦的方法,而且未來會獲得很好的支持。在本教程中,咱們將使用CMake和LibTorch構建一個最小的C ++應用程序,它只需加載並執行序列化的PyTorch模型。


#include <torch/script.h> // One-stop header.
#include <iostream>
#include <memory>

int main(int argc, const char* argv[]) {
  if (argc != 2) {
    std::cerr << "usage: example-app <path-to-exported-script-module>\n";
    return -1;

  // Deserialize the ScriptModule from a file using torch::jit::load().
  std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);

  assert(module != nullptr);
  std::cout << "ok\n";

<torch / script.h>頭文件包含運行該示例所需的LibTorch庫中的全部相關包含。咱們的應用程序接受序列化PyTorch ScriptModule的文件路徑做爲其惟一的命令行參數,而後使用torch :: jit :: load()函數繼續反序列化模塊,該函數將此文件路徑做爲輸入。做爲回報,咱們收到一個指向torch :: jit :: script :: Module的共享指針,至關於C ++中的torch.jit.ScriptModule。目前,咱們只驗證此指針不爲null。咱們將研究如何在接下來執行它。



cmake_minimum_required(VERSION 3.0 FATAL_ERROR)

find_package(Torch REQUIRED)

add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 11)




cmake -DCMAKE_PREFIX_PATH=/home/anpi-cn/workspace_min/libtorch


(surper-resolution-pytorch) anpi-cn@anpi-cn:~/workspace_min/Super-Resolution/ESRGAN/example-app$ cmake -DCMAKE_PREFIX_PATH=/home/anpi-cn/workspace_min/libtorch
-- The C compiler identification is GNU 5.4.0
-- The CXX compiler identification is GNU 5.4.0
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Looking for pthread_create
-- Looking for pthread_create - not found
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE  
-- Found CUDA: /usr/local/cuda (found version "9.0") 
-- Caffe2: CUDA detected: 9.0
-- Caffe2: CUDA nvcc is: /usr/local/cuda/bin/nvcc
-- Caffe2: CUDA toolkit directory: /usr/local/cuda
-- Caffe2: Header version is: 9.0
-- Found CUDNN: /usr/include  
-- Found cuDNN: v7.4.1  (include: /usr/include, library: /usr/lib/x86_64-linux-gnu/libcudnn.so)
-- Autodetected CUDA architecture(s):  6.1
-- Added CUDA NVCC flags for: -gencode;arch=compute_61,code=sm_61
-- Found torch: /home/anpi-cn/workspace_min/libtorch/lib/libtorch.so  
-- Configuring done
CMake Warning at CMakeLists.txt:6 (add_executable):
  Cannot generate a safe runtime search path for target example-app because
  there is a cycle in the constraint graph:

    dir 0 is [/home/anpi-cn/workspace_min/libtorch/lib]
    dir 1 is [/usr/local/cuda/lib64/stubs]
    dir 2 is [/home/anpi-cn/.conda/envs/surper-resolution-pytorch/lib]
      dir 3 must precede it due to runtime library [libcudart.so.9.0]
    dir 3 is [/usr/local/cuda/lib64]
      dir 2 must precede it due to runtime library [libnvrtc.so.9.0]

  Some of these libraries may not be found correctly.

-- Generating done
-- Build files have been written to: /home/anpi-cn/workspace_min/Super-Resolution/ESRGAN/example-app
(surper-resolution-pytorch) anpi-cn@anpi-cn:~/workspace_min/Super-Resolution/ESRGAN/example-app$ make
Scanning dependencies of target example-app
[ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o
[100%] Linking CXX executable example-app
[100%] Built target example-app
(surper-resolution-pytorch) anpi-cn@anpi-cn:~/workspace_min/Super-Resolution/ESRGAN/example-app$ ./example-app ../models/RRDB_ESRGAN_x4_000.pt 

4.在C++代碼中執行Script Module

在C ++中成功加載了咱們的序列化模型後,添加如下代碼到C ++應用程序的main()函數中:

// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({64, 3, 3, 3}));

// Execute the model and turn its output into a tensor.
auto output = module->forward(inputs).toTensor();

std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';

 前兩行設置了咱們模型的輸入。咱們建立了一個torch :: jit :: IValue的向量並添加一個輸入。要建立輸入張量,咱們使用torch :: ones(),至關於C ++ API中的torch.ones。而後咱們運行script::Moduleforward方法,將它傳遞給咱們建立的輸入向量。做爲回報,咱們獲得一個新的IValue,咱們經過調用toTensor()將其轉換爲張量。









PyTorch 1.0 中文官方教程:使用 PyTorch C++ 前端



