【火爐煉AI】機器學習022-使用均值漂移聚類算法構建模型

【火爐煉AI】機器學習022-使用均值漂移聚類算法構建模型

(本文所使用的Python庫和版本號: Python 3.5, Numpy 1.14, scikit-learn 0.19, matplotlib 2.2 )git

無監督學習算法有不少種,前面已經講解過了K-means聚類算法,並用該算法對圖片進行矢量量化壓縮。下面咱們來學習第二種無監督學習算法----均值漂移算法。github


1. 均值漂移算法簡介

均值漂移算法是一種基於密度梯度上升的非參數方法,它常常被應用在圖像識別中的目標跟蹤,數據聚類,分類等場景。算法

其核心思想是:首先隨便選擇一箇中心點,而後計算該中心點必定範圍以內全部點到中心點的距離向量的平均值,計算該平均值獲得一個偏移均值,而後將中心點移動到偏移均值位置,經過這種不斷重複的移動,可使中心點逐步逼近到最佳位置。這種思想相似於梯度降低方法,經過不斷的往梯度降低的方向移動,能夠到達梯度上的局部最優解或全局最優解。機器學習

以下是漂移均值算法的思想呈現,首先隨機選擇一箇中心點(綠色點),而後計算該點必定範圍內全部點到這個點的距離均值,而後將該中心點移動距離均值,到黃色點處,同理,再計算該黃色點必定範圍內的全部點到黃點的距離均值,通過屢次計算均值--移動中心點等方式,可使得中心點逐步逼近最佳中心點位置,即圖中紅色點處。函數

均值漂移算法的核心思想

1.1 均值漂移算法的基礎公式

從上面核心思想能夠看出,均值漂移的過程就是不斷的重複計算距離均值,移動中心點的過程,故而計算偏移均值和移動距離即是很是關鍵的兩個步驟,以下爲計算偏移均值的基礎公式。post

其中Sh:以x爲中心點,半徑爲h的高維球區域; k:包含在Sh範圍內點的個數; xi:包含在Sh範圍內的點學習

第二個步驟是計算移動必定距離以後的中心點位置,其計算公式爲:spa

其中,Mt爲t狀態下求得的偏移均值; xt爲t狀態下的中心.net

很顯然,移動以後的中心點位置是移動前位置加上偏移均值。rest

1.2 引入核函數的偏移均值算法

上述雖然介紹了均值漂移算法的基礎公式,可是該公式存在必定的問題,咱們知道,高維球區域內的全部樣本點對求解的貢獻是不同的,而基礎公式卻當作貢獻同樣來處理,即全部點的權重同樣,這是不符合邏輯的,那麼怎麼改進了?咱們能夠引入核函數,用來求出每一個樣本點的貢獻權重。固然這種求解權重的核函數有不少種,高斯函數就是其中的一種,以下公式是引入高斯核函數後的偏移均值的計算公式:

上面就是核函數內部的樣子。

1.3 均值漂移算法的運算步驟

均值漂移算法的應用很是普遍,好比在聚類,圖像分割,目標跟蹤等領域,其運算步驟每每包含有以下幾個步驟:

1,在數據點中隨機選擇一個點做爲初始中心點。

2,找出離該中心點距離在帶寬以內的全部點,記作集合M,認爲這些點屬於簇C.

3,計算從中心點開始到集合M中每一個元素的向量,將這些向量相加,獲得偏移向量。

4,將該中心點沿着偏移的方向移動,移動距離就是該偏移向量的模。

5,重複上述步驟2,3,4,直到偏移向量的大小知足設定的閾值要求,記住此時的中心點。

6,重複上述1,2,3,4,5直到全部的點都被歸類。

7,分類:根據每一個類,對每一個點的訪問頻率,取訪問頻率最大的那個類,做爲當前點集的所屬類。

1.4 均值漂移算法的優點

均值漂移算法用於集羣數據點時,把數據點的分佈當作是機率密度函數,但願在特徵空間中根據函數分佈特徵找出數據點的模式,這些模式就對應於一羣羣局部最密集分佈的點。

雖然咱們前面講解了K-means算法,但K-means算法在實際應用時,須要知道咱們要把數據劃分爲幾個類別,若是類別數量出錯,則每每難以獲得使人滿意的分類結果,而要劃分的類別每每很難事先肯定。這就是K-means算法的應用難點。

而均值漂移算法卻不須要事先知道要集羣的數量,這種算法能夠在咱們不知道要尋找多少集羣的狀況下自動劃分最合適的族羣,這就是均值漂移算法的一個很明顯優點。

