Spark筆記之使用UDAF(User Defined Aggregate Function)

 

1、UDAF簡介

先解釋一下什麼是UDAF(User Defined Aggregate Function),即用戶定義的聚合函數,聚合函數和普通函數的區別是什麼呢,普通函數是接受一行輸入產生一個輸出,聚合函數是接受一組(通常是多行)輸入而後產生一個輸出,即將一組的值想辦法聚合一下。html

 

關於UDAF的一個誤區

咱們可能下意識的認爲UDAF是須要和group by一塊兒使用的,實際上UDAF能夠跟group by一塊兒使用,也能夠不跟group by一塊兒使用,這個其實比較好理解,聯想到mysql中的max、min等函數,能夠:java

select max(foo) from foobar group by bar;

表示根據bar字段分組,而後求每一個分組的最大值,這時候的分組有不少個,使用這個函數對每一個分組進行處理,也能夠:mysql

select max(foo) from foobar;

這種狀況能夠將整張表看作是一個分組,而後在這個分組(實際上就是一整張表)中求最大值。因此聚合函數其實是對分組作處理,而不關心分組中記錄的具體數量。sql

 

2、UDAF使用

2.1 繼承UserDefinedAggregateFunction

使用UserDefinedAggregateFunction的套路:express

1. 自定義類繼承UserDefinedAggregateFunction,對每一個階段方法作實現apache

2. 在spark中註冊UDAF,爲其綁定一個名字json

3. 而後就能夠在sql語句中使用上面綁定的名字調用緩存

 

下面寫一個計算平均值的UDAF例子,首先定義一個類繼承UserDefinedAggregateFunction:數據結構

package cc11001100.spark.sql.udaf

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

object AverageUserDefinedAggregateFunction extends UserDefinedAggregateFunction {

  // 聚合函數的輸入數據結構
  override def inputSchema: StructType = StructType(StructField("input", LongType) :: Nil)

  // 緩存區數據結構
  override def bufferSchema: StructType = StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)

  // 聚合函數返回值數據結構
  override def dataType: DataType = DoubleType

  // 聚合函數是不是冪等的,即相同輸入是否老是能獲得相同輸出
  override def deterministic: Boolean = true

  // 初始化緩衝區
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 0L
  }

  // 給聚合函數傳入一條新數據進行處理
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if (input.isNullAt(0)) return
    buffer(0) = buffer.getLong(0) + input.getLong(0)
    buffer(1) = buffer.getLong(1) + 1
  }

  // 合併聚合函數緩衝區
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  // 計算最終結果
  override def evaluate(buffer: Row): Any = buffer.getLong(0).toDouble / buffer.getLong(1)

}

而後註冊並使用它:app

package cc11001100.spark.sql.udaf

import org.apache.spark.sql.SparkSession

object SparkSqlUDAFDemo_001 {

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder().master("local[*]").appName("SparkStudy").getOrCreate()
    spark.read.json("data/user").createOrReplaceTempView("v_user")
    spark.udf.register("u_avg", AverageUserDefinedAggregateFunction)
    // 將整張表看作是一個分組對求全部人的平均年齡
    spark.sql("select count(1) as count, u_avg(age) as avg_age from v_user").show()
    // 按照性別分組求平均年齡
    spark.sql("select sex, count(1) as count, u_avg(age) as avg_age from v_user group by sex").show()

  }

}

使用到的數據集:

{"id": 1001, "name": "foo", "sex": "man", "age": 20}
{"id": 1002, "name": "bar", "sex": "man", "age": 24}
{"id": 1003, "name": "baz", "sex": "man", "age": 18}
{"id": 1004, "name": "foo1", "sex": "woman", "age": 17}
{"id": 1005, "name": "bar2", "sex": "woman", "age": 19}
{"id": 1006, "name": "baz3", "sex": "woman", "age": 20}

運行結果:

image

image

 

2.2 繼承Aggregator

還有另外一種方式就是繼承Aggregator這個類,優勢是能夠帶類型:

package cc11001100.spark.sql.udaf

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders}

/**
  * 計算平均值
  *
  */
object AverageAggregator extends Aggregator[User, Average, Double] {

  // 初始化buffer
  override def zero: Average = Average(0L, 0L)

  // 處理一條新的記錄
  override def reduce(b: Average, a: User): Average = {
    b.sum += a.age
    b.count += 1L
    b
  }

  // 合併聚合buffer
  override def merge(b1: Average, b2: Average): Average = {
    b1.sum += b2.sum
    b1.count += b2.count
    b1
  }

  // 減小中間數據傳輸
  override def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count

  override def bufferEncoder: Encoder[Average] = Encoders.product

  // 最終輸出結果的類型
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble

}

/**
  * 計算平均值過程當中使用的Buffer
  *
  * @param sum
  * @param count
  */
case class Average(var sum: Long, var count: Long) {
}

case class User(id: Long, name: String, sex: String, age: Long) {
}

調用:

package cc11001100.spark.sql.udaf

import org.apache.spark.sql.SparkSession

object AverageAggregatorDemo_001 {

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder().master("local[*]").appName("SparkStudy").getOrCreate()
    import spark.implicits._
    val user = spark.read.json("data/user").as[User]
    user.select(AverageAggregator.toColumn.name("avg")).show()

  }

}

運行結果:

image 

 

.

相關文章
相關標籤/搜索