Skip to content

Commit

Permalink
[DROOLS-7517] fix wrong node sharing in PrototypeDSL when using custo…
Browse files Browse the repository at this point in the history
…m PrototypeExpressions
  • Loading branch information
mariofusco committed Jul 26, 2023
1 parent fa31166 commit 39f45ce
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,16 @@ public PrototypePatternDef expr(PrototypeExpression left, ConstraintOperator ope

Prototype prototype = getPrototype();
Function1<PrototypeFact, Object> leftExtractor = left.asFunction(prototype);

AlphaIndex alphaIndex = null;
String exprId = "expr:" + left + ":" + operator + ":" + right;

if (left.getIndexingKey().isPresent() && right instanceof PrototypeExpression.FixedValue && operator instanceof Index.ConstraintType) {
String fieldName = left.getIndexingKey().get();
Index.ConstraintType constraintType = (Index.ConstraintType) operator;
Prototype.Field field = prototype.getField(fieldName);
Object value = ((PrototypeExpression.FixedValue) right).getValue();
exprId = "expr:" + fieldName + ":" + operator + ":" + value;

Class<Object> fieldClass = (Class<Object>) (field != null && field.isTyped() ? field.getType() : value != null ? value.getClass() : null);
if (fieldClass != null) {
Expand All @@ -148,7 +152,7 @@ public PrototypePatternDef expr(PrototypeExpression left, ConstraintOperator ope
reactOnFields.addAll(left.getImpactedFields());
reactOnFields.addAll(right.getImpactedFields());

expr("expr:" + left + ":" + operator + ":" + right,
expr(exprId,
asPredicate1(leftExtractor, operator, right.asFunction(prototype)),
alphaIndex,
reactOn( reactOnFields.toArray(new String[reactOnFields.size()])) );
Expand All @@ -174,7 +178,7 @@ public PrototypePatternDef expr(PrototypeExpression left, ConstraintOperator ope
reactOnFields.addAll(left.getImpactedFields());
reactOnFields.addAll(right.getImpactedFields());

expr("expr:" + left + ":" + operator + ":" + right,
expr("expr:" + left.getIndexingKey().orElse(left.toString()) + ":" + operator + ":" + right,
other, asPredicate2(left.asFunction(prototype), operator, right.asFunction(otherPrototype)),
createBetaIndex(left, operator, right, prototype, otherPrototype),
reactOn( reactOnFields.toArray(new String[reactOnFields.size()])) );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,10 @@
*/
package org.drools.model.codegen.execmodel;

import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.BiPredicate;
import java.util.stream.Collectors;

import org.drools.core.ClockType;
import org.drools.base.facttemplates.Event;
import org.drools.base.facttemplates.Fact;
import org.drools.base.facttemplates.FactTemplateObjectType;
import org.drools.core.ClockType;
import org.drools.core.reteoo.CompositeObjectSinkAdapter;
import org.drools.core.reteoo.EntryPointNode;
import org.drools.core.reteoo.ObjectTypeNode;
Expand All @@ -37,13 +28,15 @@
import org.drools.model.Index;
import org.drools.model.Model;
import org.drools.model.Prototype;
import org.drools.model.PrototypeExpression;
import org.drools.model.PrototypeFact;
import org.drools.model.PrototypeVariable;
import org.drools.model.Query;
import org.drools.model.Rule;
import org.drools.model.Variable;
import org.drools.model.codegen.execmodel.domain.Person;
import org.drools.model.codegen.execmodel.domain.Result;
import org.drools.model.functions.Function1;
import org.drools.model.impl.ModelImpl;
import org.drools.modelcompiler.KieBaseBuilder;
import org.junit.Test;
Expand All @@ -58,6 +51,17 @@
import org.kie.api.runtime.rule.ViewChangedEventListener;
import org.kie.api.time.SessionPseudoClock;

import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.BiPredicate;
import java.util.stream.Collectors;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.fail;
import static org.drools.model.DSL.accFunction;
Expand Down Expand Up @@ -1455,4 +1459,83 @@ public void bigDecimalEqualityWithDifferentScale_shouldBeEqual() {
Collection<Result> results = getObjectsIntoList(ksession, Result.class);
assertThat(results).contains(new Result("Mark"));
}

@Test
public void testMixFieldNameAndPrototypeExpr() {
// DROOLS-7517
Prototype personFact = prototype( "org.drools.FactPerson", "name", "age" );

PrototypeVariable markV = variable( personFact );

Rule r1 = rule( "R1" )
.build(
protoPattern(markV)
.expr( "name", Index.ConstraintType.EQUAL, "Mark" ),
on(markV).execute((drools, p) ->
drools.insert(new Result("R1"))
)
);

Rule r2 = rule( "R2" )
.build(
protoPattern(markV)
.expr( "name", Index.ConstraintType.EQUAL, "Mario" ),
on(markV).execute((drools, p) ->
drools.insert(new Result("R2"))
)
);

Rule r3 = rule( "R3" )
.build(
protoPattern(markV)
.expr( new MyFieldExpression("name"), Index.ConstraintType.EQUAL, fixedValue("Mark") ),
on(markV).execute((drools, p) ->
drools.insert(new Result("R3"))
)
);

Model model = new ModelImpl().addRule( r1 ).addRule( r2 ).addRule( r3 );
KieBase kieBase = KieBaseBuilder.createKieBaseFromModel( model );

KieSession ksession = kieBase.newKieSession();

Fact mark = createMapBasedFact(personFact);
mark.set( "name", "Mark" );
mark.set( "age", 40 );

FactHandle fh = ksession.insert( mark );
assertThat(ksession.fireAllRules()).isEqualTo(2);

Collection<Result> results = getObjectsIntoList(ksession, Result.class);
assertThat(results).containsExactlyInAnyOrder(new Result("R1"), new Result("R3"));
}

class MyFieldExpression implements PrototypeExpression {

private final String fieldName;

MyFieldExpression(String fieldName) {
this.fieldName = fieldName;
}

@Override
public Function1<PrototypeFact, Object> asFunction(Prototype prototype) {
return prototype.getFieldValueExtractor(fieldName)::apply;
}

@Override
public Optional<String> getIndexingKey() {
return Optional.of(fieldName);
}

@Override
public String toString() {
return "MyFieldExpression{" + fieldName + "}";
}

@Override
public Collection<String> getImpactedFields() {
return Collections.singletonList(fieldName);
}
}
}

0 comments on commit 39f45ce

Please sign in to comment.