Batch Normalization 學習筆記

原文:http://blog.csdn.net/happynear/article/details/44238541git

今年過年以前,MSRA和Google相繼在ImagenNet圖像識別數據集上報告他們的效果超越了人類水平,下面將分兩期介紹二者的算法細節。github

  此次先講Google的這篇《Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift》,主要是由於這裏面的思想比較有普適性,並且一直答應羣裏的人寫一個有關預處理的科普,但一直沒抽出時間來寫。算法

1、神經網絡中的權重初始化與預處理方法的關係

若是作過dnn的實驗,你們可能會發如今對數據進行預處理,例如白化或者zscore,甚至是簡單的減均值操做都是能夠加速收斂的,例以下圖所示的一個簡單的例子:網絡

  圖中紅點表明2維的數據點,因爲圖像數據的每一維通常都是0-255之間的數字,所以數據點只會落在第一象限,並且圖像數據具備很強的相關性,好比第一個灰度值爲30,比較黑,那它旁邊的一個像素值通常不會超過100,不然給人的感受就像噪聲同樣。因爲強相關性,數據點僅會落在第一象限的很小的區域中,造成相似上圖所示的狹長分佈。app

  而神經網絡模型在初始化的時候,權重W是隨機採樣生成的,一個常見的神經元表示爲:ReLU(Wx+b) = max(Wx+b,0),即在Wx+b=0的兩側,對數據採用不一樣的操做方法。具體到ReLU就是一側收縮,一側保持不變。函數

  隨機的Wx+b=0表現爲上圖中的隨機虛線,注意到,兩條綠色虛線實際上並無什麼意義,在使用梯度降低時,可能須要不少次迭代纔會使這些虛線對數據點進行有效的分割,就像紫色虛線那樣,這勢必會帶來求解速率變慢的問題。更況且,咱們這只是個二維的演示,數據佔據四個象限中的一個,若是是幾百、幾千、上萬維呢?並且數據在第一象限中也只是佔了很小的一部分區域而已,可想而知不對數據進行預處理帶來了多少運算資源的浪費,並且大量的數據外分割面在迭代時極可能會在剛進入數據中時就遇到了一個局部最優,致使overfit的問題。oop

  這時,若是咱們將數據減去其均值,數據點就再也不只分佈在第一象限,這時一個隨機分界面落入數據分佈的機率增長了多少呢?2^n倍!若是咱們使用去除相關性的算法,例如PCA和ZCA白化,數據再也不是一個狹長的分佈,隨機分界面有效的機率就又大大增長了。學習

  不過計算協方差矩陣的特徵值太耗時也太耗空間,咱們通常最多隻用到z-score處理,即每一維度減去自身均值,再除以自身標準差,這樣能使數據點在每維上具備類似的寬度,能夠起到必定的增大數據分佈範圍,進而使更多隨機分界面有意義的做用。測試

2、Batch Normalization

  上一節咱們講到對輸入數據進行預處理,減均值->zscore->白化能夠逐級提高隨機初始化的權重對數據分割的有效性,還能夠下降overfit的可能性。咱們都知道,如今的神經網絡的層數都是很深的,若是咱們對每一層的數據都進行處理,訓練時間和overfit程度是否能夠下降呢?Google的這篇論文給出了答案。大數據

