JAVA實現K-means聚類

重點介紹下K-means聚類算法。K-means算法是比較經典的聚類算法,算法的基本思想是選取K個點(隨機)做爲中心進行聚類,而後對聚類的結果計算該類的質心,經過迭代的方法不斷更新質心,直到質心不變或稍微移動爲止,則最後的聚類結果就是最後的聚類結果。下面首先介紹下K-means具體的算法步驟。java

 

K-means算法算法

      在前面已經大概的介紹了下K-means,下面就介紹下具體的算法描述:數組

1)選取K個點做爲初始質心;app

2)對每一個樣本分別計算到K個質心的類似度或距離,將該樣本劃分到類似度最高或距離最短的質心所在類;dom

3)對該輪聚類結果,計算每個類別的質心,新的質心做爲下一輪的質心;ide

4)判斷算法是否知足終止條件,知足終止條件結束,不然繼續第二、三、4步。this

      在介紹算法以前,咱們首先看下K-means算法聚類平面200,000個點聚成34個類別的結果(以下圖)spa

img

 

算法實現.net

      K-means聚類算法總體思想比較簡單,下面 就分步介紹如何用Java來實現K-means算法。code

 

1、K-means算法基礎屬性

      在K-means算法中,有幾個重要的指標,好比K值、最大迭代次數等,對於這些指標,咱們統一把它們設置爲類的屬性,以下:

 

[java] view plain copy

 print?在CODE上查看代碼片派生到個人代碼片

  1. private List<T> dataArray;//待分類的原始值  
  2. private int K = 3;//將要分紅的類別個數  
  3. private int maxClusterTimes = 500;//最大迭代次數  
  4. private List<List<T>> clusterList;//聚類的結果  
  5. private List<T> clusteringCenterT;//質心  

 

 

2、初始質心的選擇

      K-means聚類算法的結果很大程度收到初始質心的選取,這了爲了保證有充分的隨機性,對於初始質心的選擇這裏採用徹底隨機的方法,先把待分類的數據隨機打亂,而後把前K個樣本做爲初始質心(經過屢次迭代,會減小初始質心的影響)。

 

[java] view plain copy

 print?在CODE上查看代碼片派生到個人代碼片

  1. List<T> centerT = new ArrayList<T>(size);  
  2. //對數據進行打亂  
  3. Collections.shuffle(dataArray);  
  4. for (int i = 0; i < size; i++) {  
  5.     centerT.add(dataArray.get(i));  
  6. }  

 

 

3、一輪聚類

      在K-means算法中,大部分的時間都在作一輪一輪的聚類,具體功能也很簡單,就是對每個樣本分別計算和全部質心的類似度或距離,找到與該樣本最類似的質心或者距離最近的質心,而後把該樣本劃分到該類中,具體邏輯介紹參照代碼中的註釋。

[java] view plain copy

 print?在CODE上查看代碼片派生到個人代碼片

  1. private void clustering(List<T> preCenter, int times) {  
  2.     if (preCenter == null || preCenter.size() < 2) {  
  3.         return;  
  4.     }  
  5.     //打亂質心的順序  
  6.     Collections.shuffle(preCenter);  
  7.     List<List<T>> clusterList =  getListT(preCenter.size());  
  8.     for (T o1 : this.dataArray) {  
  9.         //尋找最類似的質心  
  10.         int max = 0;  
  11.         double maxScore = similarScore(o1, preCenter.get(0));  
  12.         for (int i = 1; i < preCenter.size(); i++) {  
  13.             if (maxScore < similarScore(o1, preCenter.get(i))) {  
  14.                 maxScore = similarScore(o1, preCenter.get(i));  
  15.                 max = i;  
  16.             }  
  17.         }  
  18.         clusterList.get(max).add(o1);  
  19.     }  
  20.     //計算本次聚類結果每一個類別的質心  
  21.     List<T> nowCenter = new ArrayList<T> ();  
  22.     for (List<T> list : clusterList) {  
  23.         nowCenter.add(getCenterT(list));  
  24.     }  
  25.     //是否達到最大迭代次數  
  26.     if (times >= this.maxClusterTimes || preCenter.size() < this.K) {  
  27.         this.clusterList = clusterList;  
  28.         return;  
  29.     }  
  30.     this.clusteringCenterT = nowCenter;  
  31.     //判斷質心是否發生移動,若是沒有移動,結束本次聚類,不然進行下一輪  
  32.     if (isCenterChange(preCenter, nowCenter)) {  
  33.         clear(clusterList);  
  34.         clustering(nowCenter, times + 1);  
  35.     } else {  
  36.         this.clusterList = clusterList;  
  37.     }  
  38. }  

 

 

