Skip to content

Commit

Permalink
outer: disallow non-object numpy arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Aug 6, 2024
1 parent bb955e3 commit d567ad7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 14 deletions.
16 changes: 11 additions & 5 deletions arraycontext/container/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,8 +949,7 @@ def outer(a: Any, b: Any) -> Any:
Tweaks the behavior of :func:`numpy.outer` to return a lower-dimensional
object if either/both of *a* and *b* are scalars (whereas :func:`numpy.outer`
always returns a matrix). Here the definition of "scalar" includes
all non-array-container types and any scalar-like array container types
(including non-object numpy arrays).
all non-array-container types and any scalar-like array container types.
If *a* and *b* are both array containers, the result will have the same type
as *a*. If both are array containers and neither is an object array, they must
Expand All @@ -968,12 +967,19 @@ def treat_as_scalar(x: Any) -> bool:
# This condition is whether "ndarrays should broadcast inside x".
and NumpyObjectArray not in x.__class__._outer_bcast_types)

a_is_ndarray = isinstance(a, np.ndarray)
b_is_ndarray = isinstance(b, np.ndarray)

if a_is_ndarray and a.dtype != object:
raise TypeError("passing a non-object numpy array is not allowed")
if b_is_ndarray and b.dtype != object:
raise TypeError("passing a non-object numpy array is not allowed")

if treat_as_scalar(a) or treat_as_scalar(b):
return a*b
# After this point, "isinstance(o, ndarray)" means o is an object array.
elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
elif a_is_ndarray and b_is_ndarray:
return np.outer(a, b)
elif isinstance(a, np.ndarray) or isinstance(b, np.ndarray):
elif a_is_ndarray or b_is_ndarray:
return map_array_container(lambda x: outer(x, b), a)
else:
if type(a) is not type(b):
Expand Down
9 changes: 0 additions & 9 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,15 +1531,6 @@ def equal(a, b):
b_bcast_dc_of_dofs.momentum),
enthalpy=a_bcast_dc_of_dofs.enthalpy*b_bcast_dc_of_dofs.enthalpy))

# Non-object numpy arrays should be treated as scalars
ary_of_floats = np.ones(len(b_bcast_dc_of_dofs.mass))
assert equal(
outer(ary_of_floats, b_bcast_dc_of_dofs),
ary_of_floats*b_bcast_dc_of_dofs)
assert equal(
outer(a_bcast_dc_of_dofs, ary_of_floats),
a_bcast_dc_of_dofs*ary_of_floats)

# }}}


Expand Down

0 comments on commit d567ad7

Please sign in to comment.