Siamese網絡

1.       對比損失函數(Contrastive Loss function)

孿生架構的目的不是對輸入圖像進行分類,而是區分它們。所以,分類損失函數(如交叉熵)不是最合適的選擇,這種架構更適合使用對比函數。對比損失函數以下:python

 

 

(以判斷圖片類似度爲例)其中Dw被定義爲姐妹孿生網絡的輸出之間的歐氏距離。Y值爲1或0。若是模型預測輸入是類似的,那麼Y的值爲0,不然Y爲1。m是大於0的邊際價值(margin value)。有一個邊際價值表示超出該邊際價值的不一樣對不會形成損失。git

Siamese網絡架構須要一個輸入對,以及標籤(相似/不類似)。github

2.       孿生網絡的訓練過程

(1)    經過網絡傳遞圖像對的第一張圖像。網絡

(2)    經過網絡傳遞圖像對的第二張圖像。架構

(3)    使用(1)和(2)中的輸出來計算損失。函數

 

 

  其中,l12爲標籤,用於表示x1的排名是否高於x2。性能

訓練過程當中兩個分支網絡的輸出爲高級特徵,能夠視爲quality score。在訓練時,輸入是兩張圖像,分別獲得對應的分數,將分數的差別嵌入loss層,再進行反向傳播。學習

(4)    返回傳播損失計算梯度。測試

(5)    使用優化器更新權重。優化

3.       基於Siamese網絡的無參考圖像質量評估:RankIQA

3.1          參考文獻

https://arxiv.org/abs/1707.08347

3.2          RankIQA的流程

(1)    合成失真圖像。

(2)    訓練Siamese網絡,使網絡輸出對圖像質量的排序。

(3)    提取Siamese網絡的一支,使用IQA數據集進行fine-tune。將網絡的輸出校訂爲IQA度量值。fine-tune階段的損失函數以下:

 

 

訓練階段使用Hinge loss,fine-tune階段使用MSE。

訓練時,每次從圖像中隨機提取224*224或者227*227大小的圖像塊。和AlexNet、VGG16有關。在訓練Siamese network時,初始的learning rate 是1e-4;fine-tuning是初始的learning rate是1e-6,每隔10k步,rate變成原來的0.1倍。訓練50k次。測試時,隨機在圖像中提取30個圖像塊,獲得30個分數以後,進行均值操做。

本文如何提升Siamese網絡的效率:

假設有三張圖片,則每張圖片將被輸入網絡兩次,緣由是含有某張圖片的排列數爲2。爲了減小計算量,每張圖片只輸入網絡一次,在網絡以後、損失函數以前新建一層,用於生成每一個mini-batch中圖片的可能排列組合。

使用上述方法,每張圖片只前向傳播一次,只在loss計算時考慮全部的圖片組合方式。

本文使用的網絡架構:Shallow, AlexNet, and VGG-16。

4.       Siamese網絡的開源實現

4.1          代碼地址

https://github.com/xialeiliu/RankIQA

4.2          RankIQA的運行過程

4.2.1            數據集

使用兩方面的數據集,通常性的非IQA數據集用於生成排序好的圖片,進而訓練Siamese網絡;IQA數據集用於微調和評估。

本文使用的IQA數據集:

(1)    LIVE數據集:http://live.ece.utexas.edu/research/quality/ 對29張原始圖片進行五類失真處理,獲得808張圖片。Ground Truth MOS在[0, 100]之間(人工評分)。

(2)    TID2013:25張原始圖片,3000張失真圖片。MOS範圍是[0, 9]。

本文使用的用於生成ranked pairs的數據集:

(1)    爲了測試LIVE數據集,人工生成了四類失真,GB(Gaussian Blur)、GN(Gaussian Noise)、JPEG、JPEG2K

(2)    爲了在TID2013上測試,生成了17種失真(去掉了#3, #4,#12, #13, #20, #21, #24)

Waterloo數據集:

包含4744張高質量天然圖片。

Places2數據集:

做爲驗證集(包含356種場景,http://places2.csail.mit.edu/ ),每類100張,共35600張。

兩種數據集的區別:

python generate_rank_txt_tid2013.py生成的是tid2013_train.txt,標籤只起到表示相對順序的做用,即,標籤爲{1, 2, 3, 4, 5};python generate_ft_txt_tid2013.py生成的是ft_tid2013_test.txt,其中的標籤是浮點數,表示圖片的質量評分。

 

4.2.2            訓練和測試過程

從原始圖像中隨機採樣子圖(sub-images),避免因差值和過濾而產生的失真。輸入的子圖至少佔原圖的1/3,以保留場景信息。本文采用227*227或者224*224的採樣圖像(根據使用的主幹網絡而不一樣)。

訓練過程使用mini-batch SGD,初始學習率1e-4,fine-tune學習率1e-6。

共迭代50K次,每10K次減少學習率(乘以0.1),兩個訓練過程都是用l2權重衰減(正則化係數lambda=5e-4)。

實驗一:本文首先使用Places2數據集(使用五種失真進行處理)訓練網絡(不進行微調),而後在Waterloo數據及上進行預測IQA(使用一樣的五種失真進行處理)。實驗結果如圖2所示。

 

 

實驗二:hard negative mining

難分樣本挖掘,是當獲得錯誤的檢測patch時,會明確的從這個patch中建立一個負樣本,並把這個負樣本添加到訓練集中去。從新訓練分類器後,分類器會表現的更好,而且不會像以前那樣產生多的錯誤的正樣本。

本實驗使用Alexnet進行。

實驗三:網絡性能分析

LIVE數據集,80%訓練集,評價指標LCC和SROCC。VGG-16的效果最好。

4.2.3            RankIQA對數據集的處理過程

將原始圖像文件放在data/rank_tid2013/pristine_images路徑下,而後運行data/rank_tid2013/路徑下的tid2013_main.m,進而生成排序數據集(17種失真形式)。

4.3          運行指令

4.3.1            Train RankIQA

To train the RankIQA models on tid2013 dataset:

./src/RankIQA/tid2013/train_vgg.sh

 

To train the RankIQA models on LIVE dataset:

./src/RankIQA/live/train_vgg.sh

 

FT

To train the RankIQA+FT models on tid2013 dataset:

./src/FT/tid2013/train_vgg.sh

 

To train the RankIQA+FT models on LIVE dataset:

./src/FT/live/train_live.sh

 

4.3.2            Evaluation for RankIQA

python src/eval/Rank_eval_each_tid2013.py  # evaluation for each distortions in tid2013

python src/eval/Rank_eval_all_tid2013.py   # evaluation for all distortions in tid2013

Evaluation for RankIQA+FT on tid2013:

python src/eval/FT_eval_each_tid2013.py  # evaluation for each distortions in tid2013

python src/eval/FT_eval_all_tid2013.py   # evaluation for all distortions in tid2013

Evaluation for RankIQA on LIVE:

python src/eval/Rank_eval_all_live.py   # evaluation for all distortions in LIVE

Evaluation for RankIQA+FT on LIVE:

python src/eval/FT_eval_all_live.py   # evaluation for all distortions in LIVE

5.       代碼調試過程

5.1          Python沒法導入某個模塊ImportError:could not find module XXX

解決方案:

配置環境變量:export PYTHONPATH=path/to/modules

相關文章
相關標籤/搜索