JAVA實現KNN分類

 KNN算法又叫近鄰算法,是數據挖掘中一種經常使用的分類算法,接單的介紹KNN算法的核心思想就是:尋找與目標最近的K個個體,這些樣本屬於類別最多的那個類別就是目標的類別。好比K爲7,那麼咱們就從數據中找到和目標最近(或者類似度最高)的7個樣本,加入這7個樣本對應的類別分別爲A、B、C、A、A、A、B,那麼目標屬於的分類就是A(由於這7個樣本中屬於A類別的樣本個數最多)。java

 

算法實現android

1、訓練數據格式定義算法

      下面就簡單的介紹下如何用Java來實現KNN分類,首先咱們須要存儲訓練集(包括屬性以及對應的類別),這裏咱們對未知的屬性使用泛型,類別咱們使用字符串存儲。apache

 

[java] view plain copy數組

 print?在CODE上查看代碼片派生到個人代碼片app

  1.  /**   
  2.  *@Description:  KNN分類模型中一條記錄的存儲格式 
  3.  */   
  4. package com.lulei.datamining.knn.bean;    
  5.     
  6. public class KnnValueBean<T>{  
  7.     private T value;//記錄值  
  8.     private String typeId;//分類ID  
  9.       
  10.     public KnnValueBean(T value, String typeId) {  
  11.         this.value = value;  
  12.         this.typeId = typeId;  
  13.     }  
  14.   
  15.     public T getValue() {  
  16.         return value;  
  17.     }  
  18.   
  19.     public void setValue(T value) {  
  20.         this.value = value;  
  21.     }  
  22.   
  23.     public String getTypeId() {  
  24.         return typeId;  
  25.     }  
  26.   
  27.     public void setTypeId(String typeId) {  
  28.         this.typeId = typeId;  
  29.     }  
  30. }  


2、K個最近鄰類別數據格式定義ide

 

      在統計獲得K個最近鄰中,咱們須要記錄前K個樣本的分類以及對應的類似度,咱們這裏使用以下數據格式:函數

 

[java] view plain copy工具

 print?在CODE上查看代碼片派生到個人代碼片測試

  1.  /**   
  2.  *@Description: K個最近鄰的類別得分 
  3.  */   
  4. package com.lulei.datamining.knn.bean;    
  5.     
  6. public class KnnValueSort {  
  7.     private String typeId;//分類ID  
  8.     private double score;//該分類得分  
  9.       
  10.     public KnnValueSort(String typeId, double score) {  
  11.         this.typeId = typeId;  
  12.         this.score = score;  
  13.     }  
  14.     public String getTypeId() {  
  15.         return typeId;  
  16.     }  
  17.     public void setTypeId(String typeId) {  
  18.         this.typeId = typeId;  
  19.     }  
  20.     public double getScore() {  
  21.         return score;  
  22.     }  
  23.     public void setScore(double score) {  
  24.         this.score = score;  
  25.     }  
  26. }  


3、KNN算法基本屬性

 

      在KNN算法中,最重要的一個指標就是K的取值,所以咱們在基類中須要設置一個屬性K以及設置一個數組用於存儲已知分類的數據。

 

[java] view plain copy

 print?在CODE上查看代碼片派生到個人代碼片

  1. private List<KnnValueBean> dataArray;  
  2. private int K = 3;  


4、添加已知分類數據

 

      在使用KNN分類以前,咱們須要先向其中添加咱們已知分類的數據,咱們後面就是使用這些數據來預測未知數據的分類。

 

[java] view plain copy

 print?在CODE上查看代碼片派生到個人代碼片

  1. /** 
  2.  * @param value 
  3.  * @param typeId 
  4.  * @Author:lulei   
  5.  * @Description: 向模型中添加記錄 
  6.  */  
  7. public void addRecord(T value, String typeId) {  
  8.     if (dataArray == null) {  
  9.         dataArray = new ArrayList<KnnValueBean>();  
  10.     }  
  11.     dataArray.add(new KnnValueBean<T>(value, typeId));  
  12. }  


5、兩個樣本之間的類似度(或者距離)

 

      在KNN算法中,最重要的一個方法就是如何肯定兩個樣本之間的類似度(或者距離),因爲這裏咱們使用的是泛型,並無辦法肯定兩個對象之間的類似度,一次這裏咱們把它設置爲抽象方法,讓子類來實現。這裏咱們方法定義爲類似度,也就是返回值越大,二者越類似,之間的距離越短

 

