本文主要圍繞NN、RNN、LSTM和GRU,討論後向傳播中所存在的梯度問題,以及解決方法,力求深入淺出。
神經網絡包括前向過程和後向過程,前向過程定義網絡結構,後向過程對網絡進行訓練(也就是優化參數),經過多輪迭代得到最終網絡(參數已定)
我們先來分析一個非常簡單的三層神經網絡:
數據集
在輸入層,假設該層節點數爲d,也就是特徵x的維度, 作爲該層輸出;
在隱藏層中,該層節點數爲q,每個節點的輸入 就是上一層所有節點輸出 的線性組合值,該節點的輸出 的激活值,這裏假設使用sigmoid激活函數;
在輸出層,該層節點數爲l,也就是輸出y的維度,同理,每個節點的輸入 是 的線性組合值,輸出 ,根據不同任務選擇不同激活函數,比如二分類任務一般是用sigmoid激活函數把
1)首先我們根據網絡輸出和真實Label來定義Loss函數,這裏定義爲簡單的均方誤差:
那麼我們的目標就是最小化Loss,調整參數 w_{hj} 和 v_{ih} ,使得網絡儘量去擬合真實數據。如何求最小值?那當然是求導了,根據loss函數對參數求導,然後往梯度下降的方向去更新參數,可以降低loss值。梯度主宰更新,如果梯度太小,會帶來梯度消失問題,導致參數更新很慢;那如果梯度很大,又會造成梯度爆炸問題。
2)對於輸出層參數
進行鏈式求導,也就是,E先對節點的輸出
求導,再對節點的輸入
求導,最後 對
求導,結果爲:
這裏我們令 ,就可以得到參數 的更新量爲:
3)對於隱藏層參數
,也是鏈式求導,E先對該層節點的輸出
求導,再對節點的輸入
求導,最後對
求導,其實在前面我們已經求出了部分梯度,最後結果爲:
注意到,
其實我們剛剛求過,其實就是
這貨,因此我們可得:
再次令 ,可以得到 的更新量爲:
也就可以愉快地將更新 了。
1) :這是上一層傳遞過來的梯度,如果上一層的梯度本來已經很小,那麼在這一層進行相乘,會導致這一層的梯度也很小。所以如果網絡層比較深,那麼在鏈式求導的過程中,越是低層的網絡層梯度在連乘過程中可能會變得越來越小,導致梯度消失。
2) :這是這一層的權重,這一項是造成梯度爆炸的主要原因,如果權重很大,也可能會導致相乘後的梯度也比較大。(梯度爆炸不是問題,做個梯度裁剪就行了,對梯度乘以一個縮放因子,我們主要考慮的是梯度消失問題)
3) :這是sigmoid激活函數的導數,sigmoid激活值本身已經是一個比較小的數了,這兩個小於1的數相乘會變得更小,就可能會造成梯度消失。
我們直接來看sigmoid的這個圖吧,只有在靠近0的區域梯度比較大(然而也不會超過0.25),在接近無窮小或者無窮大的時候梯度幾乎是0了:
所以sigmoid是造成梯度消失的一個重要原因,激活函數其實是爲了引入了非線性操作,使得神經網絡可以逼近非線性函數。因此如果不是輸出層必須要用sigmoid來限制輸出範圍,我一般是不用sigmoid的。
那麼從激活函數出發,緩解梯度消失有以下方法:
1)不行就換,比如把sigmoid換成relu,在x>0的時候可以穩穩維持1的梯度。
2)不想換那也行,既然我們知道sigmoid在靠近0的取值範圍內梯度比較大,但我們可以把數據儘量規範化到一個比較合適的範圍,也就是接下來要談到的Normaliztion。
接下來我們再探討一下RNN系列,也就是展開型的神經網絡。
RNN是最簡單的循環神經網絡,其實就是對神經網絡展開k個step,所有step共享同一個神經網絡模塊S,我們還是直接來看圖吧:
這是一個序列預測任務,可以看到在RNN中 W_s 和 W_x這兩個參數是共享的,注意噢:這裏也有個共享的W_o ,但不是包含在RNN中的,只是用於序列預測而已。
在step t下,RNN的輸出向量 是:
接下來
進行相乘得到step t下的預測值
(加激活函數也可以)。假設step t 的正確label是
,我們現在還是將Loss函數定義爲均方誤差:
.
現在我們來看看怎麼更新W_x,可以看到在step t 下,計算 o_t 不僅涉及到了step t下的W_x ,也涉及到了前面step下的W_x,來看這個反向傳播路徑圖:
因此在step t下,
求導需要對前面所有step的
依次進行求導,再加起來:
注意到有一個碩大的連乘符號,事情好像又開始變得不簡單起來,我們來繼續求導下去,在RNN中 s的激活函數是tanh函數:
路和前面的神經網絡是一樣的!這裏又涉及到了激活函數的梯度,以及網絡的其它權重
,而tanh其實只是將sigmoid的範圍從[0, 1]變到[-1, 1]而已:
另外,我們從矩陣的角度來看, 是個Jacobian矩陣(向量對向量求導),如果矩陣值太大顯然會帶來梯度爆炸(這個不是重點),重點是如果值比較小,而且又經過矩陣連乘,梯度值迅速收縮,最後可能會造成梯度消失。
剛剛我們推導了 W_x的梯度, W_s其實也是一樣的,這裏不再重複推導。而 W_o,前面講到它不是屬於RNN的,但是我們也不妨來推導一下:
咦!沒錯,在step t下, 只和這個step的 有關,和前面step的 都沒關係,所以 的梯度對我們並沒有什麼威脅。
上面講到,RNN的梯度問題是產生於
這一項,LSTM作爲RNN的改進版本,改進了共享的神經網絡模塊,引入了cell結構,其實也是爲了在這一項中保持一定的梯度,把連乘操作改爲連加操作。
LSTM相信很多人看過這個:[譯] 理解 LSTM 網絡,但是我發現cs231n的公式更加簡潔,把四個門層結構的權重參數合成一個W
求導過程比較複雜,我們先看一下c_t這一項:
和前面一樣,我們來求一下 ,這裏注意 都是
和前面一樣,我們來求一下 ,這裏注意 都是 和gt 都是 的複合函數: