python sklearn包——grid search筆記

Preface:算法不夠好,須要調試參數時必不可少。好比SVM的懲罰因子C,核函數kernel,gamma參數等,對於不一樣的數據使用不一樣的參數,結果效果可能差1-5個點,sklearn爲咱們提供專門調試參數的函數grid_search。html

在sklearn中以API的形式給出介紹。在離線包中函數較多,但經常使用爲GridSearchCV()這個函數。python

1.GridSearchCV:git

看例子最爲容易懂得使用其的方法。算法

sklearn包中介紹的例子:數組

滷煮直接從官網上貼上例子:grid_search_digits.py數據結構

 

[python]  view plain  copy
 
 
 
  在CODE上查看代碼片派生到個人代碼片
  1. from __future__ import print_function  
  2.   
  3. from sklearn import datasets  
  4. from sklearn.cross_validation import train_test_split  
  5. from sklearn.grid_search import GridSearchCV  
  6. from sklearn.metrics import classification_report  
  7. from sklearn.svm import SVC  
  8.   
  9. print(__doc__)  
  10.   
  11. # Loading the Digits dataset  
  12. digits = datasets.load_digits()  
  13.   
  14. # To apply an classifier on this data, we need to flatten the image, to  
  15. # turn the data in a (samples, feature) matrix:  
  16. n_samples = len(digits.images)  
  17. X = digits.images.reshape((n_samples, -1))  
  18. y = digits.target  
  19.   
  20. # Split the dataset in two equal parts  
  21. X_train, X_test, y_train, y_test = train_test_split(  
  22.     X, y, test_size=0.5, random_state=0)  
  23.   
  24. # Set the parameters by cross-validation  
  25. tuned_parameters = [{'kernel': ['rbf'], 'gamma': [1e-3, 1e-4],  
  26.                      'C': [1, 10, 100, 1000]},  
  27.                     {'kernel': ['linear'], 'C': [1, 10, 100, 1000]}]  
  28.   
  29. scores = ['precision', 'recall']  
  30.   
  31. for score in scores:  
  32.     print("# Tuning hyper-parameters for %s" % score)  
  33.     print()  
  34.   
  35.     clf = GridSearchCV(SVC(C=1), tuned_parameters, cv=5,  
  36.                        scoring='%s_weighted' % score)  
  37.     clf.fit(X_train, y_train)  
  38.   
  39.     print("Best parameters set found on development set:")  
  40.     print()  
  41.     print(clf.best_params_)  
  42.     print()  
  43.     print("Grid scores on development set:")  
  44.     print()  
  45.     for params, mean_score, scores in clf.grid_scores_:  
  46.         print("%0.3f (+/-%0.03f) for %r"  
  47.               % (mean_score, scores.std() * 2, params))  
  48.     print()  
  49.   
  50.     print("Detailed classification report:")  
  51.     print()  
  52.     print("The model is trained on the full development set.")  
  53.     print("The scores are computed on the full evaluation set.")  
  54.     print()  
  55.     y_true, y_pred = y_test, clf.predict(X_test)  
  56.     print(classification_report(y_true, y_pred))  
  57.     print()  

 

其中,將參數放在列表中app

 

tuned_parameters = [{'kernel': ['rbf'], 'gamma': [1e-3, 1e-4], 'C': [1, 10, 100, 1000]}, {'kernel': ['linear'], 'C': [1, 10, 100, 1000]}]
創建分類器clf時,調用GridSearchCV()函數,將上述參數列表的變量傳入函數。而且可傳入交叉驗證cv參數,設置爲5折交叉驗證。對訓練集訓練完成後調用best_params_變量,打印出訓練的最佳參數組。

 

Figure :運行結果dom

能夠看出,其得出最佳參數組字典,還有每一次用參數組進行訓練得出的得分。最後在測試集上,給出10個類別的測試報告,對於類別0,RPF都爲1,。。。。這裏使用sklearn.metrics下的classification_report()函數便可,輸入測試集真實的結果和預測的結果即返回每一個類別的準確率召回率F值以及宏平均值。函數

對於SVM分類器,這裏只列出線性核和RBF核,其中線性核沒必要用gamma這個參數,RBF核可用不一樣懲罰值C和不一樣的gamma值做爲組合。上述列出的結果便可看出有哪些組合。這裏的結果是RBF核,懲罰項爲10,gamma值爲0.001效果最佳。滷煮覺得RBF核是比較好的,可是在最近的學習中,確實是不必定,用了線性核效果更好些,但選訓練很是慢,數據集不同效果差不少吧,可能。學習

另外有個grid_search_text_feature_extraction.py程序寫得也很不錯,只是滷煮fetch_20newsgroup數據集沒有準備好,跑不了

相關文章
相關標籤/搜索