深度學習應用的服務端Flask部署

【GiantPandaCV導讀】這篇文章包含與PyTorch模型部署相關的兩部份內容:html

  • PyTorch-YOLOv3模型的Web頁面展現程序的編寫前端

  • 模型的服務接口相關工具的使用python

0. 環境依賴:

系統:Ubuntu 18.04nginx

Python版本:3.7git

依賴Python包:1. PyTorch==1.3 2. Flask==0.12 3. Gunicorngithub

須要注意的是Flask 0.12中默認的單進程單線程,而最新的1.0.2則不是(具體是多線程仍是多進程尚待考證),而中文博客裏面能查到的資料基本都在說Flask默認單進程單線程。web

依賴工具 1. nginx 2. apache2-utilsapache

nginx 用於代理轉發和負載均衡,apache2-utils用於測試接口json


1. 製做模型演示界面

圖像識別任務的展現這項工程通常是面向客戶的,這種場景下不可能把客戶拉到你的電腦前面,敲一行命令,等matplotlib彈個結果窗口出來。總歸仍是要有個圖形化界面才顯得有點誠意。flask

爲了節約時間,咱們選擇了Flask框架來開發這個界面。

上傳頁面和展現頁面

作識別演示須要用到兩個html頁面,代碼也比較簡單,編寫以下:

上傳界面

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>Flask上傳圖片演示</title>
</head>
<body>
    <h1>使用Flask上傳本地圖片</h1>
    <form action="" enctype='multipart/form-data' method='POST'>
        <input type="file" name="file" style="margin-top:20px;"/>
        <br>
        <input type="submit" value="上傳" class="button-new" style="margin-top:15px;"/>
    </form>
</body>
</html>

展現界面

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>Flask上傳圖片演示</title>
</head>
<body>
    <h1>使用Flask上傳本地圖片</h1>
    <form action="" enctype='multipart/form-data' method='POST'>
        <input type="file" name="file" style="margin-top:20px;"/>
        <br>
        <input type="submit" value="上傳" class="button-new" style="margin-top:15px;"/>
    </form>
    <img src="{{ url_for('static', filename= path,_t=val1) }}" width="400" height="400" alt="圖片識別失敗"/>
</body>
</html>

上傳界面以下圖所示,以爲醜的話能夠找前端同事美化一下:

flask上傳圖片及展現功能

而後就能夠編寫flask代碼了,爲了更好地展現圖片,能夠向html頁面傳入圖片地址參數。

from flask import Flask, render_template, request, redirect, url_for, make_response, jsonify
from werkzeug.utils import secure_filename
import os
import cv2
import time
from datetime import timedelta
from main import run, conf
ALLOWED_EXTENSIONS = set([
    "png","jpg","JPG","PNG""bmp"
])

def is_allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS

app = Flask(__name__)

# 靜態文件緩存過時時間
app.send_file_max_age_default = timedelta(seconds=1)

@app.route("/upload",methods = ['POST''GET'])
def upload():
    if request.method == "POST":
        f = request.files['file']
        if not ( f and is_allowed_file(f.filename)):
            return jsonify({
                "error":1001, "msg":"請檢查上傳的圖片類型,僅限於png、PNG、jpg、JPG、bmp"
            })
        user_input = request.form.get("name")

        basepath = os.path.dirname(__file__)
        upload_path = os.path.join(basepath, "static/images",secure_filename(f.filename))
        f.save(upload_path)
        
        detected_path = os.path.join(basepath, "static/images""output" + secure_filename(f.filename))
        run(upload_path, conf, detected_path)

        # return render_template("upload_ok.html", userinput = user_input, val1=time.time(), path = detected_path)
        path = "/images/" + "output" + secure_filename(f.filename)
        return render_template("upload_ok.html", path = path, val1 = time.time())
    return render_template("upload.html")


if __name__ == "__main__":
    app.run(host='0.0.0.0', port=8888, debug=True)

目標檢測函數

原項目中提供了detection.py來作批量的圖片檢測,須要稍微修改一下才能用來作flask代碼中的接口。

from __future__ import division

from models import *
from utils.utils import *
from utils.datasets import *

import os
import sys
import time
import datetime
import argparse

from PIL import Image

import torch
from torchvision import datasets
from torch.autograd import Variable

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.ticker import NullLocator

