使用TensorFlow Object Detection API+Google ML Engine訓練本身的手掌識別器

  上次使用Google ML Engine跑了一下TensorFlow Object Detection API中的Quick Start(http://www.cnblogs.com/take-fetter/p/8384564.html),可是遇到了不少錯誤,索性放棄了html

這兩天直接開始從本身的數據集開始製做手掌識別器。先放運行結果吧python

   

 全部代碼文件可在https://github.com/takefetter/hand-detection查看,歡迎star和issuegit

 

使用前所須要的準備:1.clone tensorflow models(site:https://github.com/tensorflow/models)github

          2.在model/research目錄下運行setup.py安裝object detection APIwindows

          3.其他必要條件:安裝tensorflow(版本需大於等於1.4),opencv-python等必須的packageapi

          4.安裝Google Cloud SDK,激活免費試用300美金(須要一張信用卡來驗證)和在命令行中使用gcloud init設置等機器學習

  •  準備數據集

  (關於手的圖片的dataset仍舊使用的dlib訓練(site:http://www.cnblogs.com/take-fetter/p/8321158.html)中的Hand Images Databases - https://www.mutah.edu.jo/biometrix/hand-images-databases.html提供的數據集,只不過此次使用了WEHI系列的圖片(MOHI的圖片我也試過,導入後會致使standard-gpu版的訓練沒法進行(內存不足)),做爲示例目前我只使用了1-50人的共計250張圖片)工具

   tensorflow訓練的數據集需爲TFRecord格式,咱們須要對訓練數據進行標註,可是我並無找到直接能夠標註生成的工具,還好有工具能夠生成Pascal VOC格式的xml文件      https://github.com/tzutalin/labelImg,推薦將圖片文件放於research/images中,保存xml文件夾位於research/images/xmls中學習

根據你要訓練的數據集,建立.pbtxt文件測試

  • 轉換爲tfrecord格式

   完成圖片標註後在xmls文件夾中運行xml_to_csv.py便可生成csv文件,再經過create_hand_tfrecord.py便可將圖片轉換爲hand.record文件

   須要注意的是,若是你須要訓練的數據集和我這裏的不同的話,create_hand_tfrecord.py的todo部分須要與你的.pbtxt文件內的內容一致

   (方法參考至https://github.com/datitran/raccoon_dataset 使用本做者的文件還能夠完成劃分測試集和分析數據等功能,固然我這裏並無使用)

  •  下載預訓練模型

   從新開始一個模型的訓練時間是很長的時間,而tensorflow model zoo爲咱們提供好了預訓練的模型(site:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#coco-trained-models-coco-models),選擇並下載一個 我選擇的是

速度最快的ssd_mobilenet_v1,下載後解壓可找到3個含有ckpt的文件,如圖

  以後還需下載並配置model對應的config文件(https://github.com/tensorflow/models/tree/master/research/object_detection/samples/configs)並修改文件中的內容

須要修改的地方有:

  1. num_classes: 改成pbtxt中類的數目
  2. PATH_TO_BE_CONFIGURED的部分改成相應的目錄
  3. num_steps定義了學習的上限 默認是200000 可本身更改,訓練過程當中也能夠隨時中止
  • 上傳文件並在Google Cloud Platform中訓練

  1.上傳3個ckpt文件以及config文件和.record文件 

      到google cloud控制檯-存儲目錄下,建立存儲分區(這裏使用takefetter_hand_detector),並新建data文件夾,拖拽上傳到該目錄中完成後的目錄和文件以下

+ takefetter_hand_detector/
  + data/
    - ssd_mobilenet_v1_hand.config - model.ckpt.index - model.ckpt.meta - model.ckpt.data-00000-of-00001
    - hand_label_map.pbtxt - hand.record

 

  2. 打包tf slim和object detection

     在research目錄下運行

python setup.py sdist (cd slim && python setup.py sdist)

  3.建立機器學習任務

    在research目錄下運行此命令 開始訓練

gcloud ml-engine jobs submit training `whoami`_object_detection_`date +%s` \ --runtime-version 1.4 \ --job-dir=gs://takefetter_hand_detector/train \ --packages dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz \ --module-name object_detection.train \ --region us-central1 \ --config object_detection/samples/cloud/cloud.yml \ -- \ --train_dir=gs://takefetter_hand_detector/train \ --pipeline_config_path=gs://takefetter_hand_detector/data/ssd_mobilenet_v1_hand.config

須要注意的地方有 

  1. windows下須要放在同一行運行 並刪除\
  2. cloud.yml文件中的內容能夠自行更改,我這裏的設置爲
    trainingInput: runtimeVersion: "1.4" scaleTier: CUSTOM masterType: standard_gpu workerCount: 2 workerType: standard_gpu parameterServerCount: 2 parameterServerType: standard

在提交任務後在 機器學習引擎-做業中便可看到具體狀況,每運行幾千次後在 takefetter_hand_detector/train中存儲對應cheakpoint的文件 如圖

以後下載須要的cheak的3個文件 複製到research目錄下(這裏用30045的3個文件做爲示例),並將research/object_detectIon目錄下的export_inference_graph.py複製到research目錄下 運行例如

python object_detection/export_inference_graph.py \ --input_type image_tensor \ --pipeline_config_path object_detection/samples/configs/ssd_mobilenet_v1_hand.config \ --trained_checkpoint_prefix model.ckpt-30045 \ --output_directory exported_graphs

在運行完成後research目錄中會生成文件夾exported_graphs_30045 包含的文件如圖所示

拷貝frozen_inference_graph.pb和pbtxt文件到test/hand_inference_graph文件夾,並運行hand_detector.py 便可獲得如文章開頭的結果

後記:

1.若是須要視頻實時的hand tracking,可以使用https://github.com/victordibia/handtracking 在個人渣本上FPS過低了......

2.我目前使用的數據集仍是較小訓練次數也比較少,很容易出現一些誤識別的狀況,以後還會加大數據集和訓練次數

3.換用其餘model應該也會顯著改善識別精確度

4.遇到任何問題,歡迎提問(雖然感受大多數stack overflow都有)

5.本地訓練要好不少,若是使用在Google Cloud訓練中可能會遇到問題,可是解決方法是將tensorflow版本改成1.2,可是1.2版本的object detection在準備階段就會遇到問題,目前來看確實無解。(畢竟API Caller)

6.本地訓練建議使用tensorflow版本爲1.2

感謝:

  1. https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_pets.md
  2. https://github.com/victordibia/handtracking
  3. https://pythonprogramming.net/testing-custom-object-detector-tensorflow-object-detection-api-tutorial/?completed=/training-custom-objects-tensorflow-object-detection-api-tutorial/
  4. https://github.com/datitran/raccoon_dataset
  5. https://www.mutah.edu.jo/biometrix/hand-images-databases.html
相關文章
相關標籤/搜索