[java] view plain copy

 print?在CODE上查看代碼片派生到個人代碼片

  1. /** 
  2.  * @param o1 
  3.  * @param o2 
  4.  * @return 
  5.  * @Author:lulei   
  6.  * @Description: o1 o2之間的類似度 
  7.  */  
  8. public abstract double similarScore(T o1, T o2);  


6、獲取最近的K個樣本的分類

 

      KNN算法的核心思想就是找到最近的K個近鄰,所以這一步也是整個算法的核心部分。這裏咱們使用數組來保存類似度最大的K個樣本的分類和類似度,在計算的過程當中經過循環遍歷全部的樣本,數組保存截至當前計算點最類似的K個樣本對應的類別和類似度,具體實現以下:

 

[java] view plain copy

 print?在CODE上查看代碼片派生到個人代碼片

  1. /** 
  2.  * @param value 
  3.  * @return 
  4.  * @Author:lulei   
  5.  * @Description: 獲取距離最近的K個分類 
  6.  */  
  7. private KnnValueSort[] getKType(T value) {  
  8.     int k = 0;  
  9.     KnnValueSort[] topK = new KnnValueSort[K];  
  10.     for (KnnValueBean<T> bean : dataArray) {  
  11.         double score = similarScore(bean.getValue(), value);  
  12.         if (k == 0) {  
  13.             //數組中的記錄個數爲0是直接添加  
  14.             topK[k] = new KnnValueSort(bean.getTypeId(), score);  
  15.             k++;  
  16.         } else {  
  17.             if (!(k == K && score < topK[k -1].getScore())){  
  18.                 int i = 0;  
  19.                 //找到要插入的點  
  20.                 for (; i < k && score < topK[i].getScore(); i++);  
  21.                 int j = k - 1;  
  22.                 if (k < K) {  
  23.                     j = k;  
  24.                     k++;  
  25.                 }  
  26.                 for (; j > i; j--) {  
  27.                     topK[j] = topK[j - 1];  
  28.                 }  
  29.                 topK[i] = new KnnValueSort(bean.getTypeId(), score);  
  30.             }  
  31.         }  
  32.     }  
  33.     return topK;  
  34. }  


7、統計K個樣本出現次數最多的類別

 

      這一步就是一個簡單的計數,統計K個樣本中出現次數最多的分類,該分類就是咱們要預測的目標數據的分類。

 

[java] view plain copy

 print?在CODE上查看代碼片派生到個人代碼片

  1. /** 
  2.  * @param value 
  3.  * @return 
  4.  * @Author:lulei   
  5.  * @Description: KNN分類判斷value的類別 
  6.  */  
  7. public String getTypeId(T value) {  
  8.     KnnValueSort[] array = getKType(value);  
  9.     HashMap<String, Integer> map = new HashMap<String, Integer>(K);  
  10.     for (KnnValueSort bean : array) {  
  11.         if (bean != null) {  
  12.             if (map.containsKey(bean.getTypeId())) {  
  13.                 map.put(bean.getTypeId(), map.get(bean.getTypeId()) + 1);  
  14.             } else {  
  15.                 map.put(bean.getTypeId(), 1);  
  16.             }  
  17.         }  
  18.     }  
  19.     String maxTypeId = null;  
  20.     int maxCount = 0;  
  21.     Iterator<Entry<String, Integer>> iter = map.entrySet().iterator();  
  22.     while (iter.hasNext()) {  
  23.         Entry<String, Integer> entry = iter.next();  
  24.         if (maxCount < entry.getValue()) {  
  25.             maxCount = entry.getValue();  
  26.             maxTypeId = entry.getKey();  
  27.         }  
  28.     }  
  29.     return maxTypeId;  
  30. }  


      到如今爲止KNN分類的抽象基類已經編寫完成,在測試以前咱們先多說幾句,KNN分類是統計K個樣本中出現次數最多的分類,這種在有些狀況下並非特別合理,好比K=5,前5個樣本對應的分類分別爲A、A、B、B、B,對應的類似度得分分別爲十、九、二、二、1,若是使用上面的方法,那預測的分類就是B,可是看這些數據,預測的分類是A感受更合理。基於這種狀況,本身對KNN算法提出以下優化(這裏並不提供代碼,只提供簡單的思路):在獲取最類似K個樣本和類似度後,能夠對類似度和出現次數K作一種函數運算,好比加權,獲得的函數值最大的分類就是目標的預測分類。

 

