反向傳播算法學習筆記

反向傳播算法(Back propagation)

目的及思想

咱們如今有一堆輸入,咱們但願能有一個網絡,使得經過這個網絡的構成的映射關係知足咱們的期待。也就是說,咱們在解決這個問題以前先假設,這種映射能夠用網絡的模型來比較好的描述。爲何是網絡而不是什麼別的形式呢?不懂了。。html

這個網絡究竟是個怎樣的形式呢?以下圖所示,\(i1,i2\)是輸入,\(o1,o2\)是輸出,其中\(w1...w8, b1, b2\)是這個網絡中的參數。對於一個結點來講,它的全部輸出都等於它的每一個輸入,對於對應\(w\)的加權求和帶入激活函數的結果。
853467-20160630141449671-1058672778.png-140.9kBc++

而如今\(w1...w8, b1, b2\)這些參數都是未知的,咱們但願能經過一些方法逼近這些參數的真實結果。算法

咱們將\(w1...w8, b1, b2\)這些參數,考慮成一個高維空間中的點,與三維還有二維的狀況相似的,咱們貪心的朝着周圍都走一小步,找到那個能得到相對最優解的方向,並接受此次移動,這是經典的梯度降低的思想。因而,咱們引入了損失函數,使用它來描述這個點的優秀程度。\(w1...w8, b1, b2\)是這個函數的輸入,經過調整這些輸入,咱們但願能得到一個使得損失函數得到最值的位置,然而實際上,咱們得到的顯然是一個極值,並不必定是最值,除非能證實這個損失函數關於這些參數是凸的。可是,做爲一個比較優秀的解,這樣作仍是有價值的。網絡

後半部分的思想過程瓜熟蒂落,感受整套方法最有價值和啓發意義的就是這個網絡模型。函數

具體算法

  1. 設定輸入量\(i_1,i_2...i_n\),以及\(w_1...w_{n*n*2}, b_1, b_2\),若是可能儘可能設定在離真實解較近的位置,最好在一個坑裏?
  2. 激活函數選取經典的sigmoid函數 \(f(x) = \frac{1}{1+e^{-x}}\)
  3. 損失函數取 \(L(w_1...w_{n*n*2}, b_1, b_2) = \frac{1}{2} \sum_{i=1}^n (target_j - o_j)^2\), 咱們定義\(i_j\)對應的目標輸出爲\(target_j\)
  4. 對於當前網絡帶入\(i_1,i_2...i_n\),求出對應的\(o_1,o_2,...,o_n\). 這個過程顯然就是在一張dag上按照拓撲序遞推更它的後繼節點便可,每到一個點計算它的激活函數的輸出,而後更新它的後繼節點
  5. 更新完以後,咱們就得到了\(o_1,o_2,...,o_n\). 如今須要求解 L 關於這每一個參數的在當前輸入狀況下的偏導。容易利用鏈式法則解決(懶得寫了)這裏有超詳細推導
    一文弄懂神經網絡中的反向傳播法——BackPropagation
  6. 返回操做 4,直到得到使人滿意的精度

代碼

c++寫了個實現。太醜了不發了。。最麻煩的部分就是鏈式求導算梯度的幾個式子推導,有了式子以後仍是挺好寫的。很是有意思的是,一開始的寫法,沒有加入參數 b1,b2,所以迭代 500000 次左右才能使L達到 1e-22 的精度,可是當咱們,補上 b1 和 b2 時,只用迭代 200000 次便可達到,一個式子形式的設計或者說網絡結構的設計,對於算法的效果影響仍是很巨大的。spa

相關文章
相關標籤/搜索