第三節--k近鄰算法

第三節–k近鄰算法

k近鄰法(k-nearest neighbor,KNN)是一種基本分類與迴歸方法.k近鄰法的輸入爲實例的特徵變量,對應於特徵空間的點,輸出爲實例的類別,能夠取多類.k近鄰法假設給定一個訓練數據集,其中的實例類別已定,分類時,對新的實例,根據其k個最近鄰的訓練實例的類別,經過多數表決等方式進行預測.所以,k近鄰法不具備顯式的學習過程,k近鄰法實際上利用訓練數據集對特徵向量空間進行劃分,並做爲其分類的"模型".k值的選擇,距離度量分類決策規則是k近鄰法的三個基本要素html

首先敘述k近鄰算法,而後討論k近鄰法的模型及三個基本要素,最後講述k近鄰法的一個實現方法–kd樹node

一.k近鄰算法

k近鄰算法簡單,直觀:給定一個訓練數據集,對新的輸入實例,在訓練數據集中找到與該實例最鄰近的k個實例,這k個實例的多樹屬於某個類,就把該輸入實例分爲這個類python

輸入:訓練數據集
T = { ( x 1 , y 1 ) , ( x 2 , y 2 ) ,   , ( x N , y N ) } T=\left\{\left(x_{1}, y_{1}\right),\left(x_{2}, y_{2}\right), \cdots,\left(x_{N}, y_{N}\right)\right\} web

其中, x i X R n x_{i} \in \mathcal{X} \subseteq \mathbf{R}^{n} 爲實例的特徵向量, y i Y = { c 1 , c 2 ,   , c K } y_{i} \in \mathcal{Y}=\left\{c_{1}, c_{2}, \cdots, c_{K}\right\} 爲實例的類別, i = 1 , 2 ,   , N i=1,2, \cdots, N ;實例特徵向量x算法

輸出:實例x所屬的類y數據結構

  1. 根據給定的距離度量,在訓練集T中找出與x最近鄰的k個點,涵蓋這k個點的x的鄰域記做爲 N k ( x ) N_{k}(x)
  2. N k ( x ) N_{k}(x) 中根據分類決策規則(如多數表決)決定x的類別y:
    y = arg max c j x i N k ( x ) I ( y i = c j ) , i = 1 , 2 ,   , N , j = 1 , 2 ,   , K y=\arg \max _{c_{j}} \sum_{x_{i} \in N_{k}(x)} I\left(y_{i}=c_{j}\right),i=1,2, \cdots, N, \quad j=1,2, \cdots, K
    其中I爲指示函數,即當 y i = c j y_{i}=c_{j} 時I爲1,不然I爲0

k近鄰法的特殊狀況是k=1的情形,稱爲最近鄰算法,對於輸入的實例點(特徵向量)x,最近鄰法將訓練數據集中與x最近鄰點的類做爲x的類app

k近鄰法沒有顯式的學習過程dom

二.k近鄰模型

k近鄰法使用的模型實際上對應於對特徵空間的劃分,模型由三個基本要素—距離度量,k值的選擇分類決策規則決定ide

1.模型

k近鄰法中,當訓練集,距離度量(如歐式距離),k值及分類決策規則(如多數表決)肯定後.對於任何一個新的輸入實例,它所屬的類惟一地肯定.這至關於根據上述要素將特徵空間劃分爲一些子空間,肯定子空間裏的每一個點所屬的類,這一事實從最近鄰算法中能夠看得很清楚svg

特徵空間中,對每一個訓練實例點 x i x_{i} ,距離該點比其餘點更近的全部點組成一個區域,叫做單元(cell).每一個訓練實例點擁有一個單元,全部訓練實例點的單元構成對特徵空間的一個劃分,最近鄰法將實例 x i x_{i} 的類 y i y_{i} 做爲其單元中全部點的類標記(class label).這樣,每一個單元的實例點的類別是肯定的

from IPython.display import Image
Image(filename="./data/3_1.png",width=500)