基類源碼

 

[java] view plain copy

 print?在CODE上查看代碼片派生到個人代碼片

  1.  /**   
  2.  *@Description: KNN分類 
  3.  */   
  4. package com.lulei.datamining.knn;    
  5.   
  6. import java.util.ArrayList;  
  7. import java.util.HashMap;  
  8. import java.util.Iterator;  
  9. import java.util.List;  
  10. import java.util.Map.Entry;  
  11.   
  12. import com.lulei.datamining.knn.bean.KnnValueBean;  
  13. import com.lulei.datamining.knn.bean.KnnValueSort;  
  14. import com.lulei.util.JsonUtil;  
  15.     
  16. @SuppressWarnings({"rawtypes"})  
  17. public abstract class KnnClassification<T> {  
  18.     private List<KnnValueBean> dataArray;  
  19.     private int K = 3;  
  20.       
  21.     public int getK() {  
  22.         return K;  
  23.     }  
  24.     public void setK(int K) {  
  25.         if (K < 1) {  
  26.             throw new IllegalArgumentException("K must greater than 0");  
  27.         }  
  28.         this.K = K;  
  29.     }  
  30.   
  31.     /** 
  32.      * @param value 
  33.      * @param typeId 
  34.      * @Author:lulei   
  35.      * @Description: 向模型中添加記錄 
  36.      */  
  37.     public void addRecord(T value, String typeId) {  
  38.         if (dataArray == null) {  
  39.             dataArray = new ArrayList<KnnValueBean>();  
  40.         }  
  41.         dataArray.add(new KnnValueBean<T>(value, typeId));  
  42.     }  
  43.       
  44.     /** 
  45.      * @param value 
  46.      * @return 
  47.      * @Author:lulei   
  48.      * @Description: KNN分類判斷value的類別 
  49.      */  
  50.     public String getTypeId(T value) {  
  51.         KnnValueSort[] array = getKType(value);  
  52.         System.out.println(JsonUtil.parseJson(array));  
  53.         HashMap<String, Integer> map = new HashMap<String, Integer>(K);  
  54.         for (KnnValueSort bean : array) {  
  55.             if (bean != null) {  
  56.                 if (map.containsKey(bean.getTypeId())) {  
  57.                     map.put(bean.getTypeId(), map.get(bean.getTypeId()) + 1);  
  58.                 } else {  
  59.                     map.put(bean.getTypeId(), 1);  
  60.                 }  
  61.             }  
  62.         }  
  63.         String maxTypeId = null;  
  64.         int maxCount = 0;  
  65.         Iterator<Entry<String, Integer>> iter = map.entrySet().iterator();  
  66.         while (iter.hasNext()) {  
  67.             Entry<String, Integer> entry = iter.next();  
  68.             if (maxCount < entry.getValue()) {  
  69.                 maxCount = entry.getValue();  
  70.                 maxTypeId = entry.getKey();  
  71.             }  
  72.         }  
  73.         return maxTypeId;  
  74.     }  
  75.       
  76.     /** 
  77.      * @param value 
  78.      * @return 
  79.      * @Author:lulei   
  80.      * @Description: 獲取距離最近的K個分類 
  81.      */  
  82.     private KnnValueSort[] getKType(T value) {  
  83.         int k = 0;  
  84.         KnnValueSort[] topK = new KnnValueSort[K];  
  85.         for (KnnValueBean<T> bean : dataArray) {  
  86.             double score = similarScore(bean.getValue(), value);  
  87.             if (k == 0) {  
  88.                 //數組中的記錄個數爲0是直接添加  
  89.                 topK[k] = new KnnValueSort(bean.getTypeId(), score);  
  90.                 k++;  
  91.             } else {  
  92.                 if (!(k == K && score < topK[k -1].getScore())){  
  93.                     int i = 0;  
  94.                     //找到要插入的點  
  95.                     for (; i < k && score < topK[i].getScore(); i++);  
  96.                     int j = k - 1;  
  97.                     if (k < K) {  
  98.                         j = k;  
  99.                         k++;  
  100.                     }  
  101.                     for (; j > i; j--) {  
  102.                         topK[j] = topK[j - 1];  
  103.                     }  
  104.                     topK[i] = new KnnValueSort(bean.getTypeId(), score);  
  105.                 }  
  106.             }  
  107.         }  
  108.         return topK;  
  109.     }  
  110.       
  111.     /** 
  112.      * @param o1 
  113.      * @param o2 
  114.      * @return 
  115.      * @Author:lulei   
  116.      * @Description: o1 o2之間的類似度 
  117.      */  
  118.     public abstract double similarScore(T o1, T o2);  
  119. }  

 

 

具體子類實現

      對於上面介紹的都在KNN分類的抽象基類中,對於實際的問題咱們須要繼承基類並實現基類中的類似度抽象方法,這裏咱們作一個簡單的實現。

 

[java] view plain copy

 print?在CODE上查看代碼片派生到個人代碼片

  1.  /**   
  2.  *@Description:      
  3.  */   
  4. package com.lulei.datamining.knn.test;    
  5.   
  6. import com.lulei.datamining.knn.KnnClassification;  
  7. import com.lulei.util.JsonUtil;  
  8.     
  9. public class Test extends KnnClassification<Integer>{  
  10.       
  11.     @Override  
  12.     public double similarScore(Integer o1, Integer o2) {  
  13.         return -1 * Math.abs(o1 - o2);  
  14.     }  
  15.       
  16.     /**   
  17.      * @param args 
  18.      * @Author:lulei   
  19.      * @Description:   
  20.      */  
  21.     public static void main(String[] args) {  
  22.         Test test = new Test();  
  23.         for (int i = 1; i < 10; i++) {  
  24.             test.addRecord(i, i > 5 ? "0" : "1");  
  25.         }  
  26.         System.out.println(JsonUtil.parseJson(test.getTypeId(0)));  
  27.           
  28.     }  
  29. }  

 

 

      這裏咱們一共添加了一、二、三、四、五、六、七、八、9這9組數據,前5組的類別爲1,後4組的類別爲0,兩個數據之間的類似度爲二者之間的差值的絕對值的相反數,下面預測0應該屬於的分類,這裏K的默認值爲3,所以最近的K個樣本分別爲一、二、3,對應的分類分別爲"1"、"1"、"1",由於最後預測的分類爲"1"。

 

KNN算法全名爲k-Nearest Neighbor,就是K最近鄰的意思。KNN也是一種分類算法。可是與以前說的決策樹分類算法相比,這個算法算是最簡單的一個了。算法的主要過程爲:

一、給定一個訓練集數據,每一個訓練集數據都是已經分好類的。
二、設定一個初始的測試數據a,計算a到訓練集全部數據的歐幾里得距離,並排序。                       

三、選出訓練集中離a距離最近的K個訓練集數據。

四、比較k個訓練集數據,選出裏面出現最多的分類類型,此分類類型即爲最終測試數據a的分類。

下面百度百科上的一張簡圖:

KNN算法實現

首先測試數據須要2塊,1個是訓練集數據,就是已經分好類的數據,好比上圖中的非綠色的點。還有一個是測試數據,就是上面的綠點,固然這裏的測試數據不會是一個,而是一組。這裏的數據與數據之間的距離用數據的特徵向量作計算,特徵向量能夠是多維度的。經過計算特徵向量與特徵向量之間的歐幾里得距離來推算類似度。定義訓練集數據trainInput.txt:

 

[java] view plain copy

 print?

  1. a 1 2 3 4 5   
  2. b 5 4 3 2 1   
  3. c 3 3 3 3 3   
  4. d -3 -3 -3 -3 -3   
  5. a 1 2 3 4 4   
  6. b 4 4 3 2 1   
  7. c 3 3 3 2 4   
  8. d 0 0 1 1 -2   

待測試數據testInput,只有特徵向量值:

 

 

