diff --git a/pinot-common/src/main/proto/plan.proto b/pinot-common/src/main/proto/plan.proto index 49d357307648..e3b2bbf65482 100644 --- a/pinot-common/src/main/proto/plan.proto +++ b/pinot-common/src/main/proto/plan.proto @@ -69,6 +69,8 @@ message AggregateNode { repeated int32 groupKeys = 3; AggType aggType = 4; bool leafReturnFinalResult = 5; + repeated Collation collations = 6; + int32 limit = 7; } message FilterNode { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java index 558b2f898539..82f80da4bee2 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java @@ -43,6 +43,7 @@ public static class AggregateOptions { public static final String IS_PARTITIONED_BY_GROUP_BY_KEYS = "is_partitioned_by_group_by_keys"; public static final String IS_LEAF_RETURN_FINAL_RESULT = "is_leaf_return_final_result"; public static final String SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION = "is_skip_leaf_stage_group_by"; + public static final String ENABLE_GROUP_TRIM = "is_enable_group_trim"; public static final String NUM_GROUPS_LIMIT = "num_groups_limit"; public static final String MAX_INITIAL_RESULT_HOLDER_CAPACITY = "max_initial_result_holder_capacity"; diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/logical/PinotLogicalAggregate.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/logical/PinotLogicalAggregate.java index 241c44703e6b..b382f9e63598 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/logical/PinotLogicalAggregate.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/logical/PinotLogicalAggregate.java @@ -22,6 +22,7 @@ import javax.annotation.Nullable; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelWriter; import org.apache.calcite.rel.core.Aggregate; @@ -35,39 +36,36 @@ public class PinotLogicalAggregate extends Aggregate { private final AggType _aggType; private final boolean _leafReturnFinalResult; + // The following fields are for group trimming purpose, and are extracted from the Sort on top of this Aggregate. + private final List _collations; + private final int _limit; + public PinotLogicalAggregate(RelOptCluster cluster, RelTraitSet traitSet, List hints, RelNode input, ImmutableBitSet groupSet, @Nullable List groupSets, List aggCalls, - AggType aggType, boolean leafReturnFinalResult) { + AggType aggType, boolean leafReturnFinalResult, @Nullable List collations, int limit) { super(cluster, traitSet, hints, input, groupSet, groupSets, aggCalls); _aggType = aggType; _leafReturnFinalResult = leafReturnFinalResult; + _collations = collations; + _limit = limit; } - public PinotLogicalAggregate(RelOptCluster cluster, RelTraitSet traitSet, List hints, RelNode input, - ImmutableBitSet groupSet, @Nullable List groupSets, List aggCalls, - AggType aggType) { - this(cluster, traitSet, hints, input, groupSet, groupSets, aggCalls, aggType, false); - } - - public PinotLogicalAggregate(Aggregate aggRel, List aggCalls, AggType aggType, - boolean leafReturnFinalResult) { - this(aggRel.getCluster(), aggRel.getTraitSet(), aggRel.getHints(), aggRel.getInput(), aggRel.getGroupSet(), - aggRel.getGroupSets(), aggCalls, aggType, leafReturnFinalResult); + public PinotLogicalAggregate(Aggregate aggRel, RelNode input, ImmutableBitSet groupSet, + @Nullable List groupSets, List aggCalls, AggType aggType, + boolean leafReturnFinalResult, @Nullable List collations, int limit) { + this(aggRel.getCluster(), aggRel.getTraitSet(), aggRel.getHints(), input, groupSet, groupSets, aggCalls, aggType, + leafReturnFinalResult, collations, limit); } - public PinotLogicalAggregate(Aggregate aggRel, List aggCalls, AggType aggType) { - this(aggRel, aggCalls, aggType, false); - } - - public PinotLogicalAggregate(Aggregate aggRel, RelNode input, List aggCalls, AggType aggType) { - this(aggRel.getCluster(), aggRel.getTraitSet(), aggRel.getHints(), input, aggRel.getGroupSet(), - aggRel.getGroupSets(), aggCalls, aggType); + public PinotLogicalAggregate(Aggregate aggRel, RelNode input, List aggCalls, AggType aggType, + boolean leafReturnFinalResult, @Nullable List collations, int limit) { + this(aggRel, input, aggRel.getGroupSet(), aggRel.getGroupSets(), aggCalls, aggType, + leafReturnFinalResult, collations, limit); } public PinotLogicalAggregate(Aggregate aggRel, RelNode input, ImmutableBitSet groupSet, List aggCalls, - AggType aggType, boolean leafReturnFinalResult) { - this(aggRel.getCluster(), aggRel.getTraitSet(), aggRel.getHints(), input, groupSet, null, aggCalls, aggType, - leafReturnFinalResult); + AggType aggType, boolean leafReturnFinalResult, @Nullable List collations, int limit) { + this(aggRel, input, groupSet, null, aggCalls, aggType, leafReturnFinalResult, collations, limit); } public AggType getAggType() { @@ -78,11 +76,20 @@ public boolean isLeafReturnFinalResult() { return _leafReturnFinalResult; } + @Nullable + public List getCollations() { + return _collations; + } + + public int getLimit() { + return _limit; + } + @Override public PinotLogicalAggregate copy(RelTraitSet traitSet, RelNode input, ImmutableBitSet groupSet, @Nullable List groupSets, List aggCalls) { return new PinotLogicalAggregate(getCluster(), traitSet, hints, input, groupSet, groupSets, aggCalls, _aggType, - _leafReturnFinalResult); + _leafReturnFinalResult, _collations, _limit); } @Override @@ -90,12 +97,14 @@ public RelWriter explainTerms(RelWriter pw) { RelWriter relWriter = super.explainTerms(pw); relWriter.item("aggType", _aggType); relWriter.itemIf("leafReturnFinalResult", true, _leafReturnFinalResult); + relWriter.itemIf("collations", _collations, _collations != null); + relWriter.itemIf("limit", _limit, _limit > 0); return relWriter; } @Override public RelNode withHints(List hintList) { return new PinotLogicalAggregate(getCluster(), traitSet, hintList, input, groupSet, groupSets, aggCalls, _aggType, - _leafReturnFinalResult); + _leafReturnFinalResult, _collations, _limit); } } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java index df11fdb49a2e..561dcc9c281c 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java @@ -28,10 +28,12 @@ import org.apache.calcite.rel.RelCollation; import org.apache.calcite.rel.RelDistribution; import org.apache.calcite.rel.RelDistributions; +import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.Sort; import org.apache.calcite.rel.core.Union; import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.rules.AggregateExtractProjectRule; @@ -83,48 +85,148 @@ * - COUNT(*)__LEAF produces TUPLE[ SUM(1), GROUP_BY_KEY ] * - COUNT(*)__FINAL produces TUPLE[ SUM(COUNT(*)__LEAF), GROUP_BY_KEY ] */ -public class PinotAggregateExchangeNodeInsertRule extends RelOptRule { - public static final PinotAggregateExchangeNodeInsertRule INSTANCE = - new PinotAggregateExchangeNodeInsertRule(PinotRuleUtils.PINOT_REL_FACTORY); - - public PinotAggregateExchangeNodeInsertRule(RelBuilderFactory factory) { - // NOTE: Explicitly match for LogicalAggregate because after applying the rule, LogicalAggregate is replaced with - // PinotLogicalAggregate, and the rule won't be applied again. - super(operand(LogicalAggregate.class, any()), factory, null); +public class PinotAggregateExchangeNodeInsertRule { + + public static class SortProjectAggregate extends RelOptRule { + public static final SortProjectAggregate INSTANCE = new SortProjectAggregate(PinotRuleUtils.PINOT_REL_FACTORY); + + private SortProjectAggregate(RelBuilderFactory factory) { + // NOTE: Explicitly match for LogicalAggregate because after applying the rule, LogicalAggregate is replaced with + // PinotLogicalAggregate, and the rule won't be applied again. + super(operand(Sort.class, operand(Project.class, operand(LogicalAggregate.class, any()))), factory, null); + } + + @Override + public void onMatch(RelOptRuleCall call) { + // Apply this rule for group-by queries with enable group trim hint. + LogicalAggregate aggRel = call.rel(2); + if (aggRel.getGroupSet().isEmpty()) { + return; + } + Map hintOptions = + PinotHintStrategyTable.getHintOptions(aggRel.getHints(), PinotHintOptions.AGGREGATE_HINT_OPTIONS); + if (hintOptions == null || !Boolean.parseBoolean( + hintOptions.get(PinotHintOptions.AggregateOptions.ENABLE_GROUP_TRIM))) { + return; + } + + Sort sortRel = call.rel(0); + Project projectRel = call.rel(1); + List projects = projectRel.getProjects(); + List collations = sortRel.getCollation().getFieldCollations(); + if (collations.isEmpty()) { + // Cannot enable group trim without sort key. + return; + } + List newCollations = new ArrayList<>(collations.size()); + for (RelFieldCollation fieldCollation : collations) { + RexNode project = projects.get(fieldCollation.getFieldIndex()); + if (project instanceof RexInputRef) { + newCollations.add(fieldCollation.withFieldIndex(((RexInputRef) project).getIndex())); + } else { + // Cannot enable group trim when the sort key is not a direct reference to the input. + return; + } + } + int limit = 0; + if (sortRel.fetch != null) { + limit = RexLiteral.intValue(sortRel.fetch); + } + if (limit <= 0) { + // Cannot enable group trim without limit. + return; + } + + PinotLogicalAggregate newAggRel = createPlan(call, aggRel, true, hintOptions, newCollations, limit); + RelNode newProjectRel = projectRel.copy(projectRel.getTraitSet(), List.of(newAggRel)); + call.transformTo(sortRel.copy(sortRel.getTraitSet(), List.of(newProjectRel))); + } } - /** - * Split the AGG into 3 plan fragments, all with the same AGG type (in some cases the final agg name may be different) - * Pinot internal plan fragment optimization can use the info of the input data type to infer whether it should - * generate the "final-stage AGG operator" or "intermediate-stage AGG operator" or "leaf-stage AGG operator" - * - * @param call the {@link RelOptRuleCall} on match. - * @see org.apache.pinot.core.query.aggregation.function.AggregationFunction - */ - @Override - public void onMatch(RelOptRuleCall call) { - Aggregate aggRel = call.rel(0); - boolean hasGroupBy = !aggRel.getGroupSet().isEmpty(); - RelCollation collation = extractWithInGroupCollation(aggRel); - Map hintOptions = - PinotHintStrategyTable.getHintOptions(aggRel.getHints(), PinotHintOptions.AGGREGATE_HINT_OPTIONS); - // Collation is not supported in leaf stage aggregation. - if (collation != null || (hasGroupBy && hintOptions != null && Boolean.parseBoolean( + public static class SortAggregate extends RelOptRule { + public static final SortAggregate INSTANCE = new SortAggregate(PinotRuleUtils.PINOT_REL_FACTORY); + + private SortAggregate(RelBuilderFactory factory) { + // NOTE: Explicitly match for LogicalAggregate because after applying the rule, LogicalAggregate is replaced with + // PinotLogicalAggregate, and the rule won't be applied again. + super(operand(Sort.class, operand(LogicalAggregate.class, any())), factory, null); + } + + @Override + public void onMatch(RelOptRuleCall call) { + // Apply this rule for group-by queries with enable group trim hint. + LogicalAggregate aggRel = call.rel(1); + if (aggRel.getGroupSet().isEmpty()) { + return; + } + Map hintOptions = + PinotHintStrategyTable.getHintOptions(aggRel.getHints(), PinotHintOptions.AGGREGATE_HINT_OPTIONS); + if (hintOptions == null || !Boolean.parseBoolean( + hintOptions.get(PinotHintOptions.AggregateOptions.ENABLE_GROUP_TRIM))) { + return; + } + + Sort sortRel = call.rel(0); + List collations = sortRel.getCollation().getFieldCollations(); + if (collations.isEmpty()) { + // Cannot enable group trim without sort key. + return; + } + int limit = 0; + if (sortRel.fetch != null) { + limit = RexLiteral.intValue(sortRel.fetch); + } + if (limit <= 0) { + // Cannot enable group trim without limit. + return; + } + + PinotLogicalAggregate newAggRel = createPlan(call, aggRel, true, hintOptions, collations, limit); + call.transformTo(sortRel.copy(sortRel.getTraitSet(), List.of(newAggRel))); + } + } + + public static class WithoutSort extends RelOptRule { + public static final WithoutSort INSTANCE = new WithoutSort(PinotRuleUtils.PINOT_REL_FACTORY); + + private WithoutSort(RelBuilderFactory factory) { + // NOTE: Explicitly match for LogicalAggregate because after applying the rule, LogicalAggregate is replaced with + // PinotLogicalAggregate, and the rule won't be applied again. + super(operand(LogicalAggregate.class, any()), factory, null); + } + + @Override + public void onMatch(RelOptRuleCall call) { + Aggregate aggRel = call.rel(0); + Map hintOptions = + PinotHintStrategyTable.getHintOptions(aggRel.getHints(), PinotHintOptions.AGGREGATE_HINT_OPTIONS); + call.transformTo( + createPlan(call, aggRel, !aggRel.getGroupSet().isEmpty(), hintOptions != null ? hintOptions : Map.of(), null, + 0)); + } + } + + private static PinotLogicalAggregate createPlan(RelOptRuleCall call, Aggregate aggRel, boolean hasGroupBy, + Map hintOptions, @Nullable List collations, int limit) { + // WITHIN GROUP collation is not supported in leaf stage aggregation. + RelCollation withinGroupCollation = extractWithinGroupCollation(aggRel); + if (withinGroupCollation != null || (hasGroupBy && Boolean.parseBoolean( hintOptions.get(PinotHintOptions.AggregateOptions.SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION)))) { - call.transformTo(createPlanWithExchangeDirectAggregation(call, collation)); - } else if (hasGroupBy && hintOptions != null && Boolean.parseBoolean( + return createPlanWithExchangeDirectAggregation(call, aggRel, withinGroupCollation, collations, limit); + } else if (hasGroupBy && Boolean.parseBoolean( hintOptions.get(PinotHintOptions.AggregateOptions.IS_PARTITIONED_BY_GROUP_BY_KEYS))) { - call.transformTo(new PinotLogicalAggregate(aggRel, buildAggCalls(aggRel, AggType.DIRECT, false), AggType.DIRECT)); + return new PinotLogicalAggregate(aggRel, aggRel.getInput(), buildAggCalls(aggRel, AggType.DIRECT, false), + AggType.DIRECT, false, collations, limit); } else { - boolean leafReturnFinalResult = hintOptions != null && Boolean.parseBoolean( - hintOptions.get(PinotHintOptions.AggregateOptions.IS_LEAF_RETURN_FINAL_RESULT)); - call.transformTo(createPlanWithLeafExchangeFinalAggregate(call, leafReturnFinalResult)); + boolean leafReturnFinalResult = + Boolean.parseBoolean(hintOptions.get(PinotHintOptions.AggregateOptions.IS_LEAF_RETURN_FINAL_RESULT)); + return createPlanWithLeafExchangeFinalAggregate(aggRel, leafReturnFinalResult, collations, limit); } } // TODO: Currently it only handles one WITHIN GROUP collation across all AggregateCalls. @Nullable - private static RelCollation extractWithInGroupCollation(Aggregate aggRel) { + private static RelCollation extractWithinGroupCollation(Aggregate aggRel) { for (AggregateCall aggCall : aggRel.getAggCallList()) { RelCollation collation = aggCall.getCollation(); if (!collation.getFieldCollations().isEmpty()) { @@ -138,55 +240,54 @@ private static RelCollation extractWithInGroupCollation(Aggregate aggRel) { * Use this group by optimization to skip leaf stage aggregation when aggregating at leaf level is not desired. Many * situation could be wasted effort to do group-by on leaf, eg: when cardinality of group by column is very high. */ - private static PinotLogicalAggregate createPlanWithExchangeDirectAggregation(RelOptRuleCall call, - @Nullable RelCollation collation) { - Aggregate aggRel = call.rel(0); + private static PinotLogicalAggregate createPlanWithExchangeDirectAggregation(RelOptRuleCall call, Aggregate aggRel, + @Nullable RelCollation withinGroupCollation, @Nullable List collations, int limit) { RelNode input = aggRel.getInput(); // Create Project when there's none below the aggregate. if (!(PinotRuleUtils.unboxRel(input) instanceof Project)) { - aggRel = (Aggregate) generateProjectUnderAggregate(call); + aggRel = (Aggregate) generateProjectUnderAggregate(call, aggRel); input = aggRel.getInput(); } ImmutableBitSet groupSet = aggRel.getGroupSet(); RelDistribution distribution = RelDistributions.hash(groupSet.asList()); RelNode exchange; - if (collation != null) { + if (withinGroupCollation != null) { // Insert a LogicalSort node between exchange and aggregate whe collation exists. - exchange = PinotLogicalSortExchange.create(input, distribution, collation, false, true); + exchange = PinotLogicalSortExchange.create(input, distribution, withinGroupCollation, false, true); } else { exchange = PinotLogicalExchange.create(input, distribution); } - return new PinotLogicalAggregate(aggRel, exchange, buildAggCalls(aggRel, AggType.DIRECT, false), AggType.DIRECT); + return new PinotLogicalAggregate(aggRel, exchange, buildAggCalls(aggRel, AggType.DIRECT, false), AggType.DIRECT, + false, collations, limit); } /** * Aggregate node will be split into LEAF + EXCHANGE + FINAL. * TODO: Add optional INTERMEDIATE stage to reduce hotspot. */ - private static PinotLogicalAggregate createPlanWithLeafExchangeFinalAggregate(RelOptRuleCall call, - boolean leafReturnFinalResult) { - Aggregate aggRel = call.rel(0); + private static PinotLogicalAggregate createPlanWithLeafExchangeFinalAggregate(Aggregate aggRel, + boolean leafReturnFinalResult, @Nullable List collations, int limit) { // Create a LEAF aggregate. PinotLogicalAggregate leafAggRel = - new PinotLogicalAggregate(aggRel, buildAggCalls(aggRel, AggType.LEAF, leafReturnFinalResult), AggType.LEAF, - leafReturnFinalResult); + new PinotLogicalAggregate(aggRel, aggRel.getInput(), buildAggCalls(aggRel, AggType.LEAF, leafReturnFinalResult), + AggType.LEAF, leafReturnFinalResult, collations, limit); // Create an EXCHANGE node over the LEAF aggregate. PinotLogicalExchange exchange = PinotLogicalExchange.create(leafAggRel, RelDistributions.hash(ImmutableIntList.range(0, aggRel.getGroupCount()))); // Create a FINAL aggregate over the EXCHANGE. - return convertAggFromIntermediateInput(call, exchange, AggType.FINAL, leafReturnFinalResult); + return convertAggFromIntermediateInput(aggRel, exchange, AggType.FINAL, leafReturnFinalResult, collations, limit); } /** * The following is copied from {@link AggregateExtractProjectRule#onMatch(RelOptRuleCall)} with modification to take * aggregate input as input. */ - private static RelNode generateProjectUnderAggregate(RelOptRuleCall call) { - final Aggregate aggregate = call.rel(0); + private static RelNode generateProjectUnderAggregate(RelOptRuleCall call, Aggregate aggregate) { // --------------- MODIFIED --------------- final RelNode input = aggregate.getInput(); + // final Aggregate aggregate = call.rel(0); // final RelNode input = call.rel(1); // ------------- END MODIFIED ------------- @@ -230,9 +331,8 @@ private static RelNode generateProjectUnderAggregate(RelOptRuleCall call) { return relBuilder.build(); } - private static PinotLogicalAggregate convertAggFromIntermediateInput(RelOptRuleCall call, - PinotLogicalExchange exchange, AggType aggType, boolean leafReturnFinalResult) { - Aggregate aggRel = call.rel(0); + private static PinotLogicalAggregate convertAggFromIntermediateInput(Aggregate aggRel, PinotLogicalExchange exchange, + AggType aggType, boolean leafReturnFinalResult, @Nullable List collations, int limit) { RelNode input = aggRel.getInput(); List projects = findImmediateProjects(input); @@ -269,7 +369,7 @@ private static PinotLogicalAggregate convertAggFromIntermediateInput(RelOptRuleC } return new PinotLogicalAggregate(aggRel, exchange, ImmutableBitSet.range(groupCount), aggCalls, aggType, - leafReturnFinalResult); + leafReturnFinalResult, collations, limit); } private static List buildAggCalls(Aggregate aggRel, AggType aggType, boolean leafReturnFinalResult) { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java index fdb75ee78f19..630574b18398 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java @@ -137,7 +137,9 @@ private PinotQueryRuleSets() { PinotSingleValueAggregateRemoveRule.INSTANCE, PinotJoinExchangeNodeInsertRule.INSTANCE, - PinotAggregateExchangeNodeInsertRule.INSTANCE, + PinotAggregateExchangeNodeInsertRule.SortProjectAggregate.INSTANCE, + PinotAggregateExchangeNodeInsertRule.SortAggregate.INSTANCE, + PinotAggregateExchangeNodeInsertRule.WithoutSort.INSTANCE, PinotWindowExchangeNodeInsertRule.INSTANCE, PinotSetOpExchangeNodeInsertRule.INSTANCE, diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java index a20b2479d4f0..4f82de9ff995 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java @@ -96,8 +96,7 @@ public static List convertAggregateList(List groupByList return expressions; } - public static List convertOrderByList(SortNode node, PinotQuery pinotQuery) { - List collations = node.getCollations(); + public static List convertOrderByList(List collations, PinotQuery pinotQuery) { List orderByExpressions = new ArrayList<>(collations.size()); for (RelFieldCollation collation : collations) { orderByExpressions.add(convertOrderBy(collation, pinotQuery)); diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PlanNodeMerger.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PlanNodeMerger.java index 611d4417259b..6ae02da45fc9 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PlanNodeMerger.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PlanNodeMerger.java @@ -147,6 +147,12 @@ public PlanNode visitAggregate(AggregateNode node, PlanNode context) { if (node.isLeafReturnFinalResult() != otherNode.isLeafReturnFinalResult()) { return null; } + if (!node.getCollations().equals(otherNode.getCollations())) { + return null; + } + if (node.getLimit() != otherNode.getLimit()) { + return null; + } List children = mergeChildren(node, context); if (children == null) { return null; diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java index 55813264ffb0..61cf5d5be626 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java @@ -195,7 +195,9 @@ public Boolean visitAggregate(AggregateNode node1, PlanNode node2) { && Objects.equals(node1.getFilterArgs(), that.getFilterArgs()) && Objects.equals(node1.getGroupKeys(), that.getGroupKeys()) && node1.getAggType() == that.getAggType() - && node1.isLeafReturnFinalResult() == that.isLeafReturnFinalResult(); + && node1.isLeafReturnFinalResult() == that.isLeafReturnFinalResult() + && Objects.equals(node1.getCollations(), that.getCollations()) + && node1.getLimit() == that.getLimit(); } @Override diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java index 38170116126a..3f5ab2261e0c 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java @@ -264,7 +264,7 @@ private AggregateNode convertLogicalAggregate(PinotLogicalAggregate node) { } return new AggregateNode(DEFAULT_STAGE_ID, toDataSchema(node.getRowType()), NodeHint.fromRelHints(node.getHints()), convertInputs(node.getInputs()), functionCalls, filterArgs, node.getGroupSet().asList(), node.getAggType(), - node.isLeafReturnFinalResult()); + node.isLeafReturnFinalResult(), node.getCollations(), node.getLimit()); } private ProjectNode convertLogicalProject(LogicalProject node) { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java index be4a6d9fb87d..323c58cce225 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java @@ -20,6 +20,8 @@ import java.util.List; import java.util.Objects; +import javax.annotation.Nullable; +import org.apache.calcite.rel.RelFieldCollation; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.query.planner.logical.RexExpression; @@ -31,15 +33,21 @@ public class AggregateNode extends BasePlanNode { private final AggType _aggType; private final boolean _leafReturnFinalResult; + // The following fields are for group trimming purpose, and are extracted from the Sort on top of this Aggregate. + private final List _collations; + private final int _limit; + public AggregateNode(int stageId, DataSchema dataSchema, NodeHint nodeHint, List inputs, List aggCalls, List filterArgs, List groupKeys, AggType aggType, - boolean leafReturnFinalResult) { + boolean leafReturnFinalResult, @Nullable List collations, int limit) { super(stageId, dataSchema, nodeHint, inputs); _aggCalls = aggCalls; _filterArgs = filterArgs; _groupKeys = groupKeys; _aggType = aggType; _leafReturnFinalResult = leafReturnFinalResult; + _collations = collations != null ? collations : List.of(); + _limit = limit; } public List getAggCalls() { @@ -62,6 +70,14 @@ public boolean isLeafReturnFinalResult() { return _leafReturnFinalResult; } + public List getCollations() { + return _collations; + } + + public int getLimit() { + return _limit; + } + @Override public String explain() { return "AGGREGATE_" + _aggType; @@ -75,7 +91,7 @@ public T visit(PlanNodeVisitor visitor, C context) { @Override public PlanNode withInputs(List inputs) { return new AggregateNode(_stageId, _dataSchema, _nodeHint, inputs, _aggCalls, _filterArgs, _groupKeys, _aggType, - _leafReturnFinalResult); + _leafReturnFinalResult, _collations, _limit); } @Override @@ -90,14 +106,15 @@ public boolean equals(Object o) { return false; } AggregateNode that = (AggregateNode) o; - return Objects.equals(_aggCalls, that._aggCalls) && Objects.equals(_filterArgs, that._filterArgs) && Objects.equals( - _groupKeys, that._groupKeys) && _aggType == that._aggType - && _leafReturnFinalResult == that._leafReturnFinalResult; + return _leafReturnFinalResult == that._leafReturnFinalResult && _limit == that._limit && Objects.equals(_aggCalls, + that._aggCalls) && Objects.equals(_filterArgs, that._filterArgs) && Objects.equals(_groupKeys, that._groupKeys) + && _aggType == that._aggType && Objects.equals(_collations, that._collations); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), _aggCalls, _filterArgs, _groupKeys, _aggType, _leafReturnFinalResult); + return Objects.hash(super.hashCode(), _aggCalls, _filterArgs, _groupKeys, _aggType, _leafReturnFinalResult, + _collations, _limit); } /** diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeDeserializer.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeDeserializer.java index abd474ebce3e..0f6851418925 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeDeserializer.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeDeserializer.java @@ -87,7 +87,8 @@ private static AggregateNode deserializeAggregateNode(Plan.PlanNode protoNode) { return new AggregateNode(protoNode.getStageId(), extractDataSchema(protoNode), extractNodeHint(protoNode), extractInputs(protoNode), convertFunctionCalls(protoAggregateNode.getAggCallsList()), protoAggregateNode.getFilterArgsList(), protoAggregateNode.getGroupKeysList(), - convertAggType(protoAggregateNode.getAggType()), protoAggregateNode.getLeafReturnFinalResult()); + convertAggType(protoAggregateNode.getAggType()), protoAggregateNode.getLeafReturnFinalResult(), + convertCollations(protoAggregateNode.getCollationsList()), protoAggregateNode.getLimit()); } private static FilterNode deserializeFilterNode(Plan.PlanNode protoNode) { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeSerializer.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeSerializer.java index 65ccb13b2cae..e7862173e749 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeSerializer.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeSerializer.java @@ -98,6 +98,8 @@ public Void visitAggregate(AggregateNode node, Plan.PlanNode.Builder builder) { .addAllGroupKeys(node.getGroupKeys()) .setAggType(convertAggType(node.getAggType())) .setLeafReturnFinalResult(node.isLeafReturnFinalResult()) + .addAllCollations(convertCollations(node.getCollations())) + .setLimit(node.getLimit()) .build(); builder.setAggregateNode(aggregateNode); return null; diff --git a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json index 63a69f5e8ecb..8baa87a06f5d 100644 --- a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json +++ b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json @@ -249,6 +249,39 @@ "\n LogicalTableScan(table=[[default, a]])", "\n" ] + }, + { + "description": "SQL hint based group by optimization with partitioned aggregated values and group trim enabled", + "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_leaf_return_final_result='true', is_enable_group_trim='true') */ col1, COUNT(DISTINCT col2) AS cnt FROM a WHERE a.col3 >= 0 GROUP BY col1 ORDER BY cnt DESC LIMIT 10", + "output": [ + "Execution Plan", + "\nLogicalSort(sort0=[$1], dir0=[DESC], offset=[0], fetch=[10])", + "\n PinotLogicalSortExchange(distribution=[hash], collation=[[1 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])", + "\n LogicalSort(sort0=[$1], dir0=[DESC], fetch=[10])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[DISTINCTCOUNT($1)], aggType=[FINAL], leafReturnFinalResult=[true], collations=[[1 DESC]], limit=[10])", + "\n PinotLogicalExchange(distribution=[hash[0]])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[DISTINCTCOUNT($1)], aggType=[LEAF], leafReturnFinalResult=[true], collations=[[1 DESC]], limit=[10])", + "\n LogicalFilter(condition=[>=($2, 0)])", + "\n LogicalTableScan(table=[[default, a]])", + "\n" + ] + }, + { + "description": "SQL hint based group by optimization with group trim enabled without returning group key", + "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_enable_group_trim='true') */ COUNT(DISTINCT col2) AS cnt FROM a WHERE a.col3 >= 0 GROUP BY col1 ORDER BY cnt DESC LIMIT 10", + "output": [ + "Execution Plan", + "\nLogicalSort(sort0=[$0], dir0=[DESC], offset=[0], fetch=[10])", + "\n PinotLogicalSortExchange(distribution=[hash], collation=[[0 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])", + "\n LogicalSort(sort0=[$0], dir0=[DESC], fetch=[10])", + "\n LogicalProject(cnt=[$1])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[DISTINCTCOUNT($1)], aggType=[FINAL], collations=[[1 DESC]], limit=[10])", + "\n PinotLogicalExchange(distribution=[hash[0]])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[DISTINCTCOUNT($1)], aggType=[LEAF], collations=[[1 DESC]], limit=[10])", + "\n LogicalFilter(condition=[>=($2, 0)])", + "\n LogicalTableScan(table=[[default, a]])", + "\n" + ] } ] } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java index bd58b7f64f04..254aa2430c93 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import org.apache.calcite.rel.RelFieldCollation; import org.apache.pinot.calcite.rel.logical.PinotRelExchangeType; import org.apache.pinot.common.datablock.DataBlock; import org.apache.pinot.common.request.DataSource; @@ -71,22 +72,27 @@ static void walkPlanNode(PlanNode node, ServerPlanRequestContext context) { public Void visitAggregate(AggregateNode node, ServerPlanRequestContext context) { if (visit(node.getInputs().get(0), context)) { PinotQuery pinotQuery = context.getPinotQuery(); - if (pinotQuery.getGroupByList() == null) { - List groupByList = CalciteRexExpressionParser.convertInputRefs(node.getGroupKeys(), pinotQuery); + List groupByList = CalciteRexExpressionParser.convertInputRefs(node.getGroupKeys(), pinotQuery); + if (!groupByList.isEmpty()) { pinotQuery.setGroupByList(groupByList); - pinotQuery.setSelectList( - CalciteRexExpressionParser.convertAggregateList(groupByList, node.getAggCalls(), node.getFilterArgs(), - pinotQuery)); - if (node.getAggType() == AggregateNode.AggType.DIRECT) { - pinotQuery.putToQueryOptions(CommonConstants.Broker.Request.QueryOptionKey.SERVER_RETURN_FINAL_RESULT, - "true"); - } else if (node.isLeafReturnFinalResult()) { - pinotQuery.putToQueryOptions( - CommonConstants.Broker.Request.QueryOptionKey.SERVER_RETURN_FINAL_RESULT_KEY_UNPARTITIONED, "true"); - } - // there cannot be any more modification of PinotQuery post agg, thus this is the last one possible. - context.setLeafStageBoundaryNode(node); } + pinotQuery.setSelectList( + CalciteRexExpressionParser.convertAggregateList(groupByList, node.getAggCalls(), node.getFilterArgs(), + pinotQuery)); + if (node.getAggType() == AggregateNode.AggType.DIRECT) { + pinotQuery.putToQueryOptions(CommonConstants.Broker.Request.QueryOptionKey.SERVER_RETURN_FINAL_RESULT, "true"); + } else if (node.isLeafReturnFinalResult()) { + pinotQuery.putToQueryOptions( + CommonConstants.Broker.Request.QueryOptionKey.SERVER_RETURN_FINAL_RESULT_KEY_UNPARTITIONED, "true"); + } + List collations = node.getCollations(); + int limit = node.getLimit(); + if (!collations.isEmpty() && limit > 0) { + pinotQuery.setOrderByList(CalciteRexExpressionParser.convertOrderByList(collations, pinotQuery)); + pinotQuery.setLimit(limit); + } + // There cannot be any more modification of PinotQuery post agg, thus this is the last one possible. + context.setLeafStageBoundaryNode(node); } return null; } @@ -193,8 +199,9 @@ public Void visitSort(SortNode node, ServerPlanRequestContext context) { if (visit(node.getInputs().get(0), context)) { PinotQuery pinotQuery = context.getPinotQuery(); if (pinotQuery.getOrderByList() == null) { - if (!node.getCollations().isEmpty()) { - pinotQuery.setOrderByList(CalciteRexExpressionParser.convertOrderByList(node, pinotQuery)); + List collations = node.getCollations(); + if (!collations.isEmpty()) { + pinotQuery.setOrderByList(CalciteRexExpressionParser.convertOrderByList(collations, pinotQuery)); } if (node.getFetch() >= 0) { pinotQuery.setLimit(node.getFetch()); diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java index f7f56e0ccb6e..b2e73f226a3a 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java @@ -273,7 +273,7 @@ private AggregateOperator getOperator(DataSchema resultSchema, List filterArgs, List groupKeys, PlanNode.NodeHint nodeHint) { return new AggregateOperator(OperatorTestUtil.getTracingContext(), _input, new AggregateNode(-1, resultSchema, nodeHint, List.of(), aggCalls, filterArgs, groupKeys, AggType.DIRECT, - false)); + false, null, 0)); } private AggregateOperator getOperator(DataSchema resultSchema, List aggCalls, diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java index fc7ebba0b4cb..05ccf5762191 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java @@ -152,7 +152,7 @@ private static MultiStageOperator getAggregateOperator() { new DataSchema(new String[]{"group", "sum"}, new DataSchema.ColumnDataType[]{INT, DOUBLE}); return new AggregateOperator(OperatorTestUtil.getTracingContext(), input, new AggregateNode(-1, resultSchema, PlanNode.NodeHint.EMPTY, List.of(), aggCalls, filterArgs, groupKeys, - AggregateNode.AggType.DIRECT, false)); + AggregateNode.AggType.DIRECT, false, null, 0)); } private static MultiStageOperator getHashJoinOperator() { diff --git a/pinot-query-runtime/src/test/resources/queries/QueryHints.json b/pinot-query-runtime/src/test/resources/queries/QueryHints.json index e7c2ca375700..ba7f2ec01960 100644 --- a/pinot-query-runtime/src/test/resources/queries/QueryHints.json +++ b/pinot-query-runtime/src/test/resources/queries/QueryHints.json @@ -321,6 +321,10 @@ "description": "aggregate with skip intermediate stage hint (via hint option is_partitioned_by_group_by_keys)", "sql": "SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') */ {tbl1}.num, COUNT(*), SUM({tbl1}.val), SUM({tbl1}.num), COUNT(DISTINCT {tbl1}.val) FROM {tbl1} WHERE {tbl1}.val >= 0 AND {tbl1}.name != 'a' GROUP BY {tbl1}.num" }, + { + "description": "aggregate with skip intermediate stage hint and group trim enabled", + "sql": "SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true', is_enable_group_trim='true') */ {tbl1}.num, COUNT(*), SUM({tbl1}.val), SUM({tbl1}.num), COUNT(DISTINCT {tbl1}.val) FROM {tbl1} WHERE {tbl1}.val >= 0 AND {tbl1}.name != 'a' GROUP BY {tbl1}.num ORDER BY COUNT(*) DESC LIMIT 1" + }, { "description": "join with pre-partitioned left and right tables", "sql": "SELECT {tbl1}.num, {tbl1}.val, {tbl2}.data FROM {tbl1} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ JOIN {tbl2} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ ON {tbl1}.num = {tbl2}.num WHERE {tbl2}.data > 0"