pytorch循環神經網絡RNN從結構原理到應用實例

1、 RNN概述

人工神經網絡和卷積神經網絡的假設前提都是:元素之間是相互獨立的 ,可是在生活中不少狀況下這種假設並不成立,好比你寫一段有意義的話 「碰見一我的只需1秒,喜歡一我的只需3,秒,愛上一我的只需1分鐘,而我卻用個人[?]在愛你。」 ,做爲正常人咱們知道這裏應該填 「一輩子」,但之因此咱們會這樣填是由於咱們讀取了上下文,而普通的神經網絡輸入之間是相互獨立的,網絡沒有記憶能力。擴展一下:訓練樣本是連續的序列且其長短不一,如一段連續的語音、一段連續的文本等,這些序列前面的輸入與後面的輸入有有必定的相關性,很難將其拆解爲一個個單獨的樣原本進行DNN/CNN訓練。html

循環神經網絡(Recurrent Neural Networks,簡稱RNN)普遍應用於:算法

  • 語義分析(Semantic Analysis):按照語法分析器識別語法範疇進行語義檢查和處理,產生相應的中間代碼或者目標代碼
  • 情感分析(Sentiment Classification)
  • 圖像標註(Image Captioning):對圖片進行文本描述
  • 語言翻譯(Language Translation)

2、RNN網絡結構及原理

圖中各個參數意義:網絡

1)x(t)表明在序列索引號t時訓練樣本的輸入。一樣的,x(t−1)x(t+1)表明在序列索引號t−1t+1時訓練樣本的輸入。函數

2)h(t)表明在序列索引號t時模型的隱藏狀態。h(t)x(t)h(t−1)共同決定。spa

3)o(t)表明在序列索引號t時模型的輸出。o(t)只由模型當前的隱藏狀態h(t)決定。.net

4)L(t)表明在序列索引號t時模型的損失函數。翻譯

5)y(t)表明在序列索引號t時訓練樣本序列的真實輸出。3d

6)U,W,V這三個矩陣是咱們的模型的線性關係參數,它在整個RNN網絡中是共享的,這點和DNN很不相同。 也正由於是共享了,它體現了RNN的模型的「循環反饋」的思想。 [1]code

3、RNN前向傳播原理

對於任何一個序列索引號t,隱藏狀態\(h{(t)}\)\(h^{(t-1)}\)\(x^{(t)}\)獲得:htm

\[h^{(t)} = \sigma(z^{(t)} = \sigma(Ux^{(t)}+Wh^{(t-1)}+b)) \]

其中σ爲RNN的激活函數,b爲偏置值(bias)

序列索引號爲t的時候模型的輸出\(o^{(t)}\)的表達式比較簡單:

\[o^{(t)} = Vh^{(t)}+c \]

此時預測輸出爲:

\[\hat{y}^{(t)} = \sigma(o^{(t)}) \]

在上面這一過程當中使用了兩次激活函數(第一次得到隱藏狀態\(h^{(t)}\),第二次得到預測輸出\(\hat{y}^{(t)}\))一般在第一次使用tanh激活函數,第二次使用softmax激活函數

4、RNN反向傳播推導

RNN的法向傳播經過梯度降低一次次迭代獲得合適的參數U、W、V、b、c。在RNN中U、W、V、b、c參數在序列的各個位置都是相同的,反向傳播咱們更新的是一樣的參數。

對於RNN,咱們在序列的每個位置上都有損失,因此最終的損失L爲:

\[L = \sum_{t=1}^{\tau}L^{(t)} \]

損失函數對更新的參數進行求偏導(注意咱們這裏使用的兩個激活函數分別爲softmaxtanh,使用的偏差計算公式爲交叉熵):

  • 首先考慮與損失函數直接相關的兩個變量cV(即預測輸出時的權值和偏置值),利用損失函數能夠對這兩個變量進行直接求偏導(即對softmax函數求導):

\[\frac{\partial{L}}{\partial{c}} = \sum_{t=1}^{\tau}\frac{\partial{L^{(t)}}}{\partial{c}} = \sum_{t =1}^{\tau}\hat{y}^{(t)}-y^{(t)} \]

