Python超參數自動搜索模塊GridSearchCV上手

1. 引言

當咱們跑機器學習程序時,尤爲是調節網絡參數時,一般待調節的參數有不少,參數之間的組合更是繁複。依照注意力>時間>金錢的原則,人力手動調節注意力成本過高,很是不值得。For循環或相似於for循環的方法受限於太過度明的層次,不夠簡潔與靈活,注意力成本高,易出錯。本文介紹sklearn模塊的GridSearchCV模塊,可以在指定的範圍內自動搜索具備不一樣超參數的不一樣模型組合,有效解放注意力。html

2. GridSearchCV模塊簡介

這個模塊是sklearn模塊的子模塊,導入方法很是簡單python

from sklearn.model_selection import GridSearchCV

函數原型:git

class sklearn.model_selection.GridSearchCV(estimator, param_grid, scoring=None, fit_params=None, n_jobs=1, iid=True, refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs', error_score='raise', return_train_score=True)

其中cv能夠是整數或者交叉驗證生成器或一個可迭代器,cv參數對應的4種輸入列舉以下:github

  1. None:默認參數,函數會使用默認的3折交叉驗證
  2. 整數k:k折交叉驗證。對於分類任務,使用StratifiedKFold(類別平衡,每類的訓練集佔比同樣多,具體能夠查看官方文檔)。對於其餘任務,使用KFold
  3. 交叉驗證生成器:得本身寫生成器,頭疼,略
  4. 能夠生成訓練集與測試集的迭代器:同上,略

3. 分析結果自動保存

逗號分隔值(Comma-Separated Values,CSV,有時也稱爲字符分隔值,由於分隔字符也能夠不是逗號),其文件以純文本形式存儲表格數據(數字和文本)。純文本意味着該文件是一個,不含必須像二進制數字那樣被解讀的數據。CSV文件由任意數目的記錄組成,記錄間以某種換行符分隔;每條記錄由字段組成,字段間的分隔符是其它字符或字符串,最多見的是逗號或製表符。一般,全部記錄都有徹底相同的字段序列。算法

CSV文件有個突出的優勢,能夠用excel等軟件打開,比起記事本和matlab、python等編程語言界面,便於查看、製做報告、後期整理等。編程

GridSearchCV模塊中,不一樣超參數的組合方式及其計算結果以字典的形式保存在 clf.cv_results_中,python的pandas模塊提供了高效整理數據的方法,只須要3行代碼便可解決問題。網絡

cv_result = pd.DataFrame.from_dict(clf.cv_results_)
with open('cv_result.csv','w') as f:
  cv_result.to_csv(f)

4. 完整例程

代碼清晰易懂,無須解釋。https://github.com/JiJingYu/tensorflow-exercise/tree/master/svm_grid_search機器學習

 1 import pandas as pd
 2 from sklearn import svm, datasets
 3 from sklearn.model_selection import GridSearchCV
 4 from sklearn.metrics import classification_report
 5 
 6 iris = datasets.load_iris()
 7 parameters = {'kernel':('linear', 'rbf'), 'C':[1, 2, 4], 'gamma':[0.125, 0.25, 0.5 ,1, 2, 4]}
 8 svr = svm.SVC()
 9 clf = GridSearchCV(svr, parameters, n_jobs=-1)
10 clf.fit(iris.data, iris.target)
11 cv_result = pd.DataFrame.from_dict(clf.cv_results_)
12 with open('cv_result.csv','w') as f:
13     cv_result.to_csv(f)
14     
15 print('The parameters of the best model are: ')
16 print(clf.best_params_)
17 
18 y_pred = clf.predict(iris.data)
19 print(classification_report(y_true=iris.target, y_pred=y_pred))

5. 相關資料

  1. sklearn.model_selection.GridSearchCV模塊主頁: http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html
  2. pandas.DataFrame模塊主頁:http://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html
  3. 本文例程 https://github.com/JiJingYu/tensorflow-exercise/tree/master/svm_grid_search

6.將來展望

   當前的工做侷限於算法超參數搜索,尚未結合預處理方式自動搜索、不一樣算法之間自動搜索、不一樣深度學習模型自動搜索等。如何利用pipeline、keras、tf等模塊,實現整個環節的自動搜索,是下一步學習與總結的方向。編程語言

相關文章
相關標籤/搜索