tensorflow中batch normalization的用法

網上找了下tensorflow中使用batch normalization的博客,發現寫的都不是很好,在此總結下:網絡

1.原理學習

公式以下:測試

y=γ(x-μ)/σ+βspa

其中x是輸入,y是輸出,μ是均值,σ是方差,γ和β是縮放(scale)、偏移(offset)係數。code

通常來說,這些參數都是基於channel來作的,好比輸入x是一個16*32*32*128(NWHC格式)的feature map,那麼上述參數都是128維的向量。其中γ和β是無關緊要的,有的話,就是一個能夠學習的參數(參與前向後向),沒有的話,就簡化成y=(x-μ)/σ。而μ和σ,在訓練的時候,使用的是batch內的統計值,測試/預測的時候,採用的是訓練時計算出的滑動平均值。orm

 

2.tensorflow中使用blog

tensorflow中batch normalization的實現主要有下面三個:ip

tf.nn.batch_normalizationci

tf.layers.batch_normalization字符串

tf.contrib.layers.batch_norm

封裝程度逐個遞進,建議使用tf.layers.batch_normalization或tf.contrib.layers.batch_norm,由於在tensorflow官網的解釋比較詳細。我平時多使用tf.layers.batch_normalization,所以下面的步驟都是基於這個。

 

3.訓練

訓練的時候須要注意兩點,(1)輸入參數training=True,(2)計算loss時,要添加如下代碼(即添加update_ops到最後的train_op中)。這樣才能計算μ和σ的滑動平均(測試時會用到)

  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss)

 

4.測試

測試時須要注意一點,輸入參數training=False,其餘就沒了

 

5.預測

預測時比較特別,由於這一步通常都是從checkpoint文件中讀取模型參數,而後作預測。通常來講,保存checkpoint的時候,不會把全部模型參數都保存下來,由於一些無關數據會增大模型的尺寸,常見的方法是隻保存那些訓練時更新的參數(可訓練參數),以下:

var_list = tf.trainable_variables() saver = tf.train.Saver(var_list=var_list, max_to_keep=5)

 

但使用了batch_normalization,γ和β是可訓練參數沒錯,μ和σ不是,它們僅僅是經過滑動平均計算出的,若是按照上面的方法保存模型,在讀取模型預測時,會報錯找不到μ和σ。更詭異的是,利用tf.moving_average_variables()也無法獲取bn層中的μ和σ(也多是我用法不對),不過好在全部的參數都在tf.global_variables()中,所以能夠這麼寫:

var_list = tf.trainable_variables() g_list = tf.global_variables() bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name] bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name] var_list += bn_moving_vars saver = tf.train.Saver(var_list=var_list, max_to_keep=5)

按照上述寫法,便可把μ和σ保存下來,讀取模型預測時也不會報錯,固然輸入參數training=False仍是要的。

注意上面有個不嚴謹的地方,由於個人網絡結構中只有bn層包含moving_mean和moving_variance,所以只根據這兩個字符串作了過濾,若是你的網絡結構中其餘層也有這兩個參數,但你不須要保存,建議使用諸如bn/moving_mean的字符串進行過濾。

 

2018.4.22更新

提供一個基於mnist的示例,供你們參考。包含兩個文件,分別用於train/test。注意bn_train.py文件的51-61行,僅保存了網絡中的可訓練變量和bn層利用統計獲得的mean和var。注意示例中須要下載mnist數據集,要保持電腦能夠聯網。

相關文章
相關標籤/搜索