Spark:求出分組內的TopN

製做測試數據源:

c1 85 c2 77 c3 88 c1 22 c1 66 c3 95 c3 54 c2 91 c2 66 c1 54 c1 65 c2 41 c4 65

spark scala實現代碼:

import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession object GroupTopN1 { System.setProperty("hadoop.home.dir", "D:\\Java_Study\\hadoop-common-2.2.0-bin-master") case class Rating(userId: String, rating: Long) def main(args: Array[String]) { val sparkConf = new SparkConf().setAppName("ALS with ML Pipeline") val spark = SparkSession .builder() .config(sparkConf) .master("local") .config("spark.sql.warehouse.dir", "/") .getOrCreate() 
    import spark.implicits._ import spark.sql val lines = spark.read.textFile("C:\\Users\\Administrator\\Desktop\\group.txt") val classScores = lines.map(line => Rating(line.split(" ")(0).toString, line.split(" ")(1).toLong)) classScores.createOrReplaceTempView("tb_test") var df = sql( s"""|select
          | userId, | rating, | row_number()over(partition by userId order by rating desc) rn |from tb_test |having(rn<=3) |""".stripMargin)
 df.show() spark.stop() } }

打印結果:java

+------+------+---+
|userId|rating| rn|
+------+------+---+
|    c1|    85|  1|
|    c1|    66|  2|
|    c1|    65|  3|
|    c4|    65|  1|
|    c3|    95|  1|
|    c3|    88|  2|
|    c3|    54|  3|
|    c2|    91|  1|
|    c2|    77|  2|
|    c2|    66|  3|
+------+------+---+

spark java代碼實現:

import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.*; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import scala.Function1; import javax.management.RuntimeErrorException; import java.util.List; import java.util.ArrayList; public class Test { public static void main(String[] args) { System.out.println("Hello"); SparkConf sparkConf = new SparkConf().setAppName("ALS with ML Pipeline"); SparkSession spark = SparkSession .builder() .config(sparkConf) .master("local") .config("spark.sql.warehouse.dir", "/") .getOrCreate(); // Create an RDD
        JavaRDD<String> peopleRDD = spark.sparkContext() .textFile("C:\\Users\\Administrator\\Desktop\\group.txt", 1) .toJavaRDD(); // The schema is encoded in a string
        String schemaString = "userId rating"; // Generate the schema based on the string of schema
        List<StructField> fields = new ArrayList<>(); StructField field1 = DataTypes.createStructField("userId", DataTypes.StringType, true); StructField field2 = DataTypes.createStructField("rating", DataTypes.LongType, true); fields.add(field1); fields.add(field2); StructType schema = DataTypes.createStructType(fields); // Convert records of the RDD (people) to Rows
        JavaRDD<Row> rowRDD = peopleRDD.map((Function<String, Row>) record -> { String[] attributes = record.split(" "); if(attributes.length!=2){ throw new Exception(); } return RowFactory.create(attributes[0],Long.valueOf( attributes[1].trim())); }); // Apply the schema to the RDD
        Dataset<Row> peopleDataFrame = spark.createDataFrame(rowRDD, schema); peopleDataFrame.createOrReplaceTempView("tb_test"); Dataset<Row> items = spark.sql("select userId,rating,row_number()over(partition by userId order by rating desc) rn " +
                "from tb_test " +
                "having(rn<=3)"); items.show(); spark.stop(); } }

輸出結果同上邊輸出結果。sql

Java 中使用combineByKey實現TopN:

