From 1df38b9ba50d63aab7ea446c03bb9ab5b9b1e601 Mon Sep 17 00:00:00 2001 From: axexlck Date: Tue, 2 Jul 2024 20:42:29 +0200 Subject: [PATCH] WIP #1009 Improved DecisionTree for PatternTest --- .../core/decisiontree/DecisionTree.java | 99 ++++++++++++++++--- 1 file changed, 83 insertions(+), 16 deletions(-) diff --git a/symja_android_library/tools/src/main/java/org/matheclipse/core/decisiontree/DecisionTree.java b/symja_android_library/tools/src/main/java/org/matheclipse/core/decisiontree/DecisionTree.java index 856adec79..821b803ff 100644 --- a/symja_android_library/tools/src/main/java/org/matheclipse/core/decisiontree/DecisionTree.java +++ b/symja_android_library/tools/src/main/java/org/matheclipse/core/decisiontree/DecisionTree.java @@ -44,6 +44,7 @@ import org.matheclipse.core.eval.util.SourceCodeProperties; import org.matheclipse.core.expression.F; import org.matheclipse.core.expression.Pattern; +import org.matheclipse.core.expression.S; import org.matheclipse.core.generic.GenericPair; import org.matheclipse.core.interfaces.IAST; import org.matheclipse.core.interfaces.IASTAppendable; @@ -112,7 +113,7 @@ public final static List putDownRule(final int setSymbol, } private static boolean insertRule(DecisionTree[] dts, IAST lhs, IExpr rhs) { - if (!lhs.forAll(x -> x.isFreeOfPatterns() || (x.isPattern() && !x.isPatternDefault()), 0)) { + if (!lhs.forAll(x -> isCompilableRule(x), 0)) { return false; } if (lhs.forAll(x -> x.isFreeOfPatterns(), 0)) { @@ -179,6 +180,34 @@ private static boolean insertRule(DecisionTree[] dts, IAST lhs, IExpr rhs) { set.add(node); } } + } else if (arg.isAST(S.PatternTest, 3)) { + IAST patternTest = (IAST) arg; + // Pattern p = (Pattern) patternTest.arg1(); + // IExpr testExpr = patternTest.arg2(); + if (node != null) { + DecisionTree subNet = node.decisionTree(); + if (subNet == null) { + subNet = new DecisionTree(); + node.decisionTree = subNet; + net = subNet; + } + } + TreeSet set = net.get(i); + if (set == null) { + set = new TreeSet(); + node = new DiscriminationNode(patternTest, null, null); + set.add(node); + net.put(i, set); + } else { + node = new DiscriminationNode(patternTest, null, null); + DiscriminationNode floor = set.floor(node); + if (floor != null && floor.equals(node)) { + node = floor; + net = node.decisionTree; + } else { + set.add(node); + } + } } } if (node != null) { @@ -192,6 +221,23 @@ private static boolean insertRule(DecisionTree[] dts, IAST lhs, IExpr rhs) { return true; } + public static boolean isCompilableRule(IExpr x) { + if (x.isFreeOfPatterns()) { + return true; + } + if (x.isPattern() && !x.isPatternDefault()) { + return true; + } + if (x.isAST(S.PatternTest, 3) && x.first().isPattern()) { + IPattern pattern = (IPattern) x.first(); + IExpr patternTest = x.second(); + if (patternTest.isFreeOfPatterns() && !pattern.isPatternDefault()) { + return true; + } + } + return false; + } + /** * Experimental. Don't use it. * @@ -254,6 +300,7 @@ private static IExpr toJavaMethodRecursive(DecisionTree dn, StringBuilder buf, IExpr expr = node.expr(); CharSequence patternValueVar = null; ISymbol patternSymbol = null; + IExpr exprTest = null; IPattern pattern = null; if (expr.isPattern()) { patternEval = true; @@ -266,13 +313,35 @@ private static IExpr toJavaMethodRecursive(DecisionTree dn, StringBuilder buf, break; } } + } else if (expr.isAST(S.PatternTest, 3)) { + IAST patternTest = (IAST) expr; + patternEval = true; + pattern = (IPattern) patternTest.arg1(); + exprTest = patternTest.arg2(); + patternSymbol = pattern.getSymbol(); + for (int i = 0; i < patternIndexMap.size(); i++) { + GenericPair pair = patternIndexMap.get(i); + if (pair.getFirst().equals(patternSymbol)) { + patternValueVar = pair.getSecond(); + break; + } + } } else { patternValueVar = toJava(expr); } try { if (patternValueVar == null) { buf.append("IPattern " + x + " = (IPattern)" + toJava(pattern) + ";\n"); - buf.append("if (" + x + ".isConditionMatched(" + arg + ",null)) {\n"); + if (exprTest == null) { + buf.append("if (" + x + ".isConditionMatched(" // + + arg + ",null)) {\n"); + } else { + String t = EvalEngine.uniqueName("t"); + buf.append("IExpr " + t + " = " + toJava(exprTest) + ";\n"); + buf.append("if (engine.evalTrue(" + t + "," + arg + ") &&" // + + x + ".isConditionMatched(" // + + arg + ",null)) {\n"); + } buf.append("patternIndexMap.push(new GenericPair(" + arg + ", " + x + ".getSymbol()));\n"); buf.append("try {\n"); @@ -343,7 +412,7 @@ public static void main(String[] args) { IExpr rhs2 = Times(C1D2, Power(Exp(Times(CI, p)), CN1), Sqrt(Times(C3, Power(C2Pi, CN1))), Sin(t)); // SphericalHarmonicY(1, 1, t_, p_) := (-1/2)*E^(I*p)*Sqrt(3/(2*Pi))*Sin(t), - IAST lhs3 = SphericalHarmonicY(C1, C1, t_, p_); + IAST lhs3 = SphericalHarmonicY(C1, C1, t_, F.PatternTest(p_, S.IntegerQ)); IExpr rhs3 = Times(CN1D2, Exp(Times(CI, p)), Sqrt(Times(C3, Power(C2Pi, CN1))), Sin(t)); IAST lhs3a = SphericalHarmonicY(C1, C3, t_, t_); IExpr rhs3a = Times(CN1D2, Exp(Times(CI, t)), Sqrt(Times(C3, Power(C2Pi, CN1))), Sin(t)); @@ -548,7 +617,8 @@ public static IExpr match5(IAST evalLHS, EvalEngine engine) { try { IExpr a21 = evalLHS.get(4); IPattern x22 = F.p_; - if (x22.isConditionMatched(a21, null)) { + IExpr t23 = F.IntegerQ; + if (engine.evalTrue(t23, a21) && x22.isConditionMatched(a21, null)) { patternIndexMap.push(new GenericPair(a21, x22.getSymbol())); try { result = PatternMatcherAndEvaluator.evalInternal(evalLHS, @@ -570,16 +640,16 @@ public static IExpr match5(IAST evalLHS, EvalEngine engine) { } } - IExpr x23 = F.C3; - if (x23.equals(a12)) { - IExpr a24 = evalLHS.get(3); - IPattern x25 = F.t_; - if (x25.isConditionMatched(a24, null)) { - patternIndexMap.push(new GenericPair(a24, x25.getSymbol())); + IExpr x24 = F.C3; + if (x24.equals(a12)) { + IExpr a25 = evalLHS.get(3); + IPattern x26 = F.t_; + if (x26.isConditionMatched(a25, null)) { + patternIndexMap.push(new GenericPair(a25, x26.getSymbol())); try { - IExpr a26 = evalLHS.get(4); - IExpr x27 = a24; - if (x27.equals(a26)) { + IExpr a27 = evalLHS.get(4); + IExpr x28 = a25; + if (x28.equals(a27)) { result = PatternMatcherAndEvaluator.evalInternal(evalLHS, F.Times(F.CN1D2, F.Exp(F.Times(F.CI, t)), @@ -604,7 +674,4 @@ public static IExpr match5(IAST evalLHS, EvalEngine engine) { return F.NIL; } - - - }