Skip to content

Commit

Permalink
Merge pull request #946 from Mathics3/Fix_ArrayQ
Browse files Browse the repository at this point in the history
Fix ArrayQ for SparseArray
  • Loading branch information
rocky authored Dec 14, 2023
2 parents 3969568 + 0eae3bf commit ebf604d
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 35 deletions.
37 changes: 5 additions & 32 deletions mathics/builtin/testing_expressions/list_oriented.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from mathics.core.evaluation import Evaluation
from mathics.core.exceptions import InvalidLevelspecError
from mathics.core.expression import Expression
from mathics.core.rules import Pattern
from mathics.core.symbols import Atom, SymbolFalse, SymbolTrue
from mathics.core.systemsymbols import SymbolSubsetQ
from mathics.core.systemsymbols import SymbolSubsetQ # , SymbolSparseArray
from mathics.eval.parts import python_levelspec
from mathics.eval.testing_expressions import check_ArrayQ # , check_SparseArrayQ


class ArrayQ(Builtin):
Expand Down Expand Up @@ -51,37 +51,10 @@ class ArrayQ(Builtin):
def eval(self, expr, pattern, test, evaluation: Evaluation):
"ArrayQ[expr_, pattern_, test_]"

pattern = Pattern.create(pattern)

dims = [len(expr.get_elements())] # to ensure an atom is not an array

def check(level, expr):
if not expr.has_form("List", None):
test_expr = Expression(test, expr)
if test_expr.evaluate(evaluation) != SymbolTrue:
return False
level_dim = None
else:
level_dim = len(expr.elements)

if len(dims) > level:
if dims[level] != level_dim:
return False
else:
dims.append(level_dim)
if level_dim is not None:
for element in expr.elements:
if not check(level + 1, element):
return False
return True

if not check(0, expr):
return SymbolFalse
# if not isinstance(expr, Atom) and expr.head.sameQ(SymbolSparseArray):
# return check_SparseArrayQ(expr, pattern, test, evaluation)

depth = len(dims) - 1 # None doesn't count
if not pattern.does_match(Integer(depth), evaluation):
return SymbolFalse
return SymbolTrue
return check_ArrayQ(expr, pattern, test, evaluation)


class DisjointQ(Test):
Expand Down
69 changes: 66 additions & 3 deletions mathics/eval/testing_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import sympy

from mathics.core.atoms import Complex, Integer0, Integer1, IntegerM1
from mathics.core.atoms import Complex, Integer, Integer0, Integer1, IntegerM1
from mathics.core.evaluation import Evaluation
from mathics.core.expression import Expression
from mathics.core.systemsymbols import SymbolDirectedInfinity
from mathics.core.rules import Pattern
from mathics.core.symbols import SymbolFalse, SymbolTimes, SymbolTrue
from mathics.core.systemsymbols import SymbolDirectedInfinity, SymbolSparseArray


def do_cmp(x1, x2) -> Optional[int]:

# don't attempt to compare complex numbers
for x in (x1, x2):
# TODO: Send message General::nord
Expand Down Expand Up @@ -99,3 +101,64 @@ def expr_min(elements):

def is_number(sympy_value) -> bool:
return hasattr(sympy_value, "is_number") or isinstance(sympy_value, sympy.Float)


def check_ArrayQ(expr, pattern, test, evaluation: Evaluation):
"Check if expr is an Array which test yields true for each of its elements."

pattern = Pattern.create(pattern)

dims = [len(expr.get_elements())] # to ensure an atom is not an array

def check(level, expr):
if not expr.has_form("List", None):
test_expr = Expression(test, expr)
if test_expr.evaluate(evaluation) != SymbolTrue:
return False
level_dim = None
else:
level_dim = len(expr.elements)

if len(dims) > level:
if dims[level] != level_dim:
return False
else:
dims.append(level_dim)
if level_dim is not None:
for element in expr.elements:
if not check(level + 1, element):
return False
return True

if not check(0, expr):
return SymbolFalse

depth = len(dims) - 1 # None doesn't count
if not pattern.does_match(Integer(depth), evaluation):
return SymbolFalse

return SymbolTrue


def check_SparseArrayQ(expr, pattern, test, evaluation: Evaluation):
"Check if expr is a SparseArray which test yields true for each of its elements."

if not expr.head.sameQ(SymbolSparseArray):
return SymbolFalse

pattern = Pattern.create(pattern)
dims, default_value, rules = expr.elements[1:]
if not pattern.does_match(Integer(len(dims.elements)), evaluation):
return SymbolFalse

array_size = Expression(SymbolTimes, *dims.elements).evaluate(evaluation)
if array_size.value > len(rules.elements): # expr is not full
test_expr = Expression(test, default_value) # test default value
if test_expr.evaluate(evaluation) != SymbolTrue:
return SymbolFalse
for rule in rules.elements:
test_expr = Expression(test, rule.elements[-1])
if test_expr.evaluate(evaluation) != SymbolTrue:
return SymbolFalse

return SymbolTrue

0 comments on commit ebf604d

Please sign in to comment.