更新、更全的《機器學習》的更新網站,更有python、go、數據結構與算法、爬蟲、人工智能教學等着你:http://www.javashuo.com/article/p-vozphyqp-cm.htmlpython
爲了實現接下里的代碼,你須要安裝下列5個Python第三方庫,本文只拿sklearn的安裝舉例,若是有同窗已經安裝sklearn,能夠把你的sklearn更新到最新版本,其餘庫同理。算法
安裝方式爲:apache
pip install sklearn
api
更新方式爲:數據結構
pip install --upgrade sklearn
dom
sklearn英文文檔:https://scikit-learn.org/stable/index.html機器學習
sklear中文文檔:http://sklearn.apachecn.org/#/函數
# 終端輸入,安裝sklear,其餘庫同理 !pip install sklearn
Requirement already satisfied: sklearn in /Applications/anaconda3/lib/python3.7/site-packages (0.0) Requirement already satisfied: scikit-learn in /Applications/anaconda3/lib/python3.7/site-packages (from sklearn) (0.20.1) Requirement already satisfied: numpy>=1.8.2 in /Applications/anaconda3/lib/python3.7/site-packages (from scikit-learn->sklearn) (1.15.4) Requirement already satisfied: scipy>=0.13.3 in /Applications/anaconda3/lib/python3.7/site-packages (from scikit-learn->sklearn) (1.1.0)
import sklearn # 打印sklearn的版本 sklearn.__version__
'0.20.1'
# 終端輸入,更新sklear !pip install --upgrade sklearn
Requirement already up-to-date: sklearn in /Applications/anaconda3/lib/python3.7/site-packages (0.0) Requirement already satisfied, skipping upgrade: scikit-learn in /Applications/anaconda3/lib/python3.7/site-packages (from sklearn) (0.20.1) Requirement already satisfied, skipping upgrade: numpy>=1.8.2 in /Applications/anaconda3/lib/python3.7/site-packages (from scikit-learn->sklearn) (1.15.4) Requirement already satisfied, skipping upgrade: scipy>=0.13.3 in /Applications/anaconda3/lib/python3.7/site-packages (from scikit-learn->sklearn) (1.1.0)
模型 | 功能模塊 |
---|---|
estimator.fit(X_train, y_train) | estimator.fit(X_train, y_train) |
estimator.predict(X_test) | estimator.transform(X_test) |
get_params([deep]) | get_params([deep]) |
set_params(**params) | set_params(**params) |
適用於如下模型 | 適用於如下功能模塊 |
Classification(分類) | Preprocessing(數據預處理) |
Regression(迴歸) | Dimensionality Reduction(降維) |
Clustering(聚類) | Feature Selection(特徵選擇) |
- | Feature Extraction(特徵提取) |
此處只是簡單的帶同窗們瞭解下構建機器學習應用程序的流程,即如下6個步驟:工具
1. 收集數據 2. 數據預處理 3. 訓練模型 4. 測試模型 5. 優化模型 6. 持久化模型
以後會詳細講解該流程的每個步驟。
構建機器學習應用程序,不管是監督學習仍是無監督學習,第一步都是獲取數據,此處爲了帶你們對構建機器學習應用程序有一個簡單的瞭解,因此利用sklearn自帶鳶尾花數據集做展現,以後再收集數據小節會詳細介紹收集數據的幾種方式。
import numpy as np import pandas as pd import matplotlib.pyplot as plt from matplotlib.font_manager import FontProperties from sklearn import datasets %matplotlib inline font = FontProperties(fname='/Library/Fonts/Heiti.ttc') iris = datasets.load_iris() iris
{'data': array([[5.1, 3.5, 1.4, 0.2], [4.9, 3. , 1.4, 0.2], [4.7, 3.2, 1.3, 0.2], [4.6, 3.1, 1.5, 0.2], [5. , 3.6, 1.4, 0.2], [5.4, 3.9, 1.7, 0.4], [4.6, 3.4, 1.4, 0.3], [5. , 3.4, 1.5, 0.2], [4.4, 2.9, 1.4, 0.2], [4.9, 3.1, 1.5, 0.1], [5.4, 3.7, 1.5, 0.2], [4.8, 3.4, 1.6, 0.2], [4.8, 3. , 1.4, 0.1], [4.3, 3. , 1.1, 0.1], [5.8, 4. , 1.2, 0.2], [5.7, 4.4, 1.5, 0.4], [5.4, 3.9, 1.3, 0.4], [5.1, 3.5, 1.4, 0.3], [5.7, 3.8, 1.7, 0.3], [5.1, 3.8, 1.5, 0.3], [5.4, 3.4, 1.7, 0.2], [5.1, 3.7, 1.5, 0.4], [4.6, 3.6, 1. , 0.2], [5.1, 3.3, 1.7, 0.5], [4.8, 3.4, 1.9, 0.2], [5. , 3. , 1.6, 0.2], [5. , 3.4, 1.6, 0.4], [5.2, 3.5, 1.5, 0.2], [5.2, 3.4, 1.4, 0.2], [4.7, 3.2, 1.6, 0.2], [4.8, 3.1, 1.6, 0.2], [5.4, 3.4, 1.5, 0.4], [5.2, 4.1, 1.5, 0.1], [5.5, 4.2, 1.4, 0.2], [4.9, 3.1, 1.5, 0.2], [5. , 3.2, 1.2, 0.2], [5.5, 3.5, 1.3, 0.2], [4.9, 3.6, 1.4, 0.1], [4.4, 3. , 1.3, 0.2], [5.1, 3.4, 1.5, 0.2], [5. , 3.5, 1.3, 0.3], [4.5, 2.3, 1.3, 0.3], [4.4, 3.2, 1.3, 0.2], [5. , 3.5, 1.6, 0.6], [5.1, 3.8, 1.9, 0.4], [4.8, 3. , 1.4, 0.3], [5.1, 3.8, 1.6, 0.2], [4.6, 3.2, 1.4, 0.2], [5.3, 3.7, 1.5, 0.2], [5. , 3.3, 1.4, 0.2], [7. , 3.2, 4.7, 1.4], [6.4, 3.2, 4.5, 1.5], [6.9, 3.1, 4.9, 1.5], [5.5, 2.3, 4. , 1.3], [6.5, 2.8, 4.6, 1.5], [5.7, 2.8, 4.5, 1.3], [6.3, 3.3, 4.7, 1.6], [4.9, 2.4, 3.3, 1. ], [6.6, 2.9, 4.6, 1.3], [5.2, 2.7, 3.9, 1.4], [5. , 2. , 3.5, 1. ], [5.9, 3. , 4.2, 1.5], [6. , 2.2, 4. , 1. ], [6.1, 2.9, 4.7, 1.4], [5.6, 2.9, 3.6, 1.3], [6.7, 3.1, 4.4, 1.4], [5.6, 3. , 4.5, 1.5], [5.8, 2.7, 4.1, 1. ], [6.2, 2.2, 4.5, 1.5], [5.6, 2.5, 3.9, 1.1], [5.9, 3.2, 4.8, 1.8], [6.1, 2.8, 4. , 1.3], [6.3, 2.5, 4.9, 1.5], [6.1, 2.8, 4.7, 1.2], [6.4, 2.9, 4.3, 1.3], [6.6, 3. , 4.4, 1.4], [6.8, 2.8, 4.8, 1.4], [6.7, 3. , 5. , 1.7], [6. , 2.9, 4.5, 1.5], [5.7, 2.6, 3.5, 1. ], [5.5, 2.4, 3.8, 1.1], [5.5, 2.4, 3.7, 1. ], [5.8, 2.7, 3.9, 1.2], [6. , 2.7, 5.1, 1.6], [5.4, 3. , 4.5, 1.5], [6. , 3.4, 4.5, 1.6], [6.7, 3.1, 4.7, 1.5], [6.3, 2.3, 4.4, 1.3], [5.6, 3. , 4.1, 1.3], [5.5, 2.5, 4. , 1.3], [5.5, 2.6, 4.4, 1.2], [6.1, 3. , 4.6, 1.4], [5.8, 2.6, 4. , 1.2], [5. , 2.3, 3.3, 1. ], [5.6, 2.7, 4.2, 1.3], [5.7, 3. , 4.2, 1.2], [5.7, 2.9, 4.2, 1.3], [6.2, 2.9, 4.3, 1.3], [5.1, 2.5, 3. , 1.1], [5.7, 2.8, 4.1, 1.3], [6.3, 3.3, 6. , 2.5], [5.8, 2.7, 5.1, 1.9], [7.1, 3. , 5.9, 2.1], [6.3, 2.9, 5.6, 1.8], [6.5, 3. , 5.8, 2.2], [7.6, 3. , 6.6, 2.1], [4.9, 2.5, 4.5, 1.7], [7.3, 2.9, 6.3, 1.8], [6.7, 2.5, 5.8, 1.8], [7.2, 3.6, 6.1, 2.5], [6.5, 3.2, 5.1, 2. ], [6.4, 2.7, 5.3, 1.9], [6.8, 3. , 5.5, 2.1], [5.7, 2.5, 5. , 2. ], [5.8, 2.8, 5.1, 2.4], [6.4, 3.2, 5.3, 2.3], [6.5, 3. , 5.5, 1.8], [7.7, 3.8, 6.7, 2.2], [7.7, 2.6, 6.9, 2.3], [6. , 2.2, 5. , 1.5], [6.9, 3.2, 5.7, 2.3], [5.6, 2.8, 4.9, 2. ], [7.7, 2.8, 6.7, 2. ], [6.3, 2.7, 4.9, 1.8], [6.7, 3.3, 5.7, 2.1], [7.2, 3.2, 6. , 1.8], [6.2, 2.8, 4.8, 1.8], [6.1, 3. , 4.9, 1.8], [6.4, 2.8, 5.6, 2.1], [7.2, 3. , 5.8, 1.6], [7.4, 2.8, 6.1, 1.9], [7.9, 3.8, 6.4, 2. ], [6.4, 2.8, 5.6, 2.2], [6.3, 2.8, 5.1, 1.5], [6.1, 2.6, 5.6, 1.4], [7.7, 3. , 6.1, 2.3], [6.3, 3.4, 5.6, 2.4], [6.4, 3.1, 5.5, 1.8], [6. , 3. , 4.8, 1.8], [6.9, 3.1, 5.4, 2.1], [6.7, 3.1, 5.6, 2.4], [6.9, 3.1, 5.1, 2.3], [5.8, 2.7, 5.1, 1.9], [6.8, 3.2, 5.9, 2.3], [6.7, 3.3, 5.7, 2.5], [6.7, 3. , 5.2, 2.3], [6.3, 2.5, 5. , 1.9], [6.5, 3. , 5.2, 2. ], [6.2, 3.4, 5.4, 2.3], [5.9, 3. , 5.1, 1.8]]), 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]), 'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='<U10'), 'DESCR': '.. _iris_dataset:\n\nIris plants dataset\n--------------------\n\n**Data Set Characteristics:**\n\n :Number of Instances: 150 (50 in each of three classes)\n :Number of Attributes: 4 numeric, predictive attributes and the class\n :Attribute Information:\n - sepal length in cm\n - sepal width in cm\n - petal length in cm\n - petal width in cm\n - class:\n - Iris-Setosa\n - Iris-Versicolour\n - Iris-Virginica\n \n :Summary Statistics:\n\n ============== ==== ==== ======= ===== ====================\n Min Max Mean SD Class Correlation\n ============== ==== ==== ======= ===== ====================\n sepal length: 4.3 7.9 5.84 0.83 0.7826\n sepal width: 2.0 4.4 3.05 0.43 -0.4194\n petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)\n petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)\n ============== ==== ==== ======= ===== ====================\n\n :Missing Attribute Values: None\n :Class Distribution: 33.3% for each of 3 classes.\n :Creator: R.A. Fisher\n :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n :Date: July, 1988\n\nThe famous Iris database, first used by Sir R.A. Fisher. The dataset is taken\nfrom Fisher\'s paper. Note that it\'s the same as in R, but not as in the UCI\nMachine Learning Repository, which has two wrong data points.\n\nThis is perhaps the best known database to be found in the\npattern recognition literature. Fisher\'s paper is a classic in the field and\nis referenced frequently to this day. (See Duda & Hart, for example.) The\ndata set contains 3 classes of 50 instances each, where each class refers to a\ntype of iris plant. One class is linearly separable from the other 2; the\nlatter are NOT linearly separable from each other.\n\n.. topic:: References\n\n - Fisher, R.A. "The use of multiple measurements in taxonomic problems"\n Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to\n Mathematical Statistics" (John Wiley, NY, 1950).\n - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.\n (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.\n - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System\n Structure and Classification Rule for Recognition in Partially Exposed\n Environments". IEEE Transactions on Pattern Analysis and Machine\n Intelligence, Vol. PAMI-2, No. 1, 67-71.\n - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE Transactions\n on Information Theory, May 1972, 431-433.\n - See also: 1988 MLC Proceedings, 54-64. Cheeseman et al"s AUTOCLASS II\n conceptual clustering system finds 3 classes in the data.\n - Many, many more ...', 'feature_names': ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'], 'filename': '/Applications/anaconda3/lib/python3.6/site-packages/sklearn/datasets/data/iris.csv'}
X = iris.data # 總共有150個樣本數據,此處只打印5個 'X的個數:{}'.format(len(X)), 'X:{}'.format(X[0:5])
('X的個數:150', 'X:[[5.1 3.5 1.4 0.2]\n [4.9 3. 1.4 0.2]\n [4.7 3.2 1.3 0.2]\n [4.6 3.1 1.5 0.2]\n [5. 3.6 1.4 0.2]]')
y = iris.target 'y的個數:{}'.format(len(y)), 'y:{}'.format(y)
('y的個數:150', 'y:[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2\n 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n 2 2]')
# pandas可視化數據 df = pd.DataFrame(X, columns=iris.feature_names) df['target'] = y df.plot(figsize=(10, 8)) plt.show()
# matplotlib可視化 # matplotlib適合二維可視化,所以只選特徵一、2,即萼片長度、萼片寬度 # 取全部行的第1,2列特徵 X_ = X[:, [0, 1]] # 取出山鳶尾數據 plt.scatter(X_[0:50, 0], X_[0:50, 1], color='r', label='山鳶尾', s=10) # 取出雜色鳶尾數據 plt.scatter(X_[50:100, 0], X_[50:100, 1], color='g', label='雜色鳶尾', s=50) # 取出維吉尼亞鳶尾 plt.scatter(X_[100:150, 0], X_[100:150, 1], color='b', label='維吉尼亞鳶尾', s=100) plt.legend(prop=font) plt.xlabel('萼片長度', fontproperties=font, fontsize=15) plt.ylabel('萼片寬度', fontproperties=font, fontsize=15) plt.title('萼片長度-萼片寬度', fontproperties=font, fontsize=20) plt.show()
能夠發現鳶尾花數據的某一個特徵的特徵值最小值和最大值差距很是大,爲了解決上述相同權重特徵不一樣尺度的問題,可使用機器學習中的最小-最大標準化作處理,把他們兩個值壓縮在\([0-1]\)區間內。
最小-最大標準化公式:
\[ x_{norm}^{(i)}={\frac{x^{(i)}-x_{min}}{x_{max}-x_{min}}} \]
其中\(i=1,2,\cdots,m\);\(m\)爲樣本個數;\(x_{min},x_{max}\)分別是某個的特徵最小值和最大值。
from sklearn.preprocessing import MinMaxScaler scaler = MinMaxScaler() # scaler.fit_transform(X) # 等同於先fit()後transform() scaler = scaler.fit(X) print(X) X1 = scaler.transform(X) X1
[[5.1 3.5 1.4 0.2] [4.9 3. 1.4 0.2] [4.7 3.2 1.3 0.2] [4.6 3.1 1.5 0.2] [5. 3.6 1.4 0.2] [5.4 3.9 1.7 0.4] [4.6 3.4 1.4 0.3] [5. 3.4 1.5 0.2] [4.4 2.9 1.4 0.2] [4.9 3.1 1.5 0.1] [5.4 3.7 1.5 0.2] [4.8 3.4 1.6 0.2] [4.8 3. 1.4 0.1] [4.3 3. 1.1 0.1] [5.8 4. 1.2 0.2] [5.7 4.4 1.5 0.4] [5.4 3.9 1.3 0.4] [5.1 3.5 1.4 0.3] [5.7 3.8 1.7 0.3] [5.1 3.8 1.5 0.3] [5.4 3.4 1.7 0.2] [5.1 3.7 1.5 0.4] [4.6 3.6 1. 0.2] [5.1 3.3 1.7 0.5] [4.8 3.4 1.9 0.2] [5. 3. 1.6 0.2] [5. 3.4 1.6 0.4] [5.2 3.5 1.5 0.2] [5.2 3.4 1.4 0.2] [4.7 3.2 1.6 0.2] [4.8 3.1 1.6 0.2] [5.4 3.4 1.5 0.4] [5.2 4.1 1.5 0.1] [5.5 4.2 1.4 0.2] [4.9 3.1 1.5 0.2] [5. 3.2 1.2 0.2] [5.5 3.5 1.3 0.2] [4.9 3.6 1.4 0.1] [4.4 3. 1.3 0.2] [5.1 3.4 1.5 0.2] [5. 3.5 1.3 0.3] [4.5 2.3 1.3 0.3] [4.4 3.2 1.3 0.2] [5. 3.5 1.6 0.6] [5.1 3.8 1.9 0.4] [4.8 3. 1.4 0.3] [5.1 3.8 1.6 0.2] [4.6 3.2 1.4 0.2] [5.3 3.7 1.5 0.2] [5. 3.3 1.4 0.2] [7. 3.2 4.7 1.4] [6.4 3.2 4.5 1.5] [6.9 3.1 4.9 1.5] [5.5 2.3 4. 1.3] [6.5 2.8 4.6 1.5] [5.7 2.8 4.5 1.3] [6.3 3.3 4.7 1.6] [4.9 2.4 3.3 1. ] [6.6 2.9 4.6 1.3] [5.2 2.7 3.9 1.4] [5. 2. 3.5 1. ] [5.9 3. 4.2 1.5] [6. 2.2 4. 1. ] [6.1 2.9 4.7 1.4] [5.6 2.9 3.6 1.3] [6.7 3.1 4.4 1.4] [5.6 3. 4.5 1.5] [5.8 2.7 4.1 1. ] [6.2 2.2 4.5 1.5] [5.6 2.5 3.9 1.1] [5.9 3.2 4.8 1.8] [6.1 2.8 4. 1.3] [6.3 2.5 4.9 1.5] [6.1 2.8 4.7 1.2] [6.4 2.9 4.3 1.3] [6.6 3. 4.4 1.4] [6.8 2.8 4.8 1.4] [6.7 3. 5. 1.7] [6. 2.9 4.5 1.5] [5.7 2.6 3.5 1. ] [5.5 2.4 3.8 1.1] [5.5 2.4 3.7 1. ] [5.8 2.7 3.9 1.2] [6. 2.7 5.1 1.6] [5.4 3. 4.5 1.5] [6. 3.4 4.5 1.6] [6.7 3.1 4.7 1.5] [6.3 2.3 4.4 1.3] [5.6 3. 4.1 1.3] [5.5 2.5 4. 1.3] [5.5 2.6 4.4 1.2] [6.1 3. 4.6 1.4] [5.8 2.6 4. 1.2] [5. 2.3 3.3 1. ] [5.6 2.7 4.2 1.3] [5.7 3. 4.2 1.2] [5.7 2.9 4.2 1.3] [6.2 2.9 4.3 1.3] [5.1 2.5 3. 1.1] [5.7 2.8 4.1 1.3] [6.3 3.3 6. 2.5] [5.8 2.7 5.1 1.9] [7.1 3. 5.9 2.1] [6.3 2.9 5.6 1.8] [6.5 3. 5.8 2.2] [7.6 3. 6.6 2.1] [4.9 2.5 4.5 1.7] [7.3 2.9 6.3 1.8] [6.7 2.5 5.8 1.8] [7.2 3.6 6.1 2.5] [6.5 3.2 5.1 2. ] [6.4 2.7 5.3 1.9] [6.8 3. 5.5 2.1] [5.7 2.5 5. 2. ] [5.8 2.8 5.1 2.4] [6.4 3.2 5.3 2.3] [6.5 3. 5.5 1.8] [7.7 3.8 6.7 2.2] [7.7 2.6 6.9 2.3] [6. 2.2 5. 1.5] [6.9 3.2 5.7 2.3] [5.6 2.8 4.9 2. ] [7.7 2.8 6.7 2. ] [6.3 2.7 4.9 1.8] [6.7 3.3 5.7 2.1] [7.2 3.2 6. 1.8] [6.2 2.8 4.8 1.8] [6.1 3. 4.9 1.8] [6.4 2.8 5.6 2.1] [7.2 3. 5.8 1.6] [7.4 2.8 6.1 1.9] [7.9 3.8 6.4 2. ] [6.4 2.8 5.6 2.2] [6.3 2.8 5.1 1.5] [6.1 2.6 5.6 1.4] [7.7 3. 6.1 2.3] [6.3 3.4 5.6 2.4] [6.4 3.1 5.5 1.8] [6. 3. 4.8 1.8] [6.9 3.1 5.4 2.1] [6.7 3.1 5.6 2.4] [6.9 3.1 5.1 2.3] [5.8 2.7 5.1 1.9] [6.8 3.2 5.9 2.3] [6.7 3.3 5.7 2.5] [6.7 3. 5.2 2.3] [6.3 2.5 5. 1.9] [6.5 3. 5.2 2. ] [6.2 3.4 5.4 2.3] [5.9 3. 5.1 1.8]] array([[0.22222222, 0.625 , 0.06779661, 0.04166667], [0.16666667, 0.41666667, 0.06779661, 0.04166667], [0.11111111, 0.5 , 0.05084746, 0.04166667], [0.08333333, 0.45833333, 0.08474576, 0.04166667], [0.19444444, 0.66666667, 0.06779661, 0.04166667], [0.30555556, 0.79166667, 0.11864407, 0.125 ], [0.08333333, 0.58333333, 0.06779661, 0.08333333], [0.19444444, 0.58333333, 0.08474576, 0.04166667], [0.02777778, 0.375 , 0.06779661, 0.04166667], [0.16666667, 0.45833333, 0.08474576, 0. ], [0.30555556, 0.70833333, 0.08474576, 0.04166667], [0.13888889, 0.58333333, 0.10169492, 0.04166667], [0.13888889, 0.41666667, 0.06779661, 0. ], [0. , 0.41666667, 0.01694915, 0. ], [0.41666667, 0.83333333, 0.03389831, 0.04166667], [0.38888889, 1. , 0.08474576, 0.125 ], [0.30555556, 0.79166667, 0.05084746, 0.125 ], [0.22222222, 0.625 , 0.06779661, 0.08333333], [0.38888889, 0.75 , 0.11864407, 0.08333333], [0.22222222, 0.75 , 0.08474576, 0.08333333], [0.30555556, 0.58333333, 0.11864407, 0.04166667], [0.22222222, 0.70833333, 0.08474576, 0.125 ], [0.08333333, 0.66666667, 0. , 0.04166667], [0.22222222, 0.54166667, 0.11864407, 0.16666667], [0.13888889, 0.58333333, 0.15254237, 0.04166667], [0.19444444, 0.41666667, 0.10169492, 0.04166667], [0.19444444, 0.58333333, 0.10169492, 0.125 ], [0.25 , 0.625 , 0.08474576, 0.04166667], [0.25 , 0.58333333, 0.06779661, 0.04166667], [0.11111111, 0.5 , 0.10169492, 0.04166667], [0.13888889, 0.45833333, 0.10169492, 0.04166667], [0.30555556, 0.58333333, 0.08474576, 0.125 ], [0.25 , 0.875 , 0.08474576, 0. ], [0.33333333, 0.91666667, 0.06779661, 0.04166667], [0.16666667, 0.45833333, 0.08474576, 0.04166667], [0.19444444, 0.5 , 0.03389831, 0.04166667], [0.33333333, 0.625 , 0.05084746, 0.04166667], [0.16666667, 0.66666667, 0.06779661, 0. ], [0.02777778, 0.41666667, 0.05084746, 0.04166667], [0.22222222, 0.58333333, 0.08474576, 0.04166667], [0.19444444, 0.625 , 0.05084746, 0.08333333], [0.05555556, 0.125 , 0.05084746, 0.08333333], [0.02777778, 0.5 , 0.05084746, 0.04166667], [0.19444444, 0.625 , 0.10169492, 0.20833333], [0.22222222, 0.75 , 0.15254237, 0.125 ], [0.13888889, 0.41666667, 0.06779661, 0.08333333], [0.22222222, 0.75 , 0.10169492, 0.04166667], [0.08333333, 0.5 , 0.06779661, 0.04166667], [0.27777778, 0.70833333, 0.08474576, 0.04166667], [0.19444444, 0.54166667, 0.06779661, 0.04166667], [0.75 , 0.5 , 0.62711864, 0.54166667], [0.58333333, 0.5 , 0.59322034, 0.58333333], [0.72222222, 0.45833333, 0.66101695, 0.58333333], [0.33333333, 0.125 , 0.50847458, 0.5 ], [0.61111111, 0.33333333, 0.61016949, 0.58333333], [0.38888889, 0.33333333, 0.59322034, 0.5 ], [0.55555556, 0.54166667, 0.62711864, 0.625 ], [0.16666667, 0.16666667, 0.38983051, 0.375 ], [0.63888889, 0.375 , 0.61016949, 0.5 ], [0.25 , 0.29166667, 0.49152542, 0.54166667], [0.19444444, 0. , 0.42372881, 0.375 ], [0.44444444, 0.41666667, 0.54237288, 0.58333333], [0.47222222, 0.08333333, 0.50847458, 0.375 ], [0.5 , 0.375 , 0.62711864, 0.54166667], [0.36111111, 0.375 , 0.44067797, 0.5 ], [0.66666667, 0.45833333, 0.57627119, 0.54166667], [0.36111111, 0.41666667, 0.59322034, 0.58333333], [0.41666667, 0.29166667, 0.52542373, 0.375 ], [0.52777778, 0.08333333, 0.59322034, 0.58333333], [0.36111111, 0.20833333, 0.49152542, 0.41666667], [0.44444444, 0.5 , 0.6440678 , 0.70833333], [0.5 , 0.33333333, 0.50847458, 0.5 ], [0.55555556, 0.20833333, 0.66101695, 0.58333333], [0.5 , 0.33333333, 0.62711864, 0.45833333], [0.58333333, 0.375 , 0.55932203, 0.5 ], [0.63888889, 0.41666667, 0.57627119, 0.54166667], [0.69444444, 0.33333333, 0.6440678 , 0.54166667], [0.66666667, 0.41666667, 0.6779661 , 0.66666667], [0.47222222, 0.375 , 0.59322034, 0.58333333], [0.38888889, 0.25 , 0.42372881, 0.375 ], [0.33333333, 0.16666667, 0.47457627, 0.41666667], [0.33333333, 0.16666667, 0.45762712, 0.375 ], [0.41666667, 0.29166667, 0.49152542, 0.45833333], [0.47222222, 0.29166667, 0.69491525, 0.625 ], [0.30555556, 0.41666667, 0.59322034, 0.58333333], [0.47222222, 0.58333333, 0.59322034, 0.625 ], [0.66666667, 0.45833333, 0.62711864, 0.58333333], [0.55555556, 0.125 , 0.57627119, 0.5 ], [0.36111111, 0.41666667, 0.52542373, 0.5 ], [0.33333333, 0.20833333, 0.50847458, 0.5 ], [0.33333333, 0.25 , 0.57627119, 0.45833333], [0.5 , 0.41666667, 0.61016949, 0.54166667], [0.41666667, 0.25 , 0.50847458, 0.45833333], [0.19444444, 0.125 , 0.38983051, 0.375 ], [0.36111111, 0.29166667, 0.54237288, 0.5 ], [0.38888889, 0.41666667, 0.54237288, 0.45833333], [0.38888889, 0.375 , 0.54237288, 0.5 ], [0.52777778, 0.375 , 0.55932203, 0.5 ], [0.22222222, 0.20833333, 0.33898305, 0.41666667], [0.38888889, 0.33333333, 0.52542373, 0.5 ], [0.55555556, 0.54166667, 0.84745763, 1. ], [0.41666667, 0.29166667, 0.69491525, 0.75 ], [0.77777778, 0.41666667, 0.83050847, 0.83333333], [0.55555556, 0.375 , 0.77966102, 0.70833333], [0.61111111, 0.41666667, 0.81355932, 0.875 ], [0.91666667, 0.41666667, 0.94915254, 0.83333333], [0.16666667, 0.20833333, 0.59322034, 0.66666667], [0.83333333, 0.375 , 0.89830508, 0.70833333], [0.66666667, 0.20833333, 0.81355932, 0.70833333], [0.80555556, 0.66666667, 0.86440678, 1. ], [0.61111111, 0.5 , 0.69491525, 0.79166667], [0.58333333, 0.29166667, 0.72881356, 0.75 ], [0.69444444, 0.41666667, 0.76271186, 0.83333333], [0.38888889, 0.20833333, 0.6779661 , 0.79166667], [0.41666667, 0.33333333, 0.69491525, 0.95833333], [0.58333333, 0.5 , 0.72881356, 0.91666667], [0.61111111, 0.41666667, 0.76271186, 0.70833333], [0.94444444, 0.75 , 0.96610169, 0.875 ], [0.94444444, 0.25 , 1. , 0.91666667], [0.47222222, 0.08333333, 0.6779661 , 0.58333333], [0.72222222, 0.5 , 0.79661017, 0.91666667], [0.36111111, 0.33333333, 0.66101695, 0.79166667], [0.94444444, 0.33333333, 0.96610169, 0.79166667], [0.55555556, 0.29166667, 0.66101695, 0.70833333], [0.66666667, 0.54166667, 0.79661017, 0.83333333], [0.80555556, 0.5 , 0.84745763, 0.70833333], [0.52777778, 0.33333333, 0.6440678 , 0.70833333], [0.5 , 0.41666667, 0.66101695, 0.70833333], [0.58333333, 0.33333333, 0.77966102, 0.83333333], [0.80555556, 0.41666667, 0.81355932, 0.625 ], [0.86111111, 0.33333333, 0.86440678, 0.75 ], [1. , 0.75 , 0.91525424, 0.79166667], [0.58333333, 0.33333333, 0.77966102, 0.875 ], [0.55555556, 0.33333333, 0.69491525, 0.58333333], [0.5 , 0.25 , 0.77966102, 0.54166667], [0.94444444, 0.41666667, 0.86440678, 0.91666667], [0.55555556, 0.58333333, 0.77966102, 0.95833333], [0.58333333, 0.45833333, 0.76271186, 0.70833333], [0.47222222, 0.41666667, 0.6440678 , 0.70833333], [0.72222222, 0.45833333, 0.74576271, 0.83333333], [0.66666667, 0.45833333, 0.77966102, 0.95833333], [0.72222222, 0.45833333, 0.69491525, 0.91666667], [0.41666667, 0.29166667, 0.69491525, 0.75 ], [0.69444444, 0.5 , 0.83050847, 0.91666667], [0.66666667, 0.54166667, 0.79661017, 1. ], [0.66666667, 0.41666667, 0.71186441, 0.91666667], [0.55555556, 0.20833333, 0.6779661 , 0.75 ], [0.61111111, 0.41666667, 0.71186441, 0.79166667], [0.52777778, 0.58333333, 0.74576271, 0.91666667], [0.44444444, 0.41666667, 0.69491525, 0.70833333]])
對於不一樣的問題須要考慮不一樣的機器學習算法,如分類問題使用分類算法;迴歸問題使用迴歸算法……
對於鳶尾花分類問題,能夠考慮使用分類問題,可是使用哪一個分類算法呢?咱們能夠從sklearn使用地圖中獲取。
鳶尾花的樣本數大於50個->屬於分類問題->有已標記數據->樣本數小於100K->線性核SVD(LinearSVC)
from sklearn.model_selection import train_test_split # 把訓練集按照7:3的比例分紅訓練集和測試集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1/3) '訓練集長度:{}'.format(len(y_train)), '測試集長度:{}'.format(len(y_test))
('訓練集長度:100', '測試集長度:50')
y_train
array([1, 0, 0, 0, 2, 1, 1, 0, 2, 2, 2, 0, 1, 0, 2, 1, 0, 0, 1, 2, 0, 1, 1, 2, 0, 2, 0, 0, 2, 2, 2, 1, 0, 2, 0, 1, 2, 0, 1, 2, 1, 1, 0, 1, 1, 0, 1, 2, 2, 2, 0, 2, 2, 1, 2, 2, 1, 2, 0, 1, 0, 2, 0, 1, 1, 1, 0, 0, 1, 0, 2, 2, 0, 2, 0, 1, 1, 1, 1, 0, 1, 1, 2, 0, 0, 1, 1, 1, 2, 1, 2, 0, 2, 0, 1, 0, 1, 0, 0, 2])
y_test
array([1, 2, 1, 0, 0, 2, 1, 2, 2, 1, 2, 0, 2, 0, 0, 1, 1, 1, 2, 1, 0, 2, 0, 1, 2, 1, 2, 2, 0, 2, 2, 2, 0, 1, 2, 2, 2, 1, 2, 0, 0, 0, 1, 0, 1, 2, 1, 0, 0, 0])
from sklearn.svm import SVC # 同理 from sklearn.svm import LinearSVC # probability=Ture時才能打印分類機率,即才能使用下面的predict_proba()方法 clf = SVC(kernel='linear', probability=True) # 訓練數據 clf.fit(X_train, y_train) # 預測數據分類結果 y_prd = clf.predict(X_test) y_prd
array([1, 2, 1, 0, 0, 2, 1, 2, 2, 1, 2, 0, 2, 0, 0, 1, 1, 1, 2, 1, 0, 2, 0, 1, 2, 1, 2, 2, 0, 2, 2, 2, 0, 1, 2, 2, 2, 2, 2, 0, 0, 0, 2, 0, 1, 2, 1, 0, 0, 0])
y_prd-y_test
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0])
clf.get_params()
{'C': 1.0, 'cache_size': 200, 'class_weight': None, 'coef0': 0.0, 'decision_function_shape': 'ovr', 'degree': 3, 'gamma': 'auto_deprecated', 'kernel': 'linear', 'max_iter': -1, 'probability': True, 'random_state': None, 'shrinking': True, 'tol': 0.001, 'verbose': False}
clf.C
1.0
clf.set_params(C=2)
SVC(C=2, cache_size=200, class_weight=None, coef0=0.0, decision_function_shape='ovr', degree=3, gamma='auto_deprecated', kernel='linear', max_iter=-1, probability=True, random_state=None, shrinking=True, tol=0.001, verbose=False)
clf.get_params()
{'C': 2, 'cache_size': 200, 'class_weight': None, 'coef0': 0.0, 'decision_function_shape': 'ovr', 'degree': 3, 'gamma': 'auto_deprecated', 'kernel': 'linear', 'max_iter': -1, 'probability': True, 'random_state': None, 'shrinking': True, 'tol': 0.001, 'verbose': False}
# 打印1-5行的全部列 clf.predict_proba(X_test)[0:5, :]
array([[0.02073772, 0.94985386, 0.02940841], [0.93450081, 0.04756914, 0.01793006], [0.00769491, 0.90027802, 0.09202706], [0.96549643, 0.02213395, 0.01236963], [0.01035414, 0.91467105, 0.07497481]])
# 查看模型得分,此處爲準確率 clf.score(X_test, y_test)
0.96
測試模型則是在第二部分說的,使用模型性能度量工具測試模型的性能。上一節的score其實就是一種度量模型性能的工具,可是score只是對模型作了一個簡單的評估,咱們一般使用sklearn.metircs下的模塊度量模型性能;使用sklearn.model_selection下的模塊評估模型的泛化能力。
from sklearn.metrics import classification_report print(classification_report(y, clf.predict(X), target_names=iris.target_names))
precision recall f1-score support setosa 1.00 1.00 1.00 50 versicolor 1.00 0.96 0.98 50 virginica 0.96 1.00 0.98 50 micro avg 0.99 0.99 0.99 150 macro avg 0.99 0.99 0.99 150 weighted avg 0.99 0.99 0.99 150
此處使用k折交叉驗證度量模型性能。
k折交叉驗證:
下圖爲10折交叉驗證示意圖。
from sklearn.model_selection import cross_val_score # 10個模型的各自得分 scores = cross_val_score(clf, X, y, cv=10) scores
array([1. , 1. , 1. , 1. , 0.86666667, 1. , 0.93333333, 1. , 1. , 1. ])
# 平均得分和置信區間 print('準確率:{:.4f}(+/-{:.4f})'.format(scores.mean(), scores.std()*2))
準確率:0.9800(+/-0.0854)
訓練並測試模型已經讓咱們獲得了最優的參數,優化模型其實至關於找出可以使得模型性能最好的超參數,也能夠理解成咱們的驗證集的做用,此處咱們將經過網格搜索法優化模型,獲得相對最好的一組超參數。
from sklearn.svm import SVC from sklearn.model_selection import GridSearchCV # 模型 svc = SVC() # 超參數列表,總共會驗證4*4+4=20次,'linear'是線性核,線性核超參數有一個'C';rbf'是高斯核,高斯核有兩個超參數'C'和'gamma' param_grid = [{'C': [0.1, 1, 10, 20], 'kernel':['linear']}, {'C': [0.1, 1, 10, 20], 'kernel':['rbf'], 'gamma':[0.1, 1, 10, 20]}] # 打分函數 scoring = 'accuracy' clf = GridSearchCV(estimator=svc, param_grid=param_grid, scoring=scoring, cv=10) clf = clf.fit(X, y) clf.predict(X)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
clf.get_params()
{'cv': 10, 'error_score': 'raise-deprecating', 'estimator__C': 1.0, 'estimator__cache_size': 200, 'estimator__class_weight': None, 'estimator__coef0': 0.0, 'estimator__decision_function_shape': 'ovr', 'estimator__degree': 3, 'estimator__gamma': 'auto_deprecated', 'estimator__kernel': 'rbf', 'estimator__max_iter': -1, 'estimator__probability': False, 'estimator__random_state': None, 'estimator__shrinking': True, 'estimator__tol': 0.001, 'estimator__verbose': False, 'estimator': SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, decision_function_shape='ovr', degree=3, gamma='auto_deprecated', kernel='rbf', max_iter=-1, probability=False, random_state=None, shrinking=True, tol=0.001, verbose=False), 'fit_params': None, 'iid': 'warn', 'n_jobs': None, 'param_grid': [{'C': [0.1, 1, 10, 20], 'kernel': ['linear']}, {'C': [0.1, 1, 10, 20], 'kernel': ['rbf'], 'gamma': [0.1, 1, 10, 20]}], 'pre_dispatch': '2*n_jobs', 'refit': True, 'return_train_score': 'warn', 'scoring': 'accuracy', 'verbose': 0}
# 查看最優的一組超參數 clf.best_params_
{'C': 10, 'kernel': 'linear'}
# 查看最優超參數下模型的準確率 clf.best_score_
0.98
使用網格搜索獲得的模型的準確率有0.98,已是比較好的一個模型了,獲得這個模型以後,咱們怎麼樣才能作到下次再使用呢?通常會經過持久化模型的方式把上述模型保存到.plk文件中,下次從.plk文件中取出直接使用便可,一般持久化的方式只有兩種,一種是經過Python自帶pickle庫,另外一種是經過sklearn庫下的joblib模塊。
import pickle # 使用pickle模塊把模型序列化成字符串 pkl_str = pickle.dumps(clf) pkl_str[0:100]
b'\x80\x03csklearn.model_selection._search\nGridSearchCV\nq\x00)\x81q\x01}q\x02(X\x07\x00\x00\x00scoringq\x03X\x08\x00\x00\x00accuracyq\x04X\t\x00\x00\x00estimato'
# 使用pickel模塊反序列化字符串成爲模型 clf2 = pickle.loads(pkl_str) clf2.predict(X)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
from sklearn.externals import joblib # 保存模型到clf.pkl文件內 joblib.dump(clf, 'clf.pkl') # 從clf.pkl文件內加載模型 clf3 = joblib.load('clf.pkl') clf3.predict(X)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])