本文由ChardLau原創,轉載請添加原文連接https://www.chardlau.com/mean-shift/python
今天的文章介紹如何利用Mean Shift
算法的基本形式對數據進行聚類操做。而有關Mean Shift
算法加入核函數計算漂移向量部分的內容將不在本文講述範圍內。實際上除了聚類,Mean Shift
算法還能用於計算機視覺等場合,有關該算法的理論知識請參考這篇文章。git
Mean Shift
算法原理下圖展現了Mean Shift
算法計算飄逸向量的過程:
github
Mean Shift
算法的關鍵操做是經過感興趣區域內的數據密度變化計算中心點的漂移向量,從而移動中心點進行下一次迭代,直到到達密度最大處(中心點不變)。從每一個數據點出發均可以進行該操做,在這個過程,統計出如今感興趣區域內的數據的次數。該參數將在最後做爲分類的依據。算法
與K-Means
算法不同的是,Mean Shift
算法能夠自動決定類別的數目。與K-Means
算法同樣的是,二者都用集合內數據點的均值進行中心點的移動。app
下面是有關Mean Shift
聚類算法的步驟:函數
下面使用Python
實現了Mean Shift
算法的基本形式:spa
import numpy as np import matplotlib.pyplot as plt # Input data set X = np.array([ [-4, -3.5], [-3.5, -5], [-2.7, -4.5], [-2, -4.5], [-2.9, -2.9], [-0.4, -4.5], [-1.4, -2.5], [-1.6, -2], [-1.5, -1.3], [-0.5, -2.1], [-0.6, -1], [0, -1.6], [-2.8, -1], [-2.4, -0.6], [-3.5, 0], [-0.2, 4], [0.9, 1.8], [1, 2.2], [1.1, 2.8], [1.1, 3.4], [1, 4.5], [1.8, 0.3], [2.2, 1.3], [2.9, 0], [2.7, 1.2], [3, 3], [3.4, 2.8], [3, 5], [5.4, 1.2], [6.3, 2] ]) def mean_shift(data, radius=2.0): clusters = [] for i in range(len(data)): cluster_centroid = data[i] cluster_frequency = np.zeros(len(data)) # Search points in circle while True: temp_data = [] for j in range(len(data)): v = data[j] # Handle points in the circles if np.linalg.norm(v - cluster_centroid) <= radius: temp_data.append(v) cluster_frequency[i] += 1 # Update centroid old_centroid = cluster_centroid new_centroid = np.average(temp_data, axis=0) cluster_centroid = new_centroid # Find the mode if np.array_equal(new_centroid, old_centroid): break # Combined 'same' clusters has_same_cluster = False for cluster in clusters: if np.linalg.norm(cluster['centroid'] - cluster_centroid) <= radius: has_same_cluster = True cluster['frequency'] = cluster['frequency'] + cluster_frequency break if not has_same_cluster: clusters.append({ 'centroid': cluster_centroid, 'frequency': cluster_frequency }) print('clusters (', len(clusters), '): ', clusters) clustering(data, clusters) show_clusters(clusters, radius) # Clustering data using frequency def clustering(data, clusters): t = [] for cluster in clusters: cluster['data'] = [] t.append(cluster['frequency']) t = np.array(t) # Clustering for i in range(len(data)): column_frequency = t[:, i] cluster_index = np.where(column_frequency == np.max(column_frequency))[0][0] clusters[cluster_index]['data'].append(data[i]) # Plot clusters def show_clusters(clusters, radius): colors = 10 * ['r', 'g', 'b', 'k', 'y'] plt.figure(figsize=(5, 5)) plt.xlim((-8, 8)) plt.ylim((-8, 8)) plt.scatter(X[:, 0], X[:, 1], s=20) theta = np.linspace(0, 2 * np.pi, 800) for i in range(len(clusters)): cluster = clusters[i] data = np.array(cluster['data']) plt.scatter(data[:, 0], data[:, 1], color=colors[i], s=20) centroid = cluster['centroid'] plt.scatter(centroid[0], centroid[1], color=colors[i], marker='x', s=30) x, y = np.cos(theta) * radius + centroid[0], np.sin(theta) * radius + centroid[1] plt.plot(x, y, linewidth=1, color=colors[i]) plt.show() mean_shift(X, 2.5)
代碼連接code
上述代碼執行結果以下:
orm
Mean Shift
算法還有不少內容未說起。其中有「動態計算感興趣區域半徑」、「加入核函數計算漂移向量」等。本文做爲入門引導,暫時只覆蓋這些內容。blog