機器學習:Mean Shift聚類算法

本文由ChardLau原創,轉載請添加原文連接https://www.chardlau.com/mean-shift/python

今天的文章介紹如何利用Mean Shift算法的基本形式對數據進行聚類操做。而有關Mean Shift算法加入核函數計算漂移向量部分的內容將不在本文講述範圍內。實際上除了聚類,Mean Shift算法還能用於計算機視覺等場合,有關該算法的理論知識請參考這篇文章git

Mean Shift算法原理

下圖展現了Mean Shift算法計算飄逸向量的過程:
Mean Shiftgithub

Mean Shift算法的關鍵操做是經過感興趣區域內的數據密度變化計算中心點的漂移向量,從而移動中心點進行下一次迭代,直到到達密度最大處(中心點不變)。從每一個數據點出發均可以進行該操做,在這個過程,統計出如今感興趣區域內的數據的次數。該參數將在最後做爲分類的依據。算法

K-Means算法不同的是,Mean Shift算法能夠自動決定類別的數目。與K-Means算法同樣的是,二者都用集合內數據點的均值進行中心點的移動。app

算法步驟

下面是有關Mean Shift聚類算法的步驟:函數

  1. 在未被標記的數據點中隨機選擇一個點做爲起始中心點center;
  2. 找出以center爲中心半徑爲radius的區域中出現的全部數據點,認爲這些點同屬於一個聚類C。同時在該聚類中記錄數據點出現的次數加1。
  3. 以center爲中心點,計算從center開始到集合M中每一個元素的向量,將這些向量相加,獲得向量shift。
  4. center = center + shift。即center沿着shift的方向移動,移動距離是||shift||。
  5. 重複步驟二、三、4,直到shift的很小(就是迭代到收斂),記住此時的center。注意,這個迭代過程當中遇到的點都應該歸類到簇C。
  6. 若是收斂時當前簇C的center與其它已經存在的簇C2中心的距離小於閾值,那麼把C2和C合併,數據點出現次數也對應合併。不然,把C做爲新的聚類。
  7. 重複一、二、三、四、5直到全部的點都被標記爲已訪問。
  8. 分類:根據每一個類,對每一個點的訪問頻率,取訪問頻率最大的那個類,做爲當前點集的所屬類。

算法實現

下面使用Python實現了Mean Shift算法的基本形式:spa

import numpy as np
import matplotlib.pyplot as plt

# Input data set
X = np.array([
    [-4, -3.5], [-3.5, -5], [-2.7, -4.5],
    [-2, -4.5], [-2.9, -2.9], [-0.4, -4.5],
    [-1.4, -2.5], [-1.6, -2], [-1.5, -1.3],
    [-0.5, -2.1], [-0.6, -1], [0, -1.6],
    [-2.8, -1], [-2.4, -0.6], [-3.5, 0],
    [-0.2, 4], [0.9, 1.8], [1, 2.2],
    [1.1, 2.8], [1.1, 3.4], [1, 4.5],
    [1.8, 0.3], [2.2, 1.3], [2.9, 0],
    [2.7, 1.2], [3, 3], [3.4, 2.8],
    [3, 5], [5.4, 1.2], [6.3, 2]
])


def mean_shift(data, radius=2.0):
    clusters = []
    for i in range(len(data)):
        cluster_centroid = data[i]
        cluster_frequency = np.zeros(len(data))

        # Search points in circle
        while True:
            temp_data = []
            for j in range(len(data)):
                v = data[j]
                # Handle points in the circles
                if np.linalg.norm(v - cluster_centroid) <= radius:
                    temp_data.append(v)
                    cluster_frequency[i] += 1

            # Update centroid
            old_centroid = cluster_centroid
            new_centroid = np.average(temp_data, axis=0)
            cluster_centroid = new_centroid
            # Find the mode
            if np.array_equal(new_centroid, old_centroid):
                break

        # Combined 'same' clusters
        has_same_cluster = False
        for cluster in clusters:
            if np.linalg.norm(cluster['centroid'] - cluster_centroid) <= radius:
                has_same_cluster = True
                cluster['frequency'] = cluster['frequency'] + cluster_frequency
                break

        if not has_same_cluster:
            clusters.append({
                'centroid': cluster_centroid,
                'frequency': cluster_frequency
            })

    print('clusters (', len(clusters), '): ', clusters)
    clustering(data, clusters)
    show_clusters(clusters, radius)


# Clustering data using frequency
def clustering(data, clusters):
    t = []
    for cluster in clusters:
        cluster['data'] = []
        t.append(cluster['frequency'])
    t = np.array(t)
    # Clustering
    for i in range(len(data)):
        column_frequency = t[:, i]
        cluster_index = np.where(column_frequency == np.max(column_frequency))[0][0]
        clusters[cluster_index]['data'].append(data[i])


# Plot clusters
def show_clusters(clusters, radius):
    colors = 10 * ['r', 'g', 'b', 'k', 'y']
    plt.figure(figsize=(5, 5))
    plt.xlim((-8, 8))
    plt.ylim((-8, 8))
    plt.scatter(X[:, 0], X[:, 1], s=20)
    theta = np.linspace(0, 2 * np.pi, 800)
    for i in range(len(clusters)):
        cluster = clusters[i]
        data = np.array(cluster['data'])
        plt.scatter(data[:, 0], data[:, 1], color=colors[i], s=20)
        centroid = cluster['centroid']
        plt.scatter(centroid[0], centroid[1], color=colors[i], marker='x', s=30)
        x, y = np.cos(theta) * radius + centroid[0], np.sin(theta) * radius + centroid[1]
        plt.plot(x, y, linewidth=1, color=colors[i])
    plt.show()


mean_shift(X, 2.5)

代碼連接code

上述代碼執行結果以下:
執行結果orm

其餘

Mean Shift算法還有不少內容未說起。其中有「動態計算感興趣區域半徑」、「加入核函數計算漂移向量」等。本文做爲入門引導,暫時只覆蓋這些內容。blog

相關文章
相關標籤/搜索