From 7a61953d179d663737246517e869d006bc74b269 Mon Sep 17 00:00:00 2001 From: starocean999 <40539150+starocean999@users.noreply.github.com> Date: Fri, 14 Jul 2023 23:05:14 +0800 Subject: [PATCH] [fix](nereids)SimplifyComparisonPredicate rule need special care for deicmalv3 and datetimev2 literal (#21575) --- .../apache/doris/analysis/DecimalLiteral.java | 8 -- .../rules/SimplifyComparisonPredicate.java | 122 ++++++++++++++++++ .../expressions/literal/DateTimeLiteral.java | 2 +- .../literal/DateTimeV2Literal.java | 40 ++++++ .../expressions/literal/DecimalV3Literal.java | 12 ++ .../doris/nereids/util/TypeCoercionUtils.java | 5 +- .../RoundLiteralInBinaryPredicatesRule.java | 106 ++++++++------- .../test_simplify_comparison.groovy | 73 +++++++++++ 8 files changed, 312 insertions(+), 56 deletions(-) create mode 100644 regression-test/suites/nereids_syntax_p0/test_simplify_comparison.groovy diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java index 47e98e11426aec..0d781bff7b8d2a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java @@ -333,14 +333,6 @@ public int getFracValue() { return fracPart.intValue(); } - public void roundCeiling() { - roundCeiling(0); - } - - public void roundFloor() { - roundFloor(0); - } - public void roundCeiling(int newScale) { value = value.setScale(newScale, RoundingMode.CEILING); type = ScalarType.createDecimalType(((ScalarType) type) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java index c775f3b774970e..2c9976c2cb0fa2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java @@ -20,23 +20,34 @@ import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.GreaterThan; import org.apache.doris.nereids.trees.expressions.GreaterThanEqual; +import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.LessThan; import org.apache.doris.nereids.trees.expressions.LessThanEqual; +import org.apache.doris.nereids.trees.expressions.NullSafeEqual; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.expressions.literal.DateLiteral; import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral; import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal; import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal; +import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.types.BooleanType; import org.apache.doris.nereids.types.DateTimeType; +import org.apache.doris.nereids.types.DateTimeV2Type; import org.apache.doris.nereids.types.DateType; import org.apache.doris.nereids.types.DateV2Type; +import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.coercion.DateLikeType; +import java.math.BigDecimal; + /** * simplify comparison * such as: cast(c1 as DateV2) >= DateV2Literal --> c1 >= DateLiteral @@ -56,6 +67,12 @@ public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRew Expression left = rewrite(cp.left(), context); Expression right = rewrite(cp.right(), context); + // decimalv3 type + if (left.getDataType() instanceof DecimalV3Type + && right.getDataType() instanceof DecimalV3Type) { + return processDecimalV3TypeCoercion(cp, left, right); + } + // date like type if (left.getDataType() instanceof DateLikeType && right.getDataType() instanceof DateLikeType) { return processDateLikeTypeCoercion(cp, left, right); @@ -68,6 +85,49 @@ public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRew } } + private static Expression processComparisonPredicateDateTimeV2Literal( + ComparisonPredicate comparisonPredicate, Expression left, DateTimeV2Literal right) { + DateTimeV2Type leftType = (DateTimeV2Type) left.getDataType(); + DateTimeV2Type rightType = right.getDataType(); + if (leftType.getScale() < rightType.getScale()) { + int toScale = leftType.getScale(); + if (comparisonPredicate instanceof EqualTo) { + long originValue = right.getMicroSecond(); + right = right.roundCeiling(toScale); + if (right.getMicroSecond() == originValue) { + return comparisonPredicate.withChildren(left, right); + } else { + if (left.nullable()) { + // TODO: the ideal way is to return an If expr like: + // return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE), + // BooleanLiteral.of(false)); + // but current fold constant rule can't handle such complex expr with null literal + // before supporting complex conjuncts with null literal folding rules, + // we use a trick way like this: + return new And(new IsNull(left), new NullLiteral(BooleanType.INSTANCE)); + } else { + return BooleanLiteral.of(false); + } + } + } else if (comparisonPredicate instanceof NullSafeEqual) { + long originValue = right.getMicroSecond(); + right = right.roundCeiling(toScale); + if (right.getMicroSecond() == originValue) { + return comparisonPredicate.withChildren(left, right); + } else { + return BooleanLiteral.of(false); + } + } else if (comparisonPredicate instanceof GreaterThan + || comparisonPredicate instanceof LessThanEqual) { + return comparisonPredicate.withChildren(left, right.roundFloor(toScale)); + } else if (comparisonPredicate instanceof LessThan + || comparisonPredicate instanceof GreaterThanEqual) { + return comparisonPredicate.withChildren(left, right.roundCeiling(toScale)); + } + } + return comparisonPredicate; + } + private Expression processDateLikeTypeCoercion(ComparisonPredicate cp, Expression left, Expression right) { Expression originalRight = right; if (left instanceof DateLiteral) { @@ -85,6 +145,13 @@ private Expression processDateLikeTypeCoercion(ComparisonPredicate cp, Expressio right = migrateToDateTime((DateTimeV2Literal) right); } } + if (cast.child().getDataType() instanceof DateTimeV2Type) { + if (right instanceof DateTimeV2Literal) { + left = cast.child(); + return processComparisonPredicateDateTimeV2Literal(cp, left, + (DateTimeV2Literal) right); + } + } // datetime to datev2 if (cast.child().getDataType() instanceof DateType || cast.child().getDataType() instanceof DateV2Type) { if (right instanceof DateTimeLiteral) { @@ -129,6 +196,61 @@ private Expression processDateLikeTypeCoercion(ComparisonPredicate cp, Expressio } } + private Expression processDecimalV3TypeCoercion(ComparisonPredicate comparisonPredicate, + Expression left, Expression right) { + if (left instanceof DecimalV3Literal) { + comparisonPredicate = comparisonPredicate.commute(); + Expression temp = left; + left = right; + right = temp; + } + + if (left instanceof Cast && right instanceof DecimalV3Literal) { + Cast cast = (Cast) left; + left = cast.child(); + DecimalV3Literal literal = (DecimalV3Literal) right; + if (((DecimalV3Type) left.getDataType()) + .getScale() < ((DecimalV3Type) literal.getDataType()).getScale()) { + int toScale = ((DecimalV3Type) left.getDataType()).getScale(); + if (comparisonPredicate instanceof EqualTo) { + try { + BigDecimal newValue = literal.getValue().setScale(toScale); + return comparisonPredicate.withChildren(left, + new DecimalV3Literal(newValue)); + } catch (ArithmeticException e) { + if (left.nullable()) { + // TODO: the ideal way is to return an If expr like: + // return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE), + // BooleanLiteral.of(false)); + // but current fold constant rule can't handle such complex expr with null literal + // before supporting complex conjuncts with null literal folding rules, + // we use a trick way like this: + return new And(new IsNull(left), new NullLiteral(BooleanType.INSTANCE)); + } else { + return BooleanLiteral.of(false); + } + } + } else if (comparisonPredicate instanceof NullSafeEqual) { + try { + BigDecimal newValue = literal.getValue().setScale(toScale); + return comparisonPredicate.withChildren(left, + new DecimalV3Literal(newValue)); + } catch (ArithmeticException e) { + return BooleanLiteral.of(false); + } + } else if (comparisonPredicate instanceof GreaterThan + || comparisonPredicate instanceof LessThanEqual) { + return comparisonPredicate.withChildren(left, literal.roundFloor(toScale)); + } else if (comparisonPredicate instanceof LessThan + || comparisonPredicate instanceof GreaterThanEqual) { + return comparisonPredicate.withChildren(left, literal.roundCeiling(toScale)); + } + } + } + + return comparisonPredicate; + } + private Expression migrateCastToDateTime(Cast cast) { //cast( cast(v as date) as datetime) if v is datetime, set left = v if (cast.child() instanceof Cast diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateTimeLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateTimeLiteral.java index 56a76fb70f16f4..b82f0a8c94eb0f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateTimeLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateTimeLiteral.java @@ -51,10 +51,10 @@ public class DateTimeLiteral extends DateLiteral { protected static DateTimeFormatter DATETIMEKEY_FORMATTER = null; protected static DateTimeFormatter DATE_TIME_FORMATTER_TO_MICRO_SECOND = null; protected static List formatterList = null; + protected static final int MAX_MICROSECOND = 999999; private static final DateTimeLiteral MIN_DATETIME = new DateTimeLiteral(0000, 1, 1, 0, 0, 0); private static final DateTimeLiteral MAX_DATETIME = new DateTimeLiteral(9999, 12, 31, 23, 59, 59); - private static final int MAX_MICROSECOND = 999999; private static final Logger LOG = LogManager.getLogger(DateTimeLiteral.class); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateTimeV2Literal.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateTimeV2Literal.java index b3152b0bc15db4..ef09399a5aa831 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateTimeV2Literal.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateTimeV2Literal.java @@ -122,6 +122,46 @@ public Expression plusMicroSeconds(int microSeconds) { .plusNanos(microSeconds * 1000L), getDataType().getScale()); } + /** + * roundCeiling + */ + public DateTimeV2Literal roundCeiling(int newScale) { + long remain = Double.valueOf(microSecond % (Math.pow(10, 6 - newScale))).longValue(); + long newMicroSecond = microSecond; + long newSecond = second; + long newMinute = minute; + long newHour = hour; + long newDay = day; + long newMonth = month; + long newYear = year; + if (remain != 0) { + newMicroSecond = Double + .valueOf((microSecond + (Math.pow(10, 6 - newScale))) + / (int) (Math.pow(10, 6 - newScale)) * (Math.pow(10, 6 - newScale))) + .longValue(); + } + if (newMicroSecond > MAX_MICROSECOND) { + newMicroSecond %= newMicroSecond; + DateTimeV2Literal result = (DateTimeV2Literal) this.plusSeconds(1); + newSecond = result.second; + newMinute = result.minute; + newHour = result.hour; + newDay = result.day; + newMonth = result.month; + newYear = result.year; + } + return new DateTimeV2Literal(DateTimeV2Type.of(newScale), newYear, newMonth, newDay, + newHour, newMinute, newSecond, newMicroSecond); + } + + public DateTimeV2Literal roundFloor(int newScale) { + long newMicroSecond = Double.valueOf( + microSecond / (int) (Math.pow(10, 6 - newScale)) * (Math.pow(10, 6 - newScale))) + .longValue(); + return new DateTimeV2Literal(DateTimeV2Type.of(newScale), year, month, day, hour, minute, + second, newMicroSecond); + } + public static Expression fromJavaDateType(LocalDateTime dateTime) { return fromJavaDateType(dateTime, 0); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java index fe6040ca85ecd0..bc36c75436dd37 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java @@ -72,6 +72,18 @@ public double getDouble() { return value.doubleValue(); } + public DecimalV3Literal roundCeiling(int newScale) { + return new DecimalV3Literal(DecimalV3Type + .createDecimalV3Type(((DecimalV3Type) dataType).getPrecision(), newScale), + value.setScale(newScale, RoundingMode.CEILING)); + } + + public DecimalV3Literal roundFloor(int newScale) { + return new DecimalV3Literal(DecimalV3Type + .createDecimalV3Type(((DecimalV3Type) dataType).getPrecision(), newScale), + value.setScale(newScale, RoundingMode.FLOOR)); + } + private void checkPrecisionAndScale(int precision, int scale, BigDecimal value) throws AnalysisException { Preconditions.checkNotNull(value); int realPrecision = value.precision(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java index b53c4d2724c255..05de35a3887e2a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java @@ -757,7 +757,10 @@ public static Expression processCompoundPredicate(CompoundPredicate compoundPred } } ); - return compoundPredicate; + List children = compoundPredicate.children().stream() + .map(e -> e.getDataType().isNullType() ? new NullLiteral(BooleanType.INSTANCE) : e) + .collect(Collectors.toList()); + return compoundPredicate.withChildren(children); } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RoundLiteralInBinaryPredicatesRule.java b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RoundLiteralInBinaryPredicatesRule.java index 8ef6ef7a37fd70..017827a019c81a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RoundLiteralInBinaryPredicatesRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RoundLiteralInBinaryPredicatesRule.java @@ -21,11 +21,14 @@ import org.apache.doris.analysis.BinaryPredicate; import org.apache.doris.analysis.BinaryPredicate.Operator; import org.apache.doris.analysis.BoolLiteral; +import org.apache.doris.analysis.CompoundPredicate; import org.apache.doris.analysis.DateLiteral; import org.apache.doris.analysis.DecimalLiteral; import org.apache.doris.analysis.Expr; import org.apache.doris.analysis.IsNullPredicate; +import org.apache.doris.analysis.NullLiteral; import org.apache.doris.catalog.ScalarType; +import org.apache.doris.catalog.Type; import org.apache.doris.common.AnalysisException; import java.math.BigDecimal; @@ -46,45 +49,48 @@ private Expr rewriteDecimalLiteral(Expr expr) { && ((ScalarType) expr0.getType()).getScalarScale() < ((ScalarType) expr1.getType()).getScalarScale()) { int toScale = ((ScalarType) expr0.getType()).getScalarScale(); - try { - switch (op) { - case EQ: { - BigDecimal originValue = literal.getValue(); - literal.roundCeiling(); - if (literal.getValue().equals(originValue.setScale(toScale))) { - expr.setChild(1, literal); - return expr; - } else { - return new BoolLiteral(false); - } - } - case NE: { - BigDecimal originValue = literal.getValue(); - literal.roundCeiling(toScale); - if (literal.getValue().equals(originValue.setScale(toScale))) { - expr.setChild(1, literal); - return expr; + switch (op) { + case EQ: + case NE: { + try { + BigDecimal newValue = literal.getValue().setScale(toScale); + expr.setChild(1, new DecimalLiteral(newValue)); + return expr; + } catch (ArithmeticException e) { + if (expr0.isNullable()) { + // TODO: the ideal way is to return an If expr like: + // List innerIfExprs = Lists.newArrayList(); + // innerIfExprs.add(new IsNullPredicate(expr0, false)); + // innerIfExprs.add(NullLiteral.create(Type.BOOLEAN)); + // innerIfExprs + // .add(op == Operator.EQ ? new BoolLiteral(false) : new BoolLiteral(true)); + // return new FunctionCallExpr("if", innerIfExprs); + // but current fold constant rule can't handle such complex expr with null literal + // so we use a trick way like this: + Expr newExpr = new CompoundPredicate(CompoundPredicate.Operator.AND, + new IsNullPredicate(expr0, false), NullLiteral.create(Type.BOOLEAN)); + return op == Operator.EQ ? newExpr + : new CompoundPredicate(CompoundPredicate.Operator.NOT, + newExpr, null); } else { - return new IsNullPredicate(expr0, true); + return op == Operator.EQ ? new BoolLiteral(false) : new BoolLiteral(true); } } - case GT: - case LE: { - literal.roundFloor(toScale); - expr.setChild(1, literal); - return expr; - } - case LT: - case GE: { - literal.roundCeiling(toScale); - expr.setChild(1, literal); - return expr; - } - default: - return expr; } - } catch (ArithmeticException e) { - return new BoolLiteral(false); + case GT: + case LE: { + literal.roundFloor(toScale); + expr.setChild(1, literal); + return expr; + } + case LT: + case GE: { + literal.roundCeiling(toScale); + expr.setChild(1, literal); + return expr; + } + default: + return expr; } } } @@ -101,16 +107,7 @@ private Expr rewriteDateLiteral(Expr expr) { if (expr0.getType().isDatetimeV2() && expr1 instanceof DateLiteral && expr1.getType().isDatetimeV2()) { DateLiteral literal = (DateLiteral) expr1; switch (op) { - case EQ: { - long originValue = literal.getMicrosecond(); - literal.roundCeiling(((ScalarType) expr0.getType()).getScalarScale()); - if (literal.getMicrosecond() == originValue) { - expr.setChild(1, literal); - return expr; - } else { - return new BoolLiteral(false); - } - } + case EQ: case NE: { long originValue = literal.getMicrosecond(); literal.roundCeiling(((ScalarType) expr0.getType()).getScalarScale()); @@ -118,7 +115,24 @@ private Expr rewriteDateLiteral(Expr expr) { expr.setChild(1, literal); return expr; } else { - return new IsNullPredicate(expr0, true); + if (expr0.isNullable()) { + // TODO: the ideal way is to return an If expr like: + // List innerIfExprs = Lists.newArrayList(); + // innerIfExprs.add(new IsNullPredicate(expr0, false)); + // innerIfExprs.add(NullLiteral.create(Type.BOOLEAN)); + // innerIfExprs + // .add(op == Operator.EQ ? new BoolLiteral(false) : new BoolLiteral(true)); + // return new FunctionCallExpr("if", innerIfExprs); + // but current fold constant rule can't handle such complex expr with null literal + // so we use a trick way like this: + Expr newExpr = new CompoundPredicate(CompoundPredicate.Operator.AND, + new IsNullPredicate(expr0, false), NullLiteral.create(Type.BOOLEAN)); + return op == Operator.EQ ? newExpr + : new CompoundPredicate(CompoundPredicate.Operator.NOT, newExpr, + null); + } else { + return op == Operator.EQ ? new BoolLiteral(false) : new BoolLiteral(true); + } } } case GT: diff --git a/regression-test/suites/nereids_syntax_p0/test_simplify_comparison.groovy b/regression-test/suites/nereids_syntax_p0/test_simplify_comparison.groovy new file mode 100644 index 00000000000000..c252e8b86c19e0 --- /dev/null +++ b/regression-test/suites/nereids_syntax_p0/test_simplify_comparison.groovy @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_simplify_comparison") { + sql "set enable_nereids_planner=true" + sql 'set enable_fallback_to_original_planner=false;' + sql 'drop table if exists log_items_test' + sql """CREATE TABLE IF NOT EXISTS `log_items_test` ( + a DATETIME NOT NULL, + b decimal(10,2) + ) ENGINE=OLAP + UNIQUE KEY (`a`) + DISTRIBUTED BY HASH(`a`) BUCKETS 120 + PROPERTIES ( + "replication_num" = "1", + "in_memory" = "false", + "compression" = "LZ4", + "storage_cooldown_time" = "9999-12-31 23:59:59", + "enable_unique_key_merge_on_write" = "true" + );""" + sql """insert into log_items_test values( "2023-06-06", 111.11 );""" + + explain { + sql "verbose select * from log_items_test where a < '2023-06-15 23:59:59.999' and b < 111.111;" + notContains "CAST" + contains "< 111.12" + contains "< '2023-06-16 00:00:00'" + } + + explain { + sql "verbose select * from log_items_test where a <= '2023-06-15 23:59:59.999' and b <= 111.111;" + notContains "CAST" + contains "<= 111.11" + contains "<= '2023-06-15 23:59:59'" + } + + explain { + sql "verbose select * from log_items_test where a = '2023-06-15 23:59:59.999' and b = 111.111;" + notContains "CAST" + notContains "111.12" + notContains "2023-06-16 00:00:00" + notContains "111.11" + notContains "2023-06-15 23:59:59" + } + + explain { + sql "verbose select * from log_items_test where a > '2023-06-15 23:59:59.999' and b > 111.111;" + notContains "CAST" + contains "> 111.11" + contains "> '2023-06-15 23:59:59'" + } + + explain { + sql "verbose select * from log_items_test where a >= '2023-06-15 23:59:59.999' and b >= 111.111;" + notContains "CAST" + contains ">= 111.12" + contains ">= '2023-06-16 00:00:00'" + } +} \ No newline at end of file