diff --git a/mathics/builtin/list/constructing.py b/mathics/builtin/list/constructing.py index 014f467d5..72c673b40 100644 --- a/mathics/builtin/list/constructing.py +++ b/mathics/builtin/list/constructing.py @@ -439,7 +439,7 @@ def eval(self, expr, patterns, f, evaluation: Evaluation): def listener(e, tag): result = False for pattern, items in sown: - if pattern.does_match(tag, evaluation): + if pattern.does_match(tag, {"evaluation": evaluation}): for item in items: if item[0].sameQ(tag): item[1].append(e) diff --git a/mathics/builtin/patterns.py b/mathics/builtin/patterns.py index addd2a600..6c53df3a6 100644 --- a/mathics/builtin/patterns.py +++ b/mathics/builtin/patterns.py @@ -37,9 +37,9 @@ The attributes 'Flat', 'Orderless', and 'OneIdentity' affect pattern matching. """ -from typing import Callable, List, Optional as OptionalType, Tuple, Union +from typing import List, Optional as OptionalType, Tuple, Union -from mathics.core.atoms import Integer, Number, Rational, Real, String +from mathics.core.atoms import Integer, Integer2, Number, Rational, Real, String from mathics.core.attributes import ( A_HOLD_ALL, A_HOLD_FIRST, @@ -54,6 +54,7 @@ PatternError, PatternObject, PostfixOperator, + Test, ) from mathics.core.element import BaseElement, EvalMixin from mathics.core.evaluation import Evaluation @@ -63,7 +64,14 @@ from mathics.core.pattern import BasePattern, StopGenerator from mathics.core.rules import Rule from mathics.core.symbols import Atom, Symbol, SymbolList, SymbolTrue -from mathics.core.systemsymbols import SymbolBlank, SymbolDefault, SymbolDispatch +from mathics.core.systemsymbols import ( + SymbolBlank, + SymbolDefault, + SymbolDispatch, + SymbolInfinity, + SymbolRule, + SymbolRuleDelayed, +) from mathics.eval.parts import python_levelspec # This tells documentation how to sort this module @@ -85,11 +93,6 @@ class Rule_(BinaryOperator): = a + b + d >> {x,x^2,y} /. x->3 = {3, 9, y} - """ - - # TODO: An error message should appear when Rule is called with a wrong - # number of arguments - """ >> a /. Rule[1, 2, 3] -> t : Rule called with 3 arguments; 2 arguments are expected. = a @@ -102,6 +105,13 @@ class Rule_(BinaryOperator): needs_verbatim = True summary_text = "a replacement rule" + def eval_rule(self, elems, evaluation): + """Rule[elems___]""" + num_parms = len(elems.get_sequence()) + if num_parms != 2: + evaluation.message("Rule", "argrx", "Rule", Integer(num_parms), Integer2) + return None + class RuleDelayed(BinaryOperator): """ @@ -123,6 +133,15 @@ class RuleDelayed(BinaryOperator): operator = ":>" summary_text = "a rule that keeps the replacement unevaluated" + def eval_rule_delayed(self, elems, evaluation): + """RuleDelayed[elems___]""" + num_parms = len(elems.get_sequence()) + if num_parms != 2: + evaluation.message( + "RuleDelayed", "argrx", "RuleDelayed", Integer(num_parms), Integer2 + ) + return None + # TODO: disentangle me def create_rules( @@ -130,20 +149,24 @@ def create_rules( expr: Expression, name: str, evaluation: Evaluation, - extra_args: List = [], + extra_args: OptionalType[List] = None, ) -> Tuple[Union[List[Rule], BaseElement], bool]: """ - This function implements `Replace`, `ReplaceAll`, `ReplaceRepeated` and `ReplaceList` eval methods. - `name` controls which of these methods is implemented. These methods applies the rule / list of rules - `rules_expr` over the expression `expr`, using the evaluation context `evaluation`. - - The result is a tuple of two elements. If the second element is `True`, then the first element is the result of the method. + This function implements `Replace`, `ReplaceAll`, `ReplaceRepeated` + and `ReplaceList` eval methods. + `name` controls which of these methods is implemented. These methods + applies the rule / list of rules + `rules_expr` over the expression `expr`, using the evaluation context + `evaluation`. + + The result is a tuple of two elements. If the second element is `True`, + then the first element is the result of the method. If `False`, the first element of the tuple is a list of rules. """ if isinstance(rules_expr, Dispatch): return rules_expr.rules, False - elif rules_expr.has_form("Dispatch", None): + if rules_expr.has_form("Dispatch", None): return Dispatch(rules_expr.elements, evaluation) if rules_expr.has_form("List", None): @@ -164,6 +187,8 @@ def create_rules( break if all_lists: + if extra_args is None: + extra_args = [] return ( ListExpression( *[ @@ -173,28 +198,28 @@ def create_rules( ), True, ) - else: - evaluation.message(name, "rmix", rules_expr) + + evaluation.message(name, "rmix", rules_expr) + return None, True + + result = [] + for rule in rules: + if rule.head not in (SymbolRule, SymbolRuleDelayed): + evaluation.message(name, "reps", rule) return None, True - else: - result = [] - for rule in rules: - if rule.get_head_name() not in ("System`Rule", "System`RuleDelayed"): - evaluation.message(name, "reps", rule) - return None, True - elif len(rule.elements) != 2: - evaluation.message( - # TODO: shorten names here - rule.get_head_name(), - "argrx", - rule.get_head_name(), - 3, - 2, - ) - return None, True - else: - result.append(Rule(rule.elements[0], rule.elements[1])) - return result, False + if len(rule.elements) != 2: + evaluation.message( + # TODO: shorten names here + rule.get_head_name(), + "argrx", + rule.get_head_name(), + 3, + 2, + ) + return None, True + + result.append(Rule(rule.elements[0], rule.elements[1])) + return result, False class Replace(Builtin): @@ -272,7 +297,7 @@ def eval_levelspec(self, expr, rules, ls, evaluation, options): heads = self.get_option(options, "Heads", evaluation) is SymbolTrue - result, applied = expr.do_apply_rules( + result, _ = expr.do_apply_rules( rules, evaluation, level=0, @@ -285,6 +310,8 @@ def eval_levelspec(self, expr, rules, ls, evaluation, options): except PatternError: evaluation.message("Replace", "reps", rules) + return None + class ReplaceAll(BinaryOperator): """ @@ -348,10 +375,11 @@ def eval(self, expr, rules, evaluation: Evaluation): rules, ret = create_rules(rules, expr, "ReplaceAll", evaluation) if ret: return rules - result, applied = expr.do_apply_rules(rules, evaluation) + result, _ = expr.do_apply_rules(rules, evaluation) return result except PatternError: evaluation.message("Replace", "reps", rules) + return None class ReplaceRepeated(BinaryOperator): @@ -450,7 +478,10 @@ class ReplaceList(Builtin):
'ReplaceList[$expr$, $rules$]' -
returns a list of all possible results of applying $rules$ +
returns a list of all possible results when applying $rules$ \ + to $expr$. +
'ReplaceList[$expr$, $rules$, $n$]' +
returns a list of at most $n$ results when applying $rules$ \ to $expr$.
@@ -485,20 +516,28 @@ class ReplaceList(Builtin): summary_text = "list of possible replacement results" def eval( - self, expr: BaseElement, rules: BaseElement, max: Number, evaluation: Evaluation + self, + expr: BaseElement, + rules: BaseElement, + maxidx: Number, + evaluation: Evaluation, ) -> OptionalType[BaseElement]: - "ReplaceList[expr_, rules_, max_:Infinity]" + "ReplaceList[expr_, rules_, maxidx_:Infinity]" - if max.get_name() == "System`Infinity": + # TODO: the below handles Infinity getting added as a + # default argument, when it is passed explitly, e.g. + # ReplaceList[expr, {}, Infinity], then Infinity + # comes in as DirectedInfinity[1]. + if maxidx == SymbolInfinity: max_count = None else: - max_count = max.get_int_value() + max_count = maxidx.get_int_value() if max_count is None or max_count < 0: evaluation.message("ReplaceList", "innf", 3) - return + return None try: rules, ret = create_rules( - rules, expr, "ReplaceList", evaluation, extra_args=[max] + rules, expr, "ReplaceList", evaluation, extra_args=[maxidx] ) except PatternError: evaluation.message("Replace", "reps", rules) @@ -507,12 +546,12 @@ def eval( if ret: return rules - list = [] + list_result = [] for rule in rules: result = rule.apply(expr, evaluation, return_list=True, max_list=max_count) - list.extend(result) + list_result.extend(result) - return ListExpression(*list) + return ListExpression(*list_result) class PatternTest(BinaryOperator, PatternObject): @@ -545,7 +584,7 @@ class PatternTest(BinaryOperator, PatternObject): def init( self, expr: Expression, evaluation: OptionalType[Evaluation] = None ) -> None: - super(PatternTest, self).init(expr, evaluation=evaluation) + super().init(expr, evaluation=evaluation) # This class has an important effect in the general performance, # since all the rules that requires specify the type of patterns # call it. Then, for simple checks like `NumberQ` or `NumericQ` @@ -576,7 +615,10 @@ def init( if match_function: self.match = match_function - def match_atom(self, yield_func, expression, vars, evaluation, **kwargs): + def match_atom(self, expression: Expression, pattern_context: dict): + """Match function for AtomQ""" + yield_func = pattern_context["yield_func"] + def yield_match(vars_2, rest): items = expression.get_sequence() # Here we use a `for` loop instead an all over iterator @@ -588,9 +630,15 @@ def yield_match(vars_2, rest): else: yield_func(vars_2, None) - self.pattern.match(yield_match, expression, vars, evaluation) + # TODO: clarify why we need to use copy here. + pattern_context = pattern_context.copy() + pattern_context["yield_func"] = yield_match + self.pattern.match(expression, pattern_context) + + def match_string(self, expression: Expression, pattern_context: dict): + """Match function for StringQ""" + yield_func = pattern_context["yield_func"] - def match_string(self, yield_func, expression, vars, evaluation, **kwargs): def yield_match(vars_2, rest): items = expression.get_sequence() for item in items: @@ -599,9 +647,14 @@ def yield_match(vars_2, rest): else: yield_func(vars_2, None) - self.pattern.match(yield_match, expression, vars, evaluation) + pattern_context = pattern_context.copy() + pattern_context["yield_func"] = yield_match + self.pattern.match(expression, pattern_context) + + def match_numberq(self, expression: Expression, pattern_context: dict): + """Match function for NumberQ""" + yield_func = pattern_context["yield_func"] - def match_numberq(self, yield_func, expression, vars, evaluation, **kwargs): def yield_match(vars_2, rest): items = expression.get_sequence() for item in items: @@ -610,9 +663,15 @@ def yield_match(vars_2, rest): else: yield_func(vars_2, None) - self.pattern.match(yield_match, expression, vars, evaluation) + pattern_context = pattern_context.copy() + pattern_context["yield_func"] = yield_match + self.pattern.match(expression, pattern_context) + + def match_numericq(self, expression: Expression, pattern_context: dict): + """Match function for NumericQ""" + yield_func = pattern_context["yield_func"] + evaluation = pattern_context["evaluation"] - def match_numericq(self, yield_func, expression, vars, evaluation, **kwargs): def yield_match(vars_2, rest): items = expression.get_sequence() for item in items: @@ -621,9 +680,14 @@ def yield_match(vars_2, rest): else: yield_func(vars_2, None) - self.pattern.match(yield_match, expression, vars, evaluation) + pattern_context = pattern_context.copy() + pattern_context["yield_func"] = yield_match + self.pattern.match(expression, pattern_context) + + def match_real_numberq(self, expression: Expression, pattern_context: dict): + """Match function for RealValuedNumberQ""" + yield_func = pattern_context["yield_func"] - def match_real_numberq(self, yield_func, expression, vars, evaluation, **kwargs): def yield_match(vars_2, rest): items = expression.get_sequence() for item in items: @@ -632,9 +696,14 @@ def yield_match(vars_2, rest): else: yield_func(vars_2, None) - self.pattern.match(yield_match, expression, vars, evaluation) + pattern_context = pattern_context.copy() + pattern_context["yield_func"] = yield_match + self.pattern.match(expression, pattern_context) + + def match_positive(self, expression: Expression, pattern_context: dict): + """Match function for PositiveQ""" + yield_func = pattern_context["yield_func"] - def match_positive(self, yield_func, expression, vars, evaluation, **kwargs): def yield_match(vars_2, rest): items = expression.get_sequence() if all( @@ -643,9 +712,14 @@ def yield_match(vars_2, rest): ): yield_func(vars_2, None) - self.pattern.match(yield_match, expression, vars, evaluation) + pattern_context = pattern_context.copy() + pattern_context["yield_func"] = yield_match + self.pattern.match(expression, pattern_context) + + def match_negative(self, expression: Expression, pattern_context: dict): + """Match function for NegativeQ""" + yield_func = pattern_context["yield_func"] - def match_negative(self, yield_func, expression, vars, evaluation, **kwargs): def yield_match(vars_2, rest): items = expression.get_sequence() if all( @@ -654,9 +728,14 @@ def yield_match(vars_2, rest): ): yield_func(vars_2, None) - self.pattern.match(yield_match, expression, vars, evaluation) + pattern_context = pattern_context.copy() + pattern_context["yield_func"] = yield_match + self.pattern.match(expression, pattern_context) + + def match_nonpositive(self, expression: Expression, pattern_context: dict): + """Match function for NonPositiveQ""" + yield_func = pattern_context["yield_func"] - def match_nonpositive(self, yield_func, expression, vars, evaluation, **kwargs): def yield_match(vars_2, rest): items = expression.get_sequence() if all( @@ -665,9 +744,14 @@ def yield_match(vars_2, rest): ): yield_func(vars_2, None) - self.pattern.match(yield_match, expression, vars, evaluation) + pattern_context = pattern_context.copy() + pattern_context["yield_func"] = yield_match + self.pattern.match(expression, pattern_context) + + def match_nonnegative(self, expression: Expression, pattern_context: dict): + """Match function for NonNegativeQ""" + yield_func = pattern_context["yield_func"] - def match_nonnegative(self, yield_func, expression, vars, evaluation, **kwargs): def yield_match(vars_2, rest): items = expression.get_sequence() if all( @@ -676,35 +760,41 @@ def yield_match(vars_2, rest): ): yield_func(vars_2, None) - self.pattern.match(yield_match, expression, vars, evaluation) + pattern_context = pattern_context.copy() + pattern_context["yield_func"] = yield_match + self.pattern.match(expression, pattern_context) def quick_pattern_test(self, candidate, test, evaluation: Evaluation): + """Pattern test for some other special cases""" if test == "System`NegativePowerQ": return ( candidate.has_form("Power", 2) and isinstance(candidate.elements[1], (Integer, Rational, Real)) and candidate.elements[1].value < 0 ) - elif test == "System`NotNegativePowerQ": + if test == "System`NotNegativePowerQ": return not ( candidate.has_form("Power", 2) and isinstance(candidate.elements[1], (Integer, Rational, Real)) and candidate.elements[1].value < 0 ) - else: - from mathics.core.builtin import Test - - builtin = None - builtin = evaluation.definitions.get_definition(test) - if builtin: - builtin = builtin.builtin - if builtin is not None and isinstance(builtin, Test): - return builtin.test(candidate) + + builtin = None + builtin = evaluation.definitions.get_definition(test) + if builtin: + builtin = builtin.builtin + if builtin is not None and isinstance(builtin, Test): + return builtin.test(candidate) return None - def match(self, yield_func, expression, vars, evaluation, **kwargs): - # def match(self, yield_func, expression, vars, evaluation, **kwargs): - # for vars_2, rest in self.pattern.match(expression, vars, evaluation): + def match(self, expression: Expression, pattern_context: dict): + """Match expression with PatternTest""" + evaluation = pattern_context["evaluation"] + vars_dict = pattern_context["vars_dict"] + yield_func = pattern_context["yield_func"] + + # def match(self, yield_func, expression, vars_dict, evaluation, **kwargs): + # for vars_2, rest in self.pattern.match(expression, vars_dict, evaluation): def yield_match(vars_2, rest): testname = self.test_name items = expression.get_sequence() @@ -713,25 +803,31 @@ def yield_match(vars_2, rest): quick_test = self.quick_pattern_test(item, testname, evaluation) if quick_test is False: break - elif quick_test is True: + if quick_test is True: continue # raise StopGenerator - else: - test_expr = Expression(self.test, item) - test_value = test_expr.evaluate(evaluation) - if test_value is not SymbolTrue: - break - # raise StopGenerator + test_expr = Expression(self.test, item) + test_value = test_expr.evaluate(evaluation) + if test_value is not SymbolTrue: + break + # raise StopGenerator else: yield_func(vars_2, None) # try: - self.pattern.match(yield_match, expression, vars, evaluation) + self.pattern.match( + expression, + { + "yield_func": yield_match, + "vars_dict": vars_dict, + "evaluation": evaluation, + }, + ) # except StopGenerator: # pass - def get_match_count(self, vars={}): - return self.pattern.get_match_count(vars) + def get_match_count(self, vars_dict: OptionalType[dict] = None): + return self.pattern.get_match_count(vars_dict) class Alternatives(BinaryOperator, PatternObject): @@ -763,31 +859,34 @@ class Alternatives(BinaryOperator, PatternObject): def init( self, expr: Expression, evaluation: OptionalType[Evaluation] = None ) -> None: - super(Alternatives, self).init(expr, evaluation=evaluation) + super().init(expr, evaluation=evaluation) self.alternatives = [ BasePattern.create(element, evaluation=evaluation) for element in expr.elements ] - def match(self, yield_func, expression, vars, evaluation, **kwargs): + def match(self, expression: Expression, pattern_context: dict): + """Match with Alternatives""" for alternative in self.alternatives: - # for new_vars, rest in alternative.match( - # expression, vars, evaluation): - # yield_func(new_vars, rest) - alternative.match(yield_func, expression, vars, evaluation) - - def get_match_count(self, vars={}): - range = None + # for new_vars_dict, rest in alternative.match( + # expression, vars_dict, evaluation): + # yield_func(new_vars_dict, rest) + alternative.match(expression, pattern_context) + + def get_match_count( + self, vars_dict: OptionalType[dict] = None + ) -> Union[None, int, tuple]: + range_lst = None for alternative in self.alternatives: - sub = alternative.get_match_count(vars) - if range is None: - range = list(sub) + sub = alternative.get_match_count(vars_dict) + if range_lst is None: + range_lst = tuple(sub) else: - if sub[0] < range[0]: - range[0] = sub[0] - if range[1] is None or sub[1] > range[1]: - range[1] = sub[1] - return range + if sub[0] < range_lst[0]: + range_lst[0] = sub[0] + if range_lst[1] is None or sub[1] > range_lst[1]: + range_lst[1] = sub[1] + return tuple(range_lst) class _StopGeneratorExcept(StopGenerator): @@ -826,23 +925,28 @@ class Except(PatternObject): def init( self, expr: Expression, evaluation: OptionalType[Evaluation] = None ) -> None: - super(Except, self).init(expr, evaluation=evaluation) + super().init(expr, evaluation=evaluation) self.c = BasePattern.create(expr.elements[0], evaluation=evaluation) if len(expr.elements) == 2: self.p = BasePattern.create(expr.elements[1], evaluation=evaluation) else: self.p = BasePattern.create(Expression(SymbolBlank), evaluation=evaluation) - def match(self, yield_func, expression, vars, evaluation, **kwargs): - def except_yield_func(vars, rest): + def match(self, expression: Expression, pattern_context: dict): + """Match with Exception Pattern""" + + def except_yield_func(vars_dict, rest): raise _StopGeneratorExcept(True) + new_pattern_context = pattern_context.copy() + new_pattern_context["yield_func"] = except_yield_func + try: - self.c.match(except_yield_func, expression, vars, evaluation) + self.c.match(expression, new_pattern_context) except _StopGeneratorExcept: pass else: - self.p.match(yield_func, expression, vars, evaluation) + self.p.match(expression, pattern_context) class Verbatim(PatternObject): @@ -874,12 +978,16 @@ class Verbatim(PatternObject): def init( self, expr: Expression, evaluation: OptionalType[Evaluation] = None ) -> None: - super(Verbatim, self).init(expr, evaluation=evaluation) + super().init(expr, evaluation=evaluation) self.content = expr.elements[0] - def match(self, yield_func, expression, vars, evaluation, **kwargs): + def match(self, expression: Expression, pattern_context: dict): + """Match with Verbatim Pattern""" + vars_dict = pattern_context["vars_dict"] + yield_func = pattern_context["yield_func"] + if self.content.sameQ(expression): - yield_func(vars, None) + yield_func(vars_dict, None) class HoldPattern(PatternObject): @@ -910,14 +1018,14 @@ class HoldPattern(PatternObject): def init( self, expr: Expression, evaluation: OptionalType[Evaluation] = None ) -> None: - super(HoldPattern, self).init(expr, evaluation=evaluation) + super().init(expr, evaluation=evaluation) self.pattern = BasePattern.create(expr.elements[0], evaluation=evaluation) - def match(self, yield_func, expression, vars, evaluation, **kwargs): - # for new_vars, rest in self.pattern.match( - # expression, vars, evaluation): - # yield new_vars, rest - self.pattern.match(yield_func, expression, vars, evaluation) + def match(self, expression: Expression, pattern_context: dict): + # for new_vars_dict, rest in self.pattern.match( + # expression, vars_dict, evaluation): + # yield new_vars_dict, rest + self.pattern.match(expression, pattern_context) class Pattern(PatternObject): @@ -976,7 +1084,11 @@ class Pattern(PatternObject): } rules = { - "MakeBoxes[Verbatim[Pattern][symbol_Symbol, blank_Blank|blank_BlankSequence|blank_BlankNullSequence], f:StandardForm|TraditionalForm|InputForm|OutputForm]": "MakeBoxes[symbol, f] <> MakeBoxes[blank, f]", + ( + "MakeBoxes[Verbatim[Pattern][symbol_Symbol, blank_Blank|" + "blank_BlankSequence|blank_BlankNullSequence], " + "f:StandardForm|TraditionalForm|InputForm|OutputForm]" + ): "MakeBoxes[symbol, f] <> MakeBoxes[blank, f]", # 'StringForm["`1``2`", HoldForm[symbol], blank]', } @@ -996,51 +1108,62 @@ def init( varname = expr.elements[0].get_name() if varname is None or varname == "": self.error("patvar", expr) - super(Pattern, self).init(expr, evaluation=evaluation) + super().init(expr, evaluation=evaluation) self.varname = varname self.pattern = BasePattern.create(expr.elements[1], evaluation=evaluation) def __repr__(self): return "" % repr(self.pattern) - def get_match_count(self, vars={}): - return self.pattern.get_match_count(vars) + def get_match_count( + self, vars_dict: OptionalType[dict] = None + ) -> Union[int, tuple]: + return self.pattern.get_match_count(vars_dict) + + def match(self, expression: Expression, pattern_context: dict): + """Match with a (named) pattern""" + yield_func = pattern_context["yield_func"] + vars_dict = pattern_context["vars_dict"] - def match(self, yield_func, expression, vars_dict, evaluation, **kwargs): existing = vars_dict.get(self.varname, None) if existing is None: - new_vars = vars_dict.copy() - new_vars[self.varname] = expression + new_vars_dict = vars_dict.copy() + new_vars_dict[self.varname] = expression + pattern_context = pattern_context.copy() + pattern_context["vars_dict"] = new_vars_dict # for vars_2, rest in self.pattern.match( - # expression, new_vars, evaluation): + # expression, new_vars_dict, evaluation): # yield vars_2, rest - if type(self.pattern) is OptionsPattern: + if isinstance(self.pattern, OptionsPattern): self.pattern.match( - yield_func, expression, new_vars, evaluation, **kwargs + expression=expression, pattern_context=pattern_context ) else: - self.pattern.match(yield_func, expression, new_vars, evaluation) + self.pattern.match( + expression=expression, pattern_context=pattern_context + ) else: if existing.sameQ(expression): yield_func(vars_dict, None) - def get_match_candidates( - self, elements: tuple, expression, attributes, evaluation, vars_dict=None - ): - if vars_dict is None: - vars_dict = {} + def get_match_candidates(self, elements: tuple, pattern_context: dict) -> tuple: + """ + Return a sub-tuple of elements that match with + the pattern. + Optional parameters provide information + about the context where the elements and the + patterns come from. + """ + vars_dict = pattern_context.get("vars_dict", {}) + existing = vars_dict.get(self.varname, None) if existing is None: - return self.pattern.get_match_candidates( - elements, expression, attributes, evaluation, vars_dict - ) - else: - # Treat existing variable as verbatim - verbatim_expr = Expression(SymbolVerbatim, existing) - verbatim = Verbatim(verbatim_expr) - return verbatim.get_match_candidates( - elements, expression, attributes, evaluation, vars_dict - ) + return self.pattern.get_match_candidates(elements, pattern_context) + + # Treat existing variable as verbatim + verbatim_expr = Expression(SymbolVerbatim, existing) + verbatim = Verbatim(verbatim_expr) + return verbatim.get_match_candidates(elements, pattern_context) class Optional(BinaryOperator, PatternObject): @@ -1099,24 +1222,21 @@ class Optional(BinaryOperator, PatternObject): def init( self, expr: Expression, evaluation: OptionalType[Evaluation] = None ) -> None: - super(Optional, self).init(expr, evaluation=evaluation) + super().init(expr, evaluation=evaluation) self.pattern = BasePattern.create(expr.elements[0], evaluation=evaluation) if len(expr.elements) == 2: self.default = expr.elements[1] else: self.default = None - def match( - self, - yield_func, - expression, - vars, - evaluation, - head=None, - element_index=None, - element_count=None, - **kwargs, - ): + def match(self, expression: Expression, pattern_context: dict): + head = pattern_context.get("head", None) + evaluation = pattern_context["evaluation"] + element_index = pattern_context.get("element_index", None) + element_count = pattern_context.get("element_count", None) + vars_dict = pattern_context["vars_dict"] + yield_func = pattern_context["yield_func"] + if expression.has_form("Sequence", 0): if self.default is None: if head is None: # head should be given by match_element! @@ -1135,11 +1255,18 @@ def match( default = self.default expression = default - # for vars_2, rest in self.pattern.match(expression, vars, evaluation): + # for vars_2, rest in self.pattern.match(expression, vars_dict, evaluation): # yield vars_2, rest - self.pattern.match(yield_func, expression, vars, evaluation) + self.pattern.match( + expression, + { + "yield_func": yield_func, + "vars_dict": vars_dict, + "evaluation": evaluation, + }, + ) - def get_match_count(self, vars={}): + def get_match_count(self, vars_dict: OptionalType[dict] = None) -> tuple: return (0, 1) @@ -1149,6 +1276,10 @@ def get_default_value( k: OptionalType[int] = None, n: OptionalType[int] = None, ): + """ + Get the default value associated to a name, and optionally, + to a position in the expression. + """ pos = [] if k is not None: pos.append(k) @@ -1191,7 +1322,7 @@ def __new__(cls, *args, **kwargs): def init( self, expr: Expression, evaluation: OptionalType[Evaluation] = None ) -> None: - super(_Blank, self).init(expr, evaluation=evaluation) + super().init(expr, evaluation=evaluation) if expr.elements: self.head = expr.elements[0] else: @@ -1237,20 +1368,16 @@ class Blank(_Blank): } summary_text = "match to any single expression" - def match( - self, - yield_func: Callable, - expression: Expression, - vars: dict, - evaluation: Evaluation, - **kwargs, - ): + def match(self, expression: Expression, pattern_context: dict): + vars_dict = pattern_context["vars_dict"] + yield_func = pattern_context["yield_func"] + if not expression.has_form("Sequence", 0): if self.head is not None: if expression.get_head().sameQ(self.head): - yield_func(vars, None) + yield_func(vars_dict, None) else: - yield_func(vars, None) + yield_func(vars_dict, None) class BlankSequence(_Blank): @@ -1292,14 +1419,9 @@ class BlankSequence(_Blank): } summary_text = "match to a non-empty sequence of elements" - def match( - self, - yield_func: Callable, - expression: Expression, - vars: dict, - evaluation: Evaluation, - **kwargs, - ): + def match(self, expression: Expression, pattern_context: dict): + vars_dict = pattern_context["vars_dict"] + yield_func = pattern_context["yield_func"] elements = expression.get_sequence() if not elements: return @@ -1310,11 +1432,11 @@ def match( ok = False break if ok: - yield_func(vars, None) + yield_func(vars_dict, None) else: - yield_func(vars, None) + yield_func(vars_dict, None) - def get_match_count(self, vars={}): + def get_match_count(self, vars_dict: OptionalType[dict] = None) -> tuple: return (1, None) @@ -1341,14 +1463,10 @@ class BlankNullSequence(_Blank): } summary_text = "match to a sequence of zero or more elements" - def match( - self, - yield_func: Callable, - expression: Expression, - vars: dict, - evaluation: Evaluation, - **kwargs, - ): + def match(self, expression: Expression, pattern_context: dict): + """Match with a BlankNullSequence""" + vars_dict = pattern_context["vars_dict"] + yield_func = pattern_context["yield_func"] elements = expression.get_sequence() if self.head: ok = True @@ -1357,11 +1475,11 @@ def match( ok = False break if ok: - yield_func(vars, None) + yield_func(vars_dict, None) else: - yield_func(vars, None) + yield_func(vars_dict, None) - def get_match_count(self, vars={}): + def get_match_count(self, vars_dict: OptionalType[dict] = None) -> tuple: return (0, None) @@ -1398,12 +1516,12 @@ class Repeated(PostfixOperator, PatternObject): def init( self, expr: Expression, - min: int = 1, + min_idx: int = 1, evaluation: OptionalType[Evaluation] = None, ): self.pattern = BasePattern.create(expr.elements[0], evaluation=evaluation) self.max = None - self.min = min + self.min = min_idx if len(expr.elements) == 2: element_1 = expr.elements[1] allnumbers = not any( @@ -1417,32 +1535,43 @@ def init( else: self.error("range", 2, expr) - def match(self, yield_func, expression, vars, evaluation, **kwargs): + def match(self, expression: Expression, pattern_context: dict): + """Match with Repeated[...]""" + yield_func = pattern_context["yield_func"] + vars_dict = pattern_context["vars_dict"] + evaluation = pattern_context["evaluation"] elements = expression.get_sequence() if len(elements) < self.min: return if self.max is not None and len(elements) > self.max: return - def iter(yield_iter, rest_elements, vars): + def iter_fn(yield_iter, rest_elements, vars_dict): if rest_elements: - # for new_vars, rest in self.pattern.match(rest_elements[0], - # vars, evaluation): - def yield_match(new_vars, rest): - # for sub_vars, sub_rest in iter(rest_elements[1:], + # for new_vars_dict, rest in self.pattern.match(rest_elements[0], + # vars_dict, evaluation): + def yield_match(new_vars_dict, rest): + # for sub_vars_dict, sub_rest in iter(rest_elements[1:], # new_vars): - # yield sub_vars, rest - iter(yield_iter, rest_elements[1:], new_vars) + # yield sub_vars_dict, rest + iter_fn(yield_iter, rest_elements[1:], new_vars_dict) - self.pattern.match(yield_match, rest_elements[0], vars, evaluation) + self.pattern.match( + rest_elements[0], + { + "yield_func": yield_match, + "vars_dict": vars_dict, + "evaluation": evaluation, + }, + ) else: - yield_iter(vars, None) + yield_iter(vars_dict, None) - # for vars, rest in iter(elements, vars): - # yield_func(vars, rest) - iter(yield_func, elements, vars) + # for vars_dict, rest in iter(elements, vars): + # yield_func(vars_dict, rest) + iter_fn(yield_func, elements, vars_dict) - def get_match_count(self, vars={}): + def get_match_count(self, vars_dict: OptionalType[dict] = None) -> tuple: return (self.min, self.max) @@ -1467,7 +1596,7 @@ class RepeatedNull(Repeated): def init( self, expr: Expression, evaluation: OptionalType[Evaluation] = None ) -> None: - super(RepeatedNull, self).init(expr, min=0, evaluation=evaluation) + super().init(expr, min_idx=0, evaluation=evaluation) class Shortest(Builtin): @@ -1548,7 +1677,7 @@ class Condition(BinaryOperator, PatternObject): def init( self, expr: Expression, evaluation: OptionalType[Evaluation] = None ) -> None: - super(Condition, self).init(expr, evaluation=evaluation) + super().init(expr, evaluation=evaluation) self.test = expr.elements[1] # if (expr.elements[0].get_head_name() == "System`Condition" and # len(expr.elements[0].elements) == 2): @@ -1557,23 +1686,22 @@ def init( # else: self.pattern = BasePattern.create(expr.elements[0], evaluation=evaluation) - def match( - self, - yield_func: Callable, - expression: Expression, - vars: dict, - evaluation: Evaluation, - **kwargs, - ): - # for new_vars, rest in self.pattern.match(expression, vars, + def match(self, expression: Expression, pattern_context: dict): + """Match with Condition pattern""" + # for new_vars_dict, rest in self.pattern.match(expression, vars_dict, # evaluation): - def yield_match(new_vars, rest): - test_expr = self.test.replace_vars(new_vars) + evaluation = pattern_context["evaluation"] + yield_func = pattern_context["yield_func"] + + def yield_match(new_vars_dict, rest): + test_expr = self.test.replace_vars(new_vars_dict) test_result = test_expr.evaluate(evaluation) if test_result is SymbolTrue: - yield_func(new_vars, rest) + yield_func(new_vars_dict, rest) - self.pattern.match(yield_match, expression, vars, evaluation) + pattern_context = pattern_context.copy() + pattern_context["yield_func"] = yield_match + self.pattern.match(expression, pattern_context) class OptionsPattern(PatternObject): @@ -1622,7 +1750,7 @@ class OptionsPattern(PatternObject): def init( self, expr: Expression, evaluation: OptionalType[Evaluation] = None ) -> None: - super(OptionsPattern, self).init(expr, evaluation=evaluation) + super().init(expr, evaluation=evaluation) try: self.defaults = expr.elements[0] except IndexError: @@ -1630,16 +1758,12 @@ def init( # function. Set to not None in self.match self.defaults = None - def match( - self, - yield_func: Callable, - expression: Expression, - vars: dict, - evaluation: Evaluation, - **kwargs, - ): + def match(self, expression: Expression, pattern_context: dict): + """Match with an OptionsPattern""" + head = pattern_context.get("head", None) + evaluation = pattern_context["evaluation"] if self.defaults is None: - self.defaults = kwargs.get("head") + self.defaults = head if self.defaults is None: # we end up here with OptionsPattern that do not have any # default options defined, e.g. with this code: @@ -1664,28 +1788,27 @@ def match( if option_values is None: return values.update(option_values) - new_vars = vars.copy() + new_vars_dict = pattern_context["vars_dict"].copy() for name, value in values.items(): - new_vars["_option_" + name] = value - yield_func(new_vars, None) + new_vars_dict["_option_" + name] = value + pattern_context["yield_func"](new_vars_dict, None) - def get_match_count(self, vars: dict = {}): + def get_match_count(self, vars_dict: OptionalType[dict] = None) -> tuple: return (0, None) def get_match_candidates( - self, - elements: Tuple[BaseElement], - expression: Expression, - attributes: int, - evaluation: Evaluation, - vars: dict = {}, - ): + self, elements: Tuple[BaseElement], pattern_context: dict + ) -> tuple: + """ + Return the sub-tuple of elements that matches with the pattern. + """ + def _match(element: Expression): return element.has_form(("Rule", "RuleDelayed"), 2) or element.has_form( "List", None ) - return [element for element in elements if _match(element)] + return tuple((element for element in elements if _match(element))) class Dispatch(Atom): @@ -1697,7 +1820,7 @@ def __init__(self, rulelist: Expression, evaluation: Evaluation) -> None: self._elements = None self._head = SymbolDispatch - def get_sort_key(self) -> tuple: + def get_sort_key(self, pattern_sort: bool = False) -> tuple: return self.src.get_sort_key() def get_atom_name(self): @@ -1785,15 +1908,14 @@ def eval_create( # WMA does not raise this message: just leave it unevaluated, # and raise an error when the dispatch rule is used. evaluation.message("Dispatch", "invrpl", rule) - return + return None try: return Dispatch(flatten_list, evaluation) except Exception: - return + return None def eval_normal(self, dispatch: Dispatch, evaluation: Evaluation) -> ListExpression: """Normal[dispatch_Dispatch]""" if isinstance(dispatch, Dispatch): return dispatch.src - else: - return dispatch.elements[0] + return dispatch.elements[0] diff --git a/mathics/core/builtin.py b/mathics/core/builtin.py index abd5df3e3..f95f4058c 100644 --- a/mathics/core/builtin.py +++ b/mathics/core/builtin.py @@ -1198,11 +1198,11 @@ def get_lookup_name(self) -> str: return self.get_name() def get_match_candidates( - self, elements: Tuple[BaseElement], expression, attributes, evaluation, vars={} + self, elements: Tuple[BaseElement], pattern_context: dict ) -> Tuple[BaseElement]: return elements - def get_match_count(self, vars={}): + def get_match_count(self, vars_dict: dict = {}): return (1, 1) def get_sort_key(self, pattern_sort=False) -> tuple: diff --git a/mathics/core/pattern.py b/mathics/core/pattern.py index c96166b28..2d2bcbcb6 100644 --- a/mathics/core/pattern.py +++ b/mathics/core/pattern.py @@ -1,19 +1,22 @@ # cython: language_level=3 # cython: profile=False # -*- coding: utf-8 -*- -"""Core to Mathics3 is are patterns which match symbolic expressions. A pattern are built up in a custon pattern notation. +"""Core to Mathics3 is are patterns which match symbolic expressions. Patterns +are built up in a custon pattern notation. The parts of a pattern are called "Pattern Objects". -While there is a built-in function which allows users to match parts of expressions, patterns are also used in applying of transformation +While there is a built-in function which allows users to match parts of +expressions, patterns are also used in applying of transformation rules and deciding functions that get applied. -See also: mathics.core.rules and https://reference.wolfram.com/language/tutorial/PatternsAndTransformationRules.html +See also: mathics.core.rules and +https://reference.wolfram.com/language/tutorial/PatternsAndTransformationRules.html """ from abc import ABC from itertools import chain -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union from mathics.core.atoms import Integer from mathics.core.attributes import A_FLAT, A_ONE_IDENTITY, A_ORDERLESS @@ -236,17 +239,7 @@ def has_form(self, *args): """Compare the expression against a form""" return self.expr.has_form(*args) - def match( - self, - yield_func: Callable, - expression: BaseElement, - vars_dict: dict, - evaluation: Evaluation, - head: Optional[Symbol] = None, - element_index: Optional[int] = None, - element_count: Optional[int] = None, - fully: bool = True, - ): + def match(self, expression: BaseElement, pattern_context: dict): """ Check if the expression matches the pattern (self). If it does, calls `yield_func`. @@ -263,19 +256,14 @@ def match( """ raise NotImplementedError - def does_match( - self, - expression: BaseElement, - evaluation: Evaluation, - vars_dict: Optional[dict] = None, - fully: bool = True, - ) -> bool: + def does_match(self, expression: BaseElement, pattern_context: dict) -> bool: """returns True if `expression` matches self or we have reached the end of the matches, and False if it does not. """ + evaluation: Evaluation = pattern_context["evaluation"] + vars_dict: Optional[dict] = pattern_context.setdefault("vars_dict", {}) + fully: bool = pattern_context.get("fully", True) - if vars_dict is None: - vars_dict = {} # for sub_vars, rest in self.match( # nopep8 # expression, vars, evaluation, fully=fully): # return True @@ -284,38 +272,37 @@ def yield_match(sub_vars, rest): raise StopGenerator_Pattern(True) try: - self.match(yield_match, expression, vars_dict, evaluation, fully=fully) + self.match( + expression=expression, + pattern_context={ + "yield_func": yield_match, + "vars_dict": vars_dict, + "evaluation": evaluation, + "fully": fully, + }, + ) except StopGenerator_Pattern as exc: return exc.value return False def get_match_candidates( - self, - elements: Tuple[BaseElement], - expression: BaseElement, - attributes: int, - evaluation: Evaluation, - vars_dict: Optional[dict] = None, - ): + self, elements: Tuple[BaseElement], pattern_context: dict + ) -> tuple: """ - Get the candidates that matches with the pattern. + Get the a sub-tuple of elements that are candidates + matching with the pattern. + + Optional parameters provide information + about the context where the elements and the + patterns come from. """ return tuple() def get_match_candidates_count( - self, - elements: Tuple[BaseElement], - expression: BaseElement, - attributes: int, - evaluation: Evaluation, - vars_dict: Optional[dict] = None, - ): + self, elements: Tuple[BaseElement], pattern_context: dict + ) -> Union[int, tuple]: """Return the number of candidates that match with the pattern.""" - return len( - self.get_match_candidates( - elements, expression, attributes, evaluation, vars_dict - ) - ) + return len(self.get_match_candidates(elements, pattern_context)) def sameQ(self, other: BaseElement) -> bool: """Mathics SameQ""" @@ -342,61 +329,42 @@ def __repr__(self): def match_symbol( self, - yield_func: Callable, expression: BaseElement, - vars_dict: dict, - evaluation: Evaluation, - head: Optional[Symbol] = None, - element_index: Optional[int] = None, - element_count: Optional[int] = None, - fully: bool = True, + pattern_context, ): """Match against a symbol""" + assert isinstance(expression, BaseElement) if expression is self.atom: - yield_func(vars_dict, None) + pattern_context["yield_func"](pattern_context["vars_dict"], None) def get_match_symbol_candidates( - self, - elements: tuple, - expression: BaseElement, - attributes: int, - evaluation: Evaluation, - vars_dict: Optional[dict] = None, - ): - """Find the candidates that matches with the pattern""" - return [element for element in elements if element is self.atom] + self, elements: tuple, pattern_context: dict + ) -> tuple: + """Find the sub-tuple of elements that matches with the pattern""" + return tuple((element for element in elements if element is self.atom)) - def match( - self, - yield_func: Callable, - expression: BaseElement, - vars_dict: dict, - evaluation: Evaluation, - head: Optional[Symbol] = None, - element_index: Optional[int] = None, - element_count: Optional[int] = None, - fully: bool = True, - ): + def match(self, expression: BaseElement, pattern_context: dict): """Try to match the patterh with the expression.""" + if isinstance(expression, Atom) and expression.sameQ(self.atom): # yield vars, None - yield_func(vars_dict, None) + pattern_context["yield_func"](pattern_context["vars_dict"], None) def get_match_candidates( - self, - elements: Tuple[BaseElement], - expression: BaseElement, - attributes: int, - evaluation: Evaluation, - vars_dict: Optional[dict] = None, - ): - return [ - element - for element in elements - if (isinstance(element, Atom) and element.sameQ(self.atom)) - ] + self, elements: Tuple[BaseElement], pattern_context: dict + ) -> tuple: + """ + Return a sub-tuple of elements that matches with the pattern. + """ + return tuple( + ( + element + for element in elements + if (isinstance(element, Atom) and element.sameQ(self.atom)) + ) + ) - def get_match_count(self, vars_dict: Optional[dict] = None): + def get_match_count(self, vars_dict: Optional[dict] = None) -> Tuple[int, int]: """The number of matches""" return (1, 1) @@ -434,6 +402,7 @@ def __init__( def __set_pattern_attributes__(self, attributes): if attributes is None or self.attributes is not None: + self.get_pre_choices = self._get_pre_choices return self.attributes = attributes @@ -447,18 +416,13 @@ def __set_pattern_attributes__(self, attributes): element.isliteral for element in self.elements ) - def match( - self, - yield_func: Callable, - expression: BaseElement, - vars_dict: dict, - evaluation: Evaluation, - head: Optional[Symbol] = None, - element_index: Optional[int] = None, - element_count: Optional[int] = None, - fully: bool = True, - ): + def match(self, expression: BaseElement, pattern_context: dict): """Try to match the pattern against an Expression""" + evaluation = pattern_context["evaluation"] + yield_func = pattern_context["yield_func"] + vars_dict = pattern_context["vars_dict"] + fully = pattern_context.get("fully", True) + evaluation.check_stopped() if self.isliteral: if expression.sameQ(self.expr): @@ -475,48 +439,38 @@ def match( if not A_FLAT & attributes: fully = True - parms = { - "attributes": attributes, - "evaluation": evaluation, - "expression": expression, - "fully": fully, - "self": self, - "vars_dict": vars_dict, - "yield_func": yield_func, - } + parms = pattern_context.copy() + parms["fully"] = fully + parms["attributes"] = attributes + parms.setdefault("head", None) + parms.setdefault("element_index", None) + parms.setdefault("element_count", None) if not isinstance(expression, Atom): try: basic_match_expression( + self, + expression, parms, ) except StopGenerator_ExpressionPattern_match: return if A_ONE_IDENTITY & attributes: - match_expression_with_one_identity( - parms, - element_index=element_index, - element_count=element_count, - head=head, - ) + match_expression_with_one_identity(self, expression=expression, parms=parms) - def get_pre_choices( - self, - yield_choice: Callable, - expression: BaseElement, - attributes: int, - vars_dict: dict, + def _get_pre_choices( + self, expression: BaseElement, yield_choice: Callable, pattern_context: dict ): """ If not Orderless, call yield_choice with vars as the parameter. """ + attributes = pattern_context.get("attributes") + assert isinstance(attributes, int) if A_ORDERLESS & attributes: - get_pre_choices_orderless( - self, yield_choice, expression, attributes, vars_dict - ) + get_pre_choices_orderless(self, expression, pattern_context) else: - yield_choice(vars_dict) + pattern_context["yield_choice"](pattern_context["vars_dict"]) def filter_elements(self, head_name: str): """Filter the elements with a given head_name""" @@ -528,23 +482,28 @@ def filter_elements(self, head_name: str): def __repr__(self): return f"" - def get_match_count(self, vars_dict: Optional[dict] = None): + def get_match_count(self, vars_dict: Optional[dict] = None) -> Tuple[int, int]: """the number of matches""" return (1, 1) - def get_wrappings( - self, - yield_func: Callable, - items: Tuple, - max_count: Optional[int], - expression: Expression, - attributes: int, - include_flattened: bool = True, - ): - """Get the possible wrappings""" + def get_wrappings(self, yield_func: Callable, items: Tuple, pattern_context: dict): + """ + Get the possible wrappings + + If items has length 1, apply yield_func to the unique element. + Otherwise, apply it to a sequence. If the expression has the + attribute `Orderless`, apply it to all the possible orders. + Finally , if the expression is `Flat`, and the parameter `include_flattened` + is `True`, apply yield_func to the expression with the head of the original + expression applied to the original sequence. + """ if len(items) == 1: yield_func(items[0]) else: + max_count: Optional[int] = pattern_context["max_count"] + expression: Expression = pattern_context["expression"] + attributes: int = pattern_context["attributes"] + include_flattened: bool = pattern_context.get("include_flattened", True) if max_count is None or len(items) <= max_count: if A_ORDERLESS & attributes: for perm in permutations(items): @@ -555,26 +514,25 @@ def get_wrappings( sequence = Expression(SymbolSequence, *items) sequence.pattern_sequence = True yield_func(sequence) + # TODO: check if this should not be applied to each possible + # orders if A_ORDERLESS. if A_FLAT & attributes and include_flattened: yield_func(Expression(expression.get_head(), *items)) def match_element( self, - yield_func: Callable, element: BaseElement, - rest_elements: Tuple, - rest_expression: Tuple[List, List], - vars_dict: dict, - expression: BaseElement, - attributes: int, - evaluation: Evaluation, - element_index: int = 1, - element_count: Optional[int] = None, - first: bool = False, - fully: bool = True, - depth: int = 1, + pattern_context, ): """Try to match an element.""" + attributes: int = pattern_context["attributes"] + evaluation: Evaluation = pattern_context["evaluation"] + expression: BaseElement = pattern_context["expression"] + first: bool = pattern_context.setdefault("first", False) + fully: bool = pattern_context.setdefault("fully", True) + vars_dict: dict = pattern_context["vars_dict"] + rest_expression: tuple = pattern_context["rest_expression"] + rest_elements: tuple = pattern_context["rest_elements"] if rest_expression is None: rest_expression = ([], []) @@ -582,11 +540,7 @@ def match_element( match_count = element.get_match_count(vars_dict) element_candidates = element.get_match_candidates( - tuple(rest_expression[1]), # element.candidates, - expression, - attributes, - evaluation, - vars_dict, + tuple(rest_expression[1]), pattern_context # element.candidates, ) if len(element_candidates) < match_count[0]: @@ -642,24 +596,14 @@ def match_element( *set_lengths, ) - parms = { - "attributes": attributes, - "depth": depth + 1, - "element": element, - "element_count": element_count, - "element_index": element_index, - "evaluation": evaluation, - "expression": expression, - "fully": fully, - "match_count": match_count, - "next_index": element_index + 1, - "pattern": self, - "rest_elements": rest_elements, - "rest_expression": rest_expression, - "try_flattened": try_flattened, - "yield_func": yield_func, - "vars": vars_dict, - } + parms = pattern_context.copy() + parms["depth"] = parms.get("depth", 1) + 1 + parms["next_index"] = parms.setdefault("element_index", 1) + 1 + parms["pattern"] = self + parms["try_flattened"] = try_flattened + parms["match_count"] = match_count + parms["element"] = element + if rest_elements: parms["next_element"] = rest_elements[0] parms["next_rest_elements"] = rest_elements[1:] @@ -668,44 +612,43 @@ def match_element( expression_pattern_match_element_process_items(items, items_rest, parms) def get_match_candidates( - self, - elements: Tuple[BaseElement], - expression: BaseElement, - attributes: int, - evaluation: Evaluation, - vars_dict: Optional[dict] = None, - ): + self, elements: Tuple[BaseElement], pattern_context + ) -> tuple: """ Finds possible elements that could match the pattern, ignoring future pattern variable definitions, but taking into account already fixed variables. """ # TODO: fixed_vars! - - return [ - element - for element in elements - if self.does_match(element, evaluation, vars_dict) - ] + evaluation: Evaluation = pattern_context["evaluation"] + vars_dict: Optional[dict] = pattern_context.setdefault("vars_dict", {}) + return tuple( + ( + element + for element in elements + if self.does_match( + element, {"evaluation": evaluation, "vars_dict": vars_dict} + ) + ) + ) def get_match_candidates_count( - self, - elements: Tuple[BaseElement], - expression: BaseElement, - attributes: int, - evaluation: Evaluation, - vars_dict: Optional[dict] = None, - ): + self, elements: Tuple[BaseElement], pattern_context + ) -> Union[int, tuple]: """ Finds possible elements that could match the pattern, ignoring future pattern variable definitions, but taking into account already fixed variables. """ # TODO: fixed_vars! + evaluation: Evaluation = pattern_context["evaluation"] + vars_dict: Optional[dict] = pattern_context.setdefault("vars_dict", {}) count = 0 for element in elements: - if self.does_match(element, evaluation, vars_dict): + if self.does_match( + element, {"evaluation": evaluation, "vars_dict": vars_dict} + ): count += 1 return count @@ -715,10 +658,9 @@ def sort(self): def match_expression_with_one_identity( + self: BasePattern, + expression: Expression, parms: dict, - element_index: Optional[int], - element_count: Optional[int], - head: Symbol, ): """ Process expressions with the attribute OneIdentity. @@ -731,7 +673,6 @@ def match_expression_with_one_identity( # set of default values, and a single pattern. from mathics.builtin.patterns import Pattern - self: ExpressionPattern = parms["self"] vars_dict: dict = parms["vars_dict"] evaluation: Evaluation = parms["evaluation"] @@ -789,19 +730,24 @@ def match_expression_with_one_identity( del optionals[""] vars_dict.update(optionals) # Try to match the non-optional element with the expression - new_pattern.match( - parms["yield_func"], - parms["expression"], - vars_dict, - evaluation, - head=head, - element_index=element_index, - element_count=element_count, - fully=parms["fully"], - ) - - -def basic_match_expression(parms): + # no_parms={ + # "yield_func":parms["yield_func"], + # "vars_dict":vars_dict, + # "evaluation":evaluation, + # "head":head, + # "element_index":element_index, + # "element_count":element_count, + # "fully":parms["fully"], + # } + + # TODO: remove me eventually + del parms["attributes"] + new_pattern.match(expression=expression, pattern_context=parms) + + +def basic_match_expression( + self: ExpressionPattern, expression: Expression, parms: dict +): """ Try to match a pattern with an expression """ @@ -810,10 +756,7 @@ def basic_match_expression(parms): # if self.elements: # next_element = self.elements[0] # next_elements = self.elements[1:] - - self: ExpressionPattern = parms["self"] yield_func: Callable = parms["yield_func"] - expression: Expression = parms["expression"] vars_dict: dict = parms["vars_dict"] evaluation: Evaluation = parms["evaluation"] attributes: int = parms["attributes"] @@ -852,7 +795,8 @@ def yield_choice(pre_vars): if not unmatched_elements: raise StopGenerator_ExpressionPattern_match() if not element.does_match( - unmatched_elements[0], evaluation, pre_vars + unmatched_elements[0], + {"evaluation": evaluation, "vars_dict": pre_vars}, ): raise StopGenerator_ExpressionPattern_match() unmatched_elements = unmatched_elements[1:] @@ -862,10 +806,12 @@ def yield_choice(pre_vars): if not leading_blanks: candidates = element.get_match_candidates_count( unmatched_elements, - expression, - attributes, - evaluation, - pre_vars, + { + "expression": expression, + "attributes": attributes, + "evaluation": evaluation, + "vars_dict": pre_vars, + }, ) if candidates < match_count[0]: raise StopGenerator_ExpressionPattern_match() @@ -877,17 +823,19 @@ def yield_choice(pre_vars): # def yield_element(new_vars, rest): # yield_func(new_vars, rest) self.match_element( - yield_func, - next_element, - tuple(next_elements), - ([], expression.elements), - pre_vars, - expression, - attributes, - evaluation, - first=True, - fully=fully, - element_count=len(self.elements), + element=next_element, + pattern_context={ + "yield_func": yield_func, + "rest_elements": tuple(next_elements), + "rest_expression": ([], expression.elements), + "vars_dict": pre_vars, + "expression": expression, + "attributes": attributes, + "evaluation": evaluation, + "first": True, + "fully": fully, + "element_count": len(self.elements), + }, ) # for head_vars, _ in self.head.match(expression.get_head(), vars, @@ -898,14 +846,29 @@ def yield_head(head_vars, _): # expression, attributes, head_vars) # for pre_vars in pre_choices: - self.get_pre_choices(self, yield_choice, expression, attributes, head_vars) + self.get_pre_choices( + self, + expression, + { + "yield_choice": yield_choice, + "attributes": attributes, + "vars_dict": head_vars, + }, + ) else: if not expression.elements: yield_func(head_vars, None) else: return - self.head.match(yield_head, expression.get_head(), vars_dict, evaluation) + self.head.match( + expression.get_head(), + { + "yield_func": yield_head, + "vars_dict": vars_dict, + "evaluation": evaluation, + }, + ) def expression_pattern_match_element_orderless( @@ -976,10 +939,12 @@ def expression_pattern_match_element_process_items( items_rest: Union[tuple, list], parms: dict, ): + """ + Try to match sequences built from items + against the pattern. + """ # Include wrappings like Plus[a, b] only if not all items taken # - in that case we would match the same expression over and over. - - attributes: int = parms["attributes"] element_count: int = parms["element_count"] expression: Expression = parms["expression"] evaluation: Evaluation = parms["evaluation"] @@ -1015,19 +980,15 @@ def element_yield(next_vars_parm, next_rest_parm): def match_yield(new_vars, _): if parms["rest_elements"]: + new_parms = parms.copy() + new_parms["rest_expression"] = items_rest + new_parms["rest_elements"] = parms["next_rest_elements"] + new_parms["vars_dict"] = new_vars + new_parms["element_index"] = parms["next_index"] + new_parms["yield_func"] = element_yield + del new_parms["element"] pattern.match_element( - element_yield, - parms["next_element"], - parms["next_rest_elements"], - items_rest, - new_vars, - expression, - attributes, - evaluation, - fully=fully, - depth=parms["depth"], - element_index=parms["next_index"], - element_count=element_count, + element=parms["next_element"], pattern_context=new_parms ) else: if not fully or (not items_rest[0] and not items_rest[1]): @@ -1035,34 +996,34 @@ def match_yield(new_vars, _): def yield_wrapping(item): parms["element"].match( - match_yield, item, - parms["vars"], - evaluation, - fully=True, - head=expression.head, - element_index=parms["element_index"], - element_count=element_count, + { + "yield_func": match_yield, + "vars_dict": parms["vars_dict"], + "evaluation": evaluation, + "fully": True, + "head": expression.head, + "element_index": parms["element_index"], + "element_count": element_count, + }, ) + # parms = parms.copy() + parms["max_count"] = parms["match_count"][1] + parms["include_flattened"] = include_flattened + # {"max_count":parms["match_count"][1], + # "expression":expression, + # "attributes":attributes, + # "include_flattened":include_flattened} pattern.get_wrappings( - yield_wrapping, - tuple(items), - parms["match_count"][1], - expression, - attributes, - include_flattened=include_flattened, + yield_func=yield_wrapping, items=tuple(items), pattern_context=parms ) # TODO: these two functions should collect all their arguments # in a dict def get_pre_choices_with_order( - pat: ExpressionPattern, - yield_choice: Callable, - expression: Expression, - attributes: int, - vars_dict: dict, + pat: ExpressionPattern, expression: Expression, pattern_context ): """ Yield pre choices for expressions without @@ -1072,15 +1033,11 @@ def get_pre_choices_with_order( the parameter `yield_choice` with the collected var_dict. """ - yield_choice(vars_dict) + pattern_context["yield_choice"](pattern_context["vars_dict"]) def get_pre_choices_orderless( - pat: ExpressionPattern, - yield_choice: Callable, - expression: Expression, - attributes: int, - vars_dict: dict, + pat: ExpressionPattern, expression: Expression, pattern_context ): """ Yield pre choices for expressions with @@ -1089,10 +1046,13 @@ def get_pre_choices_orderless( This case is more involved, since the pattern can include subpatterns. """ + yield_choice: Callable = pattern_context["yield_choice"] + vars_dict: dict = pattern_context["vars_dict"] + patterns = pat.filter_elements("Pattern") # a dict with entries having patterns with the same name # which are not in vars_dict. - groups = {} + groups: dict = {} prev_pattern = prev_name = None for pattern in patterns: name = pattern.elements[0].get_name() @@ -1163,7 +1123,10 @@ def per_name(yield_name: Callable, groups: Tuple, vars_dict: dict): # # FIXME: this call is wrong and needs a # # wrapper_function as the 1st parameter. # wrappings = pat.get_wrappings( - # sequence, match_count[1], expression, attributes + # items=sequence, + # max_count=match_count[1], + # expression=expression, + # attributes=attributes # ) # for wrapping in wrappings: # 1/0 diff --git a/mathics/core/rules.py b/mathics/core/rules.py index ddf946eaa..83feb70c3 100644 --- a/mathics/core/rules.py +++ b/mathics/core/rules.py @@ -162,7 +162,15 @@ def yield_match(vars, rest): # only first possibility counts try: - self.pattern.match(yield_match, expression, {}, evaluation, fully=fully) + self.pattern.match( + expression, + pattern_context={ + "yield_func": yield_match, + "vars_dict": {}, + "evaluation": evaluation, + "fully": fully, + }, + ) except StopGenerator_BaseRule as exc: # FIXME: figure where these values are not getting set or updated properly. # For now we have to take a pessimistic view diff --git a/mathics/eval/patterns.py b/mathics/eval/patterns.py index 0b9e7ee44..dfecc670d 100644 --- a/mathics/eval/patterns.py +++ b/mathics/eval/patterns.py @@ -18,7 +18,10 @@ def yield_func(vars, rest): raise _StopGeneratorMatchQ(True) try: - self.form.match(yield_func, expr, {}, evaluation) + self.form.match( + expr, + {"yield_func": yield_func, "vars_dict": {}, "evaluation": evaluation}, + ) except _StopGeneratorMatchQ: return True return False diff --git a/mathics/eval/test.py b/mathics/eval/test.py index c7b66fdfe..0f380219b 100644 --- a/mathics/eval/test.py +++ b/mathics/eval/test.py @@ -20,7 +20,15 @@ def yield_match(vars, rest): # return False try: - form.match(yield_match, item, {}, evaluation, fully=False) + form.match( + item, + pattern_context={ + "yield_func": yield_match, + "vars_dict": {}, + "evaluation": evaluation, + "fully": False, + }, + ) except _StopGeneratorBaseElementIsFree as exc: return exc.value diff --git a/mathics/eval/testing_expressions.py b/mathics/eval/testing_expressions.py index 5edfe8b8f..e0b8a6f28 100644 --- a/mathics/eval/testing_expressions.py +++ b/mathics/eval/testing_expressions.py @@ -140,7 +140,7 @@ def check(level, expr): return SymbolFalse depth = len(dims) - 1 # None doesn't count - if not pattern.does_match(Integer(depth), evaluation): + if not pattern.does_match(Integer(depth), {"evaluation": evaluation}): return SymbolFalse return SymbolTrue @@ -154,7 +154,7 @@ def check_SparseArrayQ(expr, pattern, test, evaluation: Evaluation): pattern = BasePattern.create(pattern, evaluation=evaluation) dims, default_value, rules = expr.elements[1:] - if not pattern.does_match(Integer(len(dims.elements)), evaluation): + if not pattern.does_match(Integer(len(dims.elements)), {"evaluation": evaluation}): return SymbolFalse array_size = Expression(SymbolTimes, *dims.elements).evaluate(evaluation)