RNN - LSTM - GRU


循環神經網絡 (Recurrent Neural Network,RNN) 是一類具備短時間記憶能力的神經網絡,於是經常使用於序列建模。本篇先總結 RNN 的基本概念,以及其訓練中時常遇到梯度爆炸和梯度消失問題,再引出 RNN 的兩個主流變種 —— LSTM 和 GRU。算法


Vanilla RNN



Vanilla RNN 的主體結構: 網絡


上圖中 \(\bf{X, h, y}\) 都是向量,公式以下:
\[ % <![CDATA[ \begin{align} \textbf{h}_{t} &= f_{\textbf{W}}\left(\textbf{h}_{t-1}, \textbf{x}_{t} \right) \tag{1} \\ \textbf{h}_{t} &= f\left(\textbf{W}_{hx}\textbf{x}_{t} + \textbf{W}_{hh}\textbf{h}_{t-1} + \textbf{b}_{h}\right) \tag{2a} \\ \textbf{h}_{t} &= \textbf{tanh}\left(\textbf{W}_{hx}\textbf{x}_{t} + \textbf{W}_{hh}\textbf{h}_{t-1} + \textbf{b}_{h}\right) \tag{2b} \\ \hat{\textbf{y}}_{t} &= \textbf{softmax}\left(\textbf{W}_{yh}\textbf{h}_{t} + \textbf{b}_{y}\right) \tag{3} \end{align} %]]> \]
其中 \(\textbf{W}_{hx} \in \mathbb{R}^{h \times x}, \; \textbf{W}_{hh} \in \mathbb{R}^{h \times h}, \; \textbf{W}_{yh} \in \mathbb{R}^{y \times h}, \; \textbf{b}_{h} \in \mathbb{R}^{h}, \; \textbf{b}_{y} \in \mathbb{R}^{y}\)函數


\((2a)\) 式中的兩個矩陣 \(\mathbf{W}\) 能夠合併:
\[ \begin{align*} \textbf{h}_{t} &= f\left(\textbf{W}_{hx}\textbf{x}_{t} + \textbf{W}_{hh}\textbf{h}_{t-1} + \textbf{b}_{h}\right) \\ & = f\left(\left(\textbf{W}_{hx}, \textbf{W}_{hh}\right) \begin{pmatrix} \textbf{x}_t \\ \textbf{h}_{t-1} \end{pmatrix} + \textbf{b}_{h}\right) \\ & = f\left(\textbf{W} \begin{pmatrix} \textbf{x}_t \\ \textbf{h}_{t-1} \end{pmatrix} + \textbf{b}_{h}\right) \end{align*} \]性能


注意到在計算時,每一 time step 中使用的參數 \(\textbf{W}, \; \textbf{b}\) 是同樣的,也就是說每一個步驟的參數都是共享的,這是RNN的重要特色。學習

和普通的全鏈接層相比,RNN 除了輸入 \(\textbf{x}_t\) 外,還有輸入隱藏層上一節點 \(\mathbf{h}_{t-1}\) ,RNN 每一層的輸出就是這兩個輸入用矩陣 \(\textbf{W}_{hx}\)\(\textbf{W}_{hh}\)和激活函數進行組合的結果。從 \((2a)\) 式能夠看出 \(\textbf{x}_t\)\(\mathbf{h}_{t-1}\) 都是與 \(\textbf{h}_t\) 全鏈接的,下圖形象展現了各個時間節點 RNN 隱藏層記憶的變化。隨着時間流逝,最初的藍色結點保留地愈來愈少,這意味着RNN對於長時記憶的困難。spa




Vanishing & Exploding Gradient Problems

RNN 對於長時記憶的困難主要來源於梯度爆炸 / 消失問題,下面進行說明。RNN 中 Loss 的計算圖示例:翻譯



總的 Loss 是每一個 time step 的加和 : \(\mathcal{\large{L}} (\hat{\textbf{y}}, \textbf{y}) = \sum_{t = 1}^{T} \mathcal{ \large{L} }(\hat{\textbf{y}_t}, \textbf{y}_{t})\)3d


