首先,簡單介紹下,Tensorflow Object Detection API是一個構建在TensorFlow之上的開源框架,它使構建、訓練和部署對象檢測模型變得很容易python
首先,關於win10下深度學習基本環境的搭建,好比,anaconda, Tensorflow CPU或GPU版本,pycharm等安裝這塊就不說了,網上的教程不少。git
額外須要的python庫有 pillow, lxml,能夠經過pip install 命令進行安裝github
https://github.com/tensorflow/models,直接從github上下載源碼數組
Protoc是用來將下載來的 中的 object_detection/protos目錄下的proto文件編譯爲py文件bash
WIN下,建議下載3.4的版本,下載連接app
下載完成後,將對應目錄的bin文件夾目錄添加到環境變量中框架
cmd打開命令行,輸入 protoc,顯示以下內容說明安裝成功機器學習
將以前下載好的Tensorflow Object Detection文件解壓,命令行cd進入models-master\research目錄下,而後執行命令學習
protoc ./object_detection/protos/*.proto --python_out=. 複製代碼
將object_detection/protos目錄下的proto文件編譯爲py文件,ui
執行完畢後,進入object_detection/protos目錄下查看,能夠看到生成了對應的py文件
首先,在Pycharm中從新建立一個你的新項目,我這塊項目名稱爲 using_pre-trained_model_to_detect_objects,而後將下載的Tensorflow Object Detection中的models-master\research\object_detection拷貝進using_pre-trained_model_to_detect_objects新項目中
在項目中建立 object_detection_tutorial.py 文件用來進行目標檢測,項目結構爲:
預測程序以下,須要注意相關路徑問題:
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile
from distutils.version import StrictVersion
from collections import defaultdict
from io import StringIO
import matplotlib.pyplot as plt
from PIL import Image
from object_detection.utils import ops as utils_ops
if StrictVersion(tf.__version__) < StrictVersion('1.12.0'):
raise ImportError('Please upgrade your TensorFlow installation to v1.12.*.')
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util
MODEL_NAME = 'ssd_mobilenet_v1_coco_2017_11_17'
MODEL_FILE = MODEL_NAME + '.tar.gz'
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'
# pb模型存放位置.
PATH_TO_FROZEN_GRAPH = MODEL_NAME + '/frozen_inference_graph.pb'
# coco數據集的label映射文件
PATH_TO_LABELS = os.path.join('object_detection/data', 'mscoco_label_map.pbtxt')
PATH_TO_TEST_IMAGES_DIR = 'object_detection/test_images'
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]
#模型下載與解壓
def downloadModel():
opener = urllib.request.URLopener()
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
file_name = os.path.basename(file.name)
if 'frozen_inference_graph.pb' in file_name:
tar_file.extract(file, os.getcwd())
#加載模型
def loadModel():
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
return detection_graph
#將圖片轉換爲三維數組,數據類型爲uint8
def load_image_into_numpy_array(image):
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(im_height, im_width, 3)).astype(np.uint8)
#進行目標檢測
def run_inference_for_single_image(image, graph):
with graph.as_default():
with tf.Session() as sess:
# Get handles to input and output tensors
ops = tf.get_default_graph().get_operations()
all_tensor_names = {output.name for op in ops for output in op.outputs}
tensor_dict = {}
for key in [
'num_detections', 'detection_boxes', 'detection_scores',
'detection_classes'
]:
tensor_name = key + ':0'
if tensor_name in all_tensor_names:
tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
tensor_name)
image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')
# Run inference
output_dict = sess.run(tensor_dict,
feed_dict={image_tensor: image})
# all outputs are float32 numpy arrays, so convert types as appropriate
output_dict['num_detections'] = int(output_dict['num_detections'][0])
output_dict['detection_classes'] = output_dict[
'detection_classes'][0].astype(np.int64)
output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
output_dict['detection_scores'] = output_dict['detection_scores'][0]
return output_dict
def predict(detection_graph):
for image_path in TEST_IMAGE_PATHS:
image = Image.open(image_path)
# the array based representation of the image will be used later in order to prepare the
# result image with boxes and labels on it.
image_np = load_image_into_numpy_array(image)
# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(image_np, axis=0)
# Actual detection.
output_dict = run_inference_for_single_image(image_np_expanded, detection_graph)
# 獲得一個保存編號和類別描述映射關係的列表
category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)
# Visualization of the results of a detection.
vis_util.visualize_boxes_and_labels_on_image_array(
image_np,
output_dict['detection_boxes'],
output_dict['detection_classes'],
output_dict['detection_scores'],
category_index,
instance_masks=output_dict.get('detection_masks'),
use_normalized_coordinates=True,
line_thickness=8)
plt.figure(figsize=(12, 8))
plt.imshow(image_np)
plt.axis('off')
plt.show()
if __name__ == '__main__':
# downloadModel()
detection_graph = loadModel()
predict(detection_graph)複製代碼
輸出結果爲:
歡迎關注個人我的公衆號 AI計算機視覺工坊,本公衆號不按期推送機器學習,深度學習,計算機視覺等相關文章,歡迎你們和我一塊兒學習,交流。