\[\frac{\partial{L}}{\partial{V}} = \sum_{t=1}^{\tau}\frac{\partial{L^{(t)}}}{\partial{V}} = \sum_{t =1}^{\tau}(\hat{y}^{(t)}-y^{(t)})(h^{(t)})^T \]

  • 而損失函數對W、U、b的偏導數計算就比較複雜了:在反向傳播時,某一序列位置t的梯度損失由當前位置的輸出對應的梯度損失和序列索引位置t+1時的梯度損失兩部分共同決定。
    從正向傳播來看:

\[h^{(t+1)} = tanh(Ux^{(t+1)}+Wh^{(t)}+b)) \]

對於W、U、b在某一序列位置t的梯度損失須要反向傳播一步步的計算。咱們定義序列索引t位置的隱藏狀態的梯度爲:

\[\delta^{(t)} = \frac{\partial{L}}{\partial{(h^{(t)})}} \]

\(\delta^{(\tau+1)}\)遞推\(\delta^{(t)}\)

\[\delta^{(t)} = (\frac{\partial{\delta^{(t)}}}{\partial{h^{(t)}}})^T \frac{\partial{L}}{\partial{o^{(t)}}} + (\frac{\partial{h^{(t+1)}}}{\partial{h^{(t)}}})^T \frac{\partial{L}}{\partial{h^{(t+1)}}} = V^T(\hat{y}^{(t)}-y^{(t)}) +W^Tdiag(1-(h^{(t+1)})^2)\delta^{(t+1)} \]

對於\(\delta{(\tau)}\),其後面沒有其餘的索引(最後一個輸入),所以:

\[\delta^{(\tau)} = (\frac{\partial{\delta^{(\tau)}}}{\partial{h^{(\tau)}}})^T \frac{\partial{L}}{\partial{o^{(\tau)}}} = V^T(\hat{y}^{(\tau)}-y^{(t)}) \]

根據\(\delta{(t)}\),咱們就能夠計算W、U、b了:

\[\frac{\partial{L}}{\partial{W}} = \sum_{t=1}^{\tau}diag(1-(h^{(t)})^2)\delta^{(t)}(h^{(t-1)})^T \]

\[\frac{\partial{L}}{\partial{b}} = \sum_{t=1}^{\tau}diag(1-(h^{(t)})^2)\delta^{(t)} \]

\[\frac{\partial{L}}{\partial{V}} = \sum_{t=1}^{\tau}diag(1-(h^{(t)})^2)\delta^{(t)}(x^{(t)})^T \]

5、RNN梯度消失問題


假設時間序列只有三段,\(S_0\)爲給定值,神經元沒有激活函數,而RNN按照最簡單的前向傳播:

\[S_1 = W_xX_1 + W_sS_0+b_1 ; O_1 = W_0S_1 +b2 \]

\[S_2 = W_xX_2 + W_sS_1+b_1 ; O_2 = W_0S_2 +b2 \]

\[S_3 = W_xX_3 + W_sS_2+b_1 ; O_3 = W_0S_3 +b2 \]

