Skip to content

Commit

Permalink
next batch of high-level functions
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Sep 13, 2024
1 parent 16a119b commit c96f61e
Show file tree
Hide file tree
Showing 57 changed files with 738 additions and 471 deletions.
2 changes: 1 addition & 1 deletion src/awkward/_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def maybe_highlevel_to_lowlevel(obj):
Args:
obj: an object
Calls #ak.to_layout and returns the result iff. the object is a high-level
Calls #ak.to_layout and returns the result if the object is a high-level
Awkward object, otherwise the object is returned as-is.
This function should be removed once scalars are properly handled by `to_layout`.
Expand Down
50 changes: 50 additions & 0 deletions src/awkward/_namedaxis.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def _set_named_axis_to_attrs(
#
# The strategies are:
# - "keep all" (_identity_named_axis): Keep all named axes in the output array, e.g.: `ak.drop_none`
# - "keep one" (_keep_named_axis): Keep one named axes in the output array, e.g.: `ak.firsts`
# - "remove all" (_remove_all_named_axis): Removes all named axis, e.g.: `ak.categories
# - "remove one" (_remove_named_axis): Remove the named axis from the output array, e.g.: `ak.sum`
# - "unify" (_unify_named_axis): Unify the named axis in the output array given two input arrays, e.g.: `__add__`
# - "collapse" (_collapse_named_axis): Collapse multiple named axis to None in the output array, e.g.: `ak.flatten`
Expand Down Expand Up @@ -250,6 +252,54 @@ def _identity_named_axis(
return tuple(named_axis)


def _keep_named_axis(
named_axis: AxisTuple,
axis: int | None = None,
) -> AxisTuple:
"""
Determines the new named axis after keeping the specified axis. This is useful, for example,
when applying an operation that keeps only one axis.
Args:
named_axis (AxisTuple): The current named axis.
axis (int | None, optional): The index of the axis to keep. If None, all axes are kept. Default is None.
Returns:
AxisTuple: The new named axis after keeping the specified axis.
Examples:
>>> _keep_named_axis(("x", "y", "z"), 1)
("y",)
>>> _keep_named_axis(("x", "y", "z"))
("x", "y", "z")
"""
return tuple(named_axis) if axis is None else (named_axis[axis],)


def _remove_all_named_axis(
named_axis: AxisTuple,
n: int | None = None,
) -> AxisTuple:
"""
Determines the new named axis after removing all axes. This is useful, for example,
when applying an operation that removes all axes.
Args:
named_axis (AxisTuple): The current named axis.
n (int | None, optional): The number of axes to remove. If None, all axes are removed. Default is None.
Returns:
AxisTuple: The new named axis after removing all axes. All elements will be None.
Examples:
>>> _remove_all_named_axis(("x", "y", "z"))
(None, None, None)
>>> _remove_all_named_axis(("x", "y", "z"), 2)
(None, None)
"""
return (None,) * (len(named_axis) if n is None else n)


def _remove_named_axis(
axis: int | None,
named_axis: AxisTuple,
Expand Down
5 changes: 4 additions & 1 deletion src/awkward/operations/ak_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_supports_named_axis,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer, regularize_axis

__all__ = ("all",)

Expand Down Expand Up @@ -91,7 +92,9 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
if not keepdims:
out_named_axis = _remove_named_axis(axis, out_named_axis)

if not isinstance(axis, int) and axis is not None:
axis = regularize_axis(axis)

if not is_integer(axis) and axis is not None:
raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}")

reducer = ak._reducers.All()
Expand Down
5 changes: 4 additions & 1 deletion src/awkward/operations/ak_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_supports_named_axis,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer, regularize_axis

__all__ = ("any",)

Expand Down Expand Up @@ -91,7 +92,9 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
if not keepdims:
out_named_axis = _remove_named_axis(axis, out_named_axis)

if not isinstance(axis, int) and axis is not None:
axis = regularize_axis(axis)

if not is_integer(axis) and axis is not None:
raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}")

reducer = ak._reducers.Any()
Expand Down
8 changes: 1 addition & 7 deletions src/awkward/operations/ak_argcartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@

import awkward as ak
from awkward._dispatch import high_level_function
from awkward._namedaxis import _supports_named_axis
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer, regularize_axis
from awkward._regularize import regularize_axis

__all__ = ("argcartesian",)

Expand Down Expand Up @@ -108,11 +107,6 @@ def argcartesian(


def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attrs):
out_named_axis = None
if _supports_named_axis(arrays[0]) and not is_integer(axis):
# Named axis handling
raise NotImplementedError()

axis = regularize_axis(axis)

if isinstance(arrays, Mapping):
Expand Down
8 changes: 1 addition & 7 deletions src/awkward/operations/ak_argcombinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._namedaxis import _supports_named_axis
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer, regularize_axis
from awkward._regularize import regularize_axis

__all__ = ("argcombinations",)

Expand Down Expand Up @@ -94,11 +93,6 @@ def _impl(
behavior,
attrs,
):
out_named_axis = None
if _supports_named_axis(array) and not is_integer(axis):
# Named axis handling
raise NotImplementedError()

axis = regularize_axis(axis)

if parameters is None:
Expand Down
5 changes: 4 additions & 1 deletion src/awkward/operations/ak_argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_supports_named_axis,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer, regularize_axis

__all__ = ("argmax", "nanargmax")

Expand Down Expand Up @@ -156,7 +157,9 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
if not keepdims:
out_named_axis = _remove_named_axis(axis, out_named_axis)

if not isinstance(axis, int) and axis is not None:
axis = regularize_axis(axis)

if not is_integer(axis) and axis is not None:
raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}")