在這裏插入圖片描述

2.距離度量

特徵空間中兩個實例點的距離是兩個實例點類似程度的反映.k近鄰模型的特徵空間通常是n維實數向量空間 R n \mathbf{R}^{n} .使用的距離是歐式距離,但也能夠是其餘距離.如更通常的 L p L_{p} 距離( L p L_{p} distance)或Minkowski距離(Minkowski distance)

設特徵空間 X \mathcal{X} 是n維實數向量空間 R n \mathbf{R}^{n} , x i , x j X , x i = ( x i ( 1 ) , x i ( 2 ) ,   , x i ( n ) ) , x j = ( x j ( 1 ) , x j ( 2 ) ,   , x j ( n ) ) T x_{i}, x_{j} \in \mathcal{X}, \quad x_{i}=\left(x_{i}^{(1)}, x_{i}^{(2)}, \cdots, x_{i}^{(n)}\right)^{\top},x_{j}=\left(x_{j}^{(1)}, x_{j}^{(2)}, \cdots, x_{j}^{(n)}\right)^{\mathrm{T}} , x i , x j x_{i}, x_{j} L p L_{p} 距離定義爲:
L p ( x i , x j ) = ( i = 1 n x i ( l ) x j ( l ) p ) 1 p L_{p}\left(x_{i}, x_{j}\right)=\left(\sum_{i=1}^{n}\left|x_{i}^{(l)}-x_{j}^{(l)}\right|^{p}\right)^{\frac{1}{p}}

這裏p≥1.當p=2時,稱爲歐式距離(Euclidean distance),即:
L 2 ( x i , x j ) = ( i = 1 n x i ( l ) x j ( l ) 2 ) 1 2 L_{2}\left(x_{i}, x_{j}\right)=\left(\sum_{i=1}^{n}\left|x_{i}^{(l)}-x_{j}^{(l)}\right|^{2}\right)^{\frac{1}{2}}

當p=1時,稱爲曼哈頓距離(Manhattan distance),即:
L 1 ( x i , x j ) = l = 1 n x i ( l ) x j ( l ) L_{1}\left(x_{i}, x_{j}\right)=\sum_{l=1}^{n}\left|x_{i}^{(l)}-x_{j}^{(l)}\right|

p = p=\infty 時,稱爲閔式距離(Minkowski distance),它是各個座標距離的最大值,即:
L ( x i , x j ) = max l x i ( l ) x j ( l ) L_{\infty}\left(x_{i}, x_{j}\right)=\max _{l}\left|x_{i}^{(l)}-x_{j}^{(l)}\right|

下圖給出了二維空間中p取不一樣值時,與原點的 L p L_{p} 距離爲1( L p L_{p} =1)的點的圖片

Image(filename="./data/3_2.png",width=500)

在這裏插入圖片描述

下面的例子說明,由不一樣的距離度量所肯定的最近鄰點是不一樣的

實例1:已知二維空間的3個點 x 1 = ( 1 , 1 ) T , x 2 = ( 5 , 1 ) T , x 3 = ( 4 , 4 ) T x_{1}=(1,1)^{\mathrm{T}}, x_{2}=(5,1)^{\mathrm{T}}, x_{3}=(4,4)^{\mathrm{T}} ,試求在p取不一樣值時, L p L_{p} 距離下 x 1 x_{1} 的最近鄰點

Image(filename="./data/3_3.png",width=500)

在這裏插入圖片描述

因而獲得:p等於1或2時, x 2 x_{2} x 1 x_{1} 的最近鄰點;p大於等於3時, x 3 x_{3} x 1 x_{1} 的最近鄰點

3.k值的選擇

k值的選擇會對k近鄰法的結果產生重大影響

若是選擇較小的k值,就至關於用較小的鄰域中的訓練實例進行預測,"學習"的近似偏差(approximation error)會減少,只有與輸入實例較近的訓練實例纔會對預測結果其做用,但缺點是"學習"的估計偏差(estimation error)會增大,預測結果會近鄰的實例點很是敏感.若是鄰近的實例點恰巧是噪聲.預測就會出錯.換句話說,k值的減少就意味着總體模型變得複雜,容易發生過擬合