[java] view plain copy

 print?

  1. 1 2 3 2 4   
  2. 2 3 4 2 1   
  3. 8 7 2 3 5   
  4. -3 -2 2 4 0   
  5. -4 -4 -4 -4 -4   
  6. 1 2 3 4 4   
  7. 4 4 3 2 1   
  8. 3 3 3 2 4   
  9. 0 0 1 1 -2   

下面是主程序:

 

 

[java] view plain copy

 print?

  1. package DataMing_KNN;  
  2.   
  3. import java.io.BufferedReader;  
  4. import java.io.File;  
  5. import java.io.FileReader;  
  6. import java.io.IOException;  
  7. import java.util.ArrayList;  
  8. import java.util.Arrays;  
  9. import java.util.Collection;  
  10. import java.util.Collections;  
  11. import java.util.Comparator;  
  12. import java.util.HashMap;  
  13. import java.util.Map;  
  14.   
  15. import org.apache.activemq.filter.ComparisonExpression;  
  16.   
  17. /** 
  18.  * k最近鄰算法工具類 
  19.  *  
  20.  * @author lyq 
  21.  *  
  22.  */  
  23. public class KNNTool {  
  24.     // 爲4個類別設置權重,默認權重比一致  
  25.     public int[] classWeightArray = new int[] { 1, 1, 1, 1 };  
  26.     // 測試數據地址  
  27.     private String testDataPath;  
  28.     // 訓練集數據地址  
  29.     private String trainDataPath;  
  30.     // 分類的不一樣類型  
  31.     private ArrayList<String> classTypes;  
  32.     // 結果數據  
  33.     private ArrayList<Sample> resultSamples;  
  34.     // 訓練集數據列表容器  
  35.     private ArrayList<Sample> trainSamples;  
  36.     // 訓練集數據  
  37.     private String[][] trainData;  
  38.     // 測試集數據  
  39.     private String[][] testData;  
  40.   
  41.     public KNNTool(String trainDataPath, String testDataPath) {  
  42.         this.trainDataPath = trainDataPath;  
  43.         this.testDataPath = testDataPath;  
  44.         readDataFormFile();  
  45.     }  
  46.   
  47.     /** 
  48.      * 從文件中閱讀測試數和訓練數據集 
  49.      */  
  50.     private void readDataFormFile() {  
  51.         ArrayList<String[]> tempArray;  
  52.   
  53.         tempArray = fileDataToArray(trainDataPath);  
  54.         trainData = new String[tempArray.size()][];  
  55.         tempArray.toArray(trainData);  
  56.   
  57.         classTypes = new ArrayList<>();  
  58.         for (String[] s : tempArray) {  
  59.             if (!classTypes.contains(s[0])) {  
  60.                 // 添加類型  
  61.                 classTypes.add(s[0]);  
  62.             }  
  63.         }  
  64.   
  65.         tempArray = fileDataToArray(testDataPath);  
  66.         testData = new String[tempArray.size()][];  
  67.         tempArray.toArray(testData);  
  68.     }  
  69.   
  70.     /** 
  71.      * 將文件轉爲列表數據輸出 
  72.      *  
  73.      * @param filePath 
  74.      *            數據文件的內容 
  75.      */  
  76.     private ArrayList<String[]> fileDataToArray(String filePath) {  
  77.         File file = new File(filePath);  
  78.         ArrayList<String[]> dataArray = new ArrayList<String[]>();  
  79.   
  80.         try {  
  81.             BufferedReader in = new BufferedReader(new FileReader(file));  
  82.             String str;  
  83.             String[] tempArray;  
  84.             while ((str = in.readLine()) != null) {  
  85.                 tempArray = str.split(" ");  
  86.                 dataArray.add(tempArray);  
  87.             }  
  88.             in.close();  
  89.         } catch (IOException e) {  
  90.             e.getStackTrace();  
  91.         }  
  92.   
  93.         return dataArray;  
  94.     }  
  95.   
  96.     /** 
  97.      * 計算樣本特徵向量的歐幾里得距離 
  98.      *  
  99.      * @param f1 
  100.      *            待比較樣本1 
  101.      * @param f2 
  102.      *            待比較樣本2 
  103.      * @return 
  104.      */  
  105.     private int computeEuclideanDistance(Sample s1, Sample s2) {  
  106.         String[] f1 = s1.getFeatures();  
  107.         String[] f2 = s2.getFeatures();  
  108.         // 歐幾里得距離  
  109.         int distance = 0;  
  110.   
  111.         for (int i = 0; i < f1.length; i++) {  
  112.             int subF1 = Integer.parseInt(f1[i]);  
  113.             int subF2 = Integer.parseInt(f2[i]);  
  114.   
  115.             distance += (subF1 - subF2) * (subF1 - subF2);  
  116.         }  
  117.   
  118.         return distance;  
  119.     }  
  120.   
  121.     /** 
  122.      * 計算K最近鄰 
  123.      * @param k 
  124.      * 在多少的k範圍內 
  125.      */  
  126.     public void knnCompute(int k) {  
  127.         String className = "";  
  128.         String[] tempF = null;  
  129.         Sample temp;  
  130.         resultSamples = new ArrayList<>();  
  131.         trainSamples = new ArrayList<>();  
  132.         // 分類類別計數  
  133.         HashMap<String, Integer> classCount;  
  134.         // 類別權重比  
  135.         HashMap<String, Integer> classWeight = new HashMap<>();  
  136.         // 首先講測試數據轉化到結果數據中  
  137.         for (String[] s : testData) {  
  138.             temp = new Sample(s);  
  139.             resultSamples.add(temp);  
  140.         }  
  141.   
  142.         for (String[] s : trainData) {  
  143.             className = s[0];  
  144.             tempF = new String[s.length - 1];  
  145.             System.arraycopy(s, 1, tempF, 0, s.length - 1);  
  146.             temp = new Sample(className, tempF);  
  147.             trainSamples.add(temp);  
  148.         }  
  149.   
  150.         // 離樣本最近排序的的訓練集數據  
  151.         ArrayList<Sample> kNNSample = new ArrayList<>();  
  152.         // 計算訓練數據集中離樣本數據最近的K個訓練集數據  
  153.         for (Sample s : resultSamples) {  
  154.             classCount = new HashMap<>();  
  155.             int index = 0;  
  156.             for (String type : classTypes) {  
  157.                 // 開始時計數爲0  
  158.                 classCount.put(type, 0);  
  159.                 classWeight.put(type, classWeightArray[index++]);  
  160.             }  
  161.             for (Sample tS : trainSamples) {  
  162.                 int dis = computeEuclideanDistance(s, tS);  
  163.                 tS.setDistance(dis);  
  164.             }  
  165.   
  166.             Collections.sort(trainSamples);  
  167.             kNNSample.clear();  
  168.             // 挑選出前k個數據做爲分類標準  
  169.             for (int i = 0; i < trainSamples.size(); i++) {  
  170.                 if (i < k) {  
  171.                     kNNSample.add(trainSamples.get(i));  
  172.                 } else {  
  173.                     break;  
  174.                 }  
  175.             }  
  176.             // 斷定K個訓練數據的多數的分類標準  
  177.             for (Sample s1 : kNNSample) {  
  178.                 int num = classCount.get(s1.getClassName());  
  179.                 // 進行分類權重的疊加,默認類別權重平等,可自行改變,近的權重大,遠的權重小  
  180.                 num += classWeight.get(s1.getClassName());  
  181.                 classCount.put(s1.getClassName(), num);  
  182.             }  
  183.   
  184.             int maxCount = 0;  
  185.             // 篩選出k個訓練集數據中最多的一個分類  
  186.             for (Map.Entry entry : classCount.entrySet()) {  
  187.                 if ((Integer) entry.getValue() > maxCount) {  
  188.                     maxCount = (Integer) entry.getValue();  
  189.                     s.setClassName((String) entry.getKey());  
  190.                 }  
  191.             }  
  192.   
  193.             System.out.print("測試數據特徵:");  
  194.             for (String s1 : s.getFeatures()) {  
  195.                 System.out.print(s1 + " ");  
  196.             }  
  197.             System.out.println("分類:" + s.getClassName());  
  198.         }  
  199.     }  
  200. }  

