spark實現smote近鄰採樣

一.smote相關理論php

(1).java

SMOTE是一種對普經過採樣(oversampling)的一個改良。普通的過採樣會使得訓練集中有不少重複的樣本。git

SMOTE的全稱是Synthetic Minority Over-Sampling Technique,譯爲「人工少數類過採樣法」。github

SMOTE沒有直接對少數類進行重採樣,而是設計了算法來人工合成一些新的少數類的樣本。算法

爲了敘述方便,就假設陽性爲少數類,陰性爲多數類session

合成新少數類的陽性樣本的算法以下:app

  1. 選定一個陽性樣本ss
  2. 找到ss最近的kk個樣本,kk能夠取5,10之類。這kk個樣本可能有陽性的也有陰性的。
  3. 從這kk個樣本中隨機挑選一個樣本,記爲rr。
  4. 合成一個新的陽性樣本ss′,s=λs+(1λ)rs′=λs+(1−λ)r,λλ是(0,1)(0,1)之間的隨機數。換句話說,新生成的點在rr與ss之間的連線上。

 

重複以上步驟,就能夠生成不少陽性樣本。dom

=======畫了幾張圖,更新一下======ide

用圖的形式說明一下SMOTE的步驟:post

1.先選定一個陽性樣本(假設陽性爲少數類)

2.找出這個陽性樣本的k近鄰(假設k=5)。5個近鄰已經被圈出。

3.隨機從這k個近鄰中選出一個樣本(用綠色圈出來了)。

4.在陽性樣本和被選出的這個近鄰之間的連線上,隨機找一點。這個點就是人工合成的新的陽性樣本(綠色正號標出)。

以上來自http://sofasofa.io/forum_main_post.php?postid=1000817中的敘述

 

(2).

With this approach, the positive class is over-sampled by taking each minority class sample and introducing synthetic examples along the line segments joining any/all of the k minority class nearest neighbours. Depending upon the amount of over-sampling required, neighbours from the k nearest neighbours are randomly chosen. This process is illustrated in the following Figure, where xixi is the selected point, xi1xi1 to xi4xi4are some selected nearest neighbours and r1r1 to r4r4 the synthetic data points created by the randomized interpolation. The implementation of this work uses only one nearest neighbour with the euclidean distance, and balances both classes to 50% distribution.

Synthetic samples are generated in the following way: Take the difference between the feature vector (sample) under consideration and its nearest neighbour. Multiply this difference by a random number between 0 and 1, and add it to the feature vector under consideration. This causes the selection of a random point along the line segment between two specific features. This approach effectively forces the decision region of the minority class to become more general. An example is detailed in the next Figure.

In short, the main idea is to form new minority class examples by interpolating between several minority class examples that lie together. In contrast with the common replication techniques (for example random oversampling), in which the decision region usually become more specific, with SMOTE the overfitting problem is somehow avoided by causing the decision boundaries for the minority class to be larger and to spread further into the majority class space, since it provides related minority class samples to learn from. Specifically, selecting a small k-value could also avoid the risk of including some noise in the data.

以上來自https://sci2s.ugr.es/multi-imbalanced中的敘述

 

二.spark實現smote

