Regularization

這是我參與8月更文挑戰的第5天,活動詳情查看: 8月更文挑戰html

本篇文章講解如何緩解over-fitting。首先看下面三張圖,under-fitted代表預測的函數模型所包含的參數量、複雜度低於實際模型,但這種狀況已經愈來愈少見了,由於如今的網絡都足夠深python

而over-fitted代表預測的函數模型的參數量、複雜度遠高於實際模型 git

在此背景下,有人提出了「Occam’s Razor」,即more things should not be used than are necessary,不是必要的東西不要使用,在神經網絡中,不是必要的網絡參數,要儘可能選擇最小的、最有可能的參數量markdown

目前對於防止over-fitting,有如下幾種主流的作法網絡

  • More data
  • Constraint model complexity
  • shallow
  • regularization
  • Dropout
  • Data argumentation
  • Early Stopping

這裏咱們利用的是Regularization,對於一個二分類問題,它的Cross Entropy公式爲app

J 1 ( θ ) = 1 m i = 1 m [ y i ln y ^ i + ( 1 y i ) ln ( 1 y ^ i ) ] J_1(\theta)=-\frac{1}{m}\sum_{i=1}^m[y_i\ln\hat y_i+(1-y_i)\ln(1-\hat y_i)]

此時若增長一個參數 θ \theta θ \theta 表明網絡參數 ( w 1 , b 1 , w 2 ) (w1,b1,w2) 等,再將 θ \theta 的某一範數(下面公式用的是L1-norm)乘以一個因子 λ > 0 \lambda>0 ,則公式變爲函數

J 2 ( θ ) = J 1 ( θ ) + λ i = 1 n θ i J_2(\theta)=J_1(\theta)+\lambda\sum_{i=1}^n|\theta_i|

思考一下,咱們原本是要優化Loss,也就是 J 1 ( θ ) J_1(\theta) 的值,使其接近於0,如今咱們優化的是 J 2 ( θ ) J_2(\theta) ,其實就是在迫使Loss接近於0的過程當中,使得參數的L1-norm i θ i \sum_i|\theta_i| 也接近於0post

那爲何參數的範數值接近於0,模型的複雜度就會減少呢?咱們設想如今有一個模型 y = β 0 + β 1 x + . . . + β 7 x 7 y=\beta_0+\beta_1x+...+\beta_7x^7 ,經過正則化之後,參數的範數值優化爲很接近於0的值,此時可能 β 3 , . . . , β 7 \beta_3,...,\beta_7 的值都變得很小,假設都是 0.01 0.01 ,那模型近似變爲一個二次方程,就不是原來七次那麼複雜了性能

這種方法又稱Weight Decay學習

注意到上圖左側圖的分割面較複雜,不是光滑的曲線,代表函數模型的分割性較好、表達能力強,但有學到了一些噪聲樣本。右側圖是添加了regularization後的圖,函數模型沒有學習到一些噪聲樣本,表達能力沒有那麼強,能進行更好的劃分,而這就是咱們想要的

Regularization有兩種比較常見的方式,一種是加L1-norm,另外一種是加L2-norm,最經常使用的是L2-regularization,代碼以下

net = MLP()
optimizer = optim.SGD(net.parameters(), lr=learning_rater, weight_decay=0.01)
# SGD會獲得全部的網絡參數,設置weight_decay=0.01以迫使二範數逐漸趨近於0
# 但要注意的是,若沒有overfitting現象仍設置weight_decay參數,會使性能急劇降低
criteon = nn.CrossEntropyLoss()
複製代碼

pytorch對L1-regularization暫時並無很好的支持,所以須要人爲設定代碼

regularization_loss = 0
for param in model.parameters():
    regularization_loss += torch.sum(torch.abs(param))

classify_loss = criteon(logits, target)
loss = classify_loss + 0.01 * regularization_loss

optimizer.zero_grad()
loss.backward()
optimizer.step()
複製代碼
相關文章
相關標籤/搜索