機器學習之反向傳播算法

 

http://www.cnblogs.com/python27/p/MachineLearningWeek05.htm

這一章多是Andrew Ng講得最不清楚的一章,爲何這麼說呢?這一章主要講後向傳播(Backpropagration, BP)算法,Ng花了一大半的時間在講如何計算偏差項δ,如何計算Δ的矩陣,以及如何用Matlab去實現後向傳播,然而最關鍵的問題——爲何要這麼計算?前面計算的這些量到底表明着什麼,Ng基本沒有講解,也沒有給出數學的推導的例子。因此此次內容我不打算照着公開課的內容去寫,在查閱了許多資料後,我想先從一個簡單的神經網絡的梯度推導入手,理解後向傳播算法的基本工做原理以及每一個符號表明的實際意義,而後再按照課程的給出BP計算的具體步驟,這樣更有助於理解。html

簡單神經網絡的後向傳播(Backpropagration, BP)算法

1. 回顧以前的前向傳播(ForwardPropagration, FP)算法

FP算法仍是很簡單的,說白了就是根據前一層神經元的值,先加權而後取sigmoid函數獲得後一層神經元的值,寫成數學的形式就是:python

 

a(1)=X

 

 

z(2)=Θ(1)a(1)

 

 

a(2)=g(z(2))

 

 

z(3)=Θ(2)a(2)

 

 

a(3)=g(z(3))

 

 

z(4)=Θ(3)a(3)

 

 

a(4)=g(z(4))

 

2. 回顧神經網絡的代價函數(不含regularization項)

J(Θ)=1m[i=1mk=1Ky(i)klog(hθ(x(i)))k+(1y(i)k)log(1(hθ(x(i)))k)]web

3. 一個簡單神經網絡的BP推導過程

BP算法解決了什麼問題?咱們已經有了代價函數J(Θ),接下來咱們須要利用梯度降低算法(或者其餘高級優化算法)對J(Θ)進行優化從而獲得訓練參數Θ,然而關鍵問題是,優化算法須要傳遞兩個重要的參數,一個代價函數J(Θ),另外一個是代價函數的梯度J(Θ)Θ,BP算法其實就是解決如何計算梯度的問題。算法

下面咱們從一個簡單的例子入手考慮如何從數學上計算代價函數的梯度,考慮以下簡單的神經網絡(爲方便起見,途中已經給出了前向傳播(FP)的計算過程),該神經網絡有三層神經元,對應的有兩個權重矩陣Θ(1)Θ(2),爲計算梯度咱們只須要計算兩個偏導數便可:J(Θ)Θ(1)J(Θ)Θ(2)網絡

首先咱們先計算第2個權重矩陣的偏導數,即Θ(2)J(Θ)。首先咱們須要在J(Θ)Θ(2)之間創建聯繫,很容易能夠看到J(Θ)的值取決於hθ(x),而hθ(x)=a(3)a3又是由z(3)取sigmoid獲得,最後z(3)=Θ(2)×a(2),因此他們之間的聯繫能夠以下表示:數據結構

按照求導的鏈式法則,咱們能夠先求J(Θ)z(3)的導數,而後乘以z(3)Θ(2)的導數,即函數

 

Θ(2)J(Θ)=z(3)J(Θ)×z(3)Θ(2)

 

z(3)=Θ(2)a(2)不難計算z(3)Θ(2)=(a(2))T,令z(3)J(Θ)=δ(3),上式能夠重寫爲post

 

Θ(2)J(Θ)=δ(3)(a(2))T

 

接下來僅須要計算δ(3)便可,由上一章的內容咱們已經知道g(z)=g(z)(1g(z))hθ(x)=a(3)=g(z(3)),忽略前面的1/mi=1m(這裏咱們只對一個example推導,最後累加便可)優化

 

δ(3)=J(Θ)z(3)=(y)1g(z(3))g(z(3))(1y)11g(z(3))[1g(z(3))]=y(1g(z(3)))+(1y)g(z(3))=y+g(z(3))=y+a(3)

 

至此咱們已經獲得J(Θ)Θ(2)的偏導數,即atom

 

J(Θ)Θ(2)=(a(2))Tδ(3)

 

 

δ(3)=a(3)y

 

接下來咱們須要求J(Θ)Θ(1)的偏導數,J(Θ)Θ(1)的依賴關係以下:

根據鏈式求導法則有

 

J(Θ)Θ(1)=J(Θ)z(3)z(3)a(2)a(2)Θ(1)

 

咱們分別計算等式右邊的三項可得:

 

J(Θ)z(3)=δ(3)

 

 

z(3)a(2)=(Θ(2))T

 

 

a(2)Θ(1)=a(2)z(2)z(2)Θ(1)=g(z(2))a(1)

 

帶入後得

 

J(Θ)Θ(1)=(a(1))Tδ(3)(Θ(2))Tg(z(2))

 

δ(2)=δ(3)(Θ(2))Tg(z(2)), 上式能夠重寫爲

 