Sample樣本數據類:

 

 

[java] view plain copy

 print?

  1. package DataMing_KNN;  
  2.   
  3. /** 
  4.  * 樣本數據類 
  5.  *  
  6.  * @author lyq 
  7.  *  
  8.  */  
  9. public class Sample implements Comparable<Sample>{  
  10.     // 樣本數據的分類名稱  
  11.     private String className;  
  12.     // 樣本數據的特徵向量  
  13.     private String[] features;  
  14.     //測試樣本之間的間距值,以此作排序  
  15.     private Integer distance;  
  16.       
  17.     public Sample(String[] features){  
  18.         this.features = features;  
  19.     }  
  20.       
  21.     public Sample(String className, String[] features){  
  22.         this.className = className;  
  23.         this.features = features;  
  24.     }  
  25.   
  26.     public String getClassName() {  
  27.         return className;  
  28.     }  
  29.   
  30.     public void setClassName(String className) {  
  31.         this.className = className;  
  32.     }  
  33.   
  34.     public String[] getFeatures() {  
  35.         return features;  
  36.     }  
  37.   
  38.     public void setFeatures(String[] features) {  
  39.         this.features = features;  
  40.     }  
  41.   
  42.     public Integer getDistance() {  
  43.         return distance;  
  44.     }  
  45.   
  46.     public void setDistance(int distance) {  
  47.         this.distance = distance;  
  48.     }  
  49.   
  50.     @Override  
  51.     public int compareTo(Sample o) {  
  52.         // TODO Auto-generated method stub  
  53.         return this.getDistance().compareTo(o.getDistance());  
  54.     }  
  55.       
  56. }  

