SparkSQL之UDAF使用

1.建立一個類繼承UserDefinedAggregateFunction類。sql

---------------------------------------------------------------------express

package cn.piesat.test

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DataTypes, IntegerType, StructType}

class CountUDAF extends UserDefinedAggregateFunction{
/**
* 聚合函數的輸入類型
* @return
*/
override def inputSchema: StructType = {
new StructType().add("ageType",IntegerType)
}

/**
* 緩存的數據類型
* @return
*/
override def bufferSchema: StructType = {
new StructType().add("bufferAgeType",IntegerType)
}

/**
* UDAF返回值的類型
* @return
*/
override def dataType: DataType = {
DataTypes.StringType
}

/**
* 若是該函數是肯定性的,那麼將會返回true,通常給true就行。
* @return
*/
override def deterministic: Boolean = true

/**
* 爲每一個分組的數據執行初始化操做
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)=0
}

/**
* 更新操做,指的是每一個分組有新的值進來的時候,如何進行分組對應的聚合值的計算
* @param buffer
* @param input
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val num= input.getAs[Int](0)
buffer(0)=buffer.getAs[Int](0)+num
}

/**
* 分區合併時執行的操做
* @param buffer1
* @param buffer2
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0)=buffer1.getAs[Int](0)+buffer2.getAs[Int](0)
}

/**
* 最後返回的結果
* @param buffer
* @return
*/
override def evaluate(buffer: Row): Any = {
buffer.getAs[Int](0).toString
}
}
--------------------------------------------------------------


2.在main函數中使用樣例
---------------------------------------------------------------
package cn.piesat.testimport org.apache.spark.sql.SparkSessionimport scala.collection.mutable.ArrayBufferobject SparkSQLTest {  def main(args: Array[String]): Unit = {    val spark=SparkSession.builder().appName("sparkSql").master("local[4]")      .config("spark.serializer","org.apache.spark.serializer.KryoSerializer").getOrCreate()    val sc=spark.sparkContext    val sqlContext=spark.sqlContext    val workerRDD=sc.textFile("F://Workers.txt").mapPartitions(itor=>{      val array=new ArrayBuffer[Worker]()      while(itor.hasNext){        val splited=itor.next().split(",")        array.append(new Worker(splited(0),splited(2).toInt,splited(2)))      }      array.toIterator    })    import spark.implicits._    //註冊UDAF    spark.udf.register("countUDF",new CountUDAF())    val workDS=workerRDD.toDS()    workDS.createOrReplaceTempView("worker")    val resultDF=spark.sql("select countUDF(age) from worker")    val resultDS=resultDF.as("WO")    resultDS.show()    spark.stop()  }}-----------------------------------------------------------------------------------------------
相關文章
相關標籤/搜索