Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Multi-stage] Support is_leaf_return_final_result agg option #14645

Merged
merged 1 commit into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pinot-common/src/main/proto/plan.proto
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ message AggregateNode {
repeated int32 filterArgs = 2;
repeated int32 groupKeys = 3;
AggType aggType = 4;
bool leafReturnFinalResult = 5;
}

message FilterNode {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ public ColumnDataType getIntermediateResultColumnType() {

@Override
public ColumnDataType getFinalResultColumnType() {
// TODO: Revisit if we should change this to BIG_DECIMAL
return ColumnDataType.STRING;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3257,9 +3257,9 @@ public void testExplainPlanQueryV2()
+ " PinotLogicalSortExchange("
+ "distribution=[hash], collation=[[0]], isSortOnSender=[false], isSortOnReceiver=[true])\\n"
+ " LogicalProject(count=[$1], name=[$0])\\n"
+ " PinotLogicalAggregate(group=[{0}], agg#0=[COUNT($1)])\\n"
+ " PinotLogicalAggregate(group=[{0}], agg#0=[COUNT($1)], aggType=[FINAL])\\n"
+ " PinotLogicalExchange(distribution=[hash[0]])\\n"
+ " PinotLogicalAggregate(group=[{17}], agg#0=[COUNT()])\\n"
+ " PinotLogicalAggregate(group=[{17}], agg#0=[COUNT()], aggType=[LEAF])\\n"
+ " LogicalTableScan(table=[[default, mytable]])\\n"
+ "\"]]}");
//@formatter:on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ private PinotHintOptions() {

public static class AggregateOptions {
public static final String IS_PARTITIONED_BY_GROUP_BY_KEYS = "is_partitioned_by_group_by_keys";
public static final String IS_LEAF_RETURN_FINAL_RESULT = "is_leaf_return_final_result";
public static final String SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION = "is_skip_leaf_stage_group_by";

public static final String NUM_GROUPS_LIMIT = "num_groups_limit";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
package org.apache.pinot.calcite.rel.hint;

import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import javax.annotation.Nullable;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.hint.HintPredicates;
import org.apache.calcite.rel.hint.HintStrategyTable;
import org.apache.calcite.rel.hint.Hintable;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.pinot.spi.utils.BooleanUtils;


/**
Expand All @@ -35,38 +36,32 @@ public class PinotHintStrategyTable {
private PinotHintStrategyTable() {
}

public static final HintStrategyTable PINOT_HINT_STRATEGY_TABLE =
HintStrategyTable.builder().hintStrategy(PinotHintOptions.AGGREGATE_HINT_OPTIONS, HintPredicates.AGGREGATE)
.hintStrategy(PinotHintOptions.JOIN_HINT_OPTIONS, HintPredicates.JOIN)
.hintStrategy(PinotHintOptions.TABLE_HINT_OPTIONS, HintPredicates.TABLE_SCAN).build();

public static final HintStrategyTable PINOT_HINT_STRATEGY_TABLE = HintStrategyTable.builder()
.hintStrategy(PinotHintOptions.AGGREGATE_HINT_OPTIONS, HintPredicates.AGGREGATE)
.hintStrategy(PinotHintOptions.JOIN_HINT_OPTIONS, HintPredicates.JOIN)
.hintStrategy(PinotHintOptions.TABLE_HINT_OPTIONS, HintPredicates.TABLE_SCAN)
.build();

/**
* Get the first hint that has the specified name.
*/
@Nullable
public static RelHint getHint(Hintable hintable, String hintName) {
return hintable.getHints().stream()
.filter(relHint -> relHint.hintName.equals(hintName))
.findFirst()
.orElse(null);
return getHint(hintable.getHints(), hintName);
}

/**
* Get the first hint that satisfies the predicate.
*/
@Nullable
public static RelHint getHint(Hintable hintable, Predicate<RelHint> predicate) {
return hintable.getHints().stream()
.filter(predicate)
.findFirst()
.orElse(null);
return getHint(hintable.getHints(), predicate);
}

/**
* Check if a hint-able {@link org.apache.calcite.rel.RelNode} contains a specific {@link RelHint} by name.
* Check if a hint-able {@link RelNode} contains a specific {@link RelHint} by name.
*
* @param hintList hint list from the {@link org.apache.calcite.rel.RelNode}.
* @param hintList hint list from the {@link RelNode}.
* @param hintName the name of the {@link RelHint} to be matched
* @return true if it contains the hint
*/
Expand All @@ -79,58 +74,76 @@ public static boolean containsHint(List<RelHint> hintList, String hintName) {
return false;
}

@Nullable
public static RelHint getHint(List<RelHint> hintList, String hintName) {
for (RelHint relHint : hintList) {
if (relHint.hintName.equals(hintName)) {
return relHint;
}
}
return null;
}

@Nullable
public static RelHint getHint(List<RelHint> hintList, Predicate<RelHint> predicate) {
for (RelHint hint : hintList) {
if (predicate.test(hint)) {
return hint;
}
}
return null;
}

/**
* Check if a hint-able {@link org.apache.calcite.rel.RelNode} contains an option key for a specific hint name of
* {@link RelHint}.
* Returns the hint options for a specific hint name of {@link RelHint}, or {@code null} if the hint is not present.
*/
@Nullable
public static Map<String, String> getHintOptions(List<RelHint> hintList, String hintName) {
for (RelHint relHint : hintList) {
if (relHint.hintName.equals(hintName)) {
return relHint.kvOptions;
}
}
return null;
}

/**
* Check if a hint-able {@link RelNode} contains an option key for a specific hint name of {@link RelHint}.
*
* @param hintList hint list from the {@link org.apache.calcite.rel.RelNode}.
* @param hintList hint list from the {@link RelNode}.
* @param hintName the name of the {@link RelHint}.
* @param optionKey the option key to look for in the {@link RelHint#kvOptions}.
* @return true if it contains the hint
*/
public static boolean containsHintOption(List<RelHint> hintList, String hintName, String optionKey) {
for (RelHint relHint : hintList) {
if (relHint.hintName.equals(hintName)) {
return relHint.kvOptions.containsKey(optionKey);
}
}
return false;
Map<String, String> options = getHintOptions(hintList, hintName);
return options != null && options.containsKey(optionKey);
}

/**
* Check if a hint-able {@link org.apache.calcite.rel.RelNode} contains an option key for a specific hint name of
* {@link RelHint}, and the value is true via {@link BooleanUtils#toBoolean(Object)}.
* Check if a hint-able {@link RelNode} has an option key as {@code true} for a specific hint name of {@link RelHint}.
*
* @param hintList hint list from the {@link org.apache.calcite.rel.RelNode}.
* @param hintList hint list from the {@link RelNode}.
* @param hintName the name of the {@link RelHint}.
* @param optionKey the option key to look for in the {@link RelHint#kvOptions}.
* @return true if it contains the hint
*/
public static boolean isHintOptionTrue(List<RelHint> hintList, String hintName, String optionKey) {
for (RelHint relHint : hintList) {
if (relHint.hintName.equals(hintName)) {
return relHint.kvOptions.containsKey(optionKey) && BooleanUtils.toBoolean(relHint.kvOptions.get(optionKey));
}
}
return false;
Map<String, String> options = getHintOptions(hintList, hintName);
return options != null && Boolean.parseBoolean(options.get(optionKey));
}

/**
* Retrieve the option value by option key in the {@link RelHint#kvOptions}. the option key is looked up from the
* specified hint name for a hint-able {@link org.apache.calcite.rel.RelNode}.
* specified hint name for a hint-able {@link RelNode}.
*
* @param hintList hint list from the {@link org.apache.calcite.rel.RelNode}.
* @param hintList hint list from the {@link RelNode}.
* @param hintName the name of the {@link RelHint}.
* @param optionKey the option key to look for in the {@link RelHint#kvOptions}.
* @return true if it contains the hint
*/
@Nullable
public static String getHintOption(List<RelHint> hintList, String hintName, String optionKey) {
for (RelHint relHint : hintList) {
if (relHint.hintName.equals(hintName)) {
return relHint.kvOptions.get(optionKey);
}
}
return null;
Map<String, String> options = getHintOptions(hintList, hintName);
return options != null ? options.get(optionKey) : null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.hint.RelHint;
Expand All @@ -32,17 +33,30 @@

public class PinotLogicalAggregate extends Aggregate {
private final AggType _aggType;
private final boolean _leafReturnFinalResult;

public PinotLogicalAggregate(RelOptCluster cluster, RelTraitSet traitSet, List<RelHint> hints, RelNode input,
ImmutableBitSet groupSet, @Nullable List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls,
AggType aggType) {
AggType aggType, boolean leafReturnFinalResult) {
super(cluster, traitSet, hints, input, groupSet, groupSets, aggCalls);
_aggType = aggType;
_leafReturnFinalResult = leafReturnFinalResult;
}

public PinotLogicalAggregate(Aggregate aggRel, List<AggregateCall> aggCalls, AggType aggType) {
public PinotLogicalAggregate(RelOptCluster cluster, RelTraitSet traitSet, List<RelHint> hints, RelNode input,
ImmutableBitSet groupSet, @Nullable List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls,
AggType aggType) {
this(cluster, traitSet, hints, input, groupSet, groupSets, aggCalls, aggType, false);
}

public PinotLogicalAggregate(Aggregate aggRel, List<AggregateCall> aggCalls, AggType aggType,
boolean leafReturnFinalResult) {
this(aggRel.getCluster(), aggRel.getTraitSet(), aggRel.getHints(), aggRel.getInput(), aggRel.getGroupSet(),
aggRel.getGroupSets(), aggCalls, aggType);
aggRel.getGroupSets(), aggCalls, aggType, leafReturnFinalResult);
}

public PinotLogicalAggregate(Aggregate aggRel, List<AggregateCall> aggCalls, AggType aggType) {
this(aggRel, aggCalls, aggType, false);
}

public PinotLogicalAggregate(Aggregate aggRel, RelNode input, List<AggregateCall> aggCalls, AggType aggType) {
Expand All @@ -51,22 +65,37 @@ public PinotLogicalAggregate(Aggregate aggRel, RelNode input, List<AggregateCall
}

public PinotLogicalAggregate(Aggregate aggRel, RelNode input, ImmutableBitSet groupSet, List<AggregateCall> aggCalls,
AggType aggType) {
this(aggRel.getCluster(), aggRel.getTraitSet(), aggRel.getHints(), input, groupSet, null, aggCalls, aggType);
AggType aggType, boolean leafReturnFinalResult) {
this(aggRel.getCluster(), aggRel.getTraitSet(), aggRel.getHints(), input, groupSet, null, aggCalls, aggType,
leafReturnFinalResult);
}

public AggType getAggType() {
return _aggType;
}

public boolean isLeafReturnFinalResult() {
return _leafReturnFinalResult;
}

@Override
public PinotLogicalAggregate copy(RelTraitSet traitSet, RelNode input, ImmutableBitSet groupSet,
@Nullable List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls) {
return new PinotLogicalAggregate(getCluster(), traitSet, hints, input, groupSet, groupSets, aggCalls, _aggType);
return new PinotLogicalAggregate(getCluster(), traitSet, hints, input, groupSet, groupSets, aggCalls, _aggType,
_leafReturnFinalResult);
}

@Override
public RelWriter explainTerms(RelWriter pw) {
RelWriter relWriter = super.explainTerms(pw);
relWriter.item("aggType", _aggType);
relWriter.itemIf("leafReturnFinalResult", true, _leafReturnFinalResult);
return relWriter;
}

@Override
public RelNode withHints(List<RelHint> hintList) {
return new PinotLogicalAggregate(getCluster(), traitSet, hintList, input, groupSet, groupSets, aggCalls, _aggType);
return new PinotLogicalAggregate(getCluster(), traitSet, hintList, input, groupSet, groupSets, aggCalls, _aggType,
_leafReturnFinalResult);
}
}
Loading
Loading