KNN

1、KNN分類算法python

  K最近鄰(K-Nearest Neighbor,KNN)算法,是著名的模式識別統計學方法,在機器學習分類算法中佔有至關大的地位。它是一個理論上比較成熟的方法。既是最簡單的機器學習算法之一,也是基於實例的學習方法中最基本的,又是最好的文本分類算法之一。算法

  一般,在分類任務中可以使用「投票法」,即選擇這k個實例中出現最多的標記類別做爲預測結果;在迴歸任務中可以使用「平均法」,即將這k個實例的實值輸出標記的平均值做爲預測結果;還可基於距離遠近進行加權平均或加權投票,距離越近的實例權重越大。app

2、算法圖示python2.7

◊ 從訓練集中找到和新數據最接近的k條記錄,而後根據多數類來決定新數據類別。機器學習

◊算法涉及3個主要因素:學習

1) 訓練數據集spa

2) 距離或類似度的計算衡量rest

3) k的大小code

 

◊算法描述blog

1) 已知兩類「先驗」數據,分別是藍方塊和紅三角,他們分佈在一個二維空間中

2) 有一個未知類別的數據(綠點),須要判斷它是屬於「藍方塊」仍是「紅三角」類

3) 考察離綠點最近的3個(或k個)數據點的類別,佔多數的類別即爲綠點斷定類別

 

3、KNN分類算法python實現(python2.7)

需求:

有如下先驗數據,使用knn算法對未知類別數據分類

屬性1

屬性2

類別

1.0

0.9

A

1.0

1.0

A

0.1

0.2

B

0.0

0.1

B

 

未知類別數據

屬性1

屬性2

類別

1.2

1.0

?

0.1

0.3

?

 

KNN.py

# coding=utf-8from numpy import *
import operator

def createDataSet():
    group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
    labels = ['A', 'A', 'B', 'B']
    return group, labels

def kNNClassify(newInput, dataSet, labels, k):
    numSamples = dataSet.shape[0]   
    diff = tile(newInput, (numSamples, 1)) - dataSet  
    squaredDiff = diff ** 2
    squaredDist = sum(squaredDiff, axis = 1)
    distance = squaredDist ** 0.5
    sortedDistIndices = argsort(distance)
    classCount = {} # define a dictionary (can be append element)
    for i in xrange(k):
        voteLabel = labels[sortedDistIndices[i]]
        classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
    maxCount = 0
    for key, value in classCount.items():
        if value > maxCount:
            maxCount = value
            maxIndex = key

    return maxIndex

 

KNNTest.py

#!/usr/bin/python
# coding=utf-8
from KNN import KNN
from numpy import *
dataSet, labels = KNN.createDataSet()
testX = array([1.2, 1.0])
k = 3
outputLabel = KNN.kNNClassify(testX, dataSet, labels, 3)
print "Your input is:", testX, "and classified to class: ", outputLabel

testX = array([0.1, 0.3])
outputLabel = KNN.kNNClassify(testX, dataSet, labels, 3)
print "Your input is:", testX, "and classified to class: ", outputLabel

 

結果:

Your input is: [1.2 1. ] and classified to class:  A
Your input is: [0.1 0.3] and classified to class:  B
相關文章
相關標籤/搜索