From 014ce4380d59c9d0a54f20422fe286bb4c9a046d Mon Sep 17 00:00:00 2001 From: "Xiaotian (Jackie) Jiang" Date: Wed, 11 Dec 2024 18:11:56 -0800 Subject: [PATCH] Support is_leaf_return_final_result agg option --- pinot-common/src/main/proto/plan.proto | 1 + .../SumPrecisionAggregationFunction.java | 1 + .../tests/OfflineClusterIntegrationTest.java | 4 +- .../calcite/rel/hint/PinotHintOptions.java | 1 + .../rel/hint/PinotHintStrategyTable.java | 99 +++++++++------- .../rel/logical/PinotLogicalAggregate.java | 43 +++++-- .../PinotAggregateExchangeNodeInsertRule.java | 64 ++++++----- .../query/planner/explain/PlanNodeMerger.java | 3 + .../logical/EquivalentStagesFinder.java | 3 +- .../logical/RelToPlanNodeConverter.java | 3 +- .../query/planner/plannode/AggregateNode.java | 17 ++- .../planner/serde/PlanNodeDeserializer.java | 2 +- .../planner/serde/PlanNodeSerializer.java | 92 +++++++++------ .../pinot/query/QueryCompilationTest.java | 4 +- .../resources/queries/AggregatePlans.json | 48 ++++---- .../resources/queries/BasicQueryPlans.json | 8 +- .../test/resources/queries/GroupByPlans.json | 64 +++++++---- .../src/test/resources/queries/JoinPlans.json | 68 +++++------ .../queries/LiteralEvaluationPlans.json | 8 +- .../test/resources/queries/OrderByPlans.json | 12 +- .../resources/queries/PinotHintablePlans.json | 74 ++++++------ .../test/resources/queries/SetOpPlans.json | 4 +- .../queries/WindowFunctionPlans.json | 108 +++++++++--------- .../runtime/operator/AggregateOperator.java | 4 +- .../operator/MultistageGroupByExecutor.java | 95 ++++++++++----- .../plan/server/ServerPlanRequestVisitor.java | 3 + .../operator/AggregateOperatorTest.java | 3 +- .../operator/MultiStageAccountingTest.java | 2 +- .../test/resources/queries/QueryHints.json | 20 ++++ .../segment/spi/AggregationFunctionType.java | 96 ++++++++++------ 30 files changed, 573 insertions(+), 381 deletions(-) diff --git a/pinot-common/src/main/proto/plan.proto b/pinot-common/src/main/proto/plan.proto index 06b2f0910cfd..49d357307648 100644 --- a/pinot-common/src/main/proto/plan.proto +++ b/pinot-common/src/main/proto/plan.proto @@ -68,6 +68,7 @@ message AggregateNode { repeated int32 filterArgs = 2; repeated int32 groupKeys = 3; AggType aggType = 4; + bool leafReturnFinalResult = 5; } message FilterNode { diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java index 2052ea6e154f..24a4a73c44ce 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java @@ -351,6 +351,7 @@ public ColumnDataType getIntermediateResultColumnType() { @Override public ColumnDataType getFinalResultColumnType() { + // TODO: Revisit if we should change this to BIG_DECIMAL return ColumnDataType.STRING; } diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java index 9e287cb28473..4c3c50bb8797 100644 --- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java +++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java @@ -3257,9 +3257,9 @@ public void testExplainPlanQueryV2() + " PinotLogicalSortExchange(" + "distribution=[hash], collation=[[0]], isSortOnSender=[false], isSortOnReceiver=[true])\\n" + " LogicalProject(count=[$1], name=[$0])\\n" - + " PinotLogicalAggregate(group=[{0}], agg#0=[COUNT($1)])\\n" + + " PinotLogicalAggregate(group=[{0}], agg#0=[COUNT($1)], aggType=[FINAL])\\n" + " PinotLogicalExchange(distribution=[hash[0]])\\n" - + " PinotLogicalAggregate(group=[{17}], agg#0=[COUNT()])\\n" + + " PinotLogicalAggregate(group=[{17}], agg#0=[COUNT()], aggType=[LEAF])\\n" + " LogicalTableScan(table=[[default, mytable]])\\n" + "\"]]}"); //@formatter:on diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java index a753abca4acd..558b2f898539 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java @@ -41,6 +41,7 @@ private PinotHintOptions() { public static class AggregateOptions { public static final String IS_PARTITIONED_BY_GROUP_BY_KEYS = "is_partitioned_by_group_by_keys"; + public static final String IS_LEAF_RETURN_FINAL_RESULT = "is_leaf_return_final_result"; public static final String SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION = "is_skip_leaf_stage_group_by"; public static final String NUM_GROUPS_LIMIT = "num_groups_limit"; diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintStrategyTable.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintStrategyTable.java index 9de47caea8c7..503dcfe80f6f 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintStrategyTable.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintStrategyTable.java @@ -19,13 +19,14 @@ package org.apache.pinot.calcite.rel.hint; import java.util.List; +import java.util.Map; import java.util.function.Predicate; import javax.annotation.Nullable; +import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.hint.HintPredicates; import org.apache.calcite.rel.hint.HintStrategyTable; import org.apache.calcite.rel.hint.Hintable; import org.apache.calcite.rel.hint.RelHint; -import org.apache.pinot.spi.utils.BooleanUtils; /** @@ -35,21 +36,18 @@ public class PinotHintStrategyTable { private PinotHintStrategyTable() { } - public static final HintStrategyTable PINOT_HINT_STRATEGY_TABLE = - HintStrategyTable.builder().hintStrategy(PinotHintOptions.AGGREGATE_HINT_OPTIONS, HintPredicates.AGGREGATE) - .hintStrategy(PinotHintOptions.JOIN_HINT_OPTIONS, HintPredicates.JOIN) - .hintStrategy(PinotHintOptions.TABLE_HINT_OPTIONS, HintPredicates.TABLE_SCAN).build(); - + public static final HintStrategyTable PINOT_HINT_STRATEGY_TABLE = HintStrategyTable.builder() + .hintStrategy(PinotHintOptions.AGGREGATE_HINT_OPTIONS, HintPredicates.AGGREGATE) + .hintStrategy(PinotHintOptions.JOIN_HINT_OPTIONS, HintPredicates.JOIN) + .hintStrategy(PinotHintOptions.TABLE_HINT_OPTIONS, HintPredicates.TABLE_SCAN) + .build(); /** * Get the first hint that has the specified name. */ @Nullable public static RelHint getHint(Hintable hintable, String hintName) { - return hintable.getHints().stream() - .filter(relHint -> relHint.hintName.equals(hintName)) - .findFirst() - .orElse(null); + return getHint(hintable.getHints(), hintName); } /** @@ -57,16 +55,13 @@ public static RelHint getHint(Hintable hintable, String hintName) { */ @Nullable public static RelHint getHint(Hintable hintable, Predicate predicate) { - return hintable.getHints().stream() - .filter(predicate) - .findFirst() - .orElse(null); + return getHint(hintable.getHints(), predicate); } /** - * Check if a hint-able {@link org.apache.calcite.rel.RelNode} contains a specific {@link RelHint} by name. + * Check if a hint-able {@link RelNode} contains a specific {@link RelHint} by name. * - * @param hintList hint list from the {@link org.apache.calcite.rel.RelNode}. + * @param hintList hint list from the {@link RelNode}. * @param hintName the name of the {@link RelHint} to be matched * @return true if it contains the hint */ @@ -79,58 +74,76 @@ public static boolean containsHint(List hintList, String hintName) { return false; } + @Nullable + public static RelHint getHint(List hintList, String hintName) { + for (RelHint relHint : hintList) { + if (relHint.hintName.equals(hintName)) { + return relHint; + } + } + return null; + } + + @Nullable + public static RelHint getHint(List hintList, Predicate predicate) { + for (RelHint hint : hintList) { + if (predicate.test(hint)) { + return hint; + } + } + return null; + } + /** - * Check if a hint-able {@link org.apache.calcite.rel.RelNode} contains an option key for a specific hint name of - * {@link RelHint}. + * Returns the hint options for a specific hint name of {@link RelHint}, or {@code null} if the hint is not present. + */ + @Nullable + public static Map getHintOptions(List hintList, String hintName) { + for (RelHint relHint : hintList) { + if (relHint.hintName.equals(hintName)) { + return relHint.kvOptions; + } + } + return null; + } + + /** + * Check if a hint-able {@link RelNode} contains an option key for a specific hint name of {@link RelHint}. * - * @param hintList hint list from the {@link org.apache.calcite.rel.RelNode}. + * @param hintList hint list from the {@link RelNode}. * @param hintName the name of the {@link RelHint}. * @param optionKey the option key to look for in the {@link RelHint#kvOptions}. * @return true if it contains the hint */ public static boolean containsHintOption(List hintList, String hintName, String optionKey) { - for (RelHint relHint : hintList) { - if (relHint.hintName.equals(hintName)) { - return relHint.kvOptions.containsKey(optionKey); - } - } - return false; + Map options = getHintOptions(hintList, hintName); + return options != null && options.containsKey(optionKey); } /** - * Check if a hint-able {@link org.apache.calcite.rel.RelNode} contains an option key for a specific hint name of - * {@link RelHint}, and the value is true via {@link BooleanUtils#toBoolean(Object)}. + * Check if a hint-able {@link RelNode} has an option key as {@code true} for a specific hint name of {@link RelHint}. * - * @param hintList hint list from the {@link org.apache.calcite.rel.RelNode}. + * @param hintList hint list from the {@link RelNode}. * @param hintName the name of the {@link RelHint}. * @param optionKey the option key to look for in the {@link RelHint#kvOptions}. * @return true if it contains the hint */ public static boolean isHintOptionTrue(List hintList, String hintName, String optionKey) { - for (RelHint relHint : hintList) { - if (relHint.hintName.equals(hintName)) { - return relHint.kvOptions.containsKey(optionKey) && BooleanUtils.toBoolean(relHint.kvOptions.get(optionKey)); - } - } - return false; + Map options = getHintOptions(hintList, hintName); + return options != null && Boolean.parseBoolean(options.get(optionKey)); } /** * Retrieve the option value by option key in the {@link RelHint#kvOptions}. the option key is looked up from the - * specified hint name for a hint-able {@link org.apache.calcite.rel.RelNode}. + * specified hint name for a hint-able {@link RelNode}. * - * @param hintList hint list from the {@link org.apache.calcite.rel.RelNode}. + * @param hintList hint list from the {@link RelNode}. * @param hintName the name of the {@link RelHint}. * @param optionKey the option key to look for in the {@link RelHint#kvOptions}. - * @return true if it contains the hint */ @Nullable public static String getHintOption(List hintList, String hintName, String optionKey) { - for (RelHint relHint : hintList) { - if (relHint.hintName.equals(hintName)) { - return relHint.kvOptions.get(optionKey); - } - } - return null; + Map options = getHintOptions(hintList, hintName); + return options != null ? options.get(optionKey) : null; } } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/logical/PinotLogicalAggregate.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/logical/PinotLogicalAggregate.java index a727f0cefe1b..241c44703e6b 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/logical/PinotLogicalAggregate.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/logical/PinotLogicalAggregate.java @@ -23,6 +23,7 @@ import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelWriter; import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.hint.RelHint; @@ -32,17 +33,30 @@ public class PinotLogicalAggregate extends Aggregate { private final AggType _aggType; + private final boolean _leafReturnFinalResult; public PinotLogicalAggregate(RelOptCluster cluster, RelTraitSet traitSet, List hints, RelNode input, ImmutableBitSet groupSet, @Nullable List groupSets, List aggCalls, - AggType aggType) { + AggType aggType, boolean leafReturnFinalResult) { super(cluster, traitSet, hints, input, groupSet, groupSets, aggCalls); _aggType = aggType; + _leafReturnFinalResult = leafReturnFinalResult; } - public PinotLogicalAggregate(Aggregate aggRel, List aggCalls, AggType aggType) { + public PinotLogicalAggregate(RelOptCluster cluster, RelTraitSet traitSet, List hints, RelNode input, + ImmutableBitSet groupSet, @Nullable List groupSets, List aggCalls, + AggType aggType) { + this(cluster, traitSet, hints, input, groupSet, groupSets, aggCalls, aggType, false); + } + + public PinotLogicalAggregate(Aggregate aggRel, List aggCalls, AggType aggType, + boolean leafReturnFinalResult) { this(aggRel.getCluster(), aggRel.getTraitSet(), aggRel.getHints(), aggRel.getInput(), aggRel.getGroupSet(), - aggRel.getGroupSets(), aggCalls, aggType); + aggRel.getGroupSets(), aggCalls, aggType, leafReturnFinalResult); + } + + public PinotLogicalAggregate(Aggregate aggRel, List aggCalls, AggType aggType) { + this(aggRel, aggCalls, aggType, false); } public PinotLogicalAggregate(Aggregate aggRel, RelNode input, List aggCalls, AggType aggType) { @@ -51,22 +65,37 @@ public PinotLogicalAggregate(Aggregate aggRel, RelNode input, List aggCalls, - AggType aggType) { - this(aggRel.getCluster(), aggRel.getTraitSet(), aggRel.getHints(), input, groupSet, null, aggCalls, aggType); + AggType aggType, boolean leafReturnFinalResult) { + this(aggRel.getCluster(), aggRel.getTraitSet(), aggRel.getHints(), input, groupSet, null, aggCalls, aggType, + leafReturnFinalResult); } public AggType getAggType() { return _aggType; } + public boolean isLeafReturnFinalResult() { + return _leafReturnFinalResult; + } + @Override public PinotLogicalAggregate copy(RelTraitSet traitSet, RelNode input, ImmutableBitSet groupSet, @Nullable List groupSets, List aggCalls) { - return new PinotLogicalAggregate(getCluster(), traitSet, hints, input, groupSet, groupSets, aggCalls, _aggType); + return new PinotLogicalAggregate(getCluster(), traitSet, hints, input, groupSet, groupSets, aggCalls, _aggType, + _leafReturnFinalResult); + } + + @Override + public RelWriter explainTerms(RelWriter pw) { + RelWriter relWriter = super.explainTerms(pw); + relWriter.item("aggType", _aggType); + relWriter.itemIf("leafReturnFinalResult", true, _leafReturnFinalResult); + return relWriter; } @Override public RelNode withHints(List hintList) { - return new PinotLogicalAggregate(getCluster(), traitSet, hintList, input, groupSet, groupSets, aggCalls, _aggType); + return new PinotLogicalAggregate(getCluster(), traitSet, hintList, input, groupSet, groupSets, aggCalls, _aggType, + _leafReturnFinalResult); } } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java index 1c66f9a64890..df11fdb49a2e 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java @@ -21,6 +21,7 @@ import com.google.common.collect.ImmutableList; import java.util.ArrayList; import java.util.List; +import java.util.Map; import javax.annotation.Nullable; import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; @@ -32,7 +33,6 @@ import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.core.Union; -import org.apache.calcite.rel.hint.RelHint; import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.rules.AggregateExtractProjectRule; import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule; @@ -104,19 +104,21 @@ public PinotAggregateExchangeNodeInsertRule(RelBuilderFactory factory) { @Override public void onMatch(RelOptRuleCall call) { Aggregate aggRel = call.rel(0); - ImmutableList hints = aggRel.getHints(); - // Collation is not supported in leaf stage aggregation. - RelCollation collation = extractWithInGroupCollation(aggRel); boolean hasGroupBy = !aggRel.getGroupSet().isEmpty(); - if (collation != null || (hasGroupBy && PinotHintStrategyTable.isHintOptionTrue(hints, - PinotHintOptions.AGGREGATE_HINT_OPTIONS, - PinotHintOptions.AggregateOptions.SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION))) { + RelCollation collation = extractWithInGroupCollation(aggRel); + Map hintOptions = + PinotHintStrategyTable.getHintOptions(aggRel.getHints(), PinotHintOptions.AGGREGATE_HINT_OPTIONS); + // Collation is not supported in leaf stage aggregation. + if (collation != null || (hasGroupBy && hintOptions != null && Boolean.parseBoolean( + hintOptions.get(PinotHintOptions.AggregateOptions.SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION)))) { call.transformTo(createPlanWithExchangeDirectAggregation(call, collation)); - } else if (hasGroupBy && PinotHintStrategyTable.isHintOptionTrue(hints, PinotHintOptions.AGGREGATE_HINT_OPTIONS, - PinotHintOptions.AggregateOptions.IS_PARTITIONED_BY_GROUP_BY_KEYS)) { - call.transformTo(new PinotLogicalAggregate(aggRel, buildAggCalls(aggRel, AggType.DIRECT), AggType.DIRECT)); + } else if (hasGroupBy && hintOptions != null && Boolean.parseBoolean( + hintOptions.get(PinotHintOptions.AggregateOptions.IS_PARTITIONED_BY_GROUP_BY_KEYS))) { + call.transformTo(new PinotLogicalAggregate(aggRel, buildAggCalls(aggRel, AggType.DIRECT, false), AggType.DIRECT)); } else { - call.transformTo(createPlanWithLeafExchangeFinalAggregate(call)); + boolean leafReturnFinalResult = hintOptions != null && Boolean.parseBoolean( + hintOptions.get(PinotHintOptions.AggregateOptions.IS_LEAF_RETURN_FINAL_RESULT)); + call.transformTo(createPlanWithLeafExchangeFinalAggregate(call, leafReturnFinalResult)); } } @@ -156,23 +158,25 @@ private static PinotLogicalAggregate createPlanWithExchangeDirectAggregation(Rel exchange = PinotLogicalExchange.create(input, distribution); } - return new PinotLogicalAggregate(aggRel, exchange, buildAggCalls(aggRel, AggType.DIRECT), AggType.DIRECT); + return new PinotLogicalAggregate(aggRel, exchange, buildAggCalls(aggRel, AggType.DIRECT, false), AggType.DIRECT); } /** * Aggregate node will be split into LEAF + EXCHANGE + FINAL. * TODO: Add optional INTERMEDIATE stage to reduce hotspot. */ - private static PinotLogicalAggregate createPlanWithLeafExchangeFinalAggregate(RelOptRuleCall call) { + private static PinotLogicalAggregate createPlanWithLeafExchangeFinalAggregate(RelOptRuleCall call, + boolean leafReturnFinalResult) { Aggregate aggRel = call.rel(0); // Create a LEAF aggregate. PinotLogicalAggregate leafAggRel = - new PinotLogicalAggregate(aggRel, buildAggCalls(aggRel, AggType.LEAF), AggType.LEAF); + new PinotLogicalAggregate(aggRel, buildAggCalls(aggRel, AggType.LEAF, leafReturnFinalResult), AggType.LEAF, + leafReturnFinalResult); // Create an EXCHANGE node over the LEAF aggregate. PinotLogicalExchange exchange = PinotLogicalExchange.create(leafAggRel, RelDistributions.hash(ImmutableIntList.range(0, aggRel.getGroupCount()))); // Create a FINAL aggregate over the EXCHANGE. - return convertAggFromIntermediateInput(call, exchange, AggType.FINAL); + return convertAggFromIntermediateInput(call, exchange, AggType.FINAL, leafReturnFinalResult); } /** @@ -212,12 +216,14 @@ private static RelNode generateProjectUnderAggregate(RelOptRuleCall call) { relBuilder.project(projects); final ImmutableBitSet newGroupSet = Mappings.apply(mapping, aggregate.getGroupSet()); - final List newGroupSets = - aggregate.getGroupSets().stream().map(bitSet -> Mappings.apply(mapping, bitSet)) - .collect(ImmutableList.toImmutableList()); - final List newAggCallList = - aggregate.getAggCallList().stream().map(aggCall -> relBuilder.aggregateCall(aggCall, mapping)) - .collect(ImmutableList.toImmutableList()); + final List newGroupSets = aggregate.getGroupSets() + .stream() + .map(bitSet -> Mappings.apply(mapping, bitSet)) + .collect(ImmutableList.toImmutableList()); + final List newAggCallList = aggregate.getAggCallList() + .stream() + .map(aggCall -> relBuilder.aggregateCall(aggCall, mapping)) + .collect(ImmutableList.toImmutableList()); final RelBuilder.GroupKey groupKey = relBuilder.groupKey(newGroupSet, newGroupSets); relBuilder.aggregate(groupKey, newAggCallList); @@ -225,7 +231,7 @@ private static RelNode generateProjectUnderAggregate(RelOptRuleCall call) { } private static PinotLogicalAggregate convertAggFromIntermediateInput(RelOptRuleCall call, - PinotLogicalExchange exchange, AggType aggType) { + PinotLogicalExchange exchange, AggType aggType, boolean leafReturnFinalResult) { Aggregate aggRel = call.rel(0); RelNode input = aggRel.getInput(); List projects = findImmediateProjects(input); @@ -259,13 +265,14 @@ private static PinotLogicalAggregate convertAggFromIntermediateInput(RelOptRuleC } } } - aggCalls.add(buildAggCall(exchange, orgAggCall, rexList, groupCount, aggType)); + aggCalls.add(buildAggCall(exchange, orgAggCall, rexList, groupCount, aggType, leafReturnFinalResult)); } - return new PinotLogicalAggregate(aggRel, exchange, ImmutableBitSet.range(groupCount), aggCalls, aggType); + return new PinotLogicalAggregate(aggRel, exchange, ImmutableBitSet.range(groupCount), aggCalls, aggType, + leafReturnFinalResult); } - private static List buildAggCalls(Aggregate aggRel, AggType aggType) { + private static List buildAggCalls(Aggregate aggRel, AggType aggType, boolean leafReturnFinalResult) { RelNode input = aggRel.getInput(); List projects = findImmediateProjects(input); List orgAggCalls = aggRel.getAggCallList(); @@ -291,7 +298,7 @@ private static List buildAggCalls(Aggregate aggRel, AggType aggTy } } } - aggCalls.add(buildAggCall(input, orgAggCall, rexList, aggRel.getGroupCount(), aggType)); + aggCalls.add(buildAggCall(input, orgAggCall, rexList, aggRel.getGroupCount(), aggType, leafReturnFinalResult)); } return aggCalls; } @@ -300,7 +307,7 @@ private static List buildAggCalls(Aggregate aggRel, AggType aggTy // - DISTINCT is resolved here // - argList is replaced with rexList private static AggregateCall buildAggCall(RelNode input, AggregateCall orgAggCall, List rexList, - int numGroups, AggType aggType) { + int numGroups, AggType aggType, boolean leafReturnFinalResult) { SqlAggFunction orgAggFunction = orgAggCall.getAggregation(); String functionName = orgAggFunction.getName(); SqlKind kind = orgAggFunction.getKind(); @@ -319,7 +326,8 @@ private static AggregateCall buildAggCall(RelNode input, AggregateCall orgAggCal // Override the intermediate result type inference if it is provided if (aggType.isOutputIntermediateFormat()) { AggregationFunctionType functionType = AggregationFunctionType.getAggregationFunctionType(functionName); - returnTypeInference = functionType.getIntermediateReturnTypeInference(); + returnTypeInference = leafReturnFinalResult ? functionType.getFinalReturnTypeInference() + : functionType.getIntermediateReturnTypeInference(); } // When the output is not intermediate format, or intermediate result type inference is not provided (intermediate // result type the same as final result type), use the explicit return type diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PlanNodeMerger.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PlanNodeMerger.java index aa2e44173b4e..611d4417259b 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PlanNodeMerger.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PlanNodeMerger.java @@ -144,6 +144,9 @@ public PlanNode visitAggregate(AggregateNode node, PlanNode context) { if (node.getAggType() != otherNode.getAggType()) { return null; } + if (node.isLeafReturnFinalResult() != otherNode.isLeafReturnFinalResult()) { + return null; + } List children = mergeChildren(node, context); if (children == null) { return null; diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java index 28bca306cd5c..55813264ffb0 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java @@ -194,7 +194,8 @@ public Boolean visitAggregate(AggregateNode node1, PlanNode node2) { return areBaseNodesEquivalent(node1, node2) && Objects.equals(node1.getAggCalls(), that.getAggCalls()) && Objects.equals(node1.getFilterArgs(), that.getFilterArgs()) && Objects.equals(node1.getGroupKeys(), that.getGroupKeys()) - && node1.getAggType() == that.getAggType(); + && node1.getAggType() == that.getAggType() + && node1.isLeafReturnFinalResult() == that.isLeafReturnFinalResult(); } @Override diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java index bebf3abf3981..38170116126a 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java @@ -263,7 +263,8 @@ private AggregateNode convertLogicalAggregate(PinotLogicalAggregate node) { filterArgs.add(aggregateCall.filterArg); } return new AggregateNode(DEFAULT_STAGE_ID, toDataSchema(node.getRowType()), NodeHint.fromRelHints(node.getHints()), - convertInputs(node.getInputs()), functionCalls, filterArgs, node.getGroupSet().asList(), node.getAggType()); + convertInputs(node.getInputs()), functionCalls, filterArgs, node.getGroupSet().asList(), node.getAggType(), + node.isLeafReturnFinalResult()); } private ProjectNode convertLogicalProject(LogicalProject node) { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java index a12b049ca280..be4a6d9fb87d 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java @@ -29,14 +29,17 @@ public class AggregateNode extends BasePlanNode { private final List _filterArgs; private final List _groupKeys; private final AggType _aggType; + private final boolean _leafReturnFinalResult; public AggregateNode(int stageId, DataSchema dataSchema, NodeHint nodeHint, List inputs, - List aggCalls, List filterArgs, List groupKeys, AggType aggType) { + List aggCalls, List filterArgs, List groupKeys, AggType aggType, + boolean leafReturnFinalResult) { super(stageId, dataSchema, nodeHint, inputs); _aggCalls = aggCalls; _filterArgs = filterArgs; _groupKeys = groupKeys; _aggType = aggType; + _leafReturnFinalResult = leafReturnFinalResult; } public List getAggCalls() { @@ -55,6 +58,10 @@ public AggType getAggType() { return _aggType; } + public boolean isLeafReturnFinalResult() { + return _leafReturnFinalResult; + } + @Override public String explain() { return "AGGREGATE_" + _aggType; @@ -67,7 +74,8 @@ public T visit(PlanNodeVisitor visitor, C context) { @Override public PlanNode withInputs(List inputs) { - return new AggregateNode(_stageId, _dataSchema, _nodeHint, inputs, _aggCalls, _filterArgs, _groupKeys, _aggType); + return new AggregateNode(_stageId, _dataSchema, _nodeHint, inputs, _aggCalls, _filterArgs, _groupKeys, _aggType, + _leafReturnFinalResult); } @Override @@ -83,12 +91,13 @@ public boolean equals(Object o) { } AggregateNode that = (AggregateNode) o; return Objects.equals(_aggCalls, that._aggCalls) && Objects.equals(_filterArgs, that._filterArgs) && Objects.equals( - _groupKeys, that._groupKeys) && _aggType == that._aggType; + _groupKeys, that._groupKeys) && _aggType == that._aggType + && _leafReturnFinalResult == that._leafReturnFinalResult; } @Override public int hashCode() { - return Objects.hash(super.hashCode(), _aggCalls, _filterArgs, _groupKeys, _aggType); + return Objects.hash(super.hashCode(), _aggCalls, _filterArgs, _groupKeys, _aggType, _leafReturnFinalResult); } /** diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeDeserializer.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeDeserializer.java index dca8cb18954a..abd474ebce3e 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeDeserializer.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeDeserializer.java @@ -87,7 +87,7 @@ private static AggregateNode deserializeAggregateNode(Plan.PlanNode protoNode) { return new AggregateNode(protoNode.getStageId(), extractDataSchema(protoNode), extractNodeHint(protoNode), extractInputs(protoNode), convertFunctionCalls(protoAggregateNode.getAggCallsList()), protoAggregateNode.getFilterArgsList(), protoAggregateNode.getGroupKeysList(), - convertAggType(protoAggregateNode.getAggType())); + convertAggType(protoAggregateNode.getAggType()), protoAggregateNode.getLeafReturnFinalResult()); } private static FilterNode deserializeFilterNode(Plan.PlanNode protoNode) { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeSerializer.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeSerializer.java index 00a21c05e954..65ccb13b2cae 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeSerializer.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeSerializer.java @@ -53,7 +53,8 @@ private PlanNodeSerializer() { } public static Plan.PlanNode process(PlanNode planNode) { - Plan.PlanNode.Builder builder = Plan.PlanNode.newBuilder().setStageId(planNode.getStageId()) + Plan.PlanNode.Builder builder = Plan.PlanNode.newBuilder() + .setStageId(planNode.getStageId()) .setDataSchema(convertDataSchema(planNode.getDataSchema())) .setNodeHint(convertNodeHint(planNode.getNodeHint())); planNode.visit(SerializationVisitor.INSTANCE, builder); @@ -91,10 +92,13 @@ private SerializationVisitor() { @Override public Void visitAggregate(AggregateNode node, Plan.PlanNode.Builder builder) { - Plan.AggregateNode aggregateNode = - Plan.AggregateNode.newBuilder().addAllAggCalls(convertFunctionCalls(node.getAggCalls())) - .addAllFilterArgs(node.getFilterArgs()).addAllGroupKeys(node.getGroupKeys()) - .setAggType(convertAggType(node.getAggType())).build(); + Plan.AggregateNode aggregateNode = Plan.AggregateNode.newBuilder() + .addAllAggCalls(convertFunctionCalls(node.getAggCalls())) + .addAllFilterArgs(node.getFilterArgs()) + .addAllGroupKeys(node.getGroupKeys()) + .setAggType(convertAggType(node.getAggType())) + .setLeafReturnFinalResult(node.isLeafReturnFinalResult()) + .build(); builder.setAggregateNode(aggregateNode); return null; } @@ -102,42 +106,51 @@ public Void visitAggregate(AggregateNode node, Plan.PlanNode.Builder builder) { @Override public Void visitFilter(FilterNode node, Plan.PlanNode.Builder builder) { Plan.FilterNode filterNode = Plan.FilterNode.newBuilder() - .setCondition(RexExpressionToProtoExpression.convertExpression(node.getCondition())).build(); + .setCondition(RexExpressionToProtoExpression.convertExpression(node.getCondition())) + .build(); builder.setFilterNode(filterNode); return null; } @Override public Void visitJoin(JoinNode node, Plan.PlanNode.Builder builder) { - Plan.JoinNode joinNode = - Plan.JoinNode.newBuilder().setJoinType(convertJoinType(node.getJoinType())).addAllLeftKeys(node.getLeftKeys()) - .addAllRightKeys(node.getRightKeys()) - .addAllNonEquiConditions(convertExpressions(node.getNonEquiConditions())) - .setJoinStrategy(convertJoinStrategy(node.getJoinStrategy())).build(); + Plan.JoinNode joinNode = Plan.JoinNode.newBuilder() + .setJoinType(convertJoinType(node.getJoinType())) + .addAllLeftKeys(node.getLeftKeys()) + .addAllRightKeys(node.getRightKeys()) + .addAllNonEquiConditions(convertExpressions(node.getNonEquiConditions())) + .setJoinStrategy(convertJoinStrategy(node.getJoinStrategy())) + .build(); builder.setJoinNode(joinNode); return null; } @Override public Void visitMailboxReceive(MailboxReceiveNode node, Plan.PlanNode.Builder builder) { - Plan.MailboxReceiveNode mailboxReceiveNode = - Plan.MailboxReceiveNode.newBuilder().setSenderStageId(node.getSenderStageId()) - .setExchangeType(convertExchangeType(node.getExchangeType())) - .setDistributionType(convertDistributionType(node.getDistributionType())).addAllKeys(node.getKeys()) - .addAllCollations(convertCollations(node.getCollations())).setSort(node.isSort()) - .setSortedOnSender(node.isSortedOnSender()).build(); + Plan.MailboxReceiveNode mailboxReceiveNode = Plan.MailboxReceiveNode.newBuilder() + .setSenderStageId(node.getSenderStageId()) + .setExchangeType(convertExchangeType(node.getExchangeType())) + .setDistributionType(convertDistributionType(node.getDistributionType())) + .addAllKeys(node.getKeys()) + .addAllCollations(convertCollations(node.getCollations())) + .setSort(node.isSort()) + .setSortedOnSender(node.isSortedOnSender()) + .build(); builder.setMailboxReceiveNode(mailboxReceiveNode); return null; } @Override public Void visitMailboxSend(MailboxSendNode node, Plan.PlanNode.Builder builder) { - Plan.MailboxSendNode mailboxSendNode = - Plan.MailboxSendNode.newBuilder().setReceiverStageId(node.getReceiverStageId()) - .setExchangeType(convertExchangeType(node.getExchangeType())) - .setDistributionType(convertDistributionType(node.getDistributionType())).addAllKeys(node.getKeys()) - .setPrePartitioned(node.isPrePartitioned()).addAllCollations(convertCollations(node.getCollations())) - .setSort(node.isSort()).build(); + Plan.MailboxSendNode mailboxSendNode = Plan.MailboxSendNode.newBuilder() + .setReceiverStageId(node.getReceiverStageId()) + .setExchangeType(convertExchangeType(node.getExchangeType())) + .setDistributionType(convertDistributionType(node.getDistributionType())) + .addAllKeys(node.getKeys()) + .setPrePartitioned(node.isPrePartitioned()) + .addAllCollations(convertCollations(node.getCollations())) + .setSort(node.isSort()) + .build(); builder.setMailboxSendNode(mailboxSendNode); return null; } @@ -160,9 +173,11 @@ public Void visitSetOp(SetOpNode node, Plan.PlanNode.Builder builder) { @Override public Void visitSort(SortNode node, Plan.PlanNode.Builder builder) { - Plan.SortNode sortNode = - Plan.SortNode.newBuilder().addAllCollations(convertCollations(node.getCollations())).setFetch(node.getFetch()) - .setOffset(node.getOffset()).build(); + Plan.SortNode sortNode = Plan.SortNode.newBuilder() + .addAllCollations(convertCollations(node.getCollations())) + .setFetch(node.getFetch()) + .setOffset(node.getOffset()) + .build(); builder.setSortNode(sortNode); return null; } @@ -185,10 +200,15 @@ public Void visitValue(ValueNode node, Plan.PlanNode.Builder builder) { @Override public Void visitWindow(WindowNode node, Plan.PlanNode.Builder builder) { - Plan.WindowNode windowNode = Plan.WindowNode.newBuilder().addAllAggCalls(convertFunctionCalls(node.getAggCalls())) - .addAllKeys(node.getKeys()).addAllCollations(convertCollations(node.getCollations())) - .setWindowFrameType(convertWindowFrameType(node.getWindowFrameType())).setLowerBound(node.getLowerBound()) - .setUpperBound(node.getUpperBound()).addAllConstants(convertLiterals(node.getConstants())).build(); + Plan.WindowNode windowNode = Plan.WindowNode.newBuilder() + .addAllAggCalls(convertFunctionCalls(node.getAggCalls())) + .addAllKeys(node.getKeys()) + .addAllCollations(convertCollations(node.getCollations())) + .setWindowFrameType(convertWindowFrameType(node.getWindowFrameType())) + .setLowerBound(node.getLowerBound()) + .setUpperBound(node.getUpperBound()) + .addAllConstants(convertLiterals(node.getConstants())) + .build(); builder.setWindowNode(windowNode); return null; } @@ -200,10 +220,8 @@ public Void visitExchange(ExchangeNode exchangeNode, Plan.PlanNode.Builder conte @Override public Void visitExplained(ExplainedNode node, Plan.PlanNode.Builder builder) { - Plan.ExplainNode explainNode = Plan.ExplainNode.newBuilder() - .setTitle(node.getTitle()) - .putAllAttributes(node.getAttributes()) - .build(); + Plan.ExplainNode explainNode = + Plan.ExplainNode.newBuilder().setTitle(node.getTitle()).putAllAttributes(node.getAttributes()).build(); builder.setExplainNode(explainNode); return null; } @@ -339,9 +357,11 @@ private static Plan.NullDirection convertNullDirection(RelFieldCollation.NullDir private static List convertCollations(List collations) { List protoCollations = new ArrayList<>(collations.size()); for (RelFieldCollation collation : collations) { - protoCollations.add(Plan.Collation.newBuilder().setIndex(collation.getFieldIndex()) + protoCollations.add(Plan.Collation.newBuilder() + .setIndex(collation.getFieldIndex()) .setDirection(convertDirection(collation.direction)) - .setNullDirection(convertNullDirection(collation.nullDirection)).build()); + .setNullDirection(convertNullDirection(collation.nullDirection)) + .build()); } return protoCollations; } diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java index c43f54344029..0be64ad20f12 100644 --- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java +++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java @@ -85,9 +85,9 @@ public void testAggregateCaseToFilter() { assertEquals(explain, "Execution Plan\n" + "LogicalProject(EXPR$0=[CASE(=($1, 0), null:BIGINT, $0)])\n" - + " PinotLogicalAggregate(group=[{}], agg#0=[COUNT($0)], agg#1=[COUNT($1)])\n" + + " PinotLogicalAggregate(group=[{}], agg#0=[COUNT($0)], agg#1=[COUNT($1)], aggType=[FINAL])\n" + " PinotLogicalExchange(distribution=[hash])\n" - + " PinotLogicalAggregate(group=[{}], agg#0=[COUNT() FILTER $0], agg#1=[COUNT()])\n" + + " PinotLogicalAggregate(group=[{}], agg#0=[COUNT() FILTER $0], agg#1=[COUNT()], aggType=[LEAF])\n" + " LogicalProject($f1=[=($0, _UTF-8'a')])\n" + " LogicalTableScan(table=[[default, a]])\n"); //@formatter:on diff --git a/pinot-query-planner/src/test/resources/queries/AggregatePlans.json b/pinot-query-planner/src/test/resources/queries/AggregatePlans.json index 02c91a64904c..7b0134df8bcc 100644 --- a/pinot-query-planner/src/test/resources/queries/AggregatePlans.json +++ b/pinot-query-planner/src/test/resources/queries/AggregatePlans.json @@ -7,18 +7,18 @@ "output": [ "Execution Plan", "\nLogicalProject(avg=[CAST(/(CASE(=($1, 0), null:DECIMAL(1000, 0), $0), $1)):DECIMAL(1000, 0)])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($1)], agg#1=[COUNT()])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($1)], agg#1=[COUNT()], aggType=[LEAF])", "\n LogicalJoin(condition=[>=($0, $2)], joinType=[inner])", "\n PinotLogicalExchange(distribution=[random])", "\n LogicalProject(col3=[$2], col4=[$3])", "\n LogicalTableScan(table=[[default, a]])", "\n PinotLogicalExchange(distribution=[broadcast])", "\n LogicalProject(EXPR$0=[CAST(/(CASE(=($1, 0), null:DECIMAL(1000, 0), $0), $1)):DECIMAL(1000, 0)])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($3)], agg#1=[COUNT()])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($3)], agg#1=[COUNT()], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -29,9 +29,9 @@ "output": [ "Execution Plan", "\nLogicalProject(avg=[CAST(/(CASE(=($1, 0), null:DECIMAL(1000, 0), $0), $1)):DECIMAL(1000, 0)])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($3)], agg#1=[COUNT()])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($3)], agg#1=[COUNT()], aggType=[LEAF])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'pink floyd'))])", "\n LogicalTableScan(table=[[default, a]])", "\n" @@ -43,9 +43,9 @@ "output": [ "Execution Plan", "\nLogicalProject(avg=[CAST(/(CASE(=($1, 0), null:DECIMAL(1000, 0), $0), $1)):DECIMAL(1000, 0)], sum=[CASE(=($1, 0), null:DECIMAL(1000, 0), $0)], max=[$2])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)], agg#2=[MAX($2)])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)], agg#2=[MAX($2)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($3)], agg#1=[COUNT()], agg#2=[MAX($3)])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($3)], agg#1=[COUNT()], agg#2=[MAX($3)], aggType=[LEAF])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'pink floyd'))])", "\n LogicalTableScan(table=[[default, a]])", "\n" @@ -57,9 +57,9 @@ "output": [ "Execution Plan", "\nLogicalProject(avg=[/(CAST(CASE(=($1, 0), null:BIGINT, $0)):DOUBLE, $1)], count=[$1])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()], aggType=[LEAF])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'pink floyd'))])", "\n LogicalTableScan(table=[[default, a]])", "\n" @@ -71,9 +71,9 @@ "output": [ "Execution Plan", "\nLogicalProject(EXPR$0=[CASE(=($1, 0), null:BIGINT, $0)], EXPR$1=[$1])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -84,9 +84,9 @@ "output": [ "Execution Plan", "\nLogicalProject(EXPR$0=[CASE(=($1, 0), null:BIGINT, $0)], EXPR$1=[$1])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()], aggType=[LEAF])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", "\n LogicalTableScan(table=[[default, a]])", "\n" @@ -98,9 +98,9 @@ "output": [ "Execution Plan", "\nLogicalProject(sum=[CASE(=($1, 0), null:BIGINT, $0)], count=[$1])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()], aggType=[LEAF])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'pink floyd'))])", "\n LogicalTableScan(table=[[default, a]])", "\n" @@ -112,9 +112,9 @@ "output": [ "Execution Plan", "\nLogicalProject(avg=[/(CAST(CASE(=($1, 0), null:BIGINT, $0)):DOUBLE, $1)], count=[$1])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()], aggType=[LEAF])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'pink floyd'))])", "\n LogicalTableScan(table=[[default, a]])", "\n" @@ -126,9 +126,9 @@ "output": [ "Execution Plan", "\nLogicalProject(sum=[CASE(=($1, 0), null:BIGINT, $0)], count=[$1])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()], aggType=[LEAF])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'pink floyd'))])", "\n LogicalTableScan(table=[[default, a]])", "\n" @@ -139,19 +139,19 @@ "sql": "EXPLAIN PLAN FOR with teamOne as (select /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ col2, percentile(col3, 50) as sum_of_runs from a group by col2), teamTwo as (select /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ col2, percentile(col3, 50) as sum_of_runs from a group by col2), all as (select col2, sum_of_runs from teamOne union all select col2, sum_of_runs from teamTwo) select /*+ aggOption(is_skip_leaf_stage_group_by='true') */ col2, percentile(sum_of_runs, 50) from all group by col2", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[PERCENTILE($1, 50)])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[PERCENTILE($1, 50)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[PERCENTILE($1, 50)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[PERCENTILE($1, 50)], aggType=[LEAF])", "\n LogicalUnion(all=[true])", "\n PinotLogicalExchange(distribution=[hash[0, 1, 2]])", "\n LogicalProject(col2=[$0], sum_of_runs=[$1], $f2=[50])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[PERCENTILE($1, 50)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[PERCENTILE($1, 50)], aggType=[DIRECT])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col2=[$1], col3=[$2], $f2=[50])", "\n LogicalTableScan(table=[[default, a]])", "\n PinotLogicalExchange(distribution=[hash[0, 1, 2]])", "\n LogicalProject(col2=[$0], sum_of_runs=[$1], $f2=[50])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[PERCENTILE($1, 50)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[PERCENTILE($1, 50)], aggType=[DIRECT])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col2=[$1], col3=[$2], $f2=[50])", "\n LogicalTableScan(table=[[default, a]])", diff --git a/pinot-query-planner/src/test/resources/queries/BasicQueryPlans.json b/pinot-query-planner/src/test/resources/queries/BasicQueryPlans.json index 7c5f0de3d051..8b947e146080 100644 --- a/pinot-query-planner/src/test/resources/queries/BasicQueryPlans.json +++ b/pinot-query-planner/src/test/resources/queries/BasicQueryPlans.json @@ -88,9 +88,9 @@ "output": [ "Execution Plan", "\nLogicalProject(EXPR$0=[CASE(=($1, 0), null:DECIMAL(1000, 500), $0)])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT()])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT()], aggType=[LEAF])", "\n LogicalProject($f0=[CAST(CASE(>($2, 10), _UTF-8'1':VARCHAR CHARACTER SET \"UTF-8\", >($2, 20), _UTF-8'2':VARCHAR CHARACTER SET \"UTF-8\", >($2, 30), _UTF-8'3':VARCHAR CHARACTER SET \"UTF-8\", >($2, 40), _UTF-8'4':VARCHAR CHARACTER SET \"UTF-8\", >($2, 50), _UTF-8'5':VARCHAR CHARACTER SET \"UTF-8\", _UTF-8'0':VARCHAR CHARACTER SET \"UTF-8\")):DECIMAL(1000, 500) NOT NULL])", "\n LogicalTableScan(table=[[default, a]])", "\n" @@ -102,9 +102,9 @@ "output": [ "Execution Plan", "\nLogicalProject(sumCol3=[$1], EXPR$1=[$0])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[LEAF])", "\n LogicalProject(EXPR$1=[ARRAY_TO_MV($6)], col3=[$2])", "\n LogicalTableScan(table=[[default, e]])", "\n" diff --git a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json index 2de1a6c93cf4..63a69f5e8ecb 100644 --- a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json +++ b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json @@ -6,9 +6,9 @@ "sql": "EXPLAIN PLAN FOR SELECT a.col1, SUM(a.col3) FROM a GROUP BY a.col1", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -19,9 +19,9 @@ "output": [ "Execution Plan", "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[/(CAST($1):DOUBLE NOT NULL, $2)], EXPR$3=[$3], EXPR$4=[$4])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)], agg#2=[MAX($3)], agg#3=[MIN($4)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[COUNT()], agg#2=[MAX($2)], agg#3=[MIN($2)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[COUNT()], agg#2=[MAX($2)], agg#3=[MIN($2)], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -31,9 +31,9 @@ "sql": "EXPLAIN PLAN FOR SELECT a.col1, SUM(a.col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col1", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], aggType=[LEAF])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", "\n LogicalTableScan(table=[[default, a]])", "\n" @@ -45,9 +45,9 @@ "notes": "TODO: Needs follow up. Project should only keep a.col1 since the other columns are pushed to the filter, but it currently keeps them all", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[COUNT($1)])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[COUNT($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT()])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT()], aggType=[LEAF])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", "\n LogicalTableScan(table=[[default, a]])", "\n" @@ -59,9 +59,9 @@ "output": [ "Execution Plan", "\nLogicalProject(col2=[$1], col1=[$0], EXPR$2=[$2])", - "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", - "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], aggType=[LEAF])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($0, _UTF-8'a'))])", "\n LogicalTableScan(table=[[default, a]])", "\n" @@ -74,9 +74,9 @@ "Execution Plan", "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[$2])", "\n LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT($1)], agg#1=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT($1)], agg#1=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)], aggType=[LEAF])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", "\n LogicalTableScan(table=[[default, a]])", "\n" @@ -89,9 +89,9 @@ "Execution Plan", "\nLogicalProject(value1=[$0], count=[$1], SUM=[$2])", "\n LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT($1)], agg#1=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT($1)], agg#1=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)], aggType=[LEAF])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", "\n LogicalTableScan(table=[[default, a]])", "\n" @@ -102,7 +102,7 @@ "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, SUM(a.col3) FROM a GROUP BY a.col1", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[DIRECT])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", @@ -115,7 +115,7 @@ "output": [ "Execution Plan", "\nLogicalProject(col1=[$0], EXPR$1=[/(CAST($1):DOUBLE NOT NULL, $2)])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT()])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT()], aggType=[DIRECT])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", @@ -128,7 +128,7 @@ "output": [ "Execution Plan", "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[/(CAST($1):DOUBLE NOT NULL, $2)], EXPR$3=[$3], EXPR$4=[$4])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT()], agg#2=[MAX($1)], agg#3=[MIN($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT()], agg#2=[MAX($1)], agg#3=[MIN($1)], aggType=[DIRECT])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", @@ -140,7 +140,7 @@ "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, SUM(a.col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col1", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[DIRECT])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", @@ -153,7 +153,7 @@ "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, SUM(a.col3), MAX(a.col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col1", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[MAX($1)])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[MAX($1)], aggType=[DIRECT])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", @@ -167,7 +167,7 @@ "notes": "TODO: Needs follow up. Project should only keep a.col1 since the other columns are pushed to the filter, but it currently keeps them all", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[COUNT()])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[COUNT()], aggType=[DIRECT])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", @@ -181,7 +181,7 @@ "output": [ "Execution Plan", "\nLogicalProject(col2=[$1], col1=[$0], EXPR$2=[$2])", - "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], aggType=[DIRECT])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($0, _UTF-8'a'))])", @@ -196,7 +196,7 @@ "Execution Plan", "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[$2])", "\n LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)], aggType=[DIRECT])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", @@ -211,7 +211,7 @@ "Execution Plan", "\nLogicalProject(col1=[$0], EXPR$1=[$1])", "\n LogicalFilter(condition=[AND(>=($2, 0), <($3, 20), <=($1, 10), =(/(CAST($1):DOUBLE NOT NULL, $4), 5))])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[MAX($1)], agg#2=[MIN($1)], agg#3=[COUNT()])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[MAX($1)], agg#2=[MIN($1)], agg#3=[COUNT()], aggType=[DIRECT])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", @@ -226,13 +226,29 @@ "Execution Plan", "\nLogicalProject(value1=[$0], count=[$1], SUM=[$2])", "\n LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)], aggType=[DIRECT])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] + }, + { + "description": "SQL hint based group by optimization with partitioned aggregated values", + "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_leaf_return_final_result='true') */ col1, COUNT(DISTINCT col2) AS cnt FROM a WHERE a.col3 >= 0 GROUP BY col1 ORDER BY cnt DESC LIMIT 10", + "output": [ + "Execution Plan", + "\nLogicalSort(sort0=[$1], dir0=[DESC], offset=[0], fetch=[10])", + "\n PinotLogicalSortExchange(distribution=[hash], collation=[[1 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])", + "\n LogicalSort(sort0=[$1], dir0=[DESC], fetch=[10])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[DISTINCTCOUNT($1)], aggType=[FINAL], leafReturnFinalResult=[true])", + "\n PinotLogicalExchange(distribution=[hash[0]])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[DISTINCTCOUNT($1)], aggType=[LEAF], leafReturnFinalResult=[true])", + "\n LogicalFilter(condition=[>=($2, 0)])", + "\n LogicalTableScan(table=[[default, a]])", + "\n" + ] } ] } diff --git a/pinot-query-planner/src/test/resources/queries/JoinPlans.json b/pinot-query-planner/src/test/resources/queries/JoinPlans.json index 18f19ee5b245..fb63399fac71 100644 --- a/pinot-query-planner/src/test/resources/queries/JoinPlans.json +++ b/pinot-query-planner/src/test/resources/queries/JoinPlans.json @@ -115,9 +115,9 @@ "output": [ "Execution Plan", "\nLogicalProject(col1=[$0], EXPR$1=[/(CAST($1):DOUBLE NOT NULL, $2)])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[COUNT()])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[COUNT()], aggType=[LEAF])", "\n LogicalJoin(condition=[=($0, $1)], joinType=[inner])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0])", @@ -151,9 +151,9 @@ "sql": "EXPLAIN PLAN FOR SELECT a.col2, a.col3 FROM a JOIN b ON a.col1 = b.col1 WHERE a.col3 >= 0 GROUP BY a.col2, a.col3", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0, 1}])", + "\nPinotLogicalAggregate(group=[{0, 1}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", - "\n PinotLogicalAggregate(group=[{1, 2}])", + "\n PinotLogicalAggregate(group=[{1, 2}], aggType=[LEAF])", "\n LogicalJoin(condition=[=($0, $3)], joinType=[inner])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])", @@ -170,9 +170,9 @@ "sql": "EXPLAIN PLAN FOR SELECT a.col2, a.col3 as value3 FROM a JOIN b ON a.col1 = b.col1 WHERE a.col3 >= 0 GROUP BY a.col2, a.col3", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0, 1}])", + "\nPinotLogicalAggregate(group=[{0, 1}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", - "\n PinotLogicalAggregate(group=[{1, 2}])", + "\n PinotLogicalAggregate(group=[{1, 2}], aggType=[LEAF])", "\n LogicalJoin(condition=[=($0, $3)], joinType=[inner])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])", @@ -287,25 +287,25 @@ "\n LogicalFilter(condition=[=($1, _UTF-8'test')])", "\n LogicalTableScan(table=[[default, a]])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)], aggType=[LEAF])", "\n LogicalProject(col3=[$2], $f1=[true])", "\n LogicalFilter(condition=[=($0, _UTF-8'foo')])", "\n LogicalTableScan(table=[[default, b]])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col3=[$0], $f1=[$1])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)], aggType=[LEAF])", "\n LogicalProject(col3=[$2], $f1=[true])", "\n LogicalFilter(condition=[=($0, _UTF-8'bar')])", "\n LogicalTableScan(table=[[default, b]])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col3=[$0], $f1=[$1])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)], aggType=[LEAF])", "\n LogicalProject(col3=[$2], $f1=[true])", "\n LogicalFilter(condition=[=($0, _UTF-8'foobar')])", "\n LogicalTableScan(table=[[default, b]])", @@ -324,9 +324,9 @@ "\n LogicalTableScan(table=[[default, a]])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", "\n LogicalProject(col1=[$0], col2=[$1], EXPR$0=[*(0.5:DECIMAL(2, 1), $2)])", - "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", - "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, b]])", "\n" ] @@ -342,9 +342,9 @@ "\n LogicalProject(col1=[$0], col2=[$1], col4=[$3])", "\n LogicalTableScan(table=[[default, a]])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", - "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", - "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], aggType=[LEAF])", "\n LogicalProject(col1=[$0], col2=[$1], $f0=[*(0.5:DECIMAL(2, 1), $2)])", "\n LogicalTableScan(table=[[default, b]])", "\n" @@ -362,9 +362,9 @@ "\n LogicalTableScan(table=[[default, a]])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", "\n LogicalProject(col1=[$0], col2=[$1], EXPR$0=[CAST(/($2, $3)):DECIMAL(12, 1) NOT NULL])", - "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], agg#1=[COUNT($3)])", + "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], agg#1=[COUNT($3)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", - "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], agg#1=[COUNT()])", + "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], agg#1=[COUNT()], aggType=[LEAF])", "\n LogicalProject(col1=[$0], col2=[$1], $f0=[*(0.5:DECIMAL(2, 1), $2)])", "\n LogicalTableScan(table=[[default, b]])", "\n" @@ -382,9 +382,9 @@ "\n LogicalTableScan(table=[[default, a]])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", "\n LogicalProject(col1=[$0], col2=[$1], EXPR$0=[*(0.5:DECIMAL(2, 1), /(CAST($2):DOUBLE NOT NULL, $3))])", - "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], agg#1=[COUNT($3)])", + "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], agg#1=[COUNT($3)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", - "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], agg#1=[COUNT()])", + "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], agg#1=[COUNT()], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, b]])", "\n" ] @@ -400,9 +400,9 @@ "\n LogicalProject(col1=[$0], col2=[$1])", "\n LogicalTableScan(table=[[default, a]])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{1}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{1}], agg#0=[$SUM0($2)], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, b]])", "\n" ] @@ -455,9 +455,9 @@ "sql": "EXPLAIN PLAN FOR WITH tmp1 AS ( SELECT * FROM a WHERE col2 NOT IN ('foo', 'bar') ) SELECT COUNT(*) FROM a WHERE col2 IN (SELECT col1 FROM tmp1) AND col3 IN (SELECT col3 from b WHERE col3 < 100)", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{}], agg#0=[COUNT($0)])", + "\nPinotLogicalAggregate(group=[{}], agg#0=[COUNT($0)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[COUNT()])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[COUNT()], aggType=[LEAF])", "\n LogicalJoin(condition=[=($0, $1)], joinType=[semi])", "\n LogicalProject(col3=[$1])", "\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])", @@ -481,9 +481,9 @@ "Execution Plan", "\nLogicalProject(col1=[$0], EXPR$1=[$1])", "\n LogicalFilter(condition=[>($2, 10)])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT()])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT()], aggType=[LEAF])", "\n LogicalJoin(condition=[=($1, $2)], joinType=[semi])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalJoin(condition=[=($1, $3)], joinType=[semi])", @@ -525,9 +525,9 @@ "sql": "EXPLAIN PLAN FOR SELECT count(*) FROM a WHERE a.col1 = 'foo' AND col2 = 'xylo' AND a.col4 = 12 AND a.col5 = false AND col3 NOT IN (SELECT col3 FROM b WHERE col1='foo') AND col3 NOT IN (SELECT col3 FROM b WHERE col1='bar') AND col3 NOT IN (SELECT col3 FROM b WHERE col1='foobar') AND col3 IN (SELECT col3 FROM b WHERE col1 = 'fork')", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{}], agg#0=[COUNT($0)])", + "\nPinotLogicalAggregate(group=[{}], agg#0=[COUNT($0)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[COUNT()])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[COUNT()], aggType=[LEAF])", "\n LogicalJoin(condition=[=($0, $1)], joinType=[semi])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col3=[$0])", @@ -547,25 +547,25 @@ "\n LogicalTableScan(table=[[default, a]])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col3=[$0], $f1=[$1])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)], aggType=[LEAF])", "\n LogicalProject(col3=[$2], $f1=[true])", "\n LogicalFilter(condition=[=($0, _UTF-8'foo')])", "\n LogicalTableScan(table=[[default, b]])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col3=[$0], $f1=[$1])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)], aggType=[LEAF])", "\n LogicalProject(col3=[$2], $f1=[true])", "\n LogicalFilter(condition=[=($0, _UTF-8'bar')])", "\n LogicalTableScan(table=[[default, b]])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col3=[$0], $f1=[$1])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[MIN($1)], aggType=[LEAF])", "\n LogicalProject(col3=[$2], $f1=[true])", "\n LogicalFilter(condition=[=($0, _UTF-8'foobar')])", "\n LogicalTableScan(table=[[default, b]])", diff --git a/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json b/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json index d9ba67e9a42d..098daef73199 100644 --- a/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json +++ b/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json @@ -245,9 +245,9 @@ "sql": "EXPLAIN PLAN FOR SELECT count(*) FROM a WHERE col1 > ToEpochDays(fromDateTime('1970-01-15', 'yyyy-MM-dd'))", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{}], agg#0=[COUNT($0)])", + "\nPinotLogicalAggregate(group=[{}], agg#0=[COUNT($0)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash])", - "\n PinotLogicalAggregate(group=[{}], agg#0=[COUNT()])", + "\n PinotLogicalAggregate(group=[{}], agg#0=[COUNT()], aggType=[LEAF])", "\n LogicalFilter(condition=[>(CAST($0):BIGINT NOT NULL, 14)])", "\n LogicalTableScan(table=[[default, a]])", "\n" @@ -288,9 +288,9 @@ "output": [ "Execution Plan", "\nLogicalProject(__ts=[1692057600000:BIGINT])", - "\n PinotLogicalAggregate(group=[{0}])", + "\n PinotLogicalAggregate(group=[{0}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}])", + "\n PinotLogicalAggregate(group=[{0}], aggType=[LEAF])", "\n LogicalProject(__ts=[1692057600000:BIGINT])", "\n LogicalTableScan(table=[[default, a]])", "\n" diff --git a/pinot-query-planner/src/test/resources/queries/OrderByPlans.json b/pinot-query-planner/src/test/resources/queries/OrderByPlans.json index f0bd14ea7404..157b9e13624a 100644 --- a/pinot-query-planner/src/test/resources/queries/OrderByPlans.json +++ b/pinot-query-planner/src/test/resources/queries/OrderByPlans.json @@ -79,9 +79,9 @@ "Execution Plan", "\nLogicalSort(sort0=[$0], dir0=[ASC])", "\n PinotLogicalSortExchange(distribution=[hash], collation=[[0]], isSortOnSender=[false], isSortOnReceiver=[true])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -93,7 +93,7 @@ "Execution Plan", "\nLogicalSort(sort0=[$0], dir0=[ASC])", "\n PinotLogicalSortExchange(distribution=[hash], collation=[[0]], isSortOnSender=[false], isSortOnReceiver=[true])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[DIRECT])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", @@ -107,9 +107,9 @@ "Execution Plan", "\nLogicalSort(sort0=[$0], dir0=[ASC])", "\n PinotLogicalSortExchange(distribution=[hash], collation=[[0]], isSortOnSender=[false], isSortOnReceiver=[true])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -121,7 +121,7 @@ "Execution Plan", "\nLogicalSort(sort0=[$0], dir0=[ASC])", "\n PinotLogicalSortExchange(distribution=[hash], collation=[[0]], isSortOnSender=[false], isSortOnReceiver=[true])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[DIRECT])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", diff --git a/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json b/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json index dc615f774fce..f26a1330169b 100644 --- a/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json +++ b/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json @@ -27,7 +27,7 @@ "output": [ "Execution Plan", "\nLogicalProject(col1=[$0], EXPR$1=[/(CAST($1):DOUBLE NOT NULL, $2)])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[COUNT()])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[COUNT()], aggType=[DIRECT])", "\n LogicalJoin(condition=[=($0, $1)], joinType=[inner])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0])", @@ -80,9 +80,9 @@ "sql": "EXPLAIN PLAN FOR SELECT a.col1, SUM(a.col3) FROM a WHERE a.col1 IN (SELECT col2 FROM b WHERE b.col3 > 0) AND a.col2 IN (select col1 FROM c WHERE c.col3 > 0) GROUP BY 1", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], aggType=[LEAF])", "\n LogicalJoin(condition=[=($1, $3)], joinType=[semi])", "\n LogicalJoin(condition=[=($0, $3)], joinType=[semi])", "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])", @@ -103,7 +103,7 @@ "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') */ a.col1, SUM(a.col3) FROM a WHERE a.col1 IN (SELECT col2 FROM b WHERE b.col3 > 0) GROUP BY 1", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[DIRECT])", "\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])", "\n LogicalProject(col1=[$0], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", @@ -119,9 +119,9 @@ "sql": "EXPLAIN PLAN FOR SELECT a.col2, SUM(a.col3) FROM a WHERE a.col1 IN (SELECT col2 FROM b WHERE b.col3 > 0) GROUP BY 1", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{1}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{1}], agg#0=[$SUM0($2)], aggType=[LEAF])", "\n LogicalJoin(condition=[=($0, $3)], joinType=[semi])", "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", @@ -138,7 +138,7 @@ "output": [ "Execution Plan", "\nLogicalProject(col2=[$1], col1=[$0], EXPR$2=[$2])", - "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], aggType=[DIRECT])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($0, _UTF-8'a'))])", @@ -153,7 +153,7 @@ "Execution Plan", "\nLogicalProject(col2=[$0], EXPR$1=[$1], EXPR$2=[$2], EXPR$3=[$3])", "\n LogicalFilter(condition=[AND(>($1, 10), >=($4, 0), <($5, 20), <=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($1)], agg#2=[$SUM0($2)], agg#3=[MAX($1)], agg#4=[MIN($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($1)], agg#2=[$SUM0($2)], agg#3=[MAX($1)], agg#4=[MIN($1)], aggType=[DIRECT])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col2=[$1], col3=[$2], $f2=[CAST($0):DECIMAL(1000, 500) NOT NULL])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", @@ -166,7 +166,7 @@ "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') */ a.col2, COUNT(*), SUM(a.col3), SUM(a.col1) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col2", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($1)], agg#2=[$SUM0($2)])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($1)], agg#2=[$SUM0($2)], aggType=[DIRECT])", "\n LogicalProject(col2=[$1], col3=[$2], $f2=[CAST($0):DECIMAL(1000, 500) NOT NULL])", "\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])", "\n LogicalTableScan(table=[[default, a]])", @@ -199,9 +199,9 @@ "sql": "EXPLAIN PLAN FOR SELECT a.col2, SUM(a.col3) FROM a /*+ tableOptions(partition_function='hashcode', partition_key='col2', partition_size='4') */ GROUP BY 1", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{1}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{1}], agg#0=[$SUM0($2)], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -211,9 +211,9 @@ "sql": "EXPLAIN PLAN FOR SELECT a.col1, SUM(a.col3) FROM a /*+ tableOptions(partition_function='hashcode', partition_key='col2', partition_size='4') */ GROUP BY 1", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -223,9 +223,9 @@ "sql": "EXPLAIN PLAN FOR SELECT a.col2, SUM(a.col3) FROM a /*+ tableOptions(partition_function='hashcode', partition_key='col2', partition_size='4') */ JOIN b /*+ tableOptions(partition_function='hashcode', partition_key='col1', partition_size='4') */ ON a.col2 = b.col1 WHERE b.col3 > 0 GROUP BY 1", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[LEAF])", "\n LogicalJoin(condition=[=($0, $2)], joinType=[inner])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col2=[$1], col3=[$2])", @@ -242,9 +242,9 @@ "sql": "EXPLAIN PLAN FOR SELECT a.col2, b.col2, SUM(a.col3) FROM a /*+ tableOptions(partition_function='hashcode', partition_key='col2', partition_size='4') */ JOIN b /*+ tableOptions(partition_function='hashcode', partition_key='col1', partition_size='4') */ ON a.col2 = b.col1 WHERE b.col3 > 0 GROUP BY 1, 2", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])", + "\nPinotLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", - "\n PinotLogicalAggregate(group=[{0, 3}], agg#0=[$SUM0($1)])", + "\n PinotLogicalAggregate(group=[{0, 3}], agg#0=[$SUM0($1)], aggType=[LEAF])", "\n LogicalJoin(condition=[=($0, $2)], joinType=[inner])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col2=[$1], col3=[$2])", @@ -261,9 +261,9 @@ "sql": "EXPLAIN PLAN FOR SELECT b.col2, SUM(a.col3) FROM a /*+ tableOptions(partition_function='hashcode', partition_key='col2', partition_size='4') */ JOIN b /*+ tableOptions(partition_function='hashcode', partition_key='col1', partition_size='4') */ ON a.col2 = b.col1 WHERE b.col3 > 0 GROUP BY 1", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{3}], agg#0=[$SUM0($1)])", + "\n PinotLogicalAggregate(group=[{3}], agg#0=[$SUM0($1)], aggType=[LEAF])", "\n LogicalJoin(condition=[=($0, $2)], joinType=[inner])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col2=[$1], col3=[$2])", @@ -280,9 +280,9 @@ "sql": "EXPLAIN PLAN FOR SELECT a.col2, SUM(a.col3) FROM a /*+ tableOptions(partition_function='hashcode', partition_key='col2', partition_size='4') */ WHERE a.col2 IN (SELECT col1 FROM b /*+ tableOptions(partition_function='hashcode', partition_key='col1', partition_size='4') */ WHERE b.col3 > 0) GROUP BY 1", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[LEAF])", "\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])", "\n LogicalProject(col2=[$1], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", @@ -298,9 +298,9 @@ "sql": "EXPLAIN PLAN FOR SELECT a.col2, SUM(a.col3) FROM a /*+ tableOptions(partition_function='hashcode', partition_key='col2', partition_size='4') */ WHERE a.col2 IN (SELECT col1 FROM b WHERE b.col3 > 0) GROUP BY 1", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[LEAF])", "\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])", "\n LogicalProject(col2=[$1], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", @@ -316,9 +316,9 @@ "sql": "EXPLAIN PLAN FOR SELECT a.col1, SUM(a.col3) FROM a /*+ tableOptions(partition_function='hashcode', partition_key='col2', partition_size='4') */ WHERE a.col2 IN (SELECT col1 FROM b WHERE b.col3 > 0) GROUP BY 1", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], aggType=[LEAF])", "\n LogicalJoin(condition=[=($1, $3)], joinType=[semi])", "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", @@ -334,18 +334,18 @@ "sql": "EXPLAIN PLAN FOR SELECT a.col1, SUM(a.col3) FROM a /*+ tableOptions(partition_function='hashcode', partition_key='col2', partition_size='4') */ WHERE a.col2 IN (SELECT col1 FROM b /*+ tableOptions(partition_function='hashcode', partition_key='col1', partition_size='4') */ WHERE b.col3 > 0 GROUP BY 1 HAVING COUNT(*) > 1) GROUP BY 1", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], aggType=[LEAF])", "\n LogicalJoin(condition=[=($1, $3)], joinType=[semi])", "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", "\n PinotLogicalExchange(distribution=[broadcast], relExchangeType=[PIPELINE_BREAKER])", "\n LogicalProject(col1=[$0])", "\n LogicalFilter(condition=[>($1, 1)])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT()])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT()], aggType=[LEAF])", "\n LogicalFilter(condition=[>($2, 0)])", "\n LogicalTableScan(table=[[default, b]])", "\n" @@ -358,9 +358,9 @@ "Execution Plan", "\nLogicalProject(col1=[$0], EXPR$1=[$1])", "\n LogicalFilter(condition=[>($2, 5)])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[COUNT()])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[COUNT()], aggType=[LEAF])", "\n LogicalJoin(condition=[=($1, $3)], joinType=[semi])", "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", @@ -378,9 +378,9 @@ "Execution Plan", "\nLogicalSort(sort0=[$1], dir0=[DESC])", "\n PinotLogicalSortExchange(distribution=[hash], collation=[[1 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], aggType=[LEAF])", "\n LogicalJoin(condition=[=($1, $3)], joinType=[semi])", "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", @@ -396,7 +396,7 @@ "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') */ a.col2, SUM(a.col3) FROM a /*+ tableOptions(partition_function='hashcode', partition_key='col2', partition_size='4') */ WHERE a.col2 IN (SELECT col1 FROM b /*+ tableOptions(partition_function='hashcode', partition_key='col1', partition_size='4') */ WHERE b.col3 > 0) GROUP BY 1", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[DIRECT])", "\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])", "\n LogicalProject(col2=[$1], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", @@ -412,7 +412,7 @@ "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') */ a.col2, SUM(a.col3) FROM a /*+ tableOptions(partition_function='hashcode', partition_key='col2', partition_size='4') */ WHERE a.col2 IN (SELECT col1 FROM b WHERE b.col3 > 0) GROUP BY 1", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[DIRECT])", "\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])", "\n LogicalProject(col2=[$1], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", @@ -430,7 +430,7 @@ "Execution Plan", "\nLogicalProject(col2=[$0], EXPR$1=[$1])", "\n LogicalFilter(condition=[>($2, 5)])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT()])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT()], aggType=[DIRECT])", "\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])", "\n LogicalProject(col2=[$1], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", @@ -448,7 +448,7 @@ "Execution Plan", "\nLogicalSort(sort0=[$1], dir0=[DESC])", "\n PinotLogicalSortExchange(distribution=[hash], collation=[[1 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[DIRECT])", "\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])", "\n LogicalProject(col2=[$1], col3=[$2])", "\n LogicalTableScan(table=[[default, a]])", diff --git a/pinot-query-planner/src/test/resources/queries/SetOpPlans.json b/pinot-query-planner/src/test/resources/queries/SetOpPlans.json index 0cea3a6429a6..004ad5f098de 100644 --- a/pinot-query-planner/src/test/resources/queries/SetOpPlans.json +++ b/pinot-query-planner/src/test/resources/queries/SetOpPlans.json @@ -41,9 +41,9 @@ "sql": "EXPLAIN PLAN FOR SELECT col1, col2 FROM a UNION SELECT col1, col2 FROM b UNION SELECT col1, col2 FROM c", "output": [ "Execution Plan", - "\nPinotLogicalAggregate(group=[{0, 1}])", + "\nPinotLogicalAggregate(group=[{0, 1}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", - "\n PinotLogicalAggregate(group=[{0, 1}])", + "\n PinotLogicalAggregate(group=[{0, 1}], aggType=[LEAF])", "\n LogicalUnion(all=[true])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", "\n LogicalUnion(all=[true])", diff --git a/pinot-query-planner/src/test/resources/queries/WindowFunctionPlans.json b/pinot-query-planner/src/test/resources/queries/WindowFunctionPlans.json index 6fd3cbb19332..8568a8f3cef5 100644 --- a/pinot-query-planner/src/test/resources/queries/WindowFunctionPlans.json +++ b/pinot-query-planner/src/test/resources/queries/WindowFunctionPlans.json @@ -231,9 +231,9 @@ "\nLogicalProject($0=[$1])", "\n LogicalWindow(window#0=[window(aggs [MIN($0)])])", "\n PinotLogicalExchange(distribution=[hash])", - "\n PinotLogicalAggregate(group=[{0}])", + "\n PinotLogicalAggregate(group=[{0}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{2}])", + "\n PinotLogicalAggregate(group=[{2}], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -246,9 +246,9 @@ "\nLogicalProject(col1=[$0], $1=[$2])", "\n LogicalWindow(window#0=[window(aggs [MIN($1)])])", "\n PinotLogicalExchange(distribution=[hash])", - "\n PinotLogicalAggregate(group=[{0, 1}])", + "\n PinotLogicalAggregate(group=[{0, 1}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", - "\n PinotLogicalAggregate(group=[{0, 2}])", + "\n PinotLogicalAggregate(group=[{0, 2}], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -261,9 +261,9 @@ "\nLogicalProject(EXPR$0=[$1], $1=[$2])", "\n LogicalWindow(window#0=[window(aggs [MIN($0)])])", "\n PinotLogicalExchange(distribution=[hash])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{2}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -277,9 +277,9 @@ "\n LogicalWindow(window#0=[window(aggs [SUM($0), COUNT($0)])])", "\n PinotLogicalExchange(distribution=[hash])", "\n LogicalProject(col3=[$0], EXPR$0=[CAST($0):DOUBLE NOT NULL])", - "\n PinotLogicalAggregate(group=[{0}])", + "\n PinotLogicalAggregate(group=[{0}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{2}])", + "\n PinotLogicalAggregate(group=[{2}], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -295,9 +295,9 @@ "\n LogicalWindow(window#0=[window(aggs [SUM($0), COUNT($0)])])", "\n PinotLogicalExchange(distribution=[hash])", "\n LogicalProject(col3=[$0], EXPR$0=[CAST($0):DOUBLE NOT NULL])", - "\n PinotLogicalAggregate(group=[{0}])", + "\n PinotLogicalAggregate(group=[{0}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{2}])", + "\n PinotLogicalAggregate(group=[{2}], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -313,9 +313,9 @@ "\n LogicalWindow(window#0=[window(rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])", "\n PinotLogicalExchange(distribution=[hash])", "\n LogicalProject(col3=[$0], EXPR$0=[CAST($0):DOUBLE NOT NULL])", - "\n PinotLogicalAggregate(group=[{0}])", + "\n PinotLogicalAggregate(group=[{0}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{2}])", + "\n PinotLogicalAggregate(group=[{2}], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -485,9 +485,9 @@ "\nLogicalProject($0=[$1], $1=[$2])", "\n LogicalWindow(window#0=[window(aggs [MIN($0), SUM($0)])])", "\n PinotLogicalExchange(distribution=[hash])", - "\n PinotLogicalAggregate(group=[{0}])", + "\n PinotLogicalAggregate(group=[{0}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{2}])", + "\n PinotLogicalAggregate(group=[{2}], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -500,9 +500,9 @@ "\nLogicalProject(col1=[$0], $1=[$2], $2=[$3])", "\n LogicalWindow(window#0=[window(aggs [MIN($1), COUNT($0)])])", "\n PinotLogicalExchange(distribution=[hash])", - "\n PinotLogicalAggregate(group=[{0, 1}])", + "\n PinotLogicalAggregate(group=[{0, 1}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", - "\n PinotLogicalAggregate(group=[{0, 2}])", + "\n PinotLogicalAggregate(group=[{0, 2}], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -515,9 +515,9 @@ "\nLogicalProject(EXPR$0=[$1], $1=[$2], $2=[$3])", "\n LogicalWindow(window#0=[window(aggs [MIN($0), MAX($0)])])", "\n PinotLogicalExchange(distribution=[hash])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{2}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -531,9 +531,9 @@ "\n LogicalWindow(window#0=[window(aggs [SUM($0), COUNT($0)])])", "\n PinotLogicalExchange(distribution=[hash])", "\n LogicalProject(col3=[$0], EXPR$0=[CAST($0):DOUBLE NOT NULL])", - "\n PinotLogicalAggregate(group=[{0}])", + "\n PinotLogicalAggregate(group=[{0}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{2}])", + "\n PinotLogicalAggregate(group=[{2}], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -549,9 +549,9 @@ "\n LogicalWindow(window#0=[window(aggs [SUM($0), COUNT($0)])])", "\n PinotLogicalExchange(distribution=[hash])", "\n LogicalProject(col3=[$0], EXPR$0=[CAST($0):DOUBLE NOT NULL])", - "\n PinotLogicalAggregate(group=[{0}])", + "\n PinotLogicalAggregate(group=[{0}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{2}])", + "\n PinotLogicalAggregate(group=[{2}], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -801,9 +801,9 @@ "\nLogicalProject($0=[$1])", "\n LogicalWindow(window#0=[window(partition {0} aggs [MIN($0)])])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}])", + "\n PinotLogicalAggregate(group=[{0}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{2}])", + "\n PinotLogicalAggregate(group=[{2}], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -816,9 +816,9 @@ "\nLogicalProject(col1=[$0], $1=[$2])", "\n LogicalWindow(window#0=[window(partition {0} aggs [MIN($1)])])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0, 1}])", + "\n PinotLogicalAggregate(group=[{0, 1}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", - "\n PinotLogicalAggregate(group=[{0, 2}])", + "\n PinotLogicalAggregate(group=[{0, 2}], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -831,9 +831,9 @@ "\nLogicalWindow(window#0=[window(partition {0} rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col1=[$0])", - "\n PinotLogicalAggregate(group=[{0, 1}])", + "\n PinotLogicalAggregate(group=[{0, 1}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", - "\n PinotLogicalAggregate(group=[{0, 2}])", + "\n PinotLogicalAggregate(group=[{0, 2}], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -846,9 +846,9 @@ "\nLogicalProject(EXPR$0=[$1], $1=[$2])", "\n LogicalWindow(window#0=[window(partition {0} aggs [MIN($0)])])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{2}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -862,9 +862,9 @@ "\n LogicalWindow(window#0=[window(partition {0} aggs [SUM($0), COUNT($0)])])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col3=[$0], EXPR$0=[CAST($0):DOUBLE NOT NULL])", - "\n PinotLogicalAggregate(group=[{0}])", + "\n PinotLogicalAggregate(group=[{0}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{2}])", + "\n PinotLogicalAggregate(group=[{2}], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -880,9 +880,9 @@ "\n LogicalWindow(window#0=[window(partition {0} aggs [SUM($0), COUNT($0)])])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col3=[$0], EXPR$0=[CAST($0):DOUBLE NOT NULL])", - "\n PinotLogicalAggregate(group=[{0}])", + "\n PinotLogicalAggregate(group=[{0}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{2}])", + "\n PinotLogicalAggregate(group=[{2}], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -1139,9 +1139,9 @@ "\nLogicalProject($0=[$1], $1=[$2])", "\n LogicalWindow(window#0=[window(partition {0} aggs [MIN($0), COUNT($0)])])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}])", + "\n PinotLogicalAggregate(group=[{0}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{2}])", + "\n PinotLogicalAggregate(group=[{2}], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -1154,9 +1154,9 @@ "\nLogicalProject(col1=[$0], EXPR$1=[$2], EXPR$2=[/(CAST($3):DOUBLE NOT NULL, $4)])", "\n LogicalWindow(window#0=[window(partition {0} aggs [MIN($1), SUM($1), COUNT($1)])])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0, 1}])", + "\n PinotLogicalAggregate(group=[{0, 1}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", - "\n PinotLogicalAggregate(group=[{0, 2}])", + "\n PinotLogicalAggregate(group=[{0, 2}], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -1169,9 +1169,9 @@ "\nLogicalProject(EXPR$0=[$1], $1=[$2], $2=[$3])", "\n LogicalWindow(window#0=[window(partition {0} aggs [MIN($0), SUM($0)])])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{2}], agg#0=[$SUM0($2)])", + "\n PinotLogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -1185,9 +1185,9 @@ "\n LogicalWindow(window#0=[window(partition {0} aggs [SUM($0), COUNT($0), MAX($0)])])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col3=[$0], EXPR$0=[CAST($0):DOUBLE NOT NULL])", - "\n PinotLogicalAggregate(group=[{0}])", + "\n PinotLogicalAggregate(group=[{0}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{2}])", + "\n PinotLogicalAggregate(group=[{2}], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -1203,9 +1203,9 @@ "\n LogicalWindow(window#0=[window(partition {0} aggs [SUM($0), COUNT($0), MAX($0)])])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n LogicalProject(col3=[$0], EXPR$0=[CAST($0):DOUBLE NOT NULL])", - "\n PinotLogicalAggregate(group=[{0}])", + "\n PinotLogicalAggregate(group=[{0}], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{2}])", + "\n PinotLogicalAggregate(group=[{2}], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -3347,9 +3347,9 @@ "\nLogicalProject(col1=[$0], EXPR$1=[$2], EXPR$2=[/(CAST($3):DOUBLE NOT NULL, $4)])", "\n LogicalWindow(window#0=[window(order by [2 DESC, 0] aggs [SUM($1), COUNT($1)])])", "\n PinotLogicalSortExchange(distribution=[hash], collation=[[2 DESC, 0]], isSortOnSender=[false], isSortOnReceiver=[true])", - "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)])", + "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", - "\n PinotLogicalAggregate(group=[{0, 2}], agg#0=[COUNT()])", + "\n PinotLogicalAggregate(group=[{0, 2}], agg#0=[COUNT()], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -3362,9 +3362,9 @@ "\nLogicalWindow(window#0=[window(order by [1 DESC, 0] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])", "\n PinotLogicalSortExchange(distribution=[hash], collation=[[1 DESC, 0]], isSortOnSender=[false], isSortOnReceiver=[true])", "\n LogicalProject(col1=[$0], EXPR$1=[$2])", - "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)])", + "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", - "\n PinotLogicalAggregate(group=[{0, 2}], agg#0=[COUNT()])", + "\n PinotLogicalAggregate(group=[{0, 2}], agg#0=[COUNT()], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -3377,9 +3377,9 @@ "\nLogicalWindow(window#0=[window(order by [1 DESC, 0] aggs [DENSE_RANK(), RANK()])])", "\n PinotLogicalSortExchange(distribution=[hash], collation=[[1 DESC, 0]], isSortOnSender=[false], isSortOnReceiver=[true])", "\n LogicalProject(col1=[$0], EXPR$1=[$2])", - "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)])", + "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", - "\n PinotLogicalAggregate(group=[{0, 2}], agg#0=[COUNT()])", + "\n PinotLogicalAggregate(group=[{0, 2}], agg#0=[COUNT()], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -3392,9 +3392,9 @@ "\nLogicalProject(col1=[$0], EXPR$1=[$2], $2=[$3])", "\n LogicalWindow(window#0=[window(partition {0} order by [2 DESC, 0] aggs [MAX($1)])])", "\n PinotLogicalSortExchange(distribution=[hash[0]], collation=[[2 DESC, 0]], isSortOnSender=[false], isSortOnReceiver=[true])", - "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)])", + "\n PinotLogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0, 1]])", - "\n PinotLogicalAggregate(group=[{0, 2}], agg#0=[COUNT()])", + "\n PinotLogicalAggregate(group=[{0, 2}], agg#0=[COUNT()], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -3449,9 +3449,9 @@ "\nLogicalFilter(condition=[=($2, 1)])", "\n LogicalWindow(window#0=[window(order by [1 DESC] aggs [RANK()])])", "\n PinotLogicalSortExchange(distribution=[hash], collation=[[1 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT($1)])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT($1)], aggType=[FINAL])", "\n PinotLogicalExchange(distribution=[hash[0]])", - "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT()])", + "\n PinotLogicalAggregate(group=[{0}], agg#0=[COUNT()], aggType=[LEAF])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java index 38ff7d2d5c12..a9ce6064b886 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java @@ -91,6 +91,8 @@ public AggregateOperator(OpChainExecutionContext context, MultiStageOperator inp // Initialize the appropriate executor. List groupKeys = node.getGroupKeys(); AggregateNode.AggType aggType = node.getAggType(); + // TODO: Allow leaf return final result for non-group-by queries + boolean leafReturnFinalResult = node.isLeafReturnFinalResult(); if (groupKeys.isEmpty()) { _aggregationExecutor = new MultistageAggregationExecutor(aggFunctions, filterArgIds, maxFilterArgId, aggType, _resultSchema); @@ -98,7 +100,7 @@ public AggregateOperator(OpChainExecutionContext context, MultiStageOperator inp } else { _groupByExecutor = new MultistageGroupByExecutor(getGroupKeyIds(groupKeys), aggFunctions, filterArgIds, maxFilterArgId, aggType, - _resultSchema, context.getOpChainMetadata(), node.getNodeHint()); + leafReturnFinalResult, _resultSchema, context.getOpChainMetadata(), node.getNodeHint()); _aggregationExecutor = null; } } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java index 41501f69383e..701f098182c9 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java @@ -56,6 +56,7 @@ public class MultistageGroupByExecutor { private final int[] _filterArgIds; private final int _maxFilterArgId; private final AggType _aggType; + private final boolean _leafReturnFinalResult; private final DataSchema _resultSchema; private final int _numGroupsLimit; private final boolean _filteredAggregationsSkipEmptyGroups; @@ -69,13 +70,14 @@ public class MultistageGroupByExecutor { private final GroupIdGenerator _groupIdGenerator; public MultistageGroupByExecutor(int[] groupKeyIds, AggregationFunction[] aggFunctions, int[] filterArgIds, - int maxFilterArgId, AggType aggType, DataSchema resultSchema, Map opChainMetadata, - @Nullable PlanNode.NodeHint nodeHint) { + int maxFilterArgId, AggType aggType, boolean leafReturnFinalResult, DataSchema resultSchema, + Map opChainMetadata, @Nullable PlanNode.NodeHint nodeHint) { _groupKeyIds = groupKeyIds; _aggFunctions = aggFunctions; _filterArgIds = filterArgIds; _maxFilterArgId = maxFilterArgId; _aggType = aggType; + _leafReturnFinalResult = leafReturnFinalResult; _resultSchema = resultSchema; int maxInitialResultHolderCapacity = getMaxInitialResultHolderCapacity(opChainMetadata, nodeHint); _numGroupsLimit = getNumGroupsLimit(opChainMetadata, nodeHint); @@ -180,15 +182,20 @@ public List getResult() { private Object getResultValue(int functionId, int groupId) { AggregationFunction aggFunction = _aggFunctions[functionId]; switch (_aggType) { - case LEAF: - return aggFunction.extractGroupByResult(_aggregateResultHolders[functionId], groupId); + case LEAF: { + Object intermediateResult = aggFunction.extractGroupByResult(_aggregateResultHolders[functionId], groupId); + return _leafReturnFinalResult ? aggFunction.extractFinalResult(intermediateResult) : intermediateResult; + } case INTERMEDIATE: return _mergeResultHolder.get(groupId)[functionId]; - case FINAL: - return aggFunction.extractFinalResult(_mergeResultHolder.get(groupId)[functionId]); - case DIRECT: - Object intermediate = aggFunction.extractGroupByResult(_aggregateResultHolders[functionId], groupId); - return aggFunction.extractFinalResult(intermediate); + case FINAL: { + Object mergedResult = _mergeResultHolder.get(groupId)[functionId]; + return _leafReturnFinalResult ? mergedResult : aggFunction.extractFinalResult(mergedResult); + } + case DIRECT: { + Object intermediateResult = aggFunction.extractGroupByResult(_aggregateResultHolders[functionId], groupId); + return aggFunction.extractFinalResult(intermediateResult); + } default: throw new IllegalStateException("Unsupported aggType: " + _aggType); } @@ -263,30 +270,60 @@ private void processMerge(TransferableBlock block) { for (int i = 0; i < numFunctions; i++) { intermediateResults[i] = AggregateOperator.getIntermediateResults(_aggFunctions[i], block); } - for (int i = 0; i < numRows; i++) { - int groupByKey = groupByKeys[i]; - if (groupByKey == GroupKeyGenerator.INVALID_ID) { - continue; - } - Object[] mergedResults; - if (_mergeResultHolder.size() == groupByKey) { - mergedResults = new Object[numFunctions]; - _mergeResultHolder.add(mergedResults); - } else { - mergedResults = _mergeResultHolder.get(groupByKey); + if (_leafReturnFinalResult) { + for (int i = 0; i < numRows; i++) { + int groupByKey = groupByKeys[i]; + if (groupByKey == GroupKeyGenerator.INVALID_ID) { + continue; + } + Comparable[] mergedResults; + if (_mergeResultHolder.size() == groupByKey) { + mergedResults = new Comparable[numFunctions]; + _mergeResultHolder.add(mergedResults); + } else { + mergedResults = (Comparable[]) _mergeResultHolder.get(groupByKey); + } + for (int j = 0; j < numFunctions; j++) { + AggregationFunction aggFunction = _aggFunctions[j]; + Comparable finalResult = (Comparable) intermediateResults[j][i]; + // Not all V1 aggregation functions have null-handling logic. Handle null values before calling merge. + // TODO: Fix it + if (finalResult == null) { + continue; + } + if (mergedResults[j] == null) { + mergedResults[j] = finalResult; + } else { + mergedResults[j] = aggFunction.mergeFinalResult(mergedResults[j], finalResult); + } + } } - for (int j = 0; j < numFunctions; j++) { - AggregationFunction aggFunction = _aggFunctions[j]; - Object intermediateResult = intermediateResults[j][i]; - // Not all V1 aggregation functions have null-handling logic. Handle null values before calling merge. - // TODO: Fix it - if (intermediateResult == null) { + } else { + for (int i = 0; i < numRows; i++) { + int groupByKey = groupByKeys[i]; + if (groupByKey == GroupKeyGenerator.INVALID_ID) { continue; } - if (mergedResults[j] == null) { - mergedResults[j] = intermediateResult; + Object[] mergedResults; + if (_mergeResultHolder.size() == groupByKey) { + mergedResults = new Object[numFunctions]; + _mergeResultHolder.add(mergedResults); } else { - mergedResults[j] = aggFunction.merge(mergedResults[j], intermediateResult); + mergedResults = _mergeResultHolder.get(groupByKey); + } + for (int j = 0; j < numFunctions; j++) { + AggregationFunction aggFunction = _aggFunctions[j]; + Object intermediateResult = intermediateResults[j][i]; + // Not all V1 aggregation functions have null-handling logic. Handle null values before calling merge. + // TODO: Fix it + if (intermediateResult == null) { + continue; + } + if (mergedResults[j] == null) { + mergedResults[j] = intermediateResult; + } else { + mergedResults[j] = aggFunction.merge(mergedResults[j], intermediateResult); + } } } } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java index 9de79d15e263..bd58b7f64f04 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java @@ -80,6 +80,9 @@ public Void visitAggregate(AggregateNode node, ServerPlanRequestContext context) if (node.getAggType() == AggregateNode.AggType.DIRECT) { pinotQuery.putToQueryOptions(CommonConstants.Broker.Request.QueryOptionKey.SERVER_RETURN_FINAL_RESULT, "true"); + } else if (node.isLeafReturnFinalResult()) { + pinotQuery.putToQueryOptions( + CommonConstants.Broker.Request.QueryOptionKey.SERVER_RETURN_FINAL_RESULT_KEY_UNPARTITIONED, "true"); } // there cannot be any more modification of PinotQuery post agg, thus this is the last one possible. context.setLeafStageBoundaryNode(node); diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java index 1c7edcae6ca1..f7f56e0ccb6e 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java @@ -272,7 +272,8 @@ private static RexExpression.FunctionCall getSum(RexExpression arg) { private AggregateOperator getOperator(DataSchema resultSchema, List aggCalls, List filterArgs, List groupKeys, PlanNode.NodeHint nodeHint) { return new AggregateOperator(OperatorTestUtil.getTracingContext(), _input, - new AggregateNode(-1, resultSchema, nodeHint, List.of(), aggCalls, filterArgs, groupKeys, AggType.DIRECT)); + new AggregateNode(-1, resultSchema, nodeHint, List.of(), aggCalls, filterArgs, groupKeys, AggType.DIRECT, + false)); } private AggregateOperator getOperator(DataSchema resultSchema, List aggCalls, diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java index 5b38dcdab58a..fc7ebba0b4cb 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MultiStageAccountingTest.java @@ -152,7 +152,7 @@ private static MultiStageOperator getAggregateOperator() { new DataSchema(new String[]{"group", "sum"}, new DataSchema.ColumnDataType[]{INT, DOUBLE}); return new AggregateOperator(OperatorTestUtil.getTracingContext(), input, new AggregateNode(-1, resultSchema, PlanNode.NodeHint.EMPTY, List.of(), aggCalls, filterArgs, groupKeys, - AggregateNode.AggType.DIRECT)); + AggregateNode.AggType.DIRECT, false)); } private static MultiStageOperator getHashJoinOperator() { diff --git a/pinot-query-runtime/src/test/resources/queries/QueryHints.json b/pinot-query-runtime/src/test/resources/queries/QueryHints.json index 22e464f28b77..e7c2ca375700 100644 --- a/pinot-query-runtime/src/test/resources/queries/QueryHints.json +++ b/pinot-query-runtime/src/test/resources/queries/QueryHints.json @@ -77,6 +77,22 @@ "description": "Group by partition column with partition parallelism and GROUP BY hint", "sql": "SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') */ {tbl1}.num, COUNT(*), COUNT(DISTINCT {tbl1}.name) FROM {tbl1} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4', partition_parallelism='2') */ GROUP BY {tbl1}.num" }, + { + "description": "Group by non-partition column but aggregate on partition column", + "sql": "SELECT {tbl1}.name, COUNT(*), SUM({tbl1}.num), COUNT(DISTINCT {tbl1}.num) FROM {tbl1} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ GROUP BY {tbl1}.name" + }, + { + "description": "Group by non-partition column but aggregate on partition column with partition parallelism", + "sql": "SELECT {tbl1}.name, COUNT(*), SUM({tbl1}.num), COUNT(DISTINCT {tbl1}.num) FROM {tbl1} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4', partition_parallelism='2') */ GROUP BY {tbl1}.name" + }, + { + "description": "Group by non-partition column but aggregate on partition column with hint", + "sql": "SELECT /*+ aggOptions(is_leaf_return_final_result='true') */ {tbl1}.name, COUNT(*), SUM({tbl1}.num), COUNT(DISTINCT {tbl1}.num) FROM {tbl1} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ GROUP BY {tbl1}.name" + }, + { + "description": "Group by non-partition column but aggregate on partition column with partition parallelism and hint", + "sql": "SELECT /*+ aggOptions(is_leaf_return_final_result='true') */ {tbl1}.name, COUNT(*), SUM({tbl1}.num), COUNT(DISTINCT {tbl1}.num) FROM {tbl1} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4', partition_parallelism='2') */ GROUP BY {tbl1}.name" + }, { "description": "Skip leaf stage aggregation with GROUP BY hint", "sql": "SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ {tbl1}.name, COUNT(*), SUM({tbl1}.num), MIN({tbl1}.num), COUNT(DISTINCT {tbl1}.num) FROM {tbl1} WHERE {tbl1}.num >= 0 GROUP BY {tbl1}.name" @@ -101,6 +117,10 @@ "description": "Colocated JOIN with partition column and group by non-partitioned column", "sql": "SELECT {tbl1}.name, SUM({tbl2}.num) FROM {tbl1} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ JOIN {tbl2} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ ON {tbl1}.num = {tbl2}.num GROUP BY {tbl1}.name" }, + { + "description": "Colocated JOIN with partition column and group by non-partitioned column but aggregate on partition column", + "sql": "SELECT /*+ aggOptions(is_leaf_return_final_result='true') */ {tbl1}.name, SUM({tbl2}.num), COUNT(DISTINCT {tbl2}.num) FROM {tbl1} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ JOIN {tbl2} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ ON {tbl1}.num = {tbl2}.num GROUP BY {tbl1}.name" + }, { "description": "Colocated JOIN with partition column and group by non-partitioned column with stage parallelism", "sql": "SET stageParallelism=2; SELECT {tbl1}.name, SUM({tbl2}.num) FROM {tbl1} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ JOIN {tbl2} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ ON {tbl1}.num = {tbl2}.num GROUP BY {tbl1}.name" diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java index e92a60865dce..09fb71c968c6 100644 --- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java +++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java @@ -51,28 +51,29 @@ public enum AggregationFunctionType { // Aggregation functions for single-valued columns COUNT("count"), // TODO: min/max only supports NUMERIC in Pinot, where Calcite supports COMPARABLE_ORDERED - MIN("min", SqlTypeName.DOUBLE), - MAX("max", SqlTypeName.DOUBLE), - SUM("sum", SqlTypeName.DOUBLE), - SUM0("$sum0", SqlTypeName.DOUBLE), + MIN("min", SqlTypeName.DOUBLE, SqlTypeName.DOUBLE), + MAX("max", SqlTypeName.DOUBLE, SqlTypeName.DOUBLE), + SUM("sum", SqlTypeName.DOUBLE, SqlTypeName.DOUBLE), + SUM0("$sum0", SqlTypeName.DOUBLE, SqlTypeName.DOUBLE), SUMPRECISION("sumPrecision", ReturnTypes.explicit(SqlTypeName.DECIMAL), OperandTypes.ANY, SqlTypeName.OTHER), - AVG("avg", SqlTypeName.OTHER), - MODE("mode", SqlTypeName.OTHER), + AVG("avg", SqlTypeName.OTHER, SqlTypeName.DOUBLE), + MODE("mode", SqlTypeName.OTHER, SqlTypeName.DOUBLE), FIRSTWITHTIME("firstWithTime", ReturnTypes.ARG0, OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER), SqlTypeName.OTHER), LASTWITHTIME("lastWithTime", ReturnTypes.ARG0, OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER), SqlTypeName.OTHER), - MINMAXRANGE("minMaxRange", ReturnTypes.DOUBLE, OperandTypes.ANY, SqlTypeName.OTHER), + MINMAXRANGE("minMaxRange", ReturnTypes.DOUBLE, OperandTypes.ANY, SqlTypeName.OTHER, SqlTypeName.DOUBLE), /** * for all distinct count family functions: * (1) distinct_count only supports single argument; * (2) count(distinct ...) support multi-argument and will be converted into DISTINCT + COUNT */ - DISTINCTCOUNT("distinctCount", ReturnTypes.BIGINT, OperandTypes.ANY, SqlTypeName.OTHER), - DISTINCTSUM("distinctSum", ReturnTypes.AGG_SUM, OperandTypes.NUMERIC, SqlTypeName.OTHER), + DISTINCTCOUNT("distinctCount", ReturnTypes.BIGINT, OperandTypes.ANY, SqlTypeName.OTHER, SqlTypeName.INTEGER), + DISTINCTSUM("distinctSum", ReturnTypes.AGG_SUM, OperandTypes.NUMERIC, SqlTypeName.OTHER, SqlTypeName.DOUBLE), DISTINCTAVG("distinctAvg", ReturnTypes.DOUBLE, OperandTypes.NUMERIC, SqlTypeName.OTHER), - DISTINCTCOUNTBITMAP("distinctCountBitmap", ReturnTypes.BIGINT, OperandTypes.ANY, SqlTypeName.OTHER), + DISTINCTCOUNTBITMAP("distinctCountBitmap", ReturnTypes.BIGINT, OperandTypes.ANY, SqlTypeName.OTHER, + SqlTypeName.INTEGER), SEGMENTPARTITIONEDDISTINCTCOUNT("segmentPartitionedDistinctCount", ReturnTypes.BIGINT, OperandTypes.ANY, SqlTypeName.OTHER), DISTINCTCOUNTHLL("distinctCountHLL", ReturnTypes.BIGINT, @@ -105,7 +106,7 @@ public enum AggregationFunctionType { DISTINCTCOUNTRAWCPCSKETCH("distinctCountRawCPCSketch", ReturnTypes.VARCHAR, OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.ANY), i -> i == 1), SqlTypeName.OTHER), - PERCENTILE("percentile", ReturnTypes.ARG0, OperandTypes.ANY_NUMERIC, SqlTypeName.OTHER), + PERCENTILE("percentile", ReturnTypes.ARG0, OperandTypes.ANY_NUMERIC, SqlTypeName.OTHER, SqlTypeName.DOUBLE), PERCENTILEEST("percentileEst", ReturnTypes.BIGINT, OperandTypes.ANY_NUMERIC, SqlTypeName.OTHER), PERCENTILERAWEST("percentileRawEst", ReturnTypes.VARCHAR, OperandTypes.ANY_NUMERIC, SqlTypeName.OTHER), PERCENTILETDIGEST("percentileTDigest", ReturnTypes.DOUBLE, @@ -129,12 +130,12 @@ public enum AggregationFunctionType { HISTOGRAM("histogram", new ArrayReturnTypeInference(SqlTypeName.DOUBLE), OperandTypes.VARIADIC, SqlTypeName.OTHER), - COVARPOP("covarPop", SqlTypeName.OTHER), - COVARSAMP("covarSamp", SqlTypeName.OTHER), - VARPOP("varPop", SqlTypeName.OTHER), - VARSAMP("varSamp", SqlTypeName.OTHER), - STDDEVPOP("stdDevPop", SqlTypeName.OTHER), - STDDEVSAMP("stdDevSamp", SqlTypeName.OTHER), + COVARPOP("covarPop", SqlTypeName.OTHER, SqlTypeName.DOUBLE), + COVARSAMP("covarSamp", SqlTypeName.OTHER, SqlTypeName.DOUBLE), + VARPOP("varPop", SqlTypeName.OTHER, SqlTypeName.DOUBLE), + VARSAMP("varSamp", SqlTypeName.OTHER, SqlTypeName.DOUBLE), + STDDEVPOP("stdDevPop", SqlTypeName.OTHER, SqlTypeName.DOUBLE), + STDDEVSAMP("stdDevSamp", SqlTypeName.OTHER, SqlTypeName.DOUBLE), SKEWNESS("skewness", ReturnTypes.DOUBLE, OperandTypes.ANY, SqlTypeName.OTHER), KURTOSIS("kurtosis", ReturnTypes.DOUBLE, OperandTypes.ANY, SqlTypeName.OTHER), @@ -160,15 +161,15 @@ public enum AggregationFunctionType { PINOTPARENTAGGEXPRMAX(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX + EXPRMAX.getName(), ReturnTypes.explicit(SqlTypeName.OTHER), OperandTypes.VARIADIC, SqlTypeName.OTHER), PINOTCHILDAGGEXPRMIN(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX + EXPRMIN.getName(), - ReturnTypes.ARG1, OperandTypes.VARIADIC, SqlTypeName.OTHER), + ReturnTypes.ARG1, OperandTypes.VARIADIC, SqlTypeName.OTHER, SqlTypeName.BIGINT), PINOTCHILDAGGEXPRMAX(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX + EXPRMAX.getName(), - ReturnTypes.ARG1, OperandTypes.VARIADIC, SqlTypeName.OTHER), + ReturnTypes.ARG1, OperandTypes.VARIADIC, SqlTypeName.OTHER, SqlTypeName.BIGINT), // Array aggregate functions ARRAYAGG("arrayAgg", ReturnTypes.TO_ARRAY, OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER, SqlTypeFamily.BOOLEAN), i -> i == 2), SqlTypeName.OTHER), - LISTAGG("listAgg", SqlTypeName.OTHER), + LISTAGG("listAgg", SqlTypeName.OTHER, SqlTypeName.VARCHAR), SUMARRAYLONG("sumArrayLong", new ArrayReturnTypeInference(SqlTypeName.BIGINT), OperandTypes.ARRAY, SqlTypeName.OTHER), SUMARRAYDOUBLE("sumArrayDouble", new ArrayReturnTypeInference(SqlTypeName.DOUBLE), OperandTypes.ARRAY, @@ -183,16 +184,17 @@ public enum AggregationFunctionType { SqlTypeName.OTHER), // Aggregation functions for multi-valued columns - COUNTMV("countMV", ReturnTypes.BIGINT, OperandTypes.ARRAY, SqlTypeName.BIGINT), - MINMV("minMV", ReturnTypes.DOUBLE, OperandTypes.ARRAY, SqlTypeName.DOUBLE), - MAXMV("maxMV", ReturnTypes.DOUBLE, OperandTypes.ARRAY, SqlTypeName.DOUBLE), - SUMMV("sumMV", ReturnTypes.DOUBLE, OperandTypes.ARRAY, SqlTypeName.DOUBLE), + COUNTMV("countMV", ReturnTypes.BIGINT, OperandTypes.ARRAY), + MINMV("minMV", ReturnTypes.DOUBLE, OperandTypes.ARRAY), + MAXMV("maxMV", ReturnTypes.DOUBLE, OperandTypes.ARRAY), + SUMMV("sumMV", ReturnTypes.DOUBLE, OperandTypes.ARRAY), AVGMV("avgMV", ReturnTypes.DOUBLE, OperandTypes.ARRAY, SqlTypeName.OTHER), MINMAXRANGEMV("minMaxRangeMV", ReturnTypes.DOUBLE, OperandTypes.ARRAY, SqlTypeName.OTHER), - DISTINCTCOUNTMV("distinctCountMV", ReturnTypes.BIGINT, OperandTypes.ARRAY, SqlTypeName.OTHER), + DISTINCTCOUNTMV("distinctCountMV", ReturnTypes.BIGINT, OperandTypes.ARRAY, SqlTypeName.OTHER, SqlTypeName.INTEGER), DISTINCTSUMMV("distinctSumMV", ReturnTypes.DOUBLE, OperandTypes.ARRAY, SqlTypeName.OTHER), DISTINCTAVGMV("distinctAvgMV", ReturnTypes.DOUBLE, OperandTypes.ARRAY, SqlTypeName.OTHER), - DISTINCTCOUNTBITMAPMV("distinctCountBitmapMV", ReturnTypes.BIGINT, OperandTypes.ARRAY, SqlTypeName.OTHER), + DISTINCTCOUNTBITMAPMV("distinctCountBitmapMV", ReturnTypes.BIGINT, OperandTypes.ARRAY, SqlTypeName.OTHER, + SqlTypeName.INTEGER), DISTINCTCOUNTHLLMV("distinctCountHLLMV", ReturnTypes.BIGINT, OperandTypes.family(List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.INTEGER), i -> i == 1), SqlTypeName.OTHER), DISTINCTCOUNTRAWHLLMV("distinctCountRawHLLMV", ReturnTypes.VARCHAR, @@ -203,7 +205,7 @@ public enum AggregationFunctionType { OperandTypes.family(List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.INTEGER), i -> i == 1), SqlTypeName.OTHER), PERCENTILEMV("percentileMV", ReturnTypes.DOUBLE, OperandTypes.family(List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC)), SqlTypeName.OTHER), - PERCENTILEESTMV("percentileEstMV", ReturnTypes.DOUBLE, + PERCENTILEESTMV("percentileEstMV", ReturnTypes.BIGINT, OperandTypes.family(List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC)), SqlTypeName.OTHER), PERCENTILERAWESTMV("percentileRawEstMV", ReturnTypes.VARCHAR, OperandTypes.family(List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC)), SqlTypeName.OTHER), @@ -220,9 +222,9 @@ public enum AggregationFunctionType { OperandTypes.family(List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC, SqlTypeFamily.INTEGER), i -> i == 2), SqlTypeName.OTHER); - private static final Set NAMES = - Arrays.stream(values()).flatMap(func -> Stream.of(func.name(), func.getName(), func.getName().toLowerCase())) - .collect(Collectors.toSet()); + private static final Set NAMES = Arrays.stream(values()) + .flatMap(func -> Stream.of(func.name(), func.getName(), func.getName().toLowerCase())) + .collect(Collectors.toSet()); private final String _name; @@ -230,29 +232,48 @@ public enum AggregationFunctionType { // When returnTypeInference is provided, the function will be registered as a USER_DEFINED_FUNCTION private final SqlReturnTypeInference _returnTypeInference; private final SqlOperandTypeChecker _operandTypeChecker; - // override options for Pinot aggregate rules to insert intermediate results that are non-standard than return type. + // Override intermediate result type if it is not the same as standard return type of the function. private final SqlReturnTypeInference _intermediateReturnTypeInference; + // Override final result type if it is not the same as standard return type of the function. + private final SqlReturnTypeInference _finalReturnTypeInference; AggregationFunctionType(String name) { - this(name, null, null, (SqlReturnTypeInference) null); + this(name, null, null, (SqlReturnTypeInference) null, null); } AggregationFunctionType(String name, SqlTypeName intermediateReturnType) { - this(name, null, null, ReturnTypes.explicit(intermediateReturnType)); + this(name, null, null, ReturnTypes.explicit(intermediateReturnType), null); + } + + AggregationFunctionType(String name, SqlTypeName intermediateReturnType, SqlTypeName finalReturnType) { + this(name, null, null, ReturnTypes.explicit(intermediateReturnType), ReturnTypes.explicit(finalReturnType)); + } + + AggregationFunctionType(String name, SqlReturnTypeInference returnTypeInference, + SqlOperandTypeChecker operandTypeChecker) { + this(name, returnTypeInference, operandTypeChecker, (SqlReturnTypeInference) null, null); } AggregationFunctionType(String name, SqlReturnTypeInference returnTypeInference, SqlOperandTypeChecker operandTypeChecker, SqlTypeName intermediateReturnType) { - this(name, returnTypeInference, operandTypeChecker, ReturnTypes.explicit(intermediateReturnType)); + this(name, returnTypeInference, operandTypeChecker, ReturnTypes.explicit(intermediateReturnType), null); + } + + AggregationFunctionType(String name, SqlReturnTypeInference returnTypeInference, + SqlOperandTypeChecker operandTypeChecker, SqlTypeName intermediateReturnType, SqlTypeName finalReturnType) { + this(name, returnTypeInference, operandTypeChecker, ReturnTypes.explicit(intermediateReturnType), + ReturnTypes.explicit(finalReturnType)); } AggregationFunctionType(String name, @Nullable SqlReturnTypeInference returnTypeInference, @Nullable SqlOperandTypeChecker operandTypeChecker, - @Nullable SqlReturnTypeInference intermediateReturnTypeInference) { + @Nullable SqlReturnTypeInference intermediateReturnTypeInference, + @Nullable SqlReturnTypeInference finalReturnTypeInference) { _name = name; _returnTypeInference = returnTypeInference; _operandTypeChecker = operandTypeChecker; _intermediateReturnTypeInference = intermediateReturnTypeInference; + _finalReturnTypeInference = finalReturnTypeInference; } public String getName() { @@ -274,6 +295,11 @@ public SqlReturnTypeInference getIntermediateReturnTypeInference() { return _intermediateReturnTypeInference; } + @Nullable + public SqlReturnTypeInference getFinalReturnTypeInference() { + return _finalReturnTypeInference; + } + public static boolean isAggregationFunction(String functionName) { if (NAMES.contains(functionName)) { return true;