前面已經介紹了幾種經典的目標檢測算法,光學習理論不實踐的效果並不大,這裏咱們使用谷歌的開源框架來實現目標檢測。至於爲何不去本身實現呢?主要是由於本身實現比較麻煩,並且調參比較麻煩,咱們直接利用別人的庫去學習,能夠節約不少時間,並且逐漸吃透別人代碼,使得咱們能夠慢慢的接受。html
Object Detection API是谷歌開放的一個內部使用的物體識別系統。2016年 10月,該系統在COCO識別挑戰中名列第一。它支持當前最佳的實物檢測模型,可以在單個圖像中定位和識別多個對象。該系統不只用於谷歌於自身的產品和服務,還被推廣至整個研究社區。python
Object Detection 模塊的位置與slim的位置相近,同在github.com 中TensorFlow 的models\research目錄下。相似slim, Object Detection也囊括了各類關於物體檢測的各類先進模型:git
上述每個模型的凍結權重 (在COCO數據集上訓練)可被直接加載使用。github
SSD模型使用了輕量化的MobileNet,這意味着它們能夠垂手可得地在移動設備中實時使用。谷歌使用了 Faster R-CNN模型須要更多計算資源,但結果更爲準確。算法
在在實物檢測領域,訓練模型的最權威數據集就是COCO數據集。
COCO數據集是微軟發佈的一個能夠用來進行圖像識別訓練的數據集,官方網址爲http://mscoco.org 其圖像主要從複雜的平常場景中截取,圖像中的目標經過精確的segmentation進行位置的標定。
COCO數據集包括91類目標,分兩部分發布,前部分於2014年發佈,後部分於2015年發佈。express
Objet Detection API使用protobufs來配置模型和訓練參數,這些文件以".proto"的擴展名放models\research\object_detection\protos下。在使用框架以前,必須使用protobuf庫將其編譯成py文件才能夠正常運行。protobuf庫的下載地址爲https://github.com/google/protobuf/releases/tag/v2.6.1apache
下載並解壓protoc-2.6.1-win32.zip到models\research路徑下。ubuntu
打開cmd命令行,進入models\research目錄下,執行以下命令windows
protoc.exe object_detection/protos/*.proto --python_out=.
若是不顯示任何信息,則代表運行成功了,爲了檢驗成功效果,來到models\research\object_detection\protos下,能夠看到生成不少.py文件。數組
若是前面兩步都完成了,下面能夠測試一下object detection API是否能夠正常使用,還須要兩步操做:
代表object detection API一切正常,可使用、
爲了避免用每次都將文件複製到Object Detection文件夾外,能夠將Object Detection加到python引入庫的默認搜索路徑中,將Object Detection文件整個複製到anaconda3安裝文件目錄下lib\site-packages下:
這樣不管文件在哪裏,只要搜索import Objec Detection xxx,系統到會找到Objec Detection。
以前已經說過Objec Detection API默認提供了5個預訓練模型。他們都是使用COCO數據集訓練完成的,如何使用這些預訓練模型呢?官方已經給了一個用jupyter notebook編寫好的例子。首先在research文件下下,運行命令:jupyter-notebook,會直接打開http://localhost:8888/tree。
接着打開object_detection文件夾,並單擊object_detection_tutorial.jpynb運行示例文件。
該代碼使用Object Detection API基於COCO上訓練的ssd_mobilenet_v1模型,對任意圖片進行分類識別。
以前介紹的已有模型,在下面網站能夠下載:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
每個壓縮文件裏包含如下文件:
咱們在models\research文件夾下建立一個文件夾my_download_pretrained,用於保存預訓練的模型。
咱們對該代碼進行一些修改,並給出該代碼的中文註釋:
在models\research下建立my_object_detection.py文件。程序只能在GPU下運行,CPU會報錯。
# -*- coding: utf-8 -*- """ Created on Tue Jun 5 20:34:06 2018 @author: zy """ ''' 調用Object Detection API進行實物檢測 須要GPU運行環境,CPU下會報錯 模型下載網址:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md TensorFlow 生成的 .ckpt 和 .pb 都有什麼用? https://www.cnblogs.com/nowornever-L/p/6991295.html 如何用Tensorflow訓練模型成pb文件(一)——基於原始圖片的讀取 https://blog.csdn.net/u011463646/article/details/77918980?fps=1&locationNum=7 ''' import matplotlib.pyplot as plt import numpy as np import os import tensorflow as tf from object_detection.utils import label_map_util from object_detection.utils import visualization_utils as vis_util from PIL import Image def test(): #重置圖 tf.reset_default_graph() ''' 載入模型以及數據集樣本標籤,加載待測試的圖片文件 ''' #指定要使用的模型的路徑 包含圖結構,以及參數 PATH_TO_CKPT = './my_download_pretrained/ssd_mobilenet_v1_coco_2017_11_17/frozen_inference_graph.pb' #測試圖片所在的路徑 PATH_TO_TEST_IMAGES_DIR = './object_detection/test_images' TEST_IMAGE_PATHS = [os.path.join(PATH_TO_TEST_IMAGES_DIR,'image{}.jpg'.format(i)) for i in range(1,3) ] #數據集對應的label mscoco_label_map.pbtxt文件保存了index到類別名的映射 PATH_TO_LABELS = os.path.join('./object_detection/data','mscoco_label_map.pbtxt') NUM_CLASSES = 90 #從新定義一個圖 output_graph_def = tf.GraphDef() with tf.gfile.GFile(PATH_TO_CKPT,'rb') as fid: #將*.pb文件讀入serialized_graph serialized_graph = fid.read() #將serialized_graph的內容恢復到圖中 output_graph_def.ParseFromString(serialized_graph) #print(output_graph_def) #將output_graph_def導入當前默認圖中(加載模型) tf.import_graph_def(output_graph_def,name='') print('模型加載完成') #載入coco數據集標籤文件 label_map = label_map_util.load_labelmap(PATH_TO_LABELS) categories = label_map_util.convert_label_map_to_categories(label_map,max_num_classes = NUM_CLASSES,use_display_name = True) category_index = label_map_util.create_category_index(categories) ''' 定義session ''' def load_image_into_numpy_array(image): ''' 將圖片轉換爲ndarray數組的形式 ''' im_width,im_height = image.size return np.array(image.getdata()).reshape((im_height,im_width,3)).astype(np.uint0) #設置輸出圖片的大小 IMAGE_SIZE = (12,8) #使用默認圖,此時已經加載了模型 detection_graph = tf.get_default_graph() with tf.Session(graph=detection_graph) as sess: for image_path in TEST_IMAGE_PATHS: image = Image.open(image_path) #將圖片轉換爲numpy格式 image_np = load_image_into_numpy_array(image) ''' 定義節點,運行並可視化 ''' #將圖片擴展一維,最後進入神經網絡的圖片格式應該是[1,?,?,3] image_np_expanded = np.expand_dims(image_np,axis = 0) ''' 獲取模型中的tensor ''' image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') #boxes用來顯示識別結果 boxes = detection_graph.get_tensor_by_name('detection_boxes:0') #Echo score表明識別出的物體與標籤匹配的類似程度,在類型標籤後面 scores = detection_graph.get_tensor_by_name('detection_scores:0') classes = detection_graph.get_tensor_by_name('detection_classes:0') num_detections = detection_graph.get_tensor_by_name('num_detections:0') #開始檢查 boxes,scores,classes,num_detections = sess.run([boxes,scores,classes,num_detections], feed_dict={image_tensor:image_np_expanded}) #可視化結果 vis_util.visualize_boxes_and_labels_on_image_array( image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8) plt.figure(figsize=IMAGE_SIZE) print(type(image_np)) print(image_np.shape) image_np = np.array(image_np,dtype=np.uint8) plt.imshow(image_np) if __name__ == '__main__': test()
以VOC 2012數據集爲例,介紹如何使用Object Detection API訓練新的模型。VOC 2012是VOC2007數據集的升級版,一共有11530張圖片,每張圖片都有標準,標註的物體包括人、動物(如貓、狗、鳥等)、交通工具(如車、船飛機等)、傢俱(如椅子、桌子、沙發等)在內的20個類別。
首先下載數據集,並將其轉換爲tfrecord格式。下載地址爲:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar。
首先下載谷歌models庫,而後刪除一些沒必要要的文件,獲得文件結構以下:
在research文件夾下,建立一個voc文件夾,把VOC2012解壓到這個文件夾下,解壓後,獲得一個VOCdevkit文件夾:
JPEGImages文件中文件夾裏存放了所有的訓練圖片和驗證圖片。
對於每一張圖像,都在Annotations文件夾中存放有對應的xml文件。保存着物體框的標註,包括圖片文件名,圖片大小,圖片邊界框等信息。
以2007_000027.xml爲例:
<annotation> #數據所在的文件夾名 <folder>VOC2012</folder> #圖片名稱 <filename>2007_000027.jpg</filename> <source> <database>The VOC2007 Database</database> <annotation>PASCAL VOC2007</annotation> <image>flickr</image> </source> #圖片的寬和高 <size> <width>486</width> <height>500</height> <depth>3</depth> </size> <segmented>0</segmented> <object> #類別名 <name>person</name> #物體的姿式 <pose>Unspecified</pose> #物體是否被部分遮擋 <truncated>0</truncated> ##是否爲難以辨識的物體, 主要指要結合背景才能判斷出類別的物體。雖有標註, 但通常忽略這類物體 跳過難以識別的? <difficult>0</difficult> #邊界框 <bndbox> <xmin>174</xmin> <ymin>101</ymin> <xmax>349</xmax> <ymax>351</ymax> </bndbox> #下面的數據是人體各個部位邊界框 <part> <name>head</name> <bndbox> <xmin>169</xmin> <ymin>104</ymin> <xmax>209</xmax> <ymax>146</ymax> </bndbox> </part> <part> <name>hand</name> <bndbox> <xmin>278</xmin> <ymin>210</ymin> <xmax>297</xmax> <ymax>233</ymax> </bndbox> </part> <part> <name>foot</name> <bndbox> <xmin>273</xmin> <ymin>333</ymin> <xmax>297</xmax> <ymax>354</ymax> </bndbox> </part> <part> <name>foot</name> <bndbox> <xmin>319</xmin> <ymin>307</ymin> <xmax>340</xmax> <ymax>326</ymax> </bndbox> </part> </object> </annotation>
ImageSets文件夾包括Action Layout Main Segmentation四部分,Action存放的是人的動做,Layout存放人體部位數據,Main存放的是圖像物體識別數據(裏面的test.txt,train.txt,val.txt,trainval.txt當本身製做數據集時須要生成)。
ImageSets\Main文件夾以下。
SegmentationClass(標註出每個像素的類別)和SegmentationObject(標註出每一個像素屬於哪個物體)是分割相關的。
把pascal_label_map.pbtxt文件複製到voc文件夾下,這個文件存放在voc2012數據集物體的索引和對應的名字。
從object_detection\dataset_tools下把create_pascal_tf_record.py文件複製到research文件夾下,這個代碼是爲VOC2012數據集提早編寫好的。代碼以下:
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== r"""Convert raw PASCAL dataset to TFRecord for object_detection. Example usage: ./create_pascal_tf_record --data_dir=/home/user/VOCdevkit \ --year=VOC2012 \ --output_path=/home/user/pascal.record """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import hashlib import io import logging import os from lxml import etree import PIL.Image import tensorflow as tf from object_detection.utils import dataset_util from object_detection.utils import label_map_util import sys #配置logging logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', level=logging.INFO, stream=sys.stdout) #命令行參數 主要包括數據集根目錄,數據類型,輸出tf文件路徑等 flags = tf.app.flags flags.DEFINE_string('data_dir', '', 'Root directory to raw PASCAL VOC dataset.') flags.DEFINE_string('set', 'train', 'Convert training set, validation set or ' 'merged set.') flags.DEFINE_string('annotations_dir', 'Annotations', '(Relative) path to annotations directory.') flags.DEFINE_string('year', 'VOC2007', 'Desired challenge year.') flags.DEFINE_string('output_path', '', 'Path to output TFRecord') flags.DEFINE_string('label_map_path', 'voc/pascal_label_map.pbtxt', 'Path to label map proto') flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore ' 'difficult instances') FLAGS = flags.FLAGS SETS = ['train', 'val', 'trainval', 'test'] YEARS = ['VOC2007', 'VOC2012', 'merged'] def dict_to_tf_example(data, dataset_directory, label_map_dict, ignore_difficult_instances=False, image_subdirectory='JPEGImages'): """Convert XML derived dict to tf.Example proto. Notice that this function normalizes the bounding box coordinates provided by the raw data. Args: data: dict holding PASCAL XML fields for a single image (obtained by running dataset_util.recursive_parse_xml_to_dict) dataset_directory: Path to root directory holding PASCAL dataset label_map_dict: A map from string label names to integers ids. ignore_difficult_instances: Whether to skip difficult instances in the dataset (default: False). image_subdirectory: String specifying subdirectory within the PASCAL dataset directory holding the actual image data. Returns: example: The converted tf.Example. Raises: ValueError: if the image pointed to by data['filename'] is not a valid JPEG """ #獲取圖片相對數據集的相對路徑 img_path = os.path.join(data['folder'], image_subdirectory, data['filename']) #獲取圖片絕對路徑 full_path = os.path.join(dataset_directory, img_path) #讀取圖片 with tf.gfile.GFile(full_path, 'rb') as fid: encoded_jpg = fid.read() encoded_jpg_io = io.BytesIO(encoded_jpg) image = PIL.Image.open(encoded_jpg_io) if image.format != 'JPEG': raise ValueError('Image format not JPEG') key = hashlib.sha256(encoded_jpg).hexdigest() #獲取圖片的寬和高 width = int(data['size']['width']) height = int(data['size']['height']) xmin = [] ymin = [] xmax = [] ymax = [] classes = [] classes_text = [] truncated = [] poses = [] difficult_obj = [] for obj in data['object']: #是否爲難以辨識的物體, 主要指要結合背景才能判斷出類別的物體。雖有標註, 但通常忽略這類物體 跳過難以識別的? difficult = bool(int(obj['difficult'])) if ignore_difficult_instances and difficult: continue difficult_obj.append(int(difficult)) #bounding box 計算目標邊界 歸一化到[0,1]之間 左上角座標,右下角座標 xmin.append(float(obj['bndbox']['xmin']) / width) ymin.append(float(obj['bndbox']['ymin']) / height) xmax.append(float(obj['bndbox']['xmax']) / width) ymax.append(float(obj['bndbox']['ymax']) / height) #類別名 classes_text.append(obj['name'].encode('utf8')) #獲取該類別對應的標籤 classes.append(label_map_dict[obj['name']]) #物體是否被部分遮擋 truncated.append(int(obj['truncated'])) #物體的姿式 poses.append(obj['pose'].encode('utf8')) #tf文件一條記錄格式 example = tf.train.Example(features=tf.train.Features(feature={ 'image/height': dataset_util.int64_feature(height), 'image/width': dataset_util.int64_feature(width), 'image/filename': dataset_util.bytes_feature( data['filename'].encode('utf8')), 'image/source_id': dataset_util.bytes_feature( data['filename'].encode('utf8')), 'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')), 'image/encoded': dataset_util.bytes_feature(encoded_jpg), 'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')), 'image/object/bbox/xmin': dataset_util.float_list_feature(xmin), 'image/object/bbox/xmax': dataset_util.float_list_feature(xmax), 'image/object/bbox/ymin': dataset_util.float_list_feature(ymin), 'image/object/bbox/ymax': dataset_util.float_list_feature(ymax), 'image/object/class/text': dataset_util.bytes_list_feature(classes_text), 'image/object/class/label': dataset_util.int64_list_feature(classes), 'image/object/difficult': dataset_util.int64_list_feature(difficult_obj), 'image/object/truncated': dataset_util.int64_list_feature(truncated), 'image/object/view': dataset_util.bytes_list_feature(poses), })) return example def main(_): ''' 主要是經過讀取VOCdevkit\VOC2012\Annotations下的xml文件 而後獲取對應的圖片文件的路徑,圖片大小,文件名,邊界框、以及圖片數據等信息寫入rfrecord文件 ''' if FLAGS.set not in SETS: raise ValueError('set must be in : {}'.format(SETS)) if FLAGS.year not in YEARS: raise ValueError('year must be in : {}'.format(YEARS)) data_dir = FLAGS.data_dir years = ['VOC2007', 'VOC2012'] if FLAGS.year != 'merged': years = [FLAGS.year] #建立對象,用於向記錄文件寫入記錄 writer = tf.python_io.TFRecordWriter(FLAGS.output_path) #獲取類別名->index的映射 dict類型 label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path) for year in years: logging.info('Reading from PASCAL %s dataset.', year) #獲取aeroplane_train.txt文件的全路徑 改文件保存部分文件名(一共5717/5823個文件,各種圖片都有) examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main', 'aeroplane_' + FLAGS.set + '.txt') #獲取全部圖片標註xml文件的路徑 annotations_dir = os.path.join(data_dir, year, FLAGS.annotations_dir) #list 存放文件名 examples_list = dataset_util.read_examples_list(examples_path) #遍歷annotations_dir目錄下,examples_list中指定的xml文件 for idx, example in enumerate(examples_list): if idx % 100 == 0: logging.info('On image %d of %d', idx, len(examples_list)) path = os.path.join(annotations_dir, example + '.xml') with tf.gfile.GFile(path, 'r') as fid: xml_str = fid.read() xml = etree.fromstring(xml_str) #獲取annotation節點的內容 data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation'] #把數據整理成tfrecord須要的數據結構 tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict, FLAGS.ignore_difficult_instances) #向tf文件寫入一條記錄 writer.write(tf_example.SerializeToString()) writer.close() if __name__ == '__main__': tf.app.run()
若是讀者但願使用本身的數據集,有兩種方法:
在research文件夾中,執行如下命令能夠把VOC 2012數據集轉換爲tfrecord格式,轉換好的tfrecord保存在voc文件夾下,分別爲pasal_train.record和pascal_val.record:
python create_pascal_tf_record.py --data_dir voc/VOCdevkit --year=VOC2012 --set=train --output_path=voc/pascal_train.record
python create_pascal_tf_record.py --data_dir voc/VOCdevkit --year=VOC2012 --set=val --output_path=voc/pascal_val.record
以上執行完成後,咱們把voc文件夾和create_pascal_tf_record.py文件剪切到object_detection文件下。(其實在以前咱們就能夠直接把文件建立在object_detection文件夾下,主要是由於create_pascal_tf_record.py在執行的時候會調用到object_detection庫,我是懶得把object_detection庫加入環境變量了,因此才這樣作。)
若是想配置臨時環境變量,在research目錄下:
windows下命令:
set PYTHONPATH=%PYTHONPATH%;%CD%;%CD%/slim
ubuntu系統下:
export PYTHONPATH=$PYTHONPATH:${PWD}:${PWD}/slim
下載完VOC 2012數據集後,須要選擇合適的訓練模型。這裏以Faster R-CNN + Inception-ResNet_v2模型爲例進行介紹。首先下載在COCO數據集上預訓練的Faster R-CNN + Inception-ResNet_v2模型。解壓到voc文件夾下,如圖:
Object Detection API是依賴一種特殊的設置文件進行訓練的。在object_detection/samples/configs文件夾下,有一些設置文件的示例。能夠參考faster_rcnn_inception_resnet_v2_atrous_coco.config文件建立的設置文件。先將faster_rcnn_inception_resnet_v2_atrous_coco.config複製一份到voc文件夾下,命名爲faster_rcnn_inception_resnet_v2_atrous_voc.config。
faster_rcnn_inception_resnet_v2_atrous_voc.config文件有7處須要修改:
gradient_clipping_by_norm: 10.0 fine_tune_checkpoint: "voc/faster_rcnn_inception_resnet_v2_atrous_coco_2018_01_28/model.ckpt" from_detection_checkpoint: true # Note: The below line limits the training process to 200K steps, which we # empirically found to be sufficient enough to train the pets dataset. This # effectively bypasses the learning rate schedule (the learning rate will # never decay). Remove the below line to train indefinitely. num_steps: 200000 data_augmentation_options { random_horizontal_flip { } } } train_input_reader: { tf_record_input_reader { input_path: "voc/pascal_train.record" } label_map_path: "voc/pascal_label_map.pbtxt" } eval_config: { num_examples: 5823 # Note: The below line limits the evaluation process to 10 evaluations. # Remove the below line to evaluate indefinitely. max_evals: 10 } eval_input_reader: { tf_record_input_reader { input_path: "voc/pascal_val.record" } label_map_path: "voc/pascal_label_map.pbtxt" shuffle: false num_readers: 1 }
最後,在voc文件夾中新建一個train_dir做爲保存模型和日誌的目錄,在使用object_detection目錄下的train.py文件訓練的時候會使用到slim下庫,所以咱們須要先配置臨時環境變量,在research目錄下:
windows下命令:
set PYTHONPATH=%PYTHONPATH%;%CD%;%CD%/slim
ubuntu系統下:
export PYTHONPATH=$PYTHONPATH:${PWD}:${PWD}/slim
在object_detection目錄下,使用下面的命令就能夠開始訓練了:(要在GPU下運行,在CPU運行會拋出module 'tensorflow' has no attribute 'data'的錯誤)
python train.py --train_dir voc/train_dir/ --pipeline_config_path voc/faster_rcnn_inception_resnet_v2_atrous_voc.config
解決:
出錯緣由:知乎的大佬說是python3的兼容問題
解決辦法:把research/object_detection/utils/learning_schedules.py
文件的 第167-169行由
解決:
出錯緣由:知乎的大佬說是python3的兼容問題
解決辦法:把research/object_detection/utils/learning_schedules.py
文件的 第167-169行由
程序運行結果以下:
....
因爲咱們在設置文件中設置的訓練步數爲200k,所以整個訓練可能會消耗大量時間,這裏我訓練到4萬屢次就強行終止訓練了.
num_steps: 200000
訓練的日誌和最終的模型(默認保存了5個不一樣步數時的模型)都會保存在train_dir中,所以,一樣可使用TensorBoard來監控訓練狀況。
使用cmd來到日誌文件的上級路徑下,輸入以下命令:
tensorboard --logdir ./train_dir
接着打開瀏覽器,輸入http://127.0.0.1:6006,若是訓練時保存了一下變量,則能夠在這裏看到(我這裏沒有保存變量):
須要注意的是,若是發生內存和顯存不足報錯的狀況,除了使用較小模型進行訓練外,還能夠修改配置文件中的如下內容:
image_resizer { keep_aspect_ratio_resizer { min_dimension: 600 max_dimension: 1024 } }
這個部分表示將輸入圖像進行等比例縮放再進行訓練,縮放後的最大邊長爲1024,最小邊長爲600.能夠將整兩個數值改小(我訓練的時候就分別改爲512和300),使用的顯存就會變小。不過這樣作也可能致使模型的精度降低,所以咱們須要根據本身的狀況選擇適合的處理方法。
如何將train_dir中的checkpoint文件導出並用於單張圖片的目標檢測?TensorFlow Object Detection API提供了一個export_inference_graph.py腳本用於導出訓練好的模型。具體方法是在research目錄下執行:
python export_inference_graph.py --input_type image_tensor --pipeline_config_path voc/faster_rcnn_inception_resnet_v2_atrous_voc.config --trained_checkpoint_prefix voc/train_dir/model.ckpt-47837 --output_directory voc/export
其中model.ckpt-47837表示使用第47837步保存的模型。咱們須要根據voc/train_dir時間保存的checkpoint,將47837改成合適的數值。導出的模型是voc/export/frozen_inference_graph.pb文件。
而後能夠參考上面咱們介紹的jupyter notebook代碼,自行編寫利用導出模型對單張圖片作目標檢測的腳本。而後將PATH_TO_CKPT的值賦值爲voc/export/frozen_inference_graph.pb,即導出模型文件。將PATH_TO_LABELS修改成voc/pascal_label_map.pbtxt,即各個類別的名稱。把NUM_CLASSES設置爲20。其它代碼均可以不改變,而後測試咱們的圖片(注意:須要添加上文中提到的臨時環境變量),因爲VOC2012數據集中的類別也有狗和人,所以咱們能夠直接使用object_detection/test_images中的測試圖片。
# -*- coding: utf-8 -*- """ Created on Tue Jun 5 20:34:06 2018 @author: zy """ ''' 調用Object Detection API進行實物檢測 須要GPU運行環境,CPU下會報錯 TensorFlow 生成的 .ckpt 和 .pb 都有什麼用? https://www.cnblogs.com/nowornever-L/p/6991295.html 如何用Tensorflow訓練模型成pb文件(一)——基於原始圖片的讀取 https://blog.csdn.net/u011463646/article/details/77918980?fps=1&locationNum=7 ''' #運行前須要把object_detection添加到環境變量 #ubuntu 在research目錄下,打開終端,執行export PYTHONPATH=$PYTHONPATH:${PWD}:${PWD}/slim 而後執行spyder,運行程序 #windows 在research目錄下,打開cmd,執行set PYTHONPATH=%PYTHONPATH%;%CD%;%CD%/slim 而後執行spyder,運行程序 import matplotlib.pyplot as plt import numpy as np import os import tensorflow as tf from object_detection.utils import label_map_util from object_detection.utils import visualization_utils as vis_util from PIL import Image def test(): #重置圖 tf.reset_default_graph() ''' 載入模型以及數據集樣本標籤,加載待測試的圖片文件 ''' #指定要使用的模型的路徑 包含圖結構,以及參數 PATH_TO_CKPT = './voc/export/frozen_inference_graph.pb' #測試圖片所在的路徑 PATH_TO_TEST_IMAGES_DIR = './test_images' TEST_IMAGE_PATHS = [os.path.join(PATH_TO_TEST_IMAGES_DIR,'image{}.jpg'.format(i)) for i in range(1,3) ] #數據集對應的label pascal_label_map.pbtxt文件保存了index和類別名之間的映射 PATH_TO_LABELS = './voc/pascal_label_map.pbtxt' NUM_CLASSES = 20 #從新定義一個圖 output_graph_def = tf.GraphDef() with tf.gfile.GFile(PATH_TO_CKPT,'rb') as fid: #將*.pb文件讀入serialized_graph serialized_graph = fid.read() #將serialized_graph的內容恢復到圖中 output_graph_def.ParseFromString(serialized_graph) #print(output_graph_def) #將output_graph_def導入當前默認圖中(加載模型) tf.import_graph_def(output_graph_def,name='') print('模型加載完成') #載入coco數據集標籤文件 label_map = label_map_util.load_labelmap(PATH_TO_LABELS) categories = label_map_util.convert_label_map_to_categories(label_map,max_num_classes = NUM_CLASSES,use_display_name = True) category_index = label_map_util.create_category_index(categories) ''' 定義session ''' def load_image_into_numpy_array(image): ''' 將圖片轉換爲ndarray數組的形式 ''' im_width,im_height = image.size return np.array(image.getdata()).reshape((im_height,im_width,3)).astype(np.uint0) #設置輸出圖片的大小 IMAGE_SIZE = (12,8) #使用默認圖,此時已經加載了模型 detection_graph = tf.get_default_graph() with tf.Session(graph=detection_graph) as sess: for image_path in TEST_IMAGE_PATHS: image = Image.open(image_path) #將圖片轉換爲numpy格式 image_np = load_image_into_numpy_array(image) ''' 定義節點,運行並可視化 ''' #將圖片擴展一維,最後進入神經網絡的圖片格式應該是[1,?,?,3] image_np_expanded = np.expand_dims(image_np,axis = 0) ''' 獲取模型中的tensor ''' image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') #boxes用來顯示識別結果 boxes = detection_graph.get_tensor_by_name('detection_boxes:0') #Echo score表明識別出的物體與標籤匹配的類似程度,在類型標籤後面 scores = detection_graph.get_tensor_by_name('detection_scores:0') classes = detection_graph.get_tensor_by_name('detection_classes:0') num_detections = detection_graph.get_tensor_by_name('num_detections:0') #開始檢查 boxes,scores,classes,num_detections = sess.run([boxes,scores,classes,num_detections], feed_dict={image_tensor:image_np_expanded}) #可視化結果 vis_util.visualize_boxes_and_labels_on_image_array( image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8) plt.figure(figsize=IMAGE_SIZE) print(type(image_np)) print(image_np.shape) image_np = np.array(image_np,dtype=np.uint8) plt.imshow(image_np) if __name__ == '__main__': test()
咱們再來看一下若是直接使用官方在COCO數據集上訓練的Faster R-CNN + Inception-ResNet_v2模型,進行目標檢測:
咱們能夠看到咱們使用本身數據集訓練的模型進行目標檢測效果沒有官方提供的模型那個好,可能有如下幾個緣由:
參考文章:
[1]將數據集作成VOC2007格式用於Faster-RCNN訓練
[2]VOC數據集製做2——ImageSets\Main裏的四個txt文件
[3]21個項目玩轉深度學習-何之源
[4]深度學習之TensorFlow-李金洪