diff --git a/mathics/builtin/list/constructing.py b/mathics/builtin/list/constructing.py
index 014f467d5..72c673b40 100644
--- a/mathics/builtin/list/constructing.py
+++ b/mathics/builtin/list/constructing.py
@@ -439,7 +439,7 @@ def eval(self, expr, patterns, f, evaluation: Evaluation):
def listener(e, tag):
result = False
for pattern, items in sown:
- if pattern.does_match(tag, evaluation):
+ if pattern.does_match(tag, {"evaluation": evaluation}):
for item in items:
if item[0].sameQ(tag):
item[1].append(e)
diff --git a/mathics/builtin/patterns.py b/mathics/builtin/patterns.py
index addd2a600..6c53df3a6 100644
--- a/mathics/builtin/patterns.py
+++ b/mathics/builtin/patterns.py
@@ -37,9 +37,9 @@
The attributes 'Flat', 'Orderless', and 'OneIdentity' affect pattern matching.
"""
-from typing import Callable, List, Optional as OptionalType, Tuple, Union
+from typing import List, Optional as OptionalType, Tuple, Union
-from mathics.core.atoms import Integer, Number, Rational, Real, String
+from mathics.core.atoms import Integer, Integer2, Number, Rational, Real, String
from mathics.core.attributes import (
A_HOLD_ALL,
A_HOLD_FIRST,
@@ -54,6 +54,7 @@
PatternError,
PatternObject,
PostfixOperator,
+ Test,
)
from mathics.core.element import BaseElement, EvalMixin
from mathics.core.evaluation import Evaluation
@@ -63,7 +64,14 @@
from mathics.core.pattern import BasePattern, StopGenerator
from mathics.core.rules import Rule
from mathics.core.symbols import Atom, Symbol, SymbolList, SymbolTrue
-from mathics.core.systemsymbols import SymbolBlank, SymbolDefault, SymbolDispatch
+from mathics.core.systemsymbols import (
+ SymbolBlank,
+ SymbolDefault,
+ SymbolDispatch,
+ SymbolInfinity,
+ SymbolRule,
+ SymbolRuleDelayed,
+)
from mathics.eval.parts import python_levelspec
# This tells documentation how to sort this module
@@ -85,11 +93,6 @@ class Rule_(BinaryOperator):
= a + b + d
>> {x,x^2,y} /. x->3
= {3, 9, y}
- """
-
- # TODO: An error message should appear when Rule is called with a wrong
- # number of arguments
- """
>> a /. Rule[1, 2, 3] -> t
: Rule called with 3 arguments; 2 arguments are expected.
= a
@@ -102,6 +105,13 @@ class Rule_(BinaryOperator):
needs_verbatim = True
summary_text = "a replacement rule"
+ def eval_rule(self, elems, evaluation):
+ """Rule[elems___]"""
+ num_parms = len(elems.get_sequence())
+ if num_parms != 2:
+ evaluation.message("Rule", "argrx", "Rule", Integer(num_parms), Integer2)
+ return None
+
class RuleDelayed(BinaryOperator):
"""
@@ -123,6 +133,15 @@ class RuleDelayed(BinaryOperator):
operator = ":>"
summary_text = "a rule that keeps the replacement unevaluated"
+ def eval_rule_delayed(self, elems, evaluation):
+ """RuleDelayed[elems___]"""
+ num_parms = len(elems.get_sequence())
+ if num_parms != 2:
+ evaluation.message(
+ "RuleDelayed", "argrx", "RuleDelayed", Integer(num_parms), Integer2
+ )
+ return None
+
# TODO: disentangle me
def create_rules(
@@ -130,20 +149,24 @@ def create_rules(
expr: Expression,
name: str,
evaluation: Evaluation,
- extra_args: List = [],
+ extra_args: OptionalType[List] = None,
) -> Tuple[Union[List[Rule], BaseElement], bool]:
"""
- This function implements `Replace`, `ReplaceAll`, `ReplaceRepeated` and `ReplaceList` eval methods.
- `name` controls which of these methods is implemented. These methods applies the rule / list of rules
- `rules_expr` over the expression `expr`, using the evaluation context `evaluation`.
-
- The result is a tuple of two elements. If the second element is `True`, then the first element is the result of the method.
+ This function implements `Replace`, `ReplaceAll`, `ReplaceRepeated`
+ and `ReplaceList` eval methods.
+ `name` controls which of these methods is implemented. These methods
+ applies the rule / list of rules
+ `rules_expr` over the expression `expr`, using the evaluation context
+ `evaluation`.
+
+ The result is a tuple of two elements. If the second element is `True`,
+ then the first element is the result of the method.
If `False`, the first element of the tuple is a list of rules.
"""
if isinstance(rules_expr, Dispatch):
return rules_expr.rules, False
- elif rules_expr.has_form("Dispatch", None):
+ if rules_expr.has_form("Dispatch", None):
return Dispatch(rules_expr.elements, evaluation)
if rules_expr.has_form("List", None):
@@ -164,6 +187,8 @@ def create_rules(
break
if all_lists:
+ if extra_args is None:
+ extra_args = []
return (
ListExpression(
*[
@@ -173,28 +198,28 @@ def create_rules(
),
True,
)
- else:
- evaluation.message(name, "rmix", rules_expr)
+
+ evaluation.message(name, "rmix", rules_expr)
+ return None, True
+
+ result = []
+ for rule in rules:
+ if rule.head not in (SymbolRule, SymbolRuleDelayed):
+ evaluation.message(name, "reps", rule)
return None, True
- else:
- result = []
- for rule in rules:
- if rule.get_head_name() not in ("System`Rule", "System`RuleDelayed"):
- evaluation.message(name, "reps", rule)
- return None, True
- elif len(rule.elements) != 2:
- evaluation.message(
- # TODO: shorten names here
- rule.get_head_name(),
- "argrx",
- rule.get_head_name(),
- 3,
- 2,
- )
- return None, True
- else:
- result.append(Rule(rule.elements[0], rule.elements[1]))
- return result, False
+ if len(rule.elements) != 2:
+ evaluation.message(
+ # TODO: shorten names here
+ rule.get_head_name(),
+ "argrx",
+ rule.get_head_name(),
+ 3,
+ 2,
+ )
+ return None, True
+
+ result.append(Rule(rule.elements[0], rule.elements[1]))
+ return result, False
class Replace(Builtin):
@@ -272,7 +297,7 @@ def eval_levelspec(self, expr, rules, ls, evaluation, options):
heads = self.get_option(options, "Heads", evaluation) is SymbolTrue
- result, applied = expr.do_apply_rules(
+ result, _ = expr.do_apply_rules(
rules,
evaluation,
level=0,
@@ -285,6 +310,8 @@ def eval_levelspec(self, expr, rules, ls, evaluation, options):
except PatternError:
evaluation.message("Replace", "reps", rules)
+ return None
+
class ReplaceAll(BinaryOperator):
"""
@@ -348,10 +375,11 @@ def eval(self, expr, rules, evaluation: Evaluation):
rules, ret = create_rules(rules, expr, "ReplaceAll", evaluation)
if ret:
return rules
- result, applied = expr.do_apply_rules(rules, evaluation)
+ result, _ = expr.do_apply_rules(rules, evaluation)
return result
except PatternError:
evaluation.message("Replace", "reps", rules)
+ return None
class ReplaceRepeated(BinaryOperator):
@@ -450,7 +478,10 @@ class ReplaceList(Builtin):
- 'ReplaceList[$expr$, $rules$]'
-
- returns a list of all possible results of applying $rules$
+
- returns a list of all possible results when applying $rules$ \
+ to $expr$.
+
- 'ReplaceList[$expr$, $rules$, $n$]'
+
- returns a list of at most $n$ results when applying $rules$ \
to $expr$.
@@ -485,20 +516,28 @@ class ReplaceList(Builtin):
summary_text = "list of possible replacement results"
def eval(
- self, expr: BaseElement, rules: BaseElement, max: Number, evaluation: Evaluation
+ self,
+ expr: BaseElement,
+ rules: BaseElement,
+ maxidx: Number,
+ evaluation: Evaluation,
) -> OptionalType[BaseElement]:
- "ReplaceList[expr_, rules_, max_:Infinity]"
+ "ReplaceList[expr_, rules_, maxidx_:Infinity]"
- if max.get_name() == "System`Infinity":
+ # TODO: the below handles Infinity getting added as a
+ # default argument, when it is passed explitly, e.g.
+ # ReplaceList[expr, {}, Infinity], then Infinity
+ # comes in as DirectedInfinity[1].
+ if maxidx == SymbolInfinity:
max_count = None
else:
- max_count = max.get_int_value()
+ max_count = maxidx.get_int_value()
if max_count is None or max_count < 0:
evaluation.message("ReplaceList", "innf", 3)
- return
+ return None
try:
rules, ret = create_rules(
- rules, expr, "ReplaceList", evaluation, extra_args=[max]
+ rules, expr, "ReplaceList", evaluation, extra_args=[maxidx]
)
except PatternError:
evaluation.message("Replace", "reps", rules)
@@ -507,12 +546,12 @@ def eval(
if ret:
return rules
- list = []
+ list_result = []
for rule in rules:
result = rule.apply(expr, evaluation, return_list=True, max_list=max_count)
- list.extend(result)
+ list_result.extend(result)
- return ListExpression(*list)
+ return ListExpression(*list_result)
class PatternTest(BinaryOperator, PatternObject):
@@ -545,7 +584,7 @@ class PatternTest(BinaryOperator, PatternObject):
def init(
self, expr: Expression, evaluation: OptionalType[Evaluation] = None
) -> None:
- super(PatternTest, self).init(expr, evaluation=evaluation)
+ super().init(expr, evaluation=evaluation)
# This class has an important effect in the general performance,
# since all the rules that requires specify the type of patterns
# call it. Then, for simple checks like `NumberQ` or `NumericQ`
@@ -576,7 +615,10 @@ def init(
if match_function:
self.match = match_function
- def match_atom(self, yield_func, expression, vars, evaluation, **kwargs):
+ def match_atom(self, expression: Expression, pattern_context: dict):
+ """Match function for AtomQ"""
+ yield_func = pattern_context["yield_func"]
+
def yield_match(vars_2, rest):
items = expression.get_sequence()
# Here we use a `for` loop instead an all over iterator
@@ -588,9 +630,15 @@ def yield_match(vars_2, rest):
else:
yield_func(vars_2, None)
- self.pattern.match(yield_match, expression, vars, evaluation)
+ # TODO: clarify why we need to use copy here.
+ pattern_context = pattern_context.copy()
+ pattern_context["yield_func"] = yield_match
+ self.pattern.match(expression, pattern_context)
+
+ def match_string(self, expression: Expression, pattern_context: dict):
+ """Match function for StringQ"""
+ yield_func = pattern_context["yield_func"]
- def match_string(self, yield_func, expression, vars, evaluation, **kwargs):
def yield_match(vars_2, rest):
items = expression.get_sequence()
for item in items:
@@ -599,9 +647,14 @@ def yield_match(vars_2, rest):
else:
yield_func(vars_2, None)
- self.pattern.match(yield_match, expression, vars, evaluation)
+ pattern_context = pattern_context.copy()
+ pattern_context["yield_func"] = yield_match
+ self.pattern.match(expression, pattern_context)
+
+ def match_numberq(self, expression: Expression, pattern_context: dict):
+ """Match function for NumberQ"""
+ yield_func = pattern_context["yield_func"]
- def match_numberq(self, yield_func, expression, vars, evaluation, **kwargs):
def yield_match(vars_2, rest):
items = expression.get_sequence()
for item in items:
@@ -610,9 +663,15 @@ def yield_match(vars_2, rest):
else:
yield_func(vars_2, None)
- self.pattern.match(yield_match, expression, vars, evaluation)
+ pattern_context = pattern_context.copy()
+ pattern_context["yield_func"] = yield_match
+ self.pattern.match(expression, pattern_context)
+
+ def match_numericq(self, expression: Expression, pattern_context: dict):
+ """Match function for NumericQ"""
+ yield_func = pattern_context["yield_func"]
+ evaluation = pattern_context["evaluation"]
- def match_numericq(self, yield_func, expression, vars, evaluation, **kwargs):
def yield_match(vars_2, rest):
items = expression.get_sequence()
for item in items:
@@ -621,9 +680,14 @@ def yield_match(vars_2, rest):
else:
yield_func(vars_2, None)
- self.pattern.match(yield_match, expression, vars, evaluation)
+ pattern_context = pattern_context.copy()
+ pattern_context["yield_func"] = yield_match
+ self.pattern.match(expression, pattern_context)
+
+ def match_real_numberq(self, expression: Expression, pattern_context: dict):
+ """Match function for RealValuedNumberQ"""
+ yield_func = pattern_context["yield_func"]
- def match_real_numberq(self, yield_func, expression, vars, evaluation, **kwargs):
def yield_match(vars_2, rest):
items = expression.get_sequence()
for item in items:
@@ -632,9 +696,14 @@ def yield_match(vars_2, rest):
else:
yield_func(vars_2, None)
- self.pattern.match(yield_match, expression, vars, evaluation)
+ pattern_context = pattern_context.copy()
+ pattern_context["yield_func"] = yield_match
+ self.pattern.match(expression, pattern_context)
+
+ def match_positive(self, expression: Expression, pattern_context: dict):
+ """Match function for PositiveQ"""
+ yield_func = pattern_context["yield_func"]
- def match_positive(self, yield_func, expression, vars, evaluation, **kwargs):
def yield_match(vars_2, rest):
items = expression.get_sequence()
if all(
@@ -643,9 +712,14 @@ def yield_match(vars_2, rest):
):
yield_func(vars_2, None)
- self.pattern.match(yield_match, expression, vars, evaluation)
+ pattern_context = pattern_context.copy()
+ pattern_context["yield_func"] = yield_match
+ self.pattern.match(expression, pattern_context)
+
+ def match_negative(self, expression: Expression, pattern_context: dict):
+ """Match function for NegativeQ"""
+ yield_func = pattern_context["yield_func"]
- def match_negative(self, yield_func, expression, vars, evaluation, **kwargs):
def yield_match(vars_2, rest):
items = expression.get_sequence()
if all(
@@ -654,9 +728,14 @@ def yield_match(vars_2, rest):
):
yield_func(vars_2, None)
- self.pattern.match(yield_match, expression, vars, evaluation)
+ pattern_context = pattern_context.copy()
+ pattern_context["yield_func"] = yield_match
+ self.pattern.match(expression, pattern_context)
+
+ def match_nonpositive(self, expression: Expression, pattern_context: dict):
+ """Match function for NonPositiveQ"""
+ yield_func = pattern_context["yield_func"]
- def match_nonpositive(self, yield_func, expression, vars, evaluation, **kwargs):
def yield_match(vars_2, rest):
items = expression.get_sequence()
if all(
@@ -665,9 +744,14 @@ def yield_match(vars_2, rest):
):
yield_func(vars_2, None)
- self.pattern.match(yield_match, expression, vars, evaluation)
+ pattern_context = pattern_context.copy()
+ pattern_context["yield_func"] = yield_match
+ self.pattern.match(expression, pattern_context)
+
+ def match_nonnegative(self, expression: Expression, pattern_context: dict):
+ """Match function for NonNegativeQ"""
+ yield_func = pattern_context["yield_func"]
- def match_nonnegative(self, yield_func, expression, vars, evaluation, **kwargs):
def yield_match(vars_2, rest):
items = expression.get_sequence()
if all(
@@ -676,35 +760,41 @@ def yield_match(vars_2, rest):
):
yield_func(vars_2, None)
- self.pattern.match(yield_match, expression, vars, evaluation)
+ pattern_context = pattern_context.copy()
+ pattern_context["yield_func"] = yield_match
+ self.pattern.match(expression, pattern_context)
def quick_pattern_test(self, candidate, test, evaluation: Evaluation):
+ """Pattern test for some other special cases"""
if test == "System`NegativePowerQ":
return (
candidate.has_form("Power", 2)
and isinstance(candidate.elements[1], (Integer, Rational, Real))
and candidate.elements[1].value < 0
)
- elif test == "System`NotNegativePowerQ":
+ if test == "System`NotNegativePowerQ":
return not (
candidate.has_form("Power", 2)
and isinstance(candidate.elements[1], (Integer, Rational, Real))
and candidate.elements[1].value < 0
)
- else:
- from mathics.core.builtin import Test
-
- builtin = None
- builtin = evaluation.definitions.get_definition(test)
- if builtin:
- builtin = builtin.builtin
- if builtin is not None and isinstance(builtin, Test):
- return builtin.test(candidate)
+
+ builtin = None
+ builtin = evaluation.definitions.get_definition(test)
+ if builtin:
+ builtin = builtin.builtin
+ if builtin is not None and isinstance(builtin, Test):
+ return builtin.test(candidate)
return None
- def match(self, yield_func, expression, vars, evaluation, **kwargs):
- # def match(self, yield_func, expression, vars, evaluation, **kwargs):
- # for vars_2, rest in self.pattern.match(expression, vars, evaluation):
+ def match(self, expression: Expression, pattern_context: dict):
+ """Match expression with PatternTest"""
+ evaluation = pattern_context["evaluation"]
+ vars_dict = pattern_context["vars_dict"]
+ yield_func = pattern_context["yield_func"]
+
+ # def match(self, yield_func, expression, vars_dict, evaluation, **kwargs):
+ # for vars_2, rest in self.pattern.match(expression, vars_dict, evaluation):
def yield_match(vars_2, rest):
testname = self.test_name
items = expression.get_sequence()
@@ -713,25 +803,31 @@ def yield_match(vars_2, rest):
quick_test = self.quick_pattern_test(item, testname, evaluation)
if quick_test is False:
break
- elif quick_test is True:
+ if quick_test is True:
continue
# raise StopGenerator
- else:
- test_expr = Expression(self.test, item)
- test_value = test_expr.evaluate(evaluation)
- if test_value is not SymbolTrue:
- break
- # raise StopGenerator
+ test_expr = Expression(self.test, item)
+ test_value = test_expr.evaluate(evaluation)
+ if test_value is not SymbolTrue:
+ break
+ # raise StopGenerator
else:
yield_func(vars_2, None)
# try:
- self.pattern.match(yield_match, expression, vars, evaluation)
+ self.pattern.match(
+ expression,
+ {
+ "yield_func": yield_match,
+ "vars_dict": vars_dict,
+ "evaluation": evaluation,
+ },
+ )
# except StopGenerator:
# pass
- def get_match_count(self, vars={}):
- return self.pattern.get_match_count(vars)
+ def get_match_count(self, vars_dict: OptionalType[dict] = None):
+ return self.pattern.get_match_count(vars_dict)
class Alternatives(BinaryOperator, PatternObject):
@@ -763,31 +859,34 @@ class Alternatives(BinaryOperator, PatternObject):
def init(
self, expr: Expression, evaluation: OptionalType[Evaluation] = None
) -> None:
- super(Alternatives, self).init(expr, evaluation=evaluation)
+ super().init(expr, evaluation=evaluation)
self.alternatives = [
BasePattern.create(element, evaluation=evaluation)
for element in expr.elements
]
- def match(self, yield_func, expression, vars, evaluation, **kwargs):
+ def match(self, expression: Expression, pattern_context: dict):
+ """Match with Alternatives"""
for alternative in self.alternatives:
- # for new_vars, rest in alternative.match(
- # expression, vars, evaluation):
- # yield_func(new_vars, rest)
- alternative.match(yield_func, expression, vars, evaluation)
-
- def get_match_count(self, vars={}):
- range = None
+ # for new_vars_dict, rest in alternative.match(
+ # expression, vars_dict, evaluation):
+ # yield_func(new_vars_dict, rest)
+ alternative.match(expression, pattern_context)
+
+ def get_match_count(
+ self, vars_dict: OptionalType[dict] = None
+ ) -> Union[None, int, tuple]:
+ range_lst = None
for alternative in self.alternatives:
- sub = alternative.get_match_count(vars)
- if range is None:
- range = list(sub)
+ sub = alternative.get_match_count(vars_dict)
+ if range_lst is None:
+ range_lst = tuple(sub)
else:
- if sub[0] < range[0]:
- range[0] = sub[0]
- if range[1] is None or sub[1] > range[1]:
- range[1] = sub[1]
- return range
+ if sub[0] < range_lst[0]:
+ range_lst[0] = sub[0]
+ if range_lst[1] is None or sub[1] > range_lst[1]:
+ range_lst[1] = sub[1]
+ return tuple(range_lst)
class _StopGeneratorExcept(StopGenerator):
@@ -826,23 +925,28 @@ class Except(PatternObject):
def init(
self, expr: Expression, evaluation: OptionalType[Evaluation] = None
) -> None:
- super(Except, self).init(expr, evaluation=evaluation)
+ super().init(expr, evaluation=evaluation)
self.c = BasePattern.create(expr.elements[0], evaluation=evaluation)
if len(expr.elements) == 2:
self.p = BasePattern.create(expr.elements[1], evaluation=evaluation)
else:
self.p = BasePattern.create(Expression(SymbolBlank), evaluation=evaluation)
- def match(self, yield_func, expression, vars, evaluation, **kwargs):
- def except_yield_func(vars, rest):
+ def match(self, expression: Expression, pattern_context: dict):
+ """Match with Exception Pattern"""
+
+ def except_yield_func(vars_dict, rest):
raise _StopGeneratorExcept(True)
+ new_pattern_context = pattern_context.copy()
+ new_pattern_context["yield_func"] = except_yield_func
+
try:
- self.c.match(except_yield_func, expression, vars, evaluation)
+ self.c.match(expression, new_pattern_context)
except _StopGeneratorExcept:
pass
else:
- self.p.match(yield_func, expression, vars, evaluation)
+ self.p.match(expression, pattern_context)
class Verbatim(PatternObject):
@@ -874,12 +978,16 @@ class Verbatim(PatternObject):
def init(
self, expr: Expression, evaluation: OptionalType[Evaluation] = None
) -> None:
- super(Verbatim, self).init(expr, evaluation=evaluation)
+ super().init(expr, evaluation=evaluation)
self.content = expr.elements[0]
- def match(self, yield_func, expression, vars, evaluation, **kwargs):
+ def match(self, expression: Expression, pattern_context: dict):
+ """Match with Verbatim Pattern"""
+ vars_dict = pattern_context["vars_dict"]
+ yield_func = pattern_context["yield_func"]
+
if self.content.sameQ(expression):
- yield_func(vars, None)
+ yield_func(vars_dict, None)
class HoldPattern(PatternObject):
@@ -910,14 +1018,14 @@ class HoldPattern(PatternObject):
def init(
self, expr: Expression, evaluation: OptionalType[Evaluation] = None
) -> None:
- super(HoldPattern, self).init(expr, evaluation=evaluation)
+ super().init(expr, evaluation=evaluation)
self.pattern = BasePattern.create(expr.elements[0], evaluation=evaluation)
- def match(self, yield_func, expression, vars, evaluation, **kwargs):
- # for new_vars, rest in self.pattern.match(
- # expression, vars, evaluation):
- # yield new_vars, rest
- self.pattern.match(yield_func, expression, vars, evaluation)
+ def match(self, expression: Expression, pattern_context: dict):
+ # for new_vars_dict, rest in self.pattern.match(
+ # expression, vars_dict, evaluation):
+ # yield new_vars_dict, rest
+ self.pattern.match(expression, pattern_context)
class Pattern(PatternObject):
@@ -976,7 +1084,11 @@ class Pattern(PatternObject):
}
rules = {
- "MakeBoxes[Verbatim[Pattern][symbol_Symbol, blank_Blank|blank_BlankSequence|blank_BlankNullSequence], f:StandardForm|TraditionalForm|InputForm|OutputForm]": "MakeBoxes[symbol, f] <> MakeBoxes[blank, f]",
+ (
+ "MakeBoxes[Verbatim[Pattern][symbol_Symbol, blank_Blank|"
+ "blank_BlankSequence|blank_BlankNullSequence], "
+ "f:StandardForm|TraditionalForm|InputForm|OutputForm]"
+ ): "MakeBoxes[symbol, f] <> MakeBoxes[blank, f]",
# 'StringForm["`1``2`", HoldForm[symbol], blank]',
}
@@ -996,51 +1108,62 @@ def init(
varname = expr.elements[0].get_name()
if varname is None or varname == "":
self.error("patvar", expr)
- super(Pattern, self).init(expr, evaluation=evaluation)
+ super().init(expr, evaluation=evaluation)
self.varname = varname
self.pattern = BasePattern.create(expr.elements[1], evaluation=evaluation)
def __repr__(self):
return "" % repr(self.pattern)
- def get_match_count(self, vars={}):
- return self.pattern.get_match_count(vars)
+ def get_match_count(
+ self, vars_dict: OptionalType[dict] = None
+ ) -> Union[int, tuple]:
+ return self.pattern.get_match_count(vars_dict)
+
+ def match(self, expression: Expression, pattern_context: dict):
+ """Match with a (named) pattern"""
+ yield_func = pattern_context["yield_func"]
+ vars_dict = pattern_context["vars_dict"]
- def match(self, yield_func, expression, vars_dict, evaluation, **kwargs):
existing = vars_dict.get(self.varname, None)
if existing is None:
- new_vars = vars_dict.copy()
- new_vars[self.varname] = expression
+ new_vars_dict = vars_dict.copy()
+ new_vars_dict[self.varname] = expression
+ pattern_context = pattern_context.copy()
+ pattern_context["vars_dict"] = new_vars_dict
# for vars_2, rest in self.pattern.match(
- # expression, new_vars, evaluation):
+ # expression, new_vars_dict, evaluation):
# yield vars_2, rest
- if type(self.pattern) is OptionsPattern:
+ if isinstance(self.pattern, OptionsPattern):
self.pattern.match(
- yield_func, expression, new_vars, evaluation, **kwargs
+ expression=expression, pattern_context=pattern_context
)
else:
- self.pattern.match(yield_func, expression, new_vars, evaluation)
+ self.pattern.match(
+ expression=expression, pattern_context=pattern_context
+ )
else:
if existing.sameQ(expression):
yield_func(vars_dict, None)
- def get_match_candidates(
- self, elements: tuple, expression, attributes, evaluation, vars_dict=None
- ):
- if vars_dict is None:
- vars_dict = {}
+ def get_match_candidates(self, elements: tuple, pattern_context: dict) -> tuple:
+ """
+ Return a sub-tuple of elements that match with
+ the pattern.
+ Optional parameters provide information
+ about the context where the elements and the
+ patterns come from.
+ """
+ vars_dict = pattern_context.get("vars_dict", {})
+
existing = vars_dict.get(self.varname, None)
if existing is None:
- return self.pattern.get_match_candidates(
- elements, expression, attributes, evaluation, vars_dict
- )
- else:
- # Treat existing variable as verbatim
- verbatim_expr = Expression(SymbolVerbatim, existing)
- verbatim = Verbatim(verbatim_expr)
- return verbatim.get_match_candidates(
- elements, expression, attributes, evaluation, vars_dict
- )
+ return self.pattern.get_match_candidates(elements, pattern_context)
+
+ # Treat existing variable as verbatim
+ verbatim_expr = Expression(SymbolVerbatim, existing)
+ verbatim = Verbatim(verbatim_expr)
+ return verbatim.get_match_candidates(elements, pattern_context)
class Optional(BinaryOperator, PatternObject):
@@ -1099,24 +1222,21 @@ class Optional(BinaryOperator, PatternObject):
def init(
self, expr: Expression, evaluation: OptionalType[Evaluation] = None
) -> None:
- super(Optional, self).init(expr, evaluation=evaluation)
+ super().init(expr, evaluation=evaluation)
self.pattern = BasePattern.create(expr.elements[0], evaluation=evaluation)
if len(expr.elements) == 2:
self.default = expr.elements[1]
else:
self.default = None
- def match(
- self,
- yield_func,
- expression,
- vars,
- evaluation,
- head=None,
- element_index=None,
- element_count=None,
- **kwargs,
- ):
+ def match(self, expression: Expression, pattern_context: dict):
+ head = pattern_context.get("head", None)
+ evaluation = pattern_context["evaluation"]
+ element_index = pattern_context.get("element_index", None)
+ element_count = pattern_context.get("element_count", None)
+ vars_dict = pattern_context["vars_dict"]
+ yield_func = pattern_context["yield_func"]
+
if expression.has_form("Sequence", 0):
if self.default is None:
if head is None: # head should be given by match_element!
@@ -1135,11 +1255,18 @@ def match(
default = self.default
expression = default
- # for vars_2, rest in self.pattern.match(expression, vars, evaluation):
+ # for vars_2, rest in self.pattern.match(expression, vars_dict, evaluation):
# yield vars_2, rest
- self.pattern.match(yield_func, expression, vars, evaluation)
+ self.pattern.match(
+ expression,
+ {
+ "yield_func": yield_func,
+ "vars_dict": vars_dict,
+ "evaluation": evaluation,
+ },
+ )
- def get_match_count(self, vars={}):
+ def get_match_count(self, vars_dict: OptionalType[dict] = None) -> tuple:
return (0, 1)
@@ -1149,6 +1276,10 @@ def get_default_value(
k: OptionalType[int] = None,
n: OptionalType[int] = None,
):
+ """
+ Get the default value associated to a name, and optionally,
+ to a position in the expression.
+ """
pos = []
if k is not None:
pos.append(k)
@@ -1191,7 +1322,7 @@ def __new__(cls, *args, **kwargs):
def init(
self, expr: Expression, evaluation: OptionalType[Evaluation] = None
) -> None:
- super(_Blank, self).init(expr, evaluation=evaluation)
+ super().init(expr, evaluation=evaluation)
if expr.elements:
self.head = expr.elements[0]
else:
@@ -1237,20 +1368,16 @@ class Blank(_Blank):
}
summary_text = "match to any single expression"
- def match(
- self,
- yield_func: Callable,
- expression: Expression,
- vars: dict,
- evaluation: Evaluation,
- **kwargs,
- ):
+ def match(self, expression: Expression, pattern_context: dict):
+ vars_dict = pattern_context["vars_dict"]
+ yield_func = pattern_context["yield_func"]
+
if not expression.has_form("Sequence", 0):
if self.head is not None:
if expression.get_head().sameQ(self.head):
- yield_func(vars, None)
+ yield_func(vars_dict, None)
else:
- yield_func(vars, None)
+ yield_func(vars_dict, None)
class BlankSequence(_Blank):
@@ -1292,14 +1419,9 @@ class BlankSequence(_Blank):
}
summary_text = "match to a non-empty sequence of elements"
- def match(
- self,
- yield_func: Callable,
- expression: Expression,
- vars: dict,
- evaluation: Evaluation,
- **kwargs,
- ):
+ def match(self, expression: Expression, pattern_context: dict):
+ vars_dict = pattern_context["vars_dict"]
+ yield_func = pattern_context["yield_func"]
elements = expression.get_sequence()
if not elements:
return
@@ -1310,11 +1432,11 @@ def match(
ok = False
break
if ok:
- yield_func(vars, None)
+ yield_func(vars_dict, None)
else:
- yield_func(vars, None)
+ yield_func(vars_dict, None)
- def get_match_count(self, vars={}):
+ def get_match_count(self, vars_dict: OptionalType[dict] = None) -> tuple:
return (1, None)
@@ -1341,14 +1463,10 @@ class BlankNullSequence(_Blank):
}
summary_text = "match to a sequence of zero or more elements"
- def match(
- self,
- yield_func: Callable,
- expression: Expression,
- vars: dict,
- evaluation: Evaluation,
- **kwargs,
- ):
+ def match(self, expression: Expression, pattern_context: dict):
+ """Match with a BlankNullSequence"""
+ vars_dict = pattern_context["vars_dict"]
+ yield_func = pattern_context["yield_func"]
elements = expression.get_sequence()
if self.head:
ok = True
@@ -1357,11 +1475,11 @@ def match(
ok = False
break
if ok:
- yield_func(vars, None)
+ yield_func(vars_dict, None)
else:
- yield_func(vars, None)
+ yield_func(vars_dict, None)
- def get_match_count(self, vars={}):
+ def get_match_count(self, vars_dict: OptionalType[dict] = None) -> tuple:
return (0, None)
@@ -1398,12 +1516,12 @@ class Repeated(PostfixOperator, PatternObject):
def init(
self,
expr: Expression,
- min: int = 1,
+ min_idx: int = 1,
evaluation: OptionalType[Evaluation] = None,
):
self.pattern = BasePattern.create(expr.elements[0], evaluation=evaluation)
self.max = None
- self.min = min
+ self.min = min_idx
if len(expr.elements) == 2:
element_1 = expr.elements[1]
allnumbers = not any(
@@ -1417,32 +1535,43 @@ def init(
else:
self.error("range", 2, expr)
- def match(self, yield_func, expression, vars, evaluation, **kwargs):
+ def match(self, expression: Expression, pattern_context: dict):
+ """Match with Repeated[...]"""
+ yield_func = pattern_context["yield_func"]
+ vars_dict = pattern_context["vars_dict"]
+ evaluation = pattern_context["evaluation"]
elements = expression.get_sequence()
if len(elements) < self.min:
return
if self.max is not None and len(elements) > self.max:
return
- def iter(yield_iter, rest_elements, vars):
+ def iter_fn(yield_iter, rest_elements, vars_dict):
if rest_elements:
- # for new_vars, rest in self.pattern.match(rest_elements[0],
- # vars, evaluation):
- def yield_match(new_vars, rest):
- # for sub_vars, sub_rest in iter(rest_elements[1:],
+ # for new_vars_dict, rest in self.pattern.match(rest_elements[0],
+ # vars_dict, evaluation):
+ def yield_match(new_vars_dict, rest):
+ # for sub_vars_dict, sub_rest in iter(rest_elements[1:],
# new_vars):
- # yield sub_vars, rest
- iter(yield_iter, rest_elements[1:], new_vars)
+ # yield sub_vars_dict, rest
+ iter_fn(yield_iter, rest_elements[1:], new_vars_dict)
- self.pattern.match(yield_match, rest_elements[0], vars, evaluation)
+ self.pattern.match(
+ rest_elements[0],
+ {
+ "yield_func": yield_match,
+ "vars_dict": vars_dict,
+ "evaluation": evaluation,
+ },
+ )
else:
- yield_iter(vars, None)
+ yield_iter(vars_dict, None)
- # for vars, rest in iter(elements, vars):
- # yield_func(vars, rest)
- iter(yield_func, elements, vars)
+ # for vars_dict, rest in iter(elements, vars):
+ # yield_func(vars_dict, rest)
+ iter_fn(yield_func, elements, vars_dict)
- def get_match_count(self, vars={}):
+ def get_match_count(self, vars_dict: OptionalType[dict] = None) -> tuple:
return (self.min, self.max)
@@ -1467,7 +1596,7 @@ class RepeatedNull(Repeated):
def init(
self, expr: Expression, evaluation: OptionalType[Evaluation] = None
) -> None:
- super(RepeatedNull, self).init(expr, min=0, evaluation=evaluation)
+ super().init(expr, min_idx=0, evaluation=evaluation)
class Shortest(Builtin):
@@ -1548,7 +1677,7 @@ class Condition(BinaryOperator, PatternObject):
def init(
self, expr: Expression, evaluation: OptionalType[Evaluation] = None
) -> None:
- super(Condition, self).init(expr, evaluation=evaluation)
+ super().init(expr, evaluation=evaluation)
self.test = expr.elements[1]
# if (expr.elements[0].get_head_name() == "System`Condition" and
# len(expr.elements[0].elements) == 2):
@@ -1557,23 +1686,22 @@ def init(
# else:
self.pattern = BasePattern.create(expr.elements[0], evaluation=evaluation)
- def match(
- self,
- yield_func: Callable,
- expression: Expression,
- vars: dict,
- evaluation: Evaluation,
- **kwargs,
- ):
- # for new_vars, rest in self.pattern.match(expression, vars,
+ def match(self, expression: Expression, pattern_context: dict):
+ """Match with Condition pattern"""
+ # for new_vars_dict, rest in self.pattern.match(expression, vars_dict,
# evaluation):
- def yield_match(new_vars, rest):
- test_expr = self.test.replace_vars(new_vars)
+ evaluation = pattern_context["evaluation"]
+ yield_func = pattern_context["yield_func"]
+
+ def yield_match(new_vars_dict, rest):
+ test_expr = self.test.replace_vars(new_vars_dict)
test_result = test_expr.evaluate(evaluation)
if test_result is SymbolTrue:
- yield_func(new_vars, rest)
+ yield_func(new_vars_dict, rest)
- self.pattern.match(yield_match, expression, vars, evaluation)
+ pattern_context = pattern_context.copy()
+ pattern_context["yield_func"] = yield_match
+ self.pattern.match(expression, pattern_context)
class OptionsPattern(PatternObject):
@@ -1622,7 +1750,7 @@ class OptionsPattern(PatternObject):
def init(
self, expr: Expression, evaluation: OptionalType[Evaluation] = None
) -> None:
- super(OptionsPattern, self).init(expr, evaluation=evaluation)
+ super().init(expr, evaluation=evaluation)
try:
self.defaults = expr.elements[0]
except IndexError:
@@ -1630,16 +1758,12 @@ def init(
# function. Set to not None in self.match
self.defaults = None
- def match(
- self,
- yield_func: Callable,
- expression: Expression,
- vars: dict,
- evaluation: Evaluation,
- **kwargs,
- ):
+ def match(self, expression: Expression, pattern_context: dict):
+ """Match with an OptionsPattern"""
+ head = pattern_context.get("head", None)
+ evaluation = pattern_context["evaluation"]
if self.defaults is None:
- self.defaults = kwargs.get("head")
+ self.defaults = head
if self.defaults is None:
# we end up here with OptionsPattern that do not have any
# default options defined, e.g. with this code:
@@ -1664,28 +1788,27 @@ def match(
if option_values is None:
return
values.update(option_values)
- new_vars = vars.copy()
+ new_vars_dict = pattern_context["vars_dict"].copy()
for name, value in values.items():
- new_vars["_option_" + name] = value
- yield_func(new_vars, None)
+ new_vars_dict["_option_" + name] = value
+ pattern_context["yield_func"](new_vars_dict, None)
- def get_match_count(self, vars: dict = {}):
+ def get_match_count(self, vars_dict: OptionalType[dict] = None) -> tuple:
return (0, None)
def get_match_candidates(
- self,
- elements: Tuple[BaseElement],
- expression: Expression,
- attributes: int,
- evaluation: Evaluation,
- vars: dict = {},
- ):
+ self, elements: Tuple[BaseElement], pattern_context: dict
+ ) -> tuple:
+ """
+ Return the sub-tuple of elements that matches with the pattern.
+ """
+
def _match(element: Expression):
return element.has_form(("Rule", "RuleDelayed"), 2) or element.has_form(
"List", None
)
- return [element for element in elements if _match(element)]
+ return tuple((element for element in elements if _match(element)))
class Dispatch(Atom):
@@ -1697,7 +1820,7 @@ def __init__(self, rulelist: Expression, evaluation: Evaluation) -> None:
self._elements = None
self._head = SymbolDispatch
- def get_sort_key(self) -> tuple:
+ def get_sort_key(self, pattern_sort: bool = False) -> tuple:
return self.src.get_sort_key()
def get_atom_name(self):
@@ -1785,15 +1908,14 @@ def eval_create(
# WMA does not raise this message: just leave it unevaluated,
# and raise an error when the dispatch rule is used.
evaluation.message("Dispatch", "invrpl", rule)
- return
+ return None
try:
return Dispatch(flatten_list, evaluation)
except Exception:
- return
+ return None
def eval_normal(self, dispatch: Dispatch, evaluation: Evaluation) -> ListExpression:
"""Normal[dispatch_Dispatch]"""
if isinstance(dispatch, Dispatch):
return dispatch.src
- else:
- return dispatch.elements[0]
+ return dispatch.elements[0]
diff --git a/mathics/core/builtin.py b/mathics/core/builtin.py
index abd5df3e3..f95f4058c 100644
--- a/mathics/core/builtin.py
+++ b/mathics/core/builtin.py
@@ -1198,11 +1198,11 @@ def get_lookup_name(self) -> str:
return self.get_name()
def get_match_candidates(
- self, elements: Tuple[BaseElement], expression, attributes, evaluation, vars={}
+ self, elements: Tuple[BaseElement], pattern_context: dict
) -> Tuple[BaseElement]:
return elements
- def get_match_count(self, vars={}):
+ def get_match_count(self, vars_dict: dict = {}):
return (1, 1)
def get_sort_key(self, pattern_sort=False) -> tuple:
diff --git a/mathics/core/pattern.py b/mathics/core/pattern.py
index c96166b28..2d2bcbcb6 100644
--- a/mathics/core/pattern.py
+++ b/mathics/core/pattern.py
@@ -1,19 +1,22 @@
# cython: language_level=3
# cython: profile=False
# -*- coding: utf-8 -*-
-"""Core to Mathics3 is are patterns which match symbolic expressions. A pattern are built up in a custon pattern notation.
+"""Core to Mathics3 is are patterns which match symbolic expressions. Patterns
+are built up in a custon pattern notation.
The parts of a pattern are called "Pattern Objects".
-While there is a built-in function which allows users to match parts of expressions, patterns are also used in applying of transformation
+While there is a built-in function which allows users to match parts of
+expressions, patterns are also used in applying of transformation
rules and deciding functions that get applied.
-See also: mathics.core.rules and https://reference.wolfram.com/language/tutorial/PatternsAndTransformationRules.html
+See also: mathics.core.rules and
+https://reference.wolfram.com/language/tutorial/PatternsAndTransformationRules.html
"""
from abc import ABC
from itertools import chain
-from typing import Callable, List, Optional, Tuple, Union
+from typing import Callable, Optional, Tuple, Union
from mathics.core.atoms import Integer
from mathics.core.attributes import A_FLAT, A_ONE_IDENTITY, A_ORDERLESS
@@ -236,17 +239,7 @@ def has_form(self, *args):
"""Compare the expression against a form"""
return self.expr.has_form(*args)
- def match(
- self,
- yield_func: Callable,
- expression: BaseElement,
- vars_dict: dict,
- evaluation: Evaluation,
- head: Optional[Symbol] = None,
- element_index: Optional[int] = None,
- element_count: Optional[int] = None,
- fully: bool = True,
- ):
+ def match(self, expression: BaseElement, pattern_context: dict):
"""
Check if the expression matches the pattern (self).
If it does, calls `yield_func`.
@@ -263,19 +256,14 @@ def match(
"""
raise NotImplementedError
- def does_match(
- self,
- expression: BaseElement,
- evaluation: Evaluation,
- vars_dict: Optional[dict] = None,
- fully: bool = True,
- ) -> bool:
+ def does_match(self, expression: BaseElement, pattern_context: dict) -> bool:
"""returns True if `expression` matches self or we have
reached the end of the matches, and False if it does not.
"""
+ evaluation: Evaluation = pattern_context["evaluation"]
+ vars_dict: Optional[dict] = pattern_context.setdefault("vars_dict", {})
+ fully: bool = pattern_context.get("fully", True)
- if vars_dict is None:
- vars_dict = {}
# for sub_vars, rest in self.match( # nopep8
# expression, vars, evaluation, fully=fully):
# return True
@@ -284,38 +272,37 @@ def yield_match(sub_vars, rest):
raise StopGenerator_Pattern(True)
try:
- self.match(yield_match, expression, vars_dict, evaluation, fully=fully)
+ self.match(
+ expression=expression,
+ pattern_context={
+ "yield_func": yield_match,
+ "vars_dict": vars_dict,
+ "evaluation": evaluation,
+ "fully": fully,
+ },
+ )
except StopGenerator_Pattern as exc:
return exc.value
return False
def get_match_candidates(
- self,
- elements: Tuple[BaseElement],
- expression: BaseElement,
- attributes: int,
- evaluation: Evaluation,
- vars_dict: Optional[dict] = None,
- ):
+ self, elements: Tuple[BaseElement], pattern_context: dict
+ ) -> tuple:
"""
- Get the candidates that matches with the pattern.
+ Get the a sub-tuple of elements that are candidates
+ matching with the pattern.
+
+ Optional parameters provide information
+ about the context where the elements and the
+ patterns come from.
"""
return tuple()
def get_match_candidates_count(
- self,
- elements: Tuple[BaseElement],
- expression: BaseElement,
- attributes: int,
- evaluation: Evaluation,
- vars_dict: Optional[dict] = None,
- ):
+ self, elements: Tuple[BaseElement], pattern_context: dict
+ ) -> Union[int, tuple]:
"""Return the number of candidates that match with the pattern."""
- return len(
- self.get_match_candidates(
- elements, expression, attributes, evaluation, vars_dict
- )
- )
+ return len(self.get_match_candidates(elements, pattern_context))
def sameQ(self, other: BaseElement) -> bool:
"""Mathics SameQ"""
@@ -342,61 +329,42 @@ def __repr__(self):
def match_symbol(
self,
- yield_func: Callable,
expression: BaseElement,
- vars_dict: dict,
- evaluation: Evaluation,
- head: Optional[Symbol] = None,
- element_index: Optional[int] = None,
- element_count: Optional[int] = None,
- fully: bool = True,
+ pattern_context,
):
"""Match against a symbol"""
+ assert isinstance(expression, BaseElement)
if expression is self.atom:
- yield_func(vars_dict, None)
+ pattern_context["yield_func"](pattern_context["vars_dict"], None)
def get_match_symbol_candidates(
- self,
- elements: tuple,
- expression: BaseElement,
- attributes: int,
- evaluation: Evaluation,
- vars_dict: Optional[dict] = None,
- ):
- """Find the candidates that matches with the pattern"""
- return [element for element in elements if element is self.atom]
+ self, elements: tuple, pattern_context: dict
+ ) -> tuple:
+ """Find the sub-tuple of elements that matches with the pattern"""
+ return tuple((element for element in elements if element is self.atom))
- def match(
- self,
- yield_func: Callable,
- expression: BaseElement,
- vars_dict: dict,
- evaluation: Evaluation,
- head: Optional[Symbol] = None,
- element_index: Optional[int] = None,
- element_count: Optional[int] = None,
- fully: bool = True,
- ):
+ def match(self, expression: BaseElement, pattern_context: dict):
"""Try to match the patterh with the expression."""
+
if isinstance(expression, Atom) and expression.sameQ(self.atom):
# yield vars, None
- yield_func(vars_dict, None)
+ pattern_context["yield_func"](pattern_context["vars_dict"], None)
def get_match_candidates(
- self,
- elements: Tuple[BaseElement],
- expression: BaseElement,
- attributes: int,
- evaluation: Evaluation,
- vars_dict: Optional[dict] = None,
- ):
- return [
- element
- for element in elements
- if (isinstance(element, Atom) and element.sameQ(self.atom))
- ]
+ self, elements: Tuple[BaseElement], pattern_context: dict
+ ) -> tuple:
+ """
+ Return a sub-tuple of elements that matches with the pattern.
+ """
+ return tuple(
+ (
+ element
+ for element in elements
+ if (isinstance(element, Atom) and element.sameQ(self.atom))
+ )
+ )
- def get_match_count(self, vars_dict: Optional[dict] = None):
+ def get_match_count(self, vars_dict: Optional[dict] = None) -> Tuple[int, int]:
"""The number of matches"""
return (1, 1)
@@ -434,6 +402,7 @@ def __init__(
def __set_pattern_attributes__(self, attributes):
if attributes is None or self.attributes is not None:
+ self.get_pre_choices = self._get_pre_choices
return
self.attributes = attributes
@@ -447,18 +416,13 @@ def __set_pattern_attributes__(self, attributes):
element.isliteral for element in self.elements
)
- def match(
- self,
- yield_func: Callable,
- expression: BaseElement,
- vars_dict: dict,
- evaluation: Evaluation,
- head: Optional[Symbol] = None,
- element_index: Optional[int] = None,
- element_count: Optional[int] = None,
- fully: bool = True,
- ):
+ def match(self, expression: BaseElement, pattern_context: dict):
"""Try to match the pattern against an Expression"""
+ evaluation = pattern_context["evaluation"]
+ yield_func = pattern_context["yield_func"]
+ vars_dict = pattern_context["vars_dict"]
+ fully = pattern_context.get("fully", True)
+
evaluation.check_stopped()
if self.isliteral:
if expression.sameQ(self.expr):
@@ -475,48 +439,38 @@ def match(
if not A_FLAT & attributes:
fully = True
- parms = {
- "attributes": attributes,
- "evaluation": evaluation,
- "expression": expression,
- "fully": fully,
- "self": self,
- "vars_dict": vars_dict,
- "yield_func": yield_func,
- }
+ parms = pattern_context.copy()
+ parms["fully"] = fully
+ parms["attributes"] = attributes
+ parms.setdefault("head", None)
+ parms.setdefault("element_index", None)
+ parms.setdefault("element_count", None)
if not isinstance(expression, Atom):
try:
basic_match_expression(
+ self,
+ expression,
parms,
)
except StopGenerator_ExpressionPattern_match:
return
if A_ONE_IDENTITY & attributes:
- match_expression_with_one_identity(
- parms,
- element_index=element_index,
- element_count=element_count,
- head=head,
- )
+ match_expression_with_one_identity(self, expression=expression, parms=parms)
- def get_pre_choices(
- self,
- yield_choice: Callable,
- expression: BaseElement,
- attributes: int,
- vars_dict: dict,
+ def _get_pre_choices(
+ self, expression: BaseElement, yield_choice: Callable, pattern_context: dict
):
"""
If not Orderless, call yield_choice with vars as the parameter.
"""
+ attributes = pattern_context.get("attributes")
+ assert isinstance(attributes, int)
if A_ORDERLESS & attributes:
- get_pre_choices_orderless(
- self, yield_choice, expression, attributes, vars_dict
- )
+ get_pre_choices_orderless(self, expression, pattern_context)
else:
- yield_choice(vars_dict)
+ pattern_context["yield_choice"](pattern_context["vars_dict"])
def filter_elements(self, head_name: str):
"""Filter the elements with a given head_name"""
@@ -528,23 +482,28 @@ def filter_elements(self, head_name: str):
def __repr__(self):
return f""
- def get_match_count(self, vars_dict: Optional[dict] = None):
+ def get_match_count(self, vars_dict: Optional[dict] = None) -> Tuple[int, int]:
"""the number of matches"""
return (1, 1)
- def get_wrappings(
- self,
- yield_func: Callable,
- items: Tuple,
- max_count: Optional[int],
- expression: Expression,
- attributes: int,
- include_flattened: bool = True,
- ):
- """Get the possible wrappings"""
+ def get_wrappings(self, yield_func: Callable, items: Tuple, pattern_context: dict):
+ """
+ Get the possible wrappings
+
+ If items has length 1, apply yield_func to the unique element.
+ Otherwise, apply it to a sequence. If the expression has the
+ attribute `Orderless`, apply it to all the possible orders.
+ Finally , if the expression is `Flat`, and the parameter `include_flattened`
+ is `True`, apply yield_func to the expression with the head of the original
+ expression applied to the original sequence.
+ """
if len(items) == 1:
yield_func(items[0])
else:
+ max_count: Optional[int] = pattern_context["max_count"]
+ expression: Expression = pattern_context["expression"]
+ attributes: int = pattern_context["attributes"]
+ include_flattened: bool = pattern_context.get("include_flattened", True)
if max_count is None or len(items) <= max_count:
if A_ORDERLESS & attributes:
for perm in permutations(items):
@@ -555,26 +514,25 @@ def get_wrappings(
sequence = Expression(SymbolSequence, *items)
sequence.pattern_sequence = True
yield_func(sequence)
+ # TODO: check if this should not be applied to each possible
+ # orders if A_ORDERLESS.
if A_FLAT & attributes and include_flattened:
yield_func(Expression(expression.get_head(), *items))
def match_element(
self,
- yield_func: Callable,
element: BaseElement,
- rest_elements: Tuple,
- rest_expression: Tuple[List, List],
- vars_dict: dict,
- expression: BaseElement,
- attributes: int,
- evaluation: Evaluation,
- element_index: int = 1,
- element_count: Optional[int] = None,
- first: bool = False,
- fully: bool = True,
- depth: int = 1,
+ pattern_context,
):
"""Try to match an element."""
+ attributes: int = pattern_context["attributes"]
+ evaluation: Evaluation = pattern_context["evaluation"]
+ expression: BaseElement = pattern_context["expression"]
+ first: bool = pattern_context.setdefault("first", False)
+ fully: bool = pattern_context.setdefault("fully", True)
+ vars_dict: dict = pattern_context["vars_dict"]
+ rest_expression: tuple = pattern_context["rest_expression"]
+ rest_elements: tuple = pattern_context["rest_elements"]
if rest_expression is None:
rest_expression = ([], [])
@@ -582,11 +540,7 @@ def match_element(
match_count = element.get_match_count(vars_dict)
element_candidates = element.get_match_candidates(
- tuple(rest_expression[1]), # element.candidates,
- expression,
- attributes,
- evaluation,
- vars_dict,
+ tuple(rest_expression[1]), pattern_context # element.candidates,
)
if len(element_candidates) < match_count[0]:
@@ -642,24 +596,14 @@ def match_element(
*set_lengths,
)
- parms = {
- "attributes": attributes,
- "depth": depth + 1,
- "element": element,
- "element_count": element_count,
- "element_index": element_index,
- "evaluation": evaluation,
- "expression": expression,
- "fully": fully,
- "match_count": match_count,
- "next_index": element_index + 1,
- "pattern": self,
- "rest_elements": rest_elements,
- "rest_expression": rest_expression,
- "try_flattened": try_flattened,
- "yield_func": yield_func,
- "vars": vars_dict,
- }
+ parms = pattern_context.copy()
+ parms["depth"] = parms.get("depth", 1) + 1
+ parms["next_index"] = parms.setdefault("element_index", 1) + 1
+ parms["pattern"] = self
+ parms["try_flattened"] = try_flattened
+ parms["match_count"] = match_count
+ parms["element"] = element
+
if rest_elements:
parms["next_element"] = rest_elements[0]
parms["next_rest_elements"] = rest_elements[1:]
@@ -668,44 +612,43 @@ def match_element(
expression_pattern_match_element_process_items(items, items_rest, parms)
def get_match_candidates(
- self,
- elements: Tuple[BaseElement],
- expression: BaseElement,
- attributes: int,
- evaluation: Evaluation,
- vars_dict: Optional[dict] = None,
- ):
+ self, elements: Tuple[BaseElement], pattern_context
+ ) -> tuple:
"""
Finds possible elements that could match the pattern, ignoring future
pattern variable definitions, but taking into account already fixed
variables.
"""
# TODO: fixed_vars!
-
- return [
- element
- for element in elements
- if self.does_match(element, evaluation, vars_dict)
- ]
+ evaluation: Evaluation = pattern_context["evaluation"]
+ vars_dict: Optional[dict] = pattern_context.setdefault("vars_dict", {})
+ return tuple(
+ (
+ element
+ for element in elements
+ if self.does_match(
+ element, {"evaluation": evaluation, "vars_dict": vars_dict}
+ )
+ )
+ )
def get_match_candidates_count(
- self,
- elements: Tuple[BaseElement],
- expression: BaseElement,
- attributes: int,
- evaluation: Evaluation,
- vars_dict: Optional[dict] = None,
- ):
+ self, elements: Tuple[BaseElement], pattern_context
+ ) -> Union[int, tuple]:
"""
Finds possible elements that could match the pattern, ignoring future
pattern variable definitions, but taking into account already fixed
variables.
"""
# TODO: fixed_vars!
+ evaluation: Evaluation = pattern_context["evaluation"]
+ vars_dict: Optional[dict] = pattern_context.setdefault("vars_dict", {})
count = 0
for element in elements:
- if self.does_match(element, evaluation, vars_dict):
+ if self.does_match(
+ element, {"evaluation": evaluation, "vars_dict": vars_dict}
+ ):
count += 1
return count
@@ -715,10 +658,9 @@ def sort(self):
def match_expression_with_one_identity(
+ self: BasePattern,
+ expression: Expression,
parms: dict,
- element_index: Optional[int],
- element_count: Optional[int],
- head: Symbol,
):
"""
Process expressions with the attribute OneIdentity.
@@ -731,7 +673,6 @@ def match_expression_with_one_identity(
# set of default values, and a single pattern.
from mathics.builtin.patterns import Pattern
- self: ExpressionPattern = parms["self"]
vars_dict: dict = parms["vars_dict"]
evaluation: Evaluation = parms["evaluation"]
@@ -789,19 +730,24 @@ def match_expression_with_one_identity(
del optionals[""]
vars_dict.update(optionals)
# Try to match the non-optional element with the expression
- new_pattern.match(
- parms["yield_func"],
- parms["expression"],
- vars_dict,
- evaluation,
- head=head,
- element_index=element_index,
- element_count=element_count,
- fully=parms["fully"],
- )
-
-
-def basic_match_expression(parms):
+ # no_parms={
+ # "yield_func":parms["yield_func"],
+ # "vars_dict":vars_dict,
+ # "evaluation":evaluation,
+ # "head":head,
+ # "element_index":element_index,
+ # "element_count":element_count,
+ # "fully":parms["fully"],
+ # }
+
+ # TODO: remove me eventually
+ del parms["attributes"]
+ new_pattern.match(expression=expression, pattern_context=parms)
+
+
+def basic_match_expression(
+ self: ExpressionPattern, expression: Expression, parms: dict
+):
"""
Try to match a pattern with an expression
"""
@@ -810,10 +756,7 @@ def basic_match_expression(parms):
# if self.elements:
# next_element = self.elements[0]
# next_elements = self.elements[1:]
-
- self: ExpressionPattern = parms["self"]
yield_func: Callable = parms["yield_func"]
- expression: Expression = parms["expression"]
vars_dict: dict = parms["vars_dict"]
evaluation: Evaluation = parms["evaluation"]
attributes: int = parms["attributes"]
@@ -852,7 +795,8 @@ def yield_choice(pre_vars):
if not unmatched_elements:
raise StopGenerator_ExpressionPattern_match()
if not element.does_match(
- unmatched_elements[0], evaluation, pre_vars
+ unmatched_elements[0],
+ {"evaluation": evaluation, "vars_dict": pre_vars},
):
raise StopGenerator_ExpressionPattern_match()
unmatched_elements = unmatched_elements[1:]
@@ -862,10 +806,12 @@ def yield_choice(pre_vars):
if not leading_blanks:
candidates = element.get_match_candidates_count(
unmatched_elements,
- expression,
- attributes,
- evaluation,
- pre_vars,
+ {
+ "expression": expression,
+ "attributes": attributes,
+ "evaluation": evaluation,
+ "vars_dict": pre_vars,
+ },
)
if candidates < match_count[0]:
raise StopGenerator_ExpressionPattern_match()
@@ -877,17 +823,19 @@ def yield_choice(pre_vars):
# def yield_element(new_vars, rest):
# yield_func(new_vars, rest)
self.match_element(
- yield_func,
- next_element,
- tuple(next_elements),
- ([], expression.elements),
- pre_vars,
- expression,
- attributes,
- evaluation,
- first=True,
- fully=fully,
- element_count=len(self.elements),
+ element=next_element,
+ pattern_context={
+ "yield_func": yield_func,
+ "rest_elements": tuple(next_elements),
+ "rest_expression": ([], expression.elements),
+ "vars_dict": pre_vars,
+ "expression": expression,
+ "attributes": attributes,
+ "evaluation": evaluation,
+ "first": True,
+ "fully": fully,
+ "element_count": len(self.elements),
+ },
)
# for head_vars, _ in self.head.match(expression.get_head(), vars,
@@ -898,14 +846,29 @@ def yield_head(head_vars, _):
# expression, attributes, head_vars)
# for pre_vars in pre_choices:
- self.get_pre_choices(self, yield_choice, expression, attributes, head_vars)
+ self.get_pre_choices(
+ self,
+ expression,
+ {
+ "yield_choice": yield_choice,
+ "attributes": attributes,
+ "vars_dict": head_vars,
+ },
+ )
else:
if not expression.elements:
yield_func(head_vars, None)
else:
return
- self.head.match(yield_head, expression.get_head(), vars_dict, evaluation)
+ self.head.match(
+ expression.get_head(),
+ {
+ "yield_func": yield_head,
+ "vars_dict": vars_dict,
+ "evaluation": evaluation,
+ },
+ )
def expression_pattern_match_element_orderless(
@@ -976,10 +939,12 @@ def expression_pattern_match_element_process_items(
items_rest: Union[tuple, list],
parms: dict,
):
+ """
+ Try to match sequences built from items
+ against the pattern.
+ """
# Include wrappings like Plus[a, b] only if not all items taken
# - in that case we would match the same expression over and over.
-
- attributes: int = parms["attributes"]
element_count: int = parms["element_count"]
expression: Expression = parms["expression"]
evaluation: Evaluation = parms["evaluation"]
@@ -1015,19 +980,15 @@ def element_yield(next_vars_parm, next_rest_parm):
def match_yield(new_vars, _):
if parms["rest_elements"]:
+ new_parms = parms.copy()
+ new_parms["rest_expression"] = items_rest
+ new_parms["rest_elements"] = parms["next_rest_elements"]
+ new_parms["vars_dict"] = new_vars
+ new_parms["element_index"] = parms["next_index"]
+ new_parms["yield_func"] = element_yield
+ del new_parms["element"]
pattern.match_element(
- element_yield,
- parms["next_element"],
- parms["next_rest_elements"],
- items_rest,
- new_vars,
- expression,
- attributes,
- evaluation,
- fully=fully,
- depth=parms["depth"],
- element_index=parms["next_index"],
- element_count=element_count,
+ element=parms["next_element"], pattern_context=new_parms
)
else:
if not fully or (not items_rest[0] and not items_rest[1]):
@@ -1035,34 +996,34 @@ def match_yield(new_vars, _):
def yield_wrapping(item):
parms["element"].match(
- match_yield,
item,
- parms["vars"],
- evaluation,
- fully=True,
- head=expression.head,
- element_index=parms["element_index"],
- element_count=element_count,
+ {
+ "yield_func": match_yield,
+ "vars_dict": parms["vars_dict"],
+ "evaluation": evaluation,
+ "fully": True,
+ "head": expression.head,
+ "element_index": parms["element_index"],
+ "element_count": element_count,
+ },
)
+ # parms = parms.copy()
+ parms["max_count"] = parms["match_count"][1]
+ parms["include_flattened"] = include_flattened
+ # {"max_count":parms["match_count"][1],
+ # "expression":expression,
+ # "attributes":attributes,
+ # "include_flattened":include_flattened}
pattern.get_wrappings(
- yield_wrapping,
- tuple(items),
- parms["match_count"][1],
- expression,
- attributes,
- include_flattened=include_flattened,
+ yield_func=yield_wrapping, items=tuple(items), pattern_context=parms
)
# TODO: these two functions should collect all their arguments
# in a dict
def get_pre_choices_with_order(
- pat: ExpressionPattern,
- yield_choice: Callable,
- expression: Expression,
- attributes: int,
- vars_dict: dict,
+ pat: ExpressionPattern, expression: Expression, pattern_context
):
"""
Yield pre choices for expressions without
@@ -1072,15 +1033,11 @@ def get_pre_choices_with_order(
the parameter `yield_choice` with the collected
var_dict.
"""
- yield_choice(vars_dict)
+ pattern_context["yield_choice"](pattern_context["vars_dict"])
def get_pre_choices_orderless(
- pat: ExpressionPattern,
- yield_choice: Callable,
- expression: Expression,
- attributes: int,
- vars_dict: dict,
+ pat: ExpressionPattern, expression: Expression, pattern_context
):
"""
Yield pre choices for expressions with
@@ -1089,10 +1046,13 @@ def get_pre_choices_orderless(
This case is more involved, since the pattern can include subpatterns.
"""
+ yield_choice: Callable = pattern_context["yield_choice"]
+ vars_dict: dict = pattern_context["vars_dict"]
+
patterns = pat.filter_elements("Pattern")
# a dict with entries having patterns with the same name
# which are not in vars_dict.
- groups = {}
+ groups: dict = {}
prev_pattern = prev_name = None
for pattern in patterns:
name = pattern.elements[0].get_name()
@@ -1163,7 +1123,10 @@ def per_name(yield_name: Callable, groups: Tuple, vars_dict: dict):
# # FIXME: this call is wrong and needs a
# # wrapper_function as the 1st parameter.
# wrappings = pat.get_wrappings(
- # sequence, match_count[1], expression, attributes
+ # items=sequence,
+ # max_count=match_count[1],
+ # expression=expression,
+ # attributes=attributes
# )
# for wrapping in wrappings:
# 1/0
diff --git a/mathics/core/rules.py b/mathics/core/rules.py
index ddf946eaa..83feb70c3 100644
--- a/mathics/core/rules.py
+++ b/mathics/core/rules.py
@@ -162,7 +162,15 @@ def yield_match(vars, rest):
# only first possibility counts
try:
- self.pattern.match(yield_match, expression, {}, evaluation, fully=fully)
+ self.pattern.match(
+ expression,
+ pattern_context={
+ "yield_func": yield_match,
+ "vars_dict": {},
+ "evaluation": evaluation,
+ "fully": fully,
+ },
+ )
except StopGenerator_BaseRule as exc:
# FIXME: figure where these values are not getting set or updated properly.
# For now we have to take a pessimistic view
diff --git a/mathics/eval/patterns.py b/mathics/eval/patterns.py
index 0b9e7ee44..dfecc670d 100644
--- a/mathics/eval/patterns.py
+++ b/mathics/eval/patterns.py
@@ -18,7 +18,10 @@ def yield_func(vars, rest):
raise _StopGeneratorMatchQ(True)
try:
- self.form.match(yield_func, expr, {}, evaluation)
+ self.form.match(
+ expr,
+ {"yield_func": yield_func, "vars_dict": {}, "evaluation": evaluation},
+ )
except _StopGeneratorMatchQ:
return True
return False
diff --git a/mathics/eval/test.py b/mathics/eval/test.py
index c7b66fdfe..0f380219b 100644
--- a/mathics/eval/test.py
+++ b/mathics/eval/test.py
@@ -20,7 +20,15 @@ def yield_match(vars, rest):
# return False
try:
- form.match(yield_match, item, {}, evaluation, fully=False)
+ form.match(
+ item,
+ pattern_context={
+ "yield_func": yield_match,
+ "vars_dict": {},
+ "evaluation": evaluation,
+ "fully": False,
+ },
+ )
except _StopGeneratorBaseElementIsFree as exc:
return exc.value
diff --git a/mathics/eval/testing_expressions.py b/mathics/eval/testing_expressions.py
index 5edfe8b8f..e0b8a6f28 100644
--- a/mathics/eval/testing_expressions.py
+++ b/mathics/eval/testing_expressions.py
@@ -140,7 +140,7 @@ def check(level, expr):
return SymbolFalse
depth = len(dims) - 1 # None doesn't count
- if not pattern.does_match(Integer(depth), evaluation):
+ if not pattern.does_match(Integer(depth), {"evaluation": evaluation}):
return SymbolFalse
return SymbolTrue
@@ -154,7 +154,7 @@ def check_SparseArrayQ(expr, pattern, test, evaluation: Evaluation):
pattern = BasePattern.create(pattern, evaluation=evaluation)
dims, default_value, rules = expr.elements[1:]
- if not pattern.does_match(Integer(len(dims.elements)), evaluation):
+ if not pattern.does_match(Integer(len(dims.elements)), {"evaluation": evaluation}):
return SymbolFalse
array_size = Expression(SymbolTimes, *dims.elements).evaluate(evaluation)