【大數據分析經常使用算法】3.K-近鄰算法

簡介

K-近鄰(K-Nearest Neighbors, KNN)是一個很是簡單的機器學習算法,不少機器學習算法書籍都喜歡將該算法做爲入門的算法做爲介紹。java

KNN分類問題是找出一個數據集中與給定查詢數據點最近的K個數據點。這個操做也成爲KNN鏈接(KNN-join)。能夠定義爲:給定兩個數據集R合S,對R中的每個對象,咱們但願從S中找出K個最近的相鄰對象。算法

在數據挖掘中,R和S分別稱爲查詢和訓練(traning)數據集。訓練數據集S表示已經分類的數據,而查詢數據集R表示利用S中的分類來進行分類的數據。apache

KNN是一個比較重要的聚類算法,在數據挖掘(圖像識別)、生物信息(如乳腺癌診斷)、天氣數據生成模型和商品推薦系統中有不少應用。api

缺點:開銷大。特別是有一個龐大的訓練集時。正是這個緣由,使用MapReduce運行該算法顯得很是的有用。數組

一、KNN算法

1.一、KNN分類

KNN的中心思想是創建一個分類方法,使得對於將y(響應變量)與x(預測變量)關聯的平滑函數f的形勢沒有任何的假設: $$ x = (x_{1},x_{2},...,x_{n}) $$app

$$ y = f(x) $$機器學習

函數f是非參數化的,由於它不涉及任何形式的參數估計。在KNN中,給定一個新的點$p=(p_{1},p_{2},...,p_{n})$,要動態的識別訓練集數據集中與p類似的K個觀察(k個近鄰)。近鄰由一個距離或類似度來定義。能夠根據獨立變量計算不一樣觀察之間的距離,咱們採用歐氏距離進行計算: $$ \sqrt{(x_{1} - p_{1})^2 + (x_{2} - p_{2})^2 + ... + (x_{n}-p_{n})^2} $$函數

關於距離的算法以及種類有不少,本章節咱們採用歐氏距離,即座標系距離計算方法。oop

那麼如何找出k個近鄰呢?學習

咱們先計算出歐氏距離的集合,而後將這個查詢對象分配到k個最近訓練數據中大多數對象所在的類。

1.二、距離函數

假設有兩個n維對象: $$ X = (X_{1},X_{2},...,X_{n}) $$

$$ Y = (Y_{1},Y_{2},...,Y_{n}) $$

$distance(X,Y)$能夠定義以下: $$ distance(X,Y) = \sqrt{\sum_{i=1}^{n}(x_{i}-y_{i})^2} $$

注意歐氏距離只適用於連續性數值類型:double。若是是其餘類型,則能夠考慮關聯業務狀況下設置距離函數,將其轉化爲double類型。

關於全部的有關各類距離的介紹,參考博文:

1.三、KNN解析

KNN算法是一種對未分類數據進行分類的直觀方法,他會根據未分類數據與訓練數據集中的數據的類似度或距離完成分類。在下面的例子中,咱們有4個分類$C_{1} - C_{4}$:

能夠看到,咱們的K=6,所以選取了6個近鄰,在這6個近鄰中,出如今上方的那個類中有4個屬於它的點,所以,咱們將P點歸爲上方圓圈包含的這一類型中。

1.四、算法描述

KNN算法能夠總結爲如下的步驟:

  1. 肯定K
  2. 計算新輸入與全部訓練集之間的距離
  3. 對距離排序,並根據第k個最小距離肯定k個近鄰
  4. 手機這些近鄰所屬的類別
  5. 根據多數投票肯定類別

算法複雜度:$O(N^2)$

二、Spark實現

2.一、形式化描述

設R和S是d維數據集,咱們想找出其kNN(RS)。進一步假設全部訓練數據(S)已經分類到$C={C_{1},C_{2},...,C_{n}}$,這裏$C$表示全部可能的分類。R、S和C的定義以下: $$ R = {R_{1},R_{2},...,R_{n}} $$

