Skip to content

Commit

Permalink
more reorganizing
Browse files Browse the repository at this point in the history
  • Loading branch information
mmatera committed Sep 17, 2024
1 parent 24a75f1 commit 3b0e96f
Showing 1 changed file with 176 additions and 104 deletions.
280 changes: 176 additions & 104 deletions mathics/core/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,42 +661,16 @@ def match_element(
less_first = len(rest_elements) > 0

if A_ORDERLESS & attributes:
# we only want element_candidates to be a set if we're orderless.
# otherwise, constructing a set() is very slow for large lists.
# performance test case:
# x = Range[100000]; Timing[Combinatorica`BinarySearch[x, 100]]
element_candidates = set(element_candidates) # for fast lookup

sets = None
if element.get_head_name() == "System`Pattern":
varname = element.elements[0].get_name()
existing = vars_dict.get(varname, None)
if existing is not None:
head = existing.get_head()
if head.get_name() == "System`Sequence" or (
A_FLAT & attributes and head == expression.get_head()
):
needed = existing.elements
else:
needed = [existing]
available = list(candidates)
for needed_element in needed:
if (
needed_element in available
and needed_element in element_candidates # nopep8
):
available.remove(needed_element)
else:
return
sets = [(needed, ([], available))]

if sets is None:
sets = subsets(
candidates,
included=element_candidates,
less_first=less_first,
*set_lengths,
)
sets = expression_pattern_match_element_orderless(
expression,
element,
vars_dict,
attributes,
candidates,
element_candidates,
less_first,
set_lengths,
)
else:
# a generator that yields partitions of
# candidates as [before | block | after ]
Expand All @@ -708,80 +682,36 @@ def match_element(
less_first=less_first,
*set_lengths,
)

if rest_elements:
next_element = rest_elements[0]
next_rest_elements = rest_elements[1:]
next_depth = depth + 1
next_index = element_index + 1
else:
next_element = None
next_rest_elements = None

for items, items_rest in sets:
# Include wrappings like Plus[a, b] only if not all items taken
# - in that case we would match the same expression over and over.

include_flattened = try_flattened and 0 < len(items) < len(
expression.elements
)

# Don't try flattened when the expression would remain the same!

def element_yield(next_vars_parm, next_rest_parm):
# if next_rest is None:
# next_rest = ([], [])
# yield_func(next_vars, (rest_expression[0] + items_rest[0],
# next_rest[1]))
if next_rest_parm is None:
yield_func(
next_vars_parm,
(list(chain(rest_expression[0], items_rest[0])), []),
)
else:
yield_func(
next_vars_parm,
(
list(chain(rest_expression[0], items_rest[0])),
next_rest_parm[1],
),
)

def match_yield(new_vars, _):
if rest_elements:
self.match_element(
element_yield,
next_element,
next_rest_elements,
items_rest,
new_vars,
expression,
attributes,
evaluation,
fully=fully,
depth=next_depth,
element_index=next_index,
element_count=element_count,
)
else:
if not fully or (not items_rest[0] and not items_rest[1]):
yield_func(new_vars, items_rest)

def yield_wrapping(item):
element.match(
match_yield,
item,
vars_dict,
evaluation,
fully=True,
head=expression.head,
element_index=element_index,
element_count=element_count,
)

self.get_wrappings(
yield_wrapping,
tuple(items),
match_count[1],
expression_pattern_match_element_process_items(
self,
vars_dict,
element,
next_element,
items,
items_rest,
expression,
yield_func,
rest_expression,
rest_elements,
try_flattened,
match_count,
attributes,
include_flattened=include_flattened,
next_rest_elements,
element_count,
element_index,
evaluation,
fully,
depth + 1,
element_index + 1,
)

def get_match_candidates(
Expand Down Expand Up @@ -827,7 +757,7 @@ def get_match_candidates_count(
return count

def sort(self):
"""Sot the elements according to their sort key"""
"""Sort the elements according to their sort key"""
self.elements.sort(key=lambda e: e.get_sort_key(pattern_sort=True))


Expand Down Expand Up @@ -1016,3 +946,145 @@ def yield_head(head_vars, _):
return

self.head.match(yield_head, expression.get_head(), vars_dict, evaluation)


def expression_pattern_match_element_orderless(
expression,
element,
vars_dict,
attributes,
candidates,
element_candidates,
less_first,
set_lengths,
):
"""
match element for orderless expressions
"""
# we only want element_candidates to be a set if we're orderless.
# otherwise, constructing a set() is very slow for large lists.
# performance test case:
# x = Range[100000]; Timing[Combinatorica`BinarySearch[x, 100]]
element_candidates = set(element_candidates) # for fast lookup

sets = None
if element.get_head_name() == "System`Pattern":
varname = element.elements[0].get_name()
existing = vars_dict.get(varname, None)
if existing is not None:
head = existing.get_head()
if head.get_name() == "System`Sequence" or (
A_FLAT & attributes and head == expression.get_head()
):
needed = existing.elements
else:
needed = [existing]
available = list(candidates)
for needed_element in needed:
if (
needed_element in available
and needed_element in element_candidates # nopep8
):
available.remove(needed_element)
else:
return set()
sets = [(needed, ([], available))]

if sets is None:
sets = subsets(
candidates,
included=element_candidates,
less_first=less_first,
*set_lengths,
)
return sets


def expression_pattern_match_element_process_items(
self,
vars_dict,
element,
next_element,
items,
items_rest,
expression,
yield_func,
rest_expression,
rest_elements,
try_flattened,
match_count,
attributes,
next_rest_elements,
element_count,
element_index,
evaluation,
fully,
next_depth,
next_index,
):
# Include wrappings like Plus[a, b] only if not all items taken
# - in that case we would match the same expression over and over.

include_flattened = try_flattened and 0 < len(items) < len(expression.elements)

# Don't try flattened when the expression would remain the same!

def element_yield(next_vars_parm, next_rest_parm):
# if next_rest is None:
# next_rest = ([], [])
# yield_func(next_vars, (rest_expression[0] + items_rest[0],
# next_rest[1]))
if next_rest_parm is None:
yield_func(
next_vars_parm,
(list(chain(rest_expression[0], items_rest[0])), []),
)
else:
yield_func(
next_vars_parm,
(
list(chain(rest_expression[0], items_rest[0])),
next_rest_parm[1],
),
)

def match_yield(new_vars, _):
if rest_elements:
self.match_element(
element_yield,
next_element,
next_rest_elements,
items_rest,
new_vars,
expression,
attributes,
evaluation,
fully=fully,
depth=next_depth,
element_index=next_index,
element_count=element_count,
)
else:
if not fully or (not items_rest[0] and not items_rest[1]):
yield_func(new_vars, items_rest)

def yield_wrapping(item):
element.match(
match_yield,
item,
vars_dict,
evaluation,
fully=True,
head=expression.head,
element_index=element_index,
element_count=element_count,
)

self.get_wrappings(
yield_wrapping,
tuple(items),
match_count[1],
expression,
attributes,
include_flattened=include_flattened,
)

0 comments on commit 3b0e96f

Please sign in to comment.