Skip to content

Commit

Permalink
[DROOLS-7517] fix wrong node sharing in PrototypeDSL when using custo…
Browse files Browse the repository at this point in the history
…m PrototypeExpressions (apache#5405)

* [DROOLS-7517] fix wrong node sharing in PrototypeDSL when using custom PrototypeExpressions

* centralize generation of exprId
  • Loading branch information
mariofusco committed Jul 27, 2023
1 parent c61f57a commit e79505a
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,20 @@ public PrototypePatternDef expr(PrototypeExpression left, ConstraintOperator ope

Prototype prototype = getPrototype();
Function1<PrototypeFact, Object> leftExtractor = left.asFunction(prototype);
AlphaIndex alphaIndex = null;

Set<String> reactOnFields = new HashSet<>();
reactOnFields.addAll(left.getImpactedFields());
reactOnFields.addAll(right.getImpactedFields());

expr(createExprId(left, operator, right),
asPredicate1(leftExtractor, operator, right.asFunction(prototype)),
createAlphaIndex(left, operator, right, prototype, leftExtractor),
reactOn( reactOnFields.toArray(new String[reactOnFields.size()])) );

return this;
}

private static AlphaIndex createAlphaIndex(PrototypeExpression left, ConstraintOperator operator, PrototypeExpression right, Prototype prototype, Function1<PrototypeFact, Object> leftExtractor) {
if (left.getIndexingKey().isPresent() && right instanceof PrototypeExpression.FixedValue && operator instanceof Index.ConstraintType) {
String fieldName = left.getIndexingKey().get();
Index.ConstraintType constraintType = (Index.ConstraintType) operator;
Expand All @@ -140,20 +153,10 @@ public PrototypePatternDef expr(PrototypeExpression left, ConstraintOperator ope

Class<Object> fieldClass = (Class<Object>) (field != null && field.isTyped() ? field.getType() : value != null ? value.getClass() : null);
if (fieldClass != null) {
alphaIndex = alphaIndexedBy(fieldClass, constraintType, getFieldIndex(prototype, fieldName, field), leftExtractor, value);
return alphaIndexedBy(fieldClass, constraintType, getFieldIndex(prototype, fieldName, field), leftExtractor, value);
}
}

Set<String> reactOnFields = new HashSet<>();
reactOnFields.addAll(left.getImpactedFields());
reactOnFields.addAll(right.getImpactedFields());

expr("expr:" + left + ":" + operator + ":" + right,
asPredicate1(leftExtractor, operator, right.asFunction(prototype)),
alphaIndex,
reactOn( reactOnFields.toArray(new String[reactOnFields.size()])) );

return this;
return null;
}

private static int getFieldIndex(Prototype prototype, String fieldName, Prototype.Field field) {
Expand All @@ -174,14 +177,20 @@ public PrototypePatternDef expr(PrototypeExpression left, ConstraintOperator ope
reactOnFields.addAll(left.getImpactedFields());
reactOnFields.addAll(right.getImpactedFields());

expr("expr:" + left + ":" + operator + ":" + right,
expr(createExprId(left, operator, right),
other, asPredicate2(left.asFunction(prototype), operator, right.asFunction(otherPrototype)),
createBetaIndex(left, operator, right, prototype, otherPrototype),
reactOn( reactOnFields.toArray(new String[reactOnFields.size()])) );

return this;
}

private static String createExprId(PrototypeExpression left, ConstraintOperator operator, PrototypeExpression right) {
Object leftId = left.getIndexingKey().orElse(left.toString());
Object rightId = right instanceof PrototypeExpression.FixedValue ? ((PrototypeExpression.FixedValue) right).getValue() : right;
return "expr:" + leftId + ":" + operator + ":" + rightId;
}

private BetaIndex createBetaIndex(PrototypeExpression left, ConstraintOperator operator, PrototypeExpression right, Prototype prototype, Prototype otherPrototype) {
if (left.getIndexingKey().isPresent() && operator instanceof Index.ConstraintType && right.getIndexingKey().isPresent()) {
String fieldName = left.getIndexingKey().get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,10 @@
*/
package org.drools.model.codegen.execmodel;

import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.BiPredicate;
import java.util.stream.Collectors;

import org.drools.core.ClockType;
import org.drools.base.facttemplates.Event;
import org.drools.base.facttemplates.Fact;
import org.drools.base.facttemplates.FactTemplateObjectType;
import org.drools.core.ClockType;
import org.drools.core.reteoo.CompositeObjectSinkAdapter;
import org.drools.core.reteoo.EntryPointNode;
import org.drools.core.reteoo.ObjectTypeNode;
Expand All @@ -37,13 +28,15 @@
import org.drools.model.Index;
import org.drools.model.Model;
import org.drools.model.Prototype;
import org.drools.model.PrototypeExpression;
import org.drools.model.PrototypeFact;
import org.drools.model.PrototypeVariable;
import org.drools.model.Query;
import org.drools.model.Rule;
import org.drools.model.Variable;
import org.drools.model.codegen.execmodel.domain.Person;
import org.drools.model.codegen.execmodel.domain.Result;
import org.drools.model.functions.Function1;
import org.drools.model.impl.ModelImpl;
import org.drools.modelcompiler.KieBaseBuilder;
import org.junit.Test;
Expand All @@ -58,6 +51,17 @@
import org.kie.api.runtime.rule.ViewChangedEventListener;
import org.kie.api.time.SessionPseudoClock;

import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.BiPredicate;
import java.util.stream.Collectors;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.fail;
import static org.drools.model.DSL.accFunction;
Expand Down Expand Up @@ -1455,4 +1459,83 @@ public void bigDecimalEqualityWithDifferentScale_shouldBeEqual() {
Collection<Result> results = getObjectsIntoList(ksession, Result.class);
assertThat(results).contains(new Result("Mark"));
}

@Test
public void testMixFieldNameAndPrototypeExpr() {
// DROOLS-7517
Prototype personFact = prototype( "org.drools.FactPerson", "name", "age" );

PrototypeVariable markV = variable( personFact );

Rule r1 = rule( "R1" )
.build(
protoPattern(markV)
.expr( "name", Index.ConstraintType.EQUAL, "Mark" ),
on(markV).execute((drools, p) ->
drools.insert(new Result("R1"))
)
);

Rule r2 = rule( "R2" )
.build(
protoPattern(markV)
.expr( "name", Index.ConstraintType.EQUAL, "Mario" ),
on(markV).execute((drools, p) ->
drools.insert(new Result("R2"))
)
);

Rule r3 = rule( "R3" )
.build(
protoPattern(markV)
.expr( new MyFieldExpression("name"), Index.ConstraintType.EQUAL, fixedValue("Mark") ),
on(markV).execute((drools, p) ->
drools.insert(new Result("R3"))
)
);

Model model = new ModelImpl().addRule( r1 ).addRule( r2 ).addRule( r3 );
KieBase kieBase = KieBaseBuilder.createKieBaseFromModel( model );

KieSession ksession = kieBase.newKieSession();

Fact mark = createMapBasedFact(personFact);
mark.set( "name", "Mark" );
mark.set( "age", 40 );

FactHandle fh = ksession.insert( mark );
assertThat(ksession.fireAllRules()).isEqualTo(2);

Collection<Result> results = getObjectsIntoList(ksession, Result.class);
assertThat(results).containsExactlyInAnyOrder(new Result("R1"), new Result("R3"));
}

class MyFieldExpression implements PrototypeExpression {

private final String fieldName;

MyFieldExpression(String fieldName) {
this.fieldName = fieldName;
}

@Override
public Function1<PrototypeFact, Object> asFunction(Prototype prototype) {
return prototype.getFieldValueExtractor(fieldName)::apply;
}

@Override
public Optional<String> getIndexingKey() {
return Optional.of(fieldName);
}

@Override
public String toString() {
return "MyFieldExpression{" + fieldName + "}";
}

@Override
public Collection<String> getImpactedFields() {
return Collections.singletonList(fieldName);
}
}
}

0 comments on commit e79505a

Please sign in to comment.