1.背景(Background)html
上圖顯示了目前深度學習模型在生產環境中的方法,本文僅探討如何部署pytorch模型!前端
至於爲何要用C++調用pytorch模型,其目的在於:使用C++及多線程能夠加快模型預測速度python
關於模型訓練有兩種方法,一種是直接使用C++編寫訓練代碼,能夠作到搭建完整的網絡模型,可是沒法使用遷移學習,而遷移學習是目前訓練樣本幾乎都會用到的方法,另外一種是使用python代碼訓練好模型,並使用JIT技術,將python模型導出爲C++可調用的模型,這裏具體介紹第二種。(我的以爲還能夠採用一種方式,即將pytorch模型做爲一種Web Service以供各類客戶端調用)linux
官方對TorchScript的介紹以下(https://pytorch.org/docs/master/jit.html#creating-torchscript-code):ios
TorchScript是一種從PyTorch代碼建立可序列化和可優化模型的方法。用TorchScript編寫的任何代碼均可以從Python進程中保存並加載到沒有Python依賴關係的進程中。
咱們提供了一些工具來增量地將模型從純Python程序轉換爲可以獨立於Python運行的TorchScript程序,例如,在一個獨立的c++程序中。這使得使用熟悉的工具在PyTorch中培訓模型,而後經過TorchScript將模型導出到生產環境中成爲可能。在生產環境中,出於性能和多線程的緣由,將模型做爲Python程序運行不是一個好主意。
首先,咱們在官網下載適合於Windows的libtorch
,由於穩定版出來了,因此能夠直接拿來使用。有CPU版本的和GPU版本的,這裏我都進行了測試,都是能夠直接使用的,這裏以CPU版本爲例進行介紹:c++
2.實驗(Experiments)git
1.python環境下跑模型的推斷代碼 github
以ESRGAN的inference code(https://github.com/xinntao/ESRGAN)爲例:segmentfault
環境:Windows10+Python3.5.2+Pytorch1.1網絡
Python packages: pip install numpy opencv-python
直接run test,結果以下(個人版本有作一些改動,如增長FPS的計算等):
2.將PyTorch模型轉換爲Torch Script
第一個方法是tracing.該方法經過將樣本輸入到模型中一次來對該過程進行評估從而捕獲模型結構.並記錄該樣本在模型中的flow.該方法適用於模型中不多使用控制flow的模型.
第二個方法就是向模型添加顯式註釋,通知Torch Script編譯器它能夠直接解析和編譯模型代碼,受Torch Script語言強加的約束。
要經過tracing來將PyTorch模型轉換爲Torch腳本,必須將模型的實例以及樣本輸入傳遞給torch.jit.trace函數.
這將生成一個torch.jit.ScriptModule對象,並在模塊的forward方法中嵌入模型評估的跟蹤:
import torch import architecture as arch # An instance of your model. model = arch.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \ mode='CNA', res_scale=1, upsample_mode='upconv') model.load_state_dict(torch.load('./models/RRDB_ESRGAN_x4.pth'), strict=True) model.eval() # An example input you would normally provide to your model's forward() method. example = torch.rand(64, 3, 3, 3) # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing. traced_script_module = torch.jit.trace(model, example) output = traced_script_module(torch.ones(64, 3, 3, 3)) traced_script_module.save("./models/RRDB_ESRGAN_x4_000.pt") # The traced ScriptModule can now be evaluated identically to a regular PyTorch module print(output)
跟蹤的ScriptModule能夠與常規PyTorch模塊進行相同的計算,結果以下(注意在最後,將ScriptModule序列化爲一個文件.而後,C++就能夠不依賴任何Python代碼來執行該Script所對應的Pytorch模型.):
(surper-resolution-pytorch) anpi-cn@anpi-cn:~/workspace_min/Super-Resolution/ESRGAN$ python model_jit_converter.py tensor([[[[0.9618, 1.0375, 1.0242, ..., 1.0049, 1.0399, 1.0255], [1.0199, 0.9996, 1.0096, ..., 1.0269, 1.0140, 1.0267], [1.0290, 1.0154, 1.0161, ..., 1.0201, 1.0077, 1.0298], ..., [1.0316, 1.0139, 1.0184, ..., 1.0184, 1.0179, 1.0197], [1.0391, 1.0174, 1.0162, ..., 1.0185, 1.0443, 1.0168], [1.0066, 1.0186, 0.9976, ..., 1.0143, 1.0066, 1.0249]], [[1.0155, 1.0491, 1.0004, ..., 0.9993, 0.9828, 0.9706], [0.9992, 1.0149, 1.0032, ..., 0.9851, 0.9937, 0.9887], [0.9974, 1.0106, 1.0089, ..., 1.0072, 1.0074, 1.0041], ..., [1.0130, 1.0036, 1.0059, ..., 0.9979, 1.0065, 1.0133], [1.0066, 0.9955, 1.0034, ..., 1.0030, 0.9875, 1.0011], [0.9788, 0.9983, 1.0113, ..., 1.0106, 1.0381, 1.0248]], [[0.9570, 0.9789, 0.9720, ..., 0.9920, 0.9740, 0.9940], [0.9522, 1.0182, 1.0109, ..., 1.0181, 1.0060, 0.9842], [0.9872, 1.0062, 1.0112, ..., 1.0172, 1.0072, 0.9803], ..., [1.0211, 1.0119, 1.0091, ..., 1.0082, 1.0339, 1.0348], [0.9894, 1.0227, 1.0226, ..., 0.9930, 1.0258, 1.0234], [0.9997, 0.9755, 0.9969, ..., 1.0227, 1.0308, 1.0109]]], [[[0.9618, 1.0375, 1.0242, ..., 1.0049, 1.0399, 1.0255], [1.0199, 0.9996, 1.0096, ..., 1.0269, 1.0140, 1.0267], [1.0290, 1.0154, 1.0161, ..., 1.0201, 1.0077, 1.0298], ..., [1.0316, 1.0139, 1.0184, ..., 1.0184, 1.0179, 1.0197], [1.0391, 1.0174, 1.0162, ..., 1.0185, 1.0443, 1.0168], [1.0066, 1.0186, 0.9976, ..., 1.0143, 1.0066, 1.0249]], [[1.0155, 1.0491, 1.0004, ..., 0.9993, 0.9828, 0.9706], [0.9992, 1.0149, 1.0032, ..., 0.9851, 0.9937, 0.9887], [0.9974, 1.0106, 1.0089, ..., 1.0072, 1.0074, 1.0041], ..., [1.0130, 1.0036, 1.0059, ..., 0.9979, 1.0065, 1.0133], [1.0066, 0.9955, 1.0034, ..., 1.0030, 0.9875, 1.0011], [0.9788, 0.9983, 1.0113, ..., 1.0106, 1.0381, 1.0248]], [[0.9570, 0.9789, 0.9720, ..., 0.9920, 0.9740, 0.9940], [0.9522, 1.0182, 1.0109, ..., 1.0181, 1.0060, 0.9842], [0.9872, 1.0062, 1.0112, ..., 1.0172, 1.0072, 0.9803], ..., [1.0211, 1.0119, 1.0091, ..., 1.0082, 1.0339, 1.0348], [0.9894, 1.0227, 1.0226, ..., 0.9930, 1.0258, 1.0234], [0.9997, 0.9755, 0.9969, ..., 1.0227, 1.0308, 1.0109]]], [[[0.9618, 1.0375, 1.0242, ..., 1.0049, 1.0399, 1.0255], [1.0199, 0.9996, 1.0096, ..., 1.0269, 1.0140, 1.0267], [1.0290, 1.0154, 1.0161, ..., 1.0201, 1.0077, 1.0298], ..., [1.0316, 1.0139, 1.0184, ..., 1.0184, 1.0179, 1.0197], [1.0391, 1.0174, 1.0162, ..., 1.0185, 1.0443, 1.0168], [1.0066, 1.0186, 0.9976, ..., 1.0143, 1.0066, 1.0249]], [[1.0155, 1.0491, 1.0004, ..., 0.9993, 0.9828, 0.9706], [0.9992, 1.0149, 1.0032, ..., 0.9851, 0.9937, 0.9887], [0.9974, 1.0106, 1.0089, ..., 1.0072, 1.0074, 1.0041], ..., [1.0130, 1.0036, 1.0059, ..., 0.9979, 1.0065, 1.0133], [1.0066, 0.9955, 1.0034, ..., 1.0030, 0.9875, 1.0011], [0.9788, 0.9983, 1.0113, ..., 1.0106, 1.0381, 1.0248]], [[0.9570, 0.9789, 0.9720, ..., 0.9920, 0.9740, 0.9940], [0.9522, 1.0182, 1.0109, ..., 1.0181, 1.0060, 0.9842], [0.9872, 1.0062, 1.0112, ..., 1.0172, 1.0072, 0.9803], ..., [1.0211, 1.0119, 1.0091, ..., 1.0082, 1.0339, 1.0348], [0.9894, 1.0227, 1.0226, ..., 0.9930, 1.0258, 1.0234], [0.9997, 0.9755, 0.9969, ..., 1.0227, 1.0308, 1.0109]]], ..., [[[0.9618, 1.0375, 1.0242, ..., 1.0049, 1.0399, 1.0255], [1.0199, 0.9996, 1.0096, ..., 1.0269, 1.0140, 1.0267], [1.0290, 1.0154, 1.0161, ..., 1.0201, 1.0077, 1.0298], ..., [1.0316, 1.0139, 1.0184, ..., 1.0184, 1.0179, 1.0197], [1.0391, 1.0174, 1.0162, ..., 1.0185, 1.0443, 1.0168], [1.0066, 1.0186, 0.9976, ..., 1.0143, 1.0066, 1.0249]], [[1.0155, 1.0491, 1.0004, ..., 0.9993, 0.9828, 0.9706], [0.9992, 1.0149, 1.0032, ..., 0.9851, 0.9937, 0.9887], [0.9974, 1.0106, 1.0089, ..., 1.0072, 1.0074, 1.0041], ..., [1.0130, 1.0036, 1.0059, ..., 0.9979, 1.0065, 1.0133], [1.0066, 0.9955, 1.0034, ..., 1.0030, 0.9875, 1.0011], [0.9788, 0.9983, 1.0113, ..., 1.0106, 1.0381, 1.0248]], [[0.9570, 0.9789, 0.9720, ..., 0.9920, 0.9740, 0.9940], [0.9522, 1.0182, 1.0109, ..., 1.0181, 1.0060, 0.9842], [0.9872, 1.0062, 1.0112, ..., 1.0172, 1.0072, 0.9803], ..., [1.0211, 1.0119, 1.0091, ..., 1.0082, 1.0339, 1.0348], [0.9894, 1.0227, 1.0226, ..., 0.9930, 1.0258, 1.0234], [0.9997, 0.9755, 0.9969, ..., 1.0227, 1.0308, 1.0109]]], [[[0.9618, 1.0375, 1.0242, ..., 1.0049, 1.0399, 1.0255], [1.0199, 0.9996, 1.0096, ..., 1.0269, 1.0140, 1.0267], [1.0290, 1.0154, 1.0161, ..., 1.0201, 1.0077, 1.0298], ..., [1.0316, 1.0139, 1.0184, ..., 1.0184, 1.0179, 1.0197], [1.0391, 1.0174, 1.0162, ..., 1.0185, 1.0443, 1.0168], [1.0066, 1.0186, 0.9976, ..., 1.0143, 1.0066, 1.0249]], [[1.0155, 1.0491, 1.0004, ..., 0.9993, 0.9828, 0.9706], [0.9992, 1.0149, 1.0032, ..., 0.9851, 0.9937, 0.9887], [0.9974, 1.0106, 1.0089, ..., 1.0072, 1.0074, 1.0041], ..., [1.0130, 1.0036, 1.0059, ..., 0.9979, 1.0065, 1.0133], [1.0066, 0.9955, 1.0034, ..., 1.0030, 0.9875, 1.0011], [0.9788, 0.9983, 1.0113, ..., 1.0106, 1.0381, 1.0248]], [[0.9570, 0.9789, 0.9720, ..., 0.9920, 0.9740, 0.9940], [0.9522, 1.0182, 1.0109, ..., 1.0181, 1.0060, 0.9842], [0.9872, 1.0062, 1.0112, ..., 1.0172, 1.0072, 0.9803], ..., [1.0211, 1.0119, 1.0091, ..., 1.0082, 1.0339, 1.0348], [0.9894, 1.0227, 1.0226, ..., 0.9930, 1.0258, 1.0234], [0.9997, 0.9755, 0.9969, ..., 1.0227, 1.0308, 1.0109]]], [[[0.9618, 1.0375, 1.0242, ..., 1.0049, 1.0399, 1.0255], [1.0199, 0.9996, 1.0096, ..., 1.0269, 1.0140, 1.0267], [1.0290, 1.0154, 1.0161, ..., 1.0201, 1.0077, 1.0298], ..., [1.0316, 1.0139, 1.0184, ..., 1.0184, 1.0179, 1.0197], [1.0391, 1.0174, 1.0162, ..., 1.0185, 1.0443, 1.0168], [1.0066, 1.0186, 0.9976, ..., 1.0143, 1.0066, 1.0249]], [[1.0155, 1.0491, 1.0004, ..., 0.9993, 0.9828, 0.9706], [0.9992, 1.0149, 1.0032, ..., 0.9851, 0.9937, 0.9887], [0.9974, 1.0106, 1.0089, ..., 1.0072, 1.0074, 1.0041], ..., [1.0130, 1.0036, 1.0059, ..., 0.9979, 1.0065, 1.0133], [1.0066, 0.9955, 1.0034, ..., 1.0030, 0.9875, 1.0011], [0.9788, 0.9983, 1.0113, ..., 1.0106, 1.0381, 1.0248]], [[0.9570, 0.9789, 0.9720, ..., 0.9920, 0.9740, 0.9940], [0.9522, 1.0182, 1.0109, ..., 1.0181, 1.0060, 0.9842], [0.9872, 1.0062, 1.0112, ..., 1.0172, 1.0072, 0.9803], ..., [1.0211, 1.0119, 1.0091, ..., 1.0082, 1.0339, 1.0348], [0.9894, 1.0227, 1.0226, ..., 0.9930, 1.0258, 1.0234], [0.9997, 0.9755, 0.9969, ..., 1.0227, 1.0308, 1.0109]]]], grad_fn=<MkldnnConvolutionBackward>)
3.在C++中加載你的Script Module
要在C ++中加載序列化的PyTorch模型,您的應用程序必須依賴於PyTorch C ++ API - 也稱爲LibTorch。LibTorch發行版包含一組共享庫,頭文件和CMake構建配置文件。雖然CMake不是依賴LibTorch的要求,但它是推薦的方法,而且未來會獲得很好的支持。在本教程中,咱們將使用CMake和LibTorch構建一個最小的C ++應用程序,它只需加載並執行序列化的PyTorch模型。
加載模塊的代碼:
#include <torch/script.h> // One-stop header. #include <iostream> #include <memory> int main(int argc, const char* argv[]) { if (argc != 2) { std::cerr << "usage: example-app <path-to-exported-script-module>\n"; return -1; } // Deserialize the ScriptModule from a file using torch::jit::load(). std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]); assert(module != nullptr); std::cout << "ok\n"; }
<torch / script.h>
頭文件包含運行該示例所需的LibTorch庫中的全部相關包含。咱們的應用程序接受序列化PyTorch ScriptModule的文件路徑做爲其惟一的命令行參數,而後使用torch :: jit :: load()
函數繼續反序列化模塊,該函數將此文件路徑做爲輸入。做爲回報,咱們收到一個指向torch :: jit :: script :: Module
的共享指針,至關於C ++中的torch.jit.ScriptModule
。目前,咱們只驗證此指針不爲null
。咱們將研究如何在接下來執行它。
LibTorch和構建應用程序
假設咱們將上面的代碼保存到名爲example-app.cpp的文件中。構建它的最小CMakeLists.txt以下:
cmake_minimum_required(VERSION 3.0 FATAL_ERROR) project(custom_ops) find_package(Torch REQUIRED) add_executable(example-app example-app.cpp) target_link_libraries(example-app "${TORCH_LIBRARIES}") set_property(TARGET example-app PROPERTY CXX_STANDARD 11)
構建應用程序時,假設咱們的示例目錄佈局以下:
example-app/ CMakeLists.txt example-app.cpp
如今能夠運行如下命令從example-app/文件夾中構建應用程序:
cmake -DCMAKE_PREFIX_PATH=/home/anpi-cn/workspace_min/libtorch make
若是一切順利,它將看起來像這樣:
(surper-resolution-pytorch) anpi-cn@anpi-cn:~/workspace_min/Super-Resolution/ESRGAN/example-app$ cmake -DCMAKE_PREFIX_PATH=/home/anpi-cn/workspace_min/libtorch -- The C compiler identification is GNU 5.4.0 -- The CXX compiler identification is GNU 5.4.0 -- Check for working C compiler: /usr/bin/cc -- Check for working C compiler: /usr/bin/cc -- works -- Detecting C compiler ABI info -- Detecting C compiler ABI info - done -- Detecting C compile features -- Detecting C compile features - done -- Check for working CXX compiler: /usr/bin/c++ -- Check for working CXX compiler: /usr/bin/c++ -- works -- Detecting CXX compiler ABI info -- Detecting CXX compiler ABI info - done -- Detecting CXX compile features -- Detecting CXX compile features - done -- Looking for pthread.h -- Looking for pthread.h - found -- Looking for pthread_create -- Looking for pthread_create - not found -- Looking for pthread_create in pthreads -- Looking for pthread_create in pthreads - not found -- Looking for pthread_create in pthread -- Looking for pthread_create in pthread - found -- Found Threads: TRUE -- Found CUDA: /usr/local/cuda (found version "9.0") -- Caffe2: CUDA detected: 9.0 -- Caffe2: CUDA nvcc is: /usr/local/cuda/bin/nvcc -- Caffe2: CUDA toolkit directory: /usr/local/cuda -- Caffe2: Header version is: 9.0 -- Found CUDNN: /usr/include -- Found cuDNN: v7.4.1 (include: /usr/include, library: /usr/lib/x86_64-linux-gnu/libcudnn.so) -- Autodetected CUDA architecture(s): 6.1 -- Added CUDA NVCC flags for: -gencode;arch=compute_61,code=sm_61 -- Found torch: /home/anpi-cn/workspace_min/libtorch/lib/libtorch.so -- Configuring done CMake Warning at CMakeLists.txt:6 (add_executable): Cannot generate a safe runtime search path for target example-app because there is a cycle in the constraint graph: dir 0 is [/home/anpi-cn/workspace_min/libtorch/lib] dir 1 is [/usr/local/cuda/lib64/stubs] dir 2 is [/home/anpi-cn/.conda/envs/surper-resolution-pytorch/lib] dir 3 must precede it due to runtime library [libcudart.so.9.0] dir 3 is [/usr/local/cuda/lib64] dir 2 must precede it due to runtime library [libnvrtc.so.9.0] Some of these libraries may not be found correctly. -- Generating done -- Build files have been written to: /home/anpi-cn/workspace_min/Super-Resolution/ESRGAN/example-app (surper-resolution-pytorch) anpi-cn@anpi-cn:~/workspace_min/Super-Resolution/ESRGAN/example-app$ make Scanning dependencies of target example-app [ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o [100%] Linking CXX executable example-app [100%] Built target example-app (surper-resolution-pytorch) anpi-cn@anpi-cn:~/workspace_min/Super-Resolution/ESRGAN/example-app$ ./example-app ../models/RRDB_ESRGAN_x4_000.pt ok
4.在C++代碼中執行Script Module
在C ++中成功加載了咱們的序列化模型後,添加如下代碼到C ++應用程序的main()
函數中:
// Create a vector of inputs. std::vector<torch::jit::IValue> inputs; inputs.push_back(torch::ones({64, 3, 3, 3})); // Execute the model and turn its output into a tensor. auto output = module->forward(inputs).toTensor(); std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
前兩行設置了咱們模型的輸入。咱們建立了一個torch :: jit :: IValue
的向量並添加一個輸入。要建立輸入張量,咱們使用torch :: ones()
,至關於C ++ API中的torch.ones
。而後咱們運行script::Module
的forward
方法,將它傳遞給咱們建立的輸入向量。做爲回報,咱們獲得一個新的IValue
,咱們經過調用toTensor()
將其轉換爲張量。
在最後一行中,咱們打印輸出的前五個條目。因爲在前面的Python中爲本次的模型提供了相同的輸入,所以理想狀況下應該看到相同的輸出。從新編譯上面的應用程序並使用相同的序列化模型運行它來嘗試。經過比較,發現C++的輸出與Python的輸出是同樣的,代表實驗成功啦!
參考文章:
https://pytorch.org/tutorials/advanced/cpp_export.html
PyTorch 1.0 中文官方教程:使用 PyTorch C++ 前端