公號:碼農充電站pro
主頁:https://codeshellme.github.iohtml
上篇文章介紹了KNN 算法的原理,今天來介紹如何使用KNN 算法識別手寫數字?python
手寫數字數據集是一個用於圖像處理的數據集,這些數據描繪了 [0, 9] 的數字,咱們能夠用KNN 算法來識別這些數字。git
MNIST 是完整的手寫數字數據集,其中包含了60000 個訓練樣本和10000 個測試樣本。github
sklearn 中也有一個自帶的手寫數字數據集:算法
咱們抽出 5 個樣原本看下:shell
0,0,5,13,9,1,0,0,0,0,13,15,10,15,5,0,0,3,15,2,0,11,8,0,0,4,12,0,0,8,8,0,0,5,8,0,0,9,8,0,0,4,11,0,1,12,7,0,0,2,14,5,10,12,0,0,0,0,6,13,10,0,0,0,0 0,0,0,12,13,5,0,0,0,0,0,11,16,9,0,0,0,0,3,15,16,6,0,0,0,7,15,16,16,2,0,0,0,0,1,16,16,3,0,0,0,0,1,16,16,6,0,0,0,0,1,16,16,6,0,0,0,0,0,11,16,10,0,0,1 0,0,0,4,15,12,0,0,0,0,3,16,15,14,0,0,0,0,8,13,8,16,0,0,0,0,1,6,15,11,0,0,0,1,8,13,15,1,0,0,0,9,16,16,5,0,0,0,0,3,13,16,16,11,5,0,0,0,0,3,11,16,9,0,2 0,0,7,15,13,1,0,0,0,8,13,6,15,4,0,0,0,2,1,13,13,0,0,0,0,0,2,15,11,1,0,0,0,0,0,1,12,12,1,0,0,0,0,0,1,10,8,0,0,0,8,4,5,14,9,0,0,0,7,13,13,9,0,0,3 0,0,0,1,11,0,0,0,0,0,0,7,8,0,0,0,0,0,1,13,6,2,2,0,0,0,7,15,0,9,8,0,0,5,16,10,0,16,6,0,0,4,15,16,13,16,1,0,0,0,0,3,15,10,0,0,0,0,0,2,16,4,0,0,4
使用該數據集,須要先加載:數據結構
>>> from sklearn.datasets import load_digits >>> digits = load_digits()
查看第一個圖像數據:dom
>>> digits.images[0] array([[ 0., 0., 5., 13., 9., 1., 0., 0.], [ 0., 0., 13., 15., 10., 15., 5., 0.], [ 0., 3., 15., 2., 0., 11., 8., 0.], [ 0., 4., 12., 0., 0., 8., 8., 0.], [ 0., 5., 8., 0., 0., 9., 8., 0.], [ 0., 4., 11., 0., 1., 12., 7., 0.], [ 0., 2., 14., 5., 10., 12., 0., 0.], [ 0., 0., 6., 13., 10., 0., 0., 0.]])
咱們能夠用 matplotlib 將該圖像畫出來:函數
>>> import matplotlib.pyplot as plt >>> plt.imshow(digits.images[0]) >>> plt.show()
畫出來的圖像以下,表明 0:測試
sklearn 庫的 neighbors 模塊實現了KNN 相關算法,其中:
KNeighborsClassifier
類用於分類問題KNeighborsRegressor
類用於迴歸問題這兩個類的構造方法基本一致,這裏咱們主要介紹 KNeighborsClassifier
類,原型以下:
KNeighborsClassifier( n_neighbors=5, weights='uniform', algorithm='auto', leaf_size=30, p=2, metric='minkowski', metric_params=None, n_jobs=None, **kwargs)
來看下幾個重要參數的含義:
首先加載數據集:
from sklearn.datasets import load_digits digits = load_digits() data = digits.data # 特徵集 target = digits.target # 目標集
將數據集拆分爲訓練集(75%)和測試集(25%):
from sklearn.model_selection import train_test_split train_x, test_x, train_y, test_y = train_test_split( data, target, test_size=0.25, random_state=33)
構造KNN 分類器:
from sklearn.neighbors import KNeighborsClassifier # 採用默認參數 knn = KNeighborsClassifier()
擬合模型:
knn.fit(train_x, train_y)
預測數據:
predict_y = knn.predict(test_x)
計算模型準確度:
from sklearn.metrics import accuracy_score score = accuracy_score(test_y, predict_y) print score # 0.98
最終計算出來模型的準確度是 98%,準確度仍是不錯的。
本篇文章使用KNN 算法處理了一個實際的分類問題,主要介紹瞭如下幾點:
neighbors.KNeighborsClassifier
類的用法。KNeighborsClassifier
來識別手寫數字。(本節完。)
推薦閱讀:
歡迎關注做者公衆號,獲取更多技術乾貨。