以上部份內容來源於博客文章,在此表示感謝。


2. 構建均值漂移模型來聚類數據

本文所使用的數據集和讀取數據集的方式與上一篇文章【火爐煉AI】機器學習020-使用K-means算法對數據進行聚類分析如出一轍,故而此處省略。

下面是構建MeanShift對象的代碼,使用MeanShift以前,咱們須要評估帶寬,帶寬就是上面所講到的距離中心點的必定距離,咱們要把全部包含在這個距離以內的點都放入一個集合M中,用於計算偏移向量。

# 構建MeanShift對象,但須要評估帶寬
from sklearn.cluster import MeanShift, estimate_bandwidth
bandwidth=estimate_bandwidth(dataset_X,quantile=0.1,
                             n_samples=len(dataset_X))
meanshift=MeanShift(bandwidth=bandwidth,bin_seeding=True) # 構建對象
meanshift.fit(dataset_X) # 並用MeanShift對象來訓練該數據集

centroids=meanshift.cluster_centers_ # 質心的座標,對應於feature0, feature1
print(centroids) # 能夠看出有4行,即4個質心
labels=meanshift.labels_  # 數據集中每一個數據點對應的label
# print(labels)

cluster_num=len(np.unique(labels)) # label的個數,即自動劃分的族羣的個數
print('cluster num: {}'.format(cluster_num))
複製代碼

-------------------------------------輸---------出----------------

[[ 8.22338235 1.34779412]
[ 4.10104478 -0.81164179]
[ 1.18820896 2.10716418]
[ 4.995 4.99967742]]
cluster num: 4

--------------------------------------------完--------------------

能夠看出,此處咱們獲得了四個質心,這四個質心的座標位置能夠經過meanshift.cluster_centers_獲取,而meanshift.labels_ 獲得的就是原來樣本數據的label,也就是咱們經過均值漂移算法本身找到的label,這就是無監督學習的優點所在:雖然沒有給樣本數據指定label,可是該算法能本身找到其對應的label。

一樣的,該怎麼查看該MeanShift算法的好壞了,能夠經過下面的函數直接觀察數據集劃分的效果。

def visual_meanshift_effect(meanshift,dataset):
    assert dataset.shape[1]==2,'only support dataset with 2 features'
    X=dataset[:,0]
    Y=dataset[:,1]
    X_min,X_max=np.min(X)-1,np.max(X)+1
    Y_min,Y_max=np.min(Y)-1,np.max(Y)+1
    X_values,Y_values=np.meshgrid(np.arange(X_min,X_max,0.01),
                                  np.arange(Y_min,Y_max,0.01))
    # 預測網格點的標記
    predict_labels=meanshift.predict(np.c_[X_values.ravel(),Y_values.ravel()])
    predict_labels=predict_labels.reshape(X_values.shape)
    plt.figure()
    plt.imshow(predict_labels,interpolation='nearest',
               extent=(X_values.min(),X_values.max(),
                       Y_values.min(),Y_values.max()),
               cmap=plt.cm.Paired,
               aspect='auto',
               origin='lower')
    
    # 將數據集繪製到圖表中
    plt.scatter(X,Y,marker='v',facecolors='none',edgecolors='k',s=30)
    
    # 將中心點繪製到圖中
    centroids=meanshift.cluster_centers_
    plt.scatter(centroids[:,0],centroids[:,1],marker='o',
                s=100,linewidths=2,color='k',zorder=5,facecolors='b')
    plt.title('MeanShift effect graph')
    plt.xlim(X_min,X_max)
    plt.ylim(Y_min,Y_max)
    plt.xlabel('feature_0')
    plt.ylabel('feature_1')
    plt.show()
    
visual_meanshift_effect(meanshift,dataset_X)
複製代碼

MeanShift在數據集上的劃分效果圖

########################小**********結###################

1,MeanShift的構建和訓練方法和K-means的方式幾乎同樣,可是MeanShift能夠自動計算出數據集的族羣數量,而不須要人爲事先指定,這使得MeanShift比K-means要好用一些。

2, 訓練以後的MeanShift對象中包含有該數據集的質心座標,數據集的各個樣本對應的label信息,這些信息能夠很方便的獲取。

#######################################################


注:本部分代碼已經所有上傳到(個人github)上,歡迎下載。

參考資料:

1, Python機器學習經典實例,Prateek Joshi著,陶俊傑,陳小莉譯

相關文章
相關標籤/搜索