機器學習之路:python 網格搜索 並行搜索 GridSearchCV 模型檢驗方法

 

git:https://github.com/linyi0604/MachineLearninghtml

如何肯定一個模型應該使用哪一種參數?

k折交叉驗證:
將樣本分紅k份
每次取其中一份作測試數據 其餘作訓練數據
一共進行k次訓練和測試
用這種方式 充分利用樣本數據,評估模型在樣本上的表現狀況


網格搜索:
一種暴力枚舉搜索方法
對模型參數列舉出集中可能,
對全部列舉出的可能組合進行模型評估
從而找到最好的模型參數

並行搜索:
因爲每一種參數組合互相是獨立不影響的
全部能夠開啓多線程進行網格搜索
這種方式爲並行搜索



python實現的代碼:
  1 from sklearn.datasets import fetch_20newsgroups
  2 from sklearn.cross_validation import train_test_split
  3 import numpy as np
  4 from sklearn.svm import SVC
  5 from sklearn.feature_extraction.text import TfidfVectorizer
  6 from sklearn.pipeline import Pipeline
  7 from sklearn.grid_search import GridSearchCV
  8 
  9 # 博文: http://www.cnblogs.com/Lin-Yi/p/9000989.html
 10 
 11 '''
 12 如何肯定一個模型應該使用哪一種參數?
 13 
 14 k折交叉驗證:
 15    將樣本分紅k份
 16    每次取其中一份作測試數據 其餘作訓練數據 
 17    一共進行k次訓練和測試
 18    用這種方式 充分利用樣本數據,評估模型在樣本上的表現狀況
 19    
 20    
 21 網格搜索:
 22     一種暴力枚舉搜索方法
 23     對模型參數列舉出集中可能,
 24     對全部列舉出的可能組合進行模型評估
 25     從而找到最好的模型參數
 26     
 27 並行搜索:
 28     因爲每一種參數組合互相是獨立不影響的
 29     全部能夠開啓多線程進行網格搜索
 30     這種方式爲並行搜索
 31 
 32 '''
 33 
 34 # 聯網獲取全部想你問數據
 35 news = fetch_20newsgroups(subset="all")
 36 # 分割訓練數據和測試數據
 37 x_train, x_test, y_train, y_test = train_test_split(news.data[:3000],
 38                                                     news.target[:3000],
 39                                                     test_size=0.25,
 40                                                     random_state=33)
 41 
 42 # 使用pipeline簡化系統搭建流程
 43 clf = Pipeline([("vect", TfidfVectorizer(stop_words="english", analyzer="word")), ("svc", SVC())])
 44 
 45 # 這裏要實驗的超參數有兩個  4個svg__gama 和 3個svg__C 一共12種組合
 46 # np.logspace(start, end, num) 從10^start 到 10^end 建立num個數的等比數列
 47 parameters = {"svc__gamma": np.logspace(-2, 1, 4), "svc__C": np.logspace(-1, 1, 3)}
 48 
 49 # 網格搜索
 50 # 建立一個網格搜索: 12組參數組合, 3折交叉驗證
 51 gs = GridSearchCV(clf, parameters, verbose=2, refit=True, cv=3)
 52 # 設置n_jobs=-1 表示佔用全部cpu開線程   5表示開啓5個同步任務
 53 # windows下不支持fork開啓線程 全部 linux unix mac 能夠用該api
 54 # gs = GridSearchCV(clf, parameters, verbose=2, refit=True, cv=3, n_jobs=-1)
 55 
 56 
 57 # 執行單線程網格搜索
 58 time_ = gs.fit(x_train, y_train)
 59 print(time_)
 60 print(gs.best_params_, gs.best_score_)
 61 # 輸出最佳模型在測試機和上的準確性
 62 print(gs.score(x_test, y_test))
 63 '''
 64 Fitting 3 folds for each of 12 candidates, totalling 36 fits
 65 [CV] svc__C=0.1, svc__gamma=0.01 .....................................
 66 [CV] ............................ svc__C=0.1, svc__gamma=0.01 -   8.3s
 67 [Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    8.3s remaining:    0.0s
 68 [CV] svc__C=0.1, svc__gamma=0.01 .....................................
 69 [CV] ............................ svc__C=0.1, svc__gamma=0.01 -   8.5s
 70 [CV] svc__C=0.1, svc__gamma=0.01 .....................................
 71 [CV] ............................ svc__C=0.1, svc__gamma=0.01 -   8.5s
 72 [CV] svc__C=0.1, svc__gamma=0.1 ......................................
 73 [CV] ............................. svc__C=0.1, svc__gamma=0.1 -   8.4s
 74 [CV] svc__C=0.1, svc__gamma=0.1 ......................................
 75 [CV] ............................. svc__C=0.1, svc__gamma=0.1 -   8.5s
 76 [CV] svc__C=0.1, svc__gamma=0.1 ......................................
 77 [CV] ............................. svc__C=0.1, svc__gamma=0.1 -   8.5s
 78 [CV] svc__C=0.1, svc__gamma=1.0 ......................................
 79 [CV] ............................. svc__C=0.1, svc__gamma=1.0 -   8.4s
 80 [CV] svc__C=0.1, svc__gamma=1.0 ......................................
 81 [CV] ............................. svc__C=0.1, svc__gamma=1.0 -   8.6s
 82 [CV] svc__C=0.1, svc__gamma=1.0 ......................................
 83 [CV] ............................. svc__C=0.1, svc__gamma=1.0 -   8.6s
 84 [CV] svc__C=0.1, svc__gamma=10.0 .....................................
 85 [CV] ............................ svc__C=0.1, svc__gamma=10.0 -   8.5s
 86 [CV] svc__C=0.1, svc__gamma=10.0 .....................................
 87 [CV] ............................ svc__C=0.1, svc__gamma=10.0 -   8.6s
 88 [CV] svc__C=0.1, svc__gamma=10.0 .....................................
 89 [CV] ............................ svc__C=0.1, svc__gamma=10.0 -   8.7s
 90 [CV] svc__C=1.0, svc__gamma=0.01 .....................................
 91 [CV] ............................ svc__C=1.0, svc__gamma=0.01 -   8.3s
 92 [CV] svc__C=1.0, svc__gamma=0.01 .....................................
 93 [CV] ............................ svc__C=1.0, svc__gamma=0.01 -   8.4s
 94 [CV] svc__C=1.0, svc__gamma=0.01 .....................................
 95 [CV] ............................ svc__C=1.0, svc__gamma=0.01 -   8.5s
 96 [CV] svc__C=1.0, svc__gamma=0.1 ......................................
 97 [CV] ............................. svc__C=1.0, svc__gamma=0.1 -   8.3s
 98 [CV] svc__C=1.0, svc__gamma=0.1 ......................................
 99 [CV] ............................. svc__C=1.0, svc__gamma=0.1 -   8.4s
100 [CV] svc__C=1.0, svc__gamma=0.1 ......................................
101 [CV] ............................. svc__C=1.0, svc__gamma=0.1 -   8.5s
102 [CV] svc__C=1.0, svc__gamma=1.0 ......................................
103 [CV] ............................. svc__C=1.0, svc__gamma=1.0 -   8.5s
104 [CV] svc__C=1.0, svc__gamma=1.0 ......................................
105 [CV] ............................. svc__C=1.0, svc__gamma=1.0 -   8.6s
106 [CV] svc__C=1.0, svc__gamma=1.0 ......................................
107 [CV] ............................. svc__C=1.0, svc__gamma=1.0 -   8.7s
108 [CV] svc__C=1.0, svc__gamma=10.0 .....................................
109 [CV] ............................ svc__C=1.0, svc__gamma=10.0 -   8.5s
110 [CV] svc__C=1.0, svc__gamma=10.0 .....................................
111 [CV] ............................ svc__C=1.0, svc__gamma=10.0 -   8.6s
112 [CV] svc__C=1.0, svc__gamma=10.0 .....................................
113 [CV] ............................ svc__C=1.0, svc__gamma=10.0 -   8.7s
114 [CV] svc__C=10.0, svc__gamma=0.01 ....................................
115 [CV] ........................... svc__C=10.0, svc__gamma=0.01 -   8.4s
116 [CV] svc__C=10.0, svc__gamma=0.01 ....................................
117 [CV] ........................... svc__C=10.0, svc__gamma=0.01 -   8.4s
118 [CV] svc__C=10.0, svc__gamma=0.01 ....................................
119 [CV] ........................... svc__C=10.0, svc__gamma=0.01 -   8.7s
120 [CV] svc__C=10.0, svc__gamma=0.1 .....................................
121 [CV] ............................ svc__C=10.0, svc__gamma=0.1 -   8.6s
122 [CV] svc__C=10.0, svc__gamma=0.1 .....................................
123 [CV] ............................ svc__C=10.0, svc__gamma=0.1 -   8.6s
124 [CV] svc__C=10.0, svc__gamma=0.1 .....................................
125 [CV] ............................ svc__C=10.0, svc__gamma=0.1 -   8.6s
126 [CV] svc__C=10.0, svc__gamma=1.0 .....................................
127 [CV] ............................ svc__C=10.0, svc__gamma=1.0 -   8.5s
128 [CV] svc__C=10.0, svc__gamma=1.0 .....................................
129 [CV] ............................ svc__C=10.0, svc__gamma=1.0 -   8.6s
130 [CV] svc__C=10.0, svc__gamma=1.0 .....................................
131 [CV] ............................ svc__C=10.0, svc__gamma=1.0 -   9.3s
132 [CV] svc__C=10.0, svc__gamma=10.0 ....................................
133 [CV] ........................... svc__C=10.0, svc__gamma=10.0 -   8.8s
134 [CV] svc__C=10.0, svc__gamma=10.0 ....................................
135 [CV] ........................... svc__C=10.0, svc__gamma=10.0 -   8.9s
136 [CV] svc__C=10.0, svc__gamma=10.0 ....................................
137 [CV] ........................... svc__C=10.0, svc__gamma=10.0 -   8.7s
138 
139 12組超參數 3折交叉驗證 共36個搜索項 花費5.2分鐘
140 [Parallel(n_jobs=1)]: Done  36 out of  36 | elapsed:  5.2min finished
141 
142 最佳參數   最佳訓練得分
143 {'svc__C': 10.0, 'svc__gamma': 0.1} 0.7906666666666666
144 最佳模型的測試得分
145 0.8226666666666667
146 
147 '''
相關文章
相關標籤/搜索