tensorflow object detection API
創造一些精確的機器學習模型用於定位和識別一幅圖像裏的多元目標仍然是一個計算機視覺領域的核心挑戰。tensorflow object detection API是一個開源的基於tensorflow的框架,使得建立,訓練以及應用目標檢測模型變得簡單。在谷歌咱們已經肯定發現這個代碼對咱們的計算機視覺研究須要頗有用,咱們但願這個對你也會頗有用。
1. 安裝tensorflow以及下載object detection api
安裝tensorflow:
對於CPU版本:pip install tensorflow
對於GPU版本:pip install tensorflow-gpu
升級tensorflow到最新版1.4.0:pip install --upgrade tensorflow-gpu
安裝必須庫:
sudo pip install pillow
sudo pip install lxml
sudo pip install jupyter
sudo pip install matplotlib
protobuf編譯:在tensorflow/models/research/目錄下
protoc object_detection/protos/*.proto --python_out=.
添加pythonpath,在tensorflow/models/research/目錄下
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
測試安裝:
python object_detection/builders/model_builder_test.py
下載object detection api:
2.運行演示文件:object_detection_tutorial.ipynb
2.訓練數據集準備
在model下新建文件夾dataset,將我使用的pascal voc格式數據集(VOC3000)轉換爲TFRecord格式,並存放在dataset文件夾下:
將create_pascal_tf_record.py文件複製到dataset文件夾下:
(1)修改第55行:YEARS = ['VOC2007', 'VOC2012','VOC3000', 'merged']
(2)修改第58行:def dict_to_tf_example(data,
改成def dict_to_tf_example(year,data,
(3)修改第84行:img_path = os.path.join(data['folder'], image_subdirectory, data['filename'])
改成img_path = os.path.join(year,image_subdirectory, data['filename'])
(4)修改第152行:years = ['VOC2007', 'VOC2012']
改成years = ['VOC2007', 'VOC2012','VOC3000']
(5)修改第163行:examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main',
'aeroplane_' + FLAGS.set + '.txt')
改成 examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main',
FLAGS.set + '.txt')
(6)修改第175行:tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict,
FLAGS.ignore_difficult_instances)
改成tf_example = dict_to_tf_example(year, data, FLAGS.data_dir, label_map_dict, FLAGS.ignore_difficult_instances)
以上涉及到路徑須要根據本身數據集調整。
運行如下命令,就能夠獲得用於訓練和驗證的tf_record文件:
python data/create_pascal3000_tf_record.py
--data_dir=/data/models/research/object_detection/dataset/VOCdevkit
--label_map_path=/data/models/research/object_detection/dataset/pascal_label_map.pbtxt
--year=VOC3000
--set=train
--output_path=/data/models/research/object_detection/dataset/pascal_train.record
python data/create_pascal3000_tf_record.py
--data_dir=/data/models/research/object_detection/dataset/VOCdevkit
--label_map_path=/data/models/research/object_detection/dataset/pascal_label_map.pbtxt
--year=VOC3000
--set=val
--output_path=/data/models/research/object_detection/dataset/pascal_val.record
3.解壓SSDMobilenet模型(下載API的時候已經下載好了)
tar -xvf ssd_mobilenet_v1_coco_2017_11_08.tar.gz
獲得以下文件:
將文件夾裏面的model.ckpt.*的三個文件copy到dataset文件夾。
4.修改config文件。
將文件object_detection/samples/configs/ssd_mobilenet_v1_pets.config複製到dataset.
修改:
(1)num_classes修改成本身的類別數目,個人是10
(2)修改路徑。(5處)
fine_tune_checkpoint: "/data/models/research/object_detection/dataset/model.ckpt"
input_path: "/data/models/research/object_detection/dataset/pascal_train.record"
label_map_path: "/data/models/research/object_detection/dataset/pascal_label_map.pbtxt"
input_path: "/data/models/research/object_detection/dataset/pascal_val.record"
label_map_path: "/data/models/research/object_detection/dataset/pascal_label_map.pbtxt"
保存config文件,重命名爲ssd_mobilenet_v1_pascal.config。個人dataset文件夾如圖所示。
5.開始訓練(這裏我換用了另外一個模型faster_rcnn_inception_resnet)
python train.py
--logtostderr
--train_dir=/home/amax/guo/models/object_detection/dataset/output
--pipeline_config_path=/home/amax/guo/models/object_detection/dataset/faster_rcnn_inception_resnet/faster_rcnn_inception_resnet_v2_atrous_pets.config
6.評估模型
在dataset文件夾下新建evaluation文件夾
python eval.py
--logtostderr
--checkpoint_dir=/home/amax/guo/models/object_detection/dataset/output
--pipeline_config_path=/home/amax/guo/models/object_detection/dataset/faster_rcnn_inception_resnet/faster_rcnn_inception_resnet_v2_atrous_pets.config
--eval_dir=/home/amax/guo/models/object_detection/dataset/evaluation
報錯:ImportError: No module named nets
解決辦法:導入slim模塊
import sys
sys.path.append('/data/models/research/slim')
7.查看結果
tensorboard --logdir=/home/amax/guo/models/object_detection/dataset
8.生成能夠被調用的模型
python object_detection/export_inference_graph.py --input_type
image_tensor
--pipeline_config_path
/home/amax/guo/models/object_detection/dataset/faster_rcnn_inception_resnet/faster_rcnn_inception_resnet_v2_atrous_pets.config
--trained_checkpoint_prefix
/home/amax/guo/models/object_detection/dataset/output/model.ckpt-10000
--output_directory
/home/amax/guo/models/object_detection/dataset/savedModelcd
生成的模型如圖所示:
9.調用生成的模型
修改object_detection_tutorial.py
PATH_TO_CKPT ='/home/amax/guo/models/object_detection/dataset/savedModel/frozen_inference_graph.pb'
PATH_TO_LABELS='/home/amax/guo/models/object_detection/dataset/pascal_label_map.pbtxt'
NUM_CLASSES = 10
結果以下: