KNN分類器-Java實現

KNN,即K近鄰算法。其基本思想或者說是實現步驟以下:  算法

(1)計算樣本數據點到每一個已知類別的數據集中點的距離  數組

(2)將(1)中獲得的距離按遞增順序排列  dom

(3)選取(2)中前K個點(即與當前樣本距離最小的K個已知類別的數據點)  函數

(4)統計(3)中獲得的K個點所在類別的出現頻率  測試

(5)返回(4)中出現頻率最高的類別做爲樣本點的預測類別  this

在給出具體實現代碼以前,說明一點:Java下的矩陣操做類基於開源jama包,我本身基於它的源碼,作了部分必要的擴充和修改。 orm

具體實現代碼以下: 排序

 /** get

 * Created by Song on 2016/9/30.  源碼

*/

 public class KnnHandler implements DMHandler { 

 //訓練集中,每一個特徵的最小值 

 private Matrix minVals; 

 //訓練集中,每一個特徵的最大值 

 private Matrix maxVals; 

 //訓練集中,每一個特徵的取值範圍 

 private Matrix ranges; 


public KnnHandler(Matrix dataSet){ 

      double [][] minMax = dataSet.getMinMax(); 

      this.minVals = new Matrix(minMax[0],1); 

      this.maxVals = new Matrix(minMax[1],1); 

      this.ranges = maxVals.minus(minVals); 

 } 

 /**

 * 歸一化特徵值 

 * @param dataSet 特徵集 

 */ 

  public Matrix autoNorm(Matrix dataSet){ 

       double[][] norm = dataSet.getArray(); 

       for(int j=0;j<dataSet.getColumnDimension();j++){ 

            for(int i=0;i<norm.length;i++){ 

                  norm[i][j] = (norm[i][j]-minVals.get(0,j))/ranges.get(0,j); 

            } 

       } 

       return new Matrix(norm); 

 } 

 /** 

 * K近鄰算法 

 * @param sample 待評估樣本 

 * @param dataSet 數據集 

 * @param labels 數據集中,每行數據對應的類別 

 * @param rate 將距離按由小至大排列,按比例選擇固定數量的類別 

 */ 

 public double classify(Matrix sample,Matrix dataSet,Matrix labels,double rate){ 

       //統計樣本頻率 

      HashMap<Double,Integer> levels = new HashMap<Double, Integer>(); 

      //遍歷類別,得出一共有幾類 

     for(int i=0;i<labels.getRowDimension();i++){ 

           if(!levels.containsKey(labels.get(i,0))) levels.put(labels.get(i,0),0); 

     } 

     //得到距離,並遞增排序 

    Matrix sortedDistance = sample.distance(dataSet).expand(labels,true).sort(); 

    //取前num個數據 

    int num = (int)Math.ceil(sortedDistance.getRowDimension()*rate); 

    for(int i=0;i<num;i++){        levels.put(sortedDistance.get(i,1),levels.get(sortedDistance.get(i,1))+1); 

 } 

 //按頻率排序 

 double targetLevel = 0; 

 int count = 0; 

 for(double key:levels.keySet()){ 

       if(levels.get(key)>count) { 

              count = levels.get(key); 

              targetLevel = key; 

           } 

 } 

 return targetLevel; 

 } 

 //測試

public static void main(String [] args){ 

//隨機生成訓練集(已知類別) 

Random random = new Random(); 

 double [][] dataSet = new double[100][4]; 

 for(int i=0;i<100;i++){ 

       for(int j=0;j<4;j++){ 

             dataSet[i][j]=random.nextInt(10); 

        } 

 } 

 //訓練集中100組數據對應的類別 

 double [] lables = new double[100]; 

 for(int i=0;i<100;i++){ 

           lables[i]=i/10; 

 } 

 //生成待分類樣本 

 double [] sample = {1,2,3,4}; 

 //KNN操做類實例化 

 KnnHandler handler = new KnnHandler(new Matrix(dataSet)); //handler.autoNorm(new Matrix(dataSet)).print(4,3); 

 //輸出分類結果 

 System.out.println(handler.classify(new Matrix(sample,1),new Matrix(dataSet),new Matrix(lables,1).transpose(),0.3)); 

    } 

其中部分函數,例如構造器中得到數據集中每一個特徵的最小最大取值(即一個二維數組中每列值的最小最大值)方法getMinMax()等,都是本身基於jama源碼擴充獲得的,原理很簡單,此處就不列出來了。 能夠看出,KNN分類是一種很是基礎的分類算法,適用於數值型數據。經過計算未知數據點到已知數據點的距離,來判斷其具體分類。 

相關文章
相關標籤/搜索