三招提高數據不平衡模型的性能(附python代碼)

     對於深度學習而言,數據集很是重要,但在實際項目中,或多或少會遇見數據不平衡問題。什麼是數據不平衡呢?舉例來講,如今有一個任務是判斷西瓜是否成熟,這是一個二分類問題——西瓜是生的仍是熟的,該任務的數據集由兩部分數據組成,成熟西瓜與生西瓜,假設生西瓜的樣本數量遠遠大於成熟西瓜樣本的數量,針對這樣的數據集訓練出來的算法「偏向」於識別新樣本爲生西瓜,存心讓你買不到甜的西瓜以解夏天之苦,這就是一個數據不平衡問題。針對數據不平衡問題有相應的處理辦法,好比對多數樣本進行採樣使得其樣本數量級與少樣本數相近,或者是對少數樣本重複使用等。最近剛好在面試中遇到一個數據不平衡問題,這也是面試中常常會出現的問題之一,現向讀者分享這次解決問題的心得。git

1_jpeg

數據集

       訓練數據中有三個標籤,分別標記爲[一、二、3],這意味着該問題是一個多分類問題。訓練數據集有17個特徵以及38829個獨立數據點。而在測試數據中,有16個沒有標籤的特徵和16641個數據點。該訓練數據集很是不平衡,大部分數據是1類(95%),而2類和3類分別有3.0%和0.87%的數據,以下圖所示。github

2

算法

       通過初步觀察,決定採用隨機森林(RF)算法,由於它優於支持向量機、Xgboost以及LightGBM算法。在這個項目中選擇RF還有幾個緣由:面試

  • 1機森林對過擬合具備很強的魯棒性;
  • 2.參數化仍然很是直觀;
  • 3.在這個項目中,有許多成功的用例將隨機森林算法用於高度不平衡的數據集;
  • 4.我的有先前的算法實施經驗;
           爲了找到最佳參數,使用scikit-sklearn實現的GridSearchCV對指定的參數值執行網格搜索,更多細節能夠在本人的Github上找到。

爲了處理數據不平衡問題,使用瞭如下三種技術:

A.使用集成交叉驗證(CV):

       在這個項目中,使用交叉驗證來驗證模型的魯棒性。整個數據集被分紅五個子集。在每一個交叉驗證中,使用其中的四個子集用於訓練,剩餘的子集用於驗證模型,此外模型還對測試數據進行了預測。在交叉驗證結束時,會獲得五個測試預測機率。最後,對全部類別的機率取平均值。模型的訓練表現穩定,每一個交叉驗證上具備穩定的召回率和f1分數。這項技術也幫助我在Kaggle比賽中取得了很好的成績(前1%)。如下部分代碼片斷顯示了集成交叉驗證的實現:算法

相關文章
相關標籤/搜索