Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Sep 12, 2024
1 parent 556362e commit 16a119b
Show file tree
Hide file tree
Showing 11 changed files with 210 additions and 55 deletions.
1 change: 0 additions & 1 deletion src/awkward/_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
)
from awkward._backends.numpy import NumpyBackend
from awkward._behavior import behavior_of
from awkward._namedaxis import AxisMapping, AxisTuple
from awkward._nplikes.dispatch import nplike_of_obj
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpy_like import NumpyMetadata
Expand Down
6 changes: 4 additions & 2 deletions src/awkward/_namedaxis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from awkward._regularize import is_integer

if tp.TYPE_CHECKING:
from awkward.highlevel import Array
from awkward._layout import HighLevelContext
from awkward.highlevel import Array


# axis names are hashables, mostly strings,
Expand Down Expand Up @@ -48,7 +48,9 @@ def _check_valid_axis(axis: AxisName) -> AxisName:

def _check_axis_mapping_unique_values(axis_mapping: AxisMapping) -> None:
if len(set(axis_mapping.values())) != len(axis_mapping):
raise ValueError(f"Named axis mapping must be unique for each positional axis, got: {axis_mapping}")
raise ValueError(
f"Named axis mapping must be unique for each positional axis, got: {axis_mapping}"
)


def _axis_tuple_to_mapping(axis_tuple: AxisTuple) -> AxisMapping:
Expand Down
7 changes: 6 additions & 1 deletion src/awkward/highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
from awkward._backends.numpy import NumpyBackend
from awkward._behavior import behavior_of, get_array_class, get_record_class
from awkward._layout import wrap_layout
from awkward._namedaxis import _axis_tuple_to_mapping, _axis_mapping_to_tuple, AxisMapping, AxisTuple, _NamedAxisKey, _set_named_axis_to_attrs, _supports_named_axis
from awkward._namedaxis import (
AxisTuple,
_axis_mapping_to_tuple,
_NamedAxisKey,
_supports_named_axis,
)
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._operators import NDArrayOperatorsMixin
Expand Down
13 changes: 10 additions & 3 deletions src/awkward/operations/ak_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._namedaxis import _supports_named_axis, _check_valid_axis, _one_axis_to_positional_axis, _remove_named_axis, _identity_named_axis
from awkward._namedaxis import (
_check_valid_axis,
_identity_named_axis,
_one_axis_to_positional_axis,
_remove_named_axis,
_supports_named_axis,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis

__all__ = ("all",)

Expand Down Expand Up @@ -75,7 +80,9 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
if _supports_named_axis(ctx) and _check_valid_axis(axis):
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _one_axis_to_positional_axis(axis, array.named_axis, array.positional_axis)
axis = _one_axis_to_positional_axis(
axis, array.named_axis, array.positional_axis
)

# Step 2: propagate named axis from input to output,
# keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
Expand Down
13 changes: 10 additions & 3 deletions src/awkward/operations/ak_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._namedaxis import _supports_named_axis, _check_valid_axis, _identity_named_axis, _remove_named_axis, _one_axis_to_positional_axis
from awkward._namedaxis import (
_check_valid_axis,
_identity_named_axis,
_one_axis_to_positional_axis,
_remove_named_axis,
_supports_named_axis,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer, regularize_axis

__all__ = ("any",)

Expand Down Expand Up @@ -75,7 +80,9 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
if _supports_named_axis(ctx) and _check_valid_axis(axis):
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _one_axis_to_positional_axis(axis, array.named_axis, array.positional_axis)
axis = _one_axis_to_positional_axis(
axis, array.named_axis, array.positional_axis
)

# Step 2: propagate named axis from input to output,
# keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
Expand Down
13 changes: 10 additions & 3 deletions src/awkward/operations/ak_argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._namedaxis import _supports_named_axis, _one_axis_to_positional_axis, _check_valid_axis, _identity_named_axis, _remove_named_axis
from awkward._namedaxis import (
_check_valid_axis,
_identity_named_axis,
_one_axis_to_positional_axis,
_remove_named_axis,
_supports_named_axis,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer, regularize_axis

__all__ = ("argmax", "nanargmax")

Expand Down Expand Up @@ -140,7 +145,9 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
if _supports_named_axis(ctx or {}) and _check_valid_axis(axis):
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _one_axis_to_positional_axis(axis, array.named_axis, array.positional_axis)
axis = _one_axis_to_positional_axis(
axis, array.named_axis, array.positional_axis
)

# Step 2: propagate named axis from input to output,
# keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
Expand Down
13 changes: 10 additions & 3 deletions src/awkward/operations/ak_argmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._namedaxis import _supports_named_axis, _check_valid_axis, _one_axis_to_positional_axis, _remove_named_axis, _identity_named_axis
from awkward._namedaxis import (
_check_valid_axis,
_identity_named_axis,
_one_axis_to_positional_axis,
_remove_named_axis,
_supports_named_axis,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer, regularize_axis

__all__ = ("argmin", "nanargmin")

Expand Down Expand Up @@ -137,7 +142,9 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
if _supports_named_axis(ctx or {}) and _check_valid_axis(axis):
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _one_axis_to_positional_axis(axis, array.named_axis, array.positional_axis)
axis = _one_axis_to_positional_axis(
axis, array.named_axis, array.positional_axis
)

# Step 2: propagate named axis from input to output,
# keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
Expand Down
13 changes: 9 additions & 4 deletions src/awkward/operations/ak_argsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._namedaxis import _supports_named_axis, _check_valid_axis, _identity_named_axis, _remove_named_axis, _one_axis_to_positional_axis
from awkward._namedaxis import (
_check_valid_axis,
_identity_named_axis,
_one_axis_to_positional_axis,
_supports_named_axis,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer, regularize_axis

__all__ = ("argsort",)

Expand Down Expand Up @@ -78,13 +82,14 @@ def _impl(array, axis, ascending, stable, highlevel, behavior, attrs):
if _supports_named_axis(ctx or {}) and _check_valid_axis(axis):
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _one_axis_to_positional_axis(axis, array.named_axis, array.positional_axis)
axis = _one_axis_to_positional_axis(
axis, array.named_axis, array.positional_axis
)

# Step 2: propagate named axis from input to output,
# use strategy "keep all" (see: awkward._namedaxis)
out_named_axis = _identity_named_axis(array.named_axis)


if not isinstance(axis, int) and axis is not None:
raise TypeError(f"'axis' must be an integer or None by now, not {axis!r}")

Expand Down
9 changes: 7 additions & 2 deletions src/awkward/operations/ak_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._namedaxis import _one_axis_to_positional_axis, _remove_named_axis, _supports_named_axis, _identity_named_axis, _check_valid_axis
from awkward._namedaxis import (
_check_valid_axis,
_identity_named_axis,
_one_axis_to_positional_axis,
_remove_named_axis,
_supports_named_axis,
)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_integer

__all__ = ("max", "nanmax")

Expand Down
17 changes: 12 additions & 5 deletions src/awkward/operations/ak_with_named_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._namedaxis import AxisMapping, AxisTuple, _set_named_axis_to_attrs, _supports_named_axis, _NamedAxisKey
from awkward._namedaxis import (
AxisMapping,
AxisTuple,
_NamedAxisKey,
_set_named_axis_to_attrs,
)
from awkward._nplikes.numpy_like import NumpyMetadata

__all__ = ("with_named_axis",)
Expand Down Expand Up @@ -53,7 +58,7 @@ def with_named_axis(


def _impl(array, named_axis, highlevel, behavior, attrs):
if not named_axis: # no-op, e.g. if named_axis is None or () or {}
if not named_axis: # no-op, e.g. if named_axis is None or () or {}
return array

with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
Expand All @@ -66,11 +71,13 @@ def _impl(array, named_axis, highlevel, behavior, attrs):
for k, i in named_axis.items():
if not isinstance(i, int):
raise TypeError(f"named_axis must map axis name to integer, not {i}")
if i < 0: # handle negative axis index
if i < 0: # handle negative axis index
i += ndim
if i < 0 or i >= ndim:
raise ValueError(f"named_axis index out of range: {i} not in [0, {ndim})")
_named_axis = _named_axis[:i] + (k,) + _named_axis[i+1:]
raise ValueError(
f"named_axis index out of range: {i} not in [0, {ndim})"
)
_named_axis = _named_axis[:i] + (k,) + _named_axis[i + 1 :]
elif isinstance(named_axis, tuple):
_named_axis = named_axis
else:
Expand Down
Loading

0 comments on commit 16a119b

Please sign in to comment.