Tensorflow之實現物體檢測

目錄

  • 項目背景
  • TensorFlow介紹
  • 環境搭建
  • 模型選用
  • Api使用說明
  • 運行路由
  • 小結

項目背景

產品看到競品能夠標記物體的功能,秉承一向的他有我也要有,他沒有我更要有的做風,丟過來一網站,說這個功能很簡單,必定能夠實現html

image

這時候萬能的谷歌發揮了做用,在茫茫的數據大海中發現了Tensorflow機器學習框架,也就是目前很是火爆的的深度學習(人工智能),既然方案已有,就差一個程序員了python

Tensorflow介紹

百科介紹:TensorFlow是谷歌基於DistBelief進行研發的第二代人工智能學習系統,可被用於語音識別或圖像識別等多項機器學習和深度學習領域。git

image

翻譯成大白話:是一個深度學習和神經網絡的框架,底層C++,經過Python進行控制,固然,也是支持Go、Java等語言程序員

環境搭建

  • Linux/Unix(筆者使用Mac)
  • Python3.6
  • protoc 3.5.1
  • tensorflow 1.7.0
一、克隆文件

git clone https://github.com/guandeng/tensorflow.gitgithub

文件目錄格式以下json

└── tensorflow
    ├── Dockerfile
    ├── README.md
    ├── data
    │   ├── models
    │   ├── pbtxt
    │   └── tf_models
    ├── object_detection_api.py
    ├── server.py
    ├── sh
    │   ├── download_data.sh
    │   └── ods.sh
    ├── static
    ├── templates
    └── upload
  • data/models 存放
  • data/pbtxt 物體標識名稱
  • data/tf_models 存放tensorflow/models數據
二、安裝依賴庫

pip3 install -r requirements.txtapi

三、下載模型

sh sh/download_data.sh數組

四、添加環境變量PYTHONPATH

echo 'export PYTHONPATH=$PYTHONPATH:pwd/data/tf_models/models/research'>> ~/.bashrc && source ~/.bashrc瀏覽器

五、啓動服務

python3 server.pybash

沒有報錯,說明你已成功搭建環境,使用過程是否是很是簡單,下面介紹代碼調用邏輯過程

模型選用

我從谷歌提供幾種模型選出來對比

image

  • Speed 是識別物體速度,值越小,識別越快
  • mAP(平均準確率)是精度和檢測邊界盒的乘積,值越高神經網絡的識別精確度越高,對應Speed越大

爲了測試方便,筆者選用輕量級(ssd_mobilenet)做爲本次識別物體模型

引入Python庫

import numpy as np
import os
import tensorflow as tf
import json
import time
from PIL import Image
# 兼容Python2.7版本
try:
    import urllib.request as ulib
except Exception as e:
    import urllib as ulib
import re
from object_detection.utils import label_map_util

載入模型

MODEL_NAME = 'data/models/ssd_mobilenet_v2_coco_2018_03_29'
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
PATH_TO_LABELS = os.path.join('data/pbtxt','mscoco_label_map.pbtxt')  # CWH: Add object_detection path
# data/pbtxt下mscoco_label_map.pbtxt最大item.id
NUM_CLASSES = 90
detection_graph = tf.Graph()
with detection_graph.as_default():
  od_graph_def = tf.GraphDef()
  # 加載模型
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
    serialized_graph = fid.read()
    od_graph_def.ParseFromString(serialized_graph)
    tf.import_graph_def(od_graph_def, name='')

載入標籤映射,內置函數返回整數會映射到pbtxt字符標籤

mscoco_label_map.pbtxt格式以下

item {
  name: "/m/01g317"
  id: 1
  display_name: "person"
}
item {
  name: "/m/0199g"
  id: 2
  display_name: "bicycle"
}
# 加載標籤
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(
    label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
with detection_graph.as_default():
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  with tf.Session(graph=detection_graph,config=config) as sess:
    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
    # 物體座標
    detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
    # 檢測到物體的準確度
    detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
    detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
    num_detections = detection_graph.get_tensor_by_name('num_detections:0')
def get_objects(file_name, threshold=0.5):
    image = Image.open(file_name)
    # 判斷文件是不是jpeg格式
    if not image.format=='JPEG':
        result['status'] = 0
        result['msg'] = file_name+ ' is ' + image.format + ' ods system allow jpeg or jpg'
        return result
    image_np = load_image_into_numpy_array(image)
    # 擴展維度
    image_np_expanded = np.expand_dims(image_np, axis=0)
    output = []
    # 獲取運算結果
    (boxes, scores, classes, num) = sess.run(
        [detection_boxes, detection_scores, detection_classes, num_detections],
        feed_dict={image_tensor: image_np_expanded})
    # 去掉緯度爲1的數組
    classes = np.squeeze(classes).astype(np.int32)
    scores = np.squeeze(scores)
    boxes = np.squeeze(boxes)
    for c in range(0, len(classes)):
        if scores[c] >= threshold:
            item = Object()
            item.class_name = category_index[classes[c]]['name'] # 物體名稱
            item.score = float(scores[c]) # 準確率
            # 物體座標軸百分比
            item.y1 = float(boxes[c][0])
            item.x1 = float(boxes[c][1])
            item.y2 = float(boxes[c][2])
            item.x2 = float(boxes[c][3])
            output.append(item)
    # 返回JSON格式
    outputJson = json.dumps([ob.__dict__ for ob in output])
    return outputJson

運行路由

server.py下的邏輯

def image():
    startTime = time.time()
    if request.method=='POST':
        image_file = request.files['file']
        base_path = os.path.abspath(os.path.dirname(__file__))
        upload_path = os.path.join(base_path,'static/upload/')
        # 保存上傳圖片文件
        file_name = upload_path + image_file.filename
        image_file.save(file_name)
        # 準確率過濾值
        threshold = request.form.get('threshold',0.5)
        # 調用Api服務
        objects = object_detection_api.get_objects(file_name, threshold)
        # 模板顯示
        return render_template('index.html',json_data = objects,img=image_file.filename)

curl http://localhost:5000 | python -m json.tool

[
    {
        "y2": 0.9886252284049988,
        "class_name": "bed",
        "x2": 0.4297400414943695,
        "score": 0.9562674164772034,
        "y1": 0.5202791094779968,
        "x1": 0
    },
    {
        "y2": 0.9805927872657776,
        "class_name": "couch",
        "x2": 0.4395904541015625,
        "score": 0.6422878503799438,
        "y1": 0.5051193833351135,
        "x1": 0.00021047890186309814
    }
]
  • class_name表示物體標籤名稱
  • score 可信度值
  • x1,y1表示對象所在最左上點位置
  • x2,y2表示對象最右下點位置

在瀏覽器訪問網址體驗

http://localhost:5000/upload

小結

  • Tensorflow使用GPU效率提高几個數量級
  • 能夠嘗試不一樣的模型比較速度和準確度
  • 本案例也是支持python2,爲了跟上時代步伐,建議使用python3
  • 案例有個攝像頭演示,須要https支持,且使用安卓系統

你們確定很好奇,怎麼訓練本身須要檢測的物體,能夠期待下一篇文章

相關文章
相關標籤/搜索