Skip to content

Commit

Permalink
split prechoices
Browse files Browse the repository at this point in the history
  • Loading branch information
mmatera committed Sep 17, 2024
1 parent 1490408 commit 4b7befb
Showing 1 changed file with 117 additions and 99 deletions.
216 changes: 117 additions & 99 deletions mathics/core/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,9 @@ def match(

evaluation.check_stopped()
if self.attributes is None:
self.attributes = self.head.get_attributes(evaluation.definitions)
self.__set_pattern_attributes__(
self.head.get_attributes(evaluation.definitions)
)
attributes = self.attributes

if not A_FLAT & attributes:
Expand Down Expand Up @@ -453,115 +455,25 @@ def get_pre_choices(
If not Orderless, call yield_choice with vars as the parameter.
"""
if A_ORDERLESS & attributes:
self.sort()
patterns = self.filter_elements("Pattern")
# a dict with entries having patterns with the same name
# which are not in vars_dict.
groups = {}
prev_pattern = prev_name = None
for pattern in patterns:
name = pattern.elements[0].get_name()
existing = vars_dict.get(name, None)
if existing is None:
# There's no need for pre-choices if the variable is
# already set.
if name == prev_name:
if name in groups:
groups[name].append(pattern)
else:
groups[name] = [prev_pattern, pattern]
prev_pattern = pattern
prev_name = name
# prev_element = None

# count duplicate elements
expr_groups = {}
for element in expression.elements:
expr_groups[element] = expr_groups.get(element, 0) + 1

def per_name(yield_name: Callable, groups: Tuple, vars_dict: dict):
"""
Yields possible variable settings (dictionaries) for the
remaining pattern groups
"""
# TODO: check why this condition is never reached in tests.
if groups:
# name, patterns = groups[0]

# match_count = [0, None]
# for pattern in patterns:
# sub_match_count = pattern.get_match_count()
# if sub_match_count[0] > match_count[0]:
# match_count[0] = sub_match_count[0]
# if match_count[1] is None or (
# sub_match_count[1] is not None
# and sub_match_count[1] < match_count[1]
# ):
# match_count[1] = sub_match_count[1]
# # possibilities = [{}]
# # sum = 0

# def per_expr(yield_expr, expr_groups, sum_int=0):
# """
# Yields possible values (sequence lists) for the current
# variable (name) taking into account the
# (expression, count)'s in expr_groups
# """

# if expr_groups:
# expr, count = expr_groups.popitem()
# max_per_pattern = count // len(patterns)
# for per_pattern in range(max_per_pattern, -1, -1):
# for next_expr in per_expr( # nopep8
# expr_groups, sum_int + per_pattern
# ):
# yield_expr([expr] * per_pattern + next_expr)
# else:
# if sum_int >= match_count[0]:
# yield_expr([])
# # Until we learn that the below is incorrect,
# # we'll return basically no match.
# yield None

# # for sequence in per_expr(expr_groups.items()):
# def yield_expr(sequence):
# # FIXME: this call is wrong and needs a
# # wrapper_function as the 1st parameter.
# wrappings = self.get_wrappings(
# sequence, match_count[1], expression, attributes
# )
# for wrapping in wrappings:
# 1/0
# # for next in per_name(groups[1:], vars_dict):

# def yield_next(next_expr):
# setting = next_expr.copy()
# setting[name] = wrapping
# yield_name(setting)

# per_name(yield_next, groups[1:], vars_dict)

# per_expr(yield_expr, expr_groups)
pass
else: # no groups left
yield_name(vars_dict)

# for setting in per_name(groups.items(), vars):
# def yield_name(setting):
# yield_func(setting)
per_name(yield_choice, tuple(groups.items()), vars_dict)
get_pre_choices_orderless(
self, yield_choice, expression, attributes, vars_dict
)
else:
yield_choice(vars_dict)

def __init__(self, expr: Expression, evaluation: Optional[Evaluation] = None):
self.expr = expr
head = expr.head
self.attributes = (
attributes = (
None if evaluation is None else head.get_attributes(evaluation.definition)
)
self.__set_pattern_attributes__(attributes)
self.head = Pattern.create(head)
self.elements = [Pattern.create(element) for element in expr.elements]

def __set_pattern_attributes__(self, attributes):
self.attributes = attributes

def filter_elements(self, head_name: str):
"""Filter the elements with a given head_name"""
head_name = ensure_context(head_name)
Expand Down Expand Up @@ -1084,3 +996,109 @@ def yield_wrapping(item):
attributes,
include_flattened=include_flattened,
)


def get_pre_choices_orderless(
pat: ExpressionPattern,
yield_choice: Callable,
expression: Expression,
attributes: int,
vars_dict: dict,
):
pat.sort()
patterns = pat.filter_elements("Pattern")
# a dict with entries having patterns with the same name
# which are not in vars_dict.
groups = {}
prev_pattern = prev_name = None
for pattern in patterns:
name = pattern.elements[0].get_name()
existing = vars_dict.get(name, None)
if existing is None:
# There's no need for pre-choices if the variable is
# already set.
if name == prev_name:
if name in groups:
groups[name].append(pattern)
else:
groups[name] = [prev_pattern, pattern]
prev_pattern = pattern
prev_name = name
# prev_element = None

# count duplicate elements
expr_groups = {}
for element in expression.elements:
expr_groups[element] = expr_groups.get(element, 0) + 1

def per_name(yield_name: Callable, groups: Tuple, vars_dict: dict):
"""
Yields possible variable settings (dictionaries) for the
remaining pattern groups
"""
# TODO: check why this condition is never reached in tests.
if groups:
# name, patterns = groups[0]

# match_count = [0, None]
# for pattern in patterns:
# sub_match_count = pattern.get_match_count()
# if sub_match_count[0] > match_count[0]:
# match_count[0] = sub_match_count[0]
# if match_count[1] is None or (
# sub_match_count[1] is not None
# and sub_match_count[1] < match_count[1]
# ):
# match_count[1] = sub_match_count[1]
# # possibilities = [{}]
# # sum = 0

# def per_expr(yield_expr, expr_groups, sum_int=0):
# """
# Yields possible values (sequence lists) for the current
# variable (name) taking into account the
# (expression, count)'s in expr_groups
# """

# if expr_groups:
# expr, count = expr_groups.popitem()
# max_per_pattern = count // len(patterns)
# for per_pattern in range(max_per_pattern, -1, -1):
# for next_expr in per_expr( # nopep8
# expr_groups, sum_int + per_pattern
# ):
# yield_expr([expr] * per_pattern + next_expr)
# else:
# if sum_int >= match_count[0]:
# yield_expr([])
# # Until we learn that the below is incorrect,
# # we'll return basically no match.
# yield None

# # for sequence in per_expr(expr_groups.items()):
# def yield_expr(sequence):
# # FIXME: this call is wrong and needs a
# # wrapper_function as the 1st parameter.
# wrappings = pat.get_wrappings(
# sequence, match_count[1], expression, attributes
# )
# for wrapping in wrappings:
# 1/0
# # for next in per_name(groups[1:], vars_dict):

# def yield_next(next_expr):
# setting = next_expr.copy()
# setting[name] = wrapping
# yield_name(setting)

# per_name(yield_next, groups[1:], vars_dict)

# per_expr(yield_expr, expr_groups)
pass
else: # no groups left
yield_name(vars_dict)

# for setting in per_name(groups.items(), vars):
# def yield_name(setting):
# yield_func(setting)
per_name(yield_choice, tuple(groups.items()), vars_dict)

0 comments on commit 4b7befb

Please sign in to comment.