4、質心是否移動

      在第三步中,提到了一個重要的步驟:每輪聚類結束後,都要從新計算質心,而且計算質心是否發生移動。對於新質心的計算、樣本之間的類似度和判斷兩個樣本是否相等這幾個功能因爲並不知道樣本的具體數據類型,所以把他們定義成抽象方法,供子類來實現。下面就重點介紹如何判斷質心是否發生移動。

 

[java] view plain copy

 print?在CODE上查看代碼片派生到個人代碼片

  1. private boolean isCenterChange(List<T> preT, List<T> nowT) {  
  2.     if (preT == null || nowT == null) {  
  3.         return false;  
  4.     }  
  5.     for (T t1 : preT) {  
  6.         boolean bol = true;  
  7.         for (T t2 : nowT) {  
  8.             if (equals(t1, t2)) {//t1在t2中有相等的,認爲該質心未移動  
  9.                 bol = false;  
  10.                 break;  
  11.             }  
  12.         }  
  13.         //有一個質心發生移動,認爲須要進行下一次計算  
  14.         if (bol) {  
  15.             return bol;  
  16.         }  
  17.     }  
  18.     return false;  
  19. }  

      從上述代碼能夠看到,算法的思想就是對於先後兩個質心數組分別前一組的質心是否在後一個質心組中出現,有一個沒有出現,就認爲質心發生了變更。

 

完整代碼

      上面四步已經完整的介紹了K-means算法的具體算法思想,下面就看下完整的代碼實現。

 

