k近鄰法:一個簡單的預測模型

k近鄰法的思想

你打算預測我會在大選中投票給誰。假設你對我一無所知,一個明智的方法是看看個人鄰居們都投票給誰。固然,你可能還知道個人年齡、收入、有幾個孩子,等等,根據個人行爲受這些維度影響的程度,你能夠觀察在這些維度上最接近個人鄰居們而不是我全部的鄰居會獲得更好的預測結果。這就是最近鄰分類(nearest neighbors classification)方法背後的思想。python

k近鄰法的優缺點

k近鄰法是最簡單的預測模型之一,它沒有多少數學上的假設,也不要求任何複雜的數學處理,它所要求的僅僅是:git

  • 某種距離的概念編程

  • 彼此接近的點具備類似性質的假設(不然用近鄰來預測結果就是不合理的)app

k近鄰法有意忽略了大量信息,對每一個新的數據點的預測只依賴少許最接近它的點。
此外,它不能解釋爲何。例如,基於我鄰居的投票行爲來預測個人投票並不能告訴你我爲何要這樣投票,而基於個人收入、婚姻等因素來預測個人投票行爲的模型則能揭示我投票的緣由。編程語言

案例:最喜歡的編程語言

假設咱們有一份數據,這份數據是各個城市經緯度及該城市最受歡迎的編程語言的集合。數據以列表存儲。函數

cities = [(-86.75,33.5666666666667,'Python'),(-88.25,30.6833333333333,'Python'),(-112.016666666667,33.4333333333333,'Java')......]

數據可視化

咱們將每種語言及其經度(x)、緯度(y)按以下的格式存儲到字典中:鍵是語言,值是成對的數據。spa

plots={"Java":([],[]),"python":([],[]),"R":([],[])}

每種語言用不一樣的符號和顏色標記:rest

markers = { "Java" : "o", "Python" : "s", "R" : "^" }
colors  = { "Java" : "r", "Python" : "b", "R" : "g" }

將cities列表中的數據存放到plots字典中:code

for (longitude, latitude), language in cities:
    plots[language][0].append(longitude)
    plots[language][1].append(latitude)

咱們能夠用items方法來返回字典元素的列表:排序

In [1]: plots.items()
Out[1]:
[('Python',([-86.75, -88.25, -118.15......],[33.5666666666667, 30.6833333333333, 33.8166666666667.....])),('R'......),('Java'.....)]

這樣咱們就能夠用for語句來循環遍歷字典元素,爲每種語言建立一個散點序列:

import matplotlib.pyplot as plt    

for language, (x, y) in plots.items():
    plt.scatter(x, y, color=colors[language], marker=markers[language],
                      label=language, zorder=10)

plt.legend(loc=0) #讓matplotlib選擇一個位置
plt.axis([-130,-60,20,55]) #設置座標軸
plt.title("most popular language") #設置圖標標題
put.show()

最終效果以下:
figure_1.png

k近鄰法的python實現

下面的點是cities列表中的第一個點,這個城市最受歡迎的編程語言是python:

In [2]: cities[0]
Out[2]: ([-86.75, 33.5666666666667], 'Python')

假設咱們不知道這個城市最受歡迎的語言是什麼。根據k近鄰法的思想,爲了預測結果,(1)咱們首先須要知道這個點即這個城市與其餘全部點的距離,(2)而後找到離這個點最近的某個點或幾個點最受歡迎的編程語言是什麼,以此做爲預測結果,(3)若是是幾個點,咱們須要計算哪一種語言出現的次數最多,以此做爲預測結果。

思路清楚了,讓咱們來一步步實現吧。

先將除其餘城市存放在other_cities列表中(這裏用列表解析式遍歷全部城市,找到與該城市不一樣的全部城市):

other_cities=[other_city for other_city in cities if other_city != (cities[0][0],cities[0][1])]

按其餘城市與待預測城市之間的距離從近到遠對other_cities列表進行排序:

from linear_algebra import distance

by_distance=sorted(other_cities, key= lambda point_label:
   distance(point_label[0], cities[0][0]))

找到最近的一個城市:

In [3]: k_nearest_labels=[label for _, label in by_distance[:1]]

In [4]: k_nearest_labels
Out[4]: ['Python']

固然,咱們也能夠找到最近的3個城市、5個城市或7個城市。

In [5]: k_nearest_labels=[label for _, label in by_distance[:3]]

In [6]: k_nearest_labels
Out[6]: ['Python', 'R', 'Python']

In [7]: k_nearest_labels=[label for _, label in by_distance[:5]]

In [8]: k_nearest_labels
Out[8]: ['Python', 'R', 'Python', 'Java', 'R']

In [9]: k_nearest_labels=[label for _, label in by_distance[:7]]

In [10]: k_nearest_labels
Out[10]: ['Python', 'R', 'Python', 'Java', 'R', 'Python', 'Java']

這時,咱們須要一個計數器找到出現次數最多的語言:

def majority_vote(labels):
"""假設labels已經從最近到最遠排序"""
    vote_counts = Counter(labels)  #Counter返回的是字典,以label爲鍵,出現次數爲值
    winner, winner_count = vote_counts.most_common(1)[0] #most_common方法能夠找出vote_counts中出現次數前1(前幾由括號內參數指定)的鍵和值,以元組組織
    num_winners = len([count
               for count in vote_counts.values()
               if count == winner_count]) #計算vote_counts中前1的出現次數出現了幾回,有幾個勝出者

     if num_winners == 1:
         return winner                     # 若是隻有一個勝出者,直接返回
    else:
         return majority_vote(labels[:-1]) # 若是有幾個勝出者,排除lavels中最遠的點,再試一次

在計算距離時,咱們僅僅計算了一個點與其餘點的距離。因爲全部點的邏輯都是同樣的,這種狀況下,咱們能夠構造函數來執行同類操做。

def knn_classify(k, labeled_points, new_point):
    """k決定取最近的幾個點;labeled_points指帶標籤的點,即(point, label)的數據對,是除待預測的點以外全部的點;new_ponit爲待預測的點"""

    # 將帶標籤的點,即其餘點從近到遠排序
    by_distance = sorted(labeled_points,
                         key=lambda point_label: distance(point_label[0], new_point))

    # 找到最近的k個點
    k_nearest_labels = [label for _, label in by_distance[:k]]

    # 對每一個點進行計數
    return majority_vote(k_nearest_labels)

如今,讓咱們看看嘗試利用近鄰城市來預測每一個城市的偏心語言會獲得什麼結果:

for k in [1, 3, 5, 7]:
    num_correct = 0

    for location, actual_language in cities:

        other_cities = [other_city
                        for other_city in cities
                        if other_city != (location, actual_language)]

        predicted_language = knn_classify(k, other_cities, location)

        if predicted_language == actual_language:
            num_correct += 1

    print(k, "neighbor[s]:", num_correct, "correct out of", len(cities))

能夠看到,k=3時的預測準確率最高,59%的時間能給出正確答案:

(1, 'neighbor[s]:', 40, 'correct out of', 75)
(3, 'neighbor[s]:', 44, 'correct out of', 75)
(5, 'neighbor[s]:', 41, 'correct out of', 75)
(7, 'neighbor[s]:', 35, 'correct out of', 75)

參考資料:

Joel Grus《數據科學入門》第12章

相關文章
相關標籤/搜索