import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import scala.Tuple2; import java.util.*; public class SparkJava { public static void main(String[] args) { SparkSession spark = SparkSession.builder().master("local[*]").appName("Spark").getOrCreate(); final JavaSparkContext ctx = JavaSparkContext.fromSparkContext(spark.sparkContext()); List<String> data = Arrays.asList("a,110,a1", "b,122,b1", "c,123,c1", "a,210,a2", "b,212,b2", "a,310,a3", "b,312,b3", "a,410,a4", "b,412,b4"); JavaRDD<String> javaRDD = ctx.parallelize(data); JavaPairRDD<String, Integer> javaPairRDD = javaRDD.mapToPair(new PairFunction<String, String, Integer>() { public Tuple2<String, Integer> call(String key) throws Exception { return new Tuple2<String, Integer>(key.split(",")[0], Integer.valueOf(key.split(",")[1])); } }); final int topN = 3; JavaPairRDD<String, List<Integer>> combineByKeyRDD2 = javaPairRDD.combineByKey(new Function<Integer, List<Integer>>() { public List<Integer> call(Integer v1) throws Exception { List<Integer> items = new ArrayList<Integer>(); items.add(v1); return items; } }, new Function2<List<Integer>, Integer, List<Integer>>() { public List<Integer> call(List<Integer> v1, Integer v2) throws Exception { if (v1.size() > topN) { Integer item = Collections.min(v1); v1.remove(item); v1.add(v2); } return v1; } }, new Function2<List<Integer>, List<Integer>, List<Integer>>() { public List<Integer> call(List<Integer> v1, List<Integer> v2) throws Exception { v1.addAll(v2); while (v1.size() > topN) { Integer item = Collections.min(v1); v1.remove(item); } return v1; } }); // 由K:String,V:List<Integer> 轉化爲 K:String,V:Integer // old:[(a,[210, 310, 410]), (b,[122, 212, 312]), (c,[123])] // new:[(a,210), (a,310), (a,410), (b,122), (b,212), (b,312), (c,123)]
        JavaRDD<Tuple2<String, Integer>> javaTupleRDD = combineByKeyRDD2.flatMap(new FlatMapFunction<Tuple2<String, List<Integer>>, Tuple2<String, Integer>>() { public Iterator<Tuple2<String, Integer>> call(Tuple2<String, List<Integer>> stringListTuple2) throws Exception { List<Tuple2<String, Integer>> items=new ArrayList<Tuple2<String, Integer>>(); for(Integer v:stringListTuple2._2){ items.add(new Tuple2<String, Integer>(stringListTuple2._1,v)); } return items.iterator(); } }); JavaRDD<Row> rowRDD = javaTupleRDD.map(new Function<Tuple2<String, Integer>, Row>() { public Row call(Tuple2<String, Integer> kv) throws Exception { String key = kv._1; Integer num = kv._2; return RowFactory.create(key, num); } }); ArrayList<StructField> fields = new ArrayList<StructField>(); StructField field = null; field = DataTypes.createStructField("key", DataTypes.StringType, true); fields.add(field); field = DataTypes.createStructField("TopN_values", DataTypes.IntegerType, true); fields.add(field); StructType schema = DataTypes.createStructType(fields); Dataset<Row> df = spark.createDataFrame(rowRDD, schema); df.printSchema(); df.show(); spark.stop(); } }

輸出:apache

root |-- key: string (nullable = true) |-- TopN_values: integer (nullable = true) +---+-----------+
|key|TopN_values|
+---+-----------+
|  a|        210|
|  a|        310|
|  a|        410|
|  b|        122|
|  b|        212|
|  b|        312|
|  c|        123|
+---+-----------+

Spark使用combineByKeyWithClassTag函數實現TopN

combineByKeyWithClassTag函數,藉助HashSet的排序,此例是取組內最大的N個元素一下是代碼:api

  • createcombiner就簡單的將首個元素裝進HashSet而後返回就能夠了;
  • mergevalue插入元素以後,若是元素的個數大於N就刪除最小的元素;
  • mergeCombiner在合併以後,若是總的個數大於N,就從一次刪除最小的元素,知道Hashset內只有N 個元素。
import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import scala.collection.mutable object Main { val N = 3 def main(args: Array[String]): Unit = { val spark = SparkSession .builder() .master("local[*]") .appName("Spark") .getOrCreate() val sc = spark.sparkContext var SampleDataset = List( ("apple.com", 3L), ("apple.com", 4L), ("apple.com", 1L), ("apple.com", 9L), ("google.com", 4L), ("google.com", 1L), ("google.com", 2L), ("google.com", 3L), ("google.com", 11L), ("google.com", 32L), ("slashdot.org", 11L), ("slashdot.org", 12L), ("slashdot.org", 13L), ("slashdot.org", 14L), ("slashdot.org", 15L), ("slashdot.org", 16L), ("slashdot.org", 17L), ("slashdot.org", 18L), ("microsoft.com", 5L), ("microsoft.com", 2L), ("microsoft.com", 6L), ("microsoft.com", 9L), ("google.com", 4L)) val urdd: RDD[(String, Long)] = sc.parallelize(SampleDataset).map((t) => (t._1, t._2)) var topNs = urdd.combineByKeyWithClassTag( //createCombiner
      (firstInt: Long) => { var uset = new mutable.TreeSet[Long]() uset += firstInt }, // mergeValue
      (uset: mutable.TreeSet[Long], value: Long) => { uset += value while (uset.size > N) { uset.remove(uset.min) } uset }, //mergeCombiners
      (uset1: mutable.TreeSet[Long], uset2: mutable.TreeSet[Long]) => { var resultSet = uset1 ++ uset2 while (resultSet.size > N) { resultSet.remove(resultSet.min) } resultSet } ) import spark.implicits._ topNs.flatMap(rdd => { var uset = new mutable.HashSet[String]() for (i <- rdd._2.toList) { uset += rdd._1 + "/" + i.toString } uset }).map(rdd => { (rdd.split("/")(0), rdd.split("/")(1)) }).toDF("key", "TopN_values").show() } }

參考《https://blog.csdn.net/gpwner/article/details/78455234》app

輸出結果:函數

+-------------+-----------+
|          key|TopN_values|
+-------------+-----------+
|   google.com|          4|
|   google.com|         11|
|   google.com|         32|
|microsoft.com|          9|
|microsoft.com|          6|
|microsoft.com|          5|
|    apple.com|          4|
|    apple.com|          9|
|    apple.com|          3|
| slashdot.org|         16|
| slashdot.org|         17|
| slashdot.org|         18|
+-------------+-----------+
相關文章
相關標籤/搜索