一、算法描述

  按照第一章的理論,應當在每一層的激活函數以後,例如ReLU=max(Wx+b,0)以後,對數據進行歸一化。然而,文章中說這樣作在訓練初期,分界面還在劇烈變化時,計算出的參數不穩定,因此退而求其次,在Wx+b以後進行歸一化。由於初始的W是從標準高斯分佈中採樣獲得的,而W中元素的數量遠大於x,Wx+b每維的均值自己就接近0、方差接近1,因此在Wx+b後使用Batch Normalization能獲得更穩定的結果。

       文中使用了相似z-score的歸一化方式:每一維度減去自身均值,再除以自身標準差,因爲使用的是隨機梯度降低法,這些均值和方差也只能在當前迭代的batch中計算,故做者給這個算法命名爲Batch Normalization。這裏有一點須要注意,像卷積層這樣具備權值共享的層,Wx+b的均值和方差是對整張map求得的,在batch_size * channel * height * width這麼大的一層中,對總共batch_size*height*width個像素點統計獲得一個均值和一個標準差,共獲得channel組參數。

  在Normalization完成後,Google的研究員仍對數值穩定性不放心,又加入了兩個參數gamma和beta,使得

       注意到,若是咱們令gamma等於以前求得的標準差,beta等於以前求得的均值,則這個變換就又將數據還原回去了。在他們的模型中,這兩個參數與每層的W和b同樣,是須要迭代求解的。文章中舉了個例子,在sigmoid激活函數的中間部分,函數近似於一個線性函數(以下圖所示),使用BN後會使歸一化後的數據僅使用這一段線性的部分(吐槽一下:再乘個2之類的不就好了)。

       能夠看到,在[0.2, 0.8]範圍內,sigmoid函數基本呈線性遞增,甚至在[0.1, 0.9]範圍內,sigmoid函數都是相似於線性函數的,若是隻用這一段,那網絡不就成了線性網絡了麼,這顯然不是你們願意見到的。至於這兩個參數對ReLU起的做用文中沒說,我就不妄自揣摩了哈。

       算法原理到這差很少就講完了,下面是你們 最不喜歡的公式環節了,求均值和方差就不用說了,在BP的時候,咱們須要求最終的損失函數對gamma和beta兩個參數的導數,還要求損失函數對Wx+b中的x的導數,以便使偏差繼續向後傳播。求導公式以下:

 

  具體的公式推導就不寫了,有興趣的讀者能夠本身推一下,主要用到了鏈式法則。

  在訓練的最後一個epoch時,要對這一epoch全部的訓練樣本的均值和標準差進行統計,這樣在一張測試圖片進來時,使用訓練樣本中的標準差的指望和均值的指望(好繞口)對測試數據進行歸一化,注意這裏標準差使用的指望是其無偏估計:

二、算法優點

  論文中將Batch Normalization的做用說得突破天際,好似一下解決了全部問題,下面就來一一列舉一下:
  (1) 可使用更高的學習率。若是每層的scale不一致,實際上每層須要的學習率是不同的,同一層不一樣維度的scale每每也須要不一樣大小的學習率,一般須要使用最小的那個學習率才能保證損失函數有效降低,Batch Normalization將每層、每維的scale保持一致,那麼咱們就能夠直接使用較高的學習率進行優化。
  (2) 移除或使用較低的dropout。 dropout是經常使用的防止overfitting的方法,而致使overfit的位置每每在數據邊界處,若是初始化權重就已經落在數據內部,overfit現象就能夠獲得必定的緩解。論文中最後的模型分別使用10%、5%和0%的dropout訓練模型,與以前的40%-50%相比,能夠大大提升訓練速度。
  (3) 下降L2權重衰減係數。 仍是同樣的問題,邊界處的局部最優每每有幾維的權重(斜率)較大,使用L2衰減能夠緩解這一問題,如今用了Batch Normalization,就能夠把這個值下降了,論文中下降爲原來的5倍。
  (4) 取消Local Response Normalization層。 因爲使用了一種Normalization,再使用LRN就顯得沒那麼必要了。並且LRN實際上也沒那麼work。
  (5) 減小圖像扭曲的使用。 因爲如今訓練epoch數下降,因此要對輸入數據少作一些扭曲,讓神經網絡多看看真實的數據。

3、實驗

  這裏我只在matlab上面對算法進行了仿真,修改了DeepLearnToolbox 裏面的NN模型,代碼以下:

  在前向傳播時,分兩種狀況進行討論:若是是在train過程,就使用當前batch的數據統計均值和標準差,並按照第二章所述公式對Wx+b進行歸一化,以後再乘上gamma,加上beta獲得Batch Normalization層的輸出;若是在進行test過程,則使用記錄下的均值和標準差,還有以前訓練好的gamma和beta計算獲得結果

 

