Kmeans聚類算法

Kmeans是最流行的,以及最簡單的用於挖掘數據潛在結構的機器學習算法之一。Kmeans的目標很簡單:根據數據的均值,將數據劃分爲若干個簇。假定每一個簇的均值能夠很好地表明簇內的每個觀察值。算法

Kmeans算法

假定咱們想把數據劃分爲k個簇,那麼咱們就須要找出這k個簇的k箇中心。該如何定義,以及尋找這些中心呢?app

咱們只須要求解方程:$min\sum_i^{N}\sum_j^KO_i^j||x_i-u_j||^2$,其中當觀察點i爲簇j的中心時$O_i^j=1$,不然爲0.機器學習

咱們正在尋找k箇中心,使得各個簇內的點到簇中心的距離最小。這是一個最優化問題,可是上面的目標函數是非凸的,並且有一些變量是二元的,沒法用傳統的梯度降低法求解。函數

解決這個問題的方法以下:學習

  1. 隨機初始化k箇中心。
  2. 更新每一箇中心。新的中心爲相應簇中的全部觀察值的平均值。
  3. 更新收斂準則。

用R語言實現K-means算法

既然,咱們已經有了僞代碼算法,接下來咱們能夠用R語言實現Kmeans算法。首先,咱們建立5類數據,它們都服從2維高斯分佈。優化

require(MASS)
require(ggplot2)
set.seed(1234)
set1 <- mvrnorm(n = 300, mu = c(-4, 10), Sigma = matrix(c(1.5, 1, 1, 1.5), 2))
set2 <- mvrnorm(n = 300, mu = c(5, 7), Sigma = matrix(c(1, 2, 2, 6), 2))
set3 <- mvrnorm(n = 300, mu = c(-1, 1), Sigma = matrix(c(4, 0, 0, 4), 2))
set4 <- mvrnorm(n = 300, mu = c(10, -10), Sigma = matrix(c(4, 0, 0, 4), 2))
set5 <- mvrnorm(n = 300, mu = c(3, -3), Sigma = matrix(c(4, 0, 0, 4), 2))
DF <- data.frame(rbind(set1, set2, set3, set4, set5), cluster=as.factor(rep(1:5, each = 300)))

ggplot(DF, aes(x=X1, y=X2, color=cluster)) + geom_point()

圖片描述
在這個數據集中,Kmeans算法將會獲得一個很好的結果,由於每一個分佈都呈現圓形,如上圖所示。ui

初始化簇中心

簇中心的初始化很是重要,它可以影響算法。所以,咱們從數據集中隨機選取K的點。spa

# 中心初始化
centroids <- data[sample.int(nrow(data), K), ]
# 中止準則初始化
current_stop_crit <- 10e10
# 初始化每一個點的指定中心
cluster <- rep(0, nrow(data))
# 算法是否收斂
converged <- FALSE
iter <- 1

將每一個點歸爲指定的簇

在每一次迭代中,每一個點將會歸爲離它最近的簇。咱們可使用歐幾里得距離來計算每一個點到每一箇中心的距離,並保存各個點到各個中心的最近距離以及相應的簇中心。code

# 在觀察值上進行迭代
for (i in 1:nrow(data)){
  # 設置最小距離
  min_dist <- 10e10
  # 在中心上進行迭代
  for (centroid in 1:nrow(centroids)){
    # 計算歐式距離
    distance_to_centroid <- sum((centroids[centroid, ] - data[i, ]) ^ 2)
    # 距離點最近的中心
    if (distance_to_centroid <= min_dist){
      # 這個點將會被歸爲這個簇
      cluster[i] <- centroid
      min_dist <- distance_to_centroid
    }
  }
}

簇中心更新

一旦每一個觀察值都被歸爲距離最近的簇,每一個簇的中心座標就會被更新。簇中心的新座標爲相應簇中全部點的平均值。收斂準則爲:當各個簇的中心中止變更時,算法就應該終止。orm

while (current_stop_crit >= stop_crit & converged == FALSE){
  iter <- iter + 1
  if (current_stop_crit <= stop_crit){
    converged <- TRUE
  }
  old_centroids <- centroids
  # 更新中止準則
  current_stop_crit <- mean((old_centroids - centroids) ^ 2))
}

完整Kmeans函數

kmeans <- function(data, K=4, stop_crit=10e-5){
  # 初始化簇中心
  centroids <- data[sample.int(nrow(data), K), ]
  current_stop_crit <- 1000
  cluster <- rep(0, nrow(data))
  converged <- FALSE
  iter <- 1
  while (current_stop_crit >= stop_crit & converged == FALSE){
    iter <- iter + 1
    if (current_stop_crit <= stop_crit){
      converged <- TRUE
    }
    old_centroids <- centroids
    # 把每一個點納入相應的簇
    for (i in 1:nrow(data)){
      min_dist <- 10e10
      for (centroid in 1:nrow(centroids)){
        distance_to_centroid <- sum((centroids[centroid, ] - data[i, ]) ^ 2)
        if (distance_to_centroid <= min_dist){
          cluster[i] <- centroid
          min_dist <- distance_to_centroid
        }
      }
    }
    # 更新簇中心
    for (i in 1:nrow(centroids)){
      centroids[i, ] <- apply(data[cluster == i, ], 2, mean)
    }
    current_stop_crit <- mean((old_centroids - centroids) ^ 2)
  }
  return(list(data=data.frame(data, cluster), centroids=centroids))
}

咱們可使用上述代碼來觀察咱們的數據:

res <- kmeans(DF[1:2], K=5)
res$centroids$cluster <- 1:5
res$data$isCentroid <- FALSE
res$centroids$isCentroid <- TRUE
data_plot <- rbind(res$centroids, res$data)
ggplot(data_plot,aes(x=X1,y=X2,color=as.factor(cluster),size=isCentroid,alpha=isCentroid)) +
  geom_point()

圖片描述

相關文章
相關標籤/搜索