$$ S = {S_{1},S_{2},...,S_{n}} $$

$$ C = {C_{1},C_{2},...,C_{n}} $$

在這裏:

  1. $R_{i} = (r_{i},a_{1},a_{2},...,a_{n})$,其中$r_{i}$是當前記錄的ID,$a_{1},...,a_{n}$是$R_{i}$的屬性;
  2. $S_{j} = {r_{j},b_{1},b_{2},...,b_{n}}$同上。
  3. $C_{j}$是$S_{j}$的分類標識符。

咱們的目標是找出$KNN(R,S)$。

2.二、數據集

S數據集以下所示:

100;c1;1.0,1.0
101;c1;1.1,1.2
102;c1;1.2,1.0
103;c1;1.6,1.5
104;c1;1.3,1.7
105;c1;2.0,2.1
106;c1;2.0,2.2
107;c1;2.3,2.3
208;c2;9.0,9.0
209;c2;9.1,9.2
210;c2;9.2,9.0
211;c2;10.6,10.5
212;c2;10.3,10.7
213;c2;9.6,9.1
214;c2;9.4,10.4
215;c2;10.3,10.3
300;c3;10.0,1.0
301;c3;10.1,1.2
302;c3;10.2,1.0
303;c3;10.6,1.5
304;c3;10.3,1.7
305;c3;1.0,2.1
306;c3;10.0,2.2
307;c3;10.3,2.3

其中,第一列爲每條記錄的惟一ID,第二列爲該條記錄的所屬類別,以後的都爲維度信息;

R數據集的信息以下:

1000;3.0,3.0
1001;10.1,3.2
1003;2.7,2.7
1004;5.0,5.0
1005;13.1,2.2
1006;12.7,12.7

其中,第一列爲每條記錄的惟一ID,以後的都爲維度信息;

接下來咱們使用KNN算法,來計算R數據集中每一個記錄所屬的類別。

2.三、Spark實現

package com.sunrun.movieshow.autils.knn;

import com.google.common.base.Splitter;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.broadcast.Broadcast;
import scala.Tuple2;

import java.util.*;

public class KNNTester {
    /**
     * 1. 獲取Spark 上下文對象
     * @return
     */
    public static JavaSparkContext getSparkContext(String appName){
        SparkConf sparkConf = new SparkConf()
                .setAppName(appName)
                //.setSparkHome(sparkHome)
                .setMaster("local[*]")
                // 串行化器
                .set("spark.serializer","org.apache.spark.serializer.KryoSerializer")
                .set("spark.testing.memory", "2147480000");

        return new JavaSparkContext(sparkConf);
    }

    /**
     * 2. 將數字字符串轉換爲Double數組
     * @param str 數字字符串: "1,2,3,4,5"
     * @param delimiter 數字之間的分隔符:","
     * @return Double數組
     */
    public static List<Double> transferToDoubleList(String str, String delimiter){
        // 使用Google Splitter切割字符串
        Splitter splitter = Splitter.on(delimiter).trimResults();
        Iterable<String> tokens = splitter.split(str);
        if(tokens == null){
            return null;
        }
        List<Double> list = new ArrayList<>();
        for (String token : tokens) {
            list.add(Double.parseDouble(token));
        }
        return list;
    }

    /**
     * 計算距離
     * @param rRecord R數據集的一條記錄
     * @param sRecord S數據集的一條記錄
     * @param d 記錄的維度
     * @return 兩條記錄的歐氏距離
     */
    public static double calculateDistance(String rRecord, String sRecord, int d){
        double distance = 0D;
        List<Double> r = transferToDoubleList(rRecord,",");
        List<Double> s = transferToDoubleList(sRecord,",");
        // 若維度不一致,說明數據存在問題,返回NAN
        if(r.size() != d || s.size() != d){
            distance =  Double.NaN;
        } else{
            // 保證維度一致以後,計算歐氏距離
            double sum = 0D;
            for (int i = 0; i < s.size(); i++) {
                double diff = s.get(i) - r.get(i);
                sum += diff * diff;
            }
            distance = Math.sqrt(sum);
        }
        return distance;
    }

