diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/AddMinMax.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/AddMinMax.java index 639616244ec0ab..c1efb82bac6efc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/AddMinMax.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/AddMinMax.java @@ -326,7 +326,17 @@ private Map getExprMinMaxValues(UnknownValue valueDesc) for (int i = 1; i < sourceValues.size(); i++) { // process in sourceValues[i] Map minMaxValues = getExprMinMaxValues(sourceValues.get(i)); - for (Map.Entry entry : minMaxValues.entrySet()) { + // merge values of sourceValues[i] into result. + // also keep the value's relative order in sourceValues[i]. + // for example, if a and b in sourceValues[i], but not in result, then during merging, + // a and b will assign a new exprOrderIndex (using nextExprOrderIndex). + // if in sourceValues[i], a's exprOrderIndex < b's exprOrderIndex, + // then make sure in result, a's new exprOrderIndex < b's new exprOrderIndex. + // so that their relative order can preserve. + List> minMaxValueList = minMaxValues.entrySet().stream() + .sorted((a, b) -> Integer.compare(a.getValue().exprOrderIndex, b.getValue().exprOrderIndex)) + .collect(Collectors.toList()); + for (Map.Entry entry : minMaxValueList) { Expression expr = entry.getKey(); MinMaxValue value = result.get(expr); MinMaxValue otherValue = entry.getValue(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java index 74099df1123f25..a94444e160709b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java @@ -25,10 +25,12 @@ import org.apache.doris.nereids.trees.expressions.GreaterThan; import org.apache.doris.nereids.trees.expressions.GreaterThanEqual; import org.apache.doris.nereids.trees.expressions.InPredicate; +import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.LessThan; import org.apache.doris.nereids.trees.expressions.LessThanEqual; import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.util.ExpressionUtils; @@ -130,10 +132,22 @@ private ValueDesc simplify(ExpressionRewriteContext context, Expression originExpr, List predicates, BinaryOperator op, boolean isAnd) { + boolean convertIsNullToEmptyValue = isAnd && predicates.stream().anyMatch(expr -> expr instanceof NullLiteral); Multimap groupByReference = Multimaps.newListMultimap(new LinkedHashMap<>(), ArrayList::new); for (Expression predicate : predicates) { - ValueDesc valueDesc = predicate.accept(this, context); + // EmptyValue(a) = IsNull(a) and null, it doesn't equals to IsNull(a). + // Only the and expression contains at least a null literal in its conjunctions, + // then EmptyValue(a) can equivalent to IsNull(a). + // so for expression and(IsNull(a), IsNull(b), ..., null), a, b can convert to EmptyValue. + // What's more, if a is not nullable, then EmptyValue(a) always equals to IsNull(a), + // but we don't consider this case here, we should fold IsNull(a) to FALSE using other rule. + ValueDesc valueDesc = null; + if (convertIsNullToEmptyValue && predicate instanceof IsNull) { + valueDesc = new EmptyValue(context, ((IsNull) predicate).child(), predicate); + } else { + valueDesc = predicate.accept(this, context); + } List valueDescs = (List) groupByReference.get(valueDesc.reference); valueDescs.add(valueDesc); } @@ -461,6 +475,11 @@ public boolean isAnd() { @Override public ValueDesc union(ValueDesc other) { + // for RangeValue/DiscreteValue/UnknownValue, when union with EmptyValue, + // call EmptyValue.union(this) => this + if (other instanceof EmptyValue) { + return other.union(this); + } Expression originExpr = FoldConstantRuleOnFE.evaluate( ExpressionUtils.or(toExpr, other.toExpr), context); return new UnknownValue(context, originExpr, @@ -469,6 +488,11 @@ public ValueDesc union(ValueDesc other) { @Override public ValueDesc intersect(ValueDesc other) { + // for RangeValue/DiscreteValue/UnknownValue, when intersect with EmptyValue, + // call EmptyValue.intersect(this) => EmptyValue + if (other instanceof EmptyValue) { + return other.intersect(this); + } Expression originExpr = FoldConstantRuleOnFE.evaluate( ExpressionUtils.and(toExpr, other.toExpr), context); return new UnknownValue(context, originExpr, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java index 628d94d1dccafd..4666342943a85c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java @@ -25,22 +25,16 @@ import org.apache.doris.nereids.rules.expression.rules.RangeInference.RangeValue; import org.apache.doris.nereids.rules.expression.rules.RangeInference.UnknownValue; import org.apache.doris.nereids.rules.expression.rules.RangeInference.ValueDesc; -import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.GreaterThan; import org.apache.doris.nereids.trees.expressions.GreaterThanEqual; import org.apache.doris.nereids.trees.expressions.InPredicate; -import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.LessThan; import org.apache.doris.nereids.trees.expressions.LessThanEqual; -import org.apache.doris.nereids.trees.expressions.Not; import org.apache.doris.nereids.trees.expressions.Or; -import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; -import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; -import org.apache.doris.nereids.types.BooleanType; import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.BoundType; @@ -52,7 +46,6 @@ import java.util.Iterator; import java.util.List; import java.util.Set; -import java.util.stream.Collectors; /** * This class implements the function to simplify expression range. @@ -108,11 +101,7 @@ private Expression getExpression(ValueDesc value) { private Expression getExpression(EmptyValue value) { Expression reference = value.getReference(); - if (reference.nullable()) { - return new And(new IsNull(reference), new NullLiteral(BooleanType.INSTANCE)); - } else { - return BooleanLiteral.FALSE; - } + return ExpressionUtils.falseOrNull(reference); } private Expression getExpression(RangeValue value) { @@ -136,11 +125,7 @@ private Expression getExpression(RangeValue value) { if (!result.isEmpty()) { return ExpressionUtils.and(result); } else { - if (reference.nullable()) { - return new Or(new Not(new IsNull(reference)), new NullLiteral(BooleanType.INSTANCE)); - } else { - return BooleanLiteral.TRUE; - } + return ExpressionUtils.trueOrNull(reference); } } @@ -167,8 +152,15 @@ private Expression getExpression(UnknownValue value) { if (sourceValues.isEmpty()) { return originExpr; } - List sourceExprs = sourceValues.stream().map(sourceValue -> getExpression(sourceValue)) - .collect(Collectors.toList()); + List sourceExprs = Lists.newArrayListWithExpectedSize(sourceValues.size()); + for (ValueDesc sourceValue : sourceValues) { + Expression expr = getExpression(sourceValue); + if (value.isAnd()) { + sourceExprs.addAll(ExpressionUtils.extractConjunction(expr)); + } else { + sourceExprs.addAll(ExpressionUtils.extractDisjunction(expr)); + } + } Expression result = value.isAnd() ? ExpressionUtils.and(sourceExprs) : ExpressionUtils.or(sourceExprs); result = FoldConstantRuleOnFE.evaluate(result, value.getExpressionRewriteContext()); // ATTN: we must return original expr, because OrToIn is implemented with MutableState, diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java index 296cae5fe4023b..0fcc8043944fef 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java @@ -277,7 +277,7 @@ void testDeadLoop() { } @Test - void testOrAddMinMax() { + void testAddMinMax() { executor = new ExpressionRuleExecutor(ImmutableList.of( bottomUp( AddMinMax.INSTANCE @@ -336,4 +336,44 @@ void testOrAddMinMax() { "(AA in (timestamp '2024-01-01 02:00:00',timestamp '2024-01-02 02:00:00',timestamp '2024-01-03 02:00:00') or AA < timestamp '2024-01-01 01:00:00' ) and AA <= timestamp '2024-01-03 02:00:00'"); } + + @Test + void testSimplifyRangeAndAddMinMax() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp( + SimplifyRange.INSTANCE, + AddMinMax.INSTANCE + ) + )); + + assertRewriteAfterTypeCoercion("ISNULL(TA)", "ISNULL(TA)"); + assertRewriteAfterTypeCoercion("ISNULL(TA) and null", "ISNULL(TA) and null"); + assertRewriteAfterTypeCoercion("ISNULL(TA) and ISNULL(TA)", "ISNULL(TA)"); + assertRewriteAfterTypeCoercion("ISNULL(TA) or ISNULL(TA)", "ISNULL(TA)"); + assertRewriteAfterTypeCoercion("ISNULL(TA) and TA between 20 and 10", "ISNULL(TA) and null"); + // assertRewriteAfterTypeCoercion("ISNULL(TA) and TA > 10", "ISNULL(TA) and null"); // should be, but not support now + assertRewriteAfterTypeCoercion("ISNULL(TA) and TA > 10 and null", "ISNULL(TA) and null"); + assertRewriteAfterTypeCoercion("ISNULL(TA) or TA > 10", "ISNULL(TA) or TA > 10"); + // assertRewriteAfterTypeCoercion("(TA < 30 or TA > 40) and TA between 20 and 10", "TA IS NULL AND NULL"); // should be, but not support because flatten and + assertRewriteAfterTypeCoercion("(TA < 30 or TA > 40) and TA is null and null", "TA IS NULL AND NULL"); + assertRewriteAfterTypeCoercion("(TA < 30 or TA > 40) or TA between 20 and 10", "TA < 30 or TA > 40"); + + assertRewriteAfterTypeCoercion("TA between 10 and 20 or TA between 30 and 40 or TA between 60 and 50", + "(TA <= 20 or TA >= 30) and TA >= 10 and TA <= 40"); + // should be, but not support yet, because 'TA is null and null' => UnknownValue(EmptyValue(TA) and null) + //assertRewriteAfterTypeCoercion("TA between 10 and 20 or TA between 30 and 40 or TA is null and null", + // "(TA <= 20 or TA >= 30) and TA >= 10 and TA <= 40"); + assertRewriteAfterTypeCoercion("TA between 10 and 20 or TA between 30 and 40 or TA is null and null", + "(TA <= 20 or TA >= 30 or TA is null and null) and TA >= 10 and TA <= 40"); + assertRewriteAfterTypeCoercion("TA between 10 and 20 or TA between 30 and 40 or TA is null", + "TA >= 10 and TA <= 20 or TA >= 30 and TA <= 40 or TA is null"); + assertRewriteAfterTypeCoercion("ISNULL(TB) and (TA between 10 and 20 or TA between 30 and 40 or TA between 60 and 50)", + "ISNULL(TB) and ((TA <= 20 or TA >= 30) and TA >= 10 and TA <= 40)"); + assertRewriteAfterTypeCoercion("ISNULL(TB) and (TA between 10 and 20 or TA between 30 and 40 or TA is null)", + "ISNULL(TB) and (TA >= 10 and TA <= 20 or TA >= 30 and TA <= 40 or TA is null)"); + assertRewriteAfterTypeCoercion("TB between 20 and 10 and (TA between 10 and 20 or TA between 30 and 40 or TA between 60 and 50)", + "TB IS NULL AND NULL and (TA <= 20 or TA >= 30) and TA >= 10 and TA <= 40"); + assertRewriteAfterTypeCoercion("TA between 10 and 20 and TB between 10 and 20 or TA between 30 and 40 and TB between 30 and 40 or TA between 60 and 50 and TB between 60 and 50", + "(TA <= 20 and TB <= 20 or TA >= 30 and TB >= 30 or TA is null and null and TB is null) and TA >= 10 and TA <= 40 and TB >= 10 and TB <= 40"); + } }