tensorflow finuetuning 例子

最近研究了下如何使用tensorflow進行finetuning,相比於caffe,tensorflow的finetuning麻煩一些,記錄以下:網絡

1.原理

finetuning原理很簡單,利用一個在數據A集上已訓練好的模型做爲初始值,改變其部分結構,在另外一數據集B上(採用小學習率)訓練的過程叫作finetuning。學習

通常來說,符合以下狀況會採用finetuningspa

  • 數據集A和B有相關性
  • 數據集A較大
  • 數據集B較小

 

2.關鍵代碼

在數據集A上訓練的時候,和普通的tensorflow訓練過程徹底一致。可是在數據集B上進行finetuning時,須要先從以前訓練好的checkpoint中恢復模型參數,這個地方比較關鍵,rest

須要注意只恢復須要恢復的參數,其餘參數不要恢復,不然會由於找不到的聲明而報錯。以mnist爲例子,若是我想先訓練一個0-7的8類分類器,網絡結構以下:code

conv1-conv2-fc8(其餘不帶權重的pooling、softmaxloss層忽略)blog

而後我想用這個訓練出的模型參數,在0-9的10類分類器上作finetuning,網絡結構以下:ip

conv1-conv2-fc10get

那麼在從checkpoint中恢復模型參數時,我只能恢復conv1-conv2,若是連fc8都恢復了,就會由於找不到fc8的定義而報錯it

以上描述對應的代碼以下:io

1     if tf.train.latest_checkpoint('ckpts') is not None:
2         trainable_vars = tf.trainable_variables()
3         res_vars = [t for t in trainable_vars if t.name.startswith('conv')]
4         saver = tf.train.Saver(var_list=res_vars)
5         saver.restore(sess, tf.train.latest_checkpoint('ckpts'))
6     else:
7         saver = tf.train.Saver()

 

3.demo

利用mnist寫了一個簡單的finetuning例子,你們能夠試試,事實證實,利用一個相關的已有模型作finuetuning比從0開始訓練收斂的更快而且收斂到的準確率更高,

點我下載

相關文章
相關標籤/搜索