測試場景類:

 

 

[java] view plain copy

 print?

  1. /** 
  2.  * k最近鄰算法場景類型 
  3.  * @author lyq 
  4.  * 
  5.  */  
  6. public class Client {  
  7.     public static void main(String[] args){  
  8.         String trainDataPath = "C:\\Users\\lyq\\Desktop\\icon\\trainInput.txt";  
  9.         String testDataPath = "C:\\Users\\lyq\\Desktop\\icon\\testinput.txt";  
  10.           
  11.         KNNTool tool = new KNNTool(trainDataPath, testDataPath);  
  12.         tool.knnCompute(3);  
  13.           
  14.     }  
  15.       
  16.   
  17.   
  18. }  

執行的結果爲:

 

 

[java] view plain copy

 print?

  1. 測試數據特徵:1 2 3 2 4 分類:a  
  2. 測試數據特徵:2 3 4 2 1 分類:c  
  3. 測試數據特徵:8 7 2 3 5 分類:b  
  4. 測試數據特徵:-3 -2 2 4 0 分類:a  
  5. 測試數據特徵:-4 -4 -4 -4 -4 分類:d  
  6. 測試數據特徵:1 2 3 4 4 分類:a  
  7. 測試數據特徵:4 4 3 2 1 分類:b  
  8. 測試數據特徵:3 3 3 2 4 分類:c  
  9. 測試數據特徵:0 0 1 1 -2 分類:d  

 

程序的輸出結果如上所示,若是不相信的話能夠本身動手計算進行驗證。

KNN算法的注意點:

一、knn算法的訓練集數據必需要相對公平,各個類型的數據數量應該是平均的,不然當A數據由1000個B數據由100個,到時不管如何A數據的樣本仍是佔優的。

二、knn算法若是純粹憑藉分類的多少作判斷,仍是能夠繼續優化的,好比近的數據的權重能夠設大,最後根據全部的類型權重和進行比較,而不是單純的憑藉數量。

三、knn算法的缺點是計算量大,這個從程序中也應該看得出來,裏面每一個測試數據都要計算到全部的訓練集數據之間的歐式距離,時間複雜度就已經爲O(n*n),若是真實數據的n很是大,這個算法的開銷的確態度,因此KNN不適合大規模數據量的分類。

KNN算法編碼時遇到的困難:

按理來講這麼簡單的KNN算法本應該是沒有多少的難度,可是在多歐式距離的排序上被深深的坑了一段時間,本人起初用Collections.sort(list)的方式進行按距離排序,也把Sample類實現了Compareable接口,可是排序就是不變,最後才知道,distance的int類型要改成Integer引用類型,在compareTo重載方法中調用distance的.CompareTo()方法就成功了,這個小細節平時沒注意,難道屬性的比較最終必定要調用到引用類型的compareTo()方法?這個小問題居然花費了我一段時間,最後仔細的比較了一下網上的例子最後才發現......

相關文章
相關標籤/搜索