分類算法之鄰近算法:KNN(應用篇)

起步

此次使用的訓練集由 sklearn 模塊提供,關於虹膜(一種鳶尾屬植物)的數據。數組

1278644294.png

數據載入

from sklearn import datasets
iris = datasets.load_iris()

數據存儲在 .data 成員中,它是一個 (n_samples, n_features) numpy 數組:函數

print(iris.data)
# [[ 5.1  3.5  1.4  0.2]
#  [ 4.9  3.   1.4  0.2]
#  ...

它有四個特徵,萼片長度,萼片寬度,花瓣長度,花瓣寬度 (sepal length, sepal width, petal length and petal width)。測試

kahi2.jpg

它的品種分類有山鳶尾,變色鳶尾,菖蒲錦葵(Iris setosa, Iris versicolor, Iris virginica.)三種。spa

print iris.data.shape
# output:(150L, 4L)

這是一個含有 150 個數據的訓練集。code

構造 KNN 分類器

from sklearn import neighbors
knn = neighbors.KNeighborsClassifier(n_neighbors=5)

n_neighbors 參數級是指定獲取 K 個鄰近點。rem

訓練

訓練的函數通常就是 fitget

knn.fit(iris.data, iris.target)

測試

模擬一些測試數據,使用剛剛的模型進行預測:it

predict = knn.predict([[0.1, 0.2, 0.3, 0.4]])
print(predict) # output: [0]
相關文章
相關標籤/搜索