【大數據分析經常使用算法】6.K均值

簡介

一、K-均值距離函數

1.一、歐式距離

歐式距離的計算公式 $$ d(x,y) = \sqrt{(x_1 - y_1)^2 + (x_2 - y_2)^2 + ... + (x_n - y_n)^2} $$java

其中,x,y分別表明兩個點,同時,兩個點具備相同的維度:n。$x_1,x_2,...,x_n$表明點x的每一個維度的值,$y_1,y_2,...,y_n$表明點y的各個維度的值。算法

1.二、歐氏距離的性質

假設有$p_1,p_2,p_{k}$3個點。apache

  • $d(p_1,p_2) \ge 0$數組

  • $d(p_i,p_i) = 0$網絡

  • $d(p_i,p_j) = d(p_j,p_i)$app

  • $d(p_i,p_j) \le d(p_i,p_k) + d(p_k,p_j)$ide

最後一個性質也說明了一個很常見的現象:兩點間的距離,線段最短。函數

1.三、源碼實現

import java.util.List;
/**
 * 歐式距離計算
 */
public class EuclideanDistance {
    public static double caculate(List<Double> p1, List<Double> p2){
        double sum = 0.0;
        int length = p1.size();
        for (int i = 0; i < length; i++) {
            sum += Math.pow(p1.get(i) - p2.get(i),2.0);
        }
        return Math.sqrt(sum);
    }
}

二、形式化描述

K-均值算法是一個完成聚類分析的簡單學習算法。K-均值聚類算法的目標是找出n項的最佳劃分,也就是將n個對象劃分到K個組中,是的一個組中的成員語氣相應的質心(表示這個組)之間的總距離最小。採用形式化表示,目標就是將n項劃分到K個集合$$ {S_i,i=1,2,...,K} $$ 中,使得簇內平方和或組內平方和(within-cluster sum of squares,WCSS)最小,WCSS定義爲 $$ \min \sum_{j=1}^k \sum_{i=1}^n ||x_{i}^j - c_j|| $$工具

這裏的$||x_i^j - c_j||$表示實體點質心之間的距離。oop

三、MapReduce實現

3.一、數據集

以下所示,咱們選用的二位數據集。

1.0,2.0
1.0,3.0
1.0,4.0
2.0,5.0
2.0,3.0
2.0,7.0
2.0,8.0
3.0,100.0
3.0,101.0
3.0,102.0
3.0,103.0
3.0,104.0

3.二、Mapper

package mapreduce;

import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class KMeansMapper extends Mapper<LongWritable, Text, IntWritable, Text> {

    private List<List<Double>> centers = null;

    // K
    private int k = 0;

    /**
     * map 開始時調用一次。
     * @param context
     * @throws IOException
     * @throws InterruptedException
     */
    @Override
    protected void setup(Context context) throws IOException, InterruptedException {
        // config
        String centerPath = context.getConfiguration().get("centerPath");
        // 讀取質心點信息
        this.centers = KMeansUtil.getCenterFromFileSystem(centerPath);
        // 獲取K值(中心點個數)
        k = centers.size();
        System.out.println("當前的質心數據爲:" + centers);
    }

    /**
     * 1.每次讀取一條要分類的條記錄與中心作對比,歸類到對應的中心
     * 2.以中心ID爲key,中心包含的記錄爲value輸出(例如: 1 0.2---->1爲聚類中心的ID,0.2爲靠近聚類中心的某個值)
     */
    @Override
    protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
        // 讀取一行數據
        List<Double> fields = KMeansUtil.textToList(value);
        // 點維度
        int dimension = fields.size();

        double minDistance = Double.MAX_VALUE;

        int centerIndex = 0;

        // 依次取出K箇中心點與當前讀取的記錄作計算
        for (int i = 0; i < k; i++) {
            double currentDistance  = 0.0;
            // 之因此跳過0,是由於1表明的是該點的ID,不歸入計算的範疇
            for (int j = 1; j < dimension; j++) {
                // 獲取中心點
                double centerPoint = Math.abs(centers.get(i).get(j));
                // 當前須要計算的點
                double field = Math.abs(fields.get(j));
                // 計算歐氏距離
                currentDistance += Math.pow((centerPoint - field) / (centerPoint + field), 2);
            }

            // 找出距離該記錄最近的中心點的ID,記錄最小值、該點的索引
            if(currentDistance < minDistance){
                minDistance = currentDistance;
                centerIndex = i;
            }
        }

        // 以中心點爲key,原樣輸出,這樣以該中心點爲key的點都會做爲一個簇在reducer端匯聚
        context.write(new IntWritable(centerIndex),value);
    }
}

