大數據處理中,對數據分組後,取TopN是很是常見的運算。python
下面咱們以一個例子來展現spark如何進行分組取Top的運算。sql
from pyspark import SparkContext sc = SparkContext()
準備數據,把數據轉換爲rdd格式函數
data_list = [ (0, "cat26", 130.9), (0, "cat13", 122.1), (0, "cat95", 119.6), (0, "cat105", 11.3), (1, "cat67", 128.5), (1, "cat4", 126.8), (1, "cat13", 112.6), (1, "cat23", 15.3), (2, "cat56", 139.6), (2, "cat40", 129.7), (2, "cat187", 127.9), (2, "cat68", 19.8), (3, "cat8", 135.6) ] data = sc.parallelize(data_list) data.collect()
[(0, 'cat26', 130.9), (0, 'cat13', 122.1), (0, 'cat95', 119.6), (0, 'cat105', 11.3), (1, 'cat67', 128.5), (1, 'cat4', 126.8), (1, 'cat13', 112.6), (1, 'cat23', 15.3), (2, 'cat56', 139.6), (2, 'cat40', 129.7), (2, 'cat187', 127.9), (2, 'cat68', 19.8), (3, 'cat8', 135.6)]
對數據使用groupBy操做來分組。能夠看到分組後數據爲(key, list_data)大數據
d1 = data.groupBy(lambda x:x[0]) temp = d1.collect() print(list(temp[0][1])) print(temp)
[(0, 'cat26', 130.9), (0, 'cat13', 122.1), (0, 'cat95', 119.6), (0, 'cat105', 11.3)] [(0, <pyspark.resultiterable.ResultIterable object at 0x0000000007D2C710>), (1, <pyspark.resultiterable.ResultIterable object at 0x0000000007D2C780>), (2, <pyspark.resultiterable.ResultIterable object at 0x0000000007D2C898>), (3, <pyspark.resultiterable.ResultIterable object at 0x0000000007D2C9B0>)]
使用mapValues方法對數據進行排序。ui
能夠根據須要來取Top N 數據。spa
這裏取Top 3 的數據code
d2 = d1.mapValues(lambda x: sorted(x, key=lambda y:y[2])[:3]) d2.collect()
[(0, [(0, 'cat105', 11.3), (0, 'cat95', 119.6), (0, 'cat13', 122.1)]), (1, [(1, 'cat23', 15.3), (1, 'cat13', 112.6), (1, 'cat4', 126.8)]), (2, [(2, 'cat68', 19.8), (2, 'cat187', 127.9), (2, 'cat40', 129.7)]), (3, [(3, 'cat8', 135.6)])]
使用flatmap方法把結果拉平,變成一個list返回。排序
d3 = d2.flatMap(lambda x:[i for i in x[1]]) d3.collect()
[(0, 'cat105', 11.3), (0, 'cat95', 119.6), (0, 'cat13', 122.1), (1, 'cat23', 15.3), (1, 'cat13', 112.6), (1, 'cat4', 126.8), (2, 'cat68', 19.8), (2, 'cat187', 127.9), (2, 'cat40', 129.7), (3, 'cat8', 135.6)]
from pyspark import SparkContext # sc = SparkContext() topN = 3 data_list = [ (0, "cat26", 130.9), (0, "cat13", 122.1), (0, "cat95", 119.6), (0, "cat105", 11.3), (1, "cat67", 128.5), (1, "cat4", 126.8), (1, "cat13", 112.6), (1, "cat23", 15.3), (2, "cat56", 139.6), (2, "cat40", 129.7), (2, "cat187", 127.9), (2, "cat68", 19.8), (3, "cat8", 135.6) ] data = sc.parallelize(data_list) d1 = data.groupBy(lambda x:x[0]) d2 = d1.mapValues(lambda x: sorted(x, key=lambda y:y[2])[:topN]) d3 = d2.flatMap(lambda x:[i for i in x[1]]) d3.collect()
[(0, 'cat105', 11.3), (0, 'cat95', 119.6), (0, 'cat13', 122.1), (1, 'cat23', 15.3), (1, 'cat13', 112.6), (1, 'cat4', 126.8), (2, 'cat68', 19.8), (2, 'cat187', 127.9), (2, 'cat40', 129.7), (3, 'cat8', 135.6)]
dataframe數據格式分組取top N,簡單的方法是使用Window方法get
from pyspark.sql import SparkSession from pyspark.sql import functions as func from pyspark.sql import Window spark = SparkSession.builder.getOrCreate() data_list = [ (0, "cat26", 130.9), (0, "cat13", 122.1), (0, "cat95", 119.6), (0, "cat105", 11.3), (1, "cat67", 128.5), (1, "cat4", 126.8), (1, "cat13", 112.6), (1, "cat23", 15.3), (2, "cat56", 139.6), (2, "cat40", 129.7), (2, "cat187", 127.9), (2, "cat68", 19.8), (3, "cat8", 135.6) ]
根據數據建立dataframe,並給數據列命名
df = spark.createDataFrame(data_list, ["Hour", "Category", "TotalValue"]) df.show()
+----+--------+----------+ |Hour|Category|TotalValue| +----+--------+----------+ | 0| cat26| 130.9| | 0| cat13| 122.1| | 0| cat95| 119.6| | 0| cat105| 11.3| | 1| cat67| 128.5| | 1| cat4| 126.8| | 1| cat13| 112.6| | 1| cat23| 15.3| | 2| cat56| 139.6| | 2| cat40| 129.7| | 2| cat187| 127.9| | 2| cat68| 19.8| | 3| cat8| 135.6| +----+--------+----------+
使用窗口方法,分片參數爲分組的key,it
orderBy的參數爲排序的key,這裏使用desc降序排列。
withColumn(colName, col),爲df添加一列,數據爲對window函數生成的數據編號
where方法取rn列值小於3的數據,即取top3數據
w = Window.partitionBy(df.Hour).orderBy(df.TotalValue.desc()) top3 = df.withColumn('rn', func.row_number().over(w)).where('rn <=3') top3.show()
+----+--------+----------+---+ |Hour|Category|TotalValue| rn| +----+--------+----------+---+ | 0| cat26| 130.9| 1| | 0| cat13| 122.1| 2| | 0| cat95| 119.6| 3| | 1| cat67| 128.5| 1| | 1| cat4| 126.8| 2| | 1| cat13| 112.6| 3| | 3| cat8| 135.6| 1| | 2| cat56| 139.6| 1| | 2| cat40| 129.7| 2| | 2| cat187| 127.9| 3| +----+--------+----------+---+
### 代碼彙總 from pyspark.sql import SparkSession from pyspark.sql import functions as func from pyspark.sql import Window spark = SparkSession.builder.getOrCreate() data_list = [ (0, "cat26", 130.9), (0, "cat13", 122.1), (0, "cat95", 119.6), (0, "cat105", 11.3), (1, "cat67", 128.5), (1, "cat4", 126.8), (1, "cat13", 112.6), (1, "cat23", 15.3), (2, "cat56", 139.6), (2, "cat40", 129.7), (2, "cat187", 127.9), (2, "cat68", 19.8), (3, "cat8", 135.6) ] df = spark.createDataFrame(data_list, ["Hour", "Category", "TotalValue"]) w = Window.partitionBy(df.Hour).orderBy(df.TotalValue.desc()) top3 = df.withColumn('rn', func.row_number().over(w)).where('rn <=3') top3.show()