打造本身的圖像識別模型

1.目標

本篇文章介紹的重點是如何使用TensorFlow在本身的圖像數據上訓練深度學習模型,主要涉及的方法是對已經預訓練好的ImageNet模型進行微調(Fine-tune)。使用谷歌的Colaboratory(python3 環境)實現。python

2.微調原理

什麼是微調?這裏以VGG16爲例進行講解。git

如圖下圖所示,VGG16的結構爲卷積+全鏈接層。卷積層分爲5個部分共13層,即圖中的conv1~conv5。還有3層是全鏈接層,即圖中的fc六、fe七、fc8。卷積層加上全鏈接層合起來一共爲16層,所以它被稱爲VGG16。若是要將VGG16的結構用於一個新的數據集,首先要去掉fc8這一層。緣由是fc8層的輸入是fc7層的特徵,輸出是1000類的機率,這1000類正好對應了ImageNet 模型中的1000個類別。在本身的數據中,類別數通常不是1000類,所以fc8層的結構在此時是不適用的,必須將fc8層去掉,從新採用符合數據集類別數的全鏈接層,做爲新的fe8。好比數據集爲5類,那麼新的fc8的輸出也應當是5類。github

此外,在訓練的時候,網絡的參數的初始值並非隨機化生成的,而是採用VGG16在ImageNet 上已經訓練好的參數做爲訓練的初始值。這樣作的緣由在於,在ImageNet數據集上訓練過的VGG16中的參數已經包含了大量有用的卷積過濾器,與其從零開始初始化VGG16的全部參數,不如使用已經訓練好的參數看成訓練的起點。這樣作不只能夠節約大量訓練時間,並且有助於分類器性能的提升。shell

載入VGG16的參數後,就能夠開始訓練了。此時須要指定訓練層數的範圍。通常來講,能夠選擇如下幾種範圍進行訓練:數據庫

  • 只訓練fc8。訓練範圍必定要包含fc8這一層。以前說過,fc8的結構被調整過,所以它的參數不能直接從lmageNet預訓練模型中取得。能夠只訓練fe8,保持其餘層的參數不動。這就至關於將VGG16看成一個「特徵提取器」:用fc7層提取的特徵作一個Softmax模型分類。這樣作的好處是訓練速度快,但每每性能不會太好。
  • 訓練全部參數。還能夠對網絡中的全部參數進行訓練,這種方法的訓練速度可能比較慢,可是能取得較高的性能,能夠充分發揮深度模型的威力。
  • 訓練部分參數。一般是固定淺層參數不變,訓練深層參數。如固定 conv一、conv2部分的參數不訓練,只訓練 conv三、conv 四、conv 五、fc六、fc七、fc8的參數。

這種訓練方法就是所謂的對神經網絡模型作微調。藉助微調,能夠從預訓練模型出發,將神經網絡應用到本身的數據集上。下面介紹如何在TensorFlow中進行微調。segmentfault

3.TensorFlow Slim 微調

TensorFlow Slim 是Google公司公佈的一個圖像分類工具包,它不只定義了一些方便的接口,還提供了不少ImageNet數據集上經常使用的網絡結構和預訓練模型。截至2017年7月,Slim提供包括VGG1六、VGG1九、Inception V1~V四、ResNet 50、ResNet 10一、MobileNet在內大多數經常使用模型的結構以及預訓練模型,更多的模型還會被持續添加進來。windows

3.1 數據準備

首先要將本身的數據集切分爲訓練集和驗證集,訓練集用於訓練模型,驗證集用來驗證模型的準確率。本次使用的是衛星圖片分類數據集,這個數據集一共有6個類別,見下表所示:網絡

類別名 含義
Wetland 農田
Glacier 冰川
Urban 城市區域
Rock 岩石
water 水域
Wood 森林

在data_prepare目錄中,用一個pic文件夾保存原始的圖像文件,圖像文件保存的結構以下:工具

data prepare/
    pic/
        train/
            wood/
            water/
            rock/
            wetland/
            glacier/
            urban/
        validation/
            wood/
            water/
            rock/
            wetland/
            glacier/
            urban/

將圖片分爲trainvalidation兩個目錄,分別表示訓練使用的圖片和驗證使用的圖片。在每一個目錄中,分別以類別名爲文件夾名保存全部圖像。在每一個類別文件夾下,存放的就是原始的圖像(如jpg格式的圖像文件)。下面,在data_prepare文件夾下,使用預先編制好的腳本data_convert.py,將圖片轉換爲爲tfrecord格式:性能

!python data_ convert.py -t pic/ \
    --train-shards 2 \
    --validation-shards 2 \
    --num-threads 2 \
    --dataset-name satellite