3.三、Reuder

package mapreduce;

import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Reducer;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * 利用reduce歸併功能以中心爲key將記錄歸併在一塊兒
 */
public class KMeansReducer extends Reducer<IntWritable, Text, NullWritable, Text>{
    /**
     * 1.K-V: Key爲聚類中心的ID;value爲該中心的記錄集合;
     * 2.計數全部記錄元素的平均值,求出新的中心;KMeans算法的最終結果選取的質心點通常不是原數據集中的點
     */
    @Override
    protected void reduce(IntWritable key, Iterable<Text> values, Context context) throws IOException, InterruptedException {
        List<List<Double>> result = new ArrayList<List<Double>>();
        // 依次讀取記錄集,每行轉化爲一個List<Double>
        for (Text value : values) {
            result.add(KMeansUtil.textToList(value));
        }

        // 計算新的質心點:經過各個維的平均值
        int dimension = result.get(0).size();
        double[] averages = new double[dimension];

        for (int i = 0; i < dimension; i++) {
            double sum = 0.0;
            int size = result.size();

            for (int j = 0; j < size; j++) {
                sum += result.get(j).get(i);
            }

            averages[i] = sum / size;
        }
        context.write(NullWritable.get(),new Text(Arrays.toString(averages).replace("[","").replace("]","")));
    }
}

3.四、Driver

package mapreduce;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;

import java.io.IOException;
import java.util.List;

public class KMeansDriver {

    public static void main(String[] args) throws Exception{
        String dfs = "hdfs://192.168.35.128:9000";
        // 存放中心點座標值
        String centerPath = dfs + "/kmeans/center/";
        // 存放待處理數據
        String dataPath = dfs + "/kmeans/kmeans_input_file.txt";
        // 新中心點存放目錄
        String newCenterPath = dfs + "/kmeans/newCenter/";
        // delta
        double delta = 0.1D;

        int count = 0;

        final int K = 3;

        // 選取初始的K個質心點
        List<List<Double>> pick = KMeansUtil.pick(K, dfs + "/kmeans/kmeans_input_file.txt");

        // 存儲到結果集
        KMeansUtil.writeCurrentKClusterToCenter(centerPath + "center.data",pick);

        while(true){
            ++ count;
            System.out.println(" 第 " + count + " 次計算 ");
            run(dataPath, centerPath, newCenterPath);
            System.out.println("計算迭代變化值");
            // 比較新舊質點變化幅度
            if(KMeansUtil.compareCenters(centerPath, newCenterPath,delta)){
                System.out.println("迭代結束");
                break;
            }
        }
        /**
         * 第 1 次計算
         * 當前的質心數據爲:[[1.0, 1.0], [1.0, 2.0], [1.0, 3.0]]
         * task running status is : 1
         * 計算迭代變化值
         * 當前的質心點迭代變化值: 2125.9917355371904
         *  第 2 次計算
         * 當前的質心數據爲:[[1.0, 1.0], [1.0, 2.0], [2.272727272727273, 49.09090909090909]]
         * task running status is : 1
         * 計算迭代變化值
         * 當前的質心點迭代變化值: 2806.839601956485
         *  第 3 次計算
         * 當前的質心數據爲:[[1.0, 1.0], [1.5714285714285714, 4.571428571428571], [3.0, 102.0]]
         * task running status is : 1
         * 計算迭代變化值
         * 當前的質心點迭代變化值: 0.44274376417233585
         *  第 4 次計算
         * 當前的質心數據爲:[[1.0, 1.5], [1.6666666666666667, 5.0], [3.0, 102.0]]
         * task running status is : 1
         * 計算迭代變化值
         * 當前的質心點迭代變化值: 0.0
         * 迭代結束
         */
    }

    public static void run(String dataPath, String centerPath, String newCenterPath) throws IOException, ClassNotFoundException, InterruptedException {
        Configuration configuration = new Configuration();
        configuration.set("centerPath", centerPath);

        Job job = Job.getInstance(configuration);

        job.setJarByClass(KMeansDriver.class);
        job.setMapperClass(KMeansMapper.class);
        job.setMapOutputKeyClass(IntWritable.class);
        job.setMapOutputValueClass(Text.class);

        job.setReducerClass(KMeansReducer.class);
        job.setOutputKeyClass(NullWritable.class);
        job.setOutputValueClass(Text.class);


        FileInputFormat.setInputPaths(job,new Path(dataPath));
        FileOutputFormat.setOutputPath(job,new Path(newCenterPath) );
        System.out.println("task running status is : " + (job.waitForCompletion(true)? 1:0));
    }
}

