diff --git a/src/awkward/_namedaxis.py b/src/awkward/_namedaxis.py index a91dd87fd4..7da5dcd8fd 100644 --- a/src/awkward/_namedaxis.py +++ b/src/awkward/_namedaxis.py @@ -571,7 +571,6 @@ def __getitem__(self, where): return where - # Define a type alias for a slice or int (can be a single axis or a sequence of axes) AxisSlice: tp.TypeAlias = tp.Union[tuple, slice, int, tp.EllipsisType, None] NamedAxisSlice: tp.TypeAlias = tp.Dict[AxisName, AxisSlice] @@ -646,12 +645,12 @@ def _propagate_named_axis_through_slice( """ if where is None: return (None,) + named_axis - elif where is (..., None): + elif where == (..., None): return named_axis + (None,) elif where is Ellipsis: return named_axis elif isinstance(where, int): - return named_axis[:where] + named_axis[where+1:] + return named_axis[:where] + named_axis[where + 1 :] elif isinstance(where, slice): return named_axis elif isinstance(where, tuple): diff --git a/src/awkward/highlevel.py b/src/awkward/highlevel.py index 1dc8b3262a..aff90e90ea 100644 --- a/src/awkward/highlevel.py +++ b/src/awkward/highlevel.py @@ -1080,9 +1080,13 @@ def __getitem__(self, where): """ with ak._errors.SlicingErrorContext(self, where): # normalize for potential named axis - from awkward._namedaxis import _normalize_slice, _get_named_axis, _supports_named_axis + from awkward._namedaxis import ( + _get_named_axis, + _normalize_slice, + _supports_named_axis, + ) - out_named_axis=None + out_named_axis = None if _supports_named_axis(self): named_axis = _get_named_axis(self) diff --git a/tests/test_2596_named_axis.py b/tests/test_2596_named_axis.py index 4b2783531b..89270cc341 100644 --- a/tests/test_2596_named_axis.py +++ b/tests/test_2596_named_axis.py @@ -1011,7 +1011,7 @@ def test_named_axis_ak_to_packed(): assert ak.all(ak.to_packed(array) == ak.to_packed(named_array)) - assert ak.to_packed(named_array).named_axis == named_array.named_axis + assert ak.to_packed(named_array).named_axis == named_array.named_axis def test_named_axis_ak_to_parquet(): @@ -1083,9 +1083,13 @@ def test_named_axis_ak_values_astype(): named_array = ak.with_named_axis(array, ("x", "y")) - assert ak.all(ak.values_astype(array, np.float32) == ak.values_astype(named_array, np.float32)) + assert ak.all( + ak.values_astype(array, np.float32) == ak.values_astype(named_array, np.float32) + ) - assert ak.values_astype(named_array, np.float32).named_axis == named_array.named_axis + assert ( + ak.values_astype(named_array, np.float32).named_axis == named_array.named_axis + ) def test_named_axis_ak_var(): @@ -1100,6 +1104,7 @@ def test_named_axis_ak_with_field(): # skip assert True + def test_named_axis_ak_with_name(): array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) @@ -1121,7 +1126,10 @@ def test_named_axis_ak_with_parameter(): named_array = ak.with_named_axis(array, ("x", "y")) - assert ak.with_parameter(named_array, "param", 1.0).named_axis == named_array.named_axis + assert ( + ak.with_parameter(named_array, "param", 1.0).named_axis + == named_array.named_axis + ) def test_named_axis_ak_without_field(): @@ -1150,8 +1158,8 @@ def test_named_axis_ak_zeros_like(): def test_named_axis_ak_zip(): - named_array1 = ak.with_named_axis(ak.Array([1,2,3]), ("a",)) - named_array2 = ak.with_named_axis(ak.Array([[4,5,6], [], [7]]), ("x", "y")) + named_array1 = ak.with_named_axis(ak.Array([1, 2, 3]), ("a",)) + named_array2 = ak.with_named_axis(ak.Array([[4, 5, 6], [], [7]]), ("x", "y")) record = ak.zip({"x": named_array1, "y": named_array2})