長短時記憶網絡(LSTM)的訓練

長短時記憶網絡的訓練

熟悉咱們這個系列文章的同窗都清楚,訓練部分每每比前向計算部分複雜多了。LSTM的前向計算都這麼複雜,那麼,可想而知,它的訓練算法必定是很是很是複雜的。如今只有作幾回深呼吸,再一頭扎進公式海洋吧。算法

LSTM訓練算法框架

LSTM的訓練算法仍然是反向傳播算法,對於這個算法,咱們已經很是熟悉了。主要有下面三個步驟:網絡

  1. 前向計算每一個神經元的輸出值,對於LSTM來講,即 ft it ct ot ht 五個向量的值。計算方法已經在上一節中描述過了。
  2. 反向計算每一個神經元的偏差項 δ 值。與循環神經網絡同樣,LSTM偏差項的反向傳播也是包括兩個方向:一個是沿時間的反向傳播,即從當前t時刻開始,計算每一個時刻的偏差項;一個是將偏差項向上一層傳播。
  3. 根據相應的偏差項,計算每一個權重的梯度。

關於公式和符號的說明

首先,咱們對推導中用到的一些公式、符號作一下必要的說明。框架

接下來的推導中,咱們設定gate的激活函數爲sigmoid函數,輸出的激活函數爲tanh函數。他們的導數分別爲:ide

σ(z)σ(z)tanh(z)tanh(z)=y=11+ez=y(1y)=y=ezezez+ez=1y2(8)(9)(10)(11)

從上面能夠看出,sigmoid和tanh函數的導數都是原函數的函數。這樣,咱們一旦計算原函數的值,就能夠用它來計算出導數的值。函數

LSTM須要學習的參數共有8組,分別是:遺忘門的權重矩陣 Wf 和偏置項 bf 、輸入門的權重矩陣 Wi 和偏置項 bi 、輸出門的權重矩陣 Wo 和偏置項 bo ,以及計算單元狀態的權重矩陣 Wc 和偏置項 bc 。由於權重矩陣的兩部分在反向傳播中使用不一樣的公式,所以在後續的推導中,權重矩陣 Wf Wi Wc Wo 都將被寫爲分開的兩個矩陣: Wfh Wfx Wih Wix Woh Wox Wch Wcx 學習

咱們解釋一下按元素乘 符號。當 做用於兩個向量時,運算以下:atom

ab=a1a2a3...anb1b2b3...bn=a1b1a2b2a3b3...anbn

做用於一個向量和一個矩陣時,運算以下:spa

aX=a1a2a3...anx11x21x31xn1x12x22x32xn2x13x23x33...xn3............x1nx2nx3nxnn=a1x11a2x21a3x31anxn1a1x12a2x22a3x32anxn2a1x13a2x23a3x33...anxn3............a1x1na2x2na3x3nanxnn(12)(13)

做用於兩個矩陣時,兩個矩陣對應位置的元素相乘。按元素乘能夠在某些狀況下簡化矩陣和向量運算。例如,當一個對角矩陣右乘一個矩陣時,至關於用對角矩陣的對角線組成的向量按元素乘那個矩陣:code

diag[a]X=aX

當一個行向量右乘一個對角矩陣時,至關於這個行向量按元素乘那個矩陣對角線組成的向量:orm

aTdiag[b]=ab

上面這兩點,在咱們後續推導中會屢次用到。

在t時刻,LSTM的輸出值爲 ht 。咱們定義t時刻的偏差項 δt 爲:

δt=defEht

注意,和前面幾篇文章不一樣,咱們這裏假設偏差項是損失函數對輸出值的導數,而不是對加權輸入 netlt 的導數。由於LSTM有四個加權輸入,分別對應 ft it ct ot ,咱們但願往上一層傳遞一個偏差項而不是四個。但咱們仍然須要定義出這四個加權輸入,以及他們對應的偏差項。

netf,tneti,tnetc~,tneto,tδf,tδi,tδc~,tδo,t=Wf[ht1,xt]+bf=Wfhht1+Wfxxt+bf=Wi[ht1,xt]+bi=Wihht1+Wixxt+bi=Wc[ht1,xt]+bc=Wchht1+Wcxxt+bc=Wo[ht1,xt]+bo=Wohht1+Woxxt+bo=defEnetf,t=defEneti,t=defEnetc~,t=defEneto,t(14)(15)(16)(17)(18)(19)(20)(21)(22)(23)(24)(25)

偏差項沿時間的反向傳遞

沿時間反向傳遞偏差項,就是要計算出t-1時刻的偏差項 δt1

δTt1=Eht1=Ehththt1=δTththt1(26)(27)(28)

咱們知道, htht1 是一個Jacobian矩陣。若是隱藏層h的維度是N的話,那麼它就是一個 N×N 矩陣。爲了求出它,咱們列出 ht 的計算公式,即前面的式6式4

htct=ottanh(ct)=ftct1+itc~t(29)(30)

顯然, ot ft it c~t 都是 ht1 的函數,那麼,利用全導數公式可得:

δTththt1=δTthtototneto,tneto,tht1+δTthtctctftftnetf,tnetf,tht1+δTthtctctititneti,tneti,tht1+δTthtctctc~tc~tnetc~,tnetc~,tht1=δTo,tneto,tht1+δTf,tnetf,tht1+δTi,tneti,tht1+δTc~,tnetc~,tht1(7)(31)(32)

下面,咱們要把式7中的每一個偏導數都求出來。根據式6,咱們能夠求出:

htothtct=diag[tanh(ct)]=diag[ot(1tanh(ct)2)](33)(34)

根據式4,咱們能夠求出:

ctftctitctc~t=diag[ct1]=diag[c~t]=diag[it](35)(36)(37)

由於:

otneto,tftnetf,titneti,tc~tnetc~,t=σ(neto,t)=Wohht1+Woxxt+bo=σ(netf,t)=Wfhht1+Wf
相關文章
相關標籤/搜索