diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 905fd0f8..17173a94 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -42,6 +42,10 @@ class PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): pass +class _NoValue: + pass + + class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): """ A :mod:`numpy` mimic for :class:`PytatoPyOpenCLArrayContext`. @@ -91,22 +95,50 @@ def minimum(self, x, y): def where(self, criterion, then, else_): return rec_multimap_array_container(pt.where, criterion, then, else_) - def sum(self, a, axis=None, dtype=None): - def _pt_sum(ary): + @staticmethod + def _reduce(container_binop, array_reduce, + ary, *, + axis, dtype, initial): + def container_reduce(ctr): + if initial is _NoValue: + try: + return reduce(container_binop, ctr) + except TypeError as exc: + assert "empty sequence" in str(exc) + raise ValueError("zero-size reduction operation " + "without supplied 'initial' value") + else: + return reduce(container_binop, ctr, initial) + + def actual_array_reduce(ary): if dtype not in [ary.dtype, None]: raise NotImplementedError - return pt.sum(ary, axis=axis) - - return rec_map_reduce_array_container(sum, _pt_sum, a) - - def min(self, a, axis=None): - return rec_map_reduce_array_container( - partial(reduce, pt.minimum), partial(pt.amin, axis=axis), a) + if initial is _NoValue: + return array_reduce(ary, axis=axis) + else: + return array_reduce(ary, axis=axis, initial=initial) - def max(self, a, axis=None): return rec_map_reduce_array_container( - partial(reduce, pt.maximum), partial(pt.amax, axis=axis), a) + container_reduce, + actual_array_reduce, + ary) + + # * appears where positional signature starts diverging from numpy + def sum(self, a, axis=None, dtype=None, *, initial=0): + import operator + return self._reduce(operator.add, pt.sum, a, + axis=axis, dtype=dtype, initial=initial) + + # * appears where positional signature starts diverging from numpy + def min(self, a, axis=None, *, initial=_NoValue): + return self._reduce(pt.minimum, pt.amin, a, + axis=axis, dtype=None, initial=initial) + + # * appears where positional signature starts diverging from numpy + def max(self, a, axis=None, *, initial=_NoValue): + return self._reduce(pt.maximum, pt.amax, a, + axis=axis, dtype=None, initial=initial) def stack(self, arrays, axis=0): return rec_multimap_array_container(