KNN算法又叫近鄰算法,是數據挖掘中一種經常使用的分類算法,接單的介紹KNN算法的核心思想就是:尋找與目標最近的K個個體,這些樣本屬於類別最多的那個類別就是目標的類別。好比K爲7,那麼咱們就從數據中找到和目標最近(或者類似度最高)的7個樣本,加入這7個樣本對應的類別分別爲A、B、C、A、A、A、B,那麼目標屬於的分類就是A(由於這7個樣本中屬於A類別的樣本個數最多)。java
算法實現android
1、訓練數據格式定義算法
下面就簡單的介紹下如何用Java來實現KNN分類,首先咱們須要存儲訓練集(包括屬性以及對應的類別),這裏咱們對未知的屬性使用泛型,類別咱們使用字符串存儲。apache
[java] view plain copy數組
print?![在CODE上查看代碼片](http://static.javashuo.com/static/loading.gif)
app
- /**
- *@Description: KNN分類模型中一條記錄的存儲格式
- */
- package com.lulei.datamining.knn.bean;
-
- public class KnnValueBean<T>{
- private T value;//記錄值
- private String typeId;//分類ID
-
- public KnnValueBean(T value, String typeId) {
- this.value = value;
- this.typeId = typeId;
- }
-
- public T getValue() {
- return value;
- }
-
- public void setValue(T value) {
- this.value = value;
- }
-
- public String getTypeId() {
- return typeId;
- }
-
- public void setTypeId(String typeId) {
- this.typeId = typeId;
- }
- }
2、K個最近鄰類別數據格式定義ide
在統計獲得K個最近鄰中,咱們須要記錄前K個樣本的分類以及對應的類似度,咱們這裏使用以下數據格式:函數
[java] view plain copy工具
print?![在CODE上查看代碼片](http://static.javashuo.com/static/loading.gif)
測試
- /**
- *@Description: K個最近鄰的類別得分
- */
- package com.lulei.datamining.knn.bean;
-
- public class KnnValueSort {
- private String typeId;//分類ID
- private double score;//該分類得分
-
- public KnnValueSort(String typeId, double score) {
- this.typeId = typeId;
- this.score = score;
- }
- public String getTypeId() {
- return typeId;
- }
- public void setTypeId(String typeId) {
- this.typeId = typeId;
- }
- public double getScore() {
- return score;
- }
- public void setScore(double score) {
- this.score = score;
- }
- }
3、KNN算法基本屬性
在KNN算法中,最重要的一個指標就是K的取值,所以咱們在基類中須要設置一個屬性K以及設置一個數組用於存儲已知分類的數據。
[java] view plain copy
print?![在CODE上查看代碼片](http://static.javashuo.com/static/loading.gif)
![派生到個人代碼片](http://static.javashuo.com/static/loading.gif)
- private List<KnnValueBean> dataArray;
- private int K = 3;
4、添加已知分類數據
在使用KNN分類以前,咱們須要先向其中添加咱們已知分類的數據,咱們後面就是使用這些數據來預測未知數據的分類。
[java] view plain copy
print?![在CODE上查看代碼片](http://static.javashuo.com/static/loading.gif)
![派生到個人代碼片](http://static.javashuo.com/static/loading.gif)
- /**
- * @param value
- * @param typeId
- * @Author:lulei
- * @Description: 向模型中添加記錄
- */
- public void addRecord(T value, String typeId) {
- if (dataArray == null) {
- dataArray = new ArrayList<KnnValueBean>();
- }
- dataArray.add(new KnnValueBean<T>(value, typeId));
- }
5、兩個樣本之間的類似度(或者距離)
在KNN算法中,最重要的一個方法就是如何肯定兩個樣本之間的類似度(或者距離),因爲這裏咱們使用的是泛型,並無辦法肯定兩個對象之間的類似度,一次這裏咱們把它設置爲抽象方法,讓子類來實現。這裏咱們方法定義爲類似度,也就是返回值越大,二者越類似,之間的距離越短。
[java] view plain copy
print?![在CODE上查看代碼片](http://static.javashuo.com/static/loading.gif)
![派生到個人代碼片](http://static.javashuo.com/static/loading.gif)
- /**
- * @param o1
- * @param o2
- * @return
- * @Author:lulei
- * @Description: o1 o2之間的類似度
- */
- public abstract double similarScore(T o1, T o2);
6、獲取最近的K個樣本的分類
KNN算法的核心思想就是找到最近的K個近鄰,所以這一步也是整個算法的核心部分。這裏咱們使用數組來保存類似度最大的K個樣本的分類和類似度,在計算的過程當中經過循環遍歷全部的樣本,數組保存截至當前計算點最類似的K個樣本對應的類別和類似度,具體實現以下:
[java] view plain copy
print?![在CODE上查看代碼片](http://static.javashuo.com/static/loading.gif)
![派生到個人代碼片](http://static.javashuo.com/static/loading.gif)
- /**
- * @param value
- * @return
- * @Author:lulei
- * @Description: 獲取距離最近的K個分類
- */
- private KnnValueSort[] getKType(T value) {
- int k = 0;
- KnnValueSort[] topK = new KnnValueSort[K];
- for (KnnValueBean<T> bean : dataArray) {
- double score = similarScore(bean.getValue(), value);
- if (k == 0) {
- //數組中的記錄個數爲0是直接添加
- topK[k] = new KnnValueSort(bean.getTypeId(), score);
- k++;
- } else {
- if (!(k == K && score < topK[k -1].getScore())){
- int i = 0;
- //找到要插入的點
- for (; i < k && score < topK[i].getScore(); i++);
- int j = k - 1;
- if (k < K) {
- j = k;
- k++;
- }
- for (; j > i; j--) {
- topK[j] = topK[j - 1];
- }
- topK[i] = new KnnValueSort(bean.getTypeId(), score);
- }
- }
- }
- return topK;
- }
7、統計K個樣本出現次數最多的類別
這一步就是一個簡單的計數,統計K個樣本中出現次數最多的分類,該分類就是咱們要預測的目標數據的分類。
[java] view plain copy
print?![在CODE上查看代碼片](http://static.javashuo.com/static/loading.gif)
![派生到個人代碼片](http://static.javashuo.com/static/loading.gif)
- /**
- * @param value
- * @return
- * @Author:lulei
- * @Description: KNN分類判斷value的類別
- */
- public String getTypeId(T value) {
- KnnValueSort[] array = getKType(value);
- HashMap<String, Integer> map = new HashMap<String, Integer>(K);
- for (KnnValueSort bean : array) {
- if (bean != null) {
- if (map.containsKey(bean.getTypeId())) {
- map.put(bean.getTypeId(), map.get(bean.getTypeId()) + 1);
- } else {
- map.put(bean.getTypeId(), 1);
- }
- }
- }
- String maxTypeId = null;
- int maxCount = 0;
- Iterator<Entry<String, Integer>> iter = map.entrySet().iterator();
- while (iter.hasNext()) {
- Entry<String, Integer> entry = iter.next();
- if (maxCount < entry.getValue()) {
- maxCount = entry.getValue();
- maxTypeId = entry.getKey();
- }
- }
- return maxTypeId;
- }
到如今爲止KNN分類的抽象基類已經編寫完成,在測試以前咱們先多說幾句,KNN分類是統計K個樣本中出現次數最多的分類,這種在有些狀況下並非特別合理,好比K=5,前5個樣本對應的分類分別爲A、A、B、B、B,對應的類似度得分分別爲十、九、二、二、1,若是使用上面的方法,那預測的分類就是B,可是看這些數據,預測的分類是A感受更合理。基於這種狀況,本身對KNN算法提出以下優化(這裏並不提供代碼,只提供簡單的思路):在獲取最類似K個樣本和類似度後,能夠對類似度和出現次數K作一種函數運算,好比加權,獲得的函數值最大的分類就是目標的預測分類。
基類源碼
[java] view plain copy
print?![在CODE上查看代碼片](http://static.javashuo.com/static/loading.gif)
![派生到個人代碼片](http://static.javashuo.com/static/loading.gif)
- /**
- *@Description: KNN分類
- */
- package com.lulei.datamining.knn;
-
- import java.util.ArrayList;
- import java.util.HashMap;
- import java.util.Iterator;
- import java.util.List;
- import java.util.Map.Entry;
-
- import com.lulei.datamining.knn.bean.KnnValueBean;
- import com.lulei.datamining.knn.bean.KnnValueSort;
- import com.lulei.util.JsonUtil;
-
- @SuppressWarnings({"rawtypes"})
- public abstract class KnnClassification<T> {
- private List<KnnValueBean> dataArray;
- private int K = 3;
-
- public int getK() {
- return K;
- }
- public void setK(int K) {
- if (K < 1) {
- throw new IllegalArgumentException("K must greater than 0");
- }
- this.K = K;
- }
-
- /**
- * @param value
- * @param typeId
- * @Author:lulei
- * @Description: 向模型中添加記錄
- */
- public void addRecord(T value, String typeId) {
- if (dataArray == null) {
- dataArray = new ArrayList<KnnValueBean>();
- }
- dataArray.add(new KnnValueBean<T>(value, typeId));
- }
-
- /**
- * @param value
- * @return
- * @Author:lulei
- * @Description: KNN分類判斷value的類別
- */
- public String getTypeId(T value) {
- KnnValueSort[] array = getKType(value);
- System.out.println(JsonUtil.parseJson(array));
- HashMap<String, Integer> map = new HashMap<String, Integer>(K);
- for (KnnValueSort bean : array) {
- if (bean != null) {
- if (map.containsKey(bean.getTypeId())) {
- map.put(bean.getTypeId(), map.get(bean.getTypeId()) + 1);
- } else {
- map.put(bean.getTypeId(), 1);
- }
- }
- }
- String maxTypeId = null;
- int maxCount = 0;
- Iterator<Entry<String, Integer>> iter = map.entrySet().iterator();
- while (iter.hasNext()) {
- Entry<String, Integer> entry = iter.next();
- if (maxCount < entry.getValue()) {
- maxCount = entry.getValue();
- maxTypeId = entry.getKey();
- }
- }
- return maxTypeId;
- }
-
- /**
- * @param value
- * @return
- * @Author:lulei
- * @Description: 獲取距離最近的K個分類
- */
- private KnnValueSort[] getKType(T value) {
- int k = 0;
- KnnValueSort[] topK = new KnnValueSort[K];
- for (KnnValueBean<T> bean : dataArray) {
- double score = similarScore(bean.getValue(), value);
- if (k == 0) {
- //數組中的記錄個數爲0是直接添加
- topK[k] = new KnnValueSort(bean.getTypeId(), score);
- k++;
- } else {
- if (!(k == K && score < topK[k -1].getScore())){
- int i = 0;
- //找到要插入的點
- for (; i < k && score < topK[i].getScore(); i++);
- int j = k - 1;
- if (k < K) {
- j = k;
- k++;
- }
- for (; j > i; j--) {
- topK[j] = topK[j - 1];
- }
- topK[i] = new KnnValueSort(bean.getTypeId(), score);
- }
- }
- }
- return topK;
- }
-
- /**
- * @param o1
- * @param o2
- * @return
- * @Author:lulei
- * @Description: o1 o2之間的類似度
- */
- public abstract double similarScore(T o1, T o2);
- }
具體子類實現
對於上面介紹的都在KNN分類的抽象基類中,對於實際的問題咱們須要繼承基類並實現基類中的類似度抽象方法,這裏咱們作一個簡單的實現。
[java] view plain copy
print?![在CODE上查看代碼片](http://static.javashuo.com/static/loading.gif)
![派生到個人代碼片](http://static.javashuo.com/static/loading.gif)
- /**
- *@Description:
- */
- package com.lulei.datamining.knn.test;
-
- import com.lulei.datamining.knn.KnnClassification;
- import com.lulei.util.JsonUtil;
-
- public class Test extends KnnClassification<Integer>{
-
- @Override
- public double similarScore(Integer o1, Integer o2) {
- return -1 * Math.abs(o1 - o2);
- }
-
- /**
- * @param args
- * @Author:lulei
- * @Description:
- */
- public static void main(String[] args) {
- Test test = new Test();
- for (int i = 1; i < 10; i++) {
- test.addRecord(i, i > 5 ? "0" : "1");
- }
- System.out.println(JsonUtil.parseJson(test.getTypeId(0)));
-
- }
- }
這裏咱們一共添加了一、二、三、四、五、六、七、八、9這9組數據,前5組的類別爲1,後4組的類別爲0,兩個數據之間的類似度爲二者之間的差值的絕對值的相反數,下面預測0應該屬於的分類,這裏K的默認值爲3,所以最近的K個樣本分別爲一、二、3,對應的分類分別爲"1"、"1"、"1",由於最後預測的分類爲"1"。
KNN算法全名爲k-Nearest Neighbor,就是K最近鄰的意思。KNN也是一種分類算法。可是與以前說的決策樹分類算法相比,這個算法算是最簡單的一個了。算法的主要過程爲:
一、給定一個訓練集數據,每一個訓練集數據都是已經分好類的。
二、設定一個初始的測試數據a,計算a到訓練集全部數據的歐幾里得距離,並排序。
三、選出訓練集中離a距離最近的K個訓練集數據。
四、比較k個訓練集數據,選出裏面出現最多的分類類型,此分類類型即爲最終測試數據a的分類。
下面百度百科上的一張簡圖:
![](http://static.javashuo.com/static/loading.gif)
KNN算法實現
首先測試數據須要2塊,1個是訓練集數據,就是已經分好類的數據,好比上圖中的非綠色的點。還有一個是測試數據,就是上面的綠點,固然這裏的測試數據不會是一個,而是一組。這裏的數據與數據之間的距離用數據的特徵向量作計算,特徵向量能夠是多維度的。經過計算特徵向量與特徵向量之間的歐幾里得距離來推算類似度。定義訓練集數據trainInput.txt:
[java] view plain copy
print?
- a 1 2 3 4 5
- b 5 4 3 2 1
- c 3 3 3 3 3
- d -3 -3 -3 -3 -3
- a 1 2 3 4 4
- b 4 4 3 2 1
- c 3 3 3 2 4
- d 0 0 1 1 -2
待測試數據testInput,只有特徵向量值:
[java] view plain copy
print?
- 1 2 3 2 4
- 2 3 4 2 1
- 8 7 2 3 5
- -3 -2 2 4 0
- -4 -4 -4 -4 -4
- 1 2 3 4 4
- 4 4 3 2 1
- 3 3 3 2 4
- 0 0 1 1 -2
下面是主程序:
[java] view plain copy
print?
- package DataMing_KNN;
-
- import java.io.BufferedReader;
- import java.io.File;
- import java.io.FileReader;
- import java.io.IOException;
- import java.util.ArrayList;
- import java.util.Arrays;
- import java.util.Collection;
- import java.util.Collections;
- import java.util.Comparator;
- import java.util.HashMap;
- import java.util.Map;
-
- import org.apache.activemq.filter.ComparisonExpression;
-
- /**
- * k最近鄰算法工具類
- *
- * @author lyq
- *
- */
- public class KNNTool {
- // 爲4個類別設置權重,默認權重比一致
- public int[] classWeightArray = new int[] { 1, 1, 1, 1 };
- // 測試數據地址
- private String testDataPath;
- // 訓練集數據地址
- private String trainDataPath;
- // 分類的不一樣類型
- private ArrayList<String> classTypes;
- // 結果數據
- private ArrayList<Sample> resultSamples;
- // 訓練集數據列表容器
- private ArrayList<Sample> trainSamples;
- // 訓練集數據
- private String[][] trainData;
- // 測試集數據
- private String[][] testData;
-
- public KNNTool(String trainDataPath, String testDataPath) {
- this.trainDataPath = trainDataPath;
- this.testDataPath = testDataPath;
- readDataFormFile();
- }
-
- /**
- * 從文件中閱讀測試數和訓練數據集
- */
- private void readDataFormFile() {
- ArrayList<String[]> tempArray;
-
- tempArray = fileDataToArray(trainDataPath);
- trainData = new String[tempArray.size()][];
- tempArray.toArray(trainData);
-
- classTypes = new ArrayList<>();
- for (String[] s : tempArray) {
- if (!classTypes.contains(s[0])) {
- // 添加類型
- classTypes.add(s[0]);
- }
- }
-
- tempArray = fileDataToArray(testDataPath);
- testData = new String[tempArray.size()][];
- tempArray.toArray(testData);
- }
-
- /**
- * 將文件轉爲列表數據輸出
- *
- * @param filePath
- * 數據文件的內容
- */
- private ArrayList<String[]> fileDataToArray(String filePath) {
- File file = new File(filePath);
- ArrayList<String[]> dataArray = new ArrayList<String[]>();
-
- try {
- BufferedReader in = new BufferedReader(new FileReader(file));
- String str;
- String[] tempArray;
- while ((str = in.readLine()) != null) {
- tempArray = str.split(" ");
- dataArray.add(tempArray);
- }
- in.close();
- } catch (IOException e) {
- e.getStackTrace();
- }
-
- return dataArray;
- }
-
- /**
- * 計算樣本特徵向量的歐幾里得距離
- *
- * @param f1
- * 待比較樣本1
- * @param f2
- * 待比較樣本2
- * @return
- */
- private int computeEuclideanDistance(Sample s1, Sample s2) {
- String[] f1 = s1.getFeatures();
- String[] f2 = s2.getFeatures();
- // 歐幾里得距離
- int distance = 0;
-
- for (int i = 0; i < f1.length; i++) {
- int subF1 = Integer.parseInt(f1[i]);
- int subF2 = Integer.parseInt(f2[i]);
-
- distance += (subF1 - subF2) * (subF1 - subF2);
- }
-
- return distance;
- }
-
- /**
- * 計算K最近鄰
- * @param k
- * 在多少的k範圍內
- */
- public void knnCompute(int k) {
- String className = "";
- String[] tempF = null;
- Sample temp;
- resultSamples = new ArrayList<>();
- trainSamples = new ArrayList<>();
- // 分類類別計數
- HashMap<String, Integer> classCount;
- // 類別權重比
- HashMap<String, Integer> classWeight = new HashMap<>();
- // 首先講測試數據轉化到結果數據中
- for (String[] s : testData) {
- temp = new Sample(s);
- resultSamples.add(temp);
- }
-
- for (String[] s : trainData) {
- className = s[0];
- tempF = new String[s.length - 1];
- System.arraycopy(s, 1, tempF, 0, s.length - 1);
- temp = new Sample(className, tempF);
- trainSamples.add(temp);
- }
-
- // 離樣本最近排序的的訓練集數據
- ArrayList<Sample> kNNSample = new ArrayList<>();
- // 計算訓練數據集中離樣本數據最近的K個訓練集數據
- for (Sample s : resultSamples) {
- classCount = new HashMap<>();
- int index = 0;
- for (String type : classTypes) {
- // 開始時計數爲0
- classCount.put(type, 0);
- classWeight.put(type, classWeightArray[index++]);
- }
- for (Sample tS : trainSamples) {
- int dis = computeEuclideanDistance(s, tS);
- tS.setDistance(dis);
- }
-
- Collections.sort(trainSamples);
- kNNSample.clear();
- // 挑選出前k個數據做爲分類標準
- for (int i = 0; i < trainSamples.size(); i++) {
- if (i < k) {
- kNNSample.add(trainSamples.get(i));
- } else {
- break;
- }
- }
- // 斷定K個訓練數據的多數的分類標準
- for (Sample s1 : kNNSample) {
- int num = classCount.get(s1.getClassName());
- // 進行分類權重的疊加,默認類別權重平等,可自行改變,近的權重大,遠的權重小
- num += classWeight.get(s1.getClassName());
- classCount.put(s1.getClassName(), num);
- }
-
- int maxCount = 0;
- // 篩選出k個訓練集數據中最多的一個分類
- for (Map.Entry entry : classCount.entrySet()) {
- if ((Integer) entry.getValue() > maxCount) {
- maxCount = (Integer) entry.getValue();
- s.setClassName((String) entry.getKey());
- }
- }
-
- System.out.print("測試數據特徵:");
- for (String s1 : s.getFeatures()) {
- System.out.print(s1 + " ");
- }
- System.out.println("分類:" + s.getClassName());
- }
- }
- }
Sample樣本數據類:
[java] view plain copy
print?
- package DataMing_KNN;
-
- /**
- * 樣本數據類
- *
- * @author lyq
- *
- */
- public class Sample implements Comparable<Sample>{
- // 樣本數據的分類名稱
- private String className;
- // 樣本數據的特徵向量
- private String[] features;
- //測試樣本之間的間距值,以此作排序
- private Integer distance;
-
- public Sample(String[] features){
- this.features = features;
- }
-
- public Sample(String className, String[] features){
- this.className = className;
- this.features = features;
- }
-
- public String getClassName() {
- return className;
- }
-
- public void setClassName(String className) {
- this.className = className;
- }
-
- public String[] getFeatures() {
- return features;
- }
-
- public void setFeatures(String[] features) {
- this.features = features;
- }
-
- public Integer getDistance() {
- return distance;
- }
-
- public void setDistance(int distance) {
- this.distance = distance;
- }
-
- @Override
- public int compareTo(Sample o) {
- // TODO Auto-generated method stub
- return this.getDistance().compareTo(o.getDistance());
- }
-
- }
測試場景類:
[java] view plain copy
print?
- /**
- * k最近鄰算法場景類型
- * @author lyq
- *
- */
- public class Client {
- public static void main(String[] args){
- String trainDataPath = "C:\\Users\\lyq\\Desktop\\icon\\trainInput.txt";
- String testDataPath = "C:\\Users\\lyq\\Desktop\\icon\\testinput.txt";
-
- KNNTool tool = new KNNTool(trainDataPath, testDataPath);
- tool.knnCompute(3);
-
- }
-
-
-
- }
執行的結果爲:
[java] view plain copy
print?
- 測試數據特徵:1 2 3 2 4 分類:a
- 測試數據特徵:2 3 4 2 1 分類:c
- 測試數據特徵:8 7 2 3 5 分類:b
- 測試數據特徵:-3 -2 2 4 0 分類:a
- 測試數據特徵:-4 -4 -4 -4 -4 分類:d
- 測試數據特徵:1 2 3 4 4 分類:a
- 測試數據特徵:4 4 3 2 1 分類:b
- 測試數據特徵:3 3 3 2 4 分類:c
- 測試數據特徵:0 0 1 1 -2 分類:d
程序的輸出結果如上所示,若是不相信的話能夠本身動手計算進行驗證。
KNN算法的注意點:
一、knn算法的訓練集數據必需要相對公平,各個類型的數據數量應該是平均的,不然當A數據由1000個B數據由100個,到時不管如何A數據的樣本仍是佔優的。
二、knn算法若是純粹憑藉分類的多少作判斷,仍是能夠繼續優化的,好比近的數據的權重能夠設大,最後根據全部的類型權重和進行比較,而不是單純的憑藉數量。
三、knn算法的缺點是計算量大,這個從程序中也應該看得出來,裏面每一個測試數據都要計算到全部的訓練集數據之間的歐式距離,時間複雜度就已經爲O(n*n),若是真實數據的n很是大,這個算法的開銷的確態度,因此KNN不適合大規模數據量的分類。
KNN算法編碼時遇到的困難:
按理來講這麼簡單的KNN算法本應該是沒有多少的難度,可是在多歐式距離的排序上被深深的坑了一段時間,本人起初用Collections.sort(list)的方式進行按距離排序,也把Sample類實現了Compareable接口,可是排序就是不變,最後才知道,distance的int類型要改成Integer引用類型,在compareTo重載方法中調用distance的.CompareTo()方法就成功了,這個小細節平時沒注意,難道屬性的比較最終必定要調用到引用類型的compareTo()方法?這個小問題居然花費了我一段時間,最後仔細的比較了一下網上的例子最後才發現......