產品看到競品能夠標記物體的功能,秉承一向的他有我也要有,他沒有我更要有的做風,丟過來一網站,說這個功能很簡單,必定能夠實現html
這時候萬能的谷歌發揮了做用,在茫茫的數據大海中發現了Tensorflow機器學習框架,也就是目前很是火爆的的深度學習(人工智能),既然方案已有,就差一個程序員了python
百科介紹:TensorFlow是谷歌基於DistBelief進行研發的第二代人工智能學習系統,可被用於語音識別或圖像識別等多項機器學習和深度學習領域。git
翻譯成大白話:是一個深度學習和神經網絡的框架,底層C++,經過Python進行控制,固然,也是支持Go、Java等語言程序員
git clone https://github.com/guandeng/tensorflow.gitgithub
文件目錄格式以下json
└── tensorflow ├── Dockerfile ├── README.md ├── data │ ├── models │ ├── pbtxt │ └── tf_models ├── object_detection_api.py ├── server.py ├── sh │ ├── download_data.sh │ └── ods.sh ├── static ├── templates └── upload
pip3 install -r requirements.txtapi
sh sh/download_data.sh數組
echo 'export PYTHONPATH=$PYTHONPATH:
pwd
/data/tf_models/models/research'>> ~/.bashrc && source ~/.bashrc瀏覽器
python3 server.pybash
沒有報錯,說明你已成功搭建環境,使用過程是否是很是簡單,下面介紹代碼調用邏輯過程
我從谷歌提供幾種模型選出來對比
爲了測試方便,筆者選用輕量級(ssd_mobilenet)做爲本次識別物體模型
引入Python庫
import numpy as np import os import tensorflow as tf import json import time from PIL import Image # 兼容Python2.7版本 try: import urllib.request as ulib except Exception as e: import urllib as ulib import re from object_detection.utils import label_map_util
載入模型
MODEL_NAME = 'data/models/ssd_mobilenet_v2_coco_2018_03_29' PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' PATH_TO_LABELS = os.path.join('data/pbtxt','mscoco_label_map.pbtxt') # CWH: Add object_detection path # data/pbtxt下mscoco_label_map.pbtxt最大item.id NUM_CLASSES = 90 detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() # 加載模型 with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='')
載入標籤映射,內置函數返回整數會映射到pbtxt字符標籤
mscoco_label_map.pbtxt格式以下
item { name: "/m/01g317" id: 1 display_name: "person" } item { name: "/m/0199g" id: 2 display_name: "bicycle" }
# 加載標籤 label_map = label_map_util.load_labelmap(PATH_TO_LABELS) categories = label_map_util.convert_label_map_to_categories( label_map, max_num_classes=NUM_CLASSES, use_display_name=True) category_index = label_map_util.create_category_index(categories)
with detection_graph.as_default(): config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(graph=detection_graph,config=config) as sess: image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') # 物體座標 detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0') # 檢測到物體的準確度 detection_scores = detection_graph.get_tensor_by_name('detection_scores:0') detection_classes = detection_graph.get_tensor_by_name('detection_classes:0') num_detections = detection_graph.get_tensor_by_name('num_detections:0')
def get_objects(file_name, threshold=0.5): image = Image.open(file_name) # 判斷文件是不是jpeg格式 if not image.format=='JPEG': result['status'] = 0 result['msg'] = file_name+ ' is ' + image.format + ' ods system allow jpeg or jpg' return result image_np = load_image_into_numpy_array(image) # 擴展維度 image_np_expanded = np.expand_dims(image_np, axis=0) output = [] # 獲取運算結果 (boxes, scores, classes, num) = sess.run( [detection_boxes, detection_scores, detection_classes, num_detections], feed_dict={image_tensor: image_np_expanded}) # 去掉緯度爲1的數組 classes = np.squeeze(classes).astype(np.int32) scores = np.squeeze(scores) boxes = np.squeeze(boxes) for c in range(0, len(classes)): if scores[c] >= threshold: item = Object() item.class_name = category_index[classes[c]]['name'] # 物體名稱 item.score = float(scores[c]) # 準確率 # 物體座標軸百分比 item.y1 = float(boxes[c][0]) item.x1 = float(boxes[c][1]) item.y2 = float(boxes[c][2]) item.x2 = float(boxes[c][3]) output.append(item) # 返回JSON格式 outputJson = json.dumps([ob.__dict__ for ob in output]) return outputJson
server.py下的邏輯
def image(): startTime = time.time() if request.method=='POST': image_file = request.files['file'] base_path = os.path.abspath(os.path.dirname(__file__)) upload_path = os.path.join(base_path,'static/upload/') # 保存上傳圖片文件 file_name = upload_path + image_file.filename image_file.save(file_name) # 準確率過濾值 threshold = request.form.get('threshold',0.5) # 調用Api服務 objects = object_detection_api.get_objects(file_name, threshold) # 模板顯示 return render_template('index.html',json_data = objects,img=image_file.filename)
curl http://localhost:5000 | python -m json.tool
[ { "y2": 0.9886252284049988, "class_name": "bed", "x2": 0.4297400414943695, "score": 0.9562674164772034, "y1": 0.5202791094779968, "x1": 0 }, { "y2": 0.9805927872657776, "class_name": "couch", "x2": 0.4395904541015625, "score": 0.6422878503799438, "y1": 0.5051193833351135, "x1": 0.00021047890186309814 } ]
在瀏覽器訪問網址體驗
你們確定很好奇,怎麼訓練本身須要檢測的物體,能夠期待下一篇文章