[java] view plain copy

 print?在CODE上查看代碼片派生到個人代碼片

  1.  /**   
  2.  *@Description:  K-means聚類 
  3.  */   
  4. package com.lulei.datamining.knn;    
  5.   
  6. import java.util.ArrayList;  
  7. import java.util.Collections;  
  8. import java.util.List;  
  9.     
  10. public abstract class KMeansClustering <T>{  
  11.     private List<T> dataArray;//待分類的原始值  
  12.     private int K = 3;//將要分紅的類別個數  
  13.     private int maxClusterTimes = 500;//最大迭代次數  
  14.     private List<List<T>> clusterList;//聚類的結果  
  15.     private List<T> clusteringCenterT;//質心  
  16.       
  17.     public int getK() {  
  18.         return K;  
  19.     }  
  20.     public void setK(int K) {  
  21.         if (K < 1) {  
  22.             throw new IllegalArgumentException("K must greater than 0");  
  23.         }  
  24.         this.K = K;  
  25.     }  
  26.     public int getMaxClusterTimes() {  
  27.         return maxClusterTimes;  
  28.     }  
  29.     public void setMaxClusterTimes(int maxClusterTimes) {  
  30.         if (maxClusterTimes < 10) {  
  31.             throw new IllegalArgumentException("maxClusterTimes must greater than 10");  
  32.         }  
  33.         this.maxClusterTimes = maxClusterTimes;  
  34.     }  
  35.     public List<T> getClusteringCenterT() {  
  36.         return clusteringCenterT;  
  37.     }  
  38.     /** 
  39.      * @return 
  40.      * @Author:lulei   
  41.      * @Description: 對數據進行聚類 
  42.      */  
  43.     public List<List<T>> clustering() {  
  44.         if (dataArray == null) {  
  45.             return null;  
  46.         }  
  47.         //初始K個點爲數組中的前K個點  
  48.         int size = K > dataArray.size() ? dataArray.size() : K;  
  49.         List<T> centerT = new ArrayList<T>(size);  
  50.         //對數據進行打亂  
  51.         Collections.shuffle(dataArray);  
  52.         for (int i = 0; i < size; i++) {  
  53.             centerT.add(dataArray.get(i));  
  54.         }  
  55.         clustering(centerT, 0);  
  56.         return clusterList;  
  57.     }  
  58.       
  59.     /** 
  60.      * @param preCenter 
  61.      * @param times 
  62.      * @Author:lulei   
  63.      * @Description: 一輪聚類 
  64.      */  
  65.     private void clustering(List<T> preCenter, int times) {  
  66.         if (preCenter == null || preCenter.size() < 2) {  
  67.             return;  
  68.         }  
  69.         //打亂質心的順序  
  70.         Collections.shuffle(preCenter);  
  71.         List<List<T>> clusterList =  getListT(preCenter.size());  
  72.         for (T o1 : this.dataArray) {  
  73.             //尋找最類似的質心  
  74.             int max = 0;  
  75.             double maxScore = similarScore(o1, preCenter.get(0));  
  76.             for (int i = 1; i < preCenter.size(); i++) {  
  77.                 if (maxScore < similarScore(o1, preCenter.get(i))) {  
  78.                     maxScore = similarScore(o1, preCenter.get(i));  
  79.                     max = i;  
  80.                 }  
  81.             }  
  82.             clusterList.get(max).add(o1);  
  83.         }  
  84.         //計算本次聚類結果每一個類別的質心  
  85.         List<T> nowCenter = new ArrayList<T> ();  
  86.         for (List<T> list : clusterList) {  
  87.             nowCenter.add(getCenterT(list));  
  88.         }  
  89.         //是否達到最大迭代次數  
  90.         if (times >= this.maxClusterTimes || preCenter.size() < this.K) {  
  91.             this.clusterList = clusterList;  
  92.             return;  
  93.         }  
  94.         this.clusteringCenterT = nowCenter;  
  95.         //判斷質心是否發生移動,若是沒有移動,結束本次聚類,不然進行下一輪  
  96.         if (isCenterChange(preCenter, nowCenter)) {  
  97.             clear(clusterList);  
  98.             clustering(nowCenter, times + 1);  
  99.         } else {  
  100.             this.clusterList = clusterList;  
  101.         }  
  102.     }  
  103.       
  104.     /** 
  105.      * @param size 
  106.      * @return 
  107.      * @Author:lulei   
  108.      * @Description: 初始化一個聚類結果 
  109.      */  
  110.     private List<List<T>> getListT(int size) {  
  111.         List<List<T>> list = new ArrayList<List<T>>(size);  
  112.         for (int i = 0; i < size; i++) {  
  113.             list.add(new ArrayList<T>());  
  114.         }  
  115.         return list;  
  116.     }  
  117.       
  118.     /** 
  119.      * @param lists 
  120.      * @Author:lulei   
  121.      * @Description: 清空無用數組 
  122.      */  
  123.     private void clear(List<List<T>> lists) {  
  124.         for (List<T> list : lists) {  
  125.             list.clear();  
  126.         }  
  127.         lists.clear();  
  128.     }  
  129.       
  130.     /** 
  131.      * @param value 
  132.      * @Author:lulei   
  133.      * @Description: 向模型中添加記錄 
  134.      */  
  135.     public void addRecord(T value) {  
  136.         if (dataArray == null) {  
  137.             dataArray = new ArrayList<T>();  
  138.         }  
  139.         dataArray.add(value);  
  140.     }  
  141.       
  142.     /** 
  143.      * @param preT 
  144.      * @param nowT 
  145.      * @return 
  146.      * @Author:lulei   
  147.      * @Description: 判斷質心是否發生移動 
  148.      */  
  149.     private boolean isCenterChange(List<T> preT, List<T> nowT) {  
  150.         if (preT == null || nowT == null) {  
  151.             return false;  
  152.         }  
  153.         for (T t1 : preT) {  
  154.             boolean bol = true;  
  155.             for (T t2 : nowT) {  
  156.                 if (equals(t1, t2)) {//t1在t2中有相等的,認爲該質心未移動  
  157.                     bol = false;  
  158.                     break;  
  159.                 }  
  160.             }  
  161.             //有一個質心發生移動,認爲須要進行下一次計算  
  162.             if (bol) {  
  163.                 return bol;  
  164.             }  
  165.         }  
  166.         return false;  
  167.     }  
  168.       
  169.     /** 
  170.      * @param o1 
  171.      * @param o2 
  172.      * @return 
  173.      * @Author:lulei   
  174.      * @Description: o1 o2之間的類似度 
  175.      */  
  176.     public abstract double similarScore(T o1, T o2);  
  177.       
  178.     /** 
  179.      * @param o1 
  180.      * @param o2 
  181.      * @return 
  182.      * @Author:lulei   
  183.      * @Description: 判斷o1 o2是否相等 
  184.      */  
  185.     public abstract boolean equals(T o1, T o2);  
  186.       
  187.     /** 
  188.      * @param list 
  189.      * @return 
  190.      * @Author:lulei   
  191.      * @Description: 求一組數據的質心 
  192.      */  
  193.     public abstract T getCenterT(List<T> list);  
  194. }  

 