class custom_dict(dict):
    def __init__(self, d = None):
        if d is not None:
            for k,v in d.items():
                self[k] = v
        return super().__init__()

    def __key(self, key):
        return "" if key is None else key.lower()

    def __str__(self):
        import json
        return json.dumps(self)

    def __setattr__(self, key, value):
        self[self.__key(key)] = value

    def __getattr__(self, key):
        return self.get(self.__key(key))

    def __getitem__(self, key):
        return super().get(self.__key(key))

    def __setitem__(self, key, value):
        return super().__setitem__(self.__key(key), value)

conf = custom_dict({
    "model_def":"config/yolov3.cfg",
    "weights_path":"weights/yolov3.weights",
    "class_path":"data/coco.names",
    "conf_thres":0.8,
    "nms_thres":0.4,
    "img_size":416
})

def run(img_path, conf, target_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs("output", exist_ok=True)
    classes = load_classes(conf.class_path)
    model = Darknet(conf.model_def, img_size=conf.img_size).to(device)

    if conf.weights_path.endswith(".weights"):
        # Load darknet weights
        model.load_darknet_weights(conf.weights_path)
    else:
        # Load checkpoint weights
        model.load_state_dict(torch.load(conf.weights_path))
    model.eval() 
    
    img = Image.open(img_path).convert("RGB")
    img = img.resize(((img.size[0] // 32) * 32, (img.size[1] // 32) * 32))
    img_array = np.array(img)
    img_tensor = pad_to_square(transforms.ToTensor()(img),0)[0].unsqueeze(0)
    conf.img_size = img_tensor.shape[2]
    
    with torch.no_grad():
        detections = model(img_tensor)
        detections = non_max_suppression(detections, conf.conf_thres, conf.nms_thres)[0]

    cmap = plt.get_cmap("tab20b")
    colors = [cmap(i) for i in np.linspace(0, 1, 20)]
    plt.figure()
    fig, ax = plt.subplots(1)
    ax.imshow(img_array)
    if detections is not None:
        # Rescale boxes to original image
        detections = rescale_boxes(detections, conf.img_size, img_array.shape[:2])
        unique_labels = detections[:, -1].cpu().unique()
        n_cls_preds = len(unique_labels)
        bbox_colors = random.sample(colors, n_cls_preds)
        for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:

            print("\t+ Label: %s, Conf: %.5f" % (classes[int(cls_pred)], cls_conf.item()))

            box_w = x2 - x1
            box_h = y2 - y1

            color = bbox_colors[int(np.where(unique_labels == int(cls_pred))[0])]
            # Create a Rectangle patch
            bbox = patches.Rectangle((x1, y1), box_w, box_h, linewidth=2, edgecolor=color, facecolor="none")
            # Add the bbox to the plot
            ax.add_patch(bbox)
            # Add label
            plt.text(
                x1,
                y1,
                s=classes[int(cls_pred)],
                color="white",
                verticalalignment="top",
                bbox={"color": color, "pad": 0},
            )

    # Save generated image with detections
    plt.axis("off")
    plt.gca().xaxis.set_major_locator(NullLocator())
    plt.gca().yaxis.set_major_locator(NullLocator())
    filename = img_path.split("/")[-1].split(".")[0]
    plt.savefig(target_path, bbox_inches='tight', pad_inches=0.0)
    plt.close()



if __name__ == "__main__":
    run("data/samples/dog.jpg",conf)

展現效果

編寫好了以後,啓動server.py,在本地打開localhost:8888/upload就能夠看到以下界面了,把圖片上傳上去,很快就能獲得檢測結果。

結果以下圖所示:

想試用的同窗能夠點擊:http://106.13.201.241:8888/upload

2. 深度學習的服務接口編寫

接下來介紹的是在生產環境下的部署,使用的是flask+gunicorn+nginx的方式,能夠處理較大規模的請求。

下面以圖像分類模型爲例演示一下深度學習服務接口如何編寫。

對於深度學習工程師來講,學習這些內容主要是瞭解一下本身的模型在生產環境的運行方式,便於在服務出現問題的時候與開發的同事一塊兒進行調試。

flask服務接口

接口不須要有界面顯示,固然也能夠添加一個API介紹界面,方便調用者查看服務是否已經啓動。

from flask import Flask, request
from werkzeug.utils import secure_filename
import uuid
from PIL import Image
import os
import time
import base64
import json

import torch
from torchvision.models import resnet18
from torchvision.transforms import ToTensor

from keys import key

app = Flask(__name__)
net = resnet18(pretrained=True)
net.eval()

@app.route("/",methods=["GET"])
def show():
    return "classifier api"

@app.route("/run",methods = ["GET","POST"])
def run():
    file = request.files['file']
    base_path = os.path.dirname(__file__)
    if not os.path.exists(os.path.join(base_path, "temp")):
        os.makedirs(os.path.join(base_path, "temp"))
    file_name = uuid.uuid4().hex
    upload_path = os.path.join(base_path, "temp", file_name)
    file.save(upload_path)

    img = Image.open(upload_path)
    img_tensor = ToTensor()(img).unsqueeze(0)
    out = net(img_tensor)
    pred = torch.argmax(out,dim = 1)

    return "result : {}".format(key[pred])

if __name__ == "__main__":
    app.run(host="0.0.0.0",port=5555,debug=True)

在命令行輸入python server.py便可啓動服務。

gunicorn啓動多個實例

新版的flask已經支持多進程了,不過用在生產環境仍是不太穩定,通常生產環境會使用gunicorn來啓動多個服務。

使用以下命令便可啓動多個圖像分類實例

gunicorn -w 4 -b 0.0.0.0:5555 server:app

輸出以下內容表明服務建立成功:

[2020-02-11 14:50:24 +0800] [892] [INFO] Starting gunicorn 20.0.4
[2020-02-11 14:50:24 +0800] [892] [INFO] Listening at: http://0.0.0.0:5555 (892)
[2020-02-11 14:50:24 +0800] [892] [INFO] Using worker: sync
[2020-02-11 14:50:24 +0800] [895] [INFO] Booting worker with pid: 895
[2020-02-11 14:50:24 +0800] [896] [INFO] Booting worker with pid: 896
[2020-02-11 14:50:24 +0800] [898] [INFO] Booting worker with pid: 898
[2020-02-11 14:50:24 +0800] [899] [INFO] Booting worker with pid: 899

若是配置比較複雜,也能夠將配置寫入一個文件中,如:

bind = '0.0.0.0:5555'
timeout = 10
workers = 4

而後運行:

gunicorn -c gunicorn.conf sim_server:app

nginx負載均衡

若是有多個服務器,可使用nginx作請求分發與負載均衡。

安裝好nginx以後,修改nginx的配置文件

worker_processes auto;
error_log /var/log/nginx/error.log;
pid /run/nginx.pid;
# Load dynamic modules. See /usr/share/nginx/README.dynamic.
include /usr/share/nginx/modules/*.conf;

events {
worker_connections 1024;
}

http {
server
{
listen 5556; # nginx端口
server_name localhost;
location / {
proxy_pass http://localhost:5555/run; # gunicorn的url
}
}
}

而後按配置文件啓動

sudo nginx -c nginx.conf

測試一下服務是否正常

啓動了這麼多服務以後,可使用apache2-utils來測試服務的併發性能。

使用apache2-utils進行上傳圖片的post請求方法參考:

https://gist.github.com/chiller/dec373004894e9c9bb38ac647c7ccfa8

嚴格參照,注意一個標點,一個符號都不要錯。使用這種方法傳輸圖片的base64編碼,在服務端不須要解碼也能使用

而後使用下面的方式訪問

gunicorn 接口

ab -n 2 -c 2 -T "multipart/form-data; boundary=1234567890" -p turtle.txt http://localhost:5555/run

nginx 接口

ab -n 2 -c 2 -T "multipart/form-data; boundary=1234567890" -p turtle.txt http://localhost:5556/run

有了gunicorn和nginx就能夠輕鬆地實現PyTorch模型的多機多卡部署了。



往期精彩

任意圖像轉素描:Python分分鐘實現

肖像轉素描:AI小素的前世此生

一鍵智能摳圖-原理與實現 | 可在線體驗

【能夠玩的】人臉檢測 & 人臉關鍵點檢測 & 人臉卡通化

卡通化-二次元的你長什麼樣









長按指紋,識別二維碼,一鍵關注



本文分享自微信公衆號 - CVPy(x-cvpy)。
若有侵權,請聯繫 support@oschina.cn 刪除。
本文參與「OSC源創計劃」,歡迎正在閱讀的你也加入,一塊兒分享。

相關文章
相關標籤/搜索