RNN神經網絡產生梯度消失和梯度爆炸的緣由及解決方案

一、RNN模型結構

  循環神經網絡RNN(Recurrent Neural Network)會記憶以前的信息,並利用以前的信息影響後面結點的輸出。也就是說,循環神經網絡的隱藏層之間的結點是有鏈接的,隱藏層的輸入不只包括輸入層的輸出,還包括上時刻隱藏層的輸出。下圖爲RNN模型結構圖:算法

 

二、RNN前向傳播算法

RNN前向傳播公式爲:網絡

  其中:app

    Stt時刻的隱含層狀態值;函數

    Ott時刻的輸出值;學習

    ①是隱含層計算公式,U是輸入x的權重矩陣,St-1t-1時刻的狀態值,WSt-1做爲輸入的權重矩陣,$\Phi $是激活函數;spa

    ②是輸出層計算公司,V是輸出層的權重矩陣,f是激活函數。3d

  損失函數(loss function)採用交叉熵$L_{t}=-\overline{o_{t}}logo_{_{t}}$(Ot是t時刻預測輸出,$\overline{o_{t}}$是t時刻正確的輸出) blog

那麼對於一次訓練任務中,損失函數$L=\sum_{i=1}^{T}-\overline{o_{t}}logo_{_{t}}$, T是序列總長度。get

假設初始狀態St爲0,t=3 有三段時間序列時,由 ① 帶入②可獲得 input

  t一、t二、t3 各個狀態和輸出爲:

  t=1:

    狀態值:$s_{1}=\Phi (Ux_{1}+Ws_{0})$

    輸出:$o_{1}=f(V\Phi (Ux_{1}+Ws_{0}))$

 

  t=2:

    狀態值:$s_{2}=\Phi (Ux_{2}+Ws_{1})$

    輸出:$o_{2}=f(V\Phi (Ux_{2}+Ws_{1}))=f(V\Phi (Ux_{2}+W\Phi(Ux_{1}+Ws_{0})))$

 

  t=3:

    狀態值:$s_{3}=\Phi (Ux_{3}+Ws_{2})$

    輸出:$o_{3}=f(V\Phi (Ux_{3}+Ws_{2}))=\cdots =f(V\Phi (Ux_{3}+W\Phi(Ux_{2}+W\Phi(Ux_{1}+Ws_{0}))))$

 

三、RNN反向傳播算法

  BPTT(back-propagation through time)算法是針對循層的訓練算法,它的基本原理和BP算法同樣。其算法本質仍是梯度降低法,那麼該算法的關鍵就是計算各個參數的梯度,對於RNN來講參數有 U、W、V

反向傳播 

  現對t=3時刻的U、W、V求偏導,由鏈式法則獲得:

能夠簡寫成:

 

  觀察③④⑤式,可知,對於 V 求偏導不存在依賴問題;可是對於 W、U 求偏導的時候,因爲時間序列長度,存在長期依賴的狀況。主要緣由可由 t=一、二、3 的狀況觀察得 , St會隨着時間序列向前傳播,同時StU、W 的函數。

  前面得出的求偏導公式⑥,取其中累乘的部分出來,其中激活函數 Φ 一般是tanh函數 ,則

四、梯度爆炸和梯度消失的緣由

  激活函數tanh和它的導數圖像以下:

 

由上圖可知當激活函數是tanh函數時,tanh函數的導數最大值爲1,又不可能一直都取1這種狀況,實際上這種狀況不多出現,那麼也就是說,大部分都是小於1的數在作累乘,若當t很大的時候,$\prod_{j=k-1}^{t}tan{h}'W$中的$\prod_{j=k-1}^{t}tan{h}'$趨向0,舉個例子:0.850=0.00001427247也已經接近0了,這是RNN中梯度消失的緣由。

再看⑦部分:

$\prod_{j=k-1}^{3}\frac{\partial s_{j}}{\partial s_{j-1}}=\prod_{j=k-1}^{3}tan{h}'W$

若是參數 W 中的值太大,隨着序列長度一樣存在長期依賴的狀況,$\prod_{j=k-1}^{t}tan{h}'W$中的$\prod_{j=k-1}^{t}tan{h}'$趨向於無窮,那麼產生問題就是梯度爆炸。

