K近鄰算法的Python實現

做爲『十大機器學習算法』之一的K-近鄰(K-Nearest Neighbors)算法是思想簡單、易於理解的一種分類和迴歸算法。今天,咱們來一塊兒學習KNN算法的基本原理,並用Python實現該算法,最後,經過一個案例闡述其應用價值。html


KNN算法的直觀理解

它基於這樣的簡單假設:彼此靠近的點更有可能屬於同一個類別。用大俗話來講就是『臭味相投』,或者說『近朱者赤,近墨者黑』。python


它並未試圖創建一個顯示的預測模型,而是直接經過預測點的臨近訓練集點來肯定其所屬類別。git

K近鄰算法的實現主要基於三大基本要素:算法

  • K的選擇;bash

  • 距離度量方法的肯定;markdown

  • 分類決策規則。app

下面,即圍繞這三大基本要素,探究它的分類實現原理。dom


KNN算法的原理

算法步驟

K近鄰算法的實施步驟以下:機器學習

  1. 根據給定的距離度量,在訓練集TT中尋找出與xx最近鄰的kk個點,涵蓋這kk個點的xx的鄰域記做Nk(x)Nk(x);函數

  2. 在Nk(x)Nk(x)中根據分類決策規則決定樣本的所屬類別yy:

y=arg maxcj∑xi∈Nk(x)I(yi=cj),i=1,2,⋯,N;j=1,2,⋯,K.y=arg maxcj∑xi∈Nk(x)I(yi=cj),i=1,2,⋯,N;j=1,2,⋯,K.


K的選擇

K近鄰算法對K的選擇很是敏感。K值越小意味着模型複雜度越高,從而容易產生過擬合;K值越大則意味着總體的模型變得簡單,學習的近似近似偏差會增大。

在實際的應用中,通常採用一個比較小的K值。並採用交叉驗證的方法,選取一個最優的K值。


距離度量

距離度量通常採用歐式距離。也能夠根據須要採用LpLp距離或明氏距離。


分類決策規則

K近鄰算法中的分類決策多采用多數表決的方法進行。它等價於尋求經驗風險最小化。


但這個規則存在一個潛在的問題:有可能多個類別的投票數同爲最高。這個時候,究竟應該判爲哪個類別?

能夠經過如下幾個途徑解決該問題:

  • 從投票數相同的最高類別中隨機地選擇一個;

  • 經過距離來進一步給票數加權;

  • 減小K的個數,直到找到一個惟一的最高票數標籤。


KNN算法的優缺點

優勢

  • 精度高

  • 對異常值不敏感

  • 沒有對數據的分佈假設


缺點

  • 計算複雜度高

  • 在高維狀況下,會遇到『維數詛咒』的問題


KNN算法的算法實現

import os os.chdir('D:\\my_python_workfile\\Project\\Writting')os.getcwd()
'D:\\my_python_workfile\\Project\\Writting'
from __future__ import divisionfrom collections import Counter#from linear_algebra import distance#from statistics import meanimport math, randomimport matplotlib.pyplot as plt# 定義投票函數def raw_majority_vote(labels):
    votes = Counter(labels)
    winner,_ = votes.most_common(1)[0]
    return winner
複製代碼

以上的投票函數存在潛在的問題:有可能多個類別的投票數同爲最高。


下面的函數則實現瞭解決方案中的第三種分類決策方法。

# def majority_vote(labels):
    """assumes that labels are ordered from nearest to farthest """
    vote_counts = Counter(labels)
    winner,winner_count = vote_counts.most_common(1)[0]
    num_winners = len([count 
                          for count in vote_counts.values()
                          if count == winner_count])
    if num_winners == 1:
        return winner
    else:
        return majority_vote(labels[:-1]) # try again wthout the farthest
    
# define distance functionimport math#### 減法定義def vector_substract(v,w):
    """substracts coresponding elements"""
    return [v_i - w_i
                   for v_i,w_i in zip(v,w)]def squared_distance(v,w):
    """"""
    return sum_of_squares(vector_substract(v,w))def distance(v,w):
    return math.sqrt(squared_distance(v,w))########################################### define sum_of_squares### 向量的點乘def dot(v,w):
    return sum(v_i * w_i 
                      for v_i,w_i in zip(v,w))### 向量的平房和def sum_of_squares(v):
    """v_1*v_1+v_2*v_2+...+v_n*v_n"""
    return dot(v,v)
# classifierdef knn_classify(k,labeled_points,new_point):
    """each labeled point should be a pair (point,label)"""
    
    # order the labeled points from nearest to farthest
    by_distance = sorted(labeled_points,
                        key = lambda (point,_):distance(point,new_point))
    
    # find the labels for the k cloest
    k_nearest_labels = [label for _,label in by_distance[:k]]
    
    # and let them vote
    return majority_vote(k_nearest_labels)
複製代碼

KNN算法的應用:

案例分析

