1.建立一個類繼承UserDefinedAggregateFunction類。sql
---------------------------------------------------------------------express
package cn.piesat.test
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DataTypes, IntegerType, StructType}
class CountUDAF extends UserDefinedAggregateFunction{
/**
* 聚合函數的輸入類型
* @return
*/
override def inputSchema: StructType = {
new StructType().add("ageType",IntegerType)
}
/**
* 緩存的數據類型
* @return
*/
override def bufferSchema: StructType = {
new StructType().add("bufferAgeType",IntegerType)
}
/**
* UDAF返回值的類型
* @return
*/
override def dataType: DataType = {
DataTypes.StringType
}
/**
* 若是該函數是肯定性的,那麼將會返回true,通常給true就行。
* @return
*/
override def deterministic: Boolean = true
/**
* 爲每一個分組的數據執行初始化操做
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)=0
}
/**
* 更新操做,指的是每一個分組有新的值進來的時候,如何進行分組對應的聚合值的計算
* @param buffer
* @param input
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val num= input.getAs[Int](0)
buffer(0)=buffer.getAs[Int](0)+num
}
/**
* 分區合併時執行的操做
* @param buffer1
* @param buffer2
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0)=buffer1.getAs[Int](0)+buffer2.getAs[Int](0)
}
/**
* 最後返回的結果
* @param buffer
* @return
*/
override def evaluate(buffer: Row): Any = {
buffer.getAs[Int](0).toString
}
}
--------------------------------------------------------------
2.在main函數中使用樣例
---------------------------------------------------------------
package cn.piesat.testimport org.apache.spark.sql.SparkSessionimport scala.collection.mutable.ArrayBufferobject SparkSQLTest { def main(args: Array[String]): Unit = { val spark=SparkSession.builder().appName("sparkSql").master("local[4]") .config("spark.serializer","org.apache.spark.serializer.KryoSerializer").getOrCreate() val sc=spark.sparkContext val sqlContext=spark.sqlContext val workerRDD=sc.textFile("F://Workers.txt").mapPartitions(itor=>{ val array=new ArrayBuffer[Worker]() while(itor.hasNext){ val splited=itor.next().split(",") array.append(new Worker(splited(0),splited(2).toInt,splited(2))) } array.toIterator }) import spark.implicits._ //註冊UDAF spark.udf.register("countUDF",new CountUDAF()) val workDS=workerRDD.toDS() workDS.createOrReplaceTempView("worker") val resultDF=spark.sql("select countUDF(age) from worker") val resultDS=resultDF.as("WO") resultDS.show() spark.stop() }}-----------------------------------------------------------------------------------------------