[plain]  view plain  copy
 
 在CODE上查看代碼片派生到個人代碼片
  1. if nn.testing  
  2.     nn.a_pre{i} = nn.a{i - 1} * nn.W{i - 1}';  
  3.     norm_factor = nn.gamma{i-1}./sqrt(nn.mean_sigma2{i-1}+nn.epsilon);  
  4.     nn.a_hat{i} = bsxfun(@times, nn.a_pre{i}, norm_factor);  
  5.     nn.a_hat{i} = bsxfun(@plus, nn.a_hat{i}, nn.beta{i-1} -  norm_factor .* nn.mean_mu{i-1});  
  6. else  
  7.     nn.a_pre{i} = nn.a{i - 1} * nn.W{i - 1}';  
  8.     nn.mu{i-1} = mean(nn.a_pre{i});  
  9.     x_mu = bsxfun(@minus,nn.a_pre{i},nn.mu{i-1});  
  10.     nn.sigma2{i-1} = mean(x_mu.^2);  
  11.     norm_factor = nn.gamma{i-1}./sqrt(nn.sigma2{i-1}+nn.epsilon);  
  12.     nn.a_hat{i} = bsxfun(@times, nn.a_pre{i}, norm_factor);  
  13.     nn.a_hat{i} = bsxfun(@plus, nn.a_hat{i}, nn.beta{i-1} -  norm_factor .* nn.mu{i-1});  
  14. end;  


  反向傳播就跟上面那一堆公式同樣啦,注意爲了運行效率,儘可能使用向量化的代碼,避免使用for循環:

 

 

 

[plain]  view plain  copy
 
 在CODE上查看代碼片派生到個人代碼片
  1. d_xhat = bsxfun(@times, d{i}(:,2:end), nn.gamma{i-1});  
  2. x_mu = bsxfun(@minus, nn.a_pre{i}, nn.mu{i-1});  
  3. inv_sqrt_sigma = 1 ./ sqrt(nn.sigma2{i-1} + nn.epsilon);  
  4. d_sigma2 = -0.5 * sum(d_xhat .* x_mu) .* inv_sqrt_sigma.^3;  
  5. d_mu = bsxfun(@times, d_xhat, inv_sqrt_sigma);  
  6. d_mu = -1 * sum(d_mu) -2 .* d_sigma2 .* mean(x_mu);  
  7. d_gamma = mean(d{i}(:,2:end) .* nn.a_hat{i});  
  8. d_beta = mean(d{i}(:,2:end));  
  9. di1 = bsxfun(@times,d_xhat,inv_sqrt_sigma);  
  10. di2 = 2/m * bsxfun(@times, d_sigma2,x_mu);  
  11. d{i}(:,2:end) = di1 + di2 + 1/m * repmat(d_mu,m,1);  

  在訓練的最後一個epoch,要對全部的gamma和beta進行統計,代碼很簡單就不貼了,完整代碼在個人Github上有:https://github.com/happynear/DeepLearnToolbox

 

一、sigmoid激活函數的過飽和問題

  經測試發現算法對sigmoid激活函數的提高很是明顯,解決了困擾學術界十幾年的sigmoid過飽和的問題,即在深層的神經網絡中,前幾層在梯度降低時獲得的梯度太低,致使深層神經網絡變成了前邊是隨機變換,只在最後幾層纔是真正在作分類的問題。
  下面是使用一個10個隱藏層的nn網絡,對mnist進行分類,每層的梯度值:

  使用Batch Normalization前:

 

[plain]  view plain  copy
 
 在CODE上查看代碼片派生到個人代碼片
  1. epoch:1 iteration:10/300  
  2.  3.23e-07 8.3215e-07 3.3605e-06 1.5193e-05 6.4892e-05 0.00027249 0.0011954 0.006295 0.029835 0.12476 0.38948  
  3. epoch:1 iteration:20/300  
  4.  4.4649e-07 1.3282e-06 5.6753e-06 2.5294e-05 0.00010326 0.00043651 0.0019583 0.0096396 0.040469 0.16142 0.5235  
  5. epoch:1 iteration:30/300  
  6.  4.6973e-07 1.2993e-06 5.3923e-06 2.3111e-05 9.4839e-05 0.00040398 0.0017893 0.0081367 0.037543 0.1544 0.46472  
  7. epoch:1 iteration:40/300  
  8.  4.6986e-07 1.3801e-06 5.677e-06 2.4355e-05 0.00010245 0.00041999 0.0019832 0.0095022 0.043719 0.17696 0.56134  
  9. epoch:1 iteration:50/300  
  10.  4.6964e-07 1.6532e-06 7.2543e-06 3.0731e-05 0.00011805 0.00048795 0.0021705 0.0099466 0.042835 0.17993 0.5319  

  能夠看到,最開始的幾層只有1e-6到1e-7這個量級的梯度,基本上梯度在最後3層就已經飽和了。

 

  使用Batch Normalization後:

 