在平時運用中,RNN比較深,使得梯度爆炸或者梯度消失問題會比較明顯。

五、解決梯度爆炸和梯度消失的方案

1)採使用ReLu激活函數

  面對梯度消失問題,能夠採用ReLu做爲激活函數,下圖爲ReLu函數:

  ReLU函數在定義域大於0部分的導數恆等於1,這樣能夠解決梯度消失的問題,(雖然恆等於1很容易發生梯度爆炸的狀況,但可經過設置適當的閾值可解決)。

另外計算方便,計算速度快,能夠加速網絡訓練。可是,定義域負數部分恆等於零,這樣會形成神經元沒法激活(可經過合理設置學習率,下降發生的機率)。

  ReLU有優勢也有缺點,其中的缺點能夠經過其餘操做取避免或者減低發生的機率,是目前使用最多的激活函數。

還能夠經過更改內部結構來解決梯度消失和梯度爆炸問題,那就是LSTM了。

2)使用長短記憶網絡LSTM

  使用長短時間記憶(LSTM)單元和相關的門類型神經元結構能夠減小梯度爆炸和梯度消失問題,LSTM的經典圖爲:

能夠抽象爲:

  三個×分別表明的就是forget gate,input gate,output gate,而我認爲LSTM最關鍵的就是forget gate這個部件。這三個gate是如何控制流入流出的呢,其實就是經過下面 ft,it,ot 三個函數來控制,由於$\sigma (x)$表明sigmoid函數) 的值是介於0到1之間的,恰好用趨近於0時表示流入不能經過gate,趨近於1時表示流入能夠經過gate。

$f_{t}=\sigma (W_{f}X_{t}+b_{f})$

$i_{t}=\sigma (W_{i}X_{t}+b_{i)$

$o_{t}=\sigma (W_{o}X_{t}+b_{o})$

  LSTM當前的狀態值爲: $S_{t}=f_{t}S_{t-1}+i_{t}X_{t}$,表達式展開後得:

$S_{t}=\sigma (W_{f}X_{t}+b_{f})S_{t-1}+\sigma (W_{i}X_{t}+b_{i})X_{t}$

  若是加上激活函數:

$S_{t}=tanh[\sigma (W_{f}X_{t}+b_{f})S_{t-1}+\sigma (W_{i}X_{t}+b_{i})X_{t}]$

  上文中講到傳統RNN求偏導的過程包含:

$\prod_{j=k-1}^{t}\frac{\partial s_{j}}{\partial s_{j-1}}=\prod_{j=k-1}^{t}tan{h}'W$ 

  對於LSTM一樣也包含這樣的一項,可是在LSTM中 爲:

$\prod_{j=k-1}^{t}\frac{\partial s_{j}}{\partial s_{j-1}}=\prod_{j=k-1}^{t}tan{h}'(W_{f}X_{t}+b_{f})$

  假設$Z=tanh'(x)\sigma (y)$,則Z的函數圖像以下圖所示:

 

  能夠看到該函數值基本上不是0就是1。

  傳統RNN的求偏導過程:

$\frac{\sigma L_{3}}{\sigma W}=\sum_{k=0}^{t}\frac{\partial L_{3}}{\partial o_{3}}\frac{\partial o_{3}}{\partial s_{3}}(\prod_{j=k-1}^{3}\frac{\partial s_{j}}{\partial s_{j-1}})\frac{\partial s_{k}}{\partial W}$

  若是在LSTM中上式可能就會變成:

$\frac{\sigma L_{3}}{\sigma W}=\sum_{k=0}^{t}\frac{\partial L_{3}}{\partial o_{3}}\frac{\partial o_{3}}{\partial s_{3}}\frac{\partial s_{k}}{\partial W}$

  由於$\prod_{j=k-1}^{3}\frac{\partial s_{j}}{\partial s_{j-1}}=\prod_{j=k-1}^{3}tan{h}'\sigma (W_{f}X_{t}+b_{f})\approx 0|1$,這樣解決了傳統RNN中梯度消失的問題。

 

參考

  https://www.jiqizhixin.com/articles/2019-01-17-7

  https://zhuanlan.zhihu.com/p/28687529