From 4c51e6b3d0344d2145d8717cd03997357376ae89 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Tue, 17 Sep 2024 09:48:10 -0400 Subject: [PATCH] next batch of highlevel functions --- src/awkward/_namedaxis.py | 89 +++++++++++++++++++++++++ src/awkward/highlevel.py | 26 ++++++-- src/awkward/operations/ak_to_backend.py | 2 +- tests/test_2596_named_axis.py | 77 ++++++++++++++++++--- 4 files changed, 178 insertions(+), 16 deletions(-) diff --git a/src/awkward/_namedaxis.py b/src/awkward/_namedaxis.py index f79c6ca07c..a91dd87fd4 100644 --- a/src/awkward/_namedaxis.py +++ b/src/awkward/_namedaxis.py @@ -569,3 +569,92 @@ class Slicer: def __getitem__(self, where): return where + + + +# Define a type alias for a slice or int (can be a single axis or a sequence of axes) +AxisSlice: tp.TypeAlias = tp.Union[tuple, slice, int, tp.EllipsisType, None] +NamedAxisSlice: tp.TypeAlias = tp.Dict[AxisName, AxisSlice] + + +def _normalize_slice( + where: AxisSlice | NamedAxisSlice | tuple[AxisSlice | NamedAxisSlice], + named_axis: AxisTuple, +) -> AxisSlice: + """ + Normalizes the given slice based on the named axis. The slice can be a dictionary mapping axis names to slices, + a tuple of slices, an ellipsis, or a single slice. The named axis is a tuple of axis names. + + Args: + where (AxisSlice | NamedAxisSlice | tuple[AxisSlice | NamedAxisSlice]): The slice to normalize. + named_axis (AxisTuple): The named axis. + + Returns: + AxisSlice: The normalized slice. + + Examples: + >>> _normalize_slice({"x": slice(1, 5)}, ("x", "y", "z")) + (slice(1, 5, None), slice(None, None, None), slice(None, None, None)) + + >>> _normalize_slice((slice(1, 5), slice(2, 10)), ("x", "y", "z")) + (slice(1, 5, None), slice(2, 10, None)) + + >>> _normalize_slice(..., ("x", "y", "z")) + (slice(None, None, None), slice(None, None, None), slice(None, None, None)) + + >>> _normalize_slice(slice(1, 5), ("x", "y", "z")) + slice(1, 5, None) + """ + if isinstance(where, dict): + return tuple(where.get(axis, slice(None)) for axis in named_axis) + elif isinstance(where, tuple): + raise NotImplementedError() + return where + + +def _propagate_named_axis_through_slice( + where: AxisSlice, + named_axis: AxisTuple, +) -> AxisTuple: + """ + Propagate named axis based on where slice to output array. + + Examples: + >>> _propagate_named_axis_through_slice(None, ("x", "y", "z")) + (None, "x", "y", "z") + + >>> _propagate_named_axis_through_slice((..., None), ("x", "y", "z")) + ("x", "y", "z", None) + + >>> _propagate_named_axis_through_slice(0, ("x", "y", "z")) + ("y", "z") + + >>> _propagate_named_axis_through_slice(1, ("x", "y", "z")) + ("x", "z") + + >>> _propagate_named_axis_through_slice(2, ("x", "y", "z")) + ("x", "y") + + >>> _propagate_named_axis_through_slice(..., ("x", "y", "z")) + ("x", "y", "z") + + >>> _propagate_named_axis_through_slice(slice(0, 1), ("x", "y", "z")) + ("x", "y", "z") + + >>> _propagate_named_axis_through_slice((0, slice(0, 1)), ("x", "y", "z")) + ("y", "z") + """ + if where is None: + return (None,) + named_axis + elif where is (..., None): + return named_axis + (None,) + elif where is Ellipsis: + return named_axis + elif isinstance(where, int): + return named_axis[:where] + named_axis[where+1:] + elif isinstance(where, slice): + return named_axis + elif isinstance(where, tuple): + return tuple(_propagate_named_axis_through_slice(w, named_axis) for w in where) + else: + raise ValueError("Invalid slice type") diff --git a/src/awkward/highlevel.py b/src/awkward/highlevel.py index c6a23f470d..1dc8b3262a 100644 --- a/src/awkward/highlevel.py +++ b/src/awkward/highlevel.py @@ -1079,11 +1079,27 @@ def __getitem__(self, where): have the same dimension as the array being indexed. """ with ak._errors.SlicingErrorContext(self, where): - return wrap_layout( - prepare_layout(self._layout[where]), - self._behavior, - allow_other=True, - attrs=self._attrs, + # normalize for potential named axis + from awkward._namedaxis import _normalize_slice, _get_named_axis, _supports_named_axis + + out_named_axis=None + if _supports_named_axis(self): + named_axis = _get_named_axis(self) + + # Step 1: normalize the slice + where = _normalize_slice(where, named_axis) + + # Step 2: propagate named axis to the output array + out_named_axis = named_axis + + return ak.with_named_axis( + array=wrap_layout( + prepare_layout(self._layout[where]), + self._behavior, + allow_other=True, + attrs=self._attrs, + ), + named_axis=out_named_axis, ) def __bytes__(self) -> bytes: diff --git a/src/awkward/operations/ak_to_backend.py b/src/awkward/operations/ak_to_backend.py index f65a2c0a81..8d93e2de94 100644 --- a/src/awkward/operations/ak_to_backend.py +++ b/src/awkward/operations/ak_to_backend.py @@ -17,7 +17,7 @@ def to_backend(array, backend, *, highlevel=True, behavior=None, attrs=None): """ Args: array: Array-like data (anything #ak.to_layout recognizes). - backend (`"cpu"`, `"cuda"`, or `"jax"`): If `"cpu"`, the array structure is + backend (`"cpu"`, `"cuda"`, `"jax"`, or `"typetracer"`): If `"cpu"`, the array structure is recursively copied (if need be) to main memory for use with the default Numpy backend; if `"cuda"`, the structure is copied to the GPU(s) for use with CuPy. If `"jax"`, the structure is diff --git a/tests/test_2596_named_axis.py b/tests/test_2596_named_axis.py index 15f8c9a449..4b2783531b 100644 --- a/tests/test_2596_named_axis.py +++ b/tests/test_2596_named_axis.py @@ -632,7 +632,17 @@ def test_named_axis_ak_local_index(): def test_named_axis_ak_mask(): - assert True + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + mask = array > 3 + + named_array = ak.with_named_axis(array, ("x", "y")) + named_mask = named_array > 3 + + assert ak.all(ak.mask(array, mask) == ak.mask(named_array, mask)) + assert ak.all(ak.mask(array, mask) == ak.mask(named_array, named_mask)) + + assert ak.mask(named_array, mask).named_axis == named_array.named_axis + assert ak.mask(named_array, named_mask).named_axis == named_array.named_axis def test_named_axis_ak_max(): @@ -942,7 +952,11 @@ def test_named_axis_ak_to_arrow_table(): def test_named_axis_ak_to_backend(): - assert True + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.to_backend(named_array, "typetracer").named_axis == named_array.named_axis def test_named_axis_ak_to_buffers(): @@ -991,7 +1005,13 @@ def test_named_axis_ak_to_numpy(): def test_named_axis_ak_to_packed(): - assert True + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.to_packed(array) == ak.to_packed(named_array)) + + assert ak.to_packed(named_array).named_axis == named_array.named_axis def test_named_axis_ak_to_parquet(): @@ -1059,7 +1079,13 @@ def test_named_axis_ak_validity_error(): def test_named_axis_ak_values_astype(): - assert True + array = ak.Array([[1, 2, 3, 4], [], [5, 6, 7], [8, 9]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.values_astype(array, np.float32) == ak.values_astype(named_array, np.float32)) + + assert ak.values_astype(named_array, np.float32).named_axis == named_array.named_axis def test_named_axis_ak_var(): @@ -1071,32 +1097,63 @@ def test_named_axis_ak_where(): def test_named_axis_ak_with_field(): + # skip assert True - def test_named_axis_ak_with_name(): - assert True + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.with_name(named_array, "new_name").named_axis == named_array.named_axis def test_named_axis_ak_with_named_axis(): - assert True + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert named_array.named_axis == ("x", "y") def test_named_axis_ak_with_parameter(): - assert True + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.with_parameter(named_array, "param", 1.0).named_axis == named_array.named_axis def test_named_axis_ak_without_field(): + # skip assert True def test_named_axis_ak_without_parameters(): - assert True + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + named_array_with_parameteter = ak.with_parameter(named_array, "param", 1.0) + + assert ak.without_parameters(named_array).named_axis == named_array.named_axis def test_named_axis_ak_zeros_like(): - assert True + array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + + named_array = ak.with_named_axis(array, ("x", "y")) + + assert ak.all(ak.zeros_like(array) == ak.zeros_like(named_array)) + + assert ak.zeros_like(named_array).named_axis == named_array.named_axis def test_named_axis_ak_zip(): + named_array1 = ak.with_named_axis(ak.Array([1,2,3]), ("a",)) + named_array2 = ak.with_named_axis(ak.Array([[4,5,6], [], [7]]), ("x", "y")) + + record = ak.zip({"x": named_array1, "y": named_array2}) + + # TODO: need to implement broadcasting properly first assert True