詳解機器學習中的梯度消失、爆炸緣由及其解決方法

前言

本文主要深刻介紹深度學習中的梯度消失和梯度爆炸的問題以及解決方案。本文分爲三部分,第一部分主要直觀的介紹深度學習中爲何使用梯度更新,第二部分主要介紹深度學習中梯度消失及爆炸的緣由,第三部分對提出梯度消失及爆炸的解決方案。有基礎的同鞋能夠跳着閱讀。
其中,梯度消失爆炸的解決方案主要包括如下幾個部分。html

- 預訓練加微調
- 梯度剪切、權重正則(針對梯度爆炸)
- 使用不一樣的激活函數
- 使用batchnorm
- 使用殘差結構
- 使用LSTM網絡

第一部分:爲何要使用梯度更新規則


在介紹梯度消失以及爆炸以前,先簡單說一說梯度消失的根源—–深度神經網絡和反向傳播。目前深度學習方法中,深度神經網絡的發展造就了咱們能夠構建更深層的網絡完成更復雜的任務,深層網絡好比深度卷積網絡,LSTM等等,並且最終結果代表,在處理複雜任務上,深度網絡比淺層的網絡具備更好的效果。可是,目前優化神經網絡的方法都是基於反向傳播的思想,即根據損失函數計算的偏差經過梯度反向傳播的方式,指導深度網絡權值的更新優化。這樣作是有必定緣由的,首先,深層網絡由許多非線性層堆疊而來,每一層非線性層均可以視爲是一個非線性函數 f ( x ) f(x) (非線性來自於非線性激活函數),所以整個深度網絡能夠視爲是一個複合的非線性多元函數
F ( x ) = f n ( . . . f 3 ( f 2 ( f 1 ( x ) θ 1 + b ) θ 2 + b ) . . . ) F(x) = {f_n}(...{f_3}({f_2}({f_1}(x)*{\theta _1} + b)*{\theta _2} + b)...) 咱們最終的目的是但願這個多元函數能夠很好的完成輸入到輸出之間的映射,假設不一樣的輸入,輸出的最優解是 g ( x ) g(x) ,那麼,優化深度網絡就是爲了尋找到合適的權值,知足 L o s s = L ( g ( x ) , F ( x ) ) Loss = L(g(x),F(x)) 取得極小值點,好比最簡單的損失函數
L o s s = g ( x ) f ( x ) 2 2 Loss = ||g(x)-f(x)||^2_2 ,假設損失函數的數據空間是下圖這樣的,咱們最優的權值就是爲了尋找下圖中的最小值點,對於這種數學尋找最小值問題,採用梯度降低的方法再適合不過了。
這裏寫圖片描述web

第二部分:梯度消失、爆炸

梯度消失與梯度爆炸實際上是一種狀況,看接下來的文章就知道了。兩種狀況下梯度消失常常出現,一是在深層網絡中,二是採用了不合適的損失函數,好比sigmoid。梯度爆炸通常出如今深層網絡和權值初始化值太大的狀況下,下面分別從這兩個角度分析梯度消失和爆炸的緣由。算法

1.深層網絡角度

比較簡單的深層網絡以下:
這裏寫圖片描述
圖中是一個四層的全鏈接網絡,假設每一層網絡激活後的輸出爲 f i ( x ) f_i(x) ,其中 i i 爲第 i i 層, x x 表明第 i i 層的輸入,也就是第 i 1 i-1 層的輸出, f f 是激活函數,那麼,得出 f i + 1 = f ( f i w i + 1 + b i + 1 ) f_{i+1}=f(f_i*w_{i+1}+b_{i+1}) ,簡單記爲 f i + 1 = f ( f i w i + 1 ) f_{i+1}=f(f_i*w_{i+1})
BP算法基於梯度降低策略,以目標的負梯度方向對參數進行調整,參數的更新爲 w w + Δ w w \leftarrow w+\Delta w ,給定學習率 α \alpha ,得出 Δ w = α L o s s w \Delta w=-\alpha \frac{\partial Loss}{\partial w} 。若是要更新第二隱藏層的權值信息,根據鏈式求導法則,更新梯度信息:
Δ w 2 = L o s s w 2 = L o s s f 4 f 4 f 3 f 3 f 2 f 2 w 2 \Delta w_2=\frac{\partial Loss}{\partial w_2}=\frac{\partial Loss}{\partial f_4}\frac{\partial f_4}{\partial f_3}\frac{\partial f_3}{\partial f_2}\frac{\partial f_2}{\partial w_2} ,很容易看出來 f 2 w 2 = f ( f 1 w 2 ) f 1 \frac{\partial f_2}{\partial w_2}=\frac{\partial f}{\partial (f_1*w_2)}f_1 ,即第二隱藏層的輸入。
因此說, f 4 f 3 × w 4 \frac{\partial f_4}{\partial f_3} \times w4 就是對激活函數進行求導,若是此部分大於1,那麼層數增多的時候,最終的求出的梯度更新將以指數形式增長,即發生梯度爆炸,若是此部分小於1,那麼隨着層數增多,求出的梯度更新信息將會以指數形式衰減,即發生了梯度消失。若是說從數學上看不夠直觀的話,下面幾個圖能夠很直觀的說明深層網絡的梯度問題 1 ^1 (圖片內容來自參考文獻1):微信

