diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/SortExpressionExtractor.java b/core/trino-main/src/main/java/io/trino/sql/planner/SortExpressionExtractor.java index c36349e5e5ce3..b719a6e73734b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/SortExpressionExtractor.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/SortExpressionExtractor.java @@ -28,9 +28,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL; -import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; -import static java.util.Collections.singletonList; import static java.util.Comparator.comparing; import static java.util.function.Function.identity; import static java.util.stream.Collectors.toMap; @@ -66,8 +63,7 @@ public static Optional extractSortExpression(Set List sortExpressionCandidates = ImmutableList.copyOf(filterConjuncts.stream() .filter(DeterminismEvaluator::isDeterministic) .map(visitor::process) - .filter(Optional::isPresent) - .map(Optional::get) + .flatMap(List::stream) .collect(toMap(SortExpressionContext::getSortExpression, identity(), SortExpressionExtractor::merge)) .values()); @@ -88,7 +84,7 @@ private static SortExpressionContext merge(SortExpressionContext left, SortExpre } private static class SortExpressionVisitor - extends IrVisitor, Void> + extends IrVisitor, Void> { private final Set buildSymbols; @@ -98,13 +94,13 @@ public SortExpressionVisitor(Set buildSymbols) } @Override - protected Optional visitExpression(Expression expression, Void context) + protected List visitExpression(Expression expression, Void context) { - return Optional.empty(); + return List.of(); } @Override - protected Optional visitComparison(Comparison comparison, Void context) + protected List visitComparison(Comparison comparison, Void context) { return switch (comparison.operator()) { case GREATER_THAN, GREATER_THAN_OR_EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL -> { @@ -115,22 +111,22 @@ protected Optional visitComparison(Comparison comparison, hasBuildReferencesOnOtherSide = hasBuildSymbolReference(buildSymbols, comparison.right()); } if (sortChannel.isPresent() && !hasBuildReferencesOnOtherSide) { - yield sortChannel.map(symbolReference -> new SortExpressionContext(symbolReference, singletonList(comparison))); + yield ImmutableList.of(new SortExpressionContext(sortChannel.get(), ImmutableList.of(comparison))); } - yield Optional.empty(); + yield List.of(); } - default -> Optional.empty(); + default -> List.of(); }; } @Override - protected Optional visitBetween(Between node, Void context) + protected List visitBetween(Between node, Void context) { - Optional result = visitComparison(new Comparison(GREATER_THAN_OR_EQUAL, node.value(), node.min()), context); - if (result.isPresent()) { - return result; - } - return visitComparison(new Comparison(LESS_THAN_OR_EQUAL, node.value(), node.max()), context); + // Handle both side of BETWEEN as `GREATER_THAN_OR_EQUAL` expression and `LESS_THAN_OR_EQUAL` expression. + return ImmutableList.builder() + .addAll(visitComparison(new Comparison(Comparison.Operator.GREATER_THAN_OR_EQUAL, node.value(), node.min()), context)) + .addAll(visitComparison(new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, node.value(), node.max()), context)) + .build(); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestSortExpressionExtractor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestSortExpressionExtractor.java index 8717386f1b97b..74665295d79f9 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestSortExpressionExtractor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestSortExpressionExtractor.java @@ -93,9 +93,14 @@ public void testGetSortExpression() new Comparison(GREATER_THAN, new Reference(BIGINT, "b2"), new Reference(BIGINT, "p2"))); assertGetSortExpression( - new Between(new Reference(BIGINT, "p1"), new Reference(BIGINT, "b1"), new Reference(BIGINT, "b2")), - "b1", - new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "p1"), new Reference(BIGINT, "b1"))); + new Logical(AND, ImmutableList.of( + new Between(new Reference(BIGINT, "p1"), new Reference(BIGINT, "b1"), new Reference(BIGINT, "b2")), + new Comparison(LESS_THAN, new Reference(BIGINT, "b2"), + new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "p2"), new Constant(BIGINT, 1L)))))), + "b2", + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "p1"), new Reference(BIGINT, "b2")), + new Comparison(LESS_THAN, new Reference(BIGINT, "b2"), + new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "p2"), new Constant(BIGINT, 1L))))); assertGetSortExpression( new Between(new Reference(BIGINT, "p1"), new Reference(BIGINT, "p2"), new Reference(BIGINT, "b1")), @@ -105,7 +110,8 @@ public void testGetSortExpression() assertGetSortExpression( new Between(new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1"), new Reference(BIGINT, "p2")), "b1", - new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1"))); + new Comparison(GREATER_THAN_OR_EQUAL, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1")), + new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p2"))); assertGetSortExpression( new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN, new Reference(BIGINT, "b1"), new Reference(BIGINT, "p1")), new Between(new Reference(BIGINT, "p1"), new Reference(BIGINT, "b1"), new Reference(BIGINT, "b2")))),