若是選擇較大的k值,就至關於用較大鄰域中的訓練實例進行預測,其優勢是能夠減小學習的估計偏差,但缺點是學習的近似偏差會增大.這時與輸入實例較遠訓練實例也會對預測其做用,使預測發生錯誤,k值的增大就意味着總體的模型變得簡單

4.分類決策規則

k近鄰法中的分類決策規則每每是多數表決,即由輸入實例的k個鄰近的訓練實例中的多數類決定輸入實例的類

多數表決規則(majority voting rule)有以下解釋:若是分類的損失函數爲0-1損失函數,分類函數爲:
f : R n { c 1 , c 2 ,   , c K } f : \mathbf{R}^{n} \rightarrow\left\{c_{1}, c_{2}, \cdots, c_{K}\right\}

那麼誤分類的機率是:
P ( Y f ( X ) ) = 1 P ( Y = f ( X ) ) P(Y \neq f(X))=1-P(Y=f(X))

對給定的實例 x X x \in \mathcal{X} ,其最近鄰的k個訓練實例點構成集合 N k ( x ) N_{k}(x) .若是涵蓋 N k ( x ) N_{k}(x) 的區域的類別是 c j c_{j} ,那麼誤分類率是:
1 k x i N k ( x ) I ( y i c j ) = 1 1 k x i N k ( x ) I ( y i = c j ) \frac{1}{k} \sum_{x_{i} \in N_{k}(x)} I\left(y_{i} \neq c_{j}\right)=1-\frac{1}{k} \sum_{x_{i} \in N_{k}(x)} I\left(y_{i}=c_{j}\right)

要使誤分類率最小即檢驗風險最小,就要使 x i N k ( x ) I ( y i = c j ) \sum_{x_{i} \in N_{k}(x)} I\left(y_{i}=c_{j}\right) 最大,因此多數表決規則等價於經驗風險最小化

三.k近鄰法的實現:kd樹

實現k近鄰法時,主要考慮的問題是如何對訓練數據進行快速k近鄰搜索.這點在特徵空間的維數大及訓練數據容量大時尤爲必要

k近鄰法最簡單的實現方法是線性掃描(linear scan).這時要計算輸入實例與每個訓練實例的距離.當訓練集很大時,計算很是耗時,這種方法是不可行的

爲了提升k近鄰搜索的效率,能夠考慮使用特殊的結構存儲訓練數據,以減小計算距離的次數.kd樹(kd tree)方法就是一種

1.構造kd樹

kd樹是一種對k維空間中的實例點進行存儲以便對其進行快速檢索的樹形數據結構.kd樹是二叉樹,表示對k維空間的一個劃分(partition).構造kd樹至關於不斷地用垂直於座標軸的超平面將k維空間劃分,構成一系列的k維超矩形區域.kd樹的每一個結點對應於一個k維矩形區域

構造kd樹的方法以下:構造根結點,使根結點對應於k維空間中包含全部實例點的超矩形區域;經過下面的遞歸方法,不斷地對k維空間進行切分,生成子結點,在超矩形區域(結點)上選擇一個座標軸和在此座標軸上的一個切分點,肯定一個超平面,這個超平面經過選定的切分點並垂直於選定的座標軸.將當前超矩形區域切分爲左右兩個子區域(子結點);這時實例被分到兩個子區域,這個過程直到子區域內沒有實例時中止(終止時的結點爲葉結點).在此過程當中,將實例保存在相應的結點上

一般,依次選擇座標軸對空間切分.選擇訓練實例點在選定座標軸上的中位數(median)爲切分點,這樣獲得的kd樹是平衡的.注意,平衡的kd樹搜索時的效率未必是最優的

構造平衡kd樹