# cities = [(-86.75,33.5666666666667,'Python'),(-88.25,30.6833333333333,'Python'),(-112.016666666667,33.4333333333333,'Java'),
          (-110.933333333333,32.1166666666667,'Java'),(-92.2333333333333,34.7333333333333,'R'),(-121.95,37.7,'R'),
          (-118.15,33.8166666666667,'Python'), (-118.233333333333,34.05,'Java'),(-122.316666666667,37.8166666666667,'R'),
          (-117.6,34.05,'Python'),(-116.533333333333,33.8166666666667,'Python'),
          (-121.5,38.5166666666667,'R'),(-117.166666666667,32.7333333333333,'R'),(-122.383333333333,37.6166666666667,'R'),
          (-121.933333333333,37.3666666666667,'R'),(-122.016666666667,36.9833333333333,'Python'),
          (-104.716666666667,38.8166666666667,'Python'),(-104.866666666667,39.75,'Python'),(-72.65,41.7333333333333,'R'),
          (-75.6,39.6666666666667,'Python'),(-77.0333333333333,38.85,'Python'),(-80.2666666666667,25.8,'Java'),
          (-81.3833333333333,28.55,'Java'),(-82.5333333333333,27.9666666666667,'Java'),(-84.4333333333333,33.65,'Python'),
          (-116.216666666667,43.5666666666667,'Python'),(-87.75,41.7833333333333,'Java'),(-86.2833333333333,39.7333333333333,'Java'),
          (-93.65,41.5333333333333,'Java'),(-97.4166666666667,37.65,'Java'),(-85.7333333333333,38.1833333333333,'Python'),
          (-90.25,29.9833333333333,'Java'),(-70.3166666666667,43.65,'R'),(-76.6666666666667,39.1833333333333,'R'),
          (-71.0333333333333,42.3666666666667,'R'),(-72.5333333333333,42.2,'R'),(-83.0166666666667,42.4166666666667,'Python'),
          (-84.6,42.7833333333333,'Python'),(-93.2166666666667,44.8833333333333,'Python'),(-90.0833333333333,32.3166666666667,'Java'),
          (-94.5833333333333,39.1166666666667,'Java'),(-90.3833333333333,38.75,'Python'),(-108.533333333333,45.8,'Python'),
          (-115.166666666667,36.0833333333333,'Java'),(-71.4333333333333,42.9333333333333,'R'),(-74.1666666666667,40.7,'R'),
          (-106.616666666667,35.05,'Python'),(-78.7333333333333,42.9333333333333,'R'),(-73.9666666666667,40.7833333333333,'R'),
          (-80.9333333333333,35.2166666666667,'Python'),(-78.7833333333333,35.8666666666667,'Python'),(-100.75,46.7666666666667,'Java'),
          (-84.5166666666667,39.15,'Java'),(-81.85,41.4,'Java'),(-82.8833333333333,40,'Java'),(-97.6,35.4,'Python'),
          (-122.666666666667,45.5333333333333,'Python'),(-75.25,39.8833333333333,'Python'),(-80.2166666666667,40.5,'Python'),
          (-71.4333333333333,41.7333333333333,'R'),(-81.1166666666667,33.95,'R'),(-96.7333333333333,43.5666666666667,'Python'),
          (-90,35.05,'R'),(-86.6833333333333,36.1166666666667,'R'),(-97.7,30.3,'Python'),(-96.85,32.85,'Java'),
          (-98.4666666666667,29.5333333333333,'Java'),(-111.966666666667,40.7666666666667,'Python'),(-73.15,44.4666666666667,'R'),
          (-77.3333333333333,37.5,'Python'),(-122.3,47.5333333333333,'Python'),(-95.9,41.3,'Python'),(-95.35,29.9666666666667,'Java'),
          (-89.3333333333333,43.1333333333333,'R'),(-104.816666666667,41.15,'Java')]cities = [([longitude,latitude],language) for longitude,latitude,language in cities]
# plot_state_bordersimport resegments = []points = []lat_long_regex = r"<point lat=\"(.*)\" lng=\"(.*)\""with open("states.txt", "r") as f:
    lines = [line for line in f]for line in lines:
    if line.startswith("</state>"):
        for p1, p2 in zip(points, points[1:]):
            segments.append((p1, p2))
        points = []
    s = re.search(lat_long_regex, line)
    if s:
        lat, lon = s.groups()
        points.append((float(lon), float(lat)))def plot_state_borders(plt, color='0.8'):
    for (lon1, lat1), (lon2, lat2) in segments:
        plt.plot([lon1, lon2], [lat1, lat2], color=color)
# key is language, value is pairplots = {"Java":([],[]),"Python":([],[]),"R":([],[])}#mark and colormarkers = {"Java":"o","Python":"s","R":"^"}colors = {"Java":"r","Python":"b","R":"g"}for (logitude,latitude),language in cities:
    plots[language][0].append(logitude)
    plots[language][1].append(latitude)
    # create a scatter series for each languagefor language,(x,y) in plots.iteritems():
    plt.scatter(x,y,color = colors[language],marker = markers[language],label = language,zorder = 10)
    plot_state_borders(plt)plt.legend(loc = 0)plt.axis([-130,-60,20,55])plt.title("Favorite Programming Languages")plt.show()
複製代碼

# try several different values for kfor k in [1,3,5,7]:
    num_correct = 0
    
    for city in cities:
        location,actual_language = city
        other_cities = [other_city for other_city in cities if other_city != city]
        
        predicted_language = knn_classify(k,other_cities,location)
        
        if predicted_language == actual_language:
            num_correct += 1
            
    print k,"neighbor[s]:",num_correct,"correct out of ",len(cities)
        
1 neighbor[s]: 40 correct out of  753 neighbor[s]: 44 correct out of  755 neighbor[s]: 41 correct out of  757 neighbor[s]: 35 correct out of  75
複製代碼

閱讀原文

相關文章
相關標籤/搜索