這是我參與8月更文挑戰的第5天,活動詳情查看: 8月更文挑戰html
本篇文章講解如何緩解over-fitting。首先看下面三張圖,under-fitted代表預測的函數模型所包含的參數量、複雜度低於實際模型,但這種狀況已經愈來愈少見了,由於如今的網絡都足夠深python
而over-fitted代表預測的函數模型的參數量、複雜度遠高於實際模型 git
在此背景下,有人提出了「Occam’s Razor」,即more things should not be used than are necessary,不是必要的東西不要使用,在神經網絡中,不是必要的網絡參數,要儘可能選擇最小的、最有可能的參數量markdown
目前對於防止over-fitting,有如下幾種主流的作法網絡
這裏咱們利用的是Regularization,對於一個二分類問題,它的Cross Entropy公式爲app
此時若增長一個參數 , 表明網絡參數 等,再將 的某一範數(下面公式用的是L1-norm)乘以一個因子 ,則公式變爲函數
思考一下,咱們原本是要優化Loss,也就是 的值,使其接近於0,如今咱們優化的是 ,其實就是在迫使Loss接近於0的過程當中,使得參數的L1-norm 也接近於0post
那爲何參數的範數值接近於0,模型的複雜度就會減少呢?咱們設想如今有一個模型 ,經過正則化之後,參數的範數值優化爲很接近於0的值,此時可能 的值都變得很小,假設都是 ,那模型近似變爲一個二次方程,就不是原來七次那麼複雜了性能
這種方法又稱Weight Decay學習
注意到上圖左側圖的分割面較複雜,不是光滑的曲線,代表函數模型的分割性較好、表達能力強,但有學到了一些噪聲樣本。右側圖是添加了regularization後的圖,函數模型沒有學習到一些噪聲樣本,表達能力沒有那麼強,能進行更好的劃分,而這就是咱們想要的
Regularization有兩種比較常見的方式,一種是加L1-norm,另外一種是加L2-norm,最經常使用的是L2-regularization,代碼以下
net = MLP()
optimizer = optim.SGD(net.parameters(), lr=learning_rater, weight_decay=0.01)
# SGD會獲得全部的網絡參數,設置weight_decay=0.01以迫使二範數逐漸趨近於0
# 但要注意的是,若沒有overfitting現象仍設置weight_decay參數,會使性能急劇降低
criteon = nn.CrossEntropyLoss()
複製代碼
pytorch對L1-regularization暫時並無很好的支持,所以須要人爲設定代碼
regularization_loss = 0
for param in model.parameters():
regularization_loss += torch.sum(torch.abs(param))
classify_loss = criteon(logits, target)
loss = classify_loss + 0.01 * regularization_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
複製代碼