Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: prevent reducers like ak.sum on records (v2) #1607

Merged
merged 8 commits into from
Aug 25, 2022
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/bitmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
return self.toByteMaskedArray()._reduce_next(
reducer,
Expand All @@ -569,6 +570,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

def _validity_error(self, path):
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/bytemaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
mask_length = self._mask.length

Expand Down Expand Up @@ -899,6 +900,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

if not branch and negaxis == depth:
Expand Down
53 changes: 32 additions & 21 deletions src/awkward/_v2/contents/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def dummy(self):
def local_index(self, axis):
return self._local_index(axis, 0)

def _reduce(self, reducer, axis=-1, mask=True, keepdims=False):
def _reduce(self, reducer, axis=-1, mask=True, keepdims=False, behavior=None):
if axis is None:
raise ak._v2._util.error(NotImplementedError)

Expand Down Expand Up @@ -861,39 +861,50 @@ def _reduce(self, reducer, axis=-1, mask=True, keepdims=False):
1,
mask,
keepdims,
behavior,
)

return next[0]

def argmin(self, axis=-1, mask=True, keepdims=False):
return self._reduce(awkward._v2._reducers.ArgMin, axis, mask, keepdims)
def argmin(self, axis=-1, mask=True, keepdims=False, behavior=None):
return self._reduce(
awkward._v2._reducers.ArgMin, axis, mask, keepdims, behavior
)

def argmax(self, axis=-1, mask=True, keepdims=False):
return self._reduce(awkward._v2._reducers.ArgMax, axis, mask, keepdims)
def argmax(self, axis=-1, mask=True, keepdims=False, behavior=None):
return self._reduce(
awkward._v2._reducers.ArgMax, axis, mask, keepdims, behavior
)

def count(self, axis=-1, mask=False, keepdims=False):
return self._reduce(awkward._v2._reducers.Count, axis, mask, keepdims)
def count(self, axis=-1, mask=False, keepdims=False, behavior=None):
return self._reduce(awkward._v2._reducers.Count, axis, mask, keepdims, behavior)

def count_nonzero(self, axis=-1, mask=False, keepdims=False):
return self._reduce(awkward._v2._reducers.CountNonzero, axis, mask, keepdims)
def count_nonzero(self, axis=-1, mask=False, keepdims=False, behavior=None):
return self._reduce(
awkward._v2._reducers.CountNonzero, axis, mask, keepdims, behavior
)

def sum(self, axis=-1, mask=False, keepdims=False):
return self._reduce(awkward._v2._reducers.Sum, axis, mask, keepdims)
def sum(self, axis=-1, mask=False, keepdims=False, behavior=None):
return self._reduce(awkward._v2._reducers.Sum, axis, mask, keepdims, behavior)

def prod(self, axis=-1, mask=False, keepdims=False):
return self._reduce(awkward._v2._reducers.Prod, axis, mask, keepdims)
def prod(self, axis=-1, mask=False, keepdims=False, behavior=None):
return self._reduce(awkward._v2._reducers.Prod, axis, mask, keepdims, behavior)

def any(self, axis=-1, mask=False, keepdims=False):
return self._reduce(awkward._v2._reducers.Any, axis, mask, keepdims)
def any(self, axis=-1, mask=False, keepdims=False, behavior=None):
return self._reduce(awkward._v2._reducers.Any, axis, mask, keepdims, behavior)

def all(self, axis=-1, mask=False, keepdims=False):
return self._reduce(awkward._v2._reducers.All, axis, mask, keepdims)
def all(self, axis=-1, mask=False, keepdims=False, behavior=None):
return self._reduce(awkward._v2._reducers.All, axis, mask, keepdims, behavior)

def min(self, axis=-1, mask=True, keepdims=False, initial=None):
return self._reduce(awkward._v2._reducers.Min(initial), axis, mask, keepdims)
def min(self, axis=-1, mask=True, keepdims=False, initial=None, behavior=None):
return self._reduce(
awkward._v2._reducers.Min(initial), axis, mask, keepdims, behavior
)

def max(self, axis=-1, mask=True, keepdims=False, initial=None):
return self._reduce(awkward._v2._reducers.Max(initial), axis, mask, keepdims)
def max(self, axis=-1, mask=True, keepdims=False, initial=None, behavior=None):
return self._reduce(
awkward._v2._reducers.Max(initial), axis, mask, keepdims, behavior
)

