2. 示例:Averageexpress
3. 類型安全的自定義函數apache
spark中咱們定義一個函數,須要繼承 UserDefinedAggregateFunction這個抽象類,實現這個抽象類中所定義的方法,這是一個模板設計模式? 我只要實現抽象類的中方法,具體的全部的計算步驟由內部完成。而咱們能夠看一下UserDefinedAggregateFunction這個抽象類。設計模式
package org.apache.spark.sql.expressions
@org.apache.spark.annotation.InterfaceStability.Stable
abstract class UserDefinedAggregateFunction() extends scala.AnyRef with scala.Serializable { def inputSchema : org.apache.spark.sql.types.StructType def bufferSchema : org.apache.spark.sql.types.StructType def dataType : org.apache.spark.sql.types.DataType def deterministic : scala.Boolean def initialize(buffer : org.apache.spark.sql.expressions.MutableAggregationBuffer) : scala.Unit def update(buffer : org.apache.spark.sql.expressions.MutableAggregationBuffer, input : org.apache.spark.sql.Row) : scala.Unit def merge(buffer1 : org.apache.spark.sql.expressions.MutableAggregationBuffer, buffer2 : org.apache.spark.sql.Row) : scala.Unit def evaluate(buffer : org.apache.spark.sql.Row) : scala.Any @scala.annotation.varargs def apply(exprs : org.apache.spark.sql.Column*) : org.apache.spark.sql.Column = { /* compiled code */ } @scala.annotation.varargs def distinct(exprs : org.apache.spark.sql.Column*) : org.apache.spark.sql.Column = { /* compiled code */ } }
也就是說對於這幾個函數,咱們只要依次實現他們的功能,其他的交給spark就能夠了。數組
首先新建一個Object類MyAvage類,繼承UserDefinedAggregateFunction。下面對每個函數的實現進行解釋。緩存
def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil)
這個規定了輸入數據的數據結構安全
def bufferSchema: StructType = { StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil) }
這個規定了緩存區的數據結構session
def dataType: DataType = DoubleType
這個規定了返回值的數據類型數據結構
def deterministic: Boolean = true def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0L buffer(1) = 0L }
進行初始化,這裏要說明一下,官網中提到:app
// Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
// standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides // the opportunity to update its values. Note that arrays and maps inside the buffer are still // immutable.
這裏翻譯一下:
咱們爲咱們的緩衝區設置初始值,咱們不只能夠設置數字,還可使用index getBoolen等去改變他的值,可是咱們須要知道的是,在這個緩衝區中,數組和map依然是不可變的。
其實最後一句我也是不太明白,等我之後若是能研究並理解這句話,再回來補充吧。
def update(buffer: MutableAggregationBuffer, input: Row): Unit = { if (!input.isNullAt(0)) { buffer(0) = buffer.getLong(0) + input.getLong(0) buffer(1) = buffer.getLong(1) + 1 } }
這個是重要的update函數,對於平均值,咱們能夠不斷迭代輸入的值進行累加。buffer(0)統計總和,buffer(1)統計長度。
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1) }
在作完update後spark 須要將結果進行merge到咱們的區域,所以有一個merge 進行覆蓋buffer
def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
這是將最終的結果進行計算。
在寫完這個類之後咱們在咱們的sparksession裏面進行編寫測試案例。
spark.sparkContext.textFile("file:///Users/4pa/Desktop/people.txt") .map(_.split(",")) .map(agg=>Person(agg(0),agg(1).trim.toInt)) .toDF().createOrReplaceTempView("people") spark.udf.register("myAverage",Myaverage) val udfRes = spark.sql("select name,myAverage(age) as avgAge from people group by name") udfRes.show()
從上面咱們能夠看出來,這種自定義函數不是類型安全的,所以可否實現一個安全的自定義函數呢?
我的以爲最好的例子仍是官網給的例子,具體的解釋都已經給了出來,思路其實和上面是同樣的,只不過定義了兩個caseclass,用於類型的驗證。
case class Employee(name: String, salary: Long) case class Average(var sum: Long, var count: Long) object MyAverage extends Aggregator[Employee, Average, Double] { // 初始化 def zero: Average = Average(0L, 0L) // 這個其實有點map-reduce的意思,只不過是對一個類的reduce,第一個值是和,第二個是總數 def reduce(buffer: Average, employee: Employee): Average = { buffer.sum += employee.salary buffer.count += 1 buffer } // 實現緩衝區的一個覆蓋 def merge(b1: Average, b2: Average): Average = { b1.sum += b2.sum b1.count += b2.count b1 } // 計算最終數值 def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count // Specifies the Encoder for the intermediate value type def bufferEncoder: Encoder[Average] = Encoders.product // 指定返回類型 def outputEncoder: Encoder[Double] = Encoders.scalaDouble }