Skip to content

Commit

Permalink
Improve ReplacePart
Browse files Browse the repository at this point in the history
  • Loading branch information
axkr committed Dec 7, 2024
1 parent b74aa70 commit 1f3864c
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6192,22 +6192,24 @@ public IExpr evaluate(IAST ast, int argSize, IExpr[] option, EvalEngine engine,
}
}
if (argSize == 3) {
IExpr arg3 = ast.arg3();
if (arg3.isList()) {
if (arg3.exists(x -> !x.isInteger())) {
IExpr lhs = ast.arg3();
IExpr rhs = ast.arg2();
if (lhs.isList()) {
if (lhs.exists(x -> !x.isInteger())) {
// Position specification `1` in `2` is not a machine sized integer or a list of
// machine-sized integers.
return Errors.printMessage(S.ReplacePart, "psl", F.List(F.C3, ast), engine);
}
} else {
int position = arg3.toIntDefault();
int position = lhs.toIntDefault();
if (position == Integer.MIN_VALUE) {
// Position specification `1` in `2` is not a machine sized integer or a list of
// machine-sized integers.
return Errors.printMessage(S.ReplacePart, "psl", F.List(F.C3, ast), engine);
}
}
return result.replacePart(F.Rule(arg3, ast.arg2()), heads).orElse(result);
// Note: Rubi uses this kind of rule:
return result.replacePart(lhs, rhs, heads).orElse(result);
}
if (ast.arg2().isRuleAST()) {
return ast.arg1().replacePart((IAST) ast.arg2(), heads).orElse(ast.arg1());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6059,7 +6059,7 @@ default IExpr replaceAll(final Map<? extends IExpr, ? extends IExpr> map) {
/**
* Replace all subexpressions with the given rule set. A rule must contain the position of the
* subexpression which should be replaced on the left-hand-side. If no substitution matches, the
* method returns <code>F.NIL</code>.
* method returns {@link F#NIL}.
*
* @param astRules rules of the form <code>position-&gt;y</code> or <code>
* {position1-&gt;b, position2-&gt;d}</code>
Expand All @@ -6078,6 +6078,28 @@ default IExpr replacePart(final IAST astRules, IExpr.COMPARE_TERNARY heads) {
return F.NIL;
}

/**
* All subexpressions whose positions matches the left-hand-side (<code>lhs</code>) are replaced
* with the right-hand-side (<code>rhs</code>). If no substitution matches, the method returns
* {@link F#NIL}.
*
* @param lhs the left-hand-side of the rule
* @param rhs the right-hand-side of the rule
* @param heads if <code>TRUE</code> also replace the heads of expressions
* @return
*/
default IExpr replacePart(final IExpr lhs, IExpr rhs, IExpr.COMPARE_TERNARY heads) {
try {
return this.accept(new VisitorReplacePart(lhs, rhs, heads));
} catch (RuntimeException rex) {
Errors.rethrowsInterruptException(rex);
if (Config.SHOW_STACKTRACE) {
rex.printStackTrace();
}
}
return F.NIL;
}

/**
* Repeatedly replace all (sub-) expressions with the given unary function. If no substitution
* matches, the method returns <code>this</code>.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class VisitorReplacePart extends AbstractVisitor {
* <code>0</code> or <code>1</code>, depending if option <code>Heads->True</code> is set or the
* <code>0</code> index position is used in the left-hand-side of a rule.
*/
private int offset;
private int startOffset;

/**
* The current evaluation engine for this thread.
Expand All @@ -43,25 +43,50 @@ public class VisitorReplacePart extends AbstractVisitor {
public VisitorReplacePart(IAST rule, IExpr.COMPARE_TERNARY heads) {
super();
engine = EvalEngine.get();
startOffset = heads == IExpr.COMPARE_TERNARY.TRUE ? 0 : 1;
if (rule.isRuleAST()) {
rule = F.list(rule);
}
if (rule.isListOfRules()) {
IAST list = rule;
this.patternMatcherList = new ArrayList<IPatternMatcher>(list.argSize() + 3);
offset = heads == IExpr.COMPARE_TERNARY.TRUE ? 0 : 1;

for (int i = 1; i < list.size(); i++) {
rule = (IAST) list.get(i);
initPatternMatcher(rule, heads);
initPatternMatcher(rule.arg1(), rule.arg2(), heads);
}
if (heads == COMPARE_TERNARY.FALSE) {
offset = 1;
// if set explicitly to FALSE, then startOffset is always 1, otherwise startOffset may be
// changed in initPatternMatcher()
startOffset = 1;
}
}
}

private void initPatternMatcher(IAST rule, IExpr.COMPARE_TERNARY heads) {
IExpr fromPositions = rule.arg1();
public VisitorReplacePart(IExpr lhs, IExpr rhs, IExpr.COMPARE_TERNARY heads) {
super();
engine = EvalEngine.get();
this.patternMatcherList = new ArrayList<IPatternMatcher>(1);
startOffset = heads == IExpr.COMPARE_TERNARY.TRUE ? 0 : 1;
initPatternMatcher(lhs, rhs, heads);
if (heads == COMPARE_TERNARY.FALSE) {
startOffset = 1;
}
}

/**
* Initialize the pattern matcher. If the left-hand-side is a list of lists of integers, then the
* right-hand-side is matched against the positions in the list.
* <p>
* <b>Note</b>: the {@link #startOffset} is set to <code>0</code> if a position <code>0</code> is
* found.
*
* @param lhs
* @param rhs
* @param heads
*/
private void initPatternMatcher(IExpr lhs, IExpr rhs, IExpr.COMPARE_TERNARY heads) {
IExpr fromPositions = lhs;
try {
// try extracting an int[] array of expressions
if (fromPositions.isList()) {
Expand All @@ -76,11 +101,11 @@ private void initPatternMatcher(IAST rule, IExpr.COMPARE_TERNARY heads) {
throw ReturnException.RETURN_FALSE;
}
if (positions[k - 1] == 0) {
offset = 0;
startOffset = 0;
}
}
IPatternMatcher evalPatternMatcher =
engine.evalPatternMatcher(F.Sequence(positions), rule.arg2());
engine.evalPatternMatcher(F.Sequence(positions), rhs);
this.patternMatcherList.add(evalPatternMatcher);
}
} else {
Expand All @@ -93,42 +118,41 @@ private void initPatternMatcher(IAST rule, IExpr.COMPARE_TERNARY heads) {
throw ReturnException.RETURN_FALSE;
}
if (positions[j - 1] == 0) {
offset = 0;
startOffset = 0;
}
}
IPatternMatcher evalPatternMatcher =
engine.evalPatternMatcher(F.Sequence(positions), rule.arg2());
engine.evalPatternMatcher(F.Sequence(positions), rhs);
this.patternMatcherList.add(evalPatternMatcher);
}
}
} else {
int[] positions = new int[] {rule.arg1().toIntDefault()};
int[] positions = new int[] {lhs.toIntDefault()};
if (positions[0] == Integer.MIN_VALUE) {
throw ReturnException.RETURN_FALSE;
}
if (positions[0] == 0) {
offset = 0;
startOffset = 0;
}
IPatternMatcher evalPatternMatcher =
engine.evalPatternMatcher(F.Sequence(positions), rule.arg2());
IPatternMatcher evalPatternMatcher = engine.evalPatternMatcher(F.Sequence(positions), rhs);
this.patternMatcherList.add(evalPatternMatcher);

}
} catch (ReturnException rex) {
if (fromPositions.isList()) {
IAST list = ((IAST) fromPositions).apply(S.Sequence, 1);
IPatternMatcher evalPatternMatcher = engine.evalPatternMatcher(list, rule.arg2());
IPatternMatcher evalPatternMatcher = engine.evalPatternMatcher(list, rhs);
this.patternMatcherList.add(evalPatternMatcher);
} else {
IPatternMatcher evalPatternMatcher = engine.evalPatternMatcher(fromPositions, rule.arg2());
IPatternMatcher evalPatternMatcher = engine.evalPatternMatcher(fromPositions, rhs);
this.patternMatcherList.add(evalPatternMatcher);
}
}
}

private IExpr visitPatternIndexList(IAST ast, IASTAppendable positions) {
IASTAppendable result = F.NIL;
for (int i = offset; i < ast.size(); i++) {
for (int i = startOffset; i < ast.size(); i++) {
final IInteger position = F.ZZ(i);
for (int j = 0; j < patternMatcherList.size(); j++) {
IPatternMatcher matcher = patternMatcherList.get(j);
Expand Down

0 comments on commit 1f3864c

Please sign in to comment.