    /**
     * 根據(距離,類別),找出距離最低的K個近鄰
     * @param neighbors 當前求出的近鄰數量
     * @param k 尋找多少個近鄰
     * @return K個近鄰組成的SortedMap
     */
    public static SortedMap<Double, String>findNearestK(Iterable<Tuple2<Double,String>> neighbors, int k){
        TreeMap<Double, String> kNeighbors = new TreeMap<>();
        for (Tuple2<Double, String> neighbor : neighbors) {
            // 距離
            Double distance = neighbor._1;
            // 類別
            String classify = neighbor._2;
            kNeighbors.put(distance, classify);
            // 若是當前已經寫入K個元素,那麼刪除掉距離最遠的一個元素(位於末端)
            if(kNeighbors.size() > k){
                kNeighbors.remove(kNeighbors.lastKey());
            }
        }
        return kNeighbors;
    }

    /**
     * 計算對每一個類別的投票次數
     * @param kNeighbors 選取的K個最近的點
     * @return 對每一個類別的投票結果
     */
    public static Map<String, Integer> buildClassifyCount(Map<Double, String> kNeighbors){
        HashMap<String, Integer> majority = new HashMap<>();
        for (Map.Entry<Double, String> entry : kNeighbors.entrySet()) {
            String classify = entry.getValue();
            Integer count = majority.get(classify);
            // 當前沒有出現過,設置爲1,不然+1
            if(count == null){
                majority.put(classify,1);
            }else{
                majority.put(classify,count + 1);
            }
        }
        return  majority;
    }

    /**
     * 根據投票結果,選取最終的類別
     * @param majority 投票結果
     * @return 最終的類別
     */
    public static String classifyByMajority(Map<String, Integer> majority){
        String selectedClassify = null;
        int maxVotes = 0;
        // 從投票結果中選取票數最多的一類做爲最終選舉結果
        for (Map.Entry<String, Integer> entry : majority.entrySet()) {
            if(selectedClassify == null){
                selectedClassify = entry.getKey();
                maxVotes = entry.getValue();
            }else{
                int nowVotes = entry.getValue();
                if(nowVotes > maxVotes){
                    selectedClassify = entry.getKey();
                    maxVotes = nowVotes;
                }
            }
        }
        return selectedClassify;
    }



