LSTM系列的梯度問題

1.前沿

本文主要圍繞NN、RNN、LSTM和GRU,討論後向傳播中所存在的梯度問題,以及解決方法,力求深入淺出。

2.神經網絡開始

神經網絡包括前向過程和後向過程,前向過程定義網絡結構,後向過程對網絡進行訓練(也就是優化參數),經過多輪迭代得到最終網絡(參數已定)
我們先來分析一個非常簡單的三層神經網絡:
這裏寫圖片描述
數據集 D = ( x 1 , y 1 ) , ( x 2 , y 2 ) , . . . , ( x m , y m )

2.1前向過程:

在輸入層,假設該層節點數爲d,也就是特徵x的維度, x i 作爲該層輸出;

在隱藏層中,該層節點數爲q,每個節點的輸入 α h 就是上一層所有節點輸出 x i 線性組合值,該節點的輸出 b h α j 激活值,這裏假設使用sigmoid激活函數;

在輸出層,該層節點數爲l,也就是輸出y的維度,同理,每個節點的輸入 β j b h 的線性組合值,輸出 y j β j ,根據不同任務選擇不同激活函數,比如二分類任務一般是用sigmoid激活函數把 y j [ 0 , 1 ]

2.2後向過程

1)首先我們根據網絡輸出和真實Label來定義Loss函數,這裏定義爲簡單的均方誤差:

E k = 1 2 j = 1 l ( y j y j ) 2

那麼我們的目標就是最小化Loss,調整參數 w_{hj} 和 v_{ih} ,使得網絡儘量去擬合真實數據。如何求最小值?那當然是求導了,根據loss函數對參數求導,然後往梯度下降的方向去更新參數,可以降低loss值。梯度主宰更新,如果梯度太小,會帶來梯度消失問題,導致參數更新很慢;那如果梯度很大,又會造成梯度爆炸問題

2)對於輸出層參數 w i j E w h j 進行鏈式求導,也就是,E先對節點的輸出 y j 求導,再對節點的輸入 β j 求導,最後 w h j 求導,結果爲:
E w h j = E y j y j β j β j w h j = ( y j y j ) y j ( 1 y j b h

這裏我們令 g j = ( y j y j ) y j ( 1 y j ) ,就可以得到參數 w h j 的更新量爲:

Δ w h j = η g j b h

3)對於隱藏層參數 v i h ,也是鏈式求導,E先對該層節點的輸出 b j 求導,再對節點的輸入 α j 求導,最後對 v i h 求導,其實在前面我們已經求出了部分梯度,最後結果爲:
E v i h = E b h b h α h α h v i h = ( j = 1 l E y j y j β j β j b h ) b h α h α h v i h

注意到, E y j y j β j 其實我們剛剛求過,其實就是 g j 這貨,因此我們可得:
E v i h = ( j = 1 l g j w h j ) b h ( 1 b h ) x i

再次令 e h = ( j = 1 l g j w h j ) b h ( 1 b h )   ,可以得到 v i h 的更新量爲:

Δ v i h = η e h x i

也就可以愉快地將更新 v i h = v i h + Δ v i h 了。

2.3 等等,事情好像並沒有這麼簡單

1) g j :這是上一層傳遞過來的梯度,如果上一層的梯度本來已經很小,那麼在這一層進行相乘,會導致這一層的梯度也很小。所以如果網絡層比較深,那麼在鏈式求導的過程中,越是低層的網絡層梯度在連乘過程中可能會變得越來越小,導致梯度消失

2) w h j :這是這一層的權重,這一項是造成梯度爆炸的主要原因,如果權重很大,也可能會導致相乘後的梯度也比較大。(梯度爆炸不是問題,做個梯度裁剪就行了,對梯度乘以一個縮放因子,我們主要考慮的是梯度消失問題)

3) b h ( 1 b h ) :這是sigmoid激活函數的導數,sigmoid激活值本身已經是一個比較小的數了,這兩個小於1的數相乘會變得更小,就可能會造成梯度消失。

我們直接來看sigmoid的這個圖吧,只有在靠近0的區域梯度比較大(然而也不會超過0.25),在接近無窮小或者無窮大的時候梯度幾乎是0了:
這裏寫圖片描述

