本文主要研究一下flink Table的groupBy操做html
class Table( private[flink] val tableEnv: TableEnvironment, private[flink] val logicalPlan: LogicalNode) { //...... def groupBy(fields: String): GroupedTable = { val fieldsExpr = ExpressionParser.parseExpressionList(fields) groupBy(fieldsExpr: _*) } def groupBy(fields: Expression*): GroupedTable = { new GroupedTable(this, fields) } //...... }
class GroupedTable( private[flink] val table: Table, private[flink] val groupKey: Seq[Expression]) { def select(fields: Expression*): Table = { val expandedFields = expandProjectList(fields, table.logicalPlan, table.tableEnv) val (aggNames, propNames) = extractAggregationsAndProperties(expandedFields, table.tableEnv) if (propNames.nonEmpty) { throw new ValidationException("Window properties can only be used on windowed tables.") } val projectsOnAgg = replaceAggregationsAndProperties( expandedFields, table.tableEnv, aggNames, propNames) val projectFields = extractFieldReferences(expandedFields ++ groupKey) new Table(table.tableEnv, Project(projectsOnAgg, Aggregate(groupKey, aggNames.map(a => Alias(a._1, a._2)).toSeq, Project(projectFields, table.logicalPlan).validate(table.tableEnv) ).validate(table.tableEnv) ).validate(table.tableEnv)) } def select(fields: String): Table = { val fieldExprs = ExpressionParser.parseExpressionList(fields) //get the correct expression for AggFunctionCall val withResolvedAggFunctionCall = fieldExprs.map(replaceAggFunctionCall(_, table.tableEnv)) select(withResolvedAggFunctionCall: _*) } }
case class Aggregate( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], child: LogicalNode) extends UnaryNode { override def output: Seq[Attribute] = { (groupingExpressions ++ aggregateExpressions) map { case ne: NamedExpression => ne.toAttribute case e => Alias(e, e.toString).toAttribute } } override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = { child.construct(relBuilder) relBuilder.aggregate( relBuilder.groupKey(groupingExpressions.map(_.toRexNode(relBuilder)).asJava), aggregateExpressions.map { case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder) case _ => throw new RuntimeException("This should never happen.") }.asJava) } override def validate(tableEnv: TableEnvironment): LogicalNode = { implicit val relBuilder: RelBuilder = tableEnv.getRelBuilder val resolvedAggregate = super.validate(tableEnv).asInstanceOf[Aggregate] val groupingExprs = resolvedAggregate.groupingExpressions val aggregateExprs = resolvedAggregate.aggregateExpressions aggregateExprs.foreach(validateAggregateExpression) groupingExprs.foreach(validateGroupingExpression) def validateAggregateExpression(expr: Expression): Unit = expr match { case distinctExpr: DistinctAgg => distinctExpr.child match { case _: DistinctAgg => failValidation( "Chained distinct operators are not supported!") case aggExpr: Aggregation => validateAggregateExpression(aggExpr) case _ => failValidation( "Distinct operator can only be applied to aggregation expressions!") } // check aggregate function case aggExpr: Aggregation if aggExpr.getSqlAggFunction.requiresOver => failValidation(s"OVER clause is necessary for window functions: [${aggExpr.getClass}].") // check no nested aggregation exists. case aggExpr: Aggregation => aggExpr.children.foreach { child => child.preOrderVisit { case agg: Aggregation => failValidation( "It's not allowed to use an aggregate function as " + "input of another aggregate function") case _ => // OK } } case a: Attribute if !groupingExprs.exists(_.checkEquals(a)) => failValidation( s"expression '$a' is invalid because it is neither" + " present in group by nor an aggregate function") case e if groupingExprs.exists(_.checkEquals(e)) => // OK case e => e.children.foreach(validateAggregateExpression) } def validateGroupingExpression(expr: Expression): Unit = { if (!expr.resultType.isKeyType) { failValidation( s"expression $expr cannot be used as a grouping expression " + "because it's not a valid key type which must be hashable and comparable") } } resolvedAggregate } }
public class RelBuilder { protected final RelOptCluster cluster; protected final RelOptSchema relOptSchema; private final RelFactories.FilterFactory filterFactory; private final RelFactories.ProjectFactory projectFactory; private final RelFactories.AggregateFactory aggregateFactory; private final RelFactories.SortFactory sortFactory; private final RelFactories.ExchangeFactory exchangeFactory; private final RelFactories.SortExchangeFactory sortExchangeFactory; private final RelFactories.SetOpFactory setOpFactory; private final RelFactories.JoinFactory joinFactory; private final RelFactories.SemiJoinFactory semiJoinFactory; private final RelFactories.CorrelateFactory correlateFactory; private final RelFactories.ValuesFactory valuesFactory; private final RelFactories.TableScanFactory scanFactory; private final RelFactories.MatchFactory matchFactory; private final Deque<Frame> stack = new ArrayDeque<>(); private final boolean simplify; private final RexSimplify simplifier; //...... /** Creates an empty group key. */ public GroupKey groupKey() { return groupKey(ImmutableList.of()); } /** Creates a group key. */ public GroupKey groupKey(RexNode... nodes) { return groupKey(ImmutableList.copyOf(nodes)); } /** Creates a group key. */ public GroupKey groupKey(Iterable<? extends RexNode> nodes) { return new GroupKeyImpl(ImmutableList.copyOf(nodes), false, null, null); } /** Creates a group key with grouping sets. */ public GroupKey groupKey(Iterable<? extends RexNode> nodes, Iterable<? extends Iterable<? extends RexNode>> nodeLists) { return groupKey_(nodes, false, nodeLists); } /** Creates a group key of fields identified by ordinal. */ public GroupKey groupKey(int... fieldOrdinals) { return groupKey(fields(ImmutableIntList.of(fieldOrdinals))); } /** Creates a group key of fields identified by name. */ public GroupKey groupKey(String... fieldNames) { return groupKey(fields(ImmutableList.copyOf(fieldNames))); } public GroupKey groupKey(@Nonnull ImmutableBitSet groupSet) { return groupKey(groupSet, ImmutableList.of(groupSet)); } public GroupKey groupKey(ImmutableBitSet groupSet, @Nonnull Iterable<? extends ImmutableBitSet> groupSets) { return groupKey_(groupSet, false, ImmutableList.copyOf(groupSets)); } private GroupKey groupKey_(ImmutableBitSet groupSet, boolean indicator, @Nonnull ImmutableList<ImmutableBitSet> groupSets) { if (groupSet.length() > peek().getRowType().getFieldCount()) { throw new IllegalArgumentException("out of bounds: " + groupSet); } Objects.requireNonNull(groupSets); final ImmutableList<RexNode> nodes = fields(ImmutableIntList.of(groupSet.toArray())); final List<ImmutableList<RexNode>> nodeLists = Util.transform(groupSets, bitSet -> fields(ImmutableIntList.of(bitSet.toArray()))); return groupKey_(nodes, indicator, nodeLists); } private GroupKey groupKey_(Iterable<? extends RexNode> nodes, boolean indicator, Iterable<? extends Iterable<? extends RexNode>> nodeLists) { final ImmutableList.Builder<ImmutableList<RexNode>> builder = ImmutableList.builder(); for (Iterable<? extends RexNode> nodeList : nodeLists) { builder.add(ImmutableList.copyOf(nodeList)); } return new GroupKeyImpl(ImmutableList.copyOf(nodes), indicator, builder.build(), null); } //...... }
public interface GroupKey { /** Assigns an alias to this group key. * * <p>Used to assign field names in the {@code group} operation. */ GroupKey alias(String alias); } /** Implementation of {@link GroupKey}. */ protected static class GroupKeyImpl implements GroupKey { final ImmutableList<RexNode> nodes; final boolean indicator; final ImmutableList<ImmutableList<RexNode>> nodeLists; final String alias; GroupKeyImpl(ImmutableList<RexNode> nodes, boolean indicator, ImmutableList<ImmutableList<RexNode>> nodeLists, String alias) { this.nodes = Objects.requireNonNull(nodes); assert !indicator; this.indicator = indicator; this.nodeLists = nodeLists; this.alias = alias; } @Override public String toString() { return alias == null ? nodes.toString() : nodes + " as " + alias; } public GroupKey alias(String alias) { return Objects.equals(this.alias, alias) ? this : new GroupKeyImpl(nodes, indicator, nodeLists, alias); } }
public class RelBuilder { protected final RelOptCluster cluster; protected final RelOptSchema relOptSchema; private final RelFactories.FilterFactory filterFactory; private final RelFactories.ProjectFactory projectFactory; private final RelFactories.AggregateFactory aggregateFactory; private final RelFactories.SortFactory sortFactory; private final RelFactories.ExchangeFactory exchangeFactory; private final RelFactories.SortExchangeFactory sortExchangeFactory; private final RelFactories.SetOpFactory setOpFactory; private final RelFactories.JoinFactory joinFactory; private final RelFactories.SemiJoinFactory semiJoinFactory; private final RelFactories.CorrelateFactory correlateFactory; private final RelFactories.ValuesFactory valuesFactory; private final RelFactories.TableScanFactory scanFactory; private final RelFactories.MatchFactory matchFactory; private final Deque<Frame> stack = new ArrayDeque<>(); private final boolean simplify; private final RexSimplify simplifier; //...... /** Creates an {@link Aggregate} with an array of * calls. */ public RelBuilder aggregate(GroupKey groupKey, AggCall... aggCalls) { return aggregate(groupKey, ImmutableList.copyOf(aggCalls)); } public RelBuilder aggregate(GroupKey groupKey, List<AggregateCall> aggregateCalls) { return aggregate(groupKey, Lists.transform(aggregateCalls, AggCallImpl2::new)); } /** Creates an {@link Aggregate} with a list of * calls. */ public RelBuilder aggregate(GroupKey groupKey, Iterable<AggCall> aggCalls) { final Registrar registrar = new Registrar(); registrar.extraNodes.addAll(fields()); registrar.names.addAll(peek().getRowType().getFieldNames()); final GroupKeyImpl groupKey_ = (GroupKeyImpl) groupKey; final ImmutableBitSet groupSet = ImmutableBitSet.of(registrar.registerExpressions(groupKey_.nodes)); label: if (Iterables.isEmpty(aggCalls) && !groupKey_.indicator) { final RelMetadataQuery mq = peek().getCluster().getMetadataQuery(); if (groupSet.isEmpty()) { final Double minRowCount = mq.getMinRowCount(peek()); if (minRowCount == null || minRowCount < 1D) { // We can't remove "GROUP BY ()" if there's a chance the rel could be // empty. break label; } } if (registrar.extraNodes.size() == fields().size()) { final Boolean unique = mq.areColumnsUnique(peek(), groupSet); if (unique != null && unique) { // Rel is already unique. return project(fields(groupSet.asList())); } } final Double maxRowCount = mq.getMaxRowCount(peek()); if (maxRowCount != null && maxRowCount <= 1D) { // If there is at most one row, rel is already unique. return this; } } final ImmutableList<ImmutableBitSet> groupSets; if (groupKey_.nodeLists != null) { final int sizeBefore = registrar.extraNodes.size(); final SortedSet<ImmutableBitSet> groupSetSet = new TreeSet<>(ImmutableBitSet.ORDERING); for (ImmutableList<RexNode> nodeList : groupKey_.nodeLists) { final ImmutableBitSet groupSet2 = ImmutableBitSet.of(registrar.registerExpressions(nodeList)); if (!groupSet.contains(groupSet2)) { throw new IllegalArgumentException("group set element " + nodeList + " must be a subset of group key"); } groupSetSet.add(groupSet2); } groupSets = ImmutableList.copyOf(groupSetSet); if (registrar.extraNodes.size() > sizeBefore) { throw new IllegalArgumentException( "group sets contained expressions not in group key: " + registrar.extraNodes.subList(sizeBefore, registrar.extraNodes.size())); } } else { groupSets = ImmutableList.of(groupSet); } for (AggCall aggCall : aggCalls) { if (aggCall instanceof AggCallImpl) { final AggCallImpl aggCall1 = (AggCallImpl) aggCall; registrar.registerExpressions(aggCall1.operands); if (aggCall1.filter != null) { registrar.registerExpression(aggCall1.filter); } } } project(registrar.extraNodes); rename(registrar.names); final Frame frame = stack.pop(); final RelNode r = frame.rel; final List<AggregateCall> aggregateCalls = new ArrayList<>(); for (AggCall aggCall : aggCalls) { final AggregateCall aggregateCall; if (aggCall instanceof AggCallImpl) { final AggCallImpl aggCall1 = (AggCallImpl) aggCall; final List<Integer> args = registrar.registerExpressions(aggCall1.operands); final int filterArg = aggCall1.filter == null ? -1 : registrar.registerExpression(aggCall1.filter); if (aggCall1.distinct && !aggCall1.aggFunction.isQuantifierAllowed()) { throw new IllegalArgumentException("DISTINCT not allowed"); } if (aggCall1.filter != null && !aggCall1.aggFunction.allowsFilter()) { throw new IllegalArgumentException("FILTER not allowed"); } RelCollation collation = RelCollations.of(aggCall1.orderKeys .stream() .map(orderKey -> collation(orderKey, RelFieldCollation.Direction.ASCENDING, null, Collections.emptyList())) .collect(Collectors.toList())); aggregateCall = AggregateCall.create(aggCall1.aggFunction, aggCall1.distinct, aggCall1.approximate, args, filterArg, collation, groupSet.cardinality(), r, null, aggCall1.alias); } else { aggregateCall = ((AggCallImpl2) aggCall).aggregateCall; } aggregateCalls.add(aggregateCall); } assert ImmutableBitSet.ORDERING.isStrictlyOrdered(groupSets) : groupSets; for (ImmutableBitSet set : groupSets) { assert groupSet.contains(set); } RelNode aggregate = aggregateFactory.createAggregate(r, groupKey_.indicator, groupSet, groupSets, aggregateCalls); // build field list final ImmutableList.Builder<Field> fields = ImmutableList.builder(); final List<RelDataTypeField> aggregateFields = aggregate.getRowType().getFieldList(); int i = 0; // first, group fields for (Integer groupField : groupSet.asList()) { RexNode node = registrar.extraNodes.get(groupField); final SqlKind kind = node.getKind(); switch (kind) { case INPUT_REF: fields.add(frame.fields.get(((RexInputRef) node).getIndex())); break; default: String name = aggregateFields.get(i).getName(); RelDataTypeField fieldType = new RelDataTypeFieldImpl(name, i, node.getType()); fields.add(new Field(ImmutableSet.of(), fieldType)); break; } i++; } // second, indicator fields (copy from aggregate rel type) if (groupKey_.indicator) { for (int j = 0; j < groupSet.cardinality(); ++j) { final RelDataTypeField field = aggregateFields.get(i); final RelDataTypeField fieldType = new RelDataTypeFieldImpl(field.getName(), i, field.getType()); fields.add(new Field(ImmutableSet.of(), fieldType)); i++; } } // third, aggregate fields. retain `i' as field index for (int j = 0; j < aggregateCalls.size(); ++j) { final AggregateCall call = aggregateCalls.get(j); final RelDataTypeField fieldType = new RelDataTypeFieldImpl(aggregateFields.get(i + j).getName(), i + j, call.getType()); fields.add(new Field(ImmutableSet.of(), fieldType)); } stack.push(new Frame(aggregate, fields.build())); return this; } //...... }
public class RelFactories { //...... public static final AggregateFactory DEFAULT_AGGREGATE_FACTORY = new AggregateFactoryImpl(); public interface AggregateFactory { /** Creates an aggregate. */ RelNode createAggregate(RelNode input, boolean indicator, ImmutableBitSet groupSet, ImmutableList<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls); } private static class AggregateFactoryImpl implements AggregateFactory { @SuppressWarnings("deprecation") public RelNode createAggregate(RelNode input, boolean indicator, ImmutableBitSet groupSet, ImmutableList<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls) { return LogicalAggregate.create(input, indicator, groupSet, groupSets, aggCalls); } } //...... }
public final class LogicalAggregate extends Aggregate { //...... public static LogicalAggregate create(final RelNode input, ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls) { return create_(input, false, groupSet, groupSets, aggCalls); } @Deprecated // to be removed before 2.0 public static LogicalAggregate create(final RelNode input, boolean indicator, ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls) { return create_(input, indicator, groupSet, groupSets, aggCalls); } private static LogicalAggregate create_(final RelNode input, boolean indicator, ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls) { final RelOptCluster cluster = input.getCluster(); final RelTraitSet traitSet = cluster.traitSetOf(Convention.NONE); return new LogicalAggregate(cluster, traitSet, input, indicator, groupSet, groupSets, aggCalls); } //...... }