Skip to content

Commit

Permalink
Support is_leaf_return_final_result agg option (#14645)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackie-Jiang authored Dec 13, 2024
1 parent 5e36800 commit 3677671
Show file tree
Hide file tree
Showing 30 changed files with 573 additions and 381 deletions.
1 change: 1 addition & 0 deletions pinot-common/src/main/proto/plan.proto
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ message AggregateNode {
repeated int32 filterArgs = 2;
repeated int32 groupKeys = 3;
AggType aggType = 4;
bool leafReturnFinalResult = 5;
}

message FilterNode {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ public ColumnDataType getIntermediateResultColumnType() {

@Override
public ColumnDataType getFinalResultColumnType() {
// TODO: Revisit if we should change this to BIG_DECIMAL
return ColumnDataType.STRING;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3257,9 +3257,9 @@ public void testExplainPlanQueryV2()
+ " PinotLogicalSortExchange("
+ "distribution=[hash], collation=[[0]], isSortOnSender=[false], isSortOnReceiver=[true])\\n"
+ " LogicalProject(count=[$1], name=[$0])\\n"
+ " PinotLogicalAggregate(group=[{0}], agg#0=[COUNT($1)])\\n"
+ " PinotLogicalAggregate(group=[{0}], agg#0=[COUNT($1)], aggType=[FINAL])\\n"
+ " PinotLogicalExchange(distribution=[hash[0]])\\n"
+ " PinotLogicalAggregate(group=[{17}], agg#0=[COUNT()])\\n"
+ " PinotLogicalAggregate(group=[{17}], agg#0=[COUNT()], aggType=[LEAF])\\n"
+ " LogicalTableScan(table=[[default, mytable]])\\n"
+ "\"]]}");
//@formatter:on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ private PinotHintOptions() {

public static class AggregateOptions {
public static final String IS_PARTITIONED_BY_GROUP_BY_KEYS = "is_partitioned_by_group_by_keys";
public static final String IS_LEAF_RETURN_FINAL_RESULT = "is_leaf_return_final_result";
public static final String SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION = "is_skip_leaf_stage_group_by";

public static final String NUM_GROUPS_LIMIT = "num_groups_limit";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
package org.apache.pinot.calcite.rel.hint;

import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import javax.annotation.Nullable;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.hint.HintPredicates;
import org.apache.calcite.rel.hint.HintStrategyTable;
import org.apache.calcite.rel.hint.Hintable;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.pinot.spi.utils.BooleanUtils;


/**
Expand All @@ -35,38 +36,32 @@ public class PinotHintStrategyTable {
private PinotHintStrategyTable() {
}

public static final HintStrategyTable PINOT_HINT_STRATEGY_TABLE =
HintStrategyTable.builder().hintStrategy(PinotHintOptions.AGGREGATE_HINT_OPTIONS, HintPredicates.AGGREGATE)
.hintStrategy(PinotHintOptions.JOIN_HINT_OPTIONS, HintPredicates.JOIN)
.hintStrategy(PinotHintOptions.TABLE_HINT_OPTIONS, HintPredicates.TABLE_SCAN).build();

public static final HintStrategyTable PINOT_HINT_STRATEGY_TABLE = HintStrategyTable.builder()
.hintStrategy(PinotHintOptions.AGGREGATE_HINT_OPTIONS, HintPredicates.AGGREGATE)
.hintStrategy(PinotHintOptions.JOIN_HINT_OPTIONS, HintPredicates.JOIN)
.hintStrategy(PinotHintOptions.TABLE_HINT_OPTIONS, HintPredicates.TABLE_SCAN)
.build();

/**
* Get the first hint that has the specified name.
*/
@Nullable
public static RelHint getHint(Hintable hintable, String hintName) {
return hintable.getHints().stream()
.filter(relHint -> relHint.hintName.equals(hintName))
.findFirst()
.orElse(null);
return getHint(hintable.getHints(), hintName);
}

/**
* Get the first hint that satisfies the predicate.
*/
@Nullable
public static RelHint getHint(Hintable hintable, Predicate<RelHint> predicate) {
return hintable.getHints().stream()
.filter(predicate)
.findFirst()
.orElse(null);
return getHint(hintable.getHints(), predicate);
}

/**
* Check if a hint-able {@link org.apache.calcite.rel.RelNode} contains a specific {@link RelHint} by name.
* Check if a hint-able {@link RelNode} contains a specific {@link RelHint} by name.
*
* @param hintList hint list from the {@link org.apache.calcite.rel.RelNode}.
* @param hintList hint list from the {@link RelNode}.
* @param hintName the name of the {@link RelHint} to be matched
* @return true if it contains the hint
*/
Expand All @@ -79,58 +74,76 @@ public static boolean containsHint(List<RelHint> hintList, String hintName) {
return false;
}

@Nullable
public static RelHint getHint(List<RelHint> hintList, String hintName) {
for (RelHint relHint : hintList) {
if (relHint.hintName.equals(hintName)) {
return relHint;
}
}
return null;
}

@Nullable
public static RelHint getHint(List<RelHint> hintList, Predicate<RelHint> predicate) {
for (RelHint hint : hintList) {
if (predicate.test(hint)) {
return hint;
}
}
return null;
}

/**
* Check if a hint-able {@link org.apache.calcite.rel.RelNode} contains an option key for a specific hint name of
* {@link RelHint}.
* Returns the hint options for a specific hint name of {@link RelHint}, or {@code null} if the hint is not present.
*/
@Nullable
public static Map<String, String> getHintOptions(List<RelHint> hintList, String hintName) {
for (RelHint relHint : hintList) {
if (relHint.hintName.equals(hintName)) {
return relHint.kvOptions;
}
}
return null;
}

/**
* Check if a hint-able {@link RelNode} contains an option key for a specific hint name of {@link RelHint}.
*
* @param hintList hint list from the {@link org.apache.calcite.rel.RelNode}.
* @param hintList hint list from the {@link RelNode}.
* @param hintName the name of the {@link RelHint}.
* @param optionKey the option key to look for in the {@link RelHint#kvOptions}.
* @return true if it contains the hint
*/
public static boolean containsHintOption(List<RelHint> hintList, String hintName, String optionKey) {
for (RelHint relHint : hintList) {
if (relHint.hintName.equals(hintName)) {
return relHint.kvOptions.containsKey(optionKey);
}
}
return false;
Map<String, String> options = getHintOptions(hintList, hintName);
return options != null && options.containsKey(optionKey);
}

/**
* Check if a hint-able {@link org.apache.calcite.rel.RelNode} contains an option key for a specific hint name of
* {@link RelHint}, and the value is true via {@link BooleanUtils#toBoolean(Object)}.
* Check if a hint-able {@link RelNode} has an option key as {@code true} for a specific hint name of {@link RelHint}.
*
* @param hintList hint list from the {@link org.apache.calcite.rel.RelNode}.
* @param hintList hint list from the {@link RelNode}.
* @param hintName the name of the {@link RelHint}.
* @param optionKey the option key to look for in the {@link RelHint#kvOptions}.
* @return true if it contains the hint
*/
public static boolean isHintOptionTrue(List<RelHint> hintList, String hintName, String optionKey) {
for (RelHint relHint : hintList) {
if (relHint.hintName.equals(hintName)) {
return relHint.kvOptions.containsKey(optionKey) && BooleanUtils.toBoolean(relHint.kvOptions.get(optionKey));
}
}
return false;
Map<String, String> options = getHintOptions(hintList, hintName);
return options != null && Boolean.parseBoolean(options.get(optionKey));
}

/**
* Retrieve the option value by option key in the {@link RelHint#kvOptions}. the option key is looked up from the
* specified hint name for a hint-able {@link org.apache.calcite.rel.RelNode}.
* specified hint name for a hint-able {@link RelNode}.
*
* @param hintList hint list from the {@link org.apache.calcite.rel.RelNode}.
* @param hintList hint list from the {@link RelNode}.
* @param hintName the name of the {@link RelHint}.
* @param optionKey the option key to look for in the {@link RelHint#kvOptions}.
* @return true if it contains the hint
*/
@Nullable
public static String getHintOption(List<RelHint> hintList, String hintName, String optionKey) {
for (RelHint relHint : hintList) {
if (relHint.hintName.equals(hintName)) {
return relHint.kvOptions.get(optionKey);
}
}
return null;
Map<String, String> options = getHintOptions(hintList, hintName);
return options != null ? options.get(optionKey) : null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.hint.RelHint;
Expand All @@ -32,17 +33,30 @@

public class PinotLogicalAggregate extends Aggregate {
private final AggType _aggType;
private final boolean _leafReturnFinalResult;

public PinotLogicalAggregate(RelOptCluster cluster, RelTraitSet traitSet, List<RelHint> hints, RelNode input,
ImmutableBitSet groupSet, @Nullable List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls,
AggType aggType) {
AggType aggType, boolean leafReturnFinalResult) {
super(cluster, traitSet, hints, input, groupSet, groupSets, aggCalls);
_aggType = aggType;
_leafReturnFinalResult = leafReturnFinalResult;
}

public PinotLogicalAggregate(Aggregate aggRel, List<AggregateCall> aggCalls, AggType aggType) {
public PinotLogicalAggregate(RelOptCluster cluster, RelTraitSet traitSet, List<RelHint> hints, RelNode input,
ImmutableBitSet groupSet, @Nullable List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls,
AggType aggType) {
this(cluster, traitSet, hints, input, groupSet, groupSets, aggCalls, aggType, false);
}

public PinotLogicalAggregate(Aggregate aggRel, List<AggregateCall> aggCalls, AggType aggType,
boolean leafReturnFinalResult) {
this(aggRel.getCluster(), aggRel.getTraitSet(), aggRel.getHints(), aggRel.getInput(), aggRel.getGroupSet(),
aggRel.getGroupSets(), aggCalls, aggType);
aggRel.getGroupSets(), aggCalls, aggType, leafReturnFinalResult);
}

public PinotLogicalAggregate(Aggregate aggRel, List<AggregateCall> aggCalls, AggType aggType) {
this(aggRel, aggCalls, aggType, false);
}

public PinotLogicalAggregate(Aggregate aggRel, RelNode input, List<AggregateCall> aggCalls, AggType aggType) {
Expand All @@ -51,22 +65,37 @@ public PinotLogicalAggregate(Aggregate aggRel, RelNode input, List<AggregateCall
}

public PinotLogicalAggregate(Aggregate aggRel, RelNode input, ImmutableBitSet groupSet, List<AggregateCall> aggCalls,
AggType aggType) {
this(aggRel.getCluster(), aggRel.getTraitSet(), aggRel.getHints(), input, groupSet, null, aggCalls, aggType);
AggType aggType, boolean leafReturnFinalResult) {
this(aggRel.getCluster(), aggRel.getTraitSet(), aggRel.getHints(), input, groupSet, null, aggCalls, aggType,
leafReturnFinalResult);
}

public AggType getAggType() {
return _aggType;
}

public boolean isLeafReturnFinalResult() {
return _leafReturnFinalResult;
}

@Override
public PinotLogicalAggregate copy(RelTraitSet traitSet, RelNode input, ImmutableBitSet groupSet,
@Nullable List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls) {
return new PinotLogicalAggregate(getCluster(), traitSet, hints, input, groupSet, groupSets, aggCalls, _aggType);
return new PinotLogicalAggregate(getCluster(), traitSet, hints, input, groupSet, groupSets, aggCalls, _aggType,
_leafReturnFinalResult);
}

@Override
public RelWriter explainTerms(RelWriter pw) {
RelWriter relWriter = super.explainTerms(pw);
relWriter.item("aggType", _aggType);
relWriter.itemIf("leafReturnFinalResult", true, _leafReturnFinalResult);
return relWriter;
}

@Override
public RelNode withHints(List<RelHint> hintList) {
return new PinotLogicalAggregate(getCluster(), traitSet, hintList, input, groupSet, groupSets, aggCalls, _aggType);
return new PinotLogicalAggregate(getCluster(), traitSet, hintList, input, groupSet, groupSets, aggCalls, _aggType,
_leafReturnFinalResult);
}
}
Loading

0 comments on commit 3677671

Please sign in to comment.