最近因爲項目須要,要對tensorflow構造的模型中部分變量凍結,而後繼續訓練,所以研究了一下tf中凍結變量的方法,目前找到三種,各有優缺點,記錄以下:python
1.名詞解釋api
凍結變量,指的是在訓練模型時,對某些可訓練變量不更新,即僅參與前向loss計算,不參與後向傳播,通常用於模型的finetuning等場景。例如:咱們在其餘數據上訓練了一個resnet152模型,而後但願在目前數據上作finetuning,通常來說,網絡的前幾層卷積是用來提取底層圖像特徵的,所以能夠對前3個卷積層進行凍結,不改變其weight和bias的數值。網絡
2.方法介紹函數
目前我找到了三種tf凍結變量的方法,各有優缺點,具體以下:spa
2.1 trainable=Falsecode
一切tf.Variable或tf.Variable的子類,在建立時,都有一個trainable參數,在tf官方文檔(https://www.tensorflow.org/api_docs/python/tf/Variable)中有對這個參數的定義,
blog
意思是,若是trainable設置爲True,就會把變量添加到GraphKeys.TRAINABLE_VARIABLES集合中,若是是False,則不添加。而在計算梯度進行後向傳播時,咱們通常會使用一個optimizer,而後調用該optimizer的compute_gradients方法。在compute_gradients中,第二個參數var_list若是不傳入,則默認爲GraphKeys.TRAINABLE_VARIABLES。
文檔
總結下,trainable=False凍結變量的邏輯:trainable=False → 該變量不會放入GraphKeys.TRAINABLE_VARIABLES → 調用optimizer.compute_gradients方法時默認變量列表爲GraphKeys.TRAINABLE_VARIABLES,該變量不在其中,所以不參與後向傳播,值不進行更新,達到凍結變量效果。it
優勢:操做簡單,只要在你建立變量時設置trainable=False便可io
缺點:不知道你們發現沒有,我上面的總結中,optimizer.compute_gradients方法默認變量列表是GraphKeys.TRAINABLE_VARIABLES,這句話還意味着,若是我不想用默認變量列表,而使用自定義變量列表,那麼即便設置了trainable=False,只要把該變量加入到自定義變量列表中,變量仍是會參與後向傳播的,值也會更新。另外,tf.layers、tf.contrib.rnn等一些高度封裝的API是不支持這個參數的,無法用該方法凍結變量。最後,若是咱們在使用Saver保存ckpt時,通常調動tf.trainable_variables()方法只保存可訓練參數,這時返回的變量列表,也有上面的問題,即設置了trainable=False的變量不會在裏面。
2.2 tf.stop_gradient()
咱們還能夠經過在某個變量外面包裹一層tf.stop_gradient()函數來達到凍結變量的目的。例如咱們想凍結w1,能夠寫成這樣:
w1 = tf.stop_gradient(w1)
在後向傳播時,w1的值就不會更新。下面說下優缺點。
優勢:操做簡單,針對想凍結的變量,添加上面這一行便可,並且相比於上一個方法,設置了tf.stop_gradient()的變量,不會從GraphKeys.TRAINABLE_VARIABLES集合中去除,所以不會影響梯度計算和保存模型
缺點:和上一個方法相似,tf.stop_gradient()的輸入是Tensor,tf.layers、tf.contrib.rnn等一些高度封裝的API的返回值無法做爲參數傳入,即不能用該方法凍結
2.3 optimizer.compute_gradients(loss,var_list=no_freeze_vars)
optimizer.compute_gradients在2.1中提到過,其實咱們只須要在計算梯度時,指定變量列表,把但願凍結的變量去除,便可完成凍結變量。但這麼作有一個前提,咱們必須知道全部可訓練變量的名字,並根據一些規則去除變量。獲取全部可訓練變量名字調用tf.trainable_variables()方法便可,但去除變量則須要咱們在構建網絡的時候,合理利用tf.variable_scope,對不一樣變量作區分。例如,咱們若是想把可訓練變量中全部卷積層變量凍結,能夠這麼寫:
trainable_vars = tf.trainable_variables() freeze_conv_var_list = [t for t in trainable_vars if not t.name.startswith(u'conv')] grads = opt.compute_gradients(loss, var_list=freeze_conv_var_list)
下面總結下優缺點,
優勢:沒有2.1和2.2的缺點,是一種適用範圍更加普遍的方法
缺點:相對2.1,2.2使用起來比較複雜,須要本身去除凍結變量,而且variable_scope不能隨意改動,由於可能使去除變量的過濾操做無效化。例如:若是把原來'cnn' scope改成'vgg',那麼上面的代碼就無效了
3.總結
tf對於一些經常使用操做,每每會提供多種方法,但每種方法通常都是有區別的,而且操做原理和後面的邏輯也會有不一樣,要謹慎使用