Skip to content

Commit

Permalink
fix type hints & safer control flow when checking for named axis
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Sep 13, 2024
1 parent c96f61e commit 7f7db85
Show file tree
Hide file tree
Showing 27 changed files with 254 additions and 225 deletions.
18 changes: 10 additions & 8 deletions src/awkward/_namedaxis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@

# axis names are hashables, mostly strings,
# except for integers, which are reserved for positional axis.
AxisName = tp.Hashable
AxisName: tp.TypeAlias = tp.Hashable

AxisMapping = tp.Mapping[AxisName, int] # e.g.: {"x": 0, "y": 1, "z": 2}
AxisTuple = tuple[
AxisName | int | None, ...
] # e.g.: ("x", "y", None) where None is a wildcard
AxisTuple = tuple[AxisName, ...] # e.g.: ("x", "y", None) where None is a wildcard

_NamedAxisKey: str = "__named_axis__"

Expand All @@ -40,8 +38,12 @@ class TmpNamedAxisMarker:
"""


def _is_valid_named_axis(axis: AxisName) -> bool:
return isinstance(axis, AxisName) and not is_integer(axis)


def _check_valid_axis(axis: AxisName) -> AxisName:
if not isinstance(axis, AxisName) and not is_integer(axis):
if not _is_valid_named_axis(axis):
raise ValueError(f"Axis names must be hashable and not int, got {axis!r}")
return axis

Expand Down Expand Up @@ -106,10 +108,10 @@ def _axis_mapping_to_tuple(axis_mapping: AxisMapping) -> AxisTuple:


def _any_axis_to_positional_axis(
axis: int | AxisName | AxisTuple | None,
axis: AxisName | AxisTuple,
named_axis: AxisTuple,
positional_axis: tuple[int, ...],
) -> int | AxisTuple | None:
) -> AxisTuple | int | None:
"""
Converts any axis (int, AxisName, AxisTuple, or None) to a positional axis (int or AxisTuple).
Expand Down Expand Up @@ -139,7 +141,7 @@ def _any_axis_to_positional_axis(


def _one_axis_to_positional_axis(
axis: int | AxisName | None,
axis: AxisName | None,
named_axis: AxisTuple,
positional_axis: tuple[int, ...],
) -> int | None:
Expand Down
31 changes: 16 additions & 15 deletions src/awkward/operations/ak_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._namedaxis import (
_check_valid_axis,
_identity_named_axis,
_is_valid_named_axis,
_keep_named_axis,
_one_axis_to_positional_axis,
_remove_named_axis,
_supports_named_axis,
Expand Down Expand Up @@ -78,19 +78,20 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

out_named_axis = None
if _supports_named_axis(ctx) and _check_valid_axis(axis):
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _one_axis_to_positional_axis(
axis, array.named_axis, array.positional_axis
)

# Step 2: propagate named axis from input to output,
# keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
# keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
out_named_axis = _identity_named_axis(array.named_axis)
if not keepdims:
out_named_axis = _remove_named_axis(axis, out_named_axis)
if _supports_named_axis(ctx):
if _is_valid_named_axis(axis):
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _one_axis_to_positional_axis(
axis, array.named_axis, array.positional_axis
)

# Step 2: propagate named axis from input to output,
# keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
# keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
out_named_axis = _keep_named_axis(array.named_axis, None)
if not keepdims:
out_named_axis = _remove_named_axis(axis, out_named_axis)

axis = regularize_axis(axis)

Expand Down
31 changes: 16 additions & 15 deletions src/awkward/operations/ak_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._namedaxis import (
_check_valid_axis,
_identity_named_axis,
_is_valid_named_axis,
_keep_named_axis,
_one_axis_to_positional_axis,
_remove_named_axis,
_supports_named_axis,
Expand Down Expand Up @@ -78,19 +78,20 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

out_named_axis = None
if _supports_named_axis(ctx) and _check_valid_axis(axis):
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _one_axis_to_positional_axis(
axis, array.named_axis, array.positional_axis
)

# Step 2: propagate named axis from input to output,
# keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
# keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
out_named_axis = _identity_named_axis(array.named_axis)
if not keepdims:
out_named_axis = _remove_named_axis(axis, out_named_axis)
if _supports_named_axis(ctx):
if _is_valid_named_axis(axis):
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _one_axis_to_positional_axis(
axis, array.named_axis, array.positional_axis
)

# Step 2: propagate named axis from input to output,
# keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
# keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
out_named_axis = _keep_named_axis(array.named_axis, None)
if not keepdims:
out_named_axis = _remove_named_axis(axis, out_named_axis)

axis = regularize_axis(axis)

Expand Down
31 changes: 16 additions & 15 deletions src/awkward/operations/ak_argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._namedaxis import (
_check_valid_axis,
_identity_named_axis,
_is_valid_named_axis,
_keep_named_axis,
_one_axis_to_positional_axis,
_remove_named_axis,
_supports_named_axis,
Expand Down Expand Up @@ -143,19 +143,20 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

out_named_axis = None
if _supports_named_axis(ctx or {}) and _check_valid_axis(axis):
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _one_axis_to_positional_axis(
axis, array.named_axis, array.positional_axis
)

# Step 2: propagate named axis from input to output,
# keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
# keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
out_named_axis = _identity_named_axis(array.named_axis)
if not keepdims:
out_named_axis = _remove_named_axis(axis, out_named_axis)
if _supports_named_axis(ctx):
if _is_valid_named_axis(axis):
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _one_axis_to_positional_axis(
axis, array.named_axis, array.positional_axis
)

# Step 2: propagate named axis from input to output,
# keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
# keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
out_named_axis = _keep_named_axis(array.named_axis, None)
if not keepdims:
out_named_axis = _remove_named_axis(axis, out_named_axis)

axis = regularize_axis(axis)

Expand Down
31 changes: 16 additions & 15 deletions src/awkward/operations/ak_argmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._namedaxis import (
_check_valid_axis,
_identity_named_axis,
_is_valid_named_axis,
_keep_named_axis,
_one_axis_to_positional_axis,
_remove_named_axis,
_supports_named_axis,
Expand Down Expand Up @@ -140,19 +140,20 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

out_named_axis = None
if _supports_named_axis(ctx or {}) and _check_valid_axis(axis):
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _one_axis_to_positional_axis(
axis, array.named_axis, array.positional_axis
)

# Step 2: propagate named axis from input to output,
# keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
# keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
out_named_axis = _identity_named_axis(array.named_axis)
if not keepdims:
out_named_axis = _remove_named_axis(axis, out_named_axis)
if _supports_named_axis(ctx):
if _is_valid_named_axis(axis):
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _one_axis_to_positional_axis(
axis, array.named_axis, array.positional_axis
)

# Step 2: propagate named axis from input to output,
# keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
# keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
out_named_axis = _keep_named_axis(array.named_axis, None)
if not keepdims:
out_named_axis = _remove_named_axis(axis, out_named_axis)

axis = regularize_axis(axis)

Expand Down
25 changes: 13 additions & 12 deletions src/awkward/operations/ak_argsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._namedaxis import (
_check_valid_axis,
_identity_named_axis,
_is_valid_named_axis,
_keep_named_axis,
_one_axis_to_positional_axis,
_supports_named_axis,
)
Expand Down Expand Up @@ -80,16 +80,17 @@ def _impl(array, axis, ascending, stable, highlevel, behavior, attrs):
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

out_named_axis = None
if _supports_named_axis(ctx or {}) and _check_valid_axis(axis):
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _one_axis_to_positional_axis(
axis, array.named_axis, array.positional_axis
)

# Step 2: propagate named axis from input to output,
# use strategy "keep all" (see: awkward._namedaxis)
out_named_axis = _identity_named_axis(array.named_axis)
if _supports_named_axis(ctx):
if _is_valid_named_axis(axis):
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _one_axis_to_positional_axis(
axis, array.named_axis, array.positional_axis
)

# Step 2: propagate named axis from input to output,
# use strategy "keep all" (see: awkward._namedaxis)
out_named_axis = _keep_named_axis(array.named_axis, None)

axis = regularize_axis(axis)

Expand Down
5 changes: 0 additions & 5 deletions src/awkward/operations/ak_cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,6 @@ def cartesian(


def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attrs):
out_named_axis = None
if _supports_named_axis(arrays[0]) and not is_integer(axis):
# Named axis handling
raise NotImplementedError()

axis = regularize_axis(axis)

with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
Expand Down
7 changes: 6 additions & 1 deletion src/awkward/operations/ak_corr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from __future__ import annotations

import awkward as ak
from awkward._attrs import attrs_of_obj
from awkward._behavior import behavior_of_obj
from awkward._dispatch import high_level_function
from awkward._layout import (
HighLevelContext,
Expand Down Expand Up @@ -184,7 +186,10 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr

# propagate named axis to output
out = sumwxy / ufuncs.sqrt(sumwxx * sumwyy)
out_ctx = HighLevelContext(behavior=out.behavior, attrs=out.attrs).finalize()
out_ctx = HighLevelContext(
behavior=behavior_of_obj(out),
attrs=attrs_of_obj(out),
).finalize()

return out_ctx.wrap(
maybe_highlevel_to_lowlevel(out),
Expand Down
31 changes: 16 additions & 15 deletions src/awkward/operations/ak_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._namedaxis import (
_check_valid_axis,
_identity_named_axis,
_is_valid_named_axis,
_keep_named_axis,
_one_axis_to_positional_axis,
_remove_named_axis,
_supports_named_axis,
Expand Down Expand Up @@ -120,19 +120,20 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

out_named_axis = None
if _supports_named_axis(ctx) and _check_valid_axis(axis):
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _one_axis_to_positional_axis(
axis, array.named_axis, array.positional_axis
)

# Step 2: propagate named axis from input to output,
# keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
# keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
out_named_axis = _identity_named_axis(array.named_axis)
if not keepdims:
out_named_axis = _remove_named_axis(axis, out_named_axis)
if _supports_named_axis(ctx):
if _is_valid_named_axis(axis):
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _one_axis_to_positional_axis(
axis, array.named_axis, array.positional_axis
)

# Step 2: propagate named axis from input to output,
# keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
# keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
out_named_axis = _keep_named_axis(array.named_axis, None)
if not keepdims:
out_named_axis = _remove_named_axis(axis, out_named_axis)

axis = regularize_axis(axis)

Expand Down
31 changes: 16 additions & 15 deletions src/awkward/operations/ak_count_nonzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._namedaxis import (
_check_valid_axis,
_identity_named_axis,
_is_valid_named_axis,
_keep_named_axis,
_one_axis_to_positional_axis,
_remove_named_axis,
_supports_named_axis,
Expand Down Expand Up @@ -79,19 +79,20 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")

out_named_axis = None
if _supports_named_axis(ctx) and _check_valid_axis(axis):
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _one_axis_to_positional_axis(
axis, array.named_axis, array.positional_axis
)

# Step 2: propagate named axis from input to output,
# keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
# keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
out_named_axis = _identity_named_axis(array.named_axis)
if not keepdims:
out_named_axis = _remove_named_axis(axis, out_named_axis)
if _supports_named_axis(ctx):
if _is_valid_named_axis(axis):
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _one_axis_to_positional_axis(
axis, array.named_axis, array.positional_axis
)

# Step 2: propagate named axis from input to output,
# keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
# keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
out_named_axis = _keep_named_axis(array.named_axis, None)
if not keepdims:
out_named_axis = _remove_named_axis(axis, out_named_axis)

axis = regularize_axis(axis)

Expand Down
Loading

0 comments on commit 7f7db85

Please sign in to comment.