機器學習中用來防止過擬合的方法有哪些

雷鋒網(公衆號:雷鋒網)按:本文做者 qqfly,上海交通大學機器人所博士生,本科畢業於清華大學機械工程系,主要研究方向機器視覺與運動規劃,會寫一些好玩的內容在微信公衆號:Nao(ID:qRobotics)。本文整理自知乎回答:機器學習中用來防止過擬合的方法有哪些?html


給《機器視覺與應用》課程出大做業的時候,正好涉及到這方面內容,因此簡單整理了一下(參考 Hinton 的課程)。按照以前的套路寫:算法


是什麼數據庫


過擬合(overfitting)是指在模型參數擬合過程當中的問題,因爲訓練數據包含抽樣偏差,訓練時,複雜的模型將抽樣偏差也考慮在內,將抽樣偏差也進行了很好的擬合。微信


具體表現就是最終模型在訓練集上效果好;在測試集上效果差。模型泛化能力弱。網絡


機器學習中用來防止過擬合的方法有哪些?


爲何dom


爲何要解決過擬合現象?這是由於咱們擬合的模型通常是用來預測未知的結果(不在訓練集內),過擬合雖然在訓練集上效果好,可是在實際使用時(測試集)效果差。同時,在不少問題上,咱們沒法窮盡全部狀態,不可能將全部狀況都包含在訓練集上。因此,必需要解決過擬合問題。iphone


爲何在機器學習中比較常見?這是由於機器學習算法爲了知足儘量複雜的任務,其模型的擬合能力通常遠遠高於問題複雜度,也就是說,機器學習算法有「擬合出正確規則的前提下,進一步擬合噪聲」的能力。機器學習


而傳統的函數擬合問題(如機器人系統辨識),通常都是經過經驗、物理、數學等推導出一個含參模型,模型複雜度肯定了,只須要調整個別參數便可。模型「無多餘能力」擬合噪聲。函數


怎麼樣性能


既然過擬合這麼討厭,咱們應該怎麼防止過擬合呢?最近深度學習比較火,我就以神經網絡爲例吧:


機器學習中用來防止過擬合的方法有哪些?


1. 獲取更多數據


這是解決過擬合最有效的方法,只要給足夠多的數據,讓模型「看見」儘量多的「例外狀況」,它就會不斷修正本身,從而獲得更好的結果:


機器學習中用來防止過擬合的方法有哪些?


如何獲取更多數據,能夠有如下幾個方法:



  • 從數據源頭獲取更多數據:這個是容易想到的,例如物體分類,我就再多拍幾張照片好了;可是,在不少狀況下,大幅增長數據自己就不容易;另外,咱們不清楚獲取多少數據纔算夠;


  • 根據當前數據集估計數據分佈參數,使用該分佈產生更多數據:這個通常不用,由於估計分佈參數的過程也會代入抽樣偏差。


  • 數據加強(Data Augmentation):經過必定規則擴充數據。如在物體分類問題裏,物體在圖像中的位置、姿態、尺度,總體圖片明暗度等都不會影響分類結果。咱們就能夠經過圖像平移、翻轉、縮放、切割等手段將數據庫成倍擴充;


機器學習中用來防止過擬合的方法有哪些?


2. 使用合適的模型


前面說了,過擬合主要是有兩個緣由形成的:數據太少 + 模型太複雜。因此,咱們能夠經過使用合適複雜度的模型來防止過擬合問題,讓其足夠擬合真正的規則,同時又不至於擬合太多抽樣偏差。


(PS:若是能經過物理、數學建模,肯定模型複雜度,這是最好的方法,這也就是爲何深度學習這麼火的如今,我還堅持說初學者要學掌握傳統的建模方法。)


對於神經網絡而言,咱們能夠從如下四個方面來限制網絡能力


2.1 網絡結構 Architecture


這個很好理解,減小網絡的層數、神經元個數等都可以限制網絡的擬合能力;


機器學習中用來防止過擬合的方法有哪些?


2.2 訓練時間 Early stopping


對於每一個神經元而言,其激活函數在不一樣區間的性能是不一樣的:


機器學習中用來防止過擬合的方法有哪些?


當網絡權值較小時,神經元的激活函數工做在線性區,此時神經元的擬合能力較弱(相似線性神經元)。


