Skip to content

Commit

Permalink
next batch of highlevel functions
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Sep 16, 2024
1 parent e338332 commit a07adfa
Show file tree
Hide file tree
Showing 33 changed files with 676 additions and 178 deletions.
9 changes: 5 additions & 4 deletions src/awkward/_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
T = TypeVar("T")

if TYPE_CHECKING:
from awkward._namedaxis import AttrsNamedAxisMapping
from awkward.highlevel import Array
from awkward.highlevel import Record as HighLevelRecord

Expand Down Expand Up @@ -56,9 +57,7 @@ def merge_mappings(


class HighLevelContext:
def __init__(
self, behavior: Mapping | None = None, attrs: Mapping[str, Any] | None = None
):
def __init__(self, behavior: Mapping | None = None, attrs: Mapping | None = None):
self._behavior = behavior
self._attrs = attrs
self._is_finalized = False
Expand All @@ -81,8 +80,10 @@ def _ensure_not_finalized(self):
raise RuntimeError("HighLevelContext has already been finalized")

@property
def attrs(self) -> Mapping[str, Any] | None:
def attrs(self) -> Mapping | AttrsNamedAxisMapping:
self._ensure_finalized()
if self._attrs is None:
self._attrs = {}
return self._attrs

@property
Expand Down
177 changes: 153 additions & 24 deletions src/awkward/_namedaxis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,94 @@
from awkward._regularize import is_integer

if tp.TYPE_CHECKING:
from awkward._layout import HighLevelContext
from awkward.highlevel import Array
pass


# axis names are hashables, mostly strings,
# except for integers, which are reserved for positional axis.
AxisName: tp.TypeAlias = tp.Hashable

AxisMapping: tp.TypeAlias = tp.Mapping[AxisName, int] # e.g.: {"x": 0, "y": 1, "z": 2}
AxisTuple: tp.TypeAlias = tp.Tuple[
AxisName, ...
] # e.g.: ("x", "y", None) where None is a wildcard
# e.g.: {"x": 0, "y": 1, "z": 2}
AxisMapping: tp.TypeAlias = tp.Mapping[AxisName, int]
# e.g.: ("x", "y", None) where None is a wildcard
AxisTuple: tp.TypeAlias = tp.Tuple[AxisName, ...]

_NamedAxisKey: str = "__named_axis__"

_NamedAxisKey: str = "__named_axis__" # reserved for named axis

def _supports_named_axis(array: Array | HighLevelContext) -> bool:
"""Check if the given array supports named axis.

class AttrsNamedAxisMapping(tp.TypedDict, total=False):
_NamedAxisKey: AxisMapping


@tp.runtime_checkable
class MaybeSupportsNamedAxis(tp.Protocol):
@property
def attrs(self) -> tp.Mapping | AttrsNamedAxisMapping: ...


def _get_named_axis(
ctx: MaybeSupportsNamedAxis | AttrsNamedAxisMapping | tp.Mapping,
) -> AxisTuple:
"""
Retrieves the named axis from the given context. The context can be an object that supports named axis
or a dictionary that includes a named axis mapping.
Args:
ctx (MaybeSupportsNamedAxis | AttrsNamedAxisMapping): The context from which to retrieve the named axis.
Returns:
AxisTuple: The named axis retrieved from the context. If the context does not include a named axis,
an empty tuple is returned.
Examples:
>>> class Test(MaybeSupportsNamedAxis):
... @property
... def attrs(self):
... return {_NamedAxisKey: {"x": 0, "y": 1, "z": 2}}
...
>>> _get_named_axis(Test())
("x", "y", "z")
>>> _get_named_axis({_NamedAxisKey: {"x": 0, "y": 1, "z": 2}})
("x", "y", "z")
>>> _get_named_axis({"other_key": "other_value"})
()
"""
if isinstance(ctx, MaybeSupportsNamedAxis):
return _get_named_axis(ctx.attrs)
elif isinstance(ctx, tp.Mapping) and _NamedAxisKey in ctx:
return _axis_mapping_to_tuple(ctx[_NamedAxisKey])
else:
return ()


def _supports_named_axis(ctx: MaybeSupportsNamedAxis | AttrsNamedAxisMapping) -> bool:
"""Check if the given ctx supports named axis.
Args:
array (Array): The array to check.
ctx (SupportsNamedAxis or AttrsNamedAxisMapping): The ctx to check.
Returns:
bool: True if the array supports named axis, False otherwise.
bool: True if the ctx supports named axis, False otherwise.
"""
return bool((getattr(array, "attrs", {}) or {}).get(_NamedAxisKey, {}))
return bool(_get_named_axis(ctx))


def _positional_axis_from_named_axis(named_axis: AxisTuple) -> tuple[int, ...]:
"""
Converts a named axis to a positional axis.
Args:
named_axis (AxisTuple): The named axis to convert.
Returns:
tuple[int, ...]: The positional axis corresponding to the named axis.
Examples:
>>> _positional_axis_from_named_axis(("x", "y", "z"))
(0, 1, 2)
"""
return tuple(range(len(named_axis)))


class TmpNamedAxisMarker:
Expand All @@ -41,16 +103,67 @@ class TmpNamedAxisMarker:


def _is_valid_named_axis(axis: AxisName) -> bool:
"""
Checks if the given axis is a valid named axis. A valid named axis is a hashable object that is not an integer.
Args:
axis (AxisName): The axis to check.
Returns:
bool: True if the axis is a valid named axis, False otherwise.
Examples:
>>> _is_valid_named_axis("x")
True
>>> _is_valid_named_axis(1)
False
"""
return isinstance(axis, AxisName) and not is_integer(axis)


def _check_valid_axis(axis: AxisName) -> AxisName:
"""
Checks if the given axis is a valid named axis. If not, raises a ValueError.
Args:
axis (AxisName): The axis to check.
Returns:
AxisName: The axis if it is a valid named axis.
Raises:
ValueError: If the axis is not a valid named axis.
Examples:
>>> _check_valid_axis("x")
"x"
>>> _check_valid_axis(1)
Traceback (most recent call last):
...
ValueError: Axis names must be hashable and not int, got 1
"""
if not _is_valid_named_axis(axis):
raise ValueError(f"Axis names must be hashable and not int, got {axis!r}")
return axis


def _check_axis_mapping_unique_values(axis_mapping: AxisMapping) -> None:
"""
Checks if the values in the given axis mapping are unique. If not, raises a ValueError.
Args:
axis_mapping (AxisMapping): The axis mapping to check.
Raises:
ValueError: If the values in the axis mapping are not unique.
Examples:
>>> _check_axis_mapping_unique_values({"x": 0, "y": 1, "z": 2})
>>> _check_axis_mapping_unique_values({"x": 0, "y": 0, "z": 2})
Traceback (most recent call last):
...
ValueError: Named axis mapping must be unique for each positional axis, got: {"x": 0, "y": 0, "z": 2}
"""
if len(set(axis_mapping.values())) != len(axis_mapping):
raise ValueError(
f"Named axis mapping must be unique for each positional axis, got: {axis_mapping}"
Expand Down Expand Up @@ -112,15 +225,13 @@ def _axis_mapping_to_tuple(axis_mapping: AxisMapping) -> AxisTuple:
def _any_axis_to_positional_axis(
axis: AxisName | AxisTuple,
named_axis: AxisTuple,
positional_axis: tuple[int, ...],
) -> AxisTuple | int | None:
"""
Converts any axis (int, AxisName, AxisTuple, or None) to a positional axis (int or AxisTuple).
Args:
axis (int | AxisName | AxisTuple | None): The axis to convert. Can be an integer, an AxisName, an AxisTuple, or None.
named_axis (AxisTuple): The named axis mapping to use for conversion.
positional_axis (tuple[int, ...]): The positional axis mapping to use for conversion.
Returns:
int | AxisTuple | None: The converted axis. Will be an integer, an AxisTuple, or None.
Expand All @@ -129,31 +240,27 @@ def _any_axis_to_positional_axis(
ValueError: If the axis is not found in the named axis mapping.
Examples:
>>> _any_axis_to_positional_axis("x", ("x", "y", "z"), (0, 1, 2))
>>> _any_axis_to_positional_axis("x", ("x", "y", "z"))
0
>>> _any_axis_to_positional_axis(("x", "z"), ("x", "y", "z"), (0, 1, 2))
>>> _any_axis_to_positional_axis(("x", "z"), ("x", "y", "z"))
(0, 2)
"""
if isinstance(axis, (tuple, list)):
return tuple(
_one_axis_to_positional_axis(ax, named_axis, positional_axis) for ax in axis
)
return tuple(_one_axis_to_positional_axis(ax, named_axis) for ax in axis)
else:
return _one_axis_to_positional_axis(axis, named_axis, positional_axis)
return _one_axis_to_positional_axis(axis, named_axis)


def _one_axis_to_positional_axis(
axis: AxisName | None,
named_axis: AxisTuple,
positional_axis: tuple[int, ...],
) -> int | None:
"""
Converts a single axis (int, AxisName, or None) to a positional axis (int or None).
Args:
axis (int | AxisName | None): The axis to convert. Can be an integer, an AxisName, or None.
named_axis (AxisTuple): The named axis mapping to use for conversion.
positional_axis (tuple[int, ...]): The positional axis mapping to use for conversion.
Returns:
int | None: The converted axis. Will be an integer or None.
Expand All @@ -162,9 +269,10 @@ def _one_axis_to_positional_axis(
ValueError: If the axis is not found in the named axis mapping.
Examples:
>>> _one_axis_to_positional_axis("x", ("x", "y", "z"), (0, 1, 2))
>>> _one_axis_to_positional_axis("x", ("x", "y", "z"))
0
"""
positional_axis = _positional_axis_from_named_axis(named_axis)
if isinstance(axis, int) or axis is None:
return axis
elif axis in named_axis:
Expand Down Expand Up @@ -229,11 +337,11 @@ def _set_named_axis_to_attrs(
# - "keep one" (_keep_named_axis): Keep one named axes in the output array, e.g.: `ak.firsts`
# - "remove all" (_remove_all_named_axis): Removes all named axis, e.g.: `ak.categories
# - "remove one" (_remove_named_axis): Remove the named axis from the output array, e.g.: `ak.sum`
# - "add one" (_add_named_axis): Add a new named axis to the output array, e.g.: `ak.concatenate, ak.singletons` (not clear yet...)
# - "unify" (_unify_named_axis): Unify the named axis in the output array given two input arrays, e.g.: `__add__`
# - "collapse" (_collapse_named_axis): Collapse multiple named axis to None in the output array, e.g.: `ak.flatten`
# - "permute" (_permute_named_axis): Permute the named axis in the output array, e.g.: `ak.transpose` (does this exist?)
# - "contract" (_contract_named_axis): Contract the named axis in the output array, e.g.: `matmul` (does this exist?)
# - "adding" (_adding_named_axis): Add a new named axis to the output array, e.g.: `ak.concatenate` (not clear yet...)


def _identity_named_axis(
Expand Down Expand Up @@ -328,6 +436,27 @@ def _remove_named_axis(
return tuple(name for i, name in enumerate(named_axis) if i != axis)


def _add_named_axis(
axis: int,
named_axis: AxisTuple,
) -> AxisTuple:
"""
Adds a wildcard named axis (None) to the named_axis after the position of the specified axis.
Args:
axis (int): The index after which to add the wildcard named axis.
named_axis (AxisTuple): The current named axis.
Returns:
AxisTuple: The new named axis after adding the wildcard named axis.
Examples:
>>> _add_named_axis(1, ("x", "y", "z"))
("x", "y", None, "z")
"""
return named_axis[: axis + 1] + (None,) + named_axis[axis + 1 :]


def _permute_named_axis(
axis: int,
named_axis: AxisTuple,
Expand Down
11 changes: 5 additions & 6 deletions src/awkward/highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from awkward._behavior import behavior_of, get_array_class, get_record_class
from awkward._layout import wrap_layout
from awkward._namedaxis import (
AttrsNamedAxisMapping,
AxisTuple,
_axis_mapping_to_tuple,
_NamedAxisKey,
_get_named_axis,
_supports_named_axis,
)
from awkward._nplikes.numpy import Numpy
Expand Down Expand Up @@ -363,7 +363,7 @@ def _update_class(self):
self.__class__ = get_array_class(self._layout, self._behavior)

@property
def attrs(self) -> Mapping[str, Any]:
def attrs(self) -> Mapping | AttrsNamedAxisMapping:
"""
The mutable mapping containing top-level metadata, which is serialised
with the array during pickling.
Expand Down Expand Up @@ -466,10 +466,9 @@ def positional_axis(self) -> tuple[int, ...]:
return tuple(range(self.ndim))

@property
def named_axis(self) -> AxisTuple:
def named_axis(self) -> AxisTuple | None:
if _supports_named_axis(self):
named_axis_mapping = self.attrs[_NamedAxisKey]
return _axis_mapping_to_tuple(named_axis_mapping)
return _get_named_axis(self)
else:
return (None,) * self.ndim

Expand Down
7 changes: 3 additions & 4 deletions src/awkward/operations/ak_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._namedaxis import (
_get_named_axis,
_is_valid_named_axis,
_keep_named_axis,
_one_axis_to_positional_axis,
Expand Down Expand Up @@ -82,14 +83,12 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
if _is_valid_named_axis(axis):
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _one_axis_to_positional_axis(
axis, array.named_axis, array.positional_axis
)
axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx))

# Step 2: propagate named axis from input to output,
# keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
# keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
out_named_axis = _keep_named_axis(array.named_axis, None)
out_named_axis = _keep_named_axis(_get_named_axis(ctx), None)
if not keepdims:
out_named_axis = _remove_named_axis(axis, out_named_axis)

Expand Down
7 changes: 3 additions & 4 deletions src/awkward/operations/ak_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._namedaxis import (
_get_named_axis,
_is_valid_named_axis,
_keep_named_axis,
_one_axis_to_positional_axis,
Expand Down Expand Up @@ -82,14 +83,12 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
if _is_valid_named_axis(axis):
# Handle named axis
# Step 1: Normalize named axis to positional axis
axis = _one_axis_to_positional_axis(
axis, array.named_axis, array.positional_axis
)
axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx))

# Step 2: propagate named axis from input to output,
# keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
# keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
out_named_axis = _keep_named_axis(array.named_axis, None)
out_named_axis = _keep_named_axis(_get_named_axis(ctx), None)
if not keepdims:
out_named_axis = _remove_named_axis(axis, out_named_axis)

Expand Down
Loading

0 comments on commit a07adfa

Please sign in to comment.