[plain]  view plain  copy
 
 在CODE上查看代碼片派生到個人代碼片
  1. epoch:1 iteration:10/300  
  2.  0.27121 0.15534 0.15116 0.15409 0.15515 0.14542 0.12878 0.13888 0.16607 0.21036 0.76037  
  3. epoch:1 iteration:20/300  
  4.  0.24567 0.15369 0.14169 0.13183 0.1278 0.13904 0.13546 0.12032 0.14332 0.14868 0.54481  
  5. epoch:1 iteration:30/300  
  6.  0.30403 0.16365 0.14119 0.14502 0.13916 0.12851 0.11781 0.11424 0.11082 0.1088 0.39574  
  7. epoch:1 iteration:40/300  
  8.  0.32681 0.19801 0.16792 0.14741 0.13294 0.12805 0.13754 0.12941 0.13288 0.12957 0.50937  
  9. epoch:1 iteration:50/300  
  10.  0.32358 0.17484 0.16367 0.16605 0.17118 0.14703 0.14458 0.12693 0.13928 0.11938 0.3692  


  我第一次看到的時候,就像以前看到ReLU同樣驚豔,終於,sigmoid的飽和問題也獲得瞭解決。不過論文中還有我本身的實驗都代表,sigmoid在分類問題上確實沒有ReLU好用,大概是由於sigmoid的中間部分太「線性」了,不像ReLU一個很大的轉折,在擬合複雜非線性函數的時候可能沒那麼高效,真的是蠻遺憾的。

 

 

二、gamma和beta的做用

  在第二章提到,引入gamma和beta兩個參數是爲了不數據只用sigmoid的線性部分,這裏作了個簡單的測試,將用和不用gamma與beta參數訓練出的網絡的最大/最小激活值顯示出來:
 

 

  能夠看到,若是不使用gamma和beta,激活值基本上會在[0.1 0.9]這個近似線性的區域中,這與深度神經網絡所要求的「多層非線性函數逼近任意函數」的要求不符,因此引入gamma和beta仍是有必要的,深度網絡會自動決定使用哪一段函數(這是我本身想的,其具體做用歡迎討論)。

  對於ReLU來講,gamma的做用可能不是很明顯,由於relu是分段」線性「的,對數值進行伸縮並不能影響relu取x仍是取0。但beta的做用就很大了,試想一下若是沒有beta,通過batch normalization層的特徵,都具備0均值的指望,這樣豈不是強制令ReLU的輸出有一半是0一半非0麼?這與咱們的初衷不太相符,咱們但願神經網絡自行決定在什麼位置去設定這個閾值,而不是增長一個如此強的限制。另外,由於這個beta我曾經還鬧了個大笑話,記錄在http://blog.csdn.net/happynear/article/details/46583811,請你們引覺得戒。

4、總結

  Batch Normalization的加速做用體如今兩個方面:一是歸一化了每層和每維度的scale,因此能夠總體使用一個較高的學習率,而沒必要像之前那樣遷就小scale的維度;二是歸一化後使得更多的權重分界面落在了數據中,下降了overfit的可能性,所以一些防止overfit但會下降速度的方法,例如dropout和權重衰減就能夠不使用或者下降其權重。
  截止到目前,尚未哪一個機構宣佈重現了論文中的結果,不過歸一化的用處在理論層面就已經有了保證,之後也許歸一化的形式會有所改變,但逐層的歸一化應該會成爲一種標準。本博客文章僅僅給出了歸一化優勢的幾何解釋,但願有更多的理論解釋來指導咱們使用歸一化層。
  就目前來看,爭議的重點在於歸一化的位置,還有gamma與beta參數的引入,從理論上分析,論文中的這兩個細節實際上並不符合ReLU的特性:ReLU後,數據分佈從新回到第一象限,這時是最應當進行歸一化的;gamma與beta對sigmoid函數確實能起到必定的做用(實際也不如固定gamma=2),但對於ReLU這種分段線性的激活函數,並不存在sigmoid的低scale呈線性的現象。期待更多的理論分析,我本身也會持續跟進這個方向。

5、一些資源

本文所用到的matlab代碼:https://github.com/happynear/DeepLearnToolbox
Caffe的BN實現:https://github.com/ducha-aiki/caffe/tree/bn
cxxnet的BN實現:https://github.com/antinucleon/cxxnet
相關文章
相關標籤/搜索