最近在網上查看用MapReduce實現的Kmeans算法,例子是不錯,http://blog.csdn.net/jshayzf/article/details/22739063java
但註釋太少了,並且參數太多,若是新手學習的話不太好理解。因此本身按照我的的理解寫了一個簡單的例子並添加了詳細的註釋。算法
大體的步驟是:apache
1,Map每讀取一條數據就與中心作對比,求出該條記錄對應的中心,而後以中心的ID爲Key,該條數據爲value將數據輸出。app
2,利用reduce的歸併功能將相同的Key歸併到一塊兒,集中與該Key對應的數據,再求出這些數據的平均值,輸出平均值。oop
3,對比reduce求出的平均值與原來的中心,若是不相同,這將清空原中心的數據文件,將reduce的結果寫到中心文件中。(中心的值存在一個HDFS的文件中)學習
刪掉reduce的輸出目錄以便下次輸出。spa
繼續運行任務。.net
4,對比reduce求出的平均值與原來的中心,若是相同。則刪掉reduce的輸出目錄,運行一個沒有reduce的任務將中心ID與值對應輸出。code
1 package MyKmeans; 2 3 import java.io.IOException; 4 import java.util.ArrayList; 5 6 import org.apache.hadoop.conf.Configuration; 7 import org.apache.hadoop.fs.Path; 8 import org.apache.hadoop.io.Text; 9 10 import java.util.Arrays; 11 import java.util.Iterator; 12 13 import org.apache.hadoop.io.IntWritable; 14 import org.apache.hadoop.io.LongWritable; 15 import org.apache.hadoop.mapreduce.Job; 16 import org.apache.hadoop.mapreduce.Mapper; 17 import org.apache.hadoop.mapreduce.Reducer; 18 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; 19 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; 20 21 22 public class MapReduce { 23 24 public static class Map extends Mapper<LongWritable, Text, IntWritable, Text>{ 25 26 //中心集合 27 ArrayList<ArrayList<Double>> centers = null; 28 //用k箇中心 29 int k = 0; 30 31 //讀取中心 32 protected void setup(Context context) throws IOException, 33 InterruptedException { 34 centers = Utils.getCentersFromHDFS(context.getConfiguration().get("centersPath"),false); 35 k = centers.size(); 36 } 37 38 39 /** 40 * 1.每次讀取一條要分類的條記錄與中心作對比,歸類到對應的中心 41 * 2.以中心ID爲key,中心包含的記錄爲value輸出(例如: 1 0.2 。 1爲聚類中心的ID,0.2爲靠近聚類中心的某個值) 42 */ 43 protected void map(LongWritable key, Text value, Context context) 44 throws IOException, InterruptedException { 45 //讀取一行數據 46 ArrayList<Double> fileds = Utils.textToArray(value); 47 int sizeOfFileds = fileds.size(); 48 49 double minDistance = 99999999; 50 int centerIndex = 0; 51 52 //依次取出k箇中心點與當前讀取的記錄作計算 53 for(int i=0;i<k;i++){ 54 double currentDistance = 0; 55 for(int j=0;j<sizeOfFileds;j++){ 56 double centerPoint = Math.abs(centers.get(i).get(j)); 57 double filed = Math.abs(fileds.get(j)); 58 currentDistance += Math.pow((centerPoint - filed) / (centerPoint + filed), 2); 59 } 60 //循環找出距離該記錄最接近的中心點的ID 61 if(currentDistance<minDistance){ 62 minDistance = currentDistance; 63 centerIndex = i; 64 } 65 } 66 //以中心點爲Key 將記錄原樣輸出 67 context.write(new IntWritable(centerIndex+1), value); 68 } 69 70 } 71 72 //利用reduce的歸併功能以中心爲Key將記錄歸併到一塊兒 73 public static class Reduce extends Reducer<IntWritable, Text, Text, Text>{ 74 75 /** 76 * 1.Key爲聚類中心的ID value爲該中心的記錄集合 77 * 2.計數全部記錄元素的平均值,求出新的中心 78 */ 79 protected void reduce(IntWritable key, Iterable<Text> value,Context context) 80 throws IOException, InterruptedException { 81 ArrayList<ArrayList<Double>> filedsList = new ArrayList<ArrayList<Double>>(); 82 83 //依次讀取記錄集,每行爲一個ArrayList<Double> 84 for(Iterator<Text> it =value.iterator();it.hasNext();){ 85 ArrayList<Double> tempList = Utils.textToArray(it.next()); 86 filedsList.add(tempList); 87 } 88 89 //計算新的中心 90 //每行的元素個數 91 int filedSize = filedsList.get(0).size(); 92 double[] avg = new double[filedSize]; 93 for(int i=0;i<filedSize;i++){ 94 //求沒列的平均值 95 double sum = 0; 96 int size = filedsList.size(); 97 for(int j=0;j<size;j++){ 98 sum += filedsList.get(j).get(i); 99 } 100 avg[i] = sum / size; 101 } 102 context.write(new Text("") , new Text(Arrays.toString(avg).replace("[", "").replace("]", ""))); 103 } 104 105 } 106 107 @SuppressWarnings("deprecation") 108 public static void run(String centerPath,String dataPath,String newCenterPath,boolean runReduce) throws IOException, ClassNotFoundException, InterruptedException{ 109 110 Configuration conf = new Configuration(); 111 conf.set("centersPath", centerPath); 112 113 Job job = new Job(conf, "mykmeans"); 114 job.setJarByClass(MapReduce.class); 115 116 job.setMapperClass(Map.class); 117 118 job.setMapOutputKeyClass(IntWritable.class); 119 job.setMapOutputValueClass(Text.class); 120 121 if(runReduce){ 122 //最後依次輸出不準要reduce 123 job.setReducerClass(Reduce.class); 124 job.setOutputKeyClass(Text.class); 125 job.setOutputValueClass(Text.class); 126 } 127 128 FileInputFormat.addInputPath(job, new Path(dataPath)); 129 130 FileOutputFormat.setOutputPath(job, new Path(newCenterPath)); 131 132 System.out.println(job.waitForCompletion(true)); 133 } 134 135 public static void main(String[] args) throws ClassNotFoundException, IOException, InterruptedException { 136 String centerPath = "hdfs://localhost:9000/input/centers.txt"; 137 String dataPath = "hdfs://localhost:9000/input/wine.txt"; 138 String newCenterPath = "hdfs://localhost:9000/out/kmean"; 139 140 int count = 0; 141 142 143 while(true){ 144 run(centerPath,dataPath,newCenterPath,true); 145 System.out.println(" 第 " + ++count + " 次計算 "); 146 if(Utils.compareCenters(centerPath,newCenterPath )){ 147 run(centerPath,dataPath,newCenterPath,false); 148 break; 149 } 150 } 151 } 152 153 }
1 package MyKmeans; 2 3 import java.io.IOException; 4 import java.util.ArrayList; 5 import java.util.List; 6 7 import org.apache.hadoop.conf.Configuration; 8 import org.apache.hadoop.fs.FSDataInputStream; 9 import org.apache.hadoop.fs.FSDataOutputStream; 10 import org.apache.hadoop.fs.FileStatus; 11 import org.apache.hadoop.fs.FileSystem; 12 import org.apache.hadoop.fs.Path; 13 import org.apache.hadoop.io.IOUtils; 14 import org.apache.hadoop.io.Text; 15 import org.apache.hadoop.util.LineReader; 16 17 public class Utils { 18 19 //讀取中心文件的數據 20 public static ArrayList<ArrayList<Double>> getCentersFromHDFS(String centersPath,boolean isDirectory) throws IOException{ 21 22 ArrayList<ArrayList<Double>> result = new ArrayList<ArrayList<Double>>(); 23 24 Path path = new Path(centersPath); 25 26 Configuration conf = new Configuration(); 27 28 FileSystem fileSystem = path.getFileSystem(conf); 29 30 if(isDirectory){ 31 FileStatus[] listFile = fileSystem.listStatus(path); 32 for (int i = 0; i < listFile.length; i++) { 33 result.addAll(getCentersFromHDFS(listFile[i].getPath().toString(),false)); 34 } 35 return result; 36 } 37 38 FSDataInputStream fsis = fileSystem.open(path); 39 LineReader lineReader = new LineReader(fsis, conf); 40 41 Text line = new Text(); 42 43 while(lineReader.readLine(line) > 0){ 44 ArrayList<Double> tempList = textToArray(line); 45 result.add(tempList); 46 } 47 lineReader.close(); 48 return result; 49 } 50 51 //刪掉文件 52 public static void deletePath(String pathStr) throws IOException{ 53 Configuration conf = new Configuration(); 54 Path path = new Path(pathStr); 55 FileSystem hdfs = path.getFileSystem(conf); 56 hdfs.delete(path ,true); 57 } 58 59 public static ArrayList<Double> textToArray(Text text){ 60 ArrayList<Double> list = new ArrayList<Double>(); 61 String[] fileds = text.toString().split(","); 62 for(int i=0;i<fileds.length;i++){ 63 list.add(Double.parseDouble(fileds[i])); 64 } 65 return list; 66 } 67 68 public static boolean compareCenters(String centerPath,String newPath) throws IOException{ 69 70 List<ArrayList<Double>> oldCenters = Utils.getCentersFromHDFS(centerPath,false); 71 List<ArrayList<Double>> newCenters = Utils.getCentersFromHDFS(newPath,true); 72 73 int size = oldCenters.size(); 74 int fildSize = oldCenters.get(0).size(); 75 double distance = 0; 76 for(int i=0;i<size;i++){ 77 for(int j=0;j<fildSize;j++){ 78 double t1 = Math.abs(oldCenters.get(i).get(j)); 79 double t2 = Math.abs(newCenters.get(i).get(j)); 80 distance += Math.pow((t1 - t2) / (t1 + t2), 2); 81 } 82 } 83 84 if(distance == 0.0){ 85 //刪掉新的中心文件以便最後依次歸類輸出 86 Utils.deletePath(newPath); 87 return true; 88 }else{ 89 //先清空中心文件,將新的中心文件複製到中心文件中,再刪掉中心文件 90 91 Configuration conf = new Configuration(); 92 Path outPath = new Path(centerPath); 93 FileSystem fileSystem = outPath.getFileSystem(conf); 94 95 FSDataOutputStream overWrite = fileSystem.create(outPath,true); 96 overWrite.writeChars(""); 97 overWrite.close(); 98 99 100 Path inPath = new Path(newPath); 101 FileStatus[] listFiles = fileSystem.listStatus(inPath); 102 for (int i = 0; i < listFiles.length; i++) { 103 FSDataOutputStream out = fileSystem.create(outPath); 104 FSDataInputStream in = fileSystem.open(listFiles[i].getPath()); 105 IOUtils.copyBytes(in, out, 4096, true); 106 } 107 //刪掉新的中心文件以便第二次任務運行輸出 108 Utils.deletePath(newPath); 109 } 110 111 return false; 112 } 113 }
數據集 http://archive.ics.uci.edu/ml/machine-learning-databases/wine/wine.dataorm
運行結果能夠與 http://blog.csdn.net/jshayzf/article/details/22739063的結果作對比(前提是初始的中心相同)