# -*- coding: utf-8 -*- """ Created on Sat Aug 15 01:10:42 2020 @author: LX """ import torch import torchvision import cv2 import onnx import numpy as np import matplotlib.pyplot as plt import timm import os print(torch.__version__) print(cv2.__version__) print(np.__version__) print(onnx.__version__) classes = None class_file = r"E:\ScientificComputing\opencv\sources\samples\data\dnn\classification_classes_ILSVRC2012.txt" with open(class_file, 'rt') as f: classes = f.read().rstrip('\n').split('\n') def init_model(model_name): if model_name=='alexnet': model = torchvision.models.alexnet(pretrained=True) if model_name=='densnet': model = torchvision.models.densenet121(pretrained=True) if model_name=='resnet': model = torchvision.models.resnet50(pretrained=True) if model_name=='mobilenet': model = torchvision.models.mobilenet_v2(pretrained=True) if model_name=='squeezenet': model = torchvision.models.squeezenet1_1(pretrained=True) if model_name=='inception': model = torchvision.models.inception_v3(pretrained=False) if model_name=='googlenet': model = torchvision.models.googlenet(pretrained=True) if model_name=='vgg16': model = torchvision.models.vgg16(pretrained=True) if model_name=='vgg19': model = torchvision.models.vgg19(pretrained=True) if model_name=='shufflenet': model = torchvision.models.shufflenet_v2_x1_0(pretrained=True) if model_name=='cspdarknet53': model = timm.create_model('cspdarknet53', pretrained=True) if model_name=='seresnet18': model = timm.create_model('seresnet18',pretrained=True) if model_name=='senet154': model = timm.create_model('senet154', pretrained=True) if model_name=='seresnet50': model = timm.create_model('seresnet50',pretrained=True) if model_name=='resnest50d': model = timm.create_model('resnest50d', pretrained=True) if model_name=='skresnet50': model = timm.create_model('skresnet50',pretrained=True) model.eval() if model_name=='inception': dummy = torch.randn(1,3,299,299) else: dummy = torch.randn(1,3,224,224) return model, dummy model, dummy = init_model('seresnet18') onnx_name = 'exported.onnx' torch.onnx.export(model, dummy, onnx_name) # 載入onnx模塊 model_ = onnx.load(onnx_name) #檢查IR是否良好 onnx.checker.check_model(model_) # opencv dnn加載 net = cv2.dnn.readNetFromONNX(onnx_name) img_file = r"C:\Users\LX\Pictures\dog.jpg" assert os.path.exists(img_file) #%% torchvison模型推理 from torchvision import transforms transform = transforms.Compose([ #[1] transforms.Resize(256), #[2] transforms.CenterCrop(224), #[3] transforms.ToTensor(), #[4] transforms.Normalize( #[5] mean=[0.485, 0.456, 0.406], #[6] std=[0.229, 0.224, 0.225] #[7] )]) from PIL import Image img = Image.open(img_file) img_t = transform(img) # 輸入給模型的圖像數據要先進行轉換 batch_t = torch.unsqueeze(img_t, 0) tc_out = model(batch_t).detach().numpy() # Get a class with a highest score. tc_out = tc_out.flatten() classId = np.argmax(tc_out) confidence = tc_out[classId] label = '%s: %.4f' % (classes[classId] if classes else 'Class #%d' % classId, confidence) print(label) #%% opencv調用onnx模型 frame = cv2.imread(img_file) # Create a 4D blob from a frame. inpWidth = dummy.shape[-2] inpHeight = dummy.shape[-2] # blob = cv2.dnn.blobFromImage(frame, size=(inpWidth, inpHeight), crop=False) blob = cv2.dnn.blobFromImage(frame, scalefactor=1.0 / 255, size=(inpWidth, inpHeight), mean=[0.485, 0.456, 0.406], swapRB=True, crop=False) # Run a model net.setInput(blob) out = net.forward() print(out.shape) # Get a class with a highest score. out = out.flatten() classId = np.argmax(out) confidence = out[classId] # Put efficiency information. t, _ = net.getPerfProfile() label = 'Inference time: %.2f ms' % (t * 1000.0 / cv2.getTickFrequency()) print(label) cv2.putText(frame, label, (0, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0)) # Print predicted class. label = '%s: %.4f' % (classes[classId] if classes else 'Class #%d' % classId, confidence) print(label) cv2.putText(frame, label, (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0)) winName = 'onnx' cv2.imshow(winName, frame) cv2.waitKey(0) cv2.destroyAllWindows()
以上模型中shufflenet, resnest, skresnet模型還不能被opencv dnn模塊成功加載,有些op還不支持,會報錯。另外須要注意的是blobFromImage
函數的參數:ide
blob = cv2.dnn.blobFromImage(frame, scalefactor=1.0 / 255, size=(inpWidth, inpHeight), mean=[0.485, 0.456, 0.406], swapRB=True, crop=False)
對圖像數據的變換要和pytorch模型設置相同,好比歸一化操做,減去均值等。函數