咱們還能夠寫一個Combiner優化網絡傳輸的流量,不過此處因爲測試的緣故,就不寫不是本章節主題的代碼了。

另外,這幾個類還使用了一個輔助工具類

package mapreduce;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.*;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.util.LineReader;

import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.List;

/**
 * KMeans工具
 */
public class KMeansUtil {

    public static FileSystem getFileSystem() throws URISyntaxException, IOException, InterruptedException {
        // 獲取一個具體的文件系統對象
        return FileSystem.get(new URI("hdfs://192.168.35.128:9000"),new Configuration(),"root");
    }


    /**
     * 在數據集中選取前K個點做爲質心
     * @param k
     * @param filePath
     * @return
     */
    public static List<List<Double>> pick(int k, String filePath) throws Exception {
        List<List<Double>> result = new ArrayList<List<Double>>();
        Path path = new Path(filePath);
        FileSystem fileSystem = getFileSystem();
        FSDataInputStream open = fileSystem.open(path);
        LineReader lineReader = new LineReader(open);
        Text line = new Text();
        // 讀取每一行信息
        while(lineReader.readLine(line) > 0 && k > 0){
            List<Double> doubles = textToList(line);
            result.add(doubles);
            k = k - 1;
        }
        lineReader.close();
        return result;
    }

    /**
     * 將當前的結果寫入數據中心
     */
    public static void writeCurrentKClusterToCenter(String centerPath,List<List<Double>> data) throws Exception {
        FSDataOutputStream out = getFileSystem().create(new Path(centerPath));

        for (List<Double> d : data) {
            String str = d.toString();
            out.write(str.replace("[","").replace("]","\n").getBytes());
        }
        out.close();
    }


    /**
     * 從數據中心獲取質心點數據
     * @param filePath 路徑
     * @return 質心數據
     */
    public static List<List<Double>> getCenterFromFileSystem(String filePath) throws IOException {

        List<List<Double>> result = new ArrayList<List<Double>>();

        Path path = new Path(filePath);
        Configuration configuration = new Configuration();
        FileSystem fileSystem = null;
        try {
            fileSystem = getFileSystem();
        } catch (Exception e) {
            e.printStackTrace();
        }

        FileStatus[] listFiles = fileSystem.listStatus(path);
        for (FileStatus file : listFiles) {
            FSDataInputStream open = fileSystem.open(file.getPath());
            LineReader lineReader = new LineReader(open, configuration);
            Text line = new Text();
            // 讀取每一行信息
            while(lineReader.readLine(line) > 0){
                List<Double> doubles = textToList(line);
                result.add(doubles);
            }
        }
        return result;
    }

    /**
     * 將Text轉化爲數組
     * @param text
     * @return
     */
    public static List<Double> textToList(Text text){
        List<Double> list = new ArrayList<Double>();

        String[] split = text.toString().split(",");

        for (int i = 0; i < split.length; i++) {
            list.add(Double.parseDouble(split[i]));
        }
        return list;
    }

    /**
     * 比較新舊數據點的變化狀況
     * @return
     * @throws Exception
     */
    public static boolean compareCenters(String center, String newCenter, double delta) throws Exception{
        List<List<Double>> oldCenters = getCenterFromFileSystem(center);
        List<List<Double>> newCenters = getCenterFromFileSystem(newCenter);

        // 質心點數
        int size = oldCenters.size();
        // 維度
        int fieldSize = oldCenters.get(0).size();

        double distance = 0.0;

        for (int i = 0; i < size; i++) {
            for (int j = 0; j < fieldSize; j++) {
                double p1 = Math.abs(oldCenters.get(i).get(j));
                double p2 = Math.abs(newCenters.get(i).get(j));
                // this is used euclidean distance.
                distance += Math.pow(p1 - p2, 2);
            }
        }

        System.out.println("當前的質心點迭代變化值: " + distance);
        // 在區間內
        if(distance <= delta){
            return true;
        }else{
            Path centerPath = new Path(center);
            Path newCenterPath = new Path(newCenter);
            FileSystem fs = getFileSystem();

            // 刪除當前質點文件
            fs.delete(centerPath,true );

            // 將新質點文件結果移動到當前質點文件
            fs.rename(newCenterPath,centerPath);
        }
        return false;
    }
}

能夠看到,咱們的K=3,而且選擇的是數據集中的前三個點做爲初始迭代的質心點。固然,更好的算法應該是從數據集中隨機選取3個點或者以貼合業務的選取方式選取初始點,從算法中咱們能夠了解到,初始點的選擇在必定迭代次數內是對結果有很大的影響的。

3.五、繪圖

最終,咱們獲得的結果以下,其中的紅點即爲質心點

相關文章
相關標籤/搜索