backpropagation through time (BPTT) 算法,參數的梯度爲:
\[ \frac{\partial \boldsymbol{\mathcal{L}}}{\partial \textbf{W}} = \sum_{t=1}^{T} \frac{\partial \boldsymbol{\mathcal{L}}_{t}}{\partial \textbf{W}} = \sum_{t=1}^{T} \frac{\partial \boldsymbol{\mathcal{L}}_t}{\partial \textbf{y}_{t}} \frac{\partial \textbf{y}_{t}}{\partial \textbf{h}_{t}} \overbrace{\frac{\partial \textbf{h}_{t}}{\partial \textbf{h}_{k}}}^{ \bigstar } \frac{\partial \textbf{h}_{k}}{\partial \textbf{W}} \]
其中 \(\frac{\partial \textbf{h}_{t}}{\partial \textbf{h}_{k}}\) 包含一系列 \(\text{Jacobian}\) 矩陣,
\[ \frac{\partial \textbf{h}_{t}}{\partial \textbf{h}_{k}} = \frac{\partial \textbf{h}_{t}}{\partial \textbf{h}_{t-1}} \frac{\partial \textbf{h}_{t-1}}{\partial \textbf{h}_{t-2}} \cdots \frac{\partial \textbf{h}_{k+1}}{\partial \textbf{h}_{k}} = \prod_{i=k+1}^{t} \frac{\partial \textbf{h}_{i}}{\partial \textbf{h}_{i-1}} \]
因爲 RNN 中每一個 time step 都是用相同的 \(\textbf{W}\) ,因此由 \((2a)\) 式可得:
\[ \prod_{i=k+1}^{t} \frac{\partial \textbf{h}_{i}}{\partial \textbf{h}_{i-1}} = \prod_{i=k+1}^{t} \textbf{W}^\top \text{diag} \left[ f'\left(\textbf{h}_{i-1}\right) \right] \]orm


因爲 \(\textbf{W}_{hh} \in \mathbb{R}^{h \times h}\) 爲方陣,對其進行特徵值分解:
\[ \mathbf{W} = \mathbf{V} \, \text{diag}(\boldsymbol{\lambda}) \, \mathbf{V}^{-1} \]
因爲上式是連乘 \(\text{t}\)\(\mathbf{W}\) :
\[ \mathbf{W}^t = (\mathbf{V} \, \text{diag}(\boldsymbol{\lambda}) \, \mathbf{V}^{-1})^t = \mathbf{V} \, \text{diag}(\boldsymbol{\lambda})^t \, \mathbf{V}^{-1} \]
連乘的次數多了以後,則若最大的特徵值 \(\lambda >1\) ,會產生梯度爆炸; \(\lambda < 1\) ,則會產生梯度消失 。不論哪一種狀況,都會致使模型難以學到有用的模式。blog


下左圖顯示一個 time step 中 tanh 函數的計算結果,右圖顯示整個神經網絡的計算結果,能夠清楚地看到哪一個區域最容易產生梯度爆炸/消失問題。



梯度爆炸的解決辦法:

(1) Truncated Backpropagation through time:每次只 BP 固定的 time step 數,相似於 mini-batch SGD。缺點是喪失了長距離記憶的能力。


(2) Clipping Gradients: 當梯度超過必定的 threshold 後,就進行 element-wise 的裁剪,該方法的缺點是又引入了一個新的參數 threshold。同時該方法也可視爲一種基於瞬時梯度大小來自適應 learning rate 的方法:
\[ \text{if} \quad \lVert \textbf{g} \rVert \ge \text{threshold} \\[1ex] \textbf{g} \leftarrow \frac{\text{threshold}}{\lVert \textbf{g} \rVert} \textbf{g} \]



梯度消失的解決辦法

(1) 使用 LSTM、GRU等升級版 RNN,使用各類 gates 控制信息的流通。

(2) 在這篇論文 ( https://arxiv.org/pdf/1602.06662.pdf ) 中提出將權重矩陣 \(\textbf{W}\) 初始化爲正交矩陣。正交矩陣有以下性質:\(A^T A =A A^T = I, \; A^T = A^{-1}\), 正交矩陣的特徵值的絕對值爲 \(\text{1}\) 。證實以下, 對矩陣 \(A\) 有:
\[ \begin{align*} & A \mathbf{v} = \lambda \mathbf{v} \\[1ex] ||A \mathbf{v}||^2& = (A \mathbf{v})^\text{T} (A \mathbf{v}) \\ &= \mathbf{v}^\text{T}A ^{\text{T}}A \mathbf{v} \\ & = \mathbf{v}^{\text{T}}\mathbf{v} \\ & = ||\mathbf{v}||^2 \\ & = |\lambda|^2 ||\mathbf{v}||^2 \end{align*} \]
因爲 \(\mathbf{v}\) 爲特徵向量,\(\mathbf{v} \neq 0\) ,因此 \(|\lambda| = 1\) ,這樣連乘以後 \(\lambda^t\) 不會出現愈來愈小的狀況。

(3) 反轉輸入序列。像在機器翻譯中使用 seq2seq 模型,若使用正常序列輸入,則輸入序列的第一個詞和輸出序列的第一個詞相距較遠,難以學到長期依賴。將輸入序列反向後,輸入序列的第一個詞就會和輸出序列的第一個詞很是接近,兩者的相互關係也就比較容易學習了。這樣模型能夠先學前幾個詞的短時間依賴,再學後面詞的長期依賴關係。見下圖正常輸入順序是 \(|\text{ABC}|\),反向是 \(|\text{CBA}|\) ,則 \(\text{A}\) 與第一個輸出詞 \(\text{W}\) 接近:





LSTM



雖然 Vanilla RNN 理論上能夠創建長時間間隔狀態之間的依賴關係,但因爲梯度爆炸或消失問題,實際上只能學到短時間依賴關係。爲了學到長期依賴關係,LSTM 中引入了門控機制來控制信息的累計速度,包括有選擇地加入新的信息,並有選擇地遺忘以前累計的信息,整個 LSTM 單元結構以下圖所示:

\[ \begin{align} \text{input gate}&: \quad \textbf{i}_t = \sigma(\textbf{W}_i\textbf{x}_t + \textbf{U}_i\textbf{h}_{t-1} + \textbf{b}_i)\tag{1} \\ \text{forget gate}&: \quad \textbf{f}_t = \sigma(\textbf{W}_f\textbf{x}_t + \textbf{U}_f\textbf{h}_{t-1} + \textbf{b}_f) \tag{2}\\ \text{output gate}&: \quad \textbf{o}_t = \sigma(\textbf{W}_o\textbf{x}_t + \textbf{U}_o\textbf{h}_{t-1} + \textbf{b}_o) \tag{3}\\ \text{new memory cell}&: \quad \tilde{\textbf{c}}_t = \text{tanh}(\textbf{W}_c\textbf{x}_t + \textbf{U}_c\textbf{h}_{t-1} + \textbf{b}_c) \tag{4}\\ \text{final memory cell}& : \quad \textbf{c}_t = \textbf{f}_t \odot \textbf{c}_{t-1} + \textbf{i}_t \odot \tilde{\textbf{c}}_t \tag{5}\\ \text{final hidden state} &: \quad \textbf{h}_t= \textbf{o}_t \odot \text{tanh}(\textbf{c}_t) \tag{6} \end{align} \]
式 $(1) \sim (4) $ 的輸入都同樣,於是能夠合併:
\[ \begin{pmatrix} \textbf{i}_t \\ \textbf{f}_{t} \\ \textbf{o}_t \\ \tilde{\textbf{c}}_t \end{pmatrix} = \begin{pmatrix} \sigma \\ \sigma \\ \sigma \\ \text{tanh} \end{pmatrix} \left(\textbf{W} \begin{bmatrix} \textbf{x}_t \\ \textbf{h}_{t-1} \end{bmatrix} + \textbf{b} \right) \]

$\tilde{\textbf{c}}_t $ 爲時刻 t 的候選狀態,\(\textbf{i}_t\) 控制 \(\tilde{\textbf{c}}_t\) 中有多少新信息須要保存,\(\textbf{f}_{t}\) 控制上一時刻的內部狀態 \(\textbf{c}_{t-1}\) 須要遺忘多少信息,\(\textbf{o}_t\) 控制當前時刻的內部狀態 \(\textbf{c}_t\) 有多少信息須要輸出給外部狀態 \(\textbf{h}_t\)

下表顯示 forget gate 和 input gate 的關係,能夠看出 forget gate 其實更應該被稱爲 「remember gate」, 由於其開啓時以前的記憶信息 \(\textbf{c}_{t-1}\) 纔會被保留,關閉時則會遺忘全部:

forget gate input gate result
1 0 保留上一時刻的狀態 \(\textbf{c}_{t-1}\)
1 1 保留上一時刻 \(\textbf{c}_{t-1}\) 和添加新信息 \(\tilde{\textbf{c}}_t\)
0 1 清空歷史信息,引入新信息 \(\tilde{\textbf{c}}_t\)
0 0 清空全部新舊信息


對比 Vanilla RNN,能夠發如今時刻 t,Vanilla RNN 經過 \(\textbf{h}_t\) 來保存和傳遞信息,上文已分析了若是時間間隔較大容易產生梯度消失的問題。 LSTM 則經過記憶單元 \(\textbf{c}_t\) 來傳遞信息,經過 \(\textbf{i}_t\)\(\textbf{f}_{t}\) 的調控,\(\textbf{c}_t\) 能夠在 t 時刻捕捉到某個關鍵信息,並有能力將此關鍵信息保存必定的時間間隔。


原始的 LSTM 中是沒有 forget gate 的,即:
\[ \textbf{c}_t = \textbf{c}_{t-1} + \textbf{i}_t \odot \tilde{\textbf{c}}_t \]
這樣 \(\frac{\partial \textbf{c}_t}{\partial \textbf{c}_{t-1}}\) 恆爲 \(\text{1}\) 。可是這樣 \(\textbf{c}_t\) 會不斷增大,容易飽和從而下降模型性能。後來引入了 forget gate ,則梯度變爲 \(\textbf{f}_{t}\) ,事實上連乘多個 \(\textbf{f}_{t} \in (0,1)\) 一樣會致使梯度消失,可是 LSTM 的一個初始化技巧就是將 forget gate 的 bias 置爲正數(例如 1 或者 5,如 tensorflow 中的默認值就是 \(1.0\) ),這樣一來模型剛開始訓練時 forget gate 的值都接近 1,不會發生梯度消失 (反之若 forget gate 的初始值太小則意味着前一時刻的大部分信息都丟失了,這樣很難捕捉到長距離依賴關係)。 隨着訓練過程的進行,forget gate 就再也不恆爲 1 了。不過,一個訓好的模型裏各個 gate 值每每不是在 [0, 1] 這個區間裏,而是要麼 0 要麼 1,不多有相似 0.5 這樣的中間值,其實至關於一個二元的開關。假如在某個序列裏,forget gate 全是 1,那麼梯度不會消失;某一個 forget gate 是 0,模型選擇遺忘上一時刻的信息。


LSTM 的一種變體增長 peephole 鏈接,這樣三個 gate 不只依賴於 \(\textbf{x}_t\)\(\textbf{h}_{t-1}\),也依賴於記憶單元 \(\textbf{c}\)
\[ \begin{align*} \text{input gate}&: \quad \textbf{i}_t = \sigma(\textbf{W}_i\textbf{x}_t + \textbf{U}_i\textbf{h}_{t-1} + \textbf{V}_i\textbf{c}_{t-1} + \textbf{b}_i) \\ \text{forget gate}&: \quad \textbf{f}_t = \sigma(\textbf{W}_f\textbf{x}_t + \textbf{U}_f\textbf{h}_{t-1} + \textbf{V}_f\textbf{c}_{t-1} +\textbf{b}_f) \\ \text{output gate}&: \quad \textbf{o}_t = \sigma(\textbf{W}_o\textbf{x}_t + \textbf{U}_o\textbf{h}_{t-1} + \textbf{V}_o\textbf{c}_{t} +\textbf{b}_o) \\ \end{align*} \]

注意 input gate 和 forget gate 鏈接的是 \(\textbf{c}_{t-1}\) ,而 output gate 鏈接的是 \(\textbf{c}_t\) 。下圖來自 《LSTM: A Search Space Odyssey》,標註了 peephole 鏈接的樣貌。





GRU



相比於 Vanilla RNN (每一個 time step 有一個輸入 \(\textbf{x}_t\) ),從上面的 \((1) \sim (4)\) 式能夠看出 一個 LSTM 單元有四個輸入 (以下圖,不考慮 peephole) ,於是參數是 Vanilla RNN 的四倍,帶來的結果是訓練起來很慢,於是在2014年 Cho 等人提出了 GRU ,對 LSTM 進行了簡化,在不影響效果的前提下加快了訓練速度。



\(\large\scr{LSTM:}\)
\[ \normalsize \begin{align} \text{input gate}&: \quad \textbf{i}_t = \sigma(\textbf{W}_i\textbf{x}_t + \textbf{U}_i\textbf{h}_{t-1} + \textbf{b}_i)\tag{1} \\ \text{forget gate}&: \quad \textbf{f}_t = \sigma(\textbf{W}_f\textbf{x}_t + \textbf{U}_f\textbf{h}_{t-1} + \textbf{b}_f) \tag{2}\\ \text{output gate}&: \quad \textbf{o}_t = \sigma(\textbf{W}_o\textbf{x}_t + \textbf{U}_o\textbf{h}_{t-1} + \textbf{b}_o) \tag{3}\\ \text{new memory cell}&: \quad \tilde{\textbf{c}}_t = \text{tanh}(\textbf{W}_c\textbf{x}_t + \textbf{U}_c\textbf{h}_{t-1} + \textbf{b}_c) \tag{4}\\ \text{final memory cell}& : \quad \textbf{c}_t = \textbf{f}_t \odot \textbf{c}_{t-1} + \textbf{i}_t \odot \tilde{\textbf{c}}_t \tag{5}\\ \text{final hidden state} &: \quad \textbf{h}_t= \textbf{o}_t \odot \text{tanh}(\textbf{c}_t) \tag{6} \end{align} \]
在式 \((5)​\) 中 forget gate 和 input gate 是互補關係,於是比較冗餘,GRU 將其合併爲一個 update gate。同時 GRU 也不引入額外的記憶單元 (LSTM 中的 \(\textbf{c}​\)) ,而是直接在當前狀態 \(\textbf{h}_t​\) 和歷史狀態 \(\textbf{h}_{t-1}​\) 之間創建線性依賴關係。


\(\large\scr{GRU:}\)
\[ \normalsize \begin{align} \text{reset gate}&: \quad \textbf{r}_t = \sigma(\textbf{W}_r\textbf{x}_t + \textbf{U}_r\textbf{h}_{t-1} + \textbf{b}_r)\tag{7} \\ \text{update gate}&: \quad \textbf{z}_t = \sigma(\textbf{W}_z\textbf{x}_t + \textbf{U}_z\textbf{h}_{t-1} + \textbf{b}_z)\tag{8} \\ \text{new memory cell}&: \quad \tilde{\textbf{h}}_t = \text{tanh}(\textbf{W}_h\textbf{x}_t + \textbf{r}_t \odot (\textbf{U}_h\textbf{h}_{t-1}) + \textbf{b}_h) \tag{9}\\ \text{final hidden state}&: \quad \textbf{h}_t = \textbf{z}_t \odot \textbf{h}_{t-1} + (1 - \textbf{z}_t) \odot \tilde{\textbf{h}}_t \tag{10} \end{align} \]
$ \tilde{\textbf{h}}_t $ 爲時刻 t 的候選狀態,\(\textbf{r}_t\) 控制 $ \tilde{\textbf{h}}_t $ 有多少依賴於上一時刻的狀態 \(\textbf{h}_{t-1}\) ,若是 \(\textbf{r}_t = 1\) ,則式 \((9)\) 與 Vanilla RNN 一致,對於短依賴的 GRU 單元,reset gate 一般會更新頻繁。\(\textbf{z}_t\) 控制當前的內部狀態 \(\textbf{h}_t\) 中有多少來自於上一時刻的 \(\textbf{h}_{t-1}\) 。若是 \(\textbf{z}_t = 1\) ,則會每步都傳遞一樣的信息,和當前輸入 \(\textbf{x}_t\) 無關。


另外一方面看,\(\textbf{r}_t\) 與 LSTM 中的 \(\textbf{o}_t\) 角色有些相似,由於將上面的 \((6)\) 式代入 \((4)\) 式能夠獲得:

\[ \begin{align*} \tilde{\textbf{c}}_t &= \text{tanh}(\textbf{W}_c\textbf{x}_t + \textbf{U}_c\textbf{h}_{t-1} + \textbf{b}_c) \\ \textbf{h}_t &= \textbf{o}_t \odot \text{tanh}(\textbf{c}_t) \end{align*} \quad \Longrightarrow \quad \tilde{\textbf{c}}_t = \text{tanh}(\textbf{W}_c\textbf{x}_t + \textbf{U}_c \left(\textbf{o}_{t-1} \odot \text{tanh}(\textbf{c}_{t-1})\right) + \textbf{b}_c) \]


最後是 cs224n 中提出的 RNN 訓練 tips:





/

相關文章
相關標籤/搜索