Skip to content

Commit

Permalink
[Fix](nereids) fix normalize repeat alias rewrite (apache#38166)
Browse files Browse the repository at this point in the history
Induced by apache#34196.

In NormalizeRepeat, when NormalizeToSlot is called, aggregate function
parameters, grouping scalar function parameters, and all expressions in
grouping sets (including columns and column aliases) are pushed down to
the lower-level project output.
In the previous PR apache#34196, the context was split into two, but the two
contexts were not consistent. It is possible that the triplets in one
context save (id, c1, id as c1), and the triplets in the other context
save (id, id, id).
This causes id as c1 to be pushed down, but there is a reference to id
in the upper-level LogicalRepeat, which causes the slot to be not found.
This pr has been modified.
If the same slot in the projection column has different aliases,

for example,

select id as c1, id, id as c3, grouping(id) from table1 group by grouping sets((id, value2),(id));

then id as c1 (using the first alias) will be pushed down to the
project.
In both the LogicalRepeat operator and the LogicalAggregate operator, c1
is referenced as the input slot, and id and c3 will not be used as input
slots.

before NormalizeRepeat:
LogicalResultSink[32] ( outputExprs=[c1#3, id#0, c3#4, __grouping_3#5] )
+--LogicalRepeat ( groupingSets=[[id#0, value2#2], [id#0]], outputExpressions=[id#0 AS `c1`#3, id#0, id#0 AS `c3`#4, Grouping(id#0) AS `Grouping(id)`#5] )
   +--LogicalOlapScan (qualified=table1)

After NormalizeRepeat:
LogicalResultSink[33] (outputExprs=[c1#3, id#0, c3#4, __grouping_3#5])
+--LogicalAggregate[30] (groupByExpr=[c1#3, value2#2, GROUPING_ID#7, GROUPING_PREFIX_c1#6 originExpression=Grouping(c1#3)], outputExpr=[c1#3, c1#3 AS `id`#0, c1#3 AS `c3`#4, GROUPING_PREFIX_c1#6 originExpression=Grouping(c1#3) AS `GROUPING_PREFIX_c1`#5], hasRepeat=true )
   +--LogicalRepeat (groupingSets=[[c1#3, value2#2], [c1#3]], outputExpressions=[c1#3, value2#2, GROUPING_ID#7, GROUPING_PREFIX_c1#6 originExpression=Grouping(c1#3)] )
      +--LogicalProject[28] (projects=[id#0 AS `c1`#3, value2#2])
         +--LogicalOlapScan (qualified=table1)
  • Loading branch information
feiniaofeiafei authored Jul 26, 2024
1 parent 53ceaeb commit 7090433
Show file tree
Hide file tree
Showing 4 changed files with 684 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
import org.jetbrains.annotations.NotNull;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -133,9 +135,9 @@ private static LogicalAggregate<Plan> normalizeRepeat(LogicalRepeat<Plan> repeat
Set<Expression> needToSlotsGroupingExpr = collectNeedToSlotGroupingExpr(repeat);
NormalizeToSlotContext groupingExprContext = buildContext(repeat, needToSlotsGroupingExpr);
Map<Expression, NormalizeToSlotTriplet> groupingExprMap = groupingExprContext.getNormalizeToSlotMap();
Set<Alias> existsAlias = getExistsAlias(repeat, groupingExprMap);
Map<Expression, Alias> existsAlias = getExistsAlias(repeat, groupingExprMap);
Set<Expression> needToSlotsArgs = collectNeedToSlotArgsOfGroupingScalarFuncAndAggFunc(repeat);
NormalizeToSlotContext argsContext = NormalizeToSlotContext.buildContext(existsAlias, needToSlotsArgs);
NormalizeToSlotContext argsContext = buildContextWithAlias(repeat, existsAlias, needToSlotsArgs);

// normalize grouping sets to List<List<Slot>>
ImmutableList.Builder<List<Slot>> normalizedGroupingSetBuilder = ImmutableList.builder();
Expand Down Expand Up @@ -254,12 +256,27 @@ private static Plan pushDownProject(Set<NamedExpression> pushedExprs, Plan origi
/** buildContext */
public static NormalizeToSlotContext buildContext(Repeat<? extends Plan> repeat,
Set<? extends Expression> sourceExpressions) {
Set<Alias> aliases = ExpressionUtils.collect(repeat.getOutputExpressions(), Alias.class::isInstance);
List<Alias> aliases = ExpressionUtils.collectToList(repeat.getOutputExpressions(), Alias.class::isInstance);
Map<Expression, Alias> existsAliasMap = Maps.newLinkedHashMap();
for (Alias existsAlias : aliases) {
if (existsAliasMap.containsKey(existsAlias.child())) {
continue;
}
existsAliasMap.put(existsAlias.child(), existsAlias);
}

Map<Expression, NormalizeToSlotTriplet> normalizeToSlotMap = Maps.newLinkedHashMap();
for (Expression expression : sourceExpressions) {
Optional<NormalizeToSlotTriplet> pushDownTriplet =
toGroupingSetExpressionPushDownTriplet(expression, existsAliasMap.get(expression));
pushDownTriplet.ifPresent(
normalizeToSlotTriplet -> normalizeToSlotMap.put(expression, normalizeToSlotTriplet));
}
return new NormalizeToSlotContext(normalizeToSlotMap);
}

private static NormalizeToSlotContext buildContextWithAlias(Repeat<? extends Plan> repeat,
Map<Expression, Alias> existsAliasMap, Collection<? extends Expression> sourceExpressions) {
List<Expression> groupingSetExpressions = ExpressionUtils.flatExpressions(repeat.getGroupingSets());
Map<Expression, NormalizeToSlotTriplet> normalizeToSlotMap = Maps.newLinkedHashMap();
for (Expression expression : sourceExpressions) {
Expand All @@ -270,10 +287,8 @@ public static NormalizeToSlotContext buildContext(Repeat<? extends Plan> repeat,
pushDownTriplet = Optional.of(
NormalizeToSlotTriplet.toTriplet(expression, existsAliasMap.get(expression)));
}

if (pushDownTriplet.isPresent()) {
normalizeToSlotMap.put(expression, pushDownTriplet.get());
}
pushDownTriplet.ifPresent(
normalizeToSlotTriplet -> normalizeToSlotMap.put(expression, normalizeToSlotTriplet));
}
return new NormalizeToSlotContext(normalizeToSlotMap);
}
Expand Down Expand Up @@ -304,18 +319,23 @@ private static Expression normalizeAggFuncChildrenAndGroupingScalarFunc(Normaliz
}
}

private static Set<Alias> getExistsAlias(LogicalRepeat<Plan> repeat,
private static Map<Expression, Alias> getExistsAlias(LogicalRepeat<Plan> repeat,
Map<Expression, NormalizeToSlotTriplet> groupingExprMap) {
Set<Alias> existsAlias = Sets.newHashSet();
Set<Alias> aliases = ExpressionUtils.collect(repeat.getOutputExpressions(), Alias.class::isInstance);
existsAlias.addAll(aliases);
Map<Expression, Alias> existsAliasMap = new HashMap<>();
for (NormalizeToSlotTriplet triplet : groupingExprMap.values()) {
if (triplet.pushedExpr instanceof Alias) {
Alias alias = (Alias) triplet.pushedExpr;
existsAlias.add(alias);
existsAliasMap.put(triplet.originExpr, alias);
}
}
List<Alias> aliases = ExpressionUtils.collectToList(repeat.getOutputExpressions(), Alias.class::isInstance);
for (Alias alias : aliases) {
if (existsAliasMap.containsKey(alias.child())) {
continue;
}
existsAliasMap.put(alias.child(), alias);
}
return existsAlias;
return existsAliasMap;
}

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,15 @@ public static <E> Set<E> collect(Collection<? extends Expression> expressions,
return set.build();
}

public static <E> List<E> collectToList(Collection<? extends Expression> expressions,
Predicate<TreeNode<Expression>> predicate) {
ImmutableList.Builder<E> list = ImmutableList.builder();
for (Expression expr : expressions) {
list.addAll(expr.collectToList(predicate));
}
return list.build();
}

/**
* extract uniform slot for the given predicate, such as a = 1 and b = 2
*/
Expand Down
Loading

0 comments on commit 7090433

Please sign in to comment.