TensorFlow——批量歸一化操做

批量歸一化python

在對神經網絡的優化方法中,有一種使用十分普遍的方法——批量歸一化,使得神經網絡的識別準確度獲得了極大的提高。git

在網絡的前向計算過程當中,當輸出的數據再也不同一分佈時,可能會使得loss的值很是大,使得網絡沒法進行計算。產生梯度爆炸的緣由是由於網絡的內部協變量轉移,即正向傳播的不一樣層參數會將反向訓練計算時參照的數據樣本分佈改變。批量歸一化的目的,就是要最大限度地保證每次的正向傳播輸出在同一分佈上,這樣反向計算時參照的數據樣本分佈就會與正向計算時的數據分佈同樣了,保證分佈的統一。網絡

瞭解了原理,批量正則化的作法就會變得簡單,即將每一層運算出來的數據都歸一化成均值爲0方差爲1的標準高斯分佈。這樣就會在保留樣本分佈特徵的同時,又消除層與層間的分佈差別。在實際的應用中,批量歸一化的收斂很是快,而且有很強的泛化能力,在一些狀況下,徹底能夠代替前面的正則化,dropout。函數

批量歸一化的定義性能

在TensorFlow中有自帶的BN函數定義:測試

tf.nn.batch_normalization(x,
                          maen,
                          variance,
                          offset,
                          scale,
                          variance_epsilon)

各個參數的含義以下:優化

x:表明輸入spa

mean:表明樣本的均值code

variance:表明方差orm

offset:表明偏移量,即相加一個轉化值,一般是用激活函數來作。

scale:表明縮放,即乘以一個轉化值,同理,通常是1

variance_epsilon:爲了不分母是0的狀況,給分母加一個極小值。

要使用這個函數,還須要另外的一個函數的配合:tf.nn.moments(),由此函數來計算均值和方差,而後就可使用BN了,給函數的定義以下:

tf.nn.moments(x, axes, name, keep_dims=False),axes指定那個軸求均值和方差。

爲了更好的效果,咱們使用平滑指數衰減的方法來優化每次的均值和方差,這裏可使用

tf.train.ExponentialMovingAverage()函數,它的做用是讓上一次的值對本次的值有一個衰減後的影響,從而使的每次的值連起來後會相對平滑一下。

批量歸一化的簡單用法

下面介紹具體的用法,在使用的時候須要引入頭文件。

from tensorflow.contrib.layers.python.layers import batch_norm

函數的定義以下:

batch_norm(inputs, decay, center, scale, epsilon, activation_fn, param_initializers=None, param_regularizers=None, updates_collections=ops.GraphKeys.UPDATE_OPS, is_training=True, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, batch_weights=None, fused=False, data_format=DATA_FORMAT_NHWC, zero_debias_moving_mean=False, scope=None, renorm=False, renorm_clipping=None, renorm_decay=0.99)

各參數的具體含義以下:

inputs:輸入

decay:移動平均值的衰減速度,使用的是平滑指數衰減的方法更新均值方差,通常會設置0.9,值過小會致使更新太快,值太大會致使幾乎沒有衰減,容易出現過擬合。

scale:是否進行變換,經過乘以一個gamma值進行縮放,咱們常習慣在BN後面接一個線性變化,如relu。

epsilon:爲了不分母爲0,給分母加上一個極小值,通常默認。

is_training:當爲True時,表明訓練過程,這時會不斷更新樣本集的均值和方差,當測試時,要設置爲False,這樣就會使用訓練樣本的均值和方差。

updates_collections:在訓練時,提供一種內置的均值方差更新機制,即經過圖中的tf.GraphKeys.UPDATE_OPS變量來更新。但它是在每次當前批次訓練完成後才更新均值和方差,這樣致使當前數據老是使用前一次的均值和方差,沒有獲得最新的值,因此通常設置爲None,讓均值和方差及時更新,但在性能上稍慢。

reuse:支持變量共享。

具體的代碼以下:

x = tf.placeholder(dtype=tf.float32, shape=[None, 32, 32, 3])
y = tf.placeholder(dtype=tf.float32, shape=[None, 10])
train = tf.Variable(tf.constant(False))

x_images = tf.reshape(x, [-1, 32, 32, 3])


def batch_norm_layer(value, train=False, name='batch_norm'):
    if train is not False:
        return batch_norm(value, decay=0.9, updates_collections=None, is_training=True)
    else:
        return batch_norm(value, decay=0.9, updates_collections=None, is_training=False)


w_conv1 = init_cnn.weight_variable([3, 3, 3, 64])  # [-1, 32, 32, 3]
b_conv1 = init_cnn.bias_variable([64])
h_conv1 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(x_images, w_conv1) + b_conv1), train))
h_pool1 = init_cnn.max_pool_2x2(h_conv1)


w_conv2 = init_cnn.weight_variable([3, 3, 64, 64])  # [-1, 16, 16, 64]
b_conv2 = init_cnn.bias_variable([64])
h_conv2 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool1, w_conv2) + b_conv2), train))
h_pool2 = init_cnn.max_pool_2x2(h_conv2)


w_conv3 = init_cnn.weight_variable([3, 3, 64, 32])  # [-1, 18, 8, 32]
b_conv3 = init_cnn.bias_variable([32])
h_conv3 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool2, w_conv3) + b_conv3), train))
h_pool3 = init_cnn.max_pool_2x2(h_conv3)

w_conv4 = init_cnn.weight_variable([3, 3, 32, 16])  # [-1, 18, 8, 32]
b_conv4 = init_cnn.bias_variable([16])
h_conv4 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool3, w_conv4) + b_conv4), train))
h_pool4 = init_cnn.max_pool_2x2(h_conv4)


w_conv5 = init_cnn.weight_variable([3, 3, 16, 10])  # [-1, 4, 4, 16]
b_conv5 = init_cnn.bias_variable([10])
h_conv5 = tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool4, w_conv5) + b_conv5), train))
h_pool5 = init_cnn.avg_pool_4x4(h_conv5)                 # [-1, 4, 4, 10]

y_pool = tf.reshape(h_pool5, shape=[-1, 10])


cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_pool))

optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)

加上了BN層以後,識別的準確率顯著的獲得了提高,而且計算速度也是飛起。

相關文章
相關標籤/搜索