首先,我們以一個雙層神經網絡爲例展示神經網絡關於數據標籤的計算過程(即前向傳播)。
其中,
Wl
和
bl
分別表示第
l
層神經元的權重參數和偏置項,
sl=WlTal−1+bl
。
gl
表示第
l
層神經元的激活函數,不同層可以選取不同的函數作爲激活函數。
al
表示第
l
層神經元的輸出。本例最終的輸出
a2
即是該神經網絡針對數據集
X
計算得到的預測值
ŷ
。
我們可以構建出本神經網絡的成本函數
J(ŷ )
。一個常見的方式是採用最小二乘法,使得殘差最小化:
J(ŷ )=1m∑i=1m(yi−ŷ i)2=1m(Y−Ŷ )T(Y−Ŷ )
我們以上圖爲例,將每層神經元的計算過程以數學公式表示:
{s1=W1a0+b1a1=g1(s1){s2=W2a1+b2a2=g2(s2)
然後,我們來擴展成本函數
J(ŷ )
:
J(ŷ )=J(a2)=J[g2(s2)]=J[g2(W2a1+b2)]=J{g2[W2g1(W1a0+b1)+b2]}=J{g2[W2g1(W1X+b1)+b2]}
爲易於觀察,對於不同函數
J,g2,g1
,上式採用了不同的括號。上式即嵌套的函數:
J(ŷ )=J(g2(g1(X)))
。因此,使得成本函數
J(ŷ )
最小化,我們可以使用
梯度下降法得到此例中的自變量
W1,W2,b1
和
b2
:
{W2=W2−α▽J(W2)b2=b2−α▽J(b2){W1=W1−α▽J(W1)b1=b1−α▽J(b1)
通用的更新公式爲:
Wl=Wl−α▽J(Wl)bl=bl−α▽J(bl)
上式便是神經網絡的反向傳播算法,即其學習策略。下面我將繼續以文章開始處的例子詳細解釋反向傳播算法。
其中,
dWl
和
dbl
分別表示成本函數
J
對於
Wl
和
bl
的偏導數,
ds1
亦是如此。我們可以先計算一下
W2
和
b2
的更新公式(因爲它們離成本函數最近,偏導的計算量最小):
{W2=W2−α▽J(W2)b2=b2−α▽J(b2)
其中,
▽J(W2)=∂J∂W2=dW2
,
▽J(b2)=∂J∂b2=db2
。
da2=⎡⎣⎢⎢⎢⎢⎢da21da22⋮da2l2⎤⎦⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢∂J∂a21∂J∂a22⋮∂J∂a2l2⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢⎢⎢⎢−2m(y1i−a21i)−2m(y2i−a22i)⋮−2m(yl2i−a2l2i)⎤⎦⎥⎥⎥⎥⎥⎥⎥
其中,
l2
表示神經網絡第2層的神經元數目,
J=1m∑i=1m(yi−ŷ i)2
。
ds2=⎡⎣⎢⎢⎢⎢⎢ds21ds22⋮ds2l2⎤⎦⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢⎢da21g2′(s21)da22g2′(s22)⋮da2l2g2′(s2l2)⎤⎦⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢⎢g2′(s21)0⋮00g2′(s22)0………00g2′(s2l2)⎤⎦⎥⎥⎥⎥⎥⎡⎣⎢⎢⎢⎢⎢da21da22⋮da2l2⎤⎦⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢⎢g2′(s21)0⋮00g2′(s22)0………00g2′(s2l2)⎤⎦⎥⎥⎥⎥⎥da2
然後,求
dW2
和
db2
:
dW2=⎡⎣⎢⎢⎢⎢⎢dw211dw221⋮dw2l21dw212dw222dw2l22………dw21l1dw22l1dw2l2l1⎤⎦⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢⎢ds21a11ds22a11⋮ds2l2a11ds21a12ds22a12ds2l2a12………ds21a1l1ds22a1l1ds2l2a1l1⎤⎦⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢⎢ds21ds22⋮ds2l2⎤⎦⎥⎥⎥⎥⎥[a11a12…a1l1]=ds2a1T
db2=⎡⎣⎢⎢⎢⎢⎢db21db22⋮db2l2⎤⎦⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢⎢ds21ds22⋮ds2l2⎤⎦⎥⎥⎥⎥⎥=ds2
對於
W1
和
b1
的更新公式:
{W1=W1−α▽J(W1)b1=b1−α▽J(b1)
其中,
▽J(W1)=ds1a0T
,
▽J(b1)=ds1
(推導過程同上)。其中:
ds1=⎡⎣⎢⎢⎢⎢⎢g1′(s11)0⋮00g1′(s12)0………00g1′(s1l1)⎤⎦⎥⎥⎥⎥⎥da1
da1=⎡⎣⎢⎢⎢⎢⎢da11da12⋮da1l1⎤⎦⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢⎢⎢ds2T[w211w221…w2l21]Tds2T[w212w222…w2l22]T⋮ds2T[w21l1w22l1…w2l2l1]T⎤⎦⎥⎥⎥⎥⎥⎥=W2Tds2
因此,根據鏈式規則可得更爲通用的公式:
dsl=gl′(sl)Wl+1Tdsl+1dslast=glast′(slast)∂J∂alast
最後,我將本例的前向傳播和反向傳播的圖示結合起來,並給出完整的反向傳播更新公式。
{Wl=Wl−α▽J(Wl)=Wl−αdslal−1Tbl=bl−α▽J(bl)=bl−αdsl{bl−αdsl{dsl=gl′(sl)Wl+1Tdsl+1ddsl+1dslast=glast′(slast)∂J∂alastlast=