假設在t=3時刻,損失函數爲$$L_3 = \frac{1}{2}(Y_3-O_3)^2$$
對於一次訓練,其損失函數值是累加的:$$L = \sum_{t = 0}{T}L_t$$
此處利用反向傳播公式僅對Wx、Ws、W0求偏導數(Wx、Ws與輸出Output相關,並不是直接求損失函數Loss的偏導,在第四部分也已經說明了:

\[\frac{\partial{L}_3}{\partial{W}_0} = \frac{\partial{L}_3}{\partial{O}_3} \frac{\partial{O}_3}{\partial{W}_0} \]

\[\frac{\partial{L}_3}{\partial{W}_x} = \frac{\partial{L}_3}{\partial{O}_3} \frac{\partial{O}_3}{\partial{S}_3} \frac{\partial{S}_3}{\partial{W}_x} + \frac{\partial{L}_3}{\partial{O}_3} \frac{\partial{O}_3}{\partial{S}_3} \frac{\partial{S}_3}{\partial{S}_2}\frac{\partial{S}_2}{\partial{W}_x}+ \frac{\partial{L}_3}{\partial{O}_3} \frac{\partial{O}_3}{\partial{S}_3} \frac{\partial{S}_3}{\partial{S}_2}\frac{\partial{S}_2}{\partial{S}_1}\frac{\partial{S}_1}{\partial{w}_x} \]

\[\frac{\partial{L}_3}{\partial{W}_s} = \frac{\partial{L}_3}{\partial{O}_3} \frac{\partial{O}_3}{\partial{S}_3} \frac{\partial{S}_3}{\partial{W}_s} + \frac{\partial{L}_3}{\partial{O}_3} \frac{\partial{O}_3}{\partial{S}_3} \frac{\partial{S}_3}{\partial{S}_2}\frac{\partial{S}_2}{\partial{W}_s}+ \frac{\partial{L}_3}{\partial{O}_3} \frac{\partial{O}_3}{\partial{S}_3} \frac{\partial{S}_3}{\partial{S}_2}\frac{\partial{S}_2}{\partial{S}_1}\frac{\partial{S}_1}{\partial{w}_s} \]

從這冗長的公式中能夠看見用梯度降低法對損失函數求W0的偏導數其沒有很長的依賴(就是公式很短、求解簡單)可是對於WxWs的公式就很是長了,上面僅僅推到了三層網絡結構就已經如此繁雜了,推導任意時刻損失函數關於WxWs的偏導數公式:

\[\frac{\partial{L}_t}{\partial{W}_x} = \sum_{k=0}^{t}\frac{\partial{L}_t}{\partial{O}_t}\frac{\partial{O}_t}{\partial{S}_t}(\prod_{j=k+1}^{t}\frac{\partial{S}_j}{\partial{S}_{j-1}})\frac{\partial{S}_k}{\partial{W}_x} \]

\[\frac{\partial{L}_t}{\partial{W}_s} = \sum_{k=0}^{t}\frac{\partial{L}_t}{\partial{O}_t}\frac{\partial{O}_t}{\partial{S}_t}(\prod_{j=k+1}^{t}\frac{\partial{S}_j}{\partial{S}_{j-1}})\frac{\partial{S}_k}{\partial{W}_s} \]

若是再加上激活函數:$$S_j = tanh(W_xX_j + W_sS_{j-1}+b_1)$$

則$$\prod_{j=k+1}^{t}\frac{\partial{S}j}{\partial{S}{j-1}} = \prod_{j=k+1}^{t}W_s tanh^{'}$$

激活函數tanh[2]:

\[f(x) = tanh(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}} \]

tanh函數導數:

\[f(x)^{'} = 1 - (tanh(x))^2 \]

tanh函數及其導數

根據激活函數及其導數的圖像可見 [3]

  • \[tanh^{'}(x) ≤ 1 \]

  • 絕大部分狀況下,tanh的導數都是小於1的。不多狀況出現:

\[W_xX_j + W_sS_{j-1} + b_1 = 0 \]

  • 若是Ws是一個大於0小於1的值,當t很大的時候

\[\prod_{j=k+1}^{t}W_s tanh^{'} --> 0 \]

  • 若是Ws是一個很大的值,當t很大的時候

\[\prod_{j=k+1}^{t}W_s tanh^{'} --> ∞ \]

6、消除梯度爆炸和梯度消失

在公式:

\[\frac{\partial{L}_t}{\partial{W}_x} = \sum_{k=0}^{t}\frac{\partial{L}_t}{\partial{O}_t}\frac{\partial{O}_t}{\partial{S}_t}(\prod_{j=k+1}^{t}\frac{\partial{S}_j}{\partial{S}_{j-1}})\frac{\partial{S}_k}{\partial{W}_x} \]

\[\frac{\partial{L}_t}{\partial{W}_s} = \sum_{k=0}^{t}\frac{\partial{L}_t}{\partial{O}_t}\frac{\partial{O}_t}{\partial{S}_t}(\prod_{j=k+1}^{t}\frac{\partial{S}_j}{\partial{S}_{j-1}})\frac{\partial{S}_k}{\partial{W}_s} \]

致使梯度消失和梯度爆炸的緣由在於:

\[\prod_{j=k+1}^{t}\frac{\partial{S}_j}{\partial{S}_{j-1}} \]

消除這個部分的影響一個考慮是使得

\[\frac{\partial{S}_j}{\partial{S}_{j-1}} ≈ 1 \]

另外一種是使得:

\[\frac{\partial{S}_j}{\partial{S}_{j-1}} ≈ 0 \]


  1. 循環神經網絡(RNN)模型與前向反向傳播算法 ↩︎

  2. Tanh激活函數及求導過程 ↩︎

  3. RNN梯度消失和爆炸的緣由 ↩︎

相關文章
相關標籤/搜索