J(Θ)Θ(1)=(a(1))Tδ(2)

 

 

δ(2)=δ(3)(Θ(2))Tg(z(2))

 

把上面的結果放在一塊兒,咱們獲得J(Θ)對兩個權重矩陣的偏導數爲:

 

δ(3)=a(3)y

 

 

J(Θ)Θ(2)=(a(2))Tδ(3)

 

 

δ(2)=δ(3)(Θ(2))Tg(z(2))

 

 

J(Θ)Θ(1)=(a(1))Tδ(2)

 

觀察上面的四個等式,咱們發現

  • 偏導數能夠由當前層神經元向量a(l)與下一層的偏差向量δ(l+1)相乘獲得
  • 當前層的偏差向量δ(l)能夠由下一層的偏差向量δ(l+1)與權重矩陣Δl的乘積獲得

因此能夠從後往前逐層計算偏差向量(這就是後向傳播的來源),而後經過簡單的乘法運算獲得代價函數對每一層權重矩陣的偏導數。到這裏算是終於明白爲何要計算偏差向量,以及爲何偏差向量之間有遞歸關係了。儘管這裏的神經網絡十分簡單,推導過程也不是十分嚴謹,可是經過這個簡單的例子,基本可以理解後向傳播算法的工做原理了。

嚴謹的後向傳播算法(計算梯度)

假設咱們有m個訓練example,L層神經網絡,而且此處考慮正則項,即

J(Θ)=1m[i=1mk=1Ky(i)klog(hθ(x(i)))k+(1y(i)k)log(1(hθ(x(i)))k)]+λ2ml=1L1i=1slj=1sl+1(Θ(l)ji)2

初始化:設置Δ(l)ij=0 (理解爲對第l層的權重矩陣的偏導累加值)

For i = 1 : m

  • 設置 a(1)=X
  • 經過前向傳播算法(FP)計算對各層的預測值a(l),其中l=2,3,4,,L
  • 計算最後一層的偏差向量 δ(L)=a(L)y,利用後向傳播算法(BP)從後至前逐層計算偏差向量 δ(L1),δ(L1),,δ(2), 計算公式爲δ(l)=(Θ(l))Tδ(l+1).g(z(l))
  • 更新Δ(l)=Δ(l)+δ(l+1)(a(l))T

end // for

計算梯度:

 

D(l)ij=1mΔ(l)ij,j=0

 

 

D(l)ij=1mΔ(l)ij+λΘ(l)ij,j0

 

 

J(Θ)Θ(l)=D(l)

 

BP實際運用中的技巧

1. 將參數展開成向量

對於四層三個權重矩陣參數Θ(1),Θ(2),Θ(3)將其展開成一個參數向量,Matlab code以下: 

1
thetaVec = [Theta1(:); Theta2(:); Theta3(:)];

2. 梯度檢查

爲了保證梯度計算的正確性,能夠用數值解進行檢查,根據導數的定義

 

dJ(θ)dθJ(θ+ϵ)J(θϵ)2ϵ

 

Matlab Code 以下

1
2
3
4
5
6
7
for  i  = 1 : n
     thetaPlus = theta;
     thetaPlus( i ) = thetaPlus( i ) + EPS;
     thetaMinus = theta;
     thetaMinus( i ) = thetaMinus( i ) - EPS;
     gradApprox( i ) = (J(thetaPlus) - J(thetaMinus)) / (2 * EPS);
end

最後檢查 gradApprox 是否約等於以前計算的梯度值便可。須要注意的是:由於近似的梯度計算代價很大,在梯度檢查後記得關閉梯度檢查的代碼。

3. 隨機初始化

初始權重矩陣的初始化應該打破對稱性 (symmetry breaking),避免使用全零矩陣進行初始化。能夠採用隨機數進行初始化,即 Θ(l)ij[ϵ,+ϵ]

如何訓練一個神經網絡

  1. 隨機初始化權重矩陣
  2. 利用前向傳播算法(FP)計算模型預測值hθ(x)
  3. 計算代價函數J(Θ)
  4. 利用後向傳播算法(BP)計算代價函數的梯度 J(Θ)Θ(l)
  5. 利用數值算法進行梯度檢查(gradient checking),確保正確後關閉梯度檢查
  6. 利用梯度降低(或者其餘優化算法)求得最優參數Θ

附:一個簡短的後向傳播教學視頻

 

參考文獻

[1] Andrew Ng Coursera 公開課第五週

[2] Derivation of Backpropagation. http://web.cs.swarthmore.edu/~meeden/cs81/s10/BackPropDeriv.pdf

[3] Wikipedia: Backpropagation. https://en.wikipedia.org/wiki/Backpropagation

[4] How the backpropagation algorithm works. http://neuralnetworksanddeeplearning.com/chap2.html

[5] 神經網絡和反向傳播算法推導. http://www.mamicode.com/info-detail-671452.html

相關文章
相關標籤/搜索