Skip to content

Commit

Permalink
Unify pytato reduction handling, support 'initial' arg
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Dec 28, 2021
1 parent 1810b0e commit c677cc9
Showing 1 changed file with 43 additions and 11 deletions.
54 changes: 43 additions & 11 deletions arraycontext/impl/pytato/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
pass


class _NoValue:
pass


class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace):
"""
A :mod:`numpy` mimic for :class:`PytatoPyOpenCLArrayContext`.
Expand Down Expand Up @@ -91,22 +95,50 @@ def minimum(self, x, y):
def where(self, criterion, then, else_):
return rec_multimap_array_container(pt.where, criterion, then, else_)

def sum(self, a, axis=None, dtype=None):
def _pt_sum(ary):
@staticmethod
def _reduce(container_binop, array_reduce,
ary, *,
axis, dtype, initial):
def container_reduce(ctr):
if initial is _NoValue:
try:
return reduce(container_binop, ctr)
except TypeError as exc:
assert "empty sequence" in str(exc)
raise ValueError("zero-size reduction operation "
"without supplied 'initial' value")
else:
return reduce(container_binop, ctr, initial)

def actual_array_reduce(ary):
if dtype not in [ary.dtype, None]:
raise NotImplementedError

return pt.sum(ary, axis=axis)

return rec_map_reduce_array_container(sum, _pt_sum, a)

def min(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, pt.minimum), partial(pt.amin, axis=axis), a)
if initial is _NoValue:
return array_reduce(ary, axis=axis)
else:
return array_reduce(ary, axis=axis, initial=initial)

def max(self, a, axis=None):
return rec_map_reduce_array_container(
partial(reduce, pt.maximum), partial(pt.amax, axis=axis), a)
container_reduce,
actual_array_reduce,
ary)

# * appears where positional signature starts diverging from numpy
def sum(self, a, axis=None, dtype=None, *, initial=0):
import operator
return self._reduce(operator.add, pt.sum, a,
axis=axis, dtype=dtype, initial=initial)

# * appears where positional signature starts diverging from numpy
def min(self, a, axis=None, *, initial=_NoValue):
return self._reduce(pt.minimum, pt.amin, a,
axis=axis, dtype=None, initial=initial)

# * appears where positional signature starts diverging from numpy
def max(self, a, axis=None, *, initial=_NoValue):
return self._reduce(pt.maximum, pt.amax, a,
axis=axis, dtype=None, initial=initial)

def stack(self, arrays, axis=0):
return rec_multimap_array_container(
Expand Down

0 comments on commit c677cc9

Please sign in to comment.