Skip to content

Commit

Permalink
first accumulate implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
mariofusco committed Jun 21, 2022
1 parent a24baa4 commit befa642
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 25 deletions.
2 changes: 1 addition & 1 deletion drools-ruleunits/drools-ruleunits-dsl/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<parent>
<groupId>org.drools</groupId>
<artifactId>drools-ruleunits</artifactId>
<version>8.23.0-SNAPSHOT</version>
<version>8.24.0-SNAPSHOT</version>
</parent>

<groupId>org.drools</groupId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package org.drools.ruleunits.dsl;

import java.util.function.Supplier;

import org.drools.core.base.accumulators.IntegerSumAccumulateFunction;
import org.drools.model.PatternDSL;
import org.drools.model.RuleItemBuilder;
import org.drools.model.Variable;
import org.drools.model.functions.Function1;

import static org.drools.model.DSL.accFunction;
import static org.drools.model.DSL.accumulate;
import static org.drools.model.DSL.declarationOf;

public class Accumulators {

public static <A, B> Accumulator1<A, B> sum(Function1<A, B> bindingFunc) {
return new Accumulator1<>(bindingFunc, IntegerSumAccumulateFunction::new, Integer.class);
}

public static class Accumulator1<A, B> {
private final Function1<A, B> bindingFunc;
private final Supplier<?> accFuncSupplier;
private final Class<?> accClass;

public Accumulator1(Function1<A, B> bindingFunc, Supplier<?> accFuncSupplier, Class<?> accClass) {
this.bindingFunc = bindingFunc;
this.accFuncSupplier = accFuncSupplier;
this.accClass = accClass;
}
}

public static class AccumulatePattern1<A, B> extends RuleFactory.Pattern1<B> {

private final RuleFactory.Pattern1<A> pattern;
private final Accumulator1<A, B> acc;

public AccumulatePattern1(RuleFactory rule, RuleFactory.Pattern1<A> pattern, Accumulator1<A, B> acc) {
super(rule, declarationOf( (Class<B>) acc.accClass ));
this.pattern = pattern;
this.acc = acc;
}

@Override
public RuleItemBuilder toExecModelItem() {
PatternDSL.PatternDef patternDef = (PatternDSL.PatternDef) pattern.toExecModelItem();
Variable<B> boundVar = declarationOf( (Class<B>) acc.accClass );
patternDef.bind(boundVar, acc.bindingFunc);
return accumulate( patternDef, accFunction(acc.accFuncSupplier, boundVar).as(getVariable()) );
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ public <A> Pattern1<A> from(DataSource<A> dataSource) {
return pattern1;
}

public <A, B> Pattern1<B> accumulate(Pattern1<A> pattern, Accumulators.Accumulator1<A, B> acc) {
patterns.remove(pattern);
Pattern1<B> accPattern = new Accumulators.AccumulatePattern1<>(this, pattern, acc);
patterns.add(accPattern);
return accPattern;
}

private <A> Class<A> findDataSourceClass(DataSource<A> dataSource) {
assert(dataSource != null);
for (Field field : unit.getClass().getDeclaredFields()) {
Expand Down Expand Up @@ -83,11 +90,7 @@ public Rule toRule() {
List<RuleItemBuilder> items = new ArrayList<>();

for (PatternDefinition<?> pattern : patterns) {
PatternDSL.PatternDef patternDef = pattern(pattern.getVariable());
for (Constraint constraint : pattern.getConstraints()) {
constraint.addConstraintToPattern(patternDef);
}
items.add(patternDef);
items.add(pattern.toExecModelItem());
}

if (consequence != null) {
Expand Down Expand Up @@ -122,6 +125,14 @@ protected Variable getVariable() {
public <G> void execute(G globalObject, Block1<G> block) {
ruleFactory.execute(globalObject, block);
}

public RuleItemBuilder toExecModelItem() {
PatternDSL.PatternDef patternDef = pattern(getVariable());
for (Constraint constraint : getConstraints()) {
constraint.addConstraintToPattern(patternDef);
}
return patternDef;
}
}

public static class Pattern1<A> extends PatternDefinition<A> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,48 +1,42 @@
package org.drools.ruleunits.dsl;

import java.util.Arrays;
import java.util.ArrayList;
import java.util.List;

import org.drools.ruleunits.api.DataSource;
import org.drools.ruleunits.api.DataStore;

import static org.drools.model.Index.ConstraintType.EQUAL;
import static org.drools.ruleunits.dsl.Accumulators.sum;

public class AccumulateUnit implements RuleUnitDefinition {

private final DataStore<String> strings;
private final DataStore<Integer> ints;
private final List<String> results = new ArrayList<>();

public AccumulateUnit() {
this(DataSource.createStore(), DataSource.createStore());
this(DataSource.createStore());
}

public AccumulateUnit(DataStore<String> strings, DataStore<Integer> ints) {
public AccumulateUnit(DataStore<String> strings) {
this.strings = strings;
this.ints = ints;
}

public DataStore<String> getStrings() {
return strings;
}

public DataStore<Integer> getInts() {
return ints;
public List<String> getResults() {
return results;
}

@Override
public void defineRules(RulesFactory rulesFactory) {
rulesFactory.addRule()
.from(strings)
.filter(s -> s.substring(0, 1), EQUAL, "A")

;
}

public static void main(String[] args) {
List<String> strings = Arrays.asList("A1", "A123", "B12", "ABCDEF");

int result = strings.stream().filter(s -> s.substring(0,1).equals("A")).reduce(0, (a,s) -> a+s.length(), (a,b) -> a+b);
System.out.println(result);
RuleFactory accRuleFactory = rulesFactory.addRule();
accRuleFactory.accumulate(
accRuleFactory.from(strings).filter(s -> s.substring(0, 1), EQUAL, "A"),
sum(String::length)
)
.execute(results, (r, sum) -> r.add("Sum of length of Strings starting with A is " + sum)); ;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,17 @@ public void testSelfJoin() {
assertEquals(1, unitInstance.fire());
assertEquals("Found 'abc' and 'bcd'", unit.getResults().get(0));
}

@Test
public void testAccumulate() {
AccumulateUnit unit = new AccumulateUnit();
unit.getStrings().add("A1");
unit.getStrings().add("A123");
unit.getStrings().add("B12");
unit.getStrings().add("ABCDEF");

RuleUnitInstance<AccumulateUnit> unitInstance = DSLRuleUnit.instance(unit);
assertEquals(1, unitInstance.fire());
assertEquals("Sum of length of Strings starting with A is 12", unit.getResults().get(0));
}
}

0 comments on commit befa642

Please sign in to comment.