輸入:k維空間數據集 T = { x 1 , x 2 ,   , x N } T=\left\{x_{1}, x_{2}, \cdots, x_{N}\right\}
其中 x i = ( x i ( 0 ) , x i ( 2 ) ,   , x i ( k ) ) T , i = 1 , 2 ,   , N x_{i}=\left(x_{i}^{(0)}, x_{i}^{(2)}, \cdots, x_{i}^{(k)}\right)^{\mathrm{T}}, \quad i=1,2, \cdots, N

輸出:kd樹

  1. 開始:構造根結點,根結點對應於包含 T T 的k維空間的超矩形區域
    選擇 x ( 1 ) x^{(1)} 爲座標軸,以 T T 中全部實例的 x ( 1 ) x^{(1)} 的中位數爲切分點,將根結點對應的超矩形區域切分爲兩個子區域.切分由經過切分點並與座標軸 x ( 1 ) x^{(1)} 垂直的超平面

由根結點生成深度爲1的左右子結點;左子結點對應座標 x ( 1 ) x^{(1)} 小於切分點的子區域,右子結點對應於座標 x ( 1 ) x^{(1)} 大於切分點的子區域

將落在切分超平面上的實例點保存在根結點

  1. 重複:對深度爲j的結點,選擇 x ( l ) x^{(l)} 爲切分的座標軸, l = j (   m o d   k ) + 1 l=j(\bmod k)+1 以該結點的區域中全部實例的 x ( l ) x^{(l)} 座標的中位數爲切分點,將該結點對應的超矩形區域切分爲兩個子區域,切分由經過切分點並與座標軸 x ( l ) x^{(l)} 垂直的超平面實現

由該結點生成深度爲j+1的左右子結點:左子結點對應座標 x ( l ) x^{(l)} 小於切分點的子區域,右子結點對應座標 x ( l ) x^{(l)} 大於切分點的子區域

將落在切分超平面上的實例點保存在該結點

  1. 直到兩個子區域沒有實例存在時中止,從而造成kd樹的區域劃分

實例2:給定一個二維空間的數據集
T = { ( 2 , 3 ) T , ( 5 , 4 ) T , ( 9 , 6 ) T , ( 4 , 7 ) T , ( 8 , 1 ) T , ( 7 , 2 ) T } T=\left\{(2,3)^{\mathrm{T}},(5,4)^{\mathrm{T}},(9,6)^{\mathrm{T}},(4,7)^{\mathrm{T}},(8,1)^{\mathrm{T}},(7,2)^{\mathrm{T}}\right\}
構造一個平衡kd樹

解:根結點對應包含數據集 T T 的矩形,選擇 x ( 1 ) x^{(1)} 軸,6個數據點的 x ( 1 ) x^{(1)} 座標的中位數是7(注意:2,4,5,7,8,9在數學中的中位數爲6,但因該算法的中值需在點集合以內,因此中值計算用的是len(points)/2=3,points[3]=(7,2)),以平面 x ( 1 ) = 7 x^{(1)}=7 將空間分爲左右兩個子矩形(子結點);接着,左矩形以 x ( 2 ) = 4 x^{(2)}=4 分爲兩個子矩形,右矩形以 x ( 2 ) = 6 x^{(2)}=6 分爲兩個子矩形,如此遞歸,最後獲得如圖所示的特徵空間劃分和kd樹

Image(filename="./data/3_5.png",width=500)

在這裏插入圖片描述

Image(filename="./data/3_4.png",width=500)

在這裏插入圖片描述

2.搜索kd樹

利用kd樹能夠省去對大部分數據點的搜索,從而減小搜索的計算量,這裏以最近鄰爲例加以敘述,一樣的方法能夠應用到k近鄰

給定一個目標點,搜索其最近鄰.首先找到包含目標點的葉結點;而後從該葉結點出發,依次回退到父結點;不斷查找與目標點最近鄰的結點,當肯定不可能存在更近的結點時終止,這樣搜索就被限制在空間的局部區域上,效率大爲提升

用kd樹的最近鄰搜索

