重點介紹下K-means聚類算法。K-means算法是比較經典的聚類算法,算法的基本思想是選取K個點(隨機)做爲中心進行聚類,而後對聚類的結果計算該類的質心,經過迭代的方法不斷更新質心,直到質心不變或稍微移動爲止,則最後的聚類結果就是最後的聚類結果。下面首先介紹下K-means具體的算法步驟。java
K-means算法算法
在前面已經大概的介紹了下K-means,下面就介紹下具體的算法描述:數組
1)選取K個點做爲初始質心;app
2)對每一個樣本分別計算到K個質心的類似度或距離,將該樣本劃分到類似度最高或距離最短的質心所在類;dom
3)對該輪聚類結果,計算每個類別的質心,新的質心做爲下一輪的質心;ide
4)判斷算法是否知足終止條件,知足終止條件結束,不然繼續第二、三、4步。this
在介紹算法以前,咱們首先看下K-means算法聚類平面200,000個點聚成34個類別的結果(以下圖)spa
算法實現.net
K-means聚類算法總體思想比較簡單,下面 就分步介紹如何用Java來實現K-means算法。code
1、K-means算法基礎屬性
在K-means算法中,有幾個重要的指標,好比K值、最大迭代次數等,對於這些指標,咱們統一把它們設置爲類的屬性,以下:
[java] view plain copy
print?
- private List<T> dataArray;//待分類的原始值
- private int K = 3;//將要分紅的類別個數
- private int maxClusterTimes = 500;//最大迭代次數
- private List<List<T>> clusterList;//聚類的結果
- private List<T> clusteringCenterT;//質心
2、初始質心的選擇
K-means聚類算法的結果很大程度收到初始質心的選取,這了爲了保證有充分的隨機性,對於初始質心的選擇這裏採用徹底隨機的方法,先把待分類的數據隨機打亂,而後把前K個樣本做爲初始質心(經過屢次迭代,會減小初始質心的影響)。
[java] view plain copy
print?
- List<T> centerT = new ArrayList<T>(size);
- //對數據進行打亂
- Collections.shuffle(dataArray);
- for (int i = 0; i < size; i++) {
- centerT.add(dataArray.get(i));
- }
3、一輪聚類
在K-means算法中,大部分的時間都在作一輪一輪的聚類,具體功能也很簡單,就是對每個樣本分別計算和全部質心的類似度或距離,找到與該樣本最類似的質心或者距離最近的質心,而後把該樣本劃分到該類中,具體邏輯介紹參照代碼中的註釋。
[java] view plain copy
print?
- private void clustering(List<T> preCenter, int times) {
- if (preCenter == null || preCenter.size() < 2) {
- return;
- }
- //打亂質心的順序
- Collections.shuffle(preCenter);
- List<List<T>> clusterList = getListT(preCenter.size());
- for (T o1 : this.dataArray) {
- //尋找最類似的質心
- int max = 0;
- double maxScore = similarScore(o1, preCenter.get(0));
- for (int i = 1; i < preCenter.size(); i++) {
- if (maxScore < similarScore(o1, preCenter.get(i))) {
- maxScore = similarScore(o1, preCenter.get(i));
- max = i;
- }
- }
- clusterList.get(max).add(o1);
- }
- //計算本次聚類結果每一個類別的質心
- List<T> nowCenter = new ArrayList<T> ();
- for (List<T> list : clusterList) {
- nowCenter.add(getCenterT(list));
- }
- //是否達到最大迭代次數
- if (times >= this.maxClusterTimes || preCenter.size() < this.K) {
- this.clusterList = clusterList;
- return;
- }
- this.clusteringCenterT = nowCenter;
- //判斷質心是否發生移動,若是沒有移動,結束本次聚類,不然進行下一輪
- if (isCenterChange(preCenter, nowCenter)) {
- clear(clusterList);
- clustering(nowCenter, times + 1);
- } else {
- this.clusterList = clusterList;
- }
- }
4、質心是否移動
在第三步中,提到了一個重要的步驟:每輪聚類結束後,都要從新計算質心,而且計算質心是否發生移動。對於新質心的計算、樣本之間的類似度和判斷兩個樣本是否相等這幾個功能因爲並不知道樣本的具體數據類型,所以把他們定義成抽象方法,供子類來實現。下面就重點介紹如何判斷質心是否發生移動。
[java] view plain copy
print?
- private boolean isCenterChange(List<T> preT, List<T> nowT) {
- if (preT == null || nowT == null) {
- return false;
- }
- for (T t1 : preT) {
- boolean bol = true;
- for (T t2 : nowT) {
- if (equals(t1, t2)) {//t1在t2中有相等的,認爲該質心未移動
- bol = false;
- break;
- }
- }
- //有一個質心發生移動,認爲須要進行下一次計算
- if (bol) {
- return bol;
- }
- }
- return false;
- }
從上述代碼能夠看到,算法的思想就是對於先後兩個質心數組分別前一組的質心是否在後一個質心組中出現,有一個沒有出現,就認爲質心發生了變更。
完整代碼
上面四步已經完整的介紹了K-means算法的具體算法思想,下面就看下完整的代碼實現。
[java] view plain copy
print?
- /**
- *@Description: K-means聚類
- */
- package com.lulei.datamining.knn;
-
- import java.util.ArrayList;
- import java.util.Collections;
- import java.util.List;
-
- public abstract class KMeansClustering <T>{
- private List<T> dataArray;//待分類的原始值
- private int K = 3;//將要分紅的類別個數
- private int maxClusterTimes = 500;//最大迭代次數
- private List<List<T>> clusterList;//聚類的結果
- private List<T> clusteringCenterT;//質心
-
- public int getK() {
- return K;
- }
- public void setK(int K) {
- if (K < 1) {
- throw new IllegalArgumentException("K must greater than 0");
- }
- this.K = K;
- }
- public int getMaxClusterTimes() {
- return maxClusterTimes;
- }
- public void setMaxClusterTimes(int maxClusterTimes) {
- if (maxClusterTimes < 10) {
- throw new IllegalArgumentException("maxClusterTimes must greater than 10");
- }
- this.maxClusterTimes = maxClusterTimes;
- }
- public List<T> getClusteringCenterT() {
- return clusteringCenterT;
- }
- /**
- * @return
- * @Author:lulei
- * @Description: 對數據進行聚類
- */
- public List<List<T>> clustering() {
- if (dataArray == null) {
- return null;
- }
- //初始K個點爲數組中的前K個點
- int size = K > dataArray.size() ? dataArray.size() : K;
- List<T> centerT = new ArrayList<T>(size);
- //對數據進行打亂
- Collections.shuffle(dataArray);
- for (int i = 0; i < size; i++) {
- centerT.add(dataArray.get(i));
- }
- clustering(centerT, 0);
- return clusterList;
- }
-
- /**
- * @param preCenter
- * @param times
- * @Author:lulei
- * @Description: 一輪聚類
- */
- private void clustering(List<T> preCenter, int times) {
- if (preCenter == null || preCenter.size() < 2) {
- return;
- }
- //打亂質心的順序
- Collections.shuffle(preCenter);
- List<List<T>> clusterList = getListT(preCenter.size());
- for (T o1 : this.dataArray) {
- //尋找最類似的質心
- int max = 0;
- double maxScore = similarScore(o1, preCenter.get(0));
- for (int i = 1; i < preCenter.size(); i++) {
- if (maxScore < similarScore(o1, preCenter.get(i))) {
- maxScore = similarScore(o1, preCenter.get(i));
- max = i;
- }
- }
- clusterList.get(max).add(o1);
- }
- //計算本次聚類結果每一個類別的質心
- List<T> nowCenter = new ArrayList<T> ();
- for (List<T> list : clusterList) {
- nowCenter.add(getCenterT(list));
- }
- //是否達到最大迭代次數
- if (times >= this.maxClusterTimes || preCenter.size() < this.K) {
- this.clusterList = clusterList;
- return;
- }
- this.clusteringCenterT = nowCenter;
- //判斷質心是否發生移動,若是沒有移動,結束本次聚類,不然進行下一輪
- if (isCenterChange(preCenter, nowCenter)) {
- clear(clusterList);
- clustering(nowCenter, times + 1);
- } else {
- this.clusterList = clusterList;
- }
- }
-
- /**
- * @param size
- * @return
- * @Author:lulei
- * @Description: 初始化一個聚類結果
- */
- private List<List<T>> getListT(int size) {
- List<List<T>> list = new ArrayList<List<T>>(size);
- for (int i = 0; i < size; i++) {
- list.add(new ArrayList<T>());
- }
- return list;
- }
-
- /**
- * @param lists
- * @Author:lulei
- * @Description: 清空無用數組
- */
- private void clear(List<List<T>> lists) {
- for (List<T> list : lists) {
- list.clear();
- }
- lists.clear();
- }
-
- /**
- * @param value
- * @Author:lulei
- * @Description: 向模型中添加記錄
- */
- public void addRecord(T value) {
- if (dataArray == null) {
- dataArray = new ArrayList<T>();
- }
- dataArray.add(value);
- }
-
- /**
- * @param preT
- * @param nowT
- * @return
- * @Author:lulei
- * @Description: 判斷質心是否發生移動
- */
- private boolean isCenterChange(List<T> preT, List<T> nowT) {
- if (preT == null || nowT == null) {
- return false;
- }
- for (T t1 : preT) {
- boolean bol = true;
- for (T t2 : nowT) {
- if (equals(t1, t2)) {//t1在t2中有相等的,認爲該質心未移動
- bol = false;
- break;
- }
- }
- //有一個質心發生移動,認爲須要進行下一次計算
- if (bol) {
- return bol;
- }
- }
- return false;
- }
-
- /**
- * @param o1
- * @param o2
- * @return
- * @Author:lulei
- * @Description: o1 o2之間的類似度
- */
- public abstract double similarScore(T o1, T o2);
-
- /**
- * @param o1
- * @param o2
- * @return
- * @Author:lulei
- * @Description: 判斷o1 o2是否相等
- */
- public abstract boolean equals(T o1, T o2);
-
- /**
- * @param list
- * @return
- * @Author:lulei
- * @Description: 求一組數據的質心
- */
- public abstract T getCenterT(List<T> list);
- }
二維數聚類實現
在算法描述中,介紹了一個200,000個點聚成34個類別的效果圖,下面就針對二維座標數據實現其具體子類。
1、類似度
對於二維座標的類似度,這裏咱們採起兩點間聚類的相反數,具體實現以下:
[java] view plain copy
print?
- @Override
- public double similarScore(XYbean o1, XYbean o2) {
- double distance = Math.sqrt((o1.getX() - o2.getX()) * (o1.getX() - o2.getX()) + (o1.getY() - o2.getY()) * (o1.getY() - o2.getY()));
- return distance * -1;
- }
2、樣本/質心是否相等
判斷樣本/質心是否相等只須要判斷兩點的座標是否相等便可,具體實現以下:
[java] view plain copy
print?
- @Override
- public boolean equals(XYbean o1, XYbean o2) {
- return o1.getX() == o2.getX() && o1.getY() == o2.getY();
- }
3、獲取一個分類下的新質心
對於二維座標數據,可使用全部點的重心做爲分類的質心,具體以下:
[java] view plain copy
print?
- @Override
- public XYbean getCenterT(List<XYbean> list) {
- int x = 0;
- int y = 0;
- try {
- for (XYbean xy : list) {
- x += xy.getX();
- y += xy.getY();
- }
- x = x / list.size();
- y = y / list.size();
- } catch(Exception e) {
-
- }
- return new XYbean(x, y);
- }
4、main方法
對於具體二維座標的源碼這裏就再也不貼出來,就是實現前面介紹的抽象類,並實現其中的3個抽象方法,下面咱們就隨機產生200,000個點,而後聚成34個類別,具體代碼以下:
[java] view plain copy
print?
- public static void main(String[] args) {
-
- int width = 600;
- int height = 400;
- int K = 34;
- XYCluster xyCluster = new XYCluster();
- for (int i = 0; i < 200000; i++) {
- int x = (int)(Math.random() * width) + 1;
- int y = (int)(Math.random() * height) + 1;
- xyCluster.addRecord(new XYbean(x, y));
- }
- xyCluster.setK(K);
- long a = System.currentTimeMillis();
- List<List<XYbean>> cresult = xyCluster.clustering();
- List<XYbean> center = xyCluster.getClusteringCenterT();
- System.out.println(JsonUtil.parseJson(center));
- long b = System.currentTimeMillis();
- System.out.println("耗時:" + (b - a) + "ms");
- new ImgUtil().drawXYbeans(width, height, cresult, "d:/2.png", 0, 0);
- }
對於這隨機產生的200,000個點聚成34類,總耗時5485ms。(計算機配置:i5 + 8G內存)