本文主要研究一下flink Table的select操做html
class Table( private[flink] val tableEnv: TableEnvironment, private[flink] val logicalPlan: LogicalNode) { //...... def select(fields: String): Table = { val fieldExprs = ExpressionParser.parseExpressionList(fields) //get the correct expression for AggFunctionCall val withResolvedAggFunctionCall = fieldExprs.map(replaceAggFunctionCall(_, tableEnv)) select(withResolvedAggFunctionCall: _*) } def replaceAggFunctionCall(field: Expression, tableEnv: TableEnvironment): Expression = { field match { case l: LeafExpression => l case u: UnaryExpression => val c = replaceAggFunctionCall(u.child, tableEnv) u.makeCopy(Array(c)) case b: BinaryExpression => val l = replaceAggFunctionCall(b.left, tableEnv) val r = replaceAggFunctionCall(b.right, tableEnv) b.makeCopy(Array(l, r)) // Functions calls case c @ Call(name, args) => val function = tableEnv.getFunctionCatalog.lookupFunction(name, args) function match { case a: AggFunctionCall => a case a: Aggregation => a case p: AbstractWindowProperty => p case _ => val newArgs = args.map( (exp: Expression) => replaceAggFunctionCall(exp, tableEnv)) c.makeCopy(Array(name, newArgs)) } // Scala functions case sfc @ ScalarFunctionCall(clazz, args) => val newArgs: Seq[Expression] = args.map( (exp: Expression) => replaceAggFunctionCall(exp, tableEnv)) sfc.makeCopy(Array(clazz, newArgs)) // Array constructor case c @ ArrayConstructor(args) => val newArgs = c.elements .map((exp: Expression) => replaceAggFunctionCall(exp, tableEnv)) c.makeCopy(Array(newArgs)) // Other expressions case e: Expression => e } } def select(fields: Expression*): Table = { val expandedFields = expandProjectList(fields, logicalPlan, tableEnv) val (aggNames, propNames) = extractAggregationsAndProperties(expandedFields, tableEnv) if (propNames.nonEmpty) { throw new ValidationException("Window properties can only be used on windowed tables.") } if (aggNames.nonEmpty) { val projectsOnAgg = replaceAggregationsAndProperties( expandedFields, tableEnv, aggNames, propNames) val projectFields = extractFieldReferences(expandedFields) new Table(tableEnv, Project(projectsOnAgg, Aggregate(Nil, aggNames.map(a => Alias(a._1, a._2)).toSeq, Project(projectFields, logicalPlan).validate(tableEnv) ).validate(tableEnv) ).validate(tableEnv) ) } else { new Table(tableEnv, Project(expandedFields.map(UnresolvedAlias), logicalPlan).validate(tableEnv)) } } //...... }
abstract class Expression extends TreeNode[Expression] { /** * Returns the [[TypeInformation]] for evaluating this expression. * It is sometimes not available until the expression is valid. */ private[flink] def resultType: TypeInformation[_] /** * One pass validation of the expression tree in post order. */ private[flink] lazy val valid: Boolean = childrenValid && validateInput().isSuccess private[flink] def childrenValid: Boolean = children.forall(_.valid) /** * Check input data types, inputs number or other properties specified by this expression. * Return `ValidationSuccess` if it pass the check, * or `ValidationFailure` with supplement message explaining the error. * Note: we should only call this method until `childrenValid == true` */ private[flink] def validateInput(): ValidationResult = ValidationSuccess /** * Convert Expression to its counterpart in Calcite, i.e. RexNode */ private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = throw new UnsupportedOperationException( s"${this.getClass.getName} cannot be transformed to RexNode" ) private[flink] def checkEquals(other: Expression): Boolean = { if (this.getClass != other.getClass) { false } else { def checkEquality(elements1: Seq[Any], elements2: Seq[Any]): Boolean = { elements1.length == elements2.length && elements1.zip(elements2).forall { case (e1: Expression, e2: Expression) => e1.checkEquals(e2) case (t1: Seq[_], t2: Seq[_]) => checkEquality(t1, t2) case (i1, i2) => i1 == i2 } } val elements1 = this.productIterator.toSeq val elements2 = other.productIterator.toSeq checkEquality(elements1, elements2) } } } abstract class BinaryExpression extends Expression { private[flink] def left: Expression private[flink] def right: Expression private[flink] def children = Seq(left, right) } abstract class UnaryExpression extends Expression { private[flink] def child: Expression private[flink] def children = Seq(child) } abstract class LeafExpression extends Expression { private[flink] val children = Nil }
case class Project( projectList: Seq[NamedExpression], child: LogicalNode, explicitAlias: Boolean = false) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) override def resolveExpressions(tableEnv: TableEnvironment): LogicalNode = { val afterResolve = super.resolveExpressions(tableEnv).asInstanceOf[Project] val newProjectList = afterResolve.projectList.zipWithIndex.map { case (e, i) => e match { case u @ UnresolvedAlias(c) => c match { case ne: NamedExpression => ne case expr if !expr.valid => u case c @ Cast(ne: NamedExpression, tp) => Alias(c, s"${ne.name}-$tp") case gcf: GetCompositeField => Alias(gcf, gcf.aliasName().getOrElse(s"_c$i")) case other => Alias(other, s"_c$i") } case _ => throw new RuntimeException("This should never be called and probably points to a bug.") } } Project(newProjectList, child, explicitAlias) } override def validate(tableEnv: TableEnvironment): LogicalNode = { val resolvedProject = super.validate(tableEnv).asInstanceOf[Project] val names: mutable.Set[String] = mutable.Set() def checkName(name: String): Unit = { if (names.contains(name)) { failValidation(s"Duplicate field name $name.") } else { names.add(name) } } resolvedProject.projectList.foreach { case n: Alias => // explicit name checkName(n.name) case r: ResolvedFieldReference => // simple field forwarding checkName(r.name) case _ => // Do nothing } resolvedProject } override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = { child.construct(relBuilder) val exprs = if (explicitAlias) { projectList } else { // remove AS expressions, according to Calcite they should not be in a final RexNode projectList.map { case Alias(e: Expression, _, _) => e case e: Expression => e } } relBuilder.project( exprs.map(_.toRexNode(relBuilder)).asJava, projectList.map(_.name).asJava, true) } }
case class Aggregate( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], child: LogicalNode) extends UnaryNode { override def output: Seq[Attribute] = { (groupingExpressions ++ aggregateExpressions) map { case ne: NamedExpression => ne.toAttribute case e => Alias(e, e.toString).toAttribute } } override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = { child.construct(relBuilder) relBuilder.aggregate( relBuilder.groupKey(groupingExpressions.map(_.toRexNode(relBuilder)).asJava), aggregateExpressions.map { case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder) case _ => throw new RuntimeException("This should never happen.") }.asJava) } override def validate(tableEnv: TableEnvironment): LogicalNode = { implicit val relBuilder: RelBuilder = tableEnv.getRelBuilder val resolvedAggregate = super.validate(tableEnv).asInstanceOf[Aggregate] val groupingExprs = resolvedAggregate.groupingExpressions val aggregateExprs = resolvedAggregate.aggregateExpressions aggregateExprs.foreach(validateAggregateExpression) groupingExprs.foreach(validateGroupingExpression) def validateAggregateExpression(expr: Expression): Unit = expr match { case distinctExpr: DistinctAgg => distinctExpr.child match { case _: DistinctAgg => failValidation( "Chained distinct operators are not supported!") case aggExpr: Aggregation => validateAggregateExpression(aggExpr) case _ => failValidation( "Distinct operator can only be applied to aggregation expressions!") } // check aggregate function case aggExpr: Aggregation if aggExpr.getSqlAggFunction.requiresOver => failValidation(s"OVER clause is necessary for window functions: [${aggExpr.getClass}].") // check no nested aggregation exists. case aggExpr: Aggregation => aggExpr.children.foreach { child => child.preOrderVisit { case agg: Aggregation => failValidation( "It's not allowed to use an aggregate function as " + "input of another aggregate function") case _ => // OK } } case a: Attribute if !groupingExprs.exists(_.checkEquals(a)) => failValidation( s"expression '$a' is invalid because it is neither" + " present in group by nor an aggregate function") case e if groupingExprs.exists(_.checkEquals(e)) => // OK case e => e.children.foreach(validateAggregateExpression) } def validateGroupingExpression(expr: Expression): Unit = { if (!expr.resultType.isKeyType) { failValidation( s"expression $expr cannot be used as a grouping expression " + "because it's not a valid key type which must be hashable and comparable") } } resolvedAggregate } }
abstract class LogicalNode extends TreeNode[LogicalNode] { def output: Seq[Attribute] def resolveExpressions(tableEnv: TableEnvironment): LogicalNode = { // resolve references and function calls val exprResolved = expressionPostOrderTransform { case u @ UnresolvedFieldReference(name) => // try resolve a field resolveReference(tableEnv, name).getOrElse(u) case c @ Call(name, children) if c.childrenValid => tableEnv.getFunctionCatalog.lookupFunction(name, children) } exprResolved.expressionPostOrderTransform { case ips: InputTypeSpec if ips.childrenValid => var changed: Boolean = false val newChildren = ips.expectedTypes.zip(ips.children).map { case (tpe, child) => val childType = child.resultType if (childType != tpe && TypeCoercion.canSafelyCast(childType, tpe)) { changed = true Cast(child, tpe) } else { child } }.toArray[AnyRef] if (changed) ips.makeCopy(newChildren) else ips } } final def toRelNode(relBuilder: RelBuilder): RelNode = construct(relBuilder).build() protected[logical] def construct(relBuilder: RelBuilder): RelBuilder def validate(tableEnv: TableEnvironment): LogicalNode = { val resolvedNode = resolveExpressions(tableEnv) resolvedNode.expressionPostOrderTransform { case a: Attribute if !a.valid => val from = children.flatMap(_.output).map(_.name).mkString(", ") // give helpful error message for null literals if (a.name == "null") { failValidation(s"Cannot resolve field [${a.name}] given input [$from]. If you want to " + s"express a null literal, use 'Null(TYPE)' for typed null expressions. " + s"For example: Null(INT)") } else { failValidation(s"Cannot resolve field [${a.name}] given input [$from].") } case e: Expression if e.validateInput().isFailure => failValidation(s"Expression $e failed on input check: " + s"${e.validateInput().asInstanceOf[ValidationFailure].message}") } } /** * Resolves the given strings to a [[NamedExpression]] using the input from all child * nodes of this LogicalPlan. */ def resolveReference(tableEnv: TableEnvironment, name: String): Option[NamedExpression] = { // try to resolve a field val childrenOutput = children.flatMap(_.output) val fieldCandidates = childrenOutput.filter(_.name.equalsIgnoreCase(name)) if (fieldCandidates.length > 1) { failValidation(s"Reference $name is ambiguous.") } else if (fieldCandidates.nonEmpty) { return Some(fieldCandidates.head.withName(name)) } // try to resolve a table tableEnv.scanInternal(Array(name)) match { case Some(table) => Some(TableReference(name, table)) case None => None } } /** * Runs [[postOrderTransform]] with `rule` on all expressions present in this logical node. * * @param rule the rule to be applied to every expression in this logical node. */ def expressionPostOrderTransform(rule: PartialFunction[Expression, Expression]): LogicalNode = { var changed = false def expressionPostOrderTransform(e: Expression): Expression = { val newExpr = e.postOrderTransform(rule) if (newExpr.fastEquals(e)) { e } else { changed = true newExpr } } val newArgs = productIterator.map { case e: Expression => expressionPostOrderTransform(e) case Some(e: Expression) => Some(expressionPostOrderTransform(e)) case seq: Traversable[_] => seq.map { case e: Expression => expressionPostOrderTransform(e) case other => other } case r: Resolvable[_] => r.resolveExpressions(e => expressionPostOrderTransform(e)) case other: AnyRef => other }.toArray if (changed) makeCopy(newArgs) else this } protected def failValidation(msg: String): Nothing = { throw new ValidationException(msg) } } abstract class LeafNode extends LogicalNode { override def children: Seq[LogicalNode] = Nil } abstract class UnaryNode extends LogicalNode { def child: LogicalNode override def children: Seq[LogicalNode] = child :: Nil } abstract class BinaryNode extends LogicalNode { def left: LogicalNode def right: LogicalNode override def children: Seq[LogicalNode] = left :: right :: Nil }