    public static void main(String[] args) {
        // === 1.建立SparkContext
        JavaSparkContext sc = getSparkContext("KNN");

        // === 2.KNN算法相關參數:廣播共享對象
        String HDFSUrl = "hdfs://10.21.1.24:9000/output/";
        // k(K)
        Broadcast<Integer> broadcastK = sc.broadcast(6);
        // d(維度)
        Broadcast<Integer> broadcastD = sc.broadcast(2);

        // === 3.爲查詢和訓練數據集建立RDD
        // R and S
        String RPath = "data/knn/R.txt";
        String SPath = "data/knn/S.txt";
        JavaRDD<String> R = sc.textFile(RPath);
        JavaRDD<String> S = sc.textFile(SPath);
//        // === 將R和S的數據存儲到hdfs
//        R.saveAsTextFile(HDFSUrl + "S");
//        S.saveAsTextFile(HDFSUrl + "R");

        // === 5.計算R&S的笛卡爾積
        JavaPairRDD<String, String> cart = R.cartesian(S);
        /**
         * (1000;3.0,3.0,100;c1;1.0,1.0)
         * (1000;3.0,3.0,101;c1;1.1,1.2)
         */

        // === 6.計算R中每一個點與S各個點之間的距離:(rid,(distance,classify))
        // (1000;3.0,3.0,100;c1;1.0,1.0) => 1000 is rId, 100 is sId, c1 is classify.
        JavaPairRDD<String, Tuple2<Double, String>> knnPair = cart.mapToPair(t -> {
            String rRecord = t._1;
            String sRecord = t._2;

            // 1000;3.0,3.0
            String[] splitR = rRecord.split(";");
            String rId = splitR[0]; // 1000
            String r = splitR[1];// "3.0,3.0"

            // 100;c1;1.0,1.0
            String[] splitS = sRecord.split(";");
            // sId對於當前算法沒有多大意義,咱們只須要獲取類別細信息,即第二個字段的信息便可
            String sId = splitS[0]; // 100
            String classify = splitS[1]; // c1
            String s = splitS[2];// "3.0,3.0"

            // 獲取廣播變量中的維度信息
            Integer d = broadcastD.value();
            // 計算當前兩個點的距離
            double distance = calculateDistance(r, s, d);
            Tuple2<Double, String> V = new Tuple2<>(distance, classify);
            // (Rid,(distance,classify))
            return new Tuple2<>(rId, V);
        });
        /**
         * (1005,(2.801785145224379,c3))
         * (1006,(4.75078940808788,c2))
         * (1006,(4.0224370722237515,c2))
         * (1006,(3.3941125496954263,c2))
         * (1006,(12.0074976577137,c3))
         * (1006,(11.79025020938911,c3)
         */


        // === 7. 按R中的r根據每一個記錄進行分組
        JavaPairRDD<String, Iterable<Tuple2<Double, String>>> knnGrouped = knnPair.groupByKey();
        // (1005,[(12.159358535712318,c1),....,(7.3171032519706865,c3), (7.610519036176179,c3)]),
        // (1000,[(2.8284271247461903,c1), (2.6172504656604803,c1), (2.690724....])

        // === 8.找出每一個R節點的k個近鄰
        JavaPairRDD<String, String> knnOutput = knnGrouped.mapValues(t -> {
            // K
            Integer k = broadcastK.value();
            SortedMap<Double, String> nearestK = findNearestK(t, k);
            // {2.596150997149434=c3, 2.801785145224379=c3, 2.8442925306655775=c3, 3.0999999999999996=c3, 3.1384709652950433=c3, 3.1622776601683795=c3}

            // 統計每一個類別的投票次數
            Map<String, Integer> majority = buildClassifyCount(nearestK);
            // {c3=1, c1=5}

            // 按多數優先原則選擇最終分類
            String selectedClassify = classifyByMajority(majority);
            return selectedClassify;
        });

        // 存儲最終結果
        knnOutput.saveAsTextFile(HDFSUrl + "/result");
        /**
         * [root@h24 hadoop]# hadoop fs -cat /output/result/p*
         * (1005,c3)
         * (1001,c3)
         * (1006,c2)
         * (1003,c1)
         * (1000,c1)
         * (1004,c1)
         */
    }
}

步驟7和8也能夠經過reduceByKey或者CombineByKey進行一步到位。先來看看咱們的轉換過程:

RDD:
—— knnPair: JavaPairRDD<String, Tuple2<Double, String>>
—— knnGrouped: JavaPairRDD<String, Iterable<Tuple2<Double, String>>>
—— knnOutput:JavaPairRDD<String, String>

變換過程:

knnPair    => groupBy   => knnGrouped
knnGrouped => mapValues => knnOutput

顯然,咱們沒法使用reduceByKey,所以他要求輸出類型等同於輸入類型。彙集的返回類型不一樣於彙集值的類型時就要使用combineByKey變換。所以,咱們將使用combineByKey把步驟7和8合併到一塊兒。這個合併步驟以下:

RDD:

—— knnPair: JavaPairRDD<String, Tuple2<Double, String>>
—— knnOutput: JavaPairRDD<String, String>

變換過程:

—— knnPair => combineByKey => knnOutput
相關文章
相關標籤/搜索