先解釋一下什麼是UDAF(User Defined Aggregate Function),即用戶定義的聚合函數,聚合函數和普通函數的區別是什麼呢,普通函數是接受一行輸入產生一個輸出,聚合函數是接受一組(通常是多行)輸入而後產生一個輸出,即將一組的值想辦法聚合一下。html
咱們可能下意識的認爲UDAF是須要和group by一塊兒使用的,實際上UDAF能夠跟group by一塊兒使用,也能夠不跟group by一塊兒使用,這個其實比較好理解,聯想到mysql中的max、min等函數,能夠:java
select max(foo) from foobar group by bar;
表示根據bar字段分組,而後求每一個分組的最大值,這時候的分組有不少個,使用這個函數對每一個分組進行處理,也能夠:mysql
select max(foo) from foobar;
這種狀況能夠將整張表看作是一個分組,而後在這個分組(實際上就是一整張表)中求最大值。因此聚合函數其實是對分組作處理,而不關心分組中記錄的具體數量。sql
使用UserDefinedAggregateFunction的套路:express
1. 自定義類繼承UserDefinedAggregateFunction,對每一個階段方法作實現apache
2. 在spark中註冊UDAF,爲其綁定一個名字json
3. 而後就能夠在sql語句中使用上面綁定的名字調用緩存
下面寫一個計算平均值的UDAF例子,首先定義一個類繼承UserDefinedAggregateFunction:數據結構
package cc11001100.spark.sql.udaf import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ object AverageUserDefinedAggregateFunction extends UserDefinedAggregateFunction { // 聚合函數的輸入數據結構 override def inputSchema: StructType = StructType(StructField("input", LongType) :: Nil) // 緩存區數據結構 override def bufferSchema: StructType = StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil) // 聚合函數返回值數據結構 override def dataType: DataType = DoubleType // 聚合函數是不是冪等的,即相同輸入是否老是能獲得相同輸出 override def deterministic: Boolean = true // 初始化緩衝區 override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0L buffer(1) = 0L } // 給聚合函數傳入一條新數據進行處理 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { if (input.isNullAt(0)) return buffer(0) = buffer.getLong(0) + input.getLong(0) buffer(1) = buffer.getLong(1) + 1 } // 合併聚合函數緩衝區 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1) } // 計算最終結果 override def evaluate(buffer: Row): Any = buffer.getLong(0).toDouble / buffer.getLong(1) }
而後註冊並使用它:app
package cc11001100.spark.sql.udaf import org.apache.spark.sql.SparkSession object SparkSqlUDAFDemo_001 { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().master("local[*]").appName("SparkStudy").getOrCreate() spark.read.json("data/user").createOrReplaceTempView("v_user") spark.udf.register("u_avg", AverageUserDefinedAggregateFunction) // 將整張表看作是一個分組對求全部人的平均年齡 spark.sql("select count(1) as count, u_avg(age) as avg_age from v_user").show() // 按照性別分組求平均年齡 spark.sql("select sex, count(1) as count, u_avg(age) as avg_age from v_user group by sex").show() } }
使用到的數據集:
{"id": 1001, "name": "foo", "sex": "man", "age": 20} {"id": 1002, "name": "bar", "sex": "man", "age": 24} {"id": 1003, "name": "baz", "sex": "man", "age": 18} {"id": 1004, "name": "foo1", "sex": "woman", "age": 17} {"id": 1005, "name": "bar2", "sex": "woman", "age": 19} {"id": 1006, "name": "baz3", "sex": "woman", "age": 20}
運行結果:
還有另外一種方式就是繼承Aggregator這個類,優勢是能夠帶類型:
package cc11001100.spark.sql.udaf import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.{Encoder, Encoders} /** * 計算平均值 * */ object AverageAggregator extends Aggregator[User, Average, Double] { // 初始化buffer override def zero: Average = Average(0L, 0L) // 處理一條新的記錄 override def reduce(b: Average, a: User): Average = { b.sum += a.age b.count += 1L b } // 合併聚合buffer override def merge(b1: Average, b2: Average): Average = { b1.sum += b2.sum b1.count += b2.count b1 } // 減小中間數據傳輸 override def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count override def bufferEncoder: Encoder[Average] = Encoders.product // 最終輸出結果的類型 override def outputEncoder: Encoder[Double] = Encoders.scalaDouble } /** * 計算平均值過程當中使用的Buffer * * @param sum * @param count */ case class Average(var sum: Long, var count: Long) { } case class User(id: Long, name: String, sex: String, age: Long) { }
調用:
package cc11001100.spark.sql.udaf import org.apache.spark.sql.SparkSession object AverageAggregatorDemo_001 { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().master("local[*]").appName("SparkStudy").getOrCreate() import spark.implicits._ val user = spark.read.json("data/user").as[User] user.select(AverageAggregator.toColumn.name("avg")).show() } }
運行結果:
.