所以sigmoid是造成梯度消失的一個重要原因,激活函數其實是爲了引入了非線性操作,使得神經網絡可以逼近非線性函數。因此如果不是輸出層必須要用sigmoid來限制輸出範圍,我一般是不用sigmoid的。

那麼從激活函數出發,緩解梯度消失有以下方法:
1)不行就換,比如把sigmoid換成relu,在x>0的時候可以穩穩維持1的梯度。
這裏寫圖片描述
2)不想換那也行,既然我們知道sigmoid在靠近0的取值範圍內梯度比較大,但我們可以把數據儘量規範化到一個比較合適的範圍,也就是接下來要談到的Normaliztion。

3. 從RNN到LSTM再到GRU

接下來我們再探討一下RNN系列,也就是展開型的神經網絡。

3.1 RNN

RNN是最簡單的循環神經網絡,其實就是對神經網絡展開k個step,所有step共享同一個神經網絡模塊S,我們還是直接來看圖吧:
這裏寫圖片描述

這是一個序列預測任務,可以看到在RNN中 W_s 和 W_x這兩個參數是共享的,注意噢:這裏也有個共享的W_o ,但不是包含在RNN中的,只是用於序列預測而已。

在step t下,RNN的輸出向量 s t 是:

s t = t a n h ( W x x t + W s s t 1 + b )

接下來 W o s t 進行相乘得到step t下的預測值 o t (加激活函數也可以)。假設step t 的正確label是 y t ,我們現在還是將Loss函數定義爲均方誤差:
E = 1 2 t = 1 T ( y t o t ) 2 .

現在我們來看看怎麼更新W_x,可以看到在step t 下,計算 o_t 不僅涉及到了step t下的W_x ,也涉及到了前面step下的W_x,來看這個反向傳播路徑圖:
這裏寫圖片描述
因此在step t下, E t w x 求導需要對前面所有step的 W x 依次進行求導,再加起來:

E t W x = i = 1 t E t o t o t s t ( j = i + 1 t s j s j 1 ) s i W x

注意到有一個碩大的連乘符號,事情好像又開始變得不簡單起來,我們來繼續求導下去,在RNN中 s的激活函數是tanh函數:

j = i + 1 t s j s j 1 = j = i + 1 t t a n h W s

路和前面的神經網絡是一樣的!這裏又涉及到了激活函數的梯度,以及網絡的其它權重 W s ,而tanh其實只是將sigmoid的範圍從[0, 1]變到[-1, 1]而已:
這裏寫圖片描述

另外,我們從矩陣的角度來看, s j s j 1 是個Jacobian矩陣(向量對向量求導),如果矩陣值太大顯然會帶來梯度爆炸(這個不是重點),重點是如果值比較小,而且又經過矩陣連乘,梯度值迅速收縮,最後可能會造成梯度消失

剛剛我們推導了 W_x的梯度, W_s其實也是一樣的,這裏不再重複推導。而 W_o,前面講到它不是屬於RNN的,但是我們也不妨來推導一下:

E t W o = E t o t o t W o

咦!沒錯,在step t下, o t 只和這個step的 W o 有關,和前面step的 W o 都沒關係,所以 W o 的梯度對我們並沒有什麼威脅。

3.2 LSTM出場

上面講到,RNN的梯度問題是產生於 j = i + 1 t s j s j 1 這一項,LSTM作爲RNN的改進版本,改進了共享的神經網絡模塊,引入了cell結構,其實也是爲了在這一項中保持一定的梯度,把連乘操作改爲連加操作
引用自 Stanford CS231n slides
LSTM相信很多人看過這個:[譯] 理解 LSTM 網絡,但是我發現cs231n的公式更加簡潔,把四個門層結構的權重參數合成一個W

求導過程比較複雜,我們先看一下c_t這一項:

c t = f t c t 1 + i t g t

和前面一樣,我們來求一下 c t c t 1 ,這裏注意 f t i t g t 都是

和前面一樣,我們來求一下 c t c t 1 ,這裏注意 f t i t g t 都是 gt f t i t g t 都是 c t 1 的複合函數:

相關文章
相關標籤/搜索