本文主要研究一下flink Table的AggregateFunctionhtml
/** * Accumulator for WeightedAvg. */ public static class WeightedAvgAccum { public long sum = 0; public int count = 0; } /** * Weighted Average user-defined aggregate function. */ public static class WeightedAvg extends AggregateFunction<Long, WeightedAvgAccum> { @Override public WeightedAvgAccum createAccumulator() { return new WeightedAvgAccum(); } @Override public Long getValue(WeightedAvgAccum acc) { if (acc.count == 0) { return 0L; } else { return acc.sum / acc.count; } } public void accumulate(WeightedAvgAccum acc, long iValue, int iWeight) { acc.sum += iValue * iWeight; acc.count += iWeight; } public void retract(WeightedAvgAccum acc, long iValue, int iWeight) { acc.sum -= iValue * iWeight; acc.count -= iWeight; } public void merge(WeightedAvgAccum acc, Iterable<WeightedAvgAccum> it) { Iterator<WeightedAvgAccum> iter = it.iterator(); while (iter.hasNext()) { WeightedAvgAccum a = iter.next(); acc.count += a.count; acc.sum += a.sum; } } public void resetAccumulator(WeightedAvgAccum acc) { acc.count = 0; acc.sum = 0L; } } // register function BatchTableEnvironment tEnv = ... tEnv.registerFunction("wAvg", new WeightedAvg()); // use function tEnv.sqlQuery("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user");
flink-table_2.11-1.7.1-sources.jar!/org/apache/flink/table/functions/AggregateFunction.scalajava
abstract class AggregateFunction[T, ACC] extends UserDefinedFunction { /** * Creates and init the Accumulator for this [[AggregateFunction]]. * * @return the accumulator with the initial value */ def createAccumulator(): ACC /** * Called every time when an aggregation result should be materialized. * The returned value could be either an early and incomplete result * (periodically emitted as data arrive) or the final result of the * aggregation. * * @param accumulator the accumulator which contains the current * aggregated results * @return the aggregation result */ def getValue(accumulator: ACC): T /** * Returns true if this AggregateFunction can only be applied in an OVER window. * * @return true if the AggregateFunction requires an OVER window, false otherwise. */ def requiresOver: Boolean = false /** * Returns the TypeInformation of the AggregateFunction's result. * * @return The TypeInformation of the AggregateFunction's result or null if the result type * should be automatically inferred. */ def getResultType: TypeInformation[T] = null /** * Returns the TypeInformation of the AggregateFunction's accumulator. * * @return The TypeInformation of the AggregateFunction's accumulator or null if the * accumulator type should be automatically inferred. */ def getAccumulatorType: TypeInformation[ACC] = null }
這幾個方法中子類必須實現createAccumulator、getValue方法
)flink-table_2.11-1.7.1-sources.jar!/org/apache/flink/table/runtime/aggregate/DataSetPreAggFunction.scalanode
class DataSetPreAggFunction(genAggregations: GeneratedAggregationsFunction) extends AbstractRichFunction with GroupCombineFunction[Row, Row] with MapPartitionFunction[Row, Row] with Compiler[GeneratedAggregations] with Logging { private var output: Row = _ private var accumulators: Row = _ private var function: GeneratedAggregations = _ override def open(config: Configuration) { LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + s"Code:\n$genAggregations.code") val clazz = compile( getRuntimeContext.getUserCodeClassLoader, genAggregations.name, genAggregations.code) LOG.debug("Instantiating AggregateHelper.") function = clazz.newInstance() output = function.createOutputRow() accumulators = function.createAccumulators() } override def combine(values: Iterable[Row], out: Collector[Row]): Unit = { // reset accumulators function.resetAccumulator(accumulators) val iterator = values.iterator() var record: Row = null while (iterator.hasNext) { record = iterator.next() // accumulate function.accumulate(accumulators, record) } // set group keys and accumulators to output function.setAggregationResults(accumulators, output) function.setForwardedFields(record, output) out.collect(output) } override def mapPartition(values: Iterable[Row], out: Collector[Row]): Unit = { combine(values, out) } }
flink-table_2.11-1.7.1-sources.jar!/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scalasql
abstract class GeneratedAggregations extends Function { /** * Setup method for [[org.apache.flink.table.functions.AggregateFunction]]. * It can be used for initialization work. By default, this method does nothing. * * @param ctx The runtime context. */ def open(ctx: RuntimeContext) /** * Sets the results of the aggregations (partial or final) to the output row. * Final results are computed with the aggregation function. * Partial results are the accumulators themselves. * * @param accumulators the accumulators (saved in a row) which contains the current * aggregated results * @param output output results collected in a row */ def setAggregationResults(accumulators: Row, output: Row) /** * Copies forwarded fields, such as grouping keys, from input row to output row. * * @param input input values bundled in a row * @param output output results collected in a row */ def setForwardedFields(input: Row, output: Row) /** * Accumulates the input values to the accumulators. * * @param accumulators the accumulators (saved in a row) which contains the current * aggregated results * @param input input values bundled in a row */ def accumulate(accumulators: Row, input: Row) /** * Retracts the input values from the accumulators. * * @param accumulators the accumulators (saved in a row) which contains the current * aggregated results * @param input input values bundled in a row */ def retract(accumulators: Row, input: Row) /** * Initializes the accumulators and save them to a accumulators row. * * @return a row of accumulators which contains the aggregated results */ def createAccumulators(): Row /** * Creates an output row object with the correct arity. * * @return an output row object with the correct arity. */ def createOutputRow(): Row /** * Merges two rows of accumulators into one row. * * @param a First row of accumulators * @param b The other row of accumulators * @return A row with the merged accumulators of both input rows. */ def mergeAccumulatorsPair(a: Row, b: Row): Row /** * Resets all the accumulators. * * @param accumulators the accumulators (saved in a row) which contains the current * aggregated results */ def resetAccumulator(accumulators: Row) /** * Cleanup for the accumulators. */ def cleanup() /** * Tear-down method for [[org.apache.flink.table.functions.AggregateFunction]]. * It can be used for clean up work. By default, this method does nothing. */ def close() }
flink-table_2.11-1.7.1-sources.jar!/org/apache/flink/table/runtime/aggregate/AggregateUtil.scalaapache
object AggregateUtil { type CalcitePair[T, R] = org.apache.calcite.util.Pair[T, R] type JavaList[T] = java.util.List[T] //...... /** * Create functions to compute a [[org.apache.flink.table.plan.nodes.dataset.DataSetAggregate]]. * If all aggregation functions support pre-aggregation, a pre-aggregation function and the * respective output type are generated as well. */ private[flink] def createDataSetAggregateFunctions( generator: AggregationCodeGenerator, namedAggregates: Seq[CalcitePair[AggregateCall, String]], inputType: RelDataType, inputFieldTypeInfo: Seq[TypeInformation[_]], outputType: RelDataType, groupings: Array[Int], tableConfig: TableConfig): ( Option[DataSetPreAggFunction], Option[TypeInformation[Row]], Either[DataSetAggFunction, DataSetFinalAggFunction]) = { val needRetract = false val (aggInFields, aggregates, isDistinctAggs, accTypes, _) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, needRetract, tableConfig) val (gkeyOutMapping, aggOutMapping) = getOutputMappings( namedAggregates, groupings, inputType, outputType ) val aggOutFields = aggOutMapping.map(_._1) if (doAllSupportPartialMerge(aggregates)) { // compute preaggregation type val preAggFieldTypes = gkeyOutMapping.map(_._2) .map(inputType.getFieldList.get(_).getType) .map(FlinkTypeFactory.toTypeInfo) ++ accTypes val preAggRowType = new RowTypeInfo(preAggFieldTypes: _*) val genPreAggFunction = generator.generateAggregations( "DataSetAggregatePrepareMapHelper", inputFieldTypeInfo, aggregates, aggInFields, aggregates.indices.map(_ + groupings.length).toArray, isDistinctAggs, isStateBackedDataViews = false, partialResults = true, groupings, None, groupings.length + aggregates.length, needRetract, needMerge = false, needReset = true, None ) // compute mapping of forwarded grouping keys val gkeyMapping: Array[Int] = if (gkeyOutMapping.nonEmpty) { val gkeyOutFields = gkeyOutMapping.map(_._1) val mapping = Array.fill[Int](gkeyOutFields.max + 1)(-1) gkeyOutFields.zipWithIndex.foreach(m => mapping(m._1) = m._2) mapping } else { new Array[Int](0) } val genFinalAggFunction = generator.generateAggregations( "DataSetAggregateFinalHelper", inputFieldTypeInfo, aggregates, aggInFields, aggOutFields, isDistinctAggs, isStateBackedDataViews = false, partialResults = false, gkeyMapping, Some(aggregates.indices.map(_ + groupings.length).toArray), outputType.getFieldCount, needRetract, needMerge = true, needReset = true, None ) ( Some(new DataSetPreAggFunction(genPreAggFunction)), Some(preAggRowType), Right(new DataSetFinalAggFunction(genFinalAggFunction)) ) } else { val genFunction = generator.generateAggregations( "DataSetAggregateHelper", inputFieldTypeInfo, aggregates, aggInFields, aggOutFields, isDistinctAggs, isStateBackedDataViews = false, partialResults = false, groupings, None, outputType.getFieldCount, needRetract, needMerge = false, needReset = true, None ) ( None, None, Left(new DataSetAggFunction(genFunction)) ) } } //...... }
這幾個方法中子類必須實現createAccumulator、getValue方法
);對於AggregateFunction,有一個accumulate方法這裏沒定義,可是須要子類定義及實現,該方法接收ACC,T等參數,返回void;另外還有retract、merge、resetAccumulator三個方法是可選的,須要子類根據狀況去定義及實現(對於datastream bounded over aggregate操做,要求實現restract方法,該方法接收ACC,T等參數,返回void;對於datastream session window grouping aggregate以及dataset grouping aggregate操做,要求實現merge方法,該方法接收ACC,java.lang.Iterable<T>兩個參數,返回void;對於dataset grouping aggregate操做,要求實現resetAccumulator方法,該方法接收ACC參數,返回void
)