注:下圖中的隱層標號和第一張全鏈接圖隱層標號恰好相反。
圖中的曲線表示權值更新的速度,對於下圖兩個隱層的網絡來講,已經能夠發現隱藏層2的權值更新速度要比隱藏層1更新的速度慢網絡

這裏寫圖片描述

那麼對於四個隱層的網絡來講,就更明顯了,第四隱藏層比第一隱藏層的更新速度慢了兩個數量級:

這裏寫圖片描述

總結:從深層網絡角度來說,不一樣的層學習的速度差別很大,表現爲網絡中靠近輸出的層學習的狀況很好,靠近輸入的層學習的很慢,有時甚至訓練了好久,前幾層的權值和剛開始隨機初始化的值差很少。所以,梯度消失、爆炸,其根本緣由在於反向傳播訓練法則,屬於先天不足,另外多說一句,Hinton提出capsule的緣由就是爲了完全拋棄反向傳播,若是真能大範圍普及,那真是一個革命。app

2.激活函數角度

其實也注意到了,上文中提到計算權值更新信息的時候須要計算前層偏導信息,所以若是激活函數選擇不合適,好比使用sigmoid,梯度消失就會很明顯了,緣由看下圖,左圖是sigmoid的損失函數圖,右邊是其導數的圖像,若是使用sigmoid做爲損失函數,其梯度是不可能超過0.25的,這樣通過鏈式求導以後,很容易發生梯度消失,sigmoid函數數學表達式爲: s i g m o i d ( x ) = 1 1 + e x sigmoid(x)=\frac{1}{1+e^{-x}}
sigmoid函數 sigmoid函數導數框架

同理,tanh做爲激活函數,它的導數圖以下,能夠看出,tanh比sigmoid要好一些,可是它的導數仍然是小於1的。tanh數學表達爲:

t a n h ( x ) = e x e x e x + e x tanh(x)=\frac{e^x-e^{-x}}{e^x+e^{-x}} 機器學習

這裏寫圖片描述

第三部分:梯度消失、爆炸的解決方案


2.1 方案1-預訓練加微調

此方法來自Hinton在2006年發表的一篇論文,Hinton爲了解決梯度的問題,提出採起無監督逐層訓練方法,其基本思想是每次訓練一層隱節點,訓練時將上一層隱節點的輸出做爲輸入,而本層隱節點的輸出做爲下一層隱節點的輸入,此過程就是逐層「預訓練」(pre-training);在預訓練完成後,再對整個網絡進行「微調」(fine-tunning)。Hinton在訓練深度信念網絡(Deep Belief Networks中,使用了這個方法,在各層預訓練完成後,再利用BP算法對整個網絡進行訓練。此思想至關因而先尋找局部最優,而後整合起來尋找全局最優,此方法有必定的好處,可是目前應用的不是不少了。svg

2.2 方案2-梯度剪切、正則

梯度剪切這個方案主要是針對梯度爆炸提出的,其思想是設置一個梯度剪切閾值,而後更新梯度的時候,若是梯度超過這個閾值,那麼就將其強制限制在這個範圍以內。這能夠防止梯度爆炸。函數

注:在WGAN中也有梯度剪切限制操做,可是和這個是不同的,WGAN限制梯度更新信息是爲了保證lipchitz條件。

另一種解決梯度爆炸的手段是採用權重正則化(weithts regularization)比較常見的是 l 1 l1 正則,和 l 2 l2 正則,在各個深度框架中都有相應的API可使用正則化,好比在 t e n s o r f l o w tensorflow 中,若搭建網絡的時候已經設置了正則化參數,則調用如下代碼能夠直接計算出正則損失:

regularization_loss = tf.add_n(tf.losses.get_regularization_losses(scope='my_resnet_50'))

若是沒有設置初始化參數,也可使用如下代碼計算 l 2 l2 正則損失:

l2_loss = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables() if 'weights' in var.name])

正則化是經過對網絡權重作正則限制過擬合,仔細看正則項在損失函數的形式:
L o s s = ( y W T x ) 2 + α W 2 Loss=(y-W^Tx)^2+ \alpha ||W||^2
其中, α \alpha 是指正則項係數,所以,若是發生梯度爆炸,權值的範數就會變的很是大,經過正則化項,能夠部分限制梯度爆炸的發生。

注:事實上,在深度神經網絡中,每每是梯度消失出現的更多一些。

2.3 方案3-relu、leakrelu、elu等激活函數

**Relu:**思想也很簡單,若是激活函數的導數爲1,那麼就不存在梯度消失爆炸的問題了,每層的網絡均可以獲得相同的更新速度,relu就這樣應運而生。先看一下relu的數學表達式:

這裏寫圖片描述