輸入:已構造的kd樹,目標點x
輸出:x的最近鄰

  1. 在kd樹中找出包含目標點x的葉結點;從根結點出發,遞歸地向下訪問kd樹,若目標點x當前維的座標小於切分點的座標,則移動左子結點,不然移動到右子結點,直到子結點爲葉節點爲止

  2. 以此葉節點爲"當前最近點"

  3. 遞歸地向上回退,在每一個節點進行一下操做
    a.若是該結點保存的實例點比當前最近點距離目標點更近,則以該實例點爲"當前最近點"
    b.當前最近點必定存在於該結點一個子結點對應的區域.檢查該子結點的父結點的另外一個結點對應的區域是否有更近的點.具體地,檢查另外一子結點對應的區域是否與以目標點爲球心,以目標點與"當前最近點"間的距離爲半徑的超球體相交
    若是相交,可能在另外一個子結點對應的區域內存在距目標點更近的點,移動到另外一個子結點.接着遞歸地進行最近鄰搜索
    若是不想交,向上回退

  4. 當回退到根結點時,搜索結束,最後的"當前最近點"即爲x的最近鄰點
    若是實例點是隨機分佈的,kd樹搜索的平均計算複雜度是 O ( log N ) O(\log N)

實例:給定一個以下圖所示的kd樹,根節點爲A,其子結點爲B,C等,樹上共存儲7個實例點;另外一個輸入目標實例點S,求S的最近鄰

Image(filename="./data/3_8.png",width=500)

在這裏插入圖片描述

首先在kd樹中找到包含點S的葉結點D(圖中的右下區域),以點D做爲近似最近鄰.真正最近鄰必定在以點S爲中心經過點D的圓的內部.而後返回結點D的父結點B,在結點B的另外一子結點F的區域內搜索最近鄰,結點F的區域與圓不想交,不可能有最近鄰點,繼續返回上一級父結點A,在結點A的另外一子結點C的區域內搜索最近鄰,結點C的區域與圓相交;該區域在園內的實例點有點E,點E比點D更近,成爲新的最近鄰近似,最後獲得點E是點S的最近鄰

四.代碼實現

1.度量距離

import math
from itertools import combinations
# p=1 Manhattan distance
# p=2 Euclidean distance
# p=3 Minkowski distance
def L(x,y,p=2):
    # x1=[1,1] x2=[5,1]
    if len(x)==len(y) and len(x)>1:
        sum=0
        for i in range(len(x)):
            sum+=math.pow(abs(x[i]-y[i]),p)
        return math.pow(sum,1/p)
    else:
        return 0
# 實例1
x1=[1,1]
x2=[5,1]
x3=[4,4]
# x1,x3
for i in range(1,5):
    r={"1-{}".format(c):L(x1,c,p=i) for c in [x2,x3]}
    print(min(zip(r.values(),r.keys())))
(4.0, '1-[5, 1]')
(4.0, '1-[5, 1]')
(3.7797631496846193, '1-[4, 4]')
(3.5676213450081633, '1-[4, 4]')

2.自定義KNN分析iris

遍歷全部數據點,找出n個距離最近的點的分類狀況,少數服從多數

%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

