diff --git a/drools-model/drools-canonical-model/src/main/java/org/drools/model/PrototypeDSL.java b/drools-model/drools-canonical-model/src/main/java/org/drools/model/PrototypeDSL.java index aea97784cce..67d2b7bedc9 100644 --- a/drools-model/drools-canonical-model/src/main/java/org/drools/model/PrototypeDSL.java +++ b/drools-model/drools-canonical-model/src/main/java/org/drools/model/PrototypeDSL.java @@ -131,7 +131,20 @@ public PrototypePatternDef expr(PrototypeExpression left, ConstraintOperator ope Prototype prototype = getPrototype(); Function1 leftExtractor = left.asFunction(prototype); - AlphaIndex alphaIndex = null; + + Set reactOnFields = new HashSet<>(); + reactOnFields.addAll(left.getImpactedFields()); + reactOnFields.addAll(right.getImpactedFields()); + + expr(createExprId(left, operator, right), + asPredicate1(leftExtractor, operator, right.asFunction(prototype)), + createAlphaIndex(left, operator, right, prototype, leftExtractor), + reactOn( reactOnFields.toArray(new String[reactOnFields.size()])) ); + + return this; + } + + private static AlphaIndex createAlphaIndex(PrototypeExpression left, ConstraintOperator operator, PrototypeExpression right, Prototype prototype, Function1 leftExtractor) { if (left.getIndexingKey().isPresent() && right instanceof PrototypeExpression.FixedValue && operator instanceof Index.ConstraintType) { String fieldName = left.getIndexingKey().get(); Index.ConstraintType constraintType = (Index.ConstraintType) operator; @@ -140,20 +153,10 @@ public PrototypePatternDef expr(PrototypeExpression left, ConstraintOperator ope Class fieldClass = (Class) (field != null && field.isTyped() ? field.getType() : value != null ? value.getClass() : null); if (fieldClass != null) { - alphaIndex = alphaIndexedBy(fieldClass, constraintType, getFieldIndex(prototype, fieldName, field), leftExtractor, value); + return alphaIndexedBy(fieldClass, constraintType, getFieldIndex(prototype, fieldName, field), leftExtractor, value); } } - - Set reactOnFields = new HashSet<>(); - reactOnFields.addAll(left.getImpactedFields()); - reactOnFields.addAll(right.getImpactedFields()); - - expr("expr:" + left + ":" + operator + ":" + right, - asPredicate1(leftExtractor, operator, right.asFunction(prototype)), - alphaIndex, - reactOn( reactOnFields.toArray(new String[reactOnFields.size()])) ); - - return this; + return null; } private static int getFieldIndex(Prototype prototype, String fieldName, Prototype.Field field) { @@ -174,7 +177,7 @@ public PrototypePatternDef expr(PrototypeExpression left, ConstraintOperator ope reactOnFields.addAll(left.getImpactedFields()); reactOnFields.addAll(right.getImpactedFields()); - expr("expr:" + left + ":" + operator + ":" + right, + expr(createExprId(left, operator, right), other, asPredicate2(left.asFunction(prototype), operator, right.asFunction(otherPrototype)), createBetaIndex(left, operator, right, prototype, otherPrototype), reactOn( reactOnFields.toArray(new String[reactOnFields.size()])) ); @@ -182,6 +185,12 @@ other, asPredicate2(left.asFunction(prototype), operator, right.asFunction(other return this; } + private static String createExprId(PrototypeExpression left, ConstraintOperator operator, PrototypeExpression right) { + Object leftId = left.getIndexingKey().orElse(left.toString()); + Object rightId = right instanceof PrototypeExpression.FixedValue ? ((PrototypeExpression.FixedValue) right).getValue() : right; + return "expr:" + leftId + ":" + operator + ":" + rightId; + } + private BetaIndex createBetaIndex(PrototypeExpression left, ConstraintOperator operator, PrototypeExpression right, Prototype prototype, Prototype otherPrototype) { if (left.getIndexingKey().isPresent() && operator instanceof Index.ConstraintType && right.getIndexingKey().isPresent()) { String fieldName = left.getIndexingKey().get(); diff --git a/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/FactTemplateTest.java b/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/FactTemplateTest.java index a59bbca9717..a425b8248da 100644 --- a/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/FactTemplateTest.java +++ b/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/FactTemplateTest.java @@ -15,19 +15,10 @@ */ package org.drools.model.codegen.execmodel; -import java.math.BigDecimal; -import java.util.Arrays; -import java.util.Collection; -import java.util.List; -import java.util.Map; -import java.util.concurrent.TimeUnit; -import java.util.function.BiPredicate; -import java.util.stream.Collectors; - -import org.drools.core.ClockType; import org.drools.base.facttemplates.Event; import org.drools.base.facttemplates.Fact; import org.drools.base.facttemplates.FactTemplateObjectType; +import org.drools.core.ClockType; import org.drools.core.reteoo.CompositeObjectSinkAdapter; import org.drools.core.reteoo.EntryPointNode; import org.drools.core.reteoo.ObjectTypeNode; @@ -37,6 +28,7 @@ import org.drools.model.Index; import org.drools.model.Model; import org.drools.model.Prototype; +import org.drools.model.PrototypeExpression; import org.drools.model.PrototypeFact; import org.drools.model.PrototypeVariable; import org.drools.model.Query; @@ -44,6 +36,7 @@ import org.drools.model.Variable; import org.drools.model.codegen.execmodel.domain.Person; import org.drools.model.codegen.execmodel.domain.Result; +import org.drools.model.functions.Function1; import org.drools.model.impl.ModelImpl; import org.drools.modelcompiler.KieBaseBuilder; import org.junit.Test; @@ -58,6 +51,17 @@ import org.kie.api.runtime.rule.ViewChangedEventListener; import org.kie.api.time.SessionPseudoClock; +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.function.BiPredicate; +import java.util.stream.Collectors; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.fail; import static org.drools.model.DSL.accFunction; @@ -1455,4 +1459,83 @@ public void bigDecimalEqualityWithDifferentScale_shouldBeEqual() { Collection results = getObjectsIntoList(ksession, Result.class); assertThat(results).contains(new Result("Mark")); } + + @Test + public void testMixFieldNameAndPrototypeExpr() { + // DROOLS-7517 + Prototype personFact = prototype( "org.drools.FactPerson", "name", "age" ); + + PrototypeVariable markV = variable( personFact ); + + Rule r1 = rule( "R1" ) + .build( + protoPattern(markV) + .expr( "name", Index.ConstraintType.EQUAL, "Mark" ), + on(markV).execute((drools, p) -> + drools.insert(new Result("R1")) + ) + ); + + Rule r2 = rule( "R2" ) + .build( + protoPattern(markV) + .expr( "name", Index.ConstraintType.EQUAL, "Mario" ), + on(markV).execute((drools, p) -> + drools.insert(new Result("R2")) + ) + ); + + Rule r3 = rule( "R3" ) + .build( + protoPattern(markV) + .expr( new MyFieldExpression("name"), Index.ConstraintType.EQUAL, fixedValue("Mark") ), + on(markV).execute((drools, p) -> + drools.insert(new Result("R3")) + ) + ); + + Model model = new ModelImpl().addRule( r1 ).addRule( r2 ).addRule( r3 ); + KieBase kieBase = KieBaseBuilder.createKieBaseFromModel( model ); + + KieSession ksession = kieBase.newKieSession(); + + Fact mark = createMapBasedFact(personFact); + mark.set( "name", "Mark" ); + mark.set( "age", 40 ); + + FactHandle fh = ksession.insert( mark ); + assertThat(ksession.fireAllRules()).isEqualTo(2); + + Collection results = getObjectsIntoList(ksession, Result.class); + assertThat(results).containsExactlyInAnyOrder(new Result("R1"), new Result("R3")); + } + + class MyFieldExpression implements PrototypeExpression { + + private final String fieldName; + + MyFieldExpression(String fieldName) { + this.fieldName = fieldName; + } + + @Override + public Function1 asFunction(Prototype prototype) { + return prototype.getFieldValueExtractor(fieldName)::apply; + } + + @Override + public Optional getIndexingKey() { + return Optional.of(fieldName); + } + + @Override + public String toString() { + return "MyFieldExpression{" + fieldName + "}"; + } + + @Override + public Collection getImpactedFields() { + return Collections.singletonList(fieldName); + } + } }