Skip to content

Commit

Permalink
[fix](nereids)SimplifyComparisonPredicate rule need special care for …
Browse files Browse the repository at this point in the history
…deicmalv3 and datetimev2 literal (apache#21575)
  • Loading branch information
starocean999 authored Jul 14, 2023
1 parent 7f50c07 commit 7a61953
Show file tree
Hide file tree
Showing 8 changed files with 312 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -333,14 +333,6 @@ public int getFracValue() {
return fracPart.intValue();
}

public void roundCeiling() {
roundCeiling(0);
}

public void roundFloor() {
roundFloor(0);
}

public void roundCeiling(int newScale) {
value = value.setScale(newScale, RoundingMode.CEILING);
type = ScalarType.createDecimalType(((ScalarType) type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,34 @@
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.coercion.DateLikeType;

import java.math.BigDecimal;

/**
* simplify comparison
* such as: cast(c1 as DateV2) >= DateV2Literal --> c1 >= DateLiteral
Expand All @@ -56,6 +67,12 @@ public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRew
Expression left = rewrite(cp.left(), context);
Expression right = rewrite(cp.right(), context);

// decimalv3 type
if (left.getDataType() instanceof DecimalV3Type
&& right.getDataType() instanceof DecimalV3Type) {
return processDecimalV3TypeCoercion(cp, left, right);
}

// date like type
if (left.getDataType() instanceof DateLikeType && right.getDataType() instanceof DateLikeType) {
return processDateLikeTypeCoercion(cp, left, right);
Expand All @@ -68,6 +85,49 @@ public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRew
}
}

private static Expression processComparisonPredicateDateTimeV2Literal(
ComparisonPredicate comparisonPredicate, Expression left, DateTimeV2Literal right) {
DateTimeV2Type leftType = (DateTimeV2Type) left.getDataType();
DateTimeV2Type rightType = right.getDataType();
if (leftType.getScale() < rightType.getScale()) {
int toScale = leftType.getScale();
if (comparisonPredicate instanceof EqualTo) {
long originValue = right.getMicroSecond();
right = right.roundCeiling(toScale);
if (right.getMicroSecond() == originValue) {
return comparisonPredicate.withChildren(left, right);
} else {
if (left.nullable()) {
// TODO: the ideal way is to return an If expr like:
// return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE),
// BooleanLiteral.of(false));
// but current fold constant rule can't handle such complex expr with null literal
// before supporting complex conjuncts with null literal folding rules,
// we use a trick way like this:
return new And(new IsNull(left), new NullLiteral(BooleanType.INSTANCE));
} else {
return BooleanLiteral.of(false);
}
}
} else if (comparisonPredicate instanceof NullSafeEqual) {
long originValue = right.getMicroSecond();
right = right.roundCeiling(toScale);
if (right.getMicroSecond() == originValue) {
return comparisonPredicate.withChildren(left, right);
} else {
return BooleanLiteral.of(false);
}
} else if (comparisonPredicate instanceof GreaterThan
|| comparisonPredicate instanceof LessThanEqual) {
return comparisonPredicate.withChildren(left, right.roundFloor(toScale));
} else if (comparisonPredicate instanceof LessThan
|| comparisonPredicate instanceof GreaterThanEqual) {
return comparisonPredicate.withChildren(left, right.roundCeiling(toScale));
}
}
return comparisonPredicate;
}

private Expression processDateLikeTypeCoercion(ComparisonPredicate cp, Expression left, Expression right) {
Expression originalRight = right;
if (left instanceof DateLiteral) {
Expand All @@ -85,6 +145,13 @@ private Expression processDateLikeTypeCoercion(ComparisonPredicate cp, Expressio
right = migrateToDateTime((DateTimeV2Literal) right);
}
}
if (cast.child().getDataType() instanceof DateTimeV2Type) {
if (right instanceof DateTimeV2Literal) {
left = cast.child();
return processComparisonPredicateDateTimeV2Literal(cp, left,
(DateTimeV2Literal) right);
}
}
// datetime to datev2
if (cast.child().getDataType() instanceof DateType || cast.child().getDataType() instanceof DateV2Type) {
if (right instanceof DateTimeLiteral) {
Expand Down Expand Up @@ -129,6 +196,61 @@ private Expression processDateLikeTypeCoercion(ComparisonPredicate cp, Expressio
}
}

private Expression processDecimalV3TypeCoercion(ComparisonPredicate comparisonPredicate,
Expression left, Expression right) {
if (left instanceof DecimalV3Literal) {
comparisonPredicate = comparisonPredicate.commute();
Expression temp = left;
left = right;
right = temp;
}

if (left instanceof Cast && right instanceof DecimalV3Literal) {
Cast cast = (Cast) left;
left = cast.child();
DecimalV3Literal literal = (DecimalV3Literal) right;
if (((DecimalV3Type) left.getDataType())
.getScale() < ((DecimalV3Type) literal.getDataType()).getScale()) {
int toScale = ((DecimalV3Type) left.getDataType()).getScale();
if (comparisonPredicate instanceof EqualTo) {
try {
BigDecimal newValue = literal.getValue().setScale(toScale);
return comparisonPredicate.withChildren(left,
new DecimalV3Literal(newValue));
} catch (ArithmeticException e) {
if (left.nullable()) {
// TODO: the ideal way is to return an If expr like:
// return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE),
// BooleanLiteral.of(false));
// but current fold constant rule can't handle such complex expr with null literal
// before supporting complex conjuncts with null literal folding rules,
// we use a trick way like this:
return new And(new IsNull(left), new NullLiteral(BooleanType.INSTANCE));
} else {
return BooleanLiteral.of(false);
}
}
} else if (comparisonPredicate instanceof NullSafeEqual) {
try {
BigDecimal newValue = literal.getValue().setScale(toScale);
return comparisonPredicate.withChildren(left,
new DecimalV3Literal(newValue));
} catch (ArithmeticException e) {
return BooleanLiteral.of(false);
}
} else if (comparisonPredicate instanceof GreaterThan
|| comparisonPredicate instanceof LessThanEqual) {
return comparisonPredicate.withChildren(left, literal.roundFloor(toScale));
} else if (comparisonPredicate instanceof LessThan
|| comparisonPredicate instanceof GreaterThanEqual) {
return comparisonPredicate.withChildren(left, literal.roundCeiling(toScale));
}
}
}

return comparisonPredicate;
}

private Expression migrateCastToDateTime(Cast cast) {
//cast( cast(v as date) as datetime) if v is datetime, set left = v
if (cast.child() instanceof Cast
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ public class DateTimeLiteral extends DateLiteral {
protected static DateTimeFormatter DATETIMEKEY_FORMATTER = null;
protected static DateTimeFormatter DATE_TIME_FORMATTER_TO_MICRO_SECOND = null;
protected static List<DateTimeFormatter> formatterList = null;
protected static final int MAX_MICROSECOND = 999999;

private static final DateTimeLiteral MIN_DATETIME = new DateTimeLiteral(0000, 1, 1, 0, 0, 0);
private static final DateTimeLiteral MAX_DATETIME = new DateTimeLiteral(9999, 12, 31, 23, 59, 59);
private static final int MAX_MICROSECOND = 999999;

private static final Logger LOG = LogManager.getLogger(DateTimeLiteral.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,46 @@ public Expression plusMicroSeconds(int microSeconds) {
.plusNanos(microSeconds * 1000L), getDataType().getScale());
}

/**
* roundCeiling
*/
public DateTimeV2Literal roundCeiling(int newScale) {
long remain = Double.valueOf(microSecond % (Math.pow(10, 6 - newScale))).longValue();
long newMicroSecond = microSecond;
long newSecond = second;
long newMinute = minute;
long newHour = hour;
long newDay = day;
long newMonth = month;
long newYear = year;
if (remain != 0) {
newMicroSecond = Double
.valueOf((microSecond + (Math.pow(10, 6 - newScale)))
/ (int) (Math.pow(10, 6 - newScale)) * (Math.pow(10, 6 - newScale)))
.longValue();
}
if (newMicroSecond > MAX_MICROSECOND) {
newMicroSecond %= newMicroSecond;
DateTimeV2Literal result = (DateTimeV2Literal) this.plusSeconds(1);
newSecond = result.second;
newMinute = result.minute;
newHour = result.hour;
newDay = result.day;
newMonth = result.month;
newYear = result.year;
}
return new DateTimeV2Literal(DateTimeV2Type.of(newScale), newYear, newMonth, newDay,
newHour, newMinute, newSecond, newMicroSecond);
}

public DateTimeV2Literal roundFloor(int newScale) {
long newMicroSecond = Double.valueOf(
microSecond / (int) (Math.pow(10, 6 - newScale)) * (Math.pow(10, 6 - newScale)))
.longValue();
return new DateTimeV2Literal(DateTimeV2Type.of(newScale), year, month, day, hour, minute,
second, newMicroSecond);
}

public static Expression fromJavaDateType(LocalDateTime dateTime) {
return fromJavaDateType(dateTime, 0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ public double getDouble() {
return value.doubleValue();
}

public DecimalV3Literal roundCeiling(int newScale) {
return new DecimalV3Literal(DecimalV3Type
.createDecimalV3Type(((DecimalV3Type) dataType).getPrecision(), newScale),
value.setScale(newScale, RoundingMode.CEILING));
}

public DecimalV3Literal roundFloor(int newScale) {
return new DecimalV3Literal(DecimalV3Type
.createDecimalV3Type(((DecimalV3Type) dataType).getPrecision(), newScale),
value.setScale(newScale, RoundingMode.FLOOR));
}

private void checkPrecisionAndScale(int precision, int scale, BigDecimal value) throws AnalysisException {
Preconditions.checkNotNull(value);
int realPrecision = value.precision();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,10 @@ public static Expression processCompoundPredicate(CompoundPredicate compoundPred
}
}
);
return compoundPredicate;
List<Expression> children = compoundPredicate.children().stream()
.map(e -> e.getDataType().isNullType() ? new NullLiteral(BooleanType.INSTANCE) : e)
.collect(Collectors.toList());
return compoundPredicate.withChildren(children);
}

/**
Expand Down
Loading

0 comments on commit 7a61953

Please sign in to comment.