diff --git a/mathics/core/pattern.py b/mathics/core/pattern.py index 746ba5a99..7f1fa4b3c 100644 --- a/mathics/core/pattern.py +++ b/mathics/core/pattern.py @@ -410,7 +410,9 @@ def match( evaluation.check_stopped() if self.attributes is None: - self.attributes = self.head.get_attributes(evaluation.definitions) + self.__set_pattern_attributes__( + self.head.get_attributes(evaluation.definitions) + ) attributes = self.attributes if not A_FLAT & attributes: @@ -453,115 +455,25 @@ def get_pre_choices( If not Orderless, call yield_choice with vars as the parameter. """ if A_ORDERLESS & attributes: - self.sort() - patterns = self.filter_elements("Pattern") - # a dict with entries having patterns with the same name - # which are not in vars_dict. - groups = {} - prev_pattern = prev_name = None - for pattern in patterns: - name = pattern.elements[0].get_name() - existing = vars_dict.get(name, None) - if existing is None: - # There's no need for pre-choices if the variable is - # already set. - if name == prev_name: - if name in groups: - groups[name].append(pattern) - else: - groups[name] = [prev_pattern, pattern] - prev_pattern = pattern - prev_name = name - # prev_element = None - - # count duplicate elements - expr_groups = {} - for element in expression.elements: - expr_groups[element] = expr_groups.get(element, 0) + 1 - - def per_name(yield_name: Callable, groups: Tuple, vars_dict: dict): - """ - Yields possible variable settings (dictionaries) for the - remaining pattern groups - """ - # TODO: check why this condition is never reached in tests. - if groups: - # name, patterns = groups[0] - - # match_count = [0, None] - # for pattern in patterns: - # sub_match_count = pattern.get_match_count() - # if sub_match_count[0] > match_count[0]: - # match_count[0] = sub_match_count[0] - # if match_count[1] is None or ( - # sub_match_count[1] is not None - # and sub_match_count[1] < match_count[1] - # ): - # match_count[1] = sub_match_count[1] - # # possibilities = [{}] - # # sum = 0 - - # def per_expr(yield_expr, expr_groups, sum_int=0): - # """ - # Yields possible values (sequence lists) for the current - # variable (name) taking into account the - # (expression, count)'s in expr_groups - # """ - - # if expr_groups: - # expr, count = expr_groups.popitem() - # max_per_pattern = count // len(patterns) - # for per_pattern in range(max_per_pattern, -1, -1): - # for next_expr in per_expr( # nopep8 - # expr_groups, sum_int + per_pattern - # ): - # yield_expr([expr] * per_pattern + next_expr) - # else: - # if sum_int >= match_count[0]: - # yield_expr([]) - # # Until we learn that the below is incorrect, - # # we'll return basically no match. - # yield None - - # # for sequence in per_expr(expr_groups.items()): - # def yield_expr(sequence): - # # FIXME: this call is wrong and needs a - # # wrapper_function as the 1st parameter. - # wrappings = self.get_wrappings( - # sequence, match_count[1], expression, attributes - # ) - # for wrapping in wrappings: - # 1/0 - # # for next in per_name(groups[1:], vars_dict): - - # def yield_next(next_expr): - # setting = next_expr.copy() - # setting[name] = wrapping - # yield_name(setting) - - # per_name(yield_next, groups[1:], vars_dict) - - # per_expr(yield_expr, expr_groups) - pass - else: # no groups left - yield_name(vars_dict) - - # for setting in per_name(groups.items(), vars): - # def yield_name(setting): - # yield_func(setting) - per_name(yield_choice, tuple(groups.items()), vars_dict) + get_pre_choices_orderless( + self, yield_choice, expression, attributes, vars_dict + ) else: yield_choice(vars_dict) def __init__(self, expr: Expression, evaluation: Optional[Evaluation] = None): self.expr = expr head = expr.head - self.attributes = ( + attributes = ( None if evaluation is None else head.get_attributes(evaluation.definition) ) + self.__set_pattern_attributes__(attributes) self.head = Pattern.create(head) self.elements = [Pattern.create(element) for element in expr.elements] + def __set_pattern_attributes__(self, attributes): + self.attributes = attributes + def filter_elements(self, head_name: str): """Filter the elements with a given head_name""" head_name = ensure_context(head_name) @@ -1084,3 +996,109 @@ def yield_wrapping(item): attributes, include_flattened=include_flattened, ) + + +def get_pre_choices_orderless( + pat: ExpressionPattern, + yield_choice: Callable, + expression: Expression, + attributes: int, + vars_dict: dict, +): + pat.sort() + patterns = pat.filter_elements("Pattern") + # a dict with entries having patterns with the same name + # which are not in vars_dict. + groups = {} + prev_pattern = prev_name = None + for pattern in patterns: + name = pattern.elements[0].get_name() + existing = vars_dict.get(name, None) + if existing is None: + # There's no need for pre-choices if the variable is + # already set. + if name == prev_name: + if name in groups: + groups[name].append(pattern) + else: + groups[name] = [prev_pattern, pattern] + prev_pattern = pattern + prev_name = name + # prev_element = None + + # count duplicate elements + expr_groups = {} + for element in expression.elements: + expr_groups[element] = expr_groups.get(element, 0) + 1 + + def per_name(yield_name: Callable, groups: Tuple, vars_dict: dict): + """ + Yields possible variable settings (dictionaries) for the + remaining pattern groups + """ + # TODO: check why this condition is never reached in tests. + if groups: + # name, patterns = groups[0] + + # match_count = [0, None] + # for pattern in patterns: + # sub_match_count = pattern.get_match_count() + # if sub_match_count[0] > match_count[0]: + # match_count[0] = sub_match_count[0] + # if match_count[1] is None or ( + # sub_match_count[1] is not None + # and sub_match_count[1] < match_count[1] + # ): + # match_count[1] = sub_match_count[1] + # # possibilities = [{}] + # # sum = 0 + + # def per_expr(yield_expr, expr_groups, sum_int=0): + # """ + # Yields possible values (sequence lists) for the current + # variable (name) taking into account the + # (expression, count)'s in expr_groups + # """ + + # if expr_groups: + # expr, count = expr_groups.popitem() + # max_per_pattern = count // len(patterns) + # for per_pattern in range(max_per_pattern, -1, -1): + # for next_expr in per_expr( # nopep8 + # expr_groups, sum_int + per_pattern + # ): + # yield_expr([expr] * per_pattern + next_expr) + # else: + # if sum_int >= match_count[0]: + # yield_expr([]) + # # Until we learn that the below is incorrect, + # # we'll return basically no match. + # yield None + + # # for sequence in per_expr(expr_groups.items()): + # def yield_expr(sequence): + # # FIXME: this call is wrong and needs a + # # wrapper_function as the 1st parameter. + # wrappings = pat.get_wrappings( + # sequence, match_count[1], expression, attributes + # ) + # for wrapping in wrappings: + # 1/0 + # # for next in per_name(groups[1:], vars_dict): + + # def yield_next(next_expr): + # setting = next_expr.copy() + # setting[name] = wrapping + # yield_name(setting) + + # per_name(yield_next, groups[1:], vars_dict) + + # per_expr(yield_expr, expr_groups) + pass + else: # no groups left + yield_name(vars_dict) + + # for setting in per_name(groups.items(), vars): + # def yield_name(setting): + # yield_func(setting) + per_name(yield_choice, tuple(groups.items()), vars_dict)