二維數聚類實現

      在算法描述中,介紹了一個200,000個點聚成34個類別的效果圖,下面就針對二維座標數據實現其具體子類。

 

1、類似度

      對於二維座標的類似度,這裏咱們採起兩點間聚類的相反數,具體實現以下:

 

[java] view plain copy

 print?在CODE上查看代碼片派生到個人代碼片

  1. @Override  
  2. public double similarScore(XYbean o1, XYbean o2) {  
  3.     double distance = Math.sqrt((o1.getX() - o2.getX()) * (o1.getX() - o2.getX()) + (o1.getY() - o2.getY()) * (o1.getY() - o2.getY()));  
  4.     return distance * -1;  
  5. }  

 

 

2、樣本/質心是否相等

      判斷樣本/質心是否相等只須要判斷兩點的座標是否相等便可,具體實現以下:

 

[java] view plain copy

 print?在CODE上查看代碼片派生到個人代碼片

  1. @Override  
  2. public boolean equals(XYbean o1, XYbean o2) {  
  3.     return o1.getX() == o2.getX() && o1.getY() == o2.getY();  
  4. }  

 

 

3、獲取一個分類下的新質心

      對於二維座標數據,可使用全部點的重心做爲分類的質心,具體以下:

 

[java] view plain copy

 print?在CODE上查看代碼片派生到個人代碼片

  1. @Override  
  2. public XYbean getCenterT(List<XYbean> list) {  
  3.     int x = 0;  
  4.     int y = 0;  
  5.     try {  
  6.         for (XYbean xy : list) {  
  7.             x += xy.getX();  
  8.             y += xy.getY();  
  9.         }  
  10.         x = x / list.size();  
  11.         y = y / list.size();  
  12.     } catch(Exception e) {  
  13.           
  14.     }  
  15.     return new XYbean(x, y);  
  16. }  

 

 

4、main方法

      對於具體二維座標的源碼這裏就再也不貼出來,就是實現前面介紹的抽象類,並實現其中的3個抽象方法,下面咱們就隨機產生200,000個點,而後聚成34個類別,具體代碼以下:

 

[java] view plain copy

 print?在CODE上查看代碼片派生到個人代碼片

  1. public static void main(String[] args) {  
  2.       
  3.     int width = 600;  
  4.     int height = 400;  
  5.     int K = 34;  
  6.     XYCluster xyCluster = new XYCluster();  
  7.     for (int i = 0; i < 200000; i++) {  
  8.         int x = (int)(Math.random() * width) + 1;  
  9.         int y = (int)(Math.random() * height) + 1;  
  10.         xyCluster.addRecord(new XYbean(x, y));  
  11.     }  
  12.     xyCluster.setK(K);  
  13.     long a = System.currentTimeMillis();  
  14.     List<List<XYbean>> cresult = xyCluster.clustering();  
  15.     List<XYbean> center = xyCluster.getClusteringCenterT();  
  16.     System.out.println(JsonUtil.parseJson(center));  
  17.     long b = System.currentTimeMillis();  
  18.     System.out.println("耗時:" + (b - a) + "ms");  
  19.     new ImgUtil().drawXYbeans(width, height, cresult, "d:/2.png", 0, 0);  
  20. }  

 

 

      對於這隨機產生的200,000個點聚成34類,總耗時5485ms。(計算機配置:i5 + 8G內存)

相關文章
相關標籤/搜索