Skip to content

Commit

Permalink
Split Beta Nodes #5787
Browse files Browse the repository at this point in the history
  • Loading branch information
mdproctor committed Aug 11, 2024
1 parent 5456bc3 commit 3919e4e
Show file tree
Hide file tree
Showing 113 changed files with 2,297 additions and 1,917 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.stmt.Statement;
import org.drools.base.reteoo.NodeTypeEnums;
import org.drools.core.reteoo.Sink;

import static com.github.javaparser.StaticJavaParser.parseStatement;
Expand All @@ -40,7 +41,11 @@ protected Statement propagateMethod(Sink sink) {
if (sinkCanBeInlined(sink)) {
assertStatement = parseStatement("ALPHATERMINALNODE.collectObject();");
} else {
assertStatement = parseStatement("ALPHATERMINALNODE.assertObject(handle, context, wm);");
String g = "";
if (NodeTypeEnums.isBetaNode(sink)) {
g = "getRightInput().";
}
assertStatement = parseStatement("ALPHATERMINALNODE." + g + "assertObject(handle, context, wm);");
}
replaceNameExpr(assertStatement, "ALPHATERMINALNODE", getVariableName(sink));
return assertStatement;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,13 @@ public boolean isEmpty() {
}

@Override
public void doLinkRiaNode(ReteEvaluator reteEvaluator) {
originalSinkPropagator.doLinkRiaNode(reteEvaluator);
public void doLinkSubnetwork(ReteEvaluator reteEvaluator) {
originalSinkPropagator.doLinkSubnetwork(reteEvaluator);
}

@Override
public void doUnlinkRiaNode(ReteEvaluator reteEvaluator) {
originalSinkPropagator.doUnlinkRiaNode(reteEvaluator);
public void doUnlinkSubnetwork(ReteEvaluator reteEvaluator) {
originalSinkPropagator.doUnlinkSubnetwork(reteEvaluator);
}

public abstract void init(Object... args);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.stmt.Statement;
import org.drools.base.reteoo.NodeTypeEnums;
import org.drools.core.reteoo.Sink;

import static com.github.javaparser.StaticJavaParser.parseStatement;
Expand All @@ -40,8 +41,16 @@ protected Statement propagateMethod(Sink sink) {
if (sinkCanBeInlined(sink)) {
modifyStatement = parseStatement("ALPHATERMINALNODE.collectObject();");
} else {
modifyStatement = parseStatement("ALPHATERMINALNODE.modifyObject(handle, modifyPreviousTuples, context, wm);");
String g = "";
if (NodeTypeEnums.isBetaNode(sink)) {
g = "getRightInput().";
}

modifyStatement = parseStatement("ALPHATERMINALNODE." + g + "modifyObject(handle, modifyPreviousTuples, context, wm);");
}



replaceNameExpr(modifyStatement, "ALPHATERMINALNODE", getVariableName(sink));
return modifyStatement;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import org.drools.core.reteoo.AlphaNode;
import org.drools.core.reteoo.BetaNode;
import org.drools.core.reteoo.RightInputAdapterNode;
import org.drools.core.reteoo.CompositeObjectSinkAdapter;
import org.drools.core.reteoo.CompositeObjectSinkAdapter.FieldIndex;
import org.drools.core.reteoo.CompositePartitionAwareObjectSinkAdapter;
Expand Down Expand Up @@ -201,8 +202,8 @@ private void traverseSink(ObjectSink sink, NetworkHandler handler) {
traversePropagator(alphaNode.getObjectSinkPropagator(), handler);

handler.endNonHashedAlphaNode(alphaNode);
} else if (NodeTypeEnums.isBetaNode( sink ) ) {
BetaNode betaNode = (BetaNode) sink;
} else if (NodeTypeEnums.isBetaRightNode( sink ) ) {
BetaNode betaNode = ((RightInputAdapterNode) sink).getBetaNode();

handler.startBetaNode(betaNode);
handler.endBetaNode(betaNode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public interface NetworkNode extends Serializable {

NetworkNode[] getSinks();

default boolean isRightInputIsRiaNode() {
default boolean inputIsTupleToObjectNode() {
// not ideal, but this was here to allow NetworkNode to be in drools-base
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,8 @@ public interface BaseTerminalNode extends NetworkNode {

void initInferredMask();

BitMask getDeclaredMask();

void setDeclaredMask(BitMask mask);

BitMask getInferredMask();

void setInferredMask(BitMask mask);

BitMask getNegativeMask();

void setNegativeMask(BitMask mask);

RuleImpl getRule();

GroupElement getSubRule();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ public class NodeTypeEnums {
public static final int EndNodeMask = 1 << 8;
public static final int BetaMask = 1 << 9;

public static final int MemoryFactoryMask = 1 << 10;
public static final int BetaRightMask = 1 << 10;

public static final int MemoryFactoryMask = 1 << 11;

public static final int shift = 15; // This must shift the node IDs, enough so their bits are not mutated by the masks.

Expand All @@ -46,8 +48,15 @@ public class NodeTypeEnums {
public static final int WindowNode = (150 << shift) | ObjectSourceMask | ObjectSinkMask | MemoryFactoryMask;

// ObjectSource, LeftTupleSink
public static final int RightInputAdapterNode = (160 << shift) | ObjectSourceMask | TupleSinkMask |
TupleNodeMask | EndNodeMask | MemoryFactoryMask;
public static final int TupleToObjectNode = (160 << shift) | ObjectSourceMask | TupleSinkMask |
TupleNodeMask | EndNodeMask | MemoryFactoryMask;

public static final int JoinRightAdapterNode = (162 << shift) | ObjectSinkMask | BetaRightMask;
public static final int ExistsRightAdapterNode = (164 << shift) | ObjectSinkMask | BetaRightMask;
public static final int NotRightAdapterNode = (166 << shift) | ObjectSinkMask | BetaRightMask;
public static final int AccumulateRightAdapterNode = (168 << shift) | ObjectSinkMask | BetaRightMask;


// LefTTupleSink, LeftTupleNode
public static final int RuleTerminalNode = (180 << shift) | TupleSinkMask | TerminalNodeMask |
TupleNodeMask | EndNodeMask | MemoryFactoryMask;
Expand Down Expand Up @@ -97,8 +106,12 @@ public static boolean isBetaNode(NetworkNode node) {
return (node.getType() & BetaMask) != 0;
}

public static boolean isBetaNodeWithRian(NetworkNode node) {
return isBetaNode(node) && node.isRightInputIsRiaNode();
public static boolean isBetaNodeWithSubnetwork(NetworkNode node) {
return isBetaNode(node) && node.inputIsTupleToObjectNode();
}

public static boolean isBetaNodeWithoutSubnetwork(NetworkNode node) {
return isBetaNode(node) && !node.inputIsTupleToObjectNode();
}

public static boolean isTerminalNode(NetworkNode node) {
Expand Down Expand Up @@ -130,4 +143,8 @@ public static boolean isLeftInputAdapterNode(NetworkNode node) {
return (node.getType() & LeftInputAdapterMask) != 0;
}

public static boolean isBetaRightNode(NetworkNode node) {
return (node.getType() & BetaRightMask) != 0;
}

}
11 changes: 6 additions & 5 deletions drools-core/src/main/java/org/drools/core/common/BaseNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.drools.core.reteoo.Sink;
import org.drools.core.reteoo.TerminalNode;
import org.drools.core.reteoo.builder.BuildContext;
import org.drools.util.bitmask.BitMask;
import org.kie.api.definition.rule.Rule;
import org.drools.base.reteoo.NodeTypeEnums;

Expand Down Expand Up @@ -110,6 +111,8 @@ protected void setStreamMode(boolean streamMode) {
this.streamMode = streamMode;
}

public abstract BaseNode getParent();

/**
* Attaches the node into the network. Usually to the parent <code>ObjectSource</code> or <code>TupleSource</code>
*/
Expand Down Expand Up @@ -174,14 +177,10 @@ public void setPartitionId(BuildContext context, RuleBasePartitionId partitionId
/**
* Associates this node with the give rule
*/
public void addAssociation( Rule rule ) {
public void addAssociation(Rule rule, BuildContext context) {
this.associations.add( rule );
}

public void addAssociation( BuildContext context, Rule rule ) {
addAssociation( rule );
}

/**
* Removes the association to the given rule from the
* associations map.
Expand Down Expand Up @@ -242,4 +241,6 @@ public NetworkNode[] getSinks() {
}


public abstract BitMask getDeclaredMask();
public abstract BitMask getInferredMask();
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,28 @@
import org.drools.base.common.NetworkNode;
import org.drools.base.reteoo.NodeTypeEnums;
import org.drools.core.reteoo.AccumulateNode;
import org.drools.core.reteoo.AccumulateRight;
import org.drools.core.reteoo.AlphaTerminalNode;
import org.drools.core.reteoo.AsyncReceiveNode;
import org.drools.core.reteoo.AsyncSendNode;
import org.drools.core.reteoo.ConditionalBranchNode;
import org.drools.core.reteoo.EvalConditionNode;
import org.drools.core.reteoo.ExistsNode;
import org.drools.core.reteoo.ExistsRight;
import org.drools.core.reteoo.FromNode;
import org.drools.core.reteoo.JoinNode;
import org.drools.core.reteoo.JoinRightAdapterNode;
import org.drools.core.reteoo.LeftInputAdapterNode;
import org.drools.core.reteoo.LeftTupleNode;
import org.drools.core.reteoo.LeftTupleSinkNode;
import org.drools.core.reteoo.LeftTupleSource;
import org.drools.core.reteoo.NotNode;
import org.drools.core.reteoo.NotRight;
import org.drools.core.reteoo.ObjectTypeNodeId;
import org.drools.core.reteoo.QueryElementNode;
import org.drools.core.reteoo.QueryTerminalNode;
import org.drools.core.reteoo.ReactiveFromNode;
import org.drools.core.reteoo.RightInputAdapterNode;
import org.drools.core.reteoo.TupleToObjectNode;
import org.drools.core.reteoo.RightTuple;
import org.drools.core.reteoo.RightTupleSink;
import org.drools.core.reteoo.RuleTerminalNode;
Expand All @@ -57,7 +61,7 @@ public static LeftTupleNode getLeftTupleNode(TupleImpl t) {
switch (s.getType()) {
case NodeTypeEnums.RuleTerminalNode : return (RuleTerminalNode) s;
case NodeTypeEnums.QueryTerminalNode: return (QueryTerminalNode) s;
case NodeTypeEnums.RightInputAdapterNode: return (RightInputAdapterNode) s;
case NodeTypeEnums.TupleToObjectNode: return (TupleToObjectNode) s;
case NodeTypeEnums.LeftInputAdapterNode: return (LeftInputAdapterNode) s;
case NodeTypeEnums.AlphaTerminalNode: return (AlphaTerminalNode) s;
case NodeTypeEnums.AccumulateNode : return (AccumulateNode) s;
Expand All @@ -81,10 +85,10 @@ public static LeftTupleNode getLeftTupleNode(TupleImpl t) {
public static RightTupleSink getRightTupleSink(RightTuple t) {
Sink s = t.getSink();
switch (s.getType()) {
case NodeTypeEnums.AccumulateNode : return (AccumulateNode) s;
case NodeTypeEnums.ExistsNode: return (ExistsNode) s;
case NodeTypeEnums.NotNode: return (NotNode) s;
case NodeTypeEnums.JoinNode: return (JoinNode) s;
case NodeTypeEnums.AccumulateRightAdapterNode: return (AccumulateRight) s;
case NodeTypeEnums.ExistsRightAdapterNode: return (ExistsRight) s;
case NodeTypeEnums.NotRightAdapterNode: return (NotRight) s;
case NodeTypeEnums.JoinRightAdapterNode: return (JoinRightAdapterNode) s;
case NodeTypeEnums.WindowNode: return (WindowNode) s;
case NodeTypeEnums.MockBetaNode: return (RightTupleSink) s;
case NodeTypeEnums.MockAlphaNode: return (RightTupleSink) s;
Expand All @@ -97,7 +101,7 @@ public static LeftTupleSinkNode asLeftTupleSink(NetworkNode n) {
switch (n.getType()) {
case NodeTypeEnums.RuleTerminalNode : return (RuleTerminalNode) n;
case NodeTypeEnums.QueryTerminalNode: return (QueryTerminalNode) n;
case NodeTypeEnums.RightInputAdapterNode: return (RightInputAdapterNode) n;
case NodeTypeEnums.TupleToObjectNode: return (TupleToObjectNode) n;
case NodeTypeEnums.AccumulateNode : return (AccumulateNode) n;
case NodeTypeEnums.ExistsNode: return (ExistsNode) n;
case NodeTypeEnums.NotNode: return (NotNode) n;
Expand All @@ -119,22 +123,22 @@ public static LeftTupleSinkNode asLeftTupleSink(NetworkNode n) {
public static ObjectTypeNodeId getLeftInputOtnId(TupleImpl t) {
Sink s = t.getSink();
switch (s.getType()) {
case NodeTypeEnums.RuleTerminalNode : return ((RuleTerminalNode) s).getLeftInputOtnId();
case NodeTypeEnums.QueryTerminalNode: return ((QueryTerminalNode) s).getLeftInputOtnId();
case NodeTypeEnums.RightInputAdapterNode: return ((RightInputAdapterNode) s).getLeftInputOtnId();
case NodeTypeEnums.AccumulateNode : return ((AccumulateNode) s).getLeftInputOtnId();
case NodeTypeEnums.ExistsNode: return ((ExistsNode) s).getLeftInputOtnId();
case NodeTypeEnums.NotNode: return ((NotNode) s).getLeftInputOtnId();
case NodeTypeEnums.JoinNode: return ((JoinNode) s).getLeftInputOtnId();
case NodeTypeEnums.FromNode: return ((FromNode) s).getLeftInputOtnId();
case NodeTypeEnums.EvalConditionNode: return ((EvalConditionNode) s).getLeftInputOtnId();
case NodeTypeEnums.AsyncReceiveNode: return ((AsyncReceiveNode) s).getLeftInputOtnId();
case NodeTypeEnums.AsyncSendNode: return ((AsyncSendNode) s).getLeftInputOtnId();
case NodeTypeEnums.ReactiveFromNode: return ((ReactiveFromNode) s).getLeftInputOtnId();
case NodeTypeEnums.ConditionalBranchNode: return ((ConditionalBranchNode) s).getLeftInputOtnId();
case NodeTypeEnums.QueryElementNode: return ((QueryElementNode) s).getLeftInputOtnId();
case NodeTypeEnums.TimerConditionNode: return ((TimerNode) s).getLeftInputOtnId();
case NodeTypeEnums.MockBetaNode: return ((LeftTupleSource)s).getLeftInputOtnId();
case NodeTypeEnums.RuleTerminalNode : return ((RuleTerminalNode) s).getInputOtnId();
case NodeTypeEnums.QueryTerminalNode: return ((QueryTerminalNode) s).getInputOtnId();
case NodeTypeEnums.TupleToObjectNode: return ((TupleToObjectNode) s).getInputOtnId();
case NodeTypeEnums.AccumulateNode : return ((AccumulateNode) s).getInputOtnId();
case NodeTypeEnums.ExistsNode: return ((ExistsNode) s).getInputOtnId();
case NodeTypeEnums.NotNode: return ((NotNode) s).getInputOtnId();
case NodeTypeEnums.JoinNode: return ((JoinNode) s).getInputOtnId();
case NodeTypeEnums.FromNode: return ((FromNode) s).getInputOtnId();
case NodeTypeEnums.EvalConditionNode: return ((EvalConditionNode) s).getInputOtnId();
case NodeTypeEnums.AsyncReceiveNode: return ((AsyncReceiveNode) s).getInputOtnId();
case NodeTypeEnums.AsyncSendNode: return ((AsyncSendNode) s).getInputOtnId();
case NodeTypeEnums.ReactiveFromNode: return ((ReactiveFromNode) s).getInputOtnId();
case NodeTypeEnums.ConditionalBranchNode: return ((ConditionalBranchNode) s).getInputOtnId();
case NodeTypeEnums.QueryElementNode: return ((QueryElementNode) s).getInputOtnId();
case NodeTypeEnums.TimerConditionNode: return ((TimerNode) s).getInputOtnId();
case NodeTypeEnums.MockBetaNode: return ((LeftTupleSource)s).getInputOtnId();
default:
throw new UnsupportedOperationException("Node does not have an LeftInputOtnId: " + s);
}
Expand All @@ -143,11 +147,11 @@ public static ObjectTypeNodeId getLeftInputOtnId(TupleImpl t) {
public static ObjectTypeNodeId getRightInputOtnId(TupleImpl t) {
Sink s = t.getSink();
switch (s.getType()) {
case NodeTypeEnums.AccumulateNode : return ((AccumulateNode) s).getRightInputOtnId();
case NodeTypeEnums.ExistsNode: return ((ExistsNode) s).getRightInputOtnId();
case NodeTypeEnums.NotNode: return ((NotNode) s).getRightInputOtnId();
case NodeTypeEnums.JoinNode: return ((JoinNode) s).getRightInputOtnId();
case NodeTypeEnums.WindowNode: return ((WindowNode) s).getRightInputOtnId();
case NodeTypeEnums.AccumulateRightAdapterNode: return ((AccumulateRight) s).getInputOtnId();
case NodeTypeEnums.ExistsRightAdapterNode: return ((ExistsRight) s).getInputOtnId();
case NodeTypeEnums.NotRightAdapterNode: return ((NotRight) s).getInputOtnId();
case NodeTypeEnums.JoinRightAdapterNode: return ((JoinRightAdapterNode) s).getInputOtnId();
case NodeTypeEnums.WindowNode: return ((WindowNode) s).getInputOtnId();
default:
throw new UnsupportedOperationException("Node does not have an RightInputOtnId: " + s);
}
Expand All @@ -158,7 +162,7 @@ public static LeftTupleSource getLeftTupleSource(TupleImpl t) {
switch (s.getType()) {
case NodeTypeEnums.RuleTerminalNode : return ((RuleTerminalNode) s).getLeftTupleSource();
case NodeTypeEnums.QueryTerminalNode: return ((QueryTerminalNode) s).getLeftTupleSource();
case NodeTypeEnums.RightInputAdapterNode: return ((RightInputAdapterNode) s).getLeftTupleSource();
case NodeTypeEnums.TupleToObjectNode: return ((TupleToObjectNode) s).getLeftTupleSource();
case NodeTypeEnums.AccumulateNode : return ((AccumulateNode) s).getLeftTupleSource();
case NodeTypeEnums.ExistsNode: return ((ExistsNode) s).getLeftTupleSource();
case NodeTypeEnums.NotNode: return ((NotNode) s).getLeftTupleSource();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ public class KnowledgeBaseImpl implements InternalRuleBase {

public KnowledgeBaseImpl() { }

public KnowledgeBaseImpl(String id) {
this(id, (CompositeBaseConfiguration) RuleBaseFactory.newKnowledgeBaseConfiguration());
}

public KnowledgeBaseImpl(final String id,
final CompositeBaseConfiguration config) {
this.config = config;
Expand Down
Loading

0 comments on commit 3919e4e

Please sign in to comment.