此次使用的訓練集由 sklearn
模塊提供,關於虹膜(一種鳶尾屬植物)的數據。數組
from sklearn import datasets iris = datasets.load_iris()
數據存儲在 .data
成員中,它是一個 (n_samples, n_features) numpy
數組:函數
print(iris.data) # [[ 5.1 3.5 1.4 0.2] # [ 4.9 3. 1.4 0.2] # ...
它有四個特徵,萼片長度,萼片寬度,花瓣長度,花瓣寬度 (sepal length, sepal width, petal length and petal width)。測試
它的品種分類有山鳶尾,變色鳶尾,菖蒲錦葵(Iris setosa, Iris versicolor, Iris virginica.)三種。spa
print iris.data.shape # output:(150L, 4L)
這是一個含有 150 個數據的訓練集。code
from sklearn import neighbors knn = neighbors.KNeighborsClassifier(n_neighbors=5)
n_neighbors
參數級是指定獲取 K 個鄰近點。rem
訓練的函數通常就是 fit
:get
knn.fit(iris.data, iris.target)
模擬一些測試數據,使用剛剛的模型進行預測:it
predict = knn.predict([[0.1, 0.2, 0.3, 0.4]]) print(predict) # output: [0]