Skip to content

Commit

Permalink
more annotations. Collecting arguments into dicts
Browse files Browse the repository at this point in the history
  • Loading branch information
mmatera committed Sep 17, 2024
1 parent 40ceb2d commit 1490408
Showing 1 changed file with 117 additions and 121 deletions.
238 changes: 117 additions & 121 deletions mathics/core/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
)
from mathics.core.util import permutations, subranges, subsets

# FIXME: create definitions in systemsymbols for missing items below.
SYSTEM_SYMBOLS_PATTERNS = symbol_set(
SymbolAlternatives,
SymbolBlank,
Expand Down Expand Up @@ -325,25 +324,25 @@ def __repr__(self):

def match_symbol(
self,
yield_func,
expression,
vars_dict,
evaluation,
head=None,
element_index=None,
element_count=None,
fully=True,
yield_func: Callable,
expression: BaseElement,
vars_dict: dict,
evaluation: Evaluation,
head: Optional[Symbol] = None,
element_index: Optional[int] = None,
element_count: Optional[int] = None,
fully: bool = True,
):
"""Match against a symbol"""
if expression is self.atom:
yield_func(vars_dict, None)

def get_match_symbol_candidates(
self,
elements,
expression,
attributes,
evaluation,
elements: tuple,
expression: BaseElement,
attributes: int,
evaluation: Evaluation,
vars_dict: Optional[dict] = None,
):
"""Find the candidates that matches with the pattern"""
Expand Down Expand Up @@ -416,31 +415,31 @@ def match(

if not A_FLAT & attributes:
fully = True

parms = {
"attributes": attributes,
"evaluation": evaluation,
"expression": expression,
"fully": fully,
"self": self,
"vars_dict": vars_dict,
"yield_func": yield_func,
}

if not isinstance(expression, Atom):
try:
basic_match_expression(
self,
yield_func,
expression,
vars_dict,
evaluation,
attributes,
fully,
parms,
)
except StopGenerator_ExpressionPattern_match:
return

if A_ONE_IDENTITY & attributes:
match_expression_with_one_identity(
self,
vars_dict,
evaluation,
yield_func,
element_index,
element_count,
head,
expression,
fully,
parms,
element_index=element_index,
element_count=element_count,
head=head,
)

def get_pre_choices(
Expand Down Expand Up @@ -661,11 +660,15 @@ def match_element(
less_first = len(rest_elements) > 0

if A_ORDERLESS & attributes:
parms = {
"expression": expression,
"element": element,
"vars_dict": vars_dict,
"attributes": attributes,
}

sets = expression_pattern_match_element_orderless(
expression,
element,
vars_dict,
attributes,
parms,
candidates,
element_candidates,
less_first,
Expand All @@ -683,36 +686,30 @@ def match_element(
*set_lengths,
)

parms = {
"attributes": attributes,
"depth": depth + 1,
"element": element,
"element_count": element_count,
"element_index": element_index,
"evaluation": evaluation,
"expression": expression,
"fully": fully,
"match_count": match_count,
"next_index": element_index + 1,
"pattern": self,
"rest_elements": rest_elements,
"rest_expression": rest_expression,
"try_flattened": try_flattened,
"yield_func": yield_func,
"vars": vars_dict,
}
if rest_elements:
next_element = rest_elements[0]
next_rest_elements = rest_elements[1:]
else:
next_element = None
next_rest_elements = None
parms["next_element"] = rest_elements[0]
parms["next_rest_elements"] = rest_elements[1:]

for items, items_rest in sets:
expression_pattern_match_element_process_items(
self,
vars_dict,
element,
next_element,
items,
items_rest,
expression,
yield_func,
rest_expression,
rest_elements,
try_flattened,
match_count,
attributes,
next_rest_elements,
element_count,
element_index,
evaluation,
fully,
depth + 1,
element_index + 1,
)
expression_pattern_match_element_process_items(items, items_rest, parms)

def get_match_candidates(
self,
Expand Down Expand Up @@ -762,15 +759,10 @@ def sort(self):


def match_expression_with_one_identity(
self,
vars_dict,
evaluation,
yield_func,
element_index,
element_count,
head,
expression,
fully,
parms: dict,
element_index: int,
element_count: int,
head: Symbol,
):
"""
Process expressions with the attribute OneIdentity.
Expand All @@ -781,10 +773,15 @@ def match_expression_with_one_identity(

# This tries to reduce the pattern to a non empty
# set of default values, and a single pattern.
default_indx = 0
optionals = {}
new_pattern = None
pattern_head = self.head.expr

self: ExpressionPattern = parms["self"]
vars_dict: dict = parms["vars_dict"]
evaluation: Evaluation = parms["evaluation"]

default_indx: int = 0
optionals: dict = {}
new_pattern: Optional[Pattern] = None
pattern_head: Expression = self.head.expr
for pat_elem in self.elements:
default_indx += 1
if isinstance(pat_elem, AtomPattern):
Expand Down Expand Up @@ -836,20 +833,18 @@ def match_expression_with_one_identity(
vars_dict.update(optionals)
# Try to match the non-optional element with the expression
new_pattern.match(
yield_func,
expression,
parms["yield_func"],
parms["expression"],
vars_dict,
evaluation,
head=head,
element_index=element_index,
element_count=element_count,
fully=fully,
fully=parms["fully"],
)


def basic_match_expression(
self, yield_func, expression, vars_dict, evaluation, attributes, fully
):
def basic_match_expression(parms):
"""
Try to match a pattern with an expression
"""
Expand All @@ -859,6 +854,14 @@ def basic_match_expression(
# next_element = self.elements[0]
# next_elements = self.elements[1:]

self: ExpressionPattern = parms["self"]
yield_func: Callable = parms["yield_func"]
expression: Expression = parms["expression"]
vars_dict: dict = parms["vars_dict"]
evaluation: Evaluation = parms["evaluation"]
attributes: int = parms["attributes"]
fully: bool = parms["fully"]

def yield_choice(pre_vars):
next_element = self.elements[0]
next_elements = self.elements[1:]
Expand Down Expand Up @@ -949,14 +952,11 @@ def yield_head(head_vars, _):


def expression_pattern_match_element_orderless(
expression,
element,
vars_dict,
attributes,
candidates,
element_candidates,
less_first,
set_lengths,
parms: dict,
candidates: tuple,
element_candidates: list,
less_first: bool,
set_lengths: tuple,
):
"""
match element for orderless expressions
Expand All @@ -965,16 +965,18 @@ def expression_pattern_match_element_orderless(
# otherwise, constructing a set() is very slow for large lists.
# performance test case:
# x = Range[100000]; Timing[Combinatorica`BinarySearch[x, 100]]

element: BaseElement = parms["element"]
element_candidates = set(element_candidates) # for fast lookup

sets = None
if element.get_head_name() == "System`Pattern":
varname = element.elements[0].get_name()
existing = vars_dict.get(varname, None)
existing = parms["vars_dict"].get(varname, None)
if existing is not None:
head = existing.get_head()
if head.get_name() == "System`Sequence" or (
A_FLAT & attributes and head == expression.get_head()
A_FLAT & parms["attributes"] and head == parms["expression"].get_head()
):
needed = existing.elements
else:
Expand All @@ -1001,31 +1003,25 @@ def expression_pattern_match_element_orderless(


def expression_pattern_match_element_process_items(
self,
vars_dict,
element,
next_element,
items,
items_rest,
expression,
yield_func,
rest_expression,
rest_elements,
try_flattened,
match_count,
attributes,
next_rest_elements,
element_count,
element_index,
evaluation,
fully,
next_depth,
next_index,
items: tuple,
items_rest: tuple,
parms: dict,
):
# Include wrappings like Plus[a, b] only if not all items taken
# - in that case we would match the same expression over and over.

include_flattened = try_flattened and 0 < len(items) < len(expression.elements)
attributes: int = parms["attributes"]
element_count: int = parms["element_count"]
expression: Expression = parms["expression"]
evaluation: Evaluation = parms["evaluation"]
fully: bool = parms["fully"]
pattern: ExpressionPattern = parms["pattern"]
rest_expression = parms["rest_expression"]
yield_func: Callable = parms["yield_func"]

include_flattened: bool = parms["try_flattened"] and 0 < len(items) < len(
expression.elements
)

# Don't try flattened when the expression would remain the same!

Expand All @@ -1049,41 +1045,41 @@ def element_yield(next_vars_parm, next_rest_parm):
)

def match_yield(new_vars, _):
if rest_elements:
self.match_element(
if parms["rest_elements"]:
pattern.match_element(
element_yield,
next_element,
next_rest_elements,
parms["next_element"],
parms["next_rest_elements"],
items_rest,
new_vars,
expression,
attributes,
evaluation,
fully=fully,
depth=next_depth,
element_index=next_index,
depth=parms["depth"],
element_index=parms["next_index"],
element_count=element_count,
)
else:
if not fully or (not items_rest[0] and not items_rest[1]):
yield_func(new_vars, items_rest)

def yield_wrapping(item):
element.match(
parms["element"].match(
match_yield,
item,
vars_dict,
parms["vars"],
evaluation,
fully=True,
head=expression.head,
element_index=element_index,
element_index=parms["element_index"],
element_count=element_count,
)

self.get_wrappings(
pattern.get_wrappings(
yield_wrapping,
tuple(items),
match_count[1],
parms["match_count"][1],
expression,
attributes,
include_flattened=include_flattened,
Expand Down

0 comments on commit 1490408

Please sign in to comment.