其函數圖像:

這裏寫圖片描述
從上圖中,咱們能夠很容易看出,relu函數的導數在正數部分是恆等於1的,所以在深層網絡中使用relu激活函數就不會致使梯度消失和爆炸的問題。

relu的主要貢獻在於:

-- 解決了梯度消失、爆炸的問題
 -- 計算方便,計算速度快
 -- 加速了網絡的訓練

同時也存在一些缺點

-- 因爲負數部分恆爲0,會致使一些神經元沒法激活(可經過設置小學習率部分解決)
 -- 輸出不是以0爲中心的

儘管relu也有缺點,可是仍然是目前使用最多的激活函數

leakrelu
leakrelu就是爲了解決relu的0區間帶來的影響,其數學表達爲: l e a k r e l u = m a x ( k x , x ) leakrelu=max(k*x,x) 其中k是leak係數,通常選擇0.01或者0.02,或者經過學習而來

這裏寫圖片描述

leakrelu解決了0區間帶來的影響,並且包含了relu的全部優勢
elu
elu激活函數也是爲了解決relu的0區間帶來的影響,其數學表達爲:這裏寫圖片描述
其函數及其導數數學形式爲:

這裏寫圖片描述

可是elu相對於leakrelu來講,計算要更耗時間一些

2.4 解決方案4-batchnorm

Batchnorm是深度學習發展以來提出的最重要的成果之一了,目前已經被普遍的應用到了各大網絡中,具備加速網絡收斂速度,提高訓練穩定性的效果,Batchnorm本質上是解決反向傳播過程當中的梯度問題。batchnorm全名是batch normalization,簡稱BN,即批規範化,經過規範化操做將輸出信號x規範化保證網絡的穩定性。
具體的batchnorm原理很是複雜,在這裏不作詳細展開,此部分大概講一下batchnorm解決梯度的問題上。具體來講就是反向傳播中,通過每一層的梯度會乘以該層的權重,舉個簡單例子:
正向傳播中 f 2 = f 1 ( w T x + b ) f_2=f_1(w^T*x+b) ,那麼反向傳播中, f 2 w = f 2 f 1 x \frac {\partial f_2}{\partial w}=\frac{\partial f_2}{\partial f_1}x ,反向傳播式子中有 x x 的存在,因此 x x 的大小影響了梯度的消失和爆炸,batchnorm就是經過對每一層的輸出規範爲均值和方差一致的方法,消除了 x x 帶來的放大縮小的影響,進而解決梯度消失和爆炸的問題,或者能夠理解爲BN將輸出從飽和區拉倒了非飽和區。
有關batch norm詳細的內容能夠參考個人另外一篇博客:
http://blog.csdn.net/qq_25737169/article/details/79048516

2.5 解決方案5-殘差結構

殘差結構提及殘差的話,不得不提這篇論文了:Deep Residual Learning for Image Recognition,關於這篇論文的解讀,能夠參考知乎連接:https://zhuanlan.zhihu.com/p/31852747這裏只簡單介紹殘差如何解決梯度的問題。
事實上,就是殘差網絡的出現致使了image net比賽的終結,自從殘差提出後,幾乎全部的深度網絡都離不開殘差的身影,相比較以前的幾層,幾十層的深度網絡,在殘差網絡面前都不值一提,殘差能夠很輕鬆的構建幾百層,一千多層的網絡而不用擔憂梯度消失過快的問題,緣由就在於殘差的捷徑(shortcut)部分,其中殘差單元以下圖所示:
這裏寫圖片描述
相比較於之前網絡的直來直去結構,殘差中有不少這樣的跨層鏈接結構,這樣的結構在反向傳播中具備很大的好處,見下式:
這裏寫圖片描述
式子的第一個因子 l o s s x L \frac{\partial loss}{\partial {{x}_{L}}} 表示的損失函數到達 L 的梯度,小括號中的1代表短路機制能夠無損地傳播梯度,而另一項殘差梯度則須要通過帶有weights的層,梯度不是直接傳遞過來的。殘差梯度不會那麼巧全爲-1,並且就算其比較小,有1的存在也不會致使梯度消失。因此殘差學習會更容易。

注:上面的推導並非嚴格的證實。

2.6 解決方案6-LSTM

LSTM全稱是長短時間記憶網絡(long-short term memory networks),是不那麼容易發生梯度消失的,主要緣由在於LSTM內部複雜的「門」(gates),以下圖,LSTM經過它內部的「門」能夠接下來更新的時候「記住」前幾回訓練的」殘留記憶「,所以,常常用於生成文本中。目前也有基於CNN的LSTM,感興趣的能夠嘗試一下。

這裏寫圖片描述

參考資料:

1.《Neural networks and deep learning》
2.《機器學習》周志華
3. https://www.cnblogs.com/willnote/p/6912798.html
4. https://www.zhihu.com/question/38102762
5. http://www.jianshu.com/p/9dc9f41f0b29


若是感興趣,請關注微信公衆號,還有更多精彩:
這裏寫圖片描述