def argsort(self, axis=-1, ascending=True, stable=False, kind=None, order=None):
negaxis = -axis
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/emptyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
as_numpy = self.toNumpyArray(reducer.preferred_dtype)
return as_numpy._reduce_next(
Expand All @@ -272,6 +273,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

def _validity_error(self, path):
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/indexedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
branch, depth = self.branch_depth

Expand Down Expand Up @@ -1013,6 +1014,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

# If we are reducing the contents of this layout,
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/indexedoptionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
branch, depth = self.branch_depth

Expand All @@ -1390,6 +1391,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

# If we are reducing the contents of this layout,
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/listarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
return self.toListOffsetArray64(True)._reduce_next(
reducer,
Expand All @@ -1244,6 +1245,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

def _validity_error(self, path):
Expand Down
4 changes: 4 additions & 0 deletions src/awkward/_v2/contents/listoffsetarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,6 +1467,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
if self._offsets.dtype != np.dtype(np.int64) or (
self._offsets.nplike.known_data and self._offsets[0] != 0
Expand All @@ -1481,6 +1482,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

branch, depth = self.branch_depth
Expand Down Expand Up @@ -1586,6 +1588,7 @@ def _reduce_next(
maxnextparents[0] + 1,
mask,
False,
behavior,
)

out = ak._v2.contents.ListArray(
Expand Down Expand Up @@ -1641,6 +1644,7 @@ def _reduce_next(
globalstarts_length,
mask,
keepdims,
behavior,
)

outoffsets = ak._v2.index.Index64.empty(outlength + 1, self._nplike)
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/numpyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
if len(self._data.shape) != 1 or not self.is_contiguous:
return self.toRegularArray()._reduce_next(
Expand All @@ -1133,6 +1134,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

if isinstance(self.nplike, ak.nplike.Jax):
Expand Down
41 changes: 20 additions & 21 deletions src/awkward/_v2/contents/recordarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,30 +824,29 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
contents = []
for content in self._contents:
contents.append(
content[: self._length]._reduce_next(
reducer,
negaxis,
starts,
shifts,
parents,
outlength,
mask,
keepdims,
reducer_name = reducer.__name__.split(".")[0].lower()

if behavior is None or not (
any(
x[0].__module__.split(".")[-1].split("_")[-1] == reducer_name
for x in behavior.keys()
)
):
raise ak._v2._util.error(
TypeError(
"no ak.{} overloads for custom types: {}".format(
reducer_name, ", ".join(self._fields)
)
)
)
else:
raise ak._v2._util.error(
NotImplementedError(
"overloading reducers for RecordArrays has not been implemented yet"
)
)

return ak._v2.contents.RecordArray(
contents,
self._fields,
outlength,
None,
None,
self._nplike,
)

def _validity_error(self, path):
for i, cont in enumerate(self.contents):
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/regulararray.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
out = self.toListOffsetArray64(True)._reduce_next(
reducer,
Expand All @@ -995,6 +996,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

if not self._content.dimension_optiontype:
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/unionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
simplified = self.simplify_uniontype(mergebool=True)
if isinstance(simplified, UnionArray):
Expand All @@ -1150,6 +1151,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

def _validity_error(self, path):
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/unmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
next = self._content
if isinstance(next, ak._v2.contents.RegularArray):
Expand All @@ -502,6 +503,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

def _validity_error(self, path):
Expand Down
4 changes: 3 additions & 1 deletion src/awkward/_v2/operations/ak_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def reduce(xs):

else:
behavior = ak._v2._util.behavior_of(array)
out = layout.all(axis=axis, mask=mask_identity, keepdims=keepdims)
out = layout.all(
axis=axis, mask=mask_identity, keepdims=keepdims, behavior=behavior
)
if isinstance(out, (ak._v2.contents.Content, ak._v2.record.Record)):
return ak._v2._util.wrap(out, behavior)
else:
Expand Down
4 changes: 3 additions & 1 deletion src/awkward/_v2/operations/ak_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def reduce(xs):

else:
behavior = ak._v2._util.behavior_of(array)
out = layout.any(axis=axis, mask=mask_identity, keepdims=keepdims)
out = layout.any(
axis=axis, mask=mask_identity, keepdims=keepdims, behavior=behavior
)
if isinstance(out, (ak._v2.contents.Content, ak._v2.record.Record)):
return ak._v2._util.wrap(out, behavior)
else:
Expand Down
4 changes: 3 additions & 1 deletion src/awkward/_v2/operations/ak_argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def _impl(array, axis, keepdims, mask_identity, flatten_records):

else:
behavior = ak._v2._util.behavior_of(array)
out = layout.argmax(axis=axis, mask=mask_identity, keepdims=keepdims)
out = layout.argmax(
axis=axis, mask=mask_identity, keepdims=keepdims, behavior=behavior
)
if isinstance(out, (ak._v2.contents.Content, ak._v2.record.Record)):
return ak._v2._util.wrap(out, behavior)
else:
Expand Down
4 changes: 3 additions & 1 deletion src/awkward/_v2/operations/ak_argmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def _impl(array, axis, keepdims, mask_identity, flatten_records):

else:
behavior = ak._v2._util.behavior_of(array)
out = layout.argmin(axis=axis, mask=mask_identity, keepdims=keepdims)
out = layout.argmin(
axis=axis, mask=mask_identity, keepdims=keepdims, behavior=behavior
)
if isinstance(out, (ak._v2.contents.Content, ak._v2.record.Record)):
return ak._v2._util.wrap(out, behavior)
else:
Expand Down
4 changes: 3 additions & 1 deletion src/awkward/_v2/operations/ak_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def reduce(xs):

else:
behavior = ak._v2._util.behavior_of(array)
out = layout.count(axis=axis, mask=mask_identity, keepdims=keepdims)
out = layout.count(
axis=axis, mask=mask_identity, keepdims=keepdims, behavior=behavior
)
if isinstance(out, (ak._v2.contents.Content, ak._v2.record.Record)):
return ak._v2._util.wrap(out, behavior)
else:
Expand Down
4 changes: 3 additions & 1 deletion src/awkward/_v2/operations/ak_count_nonzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def reduce(xs):

else:
behavior = ak._v2._util.behavior_of(array)
out = layout.count_nonzero(axis=axis, mask=mask_identity, keepdims=keepdims)
out = layout.count_nonzero(
axis=axis, mask=mask_identity, keepdims=keepdims, behavior=behavior
)
if isinstance(out, (ak._v2.contents.Content, ak._v2.record.Record)):
return ak._v2._util.wrap(out, behavior)
else:
Expand Down
Loading