長短時記憶網絡的訓練
熟悉咱們這個系列文章的同窗都清楚,訓練部分每每比前向計算部分複雜多了。LSTM的前向計算都這麼複雜,那麼,可想而知,它的訓練算法必定是很是很是複雜的。如今只有作幾回深呼吸,再一頭扎進公式海洋吧。算法
LSTM訓練算法框架
LSTM的訓練算法仍然是反向傳播算法,對於這個算法,咱們已經很是熟悉了。主要有下面三個步驟:網絡
- 前向計算每一個神經元的輸出值,對於LSTM來講,即
ft
、
it
、
ct
、
ot
、
ht
五個向量的值。計算方法已經在上一節中描述過了。
- 反向計算每一個神經元的偏差項
δ
值。與循環神經網絡同樣,LSTM偏差項的反向傳播也是包括兩個方向:一個是沿時間的反向傳播,即從當前t時刻開始,計算每一個時刻的偏差項;一個是將偏差項向上一層傳播。
- 根據相應的偏差項,計算每一個權重的梯度。
關於公式和符號的說明
首先,咱們對推導中用到的一些公式、符號作一下必要的說明。框架
接下來的推導中,咱們設定gate的激活函數爲sigmoid函數,輸出的激活函數爲tanh函數。他們的導數分別爲:ide
σ(z)σ′(z)tanh(z)tanh′(z)=y=11+e−z=y(1−y)=y=ez−e−zez+e−z=1−y2(8)(9)(10)(11)
從上面能夠看出,sigmoid和tanh函數的導數都是原函數的函數。這樣,咱們一旦計算原函數的值,就能夠用它來計算出導數的值。函數
LSTM須要學習的參數共有8組,分別是:遺忘門的權重矩陣
Wf
和偏置項
bf
、輸入門的權重矩陣
Wi
和偏置項
bi
、輸出門的權重矩陣
Wo
和偏置項
bo
,以及計算單元狀態的權重矩陣
Wc
和偏置項
bc
。由於權重矩陣的兩部分在反向傳播中使用不一樣的公式,所以在後續的推導中,權重矩陣
Wf
、
Wi
、
Wc
、
Wo
都將被寫爲分開的兩個矩陣:
Wfh
、
Wfx
、
Wih
、
Wix
、
Woh
、
Wox
、
Wch
、
Wcx
。學習
咱們解釋一下按元素乘
∘
符號。當
∘
做用於兩個向量時,運算以下:atom
a∘b=⎡⎣⎢⎢⎢⎢a1a2a3...an⎤⎦⎥⎥⎥⎥∘⎡⎣⎢⎢⎢⎢⎢b1b2b3...bn⎤⎦⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢⎢a1b1a2b2a3b3...anbn⎤⎦⎥⎥⎥⎥⎥
當
∘
做用於一個向量和一個矩陣時,運算以下:spa
a∘X=⎡⎣⎢⎢⎢⎢a1a2a3...an⎤⎦⎥⎥⎥⎥∘⎡⎣⎢⎢⎢⎢x11x21x31xn1x12x22x32xn2x13x23x33...xn3............x1nx2nx3nxnn⎤⎦⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢a1x11a2x21a3x31anxn1a1x12a2x22a3x32anxn2a1x13a2x23a3x33...anxn3............a1x1na2x2na3x3nanxnn⎤⎦⎥⎥⎥⎥(12)(13)
當
∘
做用於兩個矩陣時,兩個矩陣對應位置的元素相乘。按元素乘能夠在某些狀況下簡化矩陣和向量運算。例如,當一個對角矩陣右乘一個矩陣時,至關於用對角矩陣的對角線組成的向量按元素乘那個矩陣:code
diag[a]X=a∘X
當一個行向量右乘一個對角矩陣時,至關於這個行向量按元素乘那個矩陣對角線組成的向量:orm
aTdiag[b]=a∘b
上面這兩點,在咱們後續推導中會屢次用到。
在t時刻,LSTM的輸出值爲
ht
。咱們定義t時刻的偏差項
δt
爲:
δt=def∂E∂ht
注意,和前面幾篇文章不一樣,咱們這裏假設偏差項是損失函數對輸出值的導數,而不是對加權輸入
netlt
的導數。由於LSTM有四個加權輸入,分別對應
ft
、
it
、
ct
、
ot
,咱們但願往上一層傳遞一個偏差項而不是四個。但咱們仍然須要定義出這四個加權輸入,以及他們對應的偏差項。
netf,tneti,tnetc~,tneto,tδf,tδi,tδc~,tδo,t=Wf[ht−1,xt]+bf=Wfhht−1+Wfxxt+bf=Wi[ht−1,xt]+bi=Wihht−1+Wixxt+bi=Wc[ht−1,xt]+bc=Wchht−1+Wcxxt+bc=Wo[ht−1,xt]+bo=Wohht−1+Woxxt+bo=def∂E∂netf,t=def∂E∂neti,t=def∂E∂netc~,t=def∂E∂neto,t(14)(15)(16)(17)(18)(19)(20)(21)(22)(23)(24)(25)
偏差項沿時間的反向傳遞
沿時間反向傳遞偏差項,就是要計算出t-1時刻的偏差項
δt−1
。
δTt−1=∂E∂ht−1=∂E∂ht∂ht∂ht−1=δTt∂ht∂ht−1(26)(27)(28)
咱們知道,
∂ht∂ht−1
是一個Jacobian矩陣。若是隱藏層h的維度是N的話,那麼它就是一個
N×N
矩陣。爲了求出它,咱們列出
ht
的計算公式,即前面的式6和式4:
htct=ot∘tanh(ct)=ft∘ct−1+it∘c~t(29)(30)
顯然,
ot
、
ft
、
it
、
c~t
都是
ht−1
的函數,那麼,利用全導數公式可得:
δTt∂ht∂ht−1=δTt∂ht∂ot∂ot∂neto,t∂neto,t∂ht−1+δTt∂ht∂ct∂ct∂ft∂ft∂netf,t∂netf,t∂ht−1+δTt∂ht∂ct∂ct∂it∂it∂neti,t∂neti,t∂ht−1+δTt∂ht∂ct∂ct∂c~t∂c~t∂netc~,t∂netc~,t∂ht−1=δTo,t∂neto,t∂ht−1+δTf,t∂netf,t∂ht−1+δTi,t∂neti,t∂ht−1+δTc~,t∂netc~,t∂ht−1(式7)(31)(32)
下面,咱們要把式7中的每一個偏導數都求出來。根據式6,咱們能夠求出:
∂ht∂ot∂ht∂ct=diag[tanh(ct)]=diag[ot∘(1−tanh(ct)2)](33)(34)
根據式4,咱們能夠求出:
∂ct∂ft∂ct∂it∂ct∂c~t=diag[ct−1]=diag[c~t]=diag[it](35)(36)(37)
由於:
otneto,tftnetf,titneti,tc~tnetc~,t=σ(neto,t)=Wohht−1+Woxxt+bo=σ(netf,t)=Wfhht−1+Wf