神經網絡中的批標準化

做者|Emrick Sinitambirivoutin
編譯|VK
來源|Towards Data Sciencepython

訓練學習系統的一個主要假設是在整個訓練過程當中輸入的分佈保持不變。對於簡單地將輸入數據映射到某些適當輸出的線性模型,這種條件老是知足的,但在處理由多層疊加而成的神經網絡時,狀況就不同了。git

在這樣的體系結構中,每一層的輸入都受到前面全部層的參數的影響(隨着網絡變得更深,對網絡參數的小變化會被放大)。所以,在一層內的反向傳播步驟中所作的一個小的變化能夠產生另外一層的輸入的一個巨大的變化,並在最後改變特徵映射分佈。在訓練過程當中,每一層都須要不斷地適應前一層獲得的新分佈,這就減慢了收斂速度。github

批標準化克服了這一問題,同時經過減小訓練過程當中內層的協方差移位(因爲訓練過程當中網絡參數的變化而致使的網絡激活分佈的變化)算法

本文將討論如下內容

  • 批標準化如何減小內部協方差移位,如何改進神經網絡的訓練。
  • 如何在PyTorch中實現批標準化層。
  • 一些簡單的實驗顯示了使用批標準化的優勢。

減小內部協方差移位

減小消除神經網絡內部協方差移位的不良影響的一種方法是對層輸入進行歸一化。這個操做不只使輸入具備相同的分佈,並且還使每一個輸入都白化(白化是對原始數據x實現一種變換,變換成x_Whitened,使x_Whitened的協方差矩陣的爲單位陣。)。該方法是由一些研究提出的,這些研究代表,若是對網絡的輸入進行白化,則網絡訓練收斂得更快,所以,加強各層輸入的白化是網絡的一個理想特性。網絡

然而,每一層輸入的徹底白化是昂貴的,而且不是徹底可微的。批標準化經過考慮兩個假設克服了這個問題:架構

  • 咱們將獨立地對每一個標量特徵進行歸一化(經過設置均值爲0和方差爲1),而不是對層的輸入和輸出的特徵進行白化。
  • 咱們不使用整個數據集來進行標準化,而是使用mini-batch,每一個mini-batch生成每一個激活層的平均值和方差的估計值。

對於具備d維輸入的層x = (x1, x2, ..xd)咱們獲得瞭如下公式的歸一化(對batch B的指望和方差進行計算):機器學習

然而,簡單地標準化一個層的每一個輸入可能會改變層所能表示的內容。例如,對一個sigmoid的輸入進行歸一化會將其約束到非線性的線性狀態。這樣的行爲對網絡來講是不可取的,由於它會下降其非線性的能力(它將成爲至關於一個單層網絡)。函數

爲了解決這個問題,批標準化還確保插入到網絡中的轉換能夠表示單位轉換(模型仍然在每一個層學習一些參數,這些參數在沒有線性映射的狀況下調整從上一層接收到的激活)。這是經過引入一對可學習參數gamma_k和beta_k來實現的,這兩個參數根據模型學習的內容縮放和移動標準化值。學習

最後,獲得的層的輸入(基於前一層的輸出x)爲:.net

批標準化算法

訓練時

全鏈接層

全鏈接層的實現很是簡單。咱們只須要獲得每一個批次的均值和方差,而後用以前給出的alpha和beata參數來縮放和移動。

在反向傳播期間,咱們將使用反向傳播來更新這兩個參數。

mean = torch.mean(X, axis=0)
variance = torch.mean((X-mean)**2, axis=0)
X_hat = (X-mean) * 1.0 /torch.sqrt(variance + eps)
out = gamma * X_hat + beta
卷積層

卷積層的實現幾乎與之前同樣。咱們只須要執行一些改造,以適應咱們從上一層得到的輸入結構。

N, C, H, W = X.shape
mean = torch.mean(X, axis = (0, 2, 3))
variance = torch.mean((X - mean.reshape((1, C, 1, 1))) ** 2, axis=(0, 2, 3))
X_hat = (X - mean.reshape((1, C, 1, 1))) * 1.0 / torch.sqrt(variance.reshape((1, C, 1, 1)) + eps)
out = gamma.reshape((1, C, 1, 1)) * X_hat + beta.reshape((1, C, 1, 1))

在PyTorch中,反向傳播很是容易處理,這裏的一件重要事情是指定alpha和beta是在反向傳播階段更新它們的參數。

爲此,咱們將在層中將它們聲明爲nn.Parameter(),並使用隨機值初始化它們。

推理時

在推理過程當中,咱們但願網絡的輸出只依賴於輸入,所以咱們不能考慮以前考慮的批的統計數據(它們與批相關,所以它們根據數據而變化)。爲了確保咱們有一個固定的指望和方差,咱們須要使用整個數據集來計算這些值,而不是隻考慮批。然而,就時間和計算而言,爲全部數據集計算這些統計信息是至關昂貴的。

論文中提出的方法是使用咱們在訓練期間計算的滑動統計。咱們使用參數beta(動量)調整當前批次計算的指望的重要性:

該滑動平均線存儲在一個全局變量中,該全局變量在訓練階段更新。
爲了在訓練期間將這個滑動平均線存儲在咱們的層中,咱們可使用緩衝區。當咱們使用PyTorch的register_buffer()方法實例化咱們的層時,咱們將初始化這些緩衝區。

最後一個模塊

而後,最後一個模塊由前面描述的全部塊組成。咱們在輸入數據的形狀上添加一個條件,以瞭解咱們處理的是全鏈接層仍是卷積層。

這裏須要注意的一件重要事情是,咱們只須要實現forward()方法。由於咱們的類繼承自nn.Module,咱們就能夠自動獲得backward()函數。

class CustomBatchNorm(nn.Module):

    def __init__(self, in_size, momentum=0.9, eps = 1e-5):
        super(CustomBatchNorm, self).__init__()
        
        self.momentum = momentum
        self.insize = in_size
        self.eps = eps
        
        U = uniform.Uniform(torch.tensor([0.0]), torch.tensor([1.0]))
        self.gamma = nn.Parameter(U.sample(torch.Size([self.insize])).view(self.insize))
        self.beta = nn.Parameter(torch.zeros(self.insize))
            
        self.register_buffer('running_mean', torch.zeros(self.insize))
        self.register_buffer('running_var', torch.ones(self.insize))
        
        self.running_mean.zero_()
        self.running_var.fill_(1)

    def forward(self, input):
        
        X = input

        if len(X.shape) not in (2, 4):
            raise ValueError("only support dense or 2dconv")
        
        #全鏈接層
        elif len(X.shape) == 2:
            if self.training:
                mean = torch.mean(X, axis=0)
                variance = torch.mean((X-mean)**2, axis=0)
                
                self.running_mean = (self.momentum * self.running_mean) + (1.0-self.momentum) * mean
                self.running_var = (self.momentum * self.running_var) + (1.0-self.momentum) * (input.shape[0]/(input.shape[0]-1)*variance)
            
            else:
                mean = self.running_mean
                variance = self.running_var
                
            X_hat = (X-mean) * 1.0 /torch.sqrt(variance + self.eps)
            out = self.gamma * X_hat + self.beta
  
				# 卷積層
        elif len(X.shape) == 4:
            if self.training:
                N, C, H, W = X.shape
                mean = torch.mean(X, axis = (0, 2, 3))
                variance = torch.mean((X - mean.reshape((1, C, 1, 1))) ** 2, axis=(0, 2, 3))
                
                self.running_mean = (self.momentum * self.running_mean) + (1.0-self.momentum) * mean
                self.running_var = (self.momentum * self.running_var) + (1.0-self.momentum) * (input.shape[0]/(input.shape[0]-1)*variance)
            else:
                mean = self.running_mean
                var = self.running_var
                
            X_hat = (X - mean.reshape((1, C, 1, 1))) * 1.0 / torch.sqrt(variance.reshape((1, C, 1, 1)) + self.eps)
            out = self.gamma.reshape((1, C, 1, 1)) * X_hat + self.beta.reshape((1, C, 1, 1))
        
        return out

實驗MNIST

爲了觀察批處理歸一化對訓練的影響,咱們能夠比較沒有批處理歸一化的簡單神經網絡和有批處理歸一化的神經網絡的收斂速度。

爲了簡單起見,咱們在MNIST數據集上訓練這兩個簡單的全鏈接網絡,不進行預處理(只應用數據標準化)。

沒有批標準化的網絡架構

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(28 * 28, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

有批標準化的網絡架構

class SimpleNetBN(nn.Module):
    def __init__(self):
        super(SimpleNetBN, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(28 * 28, 64),
            CustomBatchNorm(64),
            nn.ReLU(),
            nn.Linear(64, 128),
            CustomBatchNorm(128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

結果

下圖顯示了在咱們的SimpleNet的第一層以後得到的激活的分佈。咱們能夠看到,即便通過20個epoch,分佈仍然是高斯分佈(在訓練過程當中學習到的小尺度和移位)。

咱們也能夠看到收斂速度方面的巨大進步。綠色曲線(帶有批標準化)代表,咱們能夠更快地收斂到具備批標準化的最優解。

實驗結果詳見(https://github.com/sinitame/neuralnetworks-ents/blob/master/batch_normalization/batch_normaliz.ipynb)

結論

使用批標準化進行訓練的優勢

  • 一個mini-batch處理的損失梯度是對訓練集的梯度的估計,訓練的質量隨着批處理大小的增長而提升。
  • 因爲gpu提供的並行性,批處理大小上的計算要比單個示例的屢次計算效率高得多。
  • 在每一層使用批處理歸一化來減小內部方差的移位,大大提升了網絡的學習效率。

原文連接:https://towardsdatascience.com/understanding-batch-normalization-for-neural-networks-1cd269786fa6

歡迎關注磐創AI博客站:
http://panchuang.net/

sklearn機器學習中文官方文檔:
http://sklearn123.com/

歡迎關注磐創博客資源彙總站:
http://docs.panchuang.net/

相關文章
相關標籤/搜索