解釋這裏參數的含義:

  • -t pic/:表示轉換pic文件夾中的數據。pic文件夾中必須有一個train目錄和一個validation目錄,分別表明訓練和驗證數據集。每一個目錄下按類別存放了圖像數據。
  • --train-shards 2:將訓練數據集分爲兩塊,即最後的訓練數據就是兩個tfrecord格式的文件。若是讀者的數據集較大,能夠考慮將其分爲更多的數據塊。
  • --validation-shards 2:將驗證數據集分爲兩塊。
  • --num-threads 2:採用兩個線程產生數據。注意線程數必需要能整除train-shards 和validation-shards,來保證每一個線程處理的數據塊數是相同的。
  • --dataset-name satellite:給生成的數據集起一個名字。這裏將數據集起名叫「satellite」,最後生成文件的開頭就是satelite_train 和satelite_validation。

運行上述命令後,就能夠在pic文件夾中找到5個新生成的文件,分別是訓練數據 satellite_train_00000-of-00002.tfrecord、satellite_train_00001-of-00002.tfrecord,以及驗證數據 satellite validation_00000-of-00002.tfrecord、satellite validation_00001-of-00002.tfrecord。另外,還有一個文本文件label.txt,它表示圖片的內部標籤(數字)到真實類別(字符串)之間的映射順序。如圖片在tfrecord中的標籤爲0,那麼就對應label.txt 第一行的類別,在tfrecord的標籤爲1,就對應label.txt中第二行的類別,依此類推。

3.2 下載TensorFlow Slim

若是須要使用Slim微調模型,首先要下載Slim的源代碼。Slim的源代碼保存在tensorflow/models項目中,可使用下面的git命令下載tensorflow/models:

git clone https://github.com/tensorflow/models. git

找到models/research/目錄中的slim文件夾,這就是要用到的TensorFlow Slim的源代碼。這裏簡單介紹TensorFlow Slim的代碼結構,見下表。

文件夾或文件名 用途
datasets/ 定義一些訓練時使用的數據集。若是須要訓練本身的數據,必須一樣在datasets文件夾中進行定義,會在下面對此進行介紹
nets/ 定義了一些經常使用的網絡結構,如AlexNet、VGGl六、VGG1九、Inception 系列、ResNet、MobileNet等
preprocessing/ 在模型讀入圖片前,經常須要對圖像作預處理和數據加強。這個文件夾針對不一樣的網絡,分別定義了它們的預處理方法
scripts 包含了一些訓練的示例腳本
train_ image_classifier.py 訓練模型的入口代碼
eval_image_classifier.py 驗證模型性能的入口代碼
download_and _convert data.py 下載並轉換數據集格式的入口代碼

上表只列出了TensorFlow Slim中最重要的幾個文件以及文件夾的做用。其餘還有少許文件和文件夾,若是讀者對它們的做用感興趣,能夠自行參閱其文檔。

3.3 定義新的datasets文件

在slim/datasets中,定義了全部可使用的數據庫,爲了使用在第3.1節中建立的tfrecord數據進行訓練,必需要在datasets中定義新的數據庫。

首先,在datasets/目錄下新建一個文件satellite.py,並將flowers.py文件中的內容複製到satellite.py中。接下來,須要修改如下幾處內容。

第一處是_FILE_PATTERNSPLITS_TO_SIZES_NUM_CLASSES,將其進行如下修改:

_FILE_PATTERN='satellite _%s*. tfrecord'
SPLTTS_TO_SIZES={' train:4800,' validation':1200}
_NUM_CLASSES=6

_FILE_PATTERN 變量定義了數據的文件名的格式和訓練集、驗證集的數量。

_NUM_CLASSES 變量定義了數據集中圖片的類別數目。

第二處修改成image/format部分,將之修改成:

'image/format': tf. FixedLenFeature((), tf. string, default_value ='jpg'),

此處定義了圖片的默認格式。收集的衛星圖片的格式爲jpg圖片,所以修改成jpg。最後,讀者也能夠對文件中的註釋內容進行合適的修改。修改完satellite.py後,還須要在同目錄的dataset_factory.py文件中註冊satellite數據庫。以下:

datasets_map={
'cifar10':cifarl0,
'flowers':flowers,
'imagenet':imagenet,
'satellite':satellite,}

3.4 準備訓練文件夾

定義完數據集後,在slim文件夾下再新建一個satellite目錄,在這個目錄中,完成最後的幾項準備工做:

  • 新建一個data目錄,並將第3.1節中準備好的5個轉換好格式的訓練數據複製進去。
  • 新建一個空的train_dir目錄,用來保存訓練過程當中的日誌和模型。
  • 新建一個pretrained目錄,在slim的GitHub頁面找到InceptionV3模型的下載地址http:/download.tensorflow.org/models/inception_V3_2016_0828.tar.gz,下載並解壓後,會獲得一個 inception_v3.ckpt文件,將該文件複製到pretrained目錄下。

3.5 開始訓練

在slim文件夾下,運行如下命令就能夠開始訓練了:

!python train_image_classifier. py \
--train_dir=satellite/train_dir \
--dataset_name=satellite \
--dataset_split_name=train \
--model_name=inception_v3 \
--checkpoint_path=satellite/pretrained/inception_v3. ckpt \
--checkpoint_exclude_scopes=InceptionV3/Logits, InceptionV3/AuxLogits \
--trainable_scopes=InceptionV3/Logits, InceptionV3/AuxLogits \
--max_number_of steps=100000 \
--batch_size=32 \
--learning_rate=0.001 \
--learning_rate_decay_type=fixed  \
--save_interval_secs=300 \
--save_summaries_secs=2 \
--log_every_n_steps=10 \
--optimizer=rmsprop \
--weight_decay=0.00004

這裏的參數比較多,下面一一進行介紹:

  • --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits:首先來解釋參數trainable_scopes的做用,由於它很是重要。trainable_scopes規定了在模型中微調變量的範圍。這裏的設定表示只對InceptionV3/Logits,InceptionV3/AuxLogits 兩個變量進行微調,其餘變量都保持不動。InceptionV3/Logits,InceptionV3/AuxLogits 就至關於在第2章中所講的fc8,它們是Inception V3的「末端層」。若是不設定trainable_scopes,就會對模型中全部的參數進行訓練。
  • --train_dir-satellite/train_dir:代表會在satellite/train_dir目錄下保存日誌和checkpoint。
  • --dataset_name=satellite、--dataset_split_name=train:指定訓練的數據集。
  • --dataset_dir=satellite/data:指定訓練數據集保存的位置。
  • --model_name=inception_v3:使用的模型名稱。
  • --checkpoint_path=satellite/pretrained/inception_v3.ckpt:預訓練模型的保存位置。
  • --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits:在恢復預訓練模型時,不恢復這兩層。正如以前所說,這兩層是InceptionV3模型的末端層,對應着ImageNet數據集的1000類,和當前的數據集不符,所以不要去恢復它。
  • --max_number_of_steps=100000:最大的執行步數。
  • --batch_size=32:每步使用的batch數量。
  • ---learning_rate=0.001:學習率。
  • --learning_rate_decay_type=fixed:學習率是否自動降低,此處使用固定的學習率。
  • --save_interval_secs=300:每隔300s,程序會把當前模型保存到train_dir中。此處就是目錄satellite/train_dir。
  • --save_summaries_secs=2:每隔2s,就會將日誌寫入到train_dir中。能夠用TensorBoard 查看該日誌。此處爲了方便觀察,設定的時間間隔較多,實際訓練時,爲了性能考慮,能夠設定較長的時間間隔。
  • --log_every_n_steps=10:每隔10步,就會在屏幕上打出訓練信息。
  • --optimizer=rmsprop:表示選定的優化器。
  • --weight_decay=0.00004:選定的weight_decay值。即模型中全部參數的二次正則化超參數。

以上命令是隻訓練未端層InceptionV3/Logits,InceptionV3/AuxLogits,還能夠對全部層進行訓練:與只訓練末端層的命令相比,只有一處發生了變化,即去掉了--trainable_scopes參數。

3.6 驗證模型準確率

在slim文件下執行下列命令:

!python eval_image_classifier.py  \
--checkpoint_path=satellite/train_dir  \
-eval_dir=satellite/eval_dir  \
--dataset_name=satellite  \
--dataset_split_name=validation  \
--dataset_dir=satellite/data  \
--model_name=inception_v3

這裏參數的含義爲:

  • --checkpoint_path=satellite/train_dir:這個參數既能夠接收一個目錄的路徑,也能夠接收一個文件的路徑。若是接收的是一個目錄的路徑,如這裏的satellite/train_dir,就會在這個目錄中尋找最新保存的模型文件,執行驗證。也能夠指定一個模型進行驗證,以第300步的模型爲例,在satellite/train_dir文件夾下它被保存爲model.ckpt-300.meta、model.ckpt-300.index、model.ckpt-300.data-00000-of-00001三個文件。此時,若是要對它執行驗證,給checkpoint_path傳遞的參數應該爲satellite/train_dir/model.ckpt-300。|
  • --eval_dir-=satellite/eval_dir:執行結果的日誌就保存在eval_dir中,一樣能夠經過TensorBoard查看。
  • --dataset_name=satellite、--dataset_split_name=validation:指定須要執行的數據集。注意此處是使用驗證集(validation)執行驗證。
  • --dataset_dir=satellite/data:數據集保存的位置。
  • --model_name=inception_v3:使用的模型。

執行後,應該會出現相似下面的結果:

eval/Accuracy[0.51]
eval/Recal1_5[0.97333336]

Accuracy 表示模型的分類準確率,而Recall_5表示Top5的準確率,若是不須要top5 。而須要top2或者top3準確率,只要在eval_image_classifier.py中修改下面的部分就能夠了:

names_to_values, names_to_updates=slim. metrics. aggregate_metric map({
'Accuracy': slim. metrics. streaming_accuracy (predictions,labels),
'Recall_5': slim. metrics. streaming_recall_at_k(1ogits,labels,5),
})

4.代碼及數據集

百度網盤 提取碼:8qqt

5.補充

Colab等同於一個Linux的操做系統,若是你的電腦是windows,建議安裝一個cmder,解壓之後就可使用。cmder不只可使用windows下的全部命令,更爽的是可使用Linux的命令,shell命令。安裝與配置參考:Windows命令行工具cmder配置

相關文章
相關標籤/搜索