核心代碼以下,完整代碼https://github.com/jiangnanboy/spark-smote/blob/master/spark%20smote.txt

 

 1 /**
 2      *  (1) 對於少數類(X)中每個樣本x,計算它到少數類樣本集(X)中全部樣本的距離,獲得其k近鄰。
 3      *  (2) 根據樣本不平衡比例設置一個採樣比例以肯定採樣倍率sampling_rate,對於每個少數類樣本x,
 4      *      從其k近鄰中隨機選擇sampling_rate個近鄰,假設選擇的近鄰爲 x(1),x(2),...,x(sampling_rate)。
 5      *  (3) 對於每個隨機選出的近鄰 x(i)(i=1,2,...,sampling_rate),分別與原樣本按照以下的公式構建新的樣本
 6      *      xnew=x+rand(0,1)?(x(i)?x)
 7      *
 8      *  http://sofasofa.io/forum_main_post.php?postid=1000817
 9      *  http://sci2s.ugr.es/multi-imbalanced
10      * @param session
11      * @param labelFeatures
12      * @param knn 樣本類似近鄰
13      * @param samplingRate 近鄰採樣率 (knn * samplingRate),從knn中選擇幾個近鄰
14      * @parm rationToMax 採樣比率(與最多類樣本數的比率) 0.1表示與最多樣本的比率是 -> (1:10),即達到最多樣本的比率
15      * @return
16      */
17     public static Dataset<Row> smote(SparkSession session, Dataset<Row> labelFeatures, int knn, double samplingRate, double rationToMax) {
18 
19         Dataset<Row> labelCountDataset = labelFeatures.groupBy("label").agg(count("label").as("keyCount"));
20         List<Row> listRow = labelCountDataset.collectAsList();
21         ConcurrentMap<String, Long> keyCountConMap = new ConcurrentHashMap<>(); //每一個label對應的樣本數
22         for(Row row : listRow)
23             keyCountConMap.put(row.getString(0), row.getLong(1));
24         Row maxSizeRow = labelCountDataset.select(max("keyCount").as("maxSize")).first();
25         long maxSize = maxSizeRow.getAs("maxSize");//最大樣本數
26 
27         JavaPairRDD<String, SparseVector> sparseVectorJPR = labelFeatures.toJavaRDD().mapToPair(row -> {
28             String label = row.getString(0);
29             SparseVector features = (SparseVector) row.get(1);
30             return new Tuple2<String, SparseVector>(label, features);
31         });
32 
33         JavaPairRDD<String, List<SparseVector>> combineByKeyPairRDD = sparseVectorJPR.combineByKey(sparseVector -> {
34                     List<SparseVector> list = new ArrayList<>();
35                     list.add(sparseVector);
36                     return list;
37                 }, (list, sparseVector) -> {list.add(sparseVector);return list;},
38                 (list_A, list_B) -> {list_A.addAll(list_B);return list_A;});
39 
40 
41         JavaSparkContext jsc = JavaSparkContext.fromSparkContext(session.sparkContext());
42         final Broadcast<ConcurrentMap<String, Long>> keyCountBroadcast = jsc.broadcast(keyCountConMap);
43         final Broadcast<Long> maxSizeBroadcast = jsc.broadcast(maxSize);
44         final Broadcast<Integer> knnBroadcast = jsc.broadcast(knn);
45         final Broadcast<Double> samplingRateBroadcast = jsc.broadcast(samplingRate);
46         final Broadcast<Double> rationToMaxBroadcast = jsc.broadcast(rationToMax);
47 
48         /**
49          * JavaPairRDD<String, List<SparseVector>>
50          * JavaPairRDD<String, String>
51          * JavaRDD<Row>
52          */
53         JavaPairRDD<String, List<SparseVector>> pairRDD = combineByKeyPairRDD
54                 .filter(slt -> {
55                     return slt._2().size() > 1;
56                 })
57                 .mapToPair(slt -> {
58                     String label = slt._1();
59                     ConcurrentMap<String, Long> keySizeConMap = keyCountBroadcast.getValue();
60                     long oldSampleSize = keySizeConMap.get(label);
61                     long max = maxSizeBroadcast.getValue();
62                     double ration = rationToMaxBroadcast.getValue();
63                     int Knn = knnBroadcast.getValue();
64                     double rate = samplingRateBroadcast.getValue();
65                     if (oldSampleSize < maxSize * rationToMax) {
66                         int needSampleSize = (int) (max * ration - oldSampleSize);
67                         List<SparseVector> list = generateSample(slt._2(), needSampleSize, Knn, rate);
68                         return new Tuple2<String, List<SparseVector>>(label, list);
69                     } else {
70                         return slt;
71                     }
72                 });
73 
74         JavaRDD<Row> javaRowRDD = pairRDD.flatMapToPair(slt -> {
75             List<Tuple2<String, SparseVector>> floatPairList = new ArrayList<>();
76             String label = slt._1();
77             for(SparseVector sv : slt._2())
78                 floatPairList.add(new Tuple2<String, SparseVector>(label, sv));
79             return floatPairList.iterator();
80         }).map(svt->{
81             return RowFactory.create(svt._1(), svt._2());
82         });
83 
84         Dataset<Row> resultDataset = session.createDataset(javaRowRDD.rdd(), EncoderInit.getlabelFeaturesRowEncoder());
85         return resultDataset;
86     }
相關文章
相關標籤/搜索