模型的選擇與調優:交叉驗證與網格搜索

擊上方
「藍色字」
可關注咱們!


今日分享:交叉驗證與網格搜索 web


一:交叉驗證算法

交叉驗證:爲了讓被評估的模型更加準確可信數組

交叉驗證過程:將拿到的數據,分爲訓練集和驗證集(注意這裏的數據是在訓練集中進行劃分的,也就是將原始數據劃分獲得的訓練集再次劃分爲訓練集和驗集),交叉驗證通常結合網格搜索使用。微信

如下圖爲例:將數據分紅5份,其中一份做爲驗證集。而後通過5次(組)的測試,每次都更換不一樣的驗證集,又稱5折交叉驗證經過對上面的五組數據分別進行建模,每一組都會獲得一個模型的精確度,而後取平均值做爲該模型最後的精確度,而後再進行後續的步驟。學習


(五折交叉驗證)測試


二:超參數搜索-網格搜索spa

一般狀況下,有不少參數是須要手動指定的(如k-近鄰算法中的K值),這種叫超參數。可是手動過程繁雜,因此須要對模型預設幾種超參數組合。每組超參數都採用交叉驗證來進行評估。最後選出最優參數組合創建模型。.net



對於超參數較少的K-近鄰來講,也許能夠經過 for 循環開找到較優的k值,可是對於其餘的模型,若是須要調的參數較多,for循環就不太方便了,好比兩個超參數時,每一個參數分別指定4個值,則下來就有16中參數組成的模型。3d


三:網格搜索APIcode


sklearn.model_selection.GridSearchCV


四:API參數介紹


sklearn.model_selection.GridSearchCV(estimator, param_grid=None,cv=None)
對估計器的指定參數值進行詳盡搜索

estimator:估計器對象 就是哪個模型

param_grid:估計器參數(dict){「n_neighbors」:[1,3,5]}

cv:指定幾折交叉驗證

fit:輸入訓練數據

score:準確率

結果分析:
best_score_:在交叉驗證中測試的最好結果
best_estimator_:最好的參數模型
cv_results_:每次交叉驗證後的測試集準確率結果和訓練集準確率結果


五:K-近鄰網格搜索


使用鳶尾花數據集來進行網格搜索示例演示


#導入相關庫
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV


def knn_iris():
   '''K-近鄰模型對鳶尾花進行分類'''
   
   #加載數據集
   iris = load_iris()
   
   #劃分數據集
   #切記 x_train,x_test,y_train,y_test 順序位置必定不能寫錯
   #括號中參數分別爲 (特徵值 目標值 測試集大小佔比) 佔比可自行設定 經常使用0.25
   x_train,x_test,y_train,y_test = train_test_split(iris.data,iris.target,test_size=0.25)
   
   '''特徵工程(標準化)'''
   std = StandardScaler()
   
   #對測試集和訓練集的特徵值進行標準化
   x_train = std.fit_transform(x_train)

   x_test = std.transform(x_test)
   
   #Knn模型實例化
   knn = KNeighborsClassifier()
   
   # 以字典形式構造一些參數的值進行搜索,若存在別的參數時,只需添加相應的鍵值
   # 這裏指定參數爲 1,3,5,7,10
   param = {"n_neighbors": [1,3,5,7,10]}

   # 進行網格搜索 3折交叉驗證
   gc = GridSearchCV(knn, param_grid=param, cv=3)

   gc.fit(x_train, y_train)

   print("每一個超參數每次交叉驗證的結果:\n", gc.cv_results_)
   
   print("在測試集上準確率:\n", gc.score(x_test, y_test))

   print("在交叉驗證當中最好的結果:\n", gc.best_score_)

   print("選擇最好的模型是:\n", gc.best_estimator_)

   
if __name__ == '__main__':
   knn_iris()


輸出結果

每一個超參數每次交叉驗證的結果:
{'split1_train_score': array([1.        , 0.97333333, 0.97333333, 0.97333333, 0.94666667]), 'params': [{'n_neighbors': 1}, {'n_neighbors': 3}, {'n_neighbors': 5}, {'n_neighbors': 7}, {'n_neighbors': 10}], 'std_train_score': array([0.        , 0.01110803, 0.00637203, 0.00637203, 0.0108896 ]), 'mean_train_score': array([1.        , 0.97315315, 0.97765766, 0.97765766, 0.95981982]),
'split2_train_score': array([1.        , 0.98666667, 0.98666667, 0.98666667, 0.97333333]), 'mean_test_score': array([0.94642857, 0.95535714, 0.96428571, 0.95535714, 0.96428571]),
'split0_train_score': array([1.        , 0.95945946, 0.97297297, 0.97297297, 0.95945946]),
'split2_test_score': array([0.89189189, 0.94594595, 0.97297297, 0.91891892, 0.94594595]),
'split0_test_score': array([1., 1., 1., 1., 1.]), 'mean_fit_time': array([0.00100843, 0.00099985, 0.        , 0.        , 0.        ]), 'rank_test_score': array([5, 3, 1, 3, 1]), 'param_n_neighbors': masked_array(data=[1, 3, 5, 7, 10],
            mask=[False, False, False, False, False],
      fill_value='?',
           dtype=object), 'std_test_score': array([0.04423072, 0.03382427, 0.03372858, 0.03382427, 0.02559281]), 'mean_score_time': array([0.00100978, 0.00100025, 0.        , 0.        , 0.        ]), 'std_score_time': array([1.46122043e-05, 4.89903609e-07, 0.00000000e+00, 0.00000000e+00,
      0.00000000e+00]), 'std_fit_time': array([3.53483630e-05, 2.24783192e-07, 0.00000000e+00, 0.00000000e+00,
      0.00000000e+00]),
'split1_test_score': array([0.94594595, 0.91891892, 0.91891892, 0.94594595, 0.94594595])}

在測試集上準確率:
0.9210526315789473

在交叉驗證當中最好的結果:
0.9642857142857143

選擇最好的模型是:
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
          metric_params=None, n_jobs=1, n_neighbors=5, p=2,
          weights='uniform')


由輸出結果可知,在指定的幾個k值中,超參數k=5時,模型效果最好




Python基礎知識集錦

爬蟲專題文章整理篇!!!

Python數據分析乾貨整理篇

Matplotlib數據可視化專題集錦貼



公衆號     QQ羣

掃QQ羣二維碼進交流學習羣

或在後臺回覆:加羣

本文分享自微信公衆號 - 數據指南(BigDataDT)。
若有侵權,請聯繫 support@oschina.cn 刪除。
本文參與「OSC源創計劃」,歡迎正在閱讀的你也加入,一塊兒分享。

相關文章
相關標籤/搜索