Knn算法實現

Knn算法實現javascript

 

 

 

k近鄰算法

 

0.引入依賴

In [8]:
import numpy as np
import pandas as pd

#這裏直接引入sklearn裏面的數據集,iris 鳶尾花
from sklearn.datasets  import  load_iris
from sklearn.model_selection import train_test_split   # 切分數據集爲訓練集和測試集
from sklearn.metrics import accuracy_score   #計算分類預測的準確率
 

1.數據加載和預處理

In [23]:
iris = load_iris()
df = pd.DataFrame(data=iris.data, columns = iris.feature_names)
df['class'] = iris.target
df['class'] = df['class'].map( lambda  i:iris.target_names[i] )
df.describe()
Out[23]:
 
  sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)
count 150.000000 150.000000 150.000000 150.000000
mean 5.843333 3.057333 3.758000 1.199333
std 0.828066 0.435866 1.765298 0.762238
min 4.300000 2.000000 1.000000 0.100000
25% 5.100000 2.800000 1.600000 0.300000
50% 5.800000 3.000000 4.350000 1.300000
75% 6.400000 3.300000 5.100000 1.800000
max 7.900000 4.400000 6.900000 2.500000
In [24]:
x = iris.data
y = iris.target.reshape(-1,1)
In [33]:
#劃分訓練接和測試集
x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.3,random_state=35,stratify = y)
Out[33]:
array([1.7, 1. , 1.3, 1.5, 3.9, 1.8, 2.1, 7. , 6.2, 0.5, 1.9, 6.2, 2.7,
       7.1, 6.9, 0. , 2. , 2.6, 1.9, 2.3, 2.6, 6.7, 3.8, 7.1, 6.7, 4.9,
       2.2, 2.1, 2.7, 1.3, 2. , 0.8, 2.7, 2.6, 1.4, 1.9, 3.7, 6.9, 2.3,
       2.2, 1.9, 1.2, 1.7, 6.6, 0.5, 6.8, 6.9, 2.5, 6.2, 6.8, 6.7, 3.6,
       7. , 1.5, 1.7, 2.1, 2.7, 3. , 2.2, 1.8, 1.8, 1.7, 2.7, 7.2, 6.9,
       2.9, 7.2, 1.4, 2.9, 2.2, 4.2, 1.5, 6.6, 6.1, 1.5, 4.6, 6.5, 1.4,
       1.3, 0.5, 3.8, 6.3, 6.8, 6.6, 1.8, 2.5, 7.4, 2.6, 6.8, 6.8, 4. ,
       1.7, 7.1, 6.5, 7.9, 1.4, 2.4, 6.6, 6.4, 7.3, 1.9, 1.8, 7.6, 0.9,
       0.8])
In [102]:
arr=np.argsort(np.array([1,5,3,4]))[:3]
test=[np.array([1,5,3,4])[a] for a in arr]
test_2=np.array([1,5,3,4])[arr]
test_2.tolist().count(1)
Out[102]:
1
In [109]:
np.argmax([1,5,3,4])
# np.bincount([1,1,2,3,'1x'])
 
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-109-feb333e2c58a> in <module>
      1 np.argmax([1,5,3,4])
----> 2np.bincount([1,1,2,3,'1x'])

ValueError: invalid literal for int() with base 10: '1x'
 

2.核心算法實現

In [150]:
# 距離函數定義
def l1_distance(a,b):
    return np.sum(np.abs(a-b),axis=1)
def l2_distance(a,b):
     return np.sqrt(np.sum((a-b)**2,axis=1))


# 分類器實現
class kNN(object):
    #定義一個初始化方法, __init__ 是類的構造方法
    def __init__(self,n_neighbors=1,dist_func= l1_distance):
        self.n_neighbors=n_neighbors
        self.dist_func=dist_func
    
    # 訓練模型的方法
    def fit(self,x,y):
        self.x_train = x
        self.y_train = y
        
    # 模型預測
    def predict(self, x):
        # 初始化預測分類數組
        y_pred = np.zeros((x.shape[0],1),dtype=self.y_train.dtype)
        #遍歷輸入的x數據點
        for i,x_test  in enumerate(x):
            # x_test和全部訓練數據計算距離
            distances=self.dist_func(self.x_train,x_test)
            # 對獲得的距離按照由近到遠排序
            nn_indexes=np.argsort(distances)[:self.n_neighbors]
            #選取其中最近的k個點,統計類別出現頻率最高的那個,賦給y_predict[i]
#             y_res=[y_train[a] for a in nn_indexes]
            y_res=y_train[nn_indexes].ravel().tolist()
#             y_pred[i] = np.argmax([y_res.count(0),y_res.count(1),y_res.count(2)])
            y_pred[i] = np.argmax(np.bincount(y_res))
        return y_pred
In [160]:
kNN_model=kNN(n_neighbors=5,dist_func= l1_distance)
kNN_model.fit(x_train,y_train)
y_pred=kNN_model.predict(x_test)
In [161]:
accuracy_score(y_test,y_pred)
Out[161]:
0.9777777777777777
In [166]:
#比對各個參數的好壞
knn=kNN()
knn.fit(x_train,y_train)
result_list=[]
for p in [1,2]:
  knn.dist_func=l1_distance if p==1 else l2_distance
  #考慮不一樣的k取值
  for k in range(1,10,2):
        knn.n_neighbors=k
        y_pred=knn.predict(x_test)
        accuracy= accuracy_score(y_test,y_pred)
        print(accuracy)
        result_list.append([knn.n_neighbors,knn.dist_func.__name__,accuracy])
df = pd.DataFrame(result_list,columns=['k',"距離函數","準確率"])      
df
 
0.9333333333333333
0.9333333333333333
0.9777777777777777
0.9555555555555556
0.9555555555555556
0.9333333333333333
0.9333333333333333
0.9777777777777777
0.9777777777777777
0.9777777777777777
Out[166]:
 
  k 距離函數 準確率
0 1 l1_distance 0.933333
1 3 l1_distance 0.933333
2 5 l1_distance 0.977778
3 7 l1_distance 0.955556
4 9 l1_distance 0.955556
5 1 l2_distance 0.933333
6 3 l2_distance 0.933333
7 5 l2_distance 0.977778
8 7 l2_distance 0.977778
9 9 l2_distance 0.977778
In [ ]:
 
In [ ]:
相關文章
相關標籤/搜索