from collections import Counter
iris=load_iris()
df=pd.DataFrame(iris.data,columns=iris.feature_names)
df["label"]=iris.target
df.columns=["sepal length","sepal width","petal length","petal width","label"]
# data=np.array(df.iloc[:100,[0,1,-1]])
df
sepal length sepal width petal length petal width label
0 5.1 3.5 1.4 0.2 0
1 4.9 3.0 1.4 0.2 0
2 4.7 3.2 1.3 0.2 0
3 4.6 3.1 1.5 0.2 0
4 5.0 3.6 1.4 0.2 0
5 5.4 3.9 1.7 0.4 0
6 4.6 3.4 1.4 0.3 0
7 5.0 3.4 1.5 0.2 0
8 4.4 2.9 1.4 0.2 0
9 4.9 3.1 1.5 0.1 0
10 5.4 3.7 1.5 0.2 0
11 4.8 3.4 1.6 0.2 0
12 4.8 3.0 1.4 0.1 0
13 4.3 3.0 1.1 0.1 0
14 5.8 4.0 1.2 0.2 0
15 5.7 4.4 1.5 0.4 0
16 5.4 3.9 1.3 0.4 0
17 5.1 3.5 1.4 0.3 0
18 5.7 3.8 1.7 0.3 0
19 5.1 3.8 1.5 0.3 0
20 5.4 3.4 1.7 0.2 0
21 5.1 3.7 1.5 0.4 0
22 4.6 3.6 1.0 0.2 0
23 5.1 3.3 1.7 0.5 0
24 4.8 3.4 1.9 0.2 0
25 5.0 3.0 1.6 0.2 0
26 5.0 3.4 1.6 0.4 0
27 5.2 3.5 1.5 0.2 0
28 5.2 3.4 1.4 0.2 0
29 4.7 3.2 1.6 0.2 0
... ... ... ... ... ...
120 6.9 3.2 5.7 2.3 2
121 5.6 2.8 4.9 2.0 2
122 7.7 2.8 6.7 2.0 2
123 6.3 2.7 4.9 1.8 2
124 6.7 3.3 5.7 2.1 2
125 7.2 3.2 6.0 1.8 2
126 6.2 2.8 4.8 1.8 2
127 6.1 3.0 4.9 1.8 2
128 6.4 2.8 5.6 2.1 2
129 7.2 3.0 5.8 1.6 2
130 7.4 2.8 6.1 1.9 2
131 7.9 3.8 6.4 2.0 2
132 6.4 2.8 5.6 2.2 2
133 6.3 2.8 5.1 1.5 2
134 6.1 2.6 5.6 1.4 2
135 7.7 3.0 6.1 2.3 2
136 6.3 3.4 5.6 2.4 2
137 6.4 3.1 5.5 1.8 2
138 6.0 3.0 4.8 1.8 2
139 6.9 3.1 5.4 2.1 2
140 6.7 3.1 5.6 2.4 2
141 6.9 3.1 5.1 2.3 2
142 5.8 2.7 5.1 1.9 2
143 6.8 3.2 5.9 2.3 2
144 6.7 3.3 5.7 2.5 2
145 6.7 3.0 5.2 2.3 2
146 6.3 2.5 5.0 1.9 2
147 6.5 3.0 5.2 2.0 2
148 6.2 3.4 5.4 2.3 2
149 5.9 3.0 5.1 1.8 2

150 rows × 5 columns

plt.scatter(df[:50]["sepal length"],df[:50]["sepal width"],label="0")
plt.scatter(df[50:100]["sepal length"],df[50:100]["sepal width"],label="1")
plt.xlabel("sepal length")
plt.ylabel("sepal width")
plt.legend()
<matplotlib.legend.Legend at 0x18193faffd0>

在這裏插入圖片描述

data=np.array(df.iloc[:100,[0,1,-1]])
X,y=data[:,:-1],data[:,-1]
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2)
class KNN(object):
    def __init__(self,X_train,y_train,n_neighbors=3,p=2):
        """ parameter:n_neighbors 臨近點個數 parameter:p 距離度量 """
        self.n=n_neighbors
        self.p=p
        self.X_train=X_train
        self.y_train=y_train
        
    def predict(self,X):
        # 取出n個點
        knn_list=[]
        for i in range(self.n):
            dist=np.linalg.norm(X-self.X_train[i],ord=self.p)
            knn_list.append((dist,self.y_train[i]))
            
        for i in range(self.n,len(self.X_train)):
            max_index=knn_list.index(max(knn_list,key=lambda x:x[0]))
            dist=np.linalg.norm(X-self.X_train[i],ord=self.p)
            if knn_list[max_index][0]>dist:
                knn_list[max_index]=(dist,self.y_train[i])
                
        # 統計
        knn=[k[-1] for k in knn_list]
        count_pairs=Counter(knn)
        max_count=sorted(count_pairs,key=lambda x:x)[-1]
        return max_count
    
    def score(self,X_test,y_test):
        right_count=0
        n=10
        for X,y in zip(X_test,y_test):
            label=self.predict(X)
            if label==y:
                right_count+=1
        return right_count/len(X_test)
