diff --git a/drools-core/src/main/java/org/drools/core/util/index/IndexUtil.java b/drools-core/src/main/java/org/drools/core/util/index/IndexUtil.java index 4dcb6aea287..b48d47a9c07 100644 --- a/drools-core/src/main/java/org/drools/core/util/index/IndexUtil.java +++ b/drools-core/src/main/java/org/drools/core/util/index/IndexUtil.java @@ -197,8 +197,8 @@ private static void sortRangeIndexable(BetaNodeFieldConstraint[] constraints, bo indexable[0] = true; } - private static boolean isEqualIndexable(BetaNodeFieldConstraint constraint) { - return constraint instanceof IndexableConstraint && ((IndexableConstraint)constraint).getConstraintType() == ConstraintType.EQUAL; + static boolean isEqualIndexable(BetaNodeFieldConstraint constraint) { + return constraint instanceof IndexableConstraint && ((IndexableConstraint)constraint).getConstraintType() == ConstraintType.EQUAL && !isBigDecimalEqualityConstraint((IndexableConstraint)constraint); } private static void swap(BetaNodeFieldConstraint[] constraints, int p1, int p2) { @@ -334,7 +334,7 @@ public static ConstraintType getType(Constraint constraint) { public static class Factory { public static BetaMemory createBetaMemory(RuleBaseConfiguration config, short nodeType, BetaNodeFieldConstraint... constraints) { - if (config.getCompositeKeyDepth() < 1 || containsBigDecimalEqualityConstraint(constraints)) { + if (config.getCompositeKeyDepth() < 1) { return new BetaMemory( config.isSequential() ? null : new TupleList(), new TupleList(), createContext(constraints), @@ -348,17 +348,8 @@ public static BetaMemory createBetaMemory(RuleBaseConfiguration config, short no nodeType ); } - private static boolean containsBigDecimalEqualityConstraint(BetaNodeFieldConstraint[] constraints) { - for (BetaNodeFieldConstraint constraint : constraints) { - if (constraint instanceof IndexableConstraint && isBigDecimalEqualityConstraint((IndexableConstraint) constraint)) { - return true; - } - } - return false; - } - private static TupleMemory createRightMemory(RuleBaseConfiguration config, IndexSpec indexSpec) { - if ( !config.isIndexRightBetaMemory() || !indexSpec.constraintType.isIndexable() ) { + if ( !config.isIndexRightBetaMemory() || !indexSpec.constraintType.isIndexable() || indexSpec.indexes.length == 0 ) { return new TupleList(); } @@ -377,7 +368,7 @@ private static TupleMemory createLeftMemory(RuleBaseConfiguration config, IndexS if (config.isSequential()) { return null; } - if ( !config.isIndexLeftBetaMemory() || !indexSpec.constraintType.isIndexable() ) { + if ( !config.isIndexLeftBetaMemory() || !indexSpec.constraintType.isIndexable() || indexSpec.indexes.length == 0 ) { return new TupleList(); } @@ -417,11 +408,13 @@ private void init(short nodeType, BetaNodeFieldConstraint[] constraints, RuleBas if (constraintType == ConstraintType.EQUAL) { List indexList = new ArrayList<>(); - indexList.add(((IndexableConstraint)constraints[firstIndexableConstraint]).getFieldIndex()); + if (isEqualIndexable(constraints[firstIndexableConstraint])) { + indexList.add(((IndexableConstraint)constraints[firstIndexableConstraint]).getFieldIndex()); + } // look for other EQUAL constraint to eventually add them to the index for (int i = firstIndexableConstraint+1; i < constraints.length && indexList.size() < keyDepth; i++) { - if ( ConstraintType.getType(constraints[i]) == ConstraintType.EQUAL && ! ((IndexableConstraint) constraints[i]).isUnification() ) { + if ( isEqualIndexable(constraints[i]) && ! ((IndexableConstraint) constraints[i]).isUnification() ) { indexList.add(((IndexableConstraint)constraints[i]).getFieldIndex()); } } diff --git a/drools-core/src/test/java/org/drools/core/util/index/IndexUtilTest.java b/drools-core/src/test/java/org/drools/core/util/index/IndexUtilTest.java new file mode 100644 index 00000000000..f700b1dcdc8 --- /dev/null +++ b/drools-core/src/test/java/org/drools/core/util/index/IndexUtilTest.java @@ -0,0 +1,412 @@ +/* + * Copyright 2023 Red Hat, Inc. and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.drools.core.util.index; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.lang.reflect.Method; +import java.math.BigDecimal; +import java.math.BigInteger; + +import org.drools.core.RuleBaseConfiguration; +import org.drools.core.base.ValueType; +import org.drools.core.common.InternalFactHandle; +import org.drools.core.common.InternalWorkingMemory; +import org.drools.core.reteoo.BetaMemory; +import org.drools.core.reteoo.NodeTypeEnums; +import org.drools.core.rule.ContextEntry; +import org.drools.core.rule.Declaration; +import org.drools.core.rule.IndexableConstraint; +import org.drools.core.spi.BetaNodeFieldConstraint; +import org.drools.core.spi.Constraint; +import org.drools.core.spi.FieldValue; +import org.drools.core.spi.InternalReadAccessor; +import org.drools.core.spi.Tuple; +import org.drools.core.spi.TupleValueExtractor; +import org.drools.core.util.AbstractHashTable.DoubleCompositeIndex; +import org.drools.core.util.AbstractHashTable.FieldIndex; +import org.drools.core.util.AbstractHashTable.Index; +import org.junit.Test; +import org.kie.internal.conf.IndexPrecedenceOption; + +import static org.assertj.core.api.Assertions.assertThat; + +public class IndexUtilTest { + + @Test + public void isEqualIndexable() { + FakeBetaNodeFieldConstraint intEqualsConstraint = new FakeBetaNodeFieldConstraint(IndexUtil.ConstraintType.EQUAL, new FakeInternalReadAccessor(ValueType.PINTEGER_TYPE)); + assertThat(IndexUtil.isEqualIndexable(intEqualsConstraint)).isTrue(); + + FakeBetaNodeFieldConstraint intLessThanConstraint = new FakeBetaNodeFieldConstraint(IndexUtil.ConstraintType.LESS_THAN, new FakeInternalReadAccessor(ValueType.PINTEGER_TYPE)); + assertThat(IndexUtil.isEqualIndexable(intLessThanConstraint)).isFalse(); + + FakeBetaNodeFieldConstraint bigDecimalEqualsConstraint = new FakeBetaNodeFieldConstraint(IndexUtil.ConstraintType.EQUAL, new FakeInternalReadAccessor(ValueType.BIG_DECIMAL_TYPE)); + assertThat(IndexUtil.isEqualIndexable(bigDecimalEqualsConstraint)).as("BigDecimal equality cannot be indexed because of scale").isFalse(); + } + + @Test + public void createBetaMemoryWithIntEquals_shouldBeTupleIndexHashTable() { + RuleBaseConfiguration config = new RuleBaseConfiguration(); + FakeBetaNodeFieldConstraint intEqualsConstraint = new FakeBetaNodeFieldConstraint(IndexUtil.ConstraintType.EQUAL, new FakeInternalReadAccessor(ValueType.PINTEGER_TYPE)); + BetaMemory betaMemory = IndexUtil.Factory.createBetaMemory(config, NodeTypeEnums.JoinNode, intEqualsConstraint); + assertThat(betaMemory.getLeftTupleMemory()).isInstanceOf(TupleIndexHashTable.class); + assertThat(betaMemory.getRightTupleMemory()).isInstanceOf(TupleIndexHashTable.class); + } + + @Test + public void createBetaMemoryWithBigDecimalEquals_shouldNotBeTupleIndexHashTable() { + RuleBaseConfiguration config = new RuleBaseConfiguration(); + FakeBetaNodeFieldConstraint bigDecimalEqualsConstraint = new FakeBetaNodeFieldConstraint(IndexUtil.ConstraintType.EQUAL, new FakeInternalReadAccessor(ValueType.BIG_DECIMAL_TYPE)); + BetaMemory betaMemory = IndexUtil.Factory.createBetaMemory(config, NodeTypeEnums.JoinNode, bigDecimalEqualsConstraint); + assertThat(betaMemory.getLeftTupleMemory()).isInstanceOf(TupleList.class); + assertThat(betaMemory.getRightTupleMemory()).isInstanceOf(TupleList.class); + } + + @Test + public void createBetaMemoryWithBigDecimalEqualsAndOtherIndexableConstraints_shouldBeTupleIndexHashTableButBigDecimalIsNotIndexed() { + RuleBaseConfiguration config = new RuleBaseConfiguration(); + FakeBetaNodeFieldConstraint intEqualsConstraint = new FakeBetaNodeFieldConstraint(IndexUtil.ConstraintType.EQUAL, new FakeInternalReadAccessor(ValueType.PINTEGER_TYPE)); + FakeBetaNodeFieldConstraint bigDecimalEqualsConstraint = new FakeBetaNodeFieldConstraint(IndexUtil.ConstraintType.EQUAL, new FakeInternalReadAccessor(ValueType.BIG_DECIMAL_TYPE)); + FakeBetaNodeFieldConstraint stringEqualsConstraint = new FakeBetaNodeFieldConstraint(IndexUtil.ConstraintType.EQUAL, new FakeInternalReadAccessor(ValueType.STRING_TYPE)); + BetaMemory betaMemory = IndexUtil.Factory.createBetaMemory(config, NodeTypeEnums.JoinNode, intEqualsConstraint, bigDecimalEqualsConstraint, stringEqualsConstraint); + + // BigDecimal is not included in Indexes + assertThat(betaMemory.getLeftTupleMemory()).isInstanceOf(TupleIndexHashTable.class); + Index leftIndex = ((TupleIndexHashTable) betaMemory.getLeftTupleMemory()).getIndex(); + assertThat(leftIndex).isInstanceOf(DoubleCompositeIndex.class); + FieldIndex leftFieldIndex0 = leftIndex.getFieldIndex(0); + assertThat(leftFieldIndex0.getLeftExtractor().getValueType()).isEqualTo(ValueType.PINTEGER_TYPE); + FieldIndex leftFieldIndex1 = leftIndex.getFieldIndex(1); + assertThat(leftFieldIndex1.getLeftExtractor().getValueType()).isEqualTo(ValueType.STRING_TYPE); + + assertThat(betaMemory.getRightTupleMemory()).isInstanceOf(TupleIndexHashTable.class); + Index rightIndex = ((TupleIndexHashTable) betaMemory.getRightTupleMemory()).getIndex(); + assertThat(rightIndex).isInstanceOf(DoubleCompositeIndex.class); + FieldIndex rightFieldIndex0 = rightIndex.getFieldIndex(0); + assertThat(rightFieldIndex0.getRightExtractor().getValueType()).isEqualTo(ValueType.PINTEGER_TYPE); + FieldIndex rightFieldIndex1 = rightIndex.getFieldIndex(1); + assertThat(rightFieldIndex1.getRightExtractor().getValueType()).isEqualTo(ValueType.STRING_TYPE); + } + + @Test + public void isIndexableForNodeWithIntAndString() { + RuleBaseConfiguration config = new RuleBaseConfiguration(); + FakeBetaNodeFieldConstraint intEqualsConstraint = new FakeBetaNodeFieldConstraint(IndexUtil.ConstraintType.EQUAL, new FakeInternalReadAccessor(ValueType.PINTEGER_TYPE)); + FakeBetaNodeFieldConstraint stringEqualsConstraint = new FakeBetaNodeFieldConstraint(IndexUtil.ConstraintType.EQUAL, new FakeInternalReadAccessor(ValueType.STRING_TYPE)); + BetaNodeFieldConstraint[] constraints = new FakeBetaNodeFieldConstraint[]{intEqualsConstraint, stringEqualsConstraint}; + boolean[] indexed = IndexUtil.isIndexableForNode(IndexPrecedenceOption.EQUALITY_PRIORITY, NodeTypeEnums.JoinNode, config.getCompositeKeyDepth(), constraints, config); + assertThat(indexed).containsExactly(true, true); + } + + @Test + public void isIndexableForNodeWithIntAndBigDecimalAndString() { + RuleBaseConfiguration config = new RuleBaseConfiguration(); + FakeBetaNodeFieldConstraint intEqualsConstraint = new FakeBetaNodeFieldConstraint(IndexUtil.ConstraintType.EQUAL, new FakeInternalReadAccessor(ValueType.PINTEGER_TYPE)); + FakeBetaNodeFieldConstraint bigDecimalEqualsConstraint = new FakeBetaNodeFieldConstraint(IndexUtil.ConstraintType.EQUAL, new FakeInternalReadAccessor(ValueType.BIG_DECIMAL_TYPE)); + FakeBetaNodeFieldConstraint stringEqualsConstraint = new FakeBetaNodeFieldConstraint(IndexUtil.ConstraintType.EQUAL, new FakeInternalReadAccessor(ValueType.STRING_TYPE)); + BetaNodeFieldConstraint[] constraints = new FakeBetaNodeFieldConstraint[]{intEqualsConstraint, bigDecimalEqualsConstraint, stringEqualsConstraint}; + boolean[] indexed = IndexUtil.isIndexableForNode(IndexPrecedenceOption.EQUALITY_PRIORITY, NodeTypeEnums.JoinNode, config.getCompositeKeyDepth(), constraints, config); + assertThat(indexed).as("BigDecimal is sorted to the last").containsExactly(true, true, false); + } + + @Test + public void isIndexableForNodeWithBigDecimal() { + RuleBaseConfiguration config = new RuleBaseConfiguration(); + FakeBetaNodeFieldConstraint bigDecimalEqualsConstraint = new FakeBetaNodeFieldConstraint(IndexUtil.ConstraintType.EQUAL, new FakeInternalReadAccessor(ValueType.BIG_DECIMAL_TYPE)); + BetaNodeFieldConstraint[] constraints = new FakeBetaNodeFieldConstraint[]{bigDecimalEqualsConstraint}; + boolean[] indexed = IndexUtil.isIndexableForNode(IndexPrecedenceOption.EQUALITY_PRIORITY, NodeTypeEnums.JoinNode, config.getCompositeKeyDepth(), constraints, config); + assertThat(indexed).as("BigDecimal is not indexed").containsExactly(false); + } + + static class FakeBetaNodeFieldConstraint implements BetaNodeFieldConstraint, + IndexableConstraint { + + private IndexUtil.ConstraintType constraintType; + private InternalReadAccessor fieldExtractor; + + public FakeBetaNodeFieldConstraint() {} + + public FakeBetaNodeFieldConstraint(IndexUtil.ConstraintType constraintType, InternalReadAccessor fieldExtractor) { + this.constraintType = constraintType; + this.fieldExtractor = fieldExtractor; + } + + @Override + public Declaration[] getRequiredDeclarations() { + return null; + } + + @Override + public void replaceDeclaration(Declaration oldDecl, Declaration newDecl) {} + + @Override + public Constraint clone() { + return null; + } + + @Override + public ConstraintType getType() { + return null; + } + + @Override + public boolean isTemporal() { + return false; + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException {} + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {} + + @Override + public boolean isAllowedCachedLeft(ContextEntry context, InternalFactHandle handle) { + return false; + } + + @Override + public boolean isAllowedCachedRight(Tuple tuple, ContextEntry context) { + return false; + } + + @Override + public ContextEntry createContextEntry() { + return null; + } + + @Override + public BetaNodeFieldConstraint cloneIfInUse() { + return null; + } + + @Override + public boolean isUnification() { + return false; + } + + @Override + public boolean isIndexable(short nodeType, RuleBaseConfiguration config) { + return false; + } + + @Override + public org.drools.core.util.index.IndexUtil.ConstraintType getConstraintType() { + return constraintType; + } + + @Override + public FieldValue getField() { + return null; + } + + @Override + public FieldIndex getFieldIndex() { + return new FieldIndex(fieldExtractor, new Declaration("$p1", fieldExtractor, null)); + } + + @Override + public InternalReadAccessor getFieldExtractor() { + return fieldExtractor; + } + + @Override + public TupleValueExtractor getIndexExtractor() { + return null; + } + } + + static class FakeInternalReadAccessor implements InternalReadAccessor { + + private final ValueType valueType; + + private FakeInternalReadAccessor(ValueType valueType) { + this.valueType = valueType; + } + + @Override + public Object getValue(Object object) { + return null; + } + + @Override + public BigDecimal getBigDecimalValue(Object object) { + return null; + } + + @Override + public BigInteger getBigIntegerValue(Object object) { + return null; + } + + @Override + public char getCharValue(Object object) { + return 0; + } + + @Override + public int getIntValue(Object object) { + return 0; + } + + @Override + public byte getByteValue(Object object) { + return 0; + } + + @Override + public short getShortValue(Object object) { + return 0; + } + + @Override + public long getLongValue(Object object) { + return 0; + } + + @Override + public float getFloatValue(Object object) { + return 0; + } + + @Override + public double getDoubleValue(Object object) { + return 0; + } + + @Override + public boolean getBooleanValue(Object object) { + return false; + } + + @Override + public boolean isNullValue(Object object) { + return false; + } + + @Override + public ValueType getValueType() { + return valueType; + } + + @Override + public Class getExtractToClass() { + return null; + } + + @Override + public String getExtractToClassName() { + return null; + } + + @Override + public Method getNativeReadMethod() { + return null; + } + + @Override + public String getNativeReadMethodName() { + return null; + } + + @Override + public int getHashCode(Object object) { + return 0; + } + + @Override + public int getIndex() { + return 0; + } + + @Override + public Object getValue(InternalWorkingMemory workingMemory, Object object) { + return null; + } + + @Override + public BigDecimal getBigDecimalValue(InternalWorkingMemory workingMemory, Object object) { + return null; + } + + @Override + public BigInteger getBigIntegerValue(InternalWorkingMemory workingMemory, Object object) { + return null; + } + + @Override + public char getCharValue(InternalWorkingMemory workingMemory, Object object) { + return 0; + } + + @Override + public int getIntValue(InternalWorkingMemory workingMemory, Object object) { + return 0; + } + + @Override + public byte getByteValue(InternalWorkingMemory workingMemory, Object object) { + return 0; + } + + @Override + public short getShortValue(InternalWorkingMemory workingMemory, Object object) { + return 0; + } + + @Override + public long getLongValue(InternalWorkingMemory workingMemory, Object object) { + return 0; + } + + @Override + public float getFloatValue(InternalWorkingMemory workingMemory, Object object) { + return 0; + } + + @Override + public double getDoubleValue(InternalWorkingMemory workingMemory, Object object) { + return 0; + } + + @Override + public boolean getBooleanValue(InternalWorkingMemory workingMemory, Object object) { + return false; + } + + @Override + public boolean isNullValue(InternalWorkingMemory workingMemory, Object object) { + return false; + } + + @Override + public int getHashCode(InternalWorkingMemory workingMemory, Object object) { + return 0; + } + + @Override + public boolean isGlobal() { + return false; + } + + @Override + public boolean isSelfReference() { + return false; + } + } +} diff --git a/drools-test-coverage/test-compiler-integration/src/test/java/org/drools/compiler/integrationtests/IndexingTest.java b/drools-test-coverage/test-compiler-integration/src/test/java/org/drools/compiler/integrationtests/IndexingTest.java index 9823eef4bcd..dafa0380a8f 100644 --- a/drools-test-coverage/test-compiler-integration/src/test/java/org/drools/compiler/integrationtests/IndexingTest.java +++ b/drools-test-coverage/test-compiler-integration/src/test/java/org/drools/compiler/integrationtests/IndexingTest.java @@ -26,6 +26,7 @@ import org.drools.ancompiler.CompiledNetwork; import org.drools.core.base.ClassObjectType; import org.drools.core.base.DroolsQuery; +import org.drools.core.common.BetaConstraints; import org.drools.core.common.DoubleNonIndexSkipBetaConstraints; import org.drools.core.common.EmptyBetaConstraints; import org.drools.core.common.InternalFactHandle; @@ -42,7 +43,6 @@ import org.drools.core.reteoo.NotNode; import org.drools.core.reteoo.ObjectSinkPropagator; import org.drools.core.reteoo.ObjectTypeNode; -import org.drools.core.reteoo.ReteDumper; import org.drools.core.reteoo.RightTuple; import org.drools.core.util.FastIterator; import org.drools.core.util.index.TupleIndexHashTable; @@ -292,7 +292,6 @@ public void testIndexingOnQueryUnificationWithNot() { final KieBase kbase = KieBaseUtil.getKieBaseFromKieModuleFromDrl("indexing-test", kieBaseTestConfiguration, drl); final StatefulKnowledgeSessionImpl wm = (StatefulKnowledgeSessionImpl) kbase.newKieSession(); - ReteDumper.dumpRete( wm ); try { final List nodes = ((KnowledgeBaseImpl) kbase).getRete().getObjectTypeNodes(); ObjectTypeNode node = null; @@ -1016,4 +1015,81 @@ public void testBetaIndexWithBigDecimalDifferentScale() { ksession.dispose(); } } + + @Test + public void betaIndexWithBigDecimalAndInt() { + String constraints = "salary == $p1.salary, age == $p1.age"; + betaIndexWithBigDecimalWithAdditionalBetaConstraint(constraints, new Person("John", 30, new BigDecimal("10")), new Person("Paul", 30, new BigDecimal("10")), true, 1); + betaIndexWithBigDecimalWithAdditionalBetaConstraint(constraints, new Person("John", 30, new BigDecimal("10")), new Person("Paul", 28, new BigDecimal("10")), false, 1); + } + + @Test + public void betaIndexWithIntAndBigDecimal() { + String constraints = "age == $p1.age, salary == $p1.salary"; + betaIndexWithBigDecimalWithAdditionalBetaConstraint(constraints, new Person("John", 30, new BigDecimal("10")), new Person("Paul", 30, new BigDecimal("10")), true, 1); + betaIndexWithBigDecimalWithAdditionalBetaConstraint(constraints, new Person("John", 30, new BigDecimal("10")), new Person("Paul", 28, new BigDecimal("10")), false, 1); + } + + @Test + public void betaIndexWithIntAndBigDecimalAndString() { + String constraints = "age == $p1.age, salary == $p1.salary, likes == $p1.likes"; + betaIndexWithBigDecimalWithAdditionalBetaConstraint(constraints, new Person("John", 30, new BigDecimal("10"), "dog"), new Person("Paul", 30, new BigDecimal("10"), "dog"), true, 2); + betaIndexWithBigDecimalWithAdditionalBetaConstraint(constraints, new Person("John", 30, new BigDecimal("10"), "dog"), new Person("Paul", 30, new BigDecimal("10"), "cat"), false, 2); + } + + @Test + public void betaIndexWithIntInequalityAndBigDecimal() { + String constraints = "age > $p1.age, salary == $p1.salary"; + betaIndexWithBigDecimalWithAdditionalBetaConstraint(constraints, new Person("John", 30, new BigDecimal("10")), new Person("Paul", 40, new BigDecimal("10")), true, 0); + betaIndexWithBigDecimalWithAdditionalBetaConstraint(constraints, new Person("John", 30, new BigDecimal("10")), new Person("Paul", 28, new BigDecimal("10")), false, 0); + } + + @Test + public void betaIndexWithBigDecimalOnly() { + String constraints = "salary == $p1.salary"; + betaIndexWithBigDecimalWithAdditionalBetaConstraint(constraints, new Person("John", 30, new BigDecimal("10")), new Person("Paul", 28, new BigDecimal("10")), true, 0); + betaIndexWithBigDecimalWithAdditionalBetaConstraint(constraints, new Person("John", 30, new BigDecimal("10")), new Person("Paul", 28, new BigDecimal("20")), false, 0); + } + + private void betaIndexWithBigDecimalWithAdditionalBetaConstraint(String constraints, Person firstPerson, Person secondPerson, boolean shouldMatch, int expectedIndexCount) { + final String drl = + "package org.drools.compiler.test\n" + + "import " + Person.class.getCanonicalName() + "\n" + + "global java.util.List list\n" + + "rule R1\n" + + " when\n" + + " $p1 : Person( name == \"John\" )\n" + + " $p2 : Person( name == \"Paul\", " + constraints + " )\n" + + " then\n" + + " list.add(\"R1\");\n" + + "end"; + + final KieBase kbase = KieBaseUtil.getKieBaseFromKieModuleFromDrl("indexing-test", kieBaseTestConfiguration, drl); + + assertBetaIndex(kbase, Person.class, expectedIndexCount); + + KieSession ksession = kbase.newKieSession(); + + try { + List list = new ArrayList<>(); + ksession.setGlobal("list", list); + ksession.insert(firstPerson); + ksession.insert(secondPerson); + ksession.fireAllRules(); + + if (shouldMatch) { + assertThat(list).as("These constraints should match : " + constraints).containsExactly("R1"); + } else { + assertThat(list).as("These constraints should not match : " + constraints).isEmpty(); + } + } finally { + ksession.dispose(); + } + } + + private void assertBetaIndex(KieBase kbase, Class clazz, int expectedIndexCount) { + final JoinNode joinNode = KieUtil.getJoinNode(kbase, clazz); + BetaConstraints betaConstraints = joinNode.getRawConstraints(); + assertThat(betaConstraints.getIndexCount()).as("IndexCount represents how many constrains are indexed").isEqualTo(expectedIndexCount); + } } diff --git a/drools-test-coverage/test-suite/src/test/java/org/drools/testcoverage/common/model/Person.java b/drools-test-coverage/test-suite/src/test/java/org/drools/testcoverage/common/model/Person.java index 61baedcf1b0..c8fd0696823 100644 --- a/drools-test-coverage/test-suite/src/test/java/org/drools/testcoverage/common/model/Person.java +++ b/drools-test-coverage/test-suite/src/test/java/org/drools/testcoverage/common/model/Person.java @@ -79,6 +79,13 @@ public Person(final String name, final int age, final BigDecimal salary) { this.salary = salary; } + public Person(final String name, final int age, final BigDecimal salary, String likes) { + this.name = name; + this.age = age; + this.salary = salary; + this.likes = likes; + } + public int getId() { return id; } diff --git a/drools-test-coverage/test-suite/src/test/java/org/drools/testcoverage/common/util/KieUtil.java b/drools-test-coverage/test-suite/src/test/java/org/drools/testcoverage/common/util/KieUtil.java index 4320c25b795..e566454b425 100644 --- a/drools-test-coverage/test-suite/src/test/java/org/drools/testcoverage/common/util/KieUtil.java +++ b/drools-test-coverage/test-suite/src/test/java/org/drools/testcoverage/common/util/KieUtil.java @@ -18,6 +18,7 @@ import java.io.StringReader; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -25,8 +26,13 @@ import org.drools.compiler.kie.builder.impl.DrlProject; import org.drools.core.base.ClassObjectType; +import org.drools.core.common.BaseNode; import org.drools.core.impl.KnowledgeBaseImpl; +import org.drools.core.reteoo.EntryPointNode; +import org.drools.core.reteoo.JoinNode; +import org.drools.core.reteoo.ObjectSinkNode; import org.drools.core.reteoo.ObjectTypeNode; +import org.drools.core.reteoo.Sink; import org.kie.api.KieBase; import org.kie.api.KieServices; import org.kie.api.builder.KieBuilder; @@ -337,6 +343,32 @@ public static ObjectTypeNode getObjectTypeNode(final KieBase kbase, final Class< return null; } + // This method returns the first JoinNode found which meets the factClass + public static JoinNode getJoinNode(final KieBase kbase, final Class factClass) { + Collection entryPointNodes = ((KnowledgeBaseImpl) kbase).getRete().getEntryPointNodes().values(); + for (EntryPointNode entryPointNode : entryPointNodes) { + JoinNode joinNode = findNode(entryPointNode, JoinNode.class); + if (joinNode.getObjectTypeNode().getObjectType().getClassType().equals(factClass)) { + return joinNode; + } + } + return null; + } + + private static T findNode(BaseNode node, Class nodeClass) { + if (node.getClass().equals(nodeClass)) { + return (T)node; + } else { + Sink[] sinks = node.getSinks(); + for (Sink sink : sinks) { + if (sink instanceof BaseNode) { + return findNode((BaseNode)sink, nodeClass); + } + } + return null; + } + } + private KieUtil() { // Creating instances of util classes should not be possible. }