# axis = regularize_axis(axis)
Expand Down
7 changes: 4 additions & 3 deletions src/awkward/operations/ak_argmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_supports_named_axis,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer, regularize_axis

__all__ = ("argmin", "nanargmin")

Expand Down Expand Up @@ -153,10 +154,10 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
if not keepdims:
out_named_axis = _remove_named_axis(axis, out_named_axis)

if not isinstance(axis, int) and axis is not None:
raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}")
axis = regularize_axis(axis)

# axis = regularize_axis(axis)
if not is_integer(axis) and axis is not None:
raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}")

reducer = ak._reducers.ArgMin()

Expand Down
7 changes: 5 additions & 2 deletions src/awkward/operations/ak_argsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
_supports_named_axis,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer, regularize_axis

__all__ = ("argsort",)

Expand Down Expand Up @@ -90,8 +91,10 @@ def _impl(array, axis, ascending, stable, highlevel, behavior, attrs):
# use strategy "keep all" (see: awkward._namedaxis)
out_named_axis = _identity_named_axis(array.named_axis)

if not isinstance(axis, int) and axis is not None:
raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}")
axis = regularize_axis(axis)

if not is_integer(axis):
raise TypeError(f"'axis' must be an integer by now, not {axis!r}")

out = ak._do.argsort(layout, axis, ascending, stable)

Expand Down
16 changes: 15 additions & 1 deletion src/awkward/operations/ak_categories.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._namedaxis import _remove_all_named_axis

__all__ = ("categories",)

Expand Down Expand Up @@ -49,6 +50,19 @@ def action(layout, **kwargs):

with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

ak._do.recursively_apply(layout, action)

return ctx.wrap(output, highlevel=highlevel)
wrapped_out = ctx.wrap(output, highlevel=highlevel)

# propagate named axis from input to output,
# use strategy "drop all" (see: awkward._namedaxis)
out_named_axis = _remove_all_named_axis(wrapped_out.named_axis, n=wrapped_out.ndim)

return ak.operations.ak_with_named_axis._impl(
wrapped_out,
named_axis=out_named_axis,
highlevel=highlevel,
behavior=ctx.behavior,
attrs=ctx.attrs,
)
8 changes: 1 addition & 7 deletions src/awkward/operations/ak_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._namedaxis import _supports_named_axis
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer, regularize_axis
from awkward._regularize import regularize_axis

__all__ = ("combinations",)

Expand Down Expand Up @@ -215,11 +214,6 @@ def _impl(
behavior,
attrs,
):
out_named_axis = None
if _supports_named_axis(array) and not is_integer(axis):
# Named axis handling
raise NotImplementedError()

axis = regularize_axis(axis)

if with_name is None:
Expand Down
8 changes: 1 addition & 7 deletions src/awkward/operations/ak_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
from awkward._dispatch import high_level_function
from awkward._do import mergeable
from awkward._layout import HighLevelContext, ensure_same_backend, maybe_posaxis
from awkward._namedaxis import _supports_named_axis
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._nplikes.shape import unknown_length
from awkward._parameters import type_parameters_equal
from awkward._regularize import is_integer, regularize_axis
from awkward._regularize import regularize_axis
from awkward._typing import Sequence
from awkward.contents import Content
from awkward.operations.ak_fill_none import fill_none
Expand Down Expand Up @@ -93,11 +92,6 @@ def _merge_as_union(


def _impl(arrays, axis, mergebool, highlevel, behavior, attrs):
out_named_axis = None
if _supports_named_axis(arrays[0]) and not is_integer(axis):
# Named axis handling
raise NotImplementedError()

axis = regularize_axis(axis)

# Simple single-array, axis=0 fast-path
Expand Down
18 changes: 7 additions & 11 deletions src/awkward/operations/ak_corr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
ensure_same_backend,
maybe_highlevel_to_lowlevel,
)
from awkward._namedaxis import _supports_named_axis
from awkward._nplikes import ufuncs
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer, regularize_axis

__all__ = ("corr",)

Expand Down Expand Up @@ -87,13 +85,6 @@ def corr(


def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs):
out_named_axis = None
if _supports_named_axis(x) and not is_integer(axis):
# Named axis handling
raise NotImplementedError()

axis = regularize_axis(axis)

with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
x_layout, y_layout, weight_layout = ensure_same_backend(
ctx.unwrap(x, allow_record=False, primitive_policy="error"),
Expand Down Expand Up @@ -190,8 +181,13 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr
behavior=ctx.behavior,
attrs=ctx.attrs,
)
return ctx.wrap(
maybe_highlevel_to_lowlevel(sumwxy / ufuncs.sqrt(sumwxx * sumwyy)),

# propagate named axis to output
out = sumwxy / ufuncs.sqrt(sumwxx * sumwyy)
out_ctx = HighLevelContext(behavior=out.behavior, attrs=out.attrs).finalize()

return out_ctx.wrap(
maybe_highlevel_to_lowlevel(out),
highlevel=highlevel,
allow_other=True,
)
Loading

0 comments on commit c96f61e

Please sign in to comment.