clf=KNN(X_train,y_train)
clf.score(X_test,y_test)
1.0
test_point=[6.0,3.0]
print("Test Point:{}".format(clf.predict(test_point)))
Test Point:1.0
plt.scatter(df[:50]["sepal length"],df[:50]["sepal width"],label="0")
plt.scatter(df[50:100]["sepal length"],df[50:100]["sepal width"],label="1")
plt.plot(test_point[0],test_point[1],"bo",label="test_point")
plt.xlabel("sepal length")
plt.ylabel("sepal width")
plt.legend()
<matplotlib.legend.Legend at 0x181944c5588>

在這裏插入圖片描述

3.sklearn實現KNN

from sklearn.neighbors import KNeighborsClassifier

clf_sk=KNeighborsClassifier()
clf_sk.fit(X_train,y_train)
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=None, n_neighbors=5, p=2,
           weights='uniform')
clf_sk.score(X_test,y_test)
1.0

sklearn.neighbors.KNeighborsClassifier主要參數說明:

  • n_neighbors:臨近點個數
  • p:度量距離
  • algorithm:近鄰算法,可選{「auto」,「ball_tree」,「kd_tree」,「brute」}
  • weights:肯定近鄰的權重

4.kd樹

# kd-tree每一個結點中主要包含的數據結構以下 
class KdNode(object):
    def __init__(self, dom_elt, split, left, right):
        self.dom_elt = dom_elt  # k維向量節點(k維空間中的一個樣本點)
        self.split = split      # 整數(進行分割維度的序號)
        self.left = left        # 該結點分割超平面左子空間構成的kd-tree
        self.right = right      # 該結點分割超平面右子空間構成的kd-tree
 
 
class KdTree(object):
    def __init__(self, data):
        k = len(data[0])  # 數據維度
        
        def CreateNode(split, data_set): # 按第split維劃分數據集exset建立KdNode
            if not data_set:    # 數據集爲空
                return None
            # key參數的值爲一個函數,此函數只有一個參數且返回一個值用來進行比較
            # operator模塊提供的itemgetter函數用於獲取對象的哪些維的數據,參數爲須要獲取的數據在對象中的序號
            #data_set.sort(key=itemgetter(split)) # 按要進行分割的那一維數據排序
            data_set.sort(key=lambda x: x[split])
            split_pos = len(data_set) // 2      # //爲Python中的整數除法
            median = data_set[split_pos]        # 中位數分割點 
            split_next = (split + 1) % k        # cycle coordinates
            
            # 遞歸的建立kd樹
            return KdNode(median, split, 
                          CreateNode(split_next, data_set[:split_pos]),     # 建立左子樹
                          CreateNode(split_next, data_set[split_pos + 1:])) # 建立右子樹
                                
        self.root = CreateNode(0, data)         # 從第0維份量開始構建kd樹,返回根節點


# KDTree的前序遍歷
def preorder(root):  
    print (root.dom_elt)  
    if root.left:      # 節點不爲空
        preorder(root.left)  
    if root.right:  
        preorder(root.right)
# 對構建好的kd樹進行搜索,尋找與目標點最近的樣本點:
from math import sqrt
from collections import namedtuple

# 定義一個namedtuple,分別存放最近座標點、最近距離和訪問過的節點數
result = namedtuple("Result_tuple", "nearest_point nearest_dist nodes_visited")
  
def find_nearest(tree, point):
    k = len(point) # 數據維度
    def travel(kd_node, target, max_dist):
        if kd_node is None:     
            return result([0] * k, float("inf"), 0) # python中用float("inf")和float("-inf")表示正負無窮
 
        nodes_visited = 1
        
        s = kd_node.split        # 進行分割的維度
        pivot = kd_node.dom_elt  # 進行分割的「軸」
        
        if target[s] <= pivot[s]:           # 若是目標點第s維小於分割軸的對應值(目標離左子樹更近)
            nearer_node  = kd_node.left     # 下一個訪問節點爲左子樹根節點
            further_node = kd_node.right    # 同時記錄下右子樹
        else:                               # 目標離右子樹更近
            nearer_node  = kd_node.right    # 下一個訪問節點爲右子樹根節點
            further_node = kd_node.left
 
        temp1 = travel(nearer_node, target, max_dist)  # 進行遍歷找到包含目標點的區域
        
        nearest = temp1.nearest_point       # 以此葉結點做爲「當前最近點」
        dist = temp1.nearest_dist           # 更新最近距離
        
        nodes_visited += temp1.nodes_visited  
 
        if dist < max_dist:     
            max_dist = dist    # 最近點將在以目標點爲球心,max_dist爲半徑的超球體內
            
        temp_dist = abs(pivot[s] - target[s])    # 第s維上目標點與分割超平面的距離
        if  max_dist < temp_dist:                # 判斷超球體是否與超平面相交
            return result(nearest, dist, nodes_visited) # 不相交則能夠直接返回,不用繼續判斷
            
        #---------------------------------------------------------------------- 
        # 計算目標點與分割點的歐氏距離 
        temp_dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(pivot, target)))     
        
        if temp_dist < dist:         # 若是「更近」
            nearest = pivot          # 更新最近點
            dist = temp_dist         # 更新最近距離
            max_dist = dist          # 更新超球體半徑
        
        # 檢查另外一個子結點對應的區域是否有更近的點
        temp2 = travel(further_node, target, max_dist) 
        
        nodes_visited += temp2.nodes_visited
        if temp2.nearest_dist < dist:        # 若是另外一個子結點內存在更近距離
            nearest = temp2.nearest_point    # 更新最近點
            dist = temp2.nearest_dist        # 更新最近距離
 
        return result(nearest, dist, nodes_visited)
 
    return travel(tree.root, point, float("inf"))  # 從根節點開始遞歸
# 實例2
data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
kd = KdTree(data)
preorder(kd.root)
[7, 2]
[5, 4]
[2, 3]
[4, 7]
[9, 6]
[8, 1]
from time import clock
from random import random

# 產生一個k維隨機向量,每維份量值在0~1之間
def random_point(k):
    return [random() for _ in range(k)]
 
# 產生n個k維隨機向量 
def random_points(k, n):
    return [random_point(k) for _ in range(n)]
ret = find_nearest(kd, [3,4.5])
print (ret)
Result_tuple(nearest_point=[2, 3], nearest_dist=1.8027756377319946, nodes_visited=4)
N = 400000
t0 = clock()
kd2 = KdTree(random_points(3, N))            # 構建包含四十萬個3維空間樣本點的kd樹
ret2 = find_nearest(kd2, [0.1,0.5,0.8])      # 四十萬個樣本點中尋找離目標最近的點
t1 = clock()
print ("time: ",t1-t0, "s")
print (ret2)
E:\Anaconda\envs\mytensorflow\lib\site-packages\ipykernel_launcher.py:2: DeprecationWarning: time.clock has been deprecated in Python 3.3 and will be removed from Python 3.8: use time.perf_counter or time.process_time instead
  


time:  6.159827752999263 s
Result_tuple(nearest_point=[0.09732020950704356, 0.49930092577904095, 0.8029864162744909], nearest_dist=0.004072918366121865, nodes_visited=42)


E:\Anaconda\envs\mytensorflow\lib\site-packages\ipykernel_launcher.py:5: DeprecationWarning: time.clock has been deprecated in Python 3.3 and will be removed from Python 3.8: use time.perf_counter or time.process_time instead
  """
相關文章
相關標籤/搜索