c++ 使用torchscript 加載訓練好的pytorch模型

1.首先官網上下載libtorch,放到當前項目下ios

2.將pytorch訓練好的模型使用torch.jit.trace導出爲.pt格式c++

 1 import torch  2 from skimage import io, transform, color  3 import numpy as np  4 import os  5 import torch.nn.functional as F  6 import warnings  7 warnings.filterwarnings("ignore")  8 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  9 
10 labels = ['cock', 'drawing', 'neutral', 'porn', 'sexy'] 11 path = "test/n_1.jpg"
12 im = io.imread(path) 13 if im.shape[2] == 4: 14     im = color.rgba2rgb(im) 15 
16 im = transform.resize(im, (224, 224)) 17 im = np.transpose(im, (2, 0, 1)) 18 dummy_input = np.expand_dims(im, 0) 19 inp = torch.from_numpy(dummy_input) 20 inp = inp.float() 21 model = torch.load( 22     "models/resnet50-epoch-0-accu-0.9213857428381079.pth", map_location='cpu') 23 traced_script_module = torch.jit.trace(model, inp) 24 output = model(inp) 25 probs = F.softmax(output).detach().numpy()[0] 26 pred = np.argmax(probs) 27 
28 traced_script_module.save("models/traced_resnet_model.pt")

torchscript加載.pt模型ui

 1 // One-stop header.  2 #include <torch/script.h>
 3 
 4 // headers for opencv  5 #include <opencv2/highgui/highgui.hpp>
 6 #include <opencv2/imgproc/imgproc.hpp>
 7 #include <opencv2/opencv.hpp>
 8 
 9 #include <cmath>
 10 #include <iostream>
 11 #include <memory>
 12 #include <string>
 13 #include <vector>
 14 
 15 #define kIMAGE_SIZE 224
 16 #define kCHANNELS 3
 17 #define kTOP_K 1 //print top k predicted results
 18 
 19 bool LoadImage(std::string file_name, cv::Mat &image)  20 {  21   image = cv::imread(file_name); // CV_8UC3  22   if (image.empty() || !image.data)  23  {  24     return false;  25  }  26  cv::cvtColor(image, image, CV_BGR2RGB);  27   // scale image to fit  28  cv::Size scale(kIMAGE_SIZE, kIMAGE_SIZE);  29  cv::resize(image, image, scale);  30 
 31   // convert [unsigned int] to [float]  32  image.convertTo(image, CV_32FC3,1.0/255);  33 
 34   return true;  35 }  36 
 37 bool LoadImageNetLabel(std::string file_name,  38                        std::vector<std::string> &labels)  39 {  40  std::ifstream ifs(file_name);  41   if (!ifs)  42  {  43     return false;  44  }  45  std::string line;  46   while (std::getline(ifs, line))  47  {  48  labels.push_back(line);  49  }  50   return true;  51 }  52 
 53 int main(int argc, const char *argv[])  54 {  55   if (argc != 3)  56  {  57     std::cerr << "Usage:classifier <path-to-exported-script-module> <path-to-lable-file> " << std::endl;  58     return -1;  59  }  60 
 61   //load model  62   torch::jit::script::Module module = torch::jit::load(argv[1]);  63   // to GPU  64   // module->to(at::kCUDA);  65   std::cout << "== ResNet50 loaded!\n";  66 
 67   //load labels(classes names)  68   std::vector<std::string> labels;  69   if (LoadImageNetLabel(argv[2], labels))  70  {  71     std::cout << "== Label loaded! Let's try it\n";  72  }  73   else
 74  {  75     std::cerr << "Please check your label file path." << std::endl;  76     return -1;  77  }  78 
 79   std::string file_name = "";  80  cv::Mat image;  81   while (true)  82  {  83     std::cout << "== Input image path: [enter q to exit]" << std::endl;  84     std::cin >> file_name;  85     if (file_name == "Q" || file_name == "q")  86  {  87       break;  88  }  89     if (LoadImage(file_name, image))  90  {  91       //read image tensor  92       auto input_tensor = torch::from_blob(  93           image.data, {1, kIMAGE_SIZE, kIMAGE_SIZE, kCHANNELS});  94       input_tensor = input_tensor.permute({0, 3, 1, 2});  95       input_tensor[0][0] = input_tensor[0][0].sub_(0.485).div_(0.229);  96       input_tensor[0][1] = input_tensor[0][1].sub_(0.456).div_(0.224);  97       input_tensor[0][2] = input_tensor[0][2].sub_(0.406).div_(0.225);  98       // to GPU  99       // input_tensor = input_tensor.to(at::kCUDA); 100 
101       torch::Tensor out_tensor = module.forward({input_tensor}).toTensor(); 102 
103       auto results = out_tensor.sort(-1, true); 104       auto softmaxs = std::get<0>(results)[0].softmax(0); 105       auto indexs = std::get<1>(results)[0]; 106 
107       for (int i = 0; i < kTOP_K; ++i) 108  { 109         auto idx = indexs[i].item<int>(); 110         std::cout << " ============= Top-" << i + 1 << " =============" << std::endl; 111         std::cout << " Label: " << labels[idx] << std::endl; 112         std::cout << " With Probability: "
113                   << softmaxs[i].item<float>() * 100.0f << "%" << std::endl; 114  } 115  } 116     else
117  { 118       std::cout << "Can't load the image, please check your path." << std::endl; 119  } 120  } 121 }

CMakeLists.txt編譯spa

 1 cmake_minimum_required(VERSION 2.8)  2 project(predict_demo)  3 SET(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS} "-std=c++11 -O3")  4 
 5 
 6 set(OpenCV_DIR  /home/buyizhiyou/opencv-3.4.4/build)  7 find_package(OpenCV REQUIRED)  8 find_package(Torch REQUIRED)  9 
10 
11 # 添加頭文件 12 include_directories( ${OpenCV_INCLUDE_DIRS} ) 13 
14 add_executable(resnet_demo resnet_demo.cpp) 15 target_link_libraries(resnet_demo ${TORCH_LIBRARIES} ${OpenCV_LIBS}) 16 set_property(TARGET resnet_demo PROPERTY CXX_STANDARD 11)

運行c++11

./resnet_demo   models/traced_resnet_model.pt  labels.txt
相關文章
相關標籤/搜索