有了上述共識以後,咱們就能夠解釋爲何限制訓練時間(early stopping)有用:由於咱們在初始化網絡的時候通常都是初始爲較小的權值。訓練時間越長,部分網絡權值可能越大。若是咱們在合適時間中止訓練,就能夠將網絡的能力限制在必定範圍內。


2.3 限制權值 Weight-decay,也叫正則化(regularization)


原理同上,可是這類方法直接將權值的大小加入到 Cost 裏,在訓練的時候限制權值變大。以 L2 regularization 爲例:


機器學習中用來防止過擬合的方法有哪些?


訓練過程須要下降總體的 Cost,這時候,一方面能下降實際輸出與樣本之間的偏差C0,也能下降權值大小。


2.4 增長噪聲 Noise


給網絡加噪聲也有不少方法:


2.4.1 在輸入中加噪聲:


噪聲會隨着網絡傳播,按照權值的平方放大,並傳播到輸出層,對偏差 Cost 產生影響。推導直接看 Hinton 的 PPT 吧:


機器學習中用來防止過擬合的方法有哪些?


在輸入中加高斯噪聲,會在輸出中生成機器學習中用來防止過擬合的方法有哪些?的干擾項。訓練時,減少偏差,同時也會對噪聲產生的干擾項進行懲罰,達到減少權值的平方的目的,達到與 L2 regularization 相似的效果(對比公式)。


2.4.2 在權值上加噪聲


在初始化網絡的時候,用 0 均值的高斯分佈做爲初始化。Alex Graves 的手寫識別 RNN 就是用了這個方法



Graves, Alex, et al. "A novel connectionist system for unconstrained handwriting recognition." IEEE transactions on pattern analysis and machine intelligence 31.5 (2009): 855-868.



- It may work better, especially in recurrent networks (Hinton)


2.4.3 對網絡的響應加噪聲


如在前向傳播過程當中,讓默寫神經元的輸出變爲 binary 或 random。顯然,這種有點亂來的作法會打亂網絡的訓練過程,讓訓練更慢,但據 Hinton 說,在測試集上效果會有顯著提高 (But it does significantly better on the test set!)。


3. 結合多種模型


簡而言之,訓練多個模型,以每一個模型的平均輸出做爲結果。


從 N 個模型裏隨機選擇一個做爲輸出的指望偏差機器學習中用來防止過擬合的方法有哪些?,會比全部模型的平均輸出的偏差機器學習中用來防止過擬合的方法有哪些?(我不知道公式裏的圓括號爲何顯示不了)


機器學習中用來防止過擬合的方法有哪些?


大概基於這個原理,就能夠有不少方法了:


3.1  Bagging


簡單理解,就是分段函數的概念:用不一樣的模型擬合不一樣部分的訓練集。以隨機森林(Rand Forests)爲例,就是訓練了一堆互不關聯的決策樹。但因爲訓練神經網絡自己就須要耗費較多自由,因此通常不單獨使用神經網絡作 Bagging。


3.2 Boosting


既然訓練複雜神經網絡比較慢,那咱們就能夠只使用簡單的神經網絡(層數、神經元數限制等)。經過訓練一系列簡單的神經網絡,加權平均其輸出。


機器學習中用來防止過擬合的方法有哪些?


3.3 Dropout


這是一個很高效的方法。


機器學習中用來防止過擬合的方法有哪些?


在訓練時,每次隨機(如 50% 機率)忽略隱層的某些節點;這樣,咱們至關於隨機從 2^H 個模型中採樣選擇模型;同時,因爲每一個網絡只見過一個訓練數據(每次都是隨機的新網絡),因此相似 bagging 的作法,這就是我爲何將它分類到「結合多種模型」中;


此外,而不一樣模型之間權值共享(共同使用這 H 個神經元的鏈接權值),至關於一種權值正則方法,實際效果比 L2 regularization 更好。


4. 貝葉斯方法


這部分我尚未想好怎麼才能講得清楚,爲了避免誤導初學者,我就先空着,之後若是想清楚了再更新。固然,這也是防止過擬合的一類重要方法。


機器學習中用來防止過擬合的方法有哪些?


綜上:


機器學習中用來防止過擬合的方法有哪些?





「TensorFlow & 神經網絡算法高級應用班」開課了!


ThoughtWorks大牛教你玩轉TensorFlow !


課程連接:http://www.leiphone.com/special/custom/mooc04.html




雷鋒網版權文章,未經受權禁止轉載。詳情見轉載須知

相關文章
相關標籤/搜索