Skip to content

Commit

Permalink
[DROOLS-7631] unify coercion checks between plain drl and executable …
Browse files Browse the repository at this point in the history
…model (#6086)
  • Loading branch information
mariofusco authored Sep 18, 2024
1 parent d7ca8b3 commit d04ab13
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.toJavaParserType;
import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.toStringLiteral;
import static org.drools.util.ClassUtils.toNonPrimitiveType;
import static org.drools.util.CoercionUtil.areComparisonCompatible;
import static org.drools.util.CoercionUtil.areEqualityCompatible;

public class CoercedExpression {

Expand Down Expand Up @@ -151,9 +153,22 @@ public CoercedExpressionResult coerce() {
coercedLeft = left;
}

checkCoercion(coercedLeft, coercedRight, leftClass, rightClass);
return new CoercedExpressionResult(coercedLeft, coercedRight, rightAsStaticField);
}

private void checkCoercion(TypedExpression coercedLeft, TypedExpression coercedRight, Class<?> leftClass, Class<?> rightClass) {
if (equalityExpr) {
if (!areEqualityCompatible(coercedLeft.getRawClass(), coercedRight.getRawClass())) {
throw new CoercedExpressionException(new InvalidExpressionErrorResult("Equality operation requires compatible types. Found " + leftClass + " and " + rightClass));
}
} else {
if (!areComparisonCompatible(coercedLeft.getRawClass(), coercedRight.getRawClass())) {
throw new CoercedExpressionException(new InvalidExpressionErrorResult("Comparison operation requires compatible types. Found " + leftClass + " and " + rightClass));
}
}
}

private boolean isBoolean(Class<?> leftClass) {
return Boolean.class.isAssignableFrom(leftClass) || boolean.class.isAssignableFrom(leftClass);
}
Expand All @@ -163,12 +178,14 @@ private boolean shouldCoerceBToMap() {
}

private boolean canCoerce() {
final Class<?> leftClass = left.getRawClass();
return canCoerce(left.getRawClass(), right.getRawClass());
}

private static boolean canCoerce(Class<?> leftClass, Class<?> rightClass) {
if (!leftClass.isPrimitive() || !canCoerceLiteralNumberExpr(leftClass)) {
return true;
}

final Class<?> rightClass = right.getRawClass();
return rightClass.isPrimitive()
|| Number.class.isAssignableFrom(rightClass)
|| Boolean.class == rightClass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ public void castToShort() {
@Test
public void castMaps() {
final TypedExpression left = expr(THIS_PLACEHOLDER + ".getAge()", Integer.class);
final TypedExpression right = expr("$m.get(\"age\")", java.util.Map.class);
final TypedExpression right = expr("$m.get(\"age\")", Object.class);
final CoercedExpression.CoercedExpressionResult coerce = new CoercedExpression(left, right, false).coerce();
assertThat(coerce.getCoercedRight()).isEqualTo(expr("(java.lang.Integer)$m.get(\"age\")", Map.class));
assertThat(coerce.getCoercedRight()).isEqualTo(expr("$m.get(\"age\")", Object.class));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@
import org.drools.mvel.expr.MVELObjectExpression;
import org.drools.mvel.expr.MvelEvaluator;
import org.drools.mvel.java.JavaForMvelDialectConfiguration;
import org.drools.util.CoercionUtil;
import org.drools.util.MethodUtils;
import org.kie.api.definition.rule.Rule;
import org.mvel2.ConversionHandler;
import org.mvel2.DataConversion;
Expand Down Expand Up @@ -426,8 +428,6 @@ private MVELCompilationUnit buildCompilationUnit( final RuleBuildContext context
}
}

MVELDialect dialect = (MVELDialect) context.getDialect( "mvel" );

MVELCompilationUnit unit = null;

try {
Expand All @@ -439,16 +439,8 @@ private MVELCompilationUnit buildCompilationUnit( final RuleBuildContext context
((ClassObjectType) p.getObjectType()).getClassType() );
}

unit = dialect.getMVELCompilationUnit( (String) predicateDescr.getContent(),
analysis,
previousDeclarations,
localDeclarations,
null,
context,
"drools",
KnowledgeHelper.class,
context.isInXpath(),
MVELCompilationUnit.Scope.CONSTRAINT );
unit = MVELDialect.getMVELCompilationUnit( (String) predicateDescr.getContent(), analysis, previousDeclarations, localDeclarations,
null, context, "drools", KnowledgeHelper.class, context.isInXpath(), MVELCompilationUnit.Scope.CONSTRAINT );
} catch ( final Exception e ) {
copyErrorLocation(e, predicateDescr);
context.addError( new DescrBuildError( context.getParentDescr(),
Expand Down Expand Up @@ -486,48 +478,14 @@ private StringCoercionCompatibilityEvaluator() { }

@Override
public boolean areEqualityCompatible(Class<?> c1, Class<?> c2) {
if (c1 == NullType.class || c2 == NullType.class) {
return true;
}
if (c1 == String.class || c2 == String.class) {
return true;
}
Class<?> boxed1 = convertFromPrimitiveType(c1);
Class<?> boxed2 = convertFromPrimitiveType(c2);
if (boxed1.isAssignableFrom(boxed2) || boxed2.isAssignableFrom(boxed1)) {
return true;
}
if (Number.class.isAssignableFrom(boxed1) && Number.class.isAssignableFrom(boxed2)) {
return true;
}
if (areEqualityCompatibleEnums(boxed1, boxed2)) {
return true;
}
return !Modifier.isFinal(c1.getModifiers()) && !Modifier.isFinal(c2.getModifiers());
}

protected boolean areEqualityCompatibleEnums(final Class<?> boxed1,
final Class<?> boxed2) {
return boxed1.isEnum() && boxed2.isEnum() && boxed1.getName().equals(boxed2.getName())
&& equalEnumConstants(boxed1.getEnumConstants(), boxed2.getEnumConstants());
}

private boolean equalEnumConstants(final Object[] aa,
final Object[] bb) {
if (aa.length != bb.length) {
return false;
}
for (int i = 0; i < aa.length; i++) {
if (!Objects.equals(aa[i].getClass().getName(), bb[i].getClass().getName())) {
return false;
}
}
return true;
return CoercionUtil.areEqualityCompatible(c1 == NullType.class ? MethodUtils.NullType.class : c1,
c2 == NullType.class ? MethodUtils.NullType.class : c2);
}

@Override
public boolean areComparisonCompatible(Class<?> c1, Class<?> c2) {
return areEqualityCompatible(c1, c2);
return CoercionUtil.areComparisonCompatible(c1 == NullType.class ? MethodUtils.NullType.class : c1,
c2 == NullType.class ? MethodUtils.NullType.class : c2);
}
}

Expand Down Expand Up @@ -558,16 +516,8 @@ public TimerExpression buildTimerExpression( String expression, RuleBuildContext
}
Arrays.sort(previousDeclarations, SortDeclarations.instance);

MVELCompilationUnit unit = dialect.getMVELCompilationUnit( expression,
analysis,
previousDeclarations,
null,
null,
context,
"drools",
KnowledgeHelper.class,
false,
MVELCompilationUnit.Scope.EXPRESSION );
MVELCompilationUnit unit = MVELDialect.getMVELCompilationUnit( expression, analysis, previousDeclarations, null, null,
context, "drools", KnowledgeHelper.class, false, MVELCompilationUnit.Scope.EXPRESSION );

MVELObjectExpression expr = new MVELObjectExpression( unit, dialect.getId() );

Expand All @@ -578,9 +528,7 @@ public TimerExpression buildTimerExpression( String expression, RuleBuildContext
return expr;
} catch ( final Exception e ) {
AsmUtil.copyErrorLocation(e, context.getRuleDescr());
context.addError( new DescrBuildError( context.getParentDescr(),
context.getRuleDescr(),
null,
context.addError( new DescrBuildError( context.getParentDescr(), context.getRuleDescr(), null,
"Unable to build expression : " + e.getMessage() + "'" + expression + "'" ) );
return null;
} finally {
Expand All @@ -595,10 +543,8 @@ public AnalysisResult analyzeExpression(Class<?> thisClass, String expr) {
return analyzeExpression( expr, conf, new BoundIdentifiers( thisClass ) );
}

private static MVELAnalysisResult analyzeExpression(String expr,
ParserConfiguration conf,
BoundIdentifiers availableIdentifiers) {
if (expr.trim().length() == 0) {
private static MVELAnalysisResult analyzeExpression(String expr, ParserConfiguration conf, BoundIdentifiers availableIdentifiers) {
if (expr.trim().isEmpty()) {
MVELAnalysisResult result = analyze( (Set<String> ) Collections.EMPTY_SET, availableIdentifiers );
result.setMvelVariables( new HashMap<>() );
result.setTypesafe( true );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,14 @@

import org.drools.testcoverage.common.util.KieBaseTestConfiguration;
import org.drools.testcoverage.common.util.KieBaseUtil;
import org.drools.testcoverage.common.util.KieUtil;
import org.drools.testcoverage.common.util.TestParametersUtil;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.kie.api.KieBase;
import org.kie.api.builder.KieBuilder;
import org.kie.api.builder.Message;
import org.kie.api.runtime.KieSession;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -176,4 +179,37 @@ public void testDateCoercionWithNestedOr() {
assertThat(list.size()).isEqualTo(1);
assertThat(list.get(0)).isEqualTo("working");
}

@Test
public void testLocalDateTimeCoercion() {
// DROOLS-7631
String drl = "import java.util.Date\n" +
"import java.time.LocalDateTime\n" +
"global java.util.List list\n" +
"declare DateContainer\n" +
" date: Date\n" +
"end\n" +
"declare LocalDateTimeContainer\n" +
" date: LocalDateTime\n" +
"end\n" +
"\n" +
"rule Init when\n" +
"then\n" +
" insert(new DateContainer(new Date( 1439882189744L )));" +
" insert(new LocalDateTimeContainer( LocalDateTime.now() ));" +
"end\n" +
"\n" +
"rule \"Test rule\"\n" +
"when\n" +
" DateContainer( $date: date )\n" +
" LocalDateTimeContainer( date > $date )\n" +
"then\n" +
" list.add(\"working\");\n" +
"end\n";

KieBuilder kieBuilder = KieUtil.getKieBuilderFromDrls(kieBaseTestConfiguration, false, drl);
List<Message> errors = kieBuilder.getResults().getMessages(Message.Level.ERROR);
assertThat(errors).hasSize(1);
assertThat(errors.get(0).getText()).contains("Comparison operation requires compatible types");
}
}
51 changes: 50 additions & 1 deletion drools-util/src/main/java/org/drools/util/CoercionUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@

package org.drools.util;

import java.lang.reflect.Modifier;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.time.chrono.ChronoLocalDateTime;
import java.time.temporal.Temporal;
import java.util.Objects;

import org.drools.util.MathUtils;
import static org.drools.util.ClassUtils.convertFromPrimitiveType;

public class CoercionUtil {

Expand Down Expand Up @@ -194,4 +198,49 @@ public static Number coerceToNumber(Number value, Class<?> toClass) {
}
return ret;
}

public static boolean areEqualityCompatible(Class<?> c1, Class<?> c2) {
if (c1 == MethodUtils.NullType.class || c2 == MethodUtils.NullType.class) {
return true;
}
if (c1 == String.class || c2 == String.class) {
return true;
}
if (Temporal.class.isAssignableFrom(c1) && Temporal.class.isAssignableFrom(c2)) {
return true;
}
Class<?> boxed1 = convertFromPrimitiveType(c1);
Class<?> boxed2 = convertFromPrimitiveType(c2);
if (boxed1.isAssignableFrom(boxed2) || boxed2.isAssignableFrom(boxed1)) {
return true;
}
if (Number.class.isAssignableFrom(boxed1) && Number.class.isAssignableFrom(boxed2)) {
return true;
}
if (areEqualityCompatibleEnums(boxed1, boxed2)) {
return true;
}
return !Modifier.isFinal(c1.getModifiers()) && !Modifier.isFinal(c2.getModifiers());
}

protected static boolean areEqualityCompatibleEnums(Class<?> boxed1, Class<?> boxed2) {
return boxed1.isEnum() && boxed2.isEnum() && boxed1.getName().equals(boxed2.getName())
&& equalEnumConstants(boxed1.getEnumConstants(), boxed2.getEnumConstants());
}

private static boolean equalEnumConstants(Object[] aa, Object[] bb) {
if (aa.length != bb.length) {
return false;
}
for (int i = 0; i < aa.length; i++) {
if (!Objects.equals(aa[i].getClass().getName(), bb[i].getClass().getName())) {
return false;
}
}
return true;
}

public static boolean areComparisonCompatible(Class<?> c1, Class<?> c2) {
return areEqualityCompatible(c1, c2);
}
}

0 comments on commit d04ab13

Please sign in to comment.