Skip to content

Commit

Permalink
next batch of highlevel functions
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Sep 17, 2024
1 parent a07adfa commit 4c51e6b
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 16 deletions.
89 changes: 89 additions & 0 deletions src/awkward/_namedaxis.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,3 +569,92 @@ class Slicer:

def __getitem__(self, where):
return where



# Define a type alias for a slice or int (can be a single axis or a sequence of axes)
AxisSlice: tp.TypeAlias = tp.Union[tuple, slice, int, tp.EllipsisType, None]
NamedAxisSlice: tp.TypeAlias = tp.Dict[AxisName, AxisSlice]


def _normalize_slice(
where: AxisSlice | NamedAxisSlice | tuple[AxisSlice | NamedAxisSlice],
named_axis: AxisTuple,
) -> AxisSlice:
"""
Normalizes the given slice based on the named axis. The slice can be a dictionary mapping axis names to slices,
a tuple of slices, an ellipsis, or a single slice. The named axis is a tuple of axis names.
Args:
where (AxisSlice | NamedAxisSlice | tuple[AxisSlice | NamedAxisSlice]): The slice to normalize.
named_axis (AxisTuple): The named axis.
Returns:
AxisSlice: The normalized slice.
Examples:
>>> _normalize_slice({"x": slice(1, 5)}, ("x", "y", "z"))
(slice(1, 5, None), slice(None, None, None), slice(None, None, None))
>>> _normalize_slice((slice(1, 5), slice(2, 10)), ("x", "y", "z"))
(slice(1, 5, None), slice(2, 10, None))
>>> _normalize_slice(..., ("x", "y", "z"))
(slice(None, None, None), slice(None, None, None), slice(None, None, None))
>>> _normalize_slice(slice(1, 5), ("x", "y", "z"))
slice(1, 5, None)
"""
if isinstance(where, dict):
return tuple(where.get(axis, slice(None)) for axis in named_axis)
elif isinstance(where, tuple):
raise NotImplementedError()
return where


def _propagate_named_axis_through_slice(
where: AxisSlice,
named_axis: AxisTuple,
) -> AxisTuple:
"""
Propagate named axis based on where slice to output array.
Examples:
>>> _propagate_named_axis_through_slice(None, ("x", "y", "z"))
(None, "x", "y", "z")
>>> _propagate_named_axis_through_slice((..., None), ("x", "y", "z"))
("x", "y", "z", None)
>>> _propagate_named_axis_through_slice(0, ("x", "y", "z"))
("y", "z")
>>> _propagate_named_axis_through_slice(1, ("x", "y", "z"))
("x", "z")
>>> _propagate_named_axis_through_slice(2, ("x", "y", "z"))
("x", "y")
>>> _propagate_named_axis_through_slice(..., ("x", "y", "z"))
("x", "y", "z")
>>> _propagate_named_axis_through_slice(slice(0, 1), ("x", "y", "z"))
("x", "y", "z")
>>> _propagate_named_axis_through_slice((0, slice(0, 1)), ("x", "y", "z"))
("y", "z")
"""
if where is None:
return (None,) + named_axis
elif where is (..., None):
return named_axis + (None,)
elif where is Ellipsis:
return named_axis
elif isinstance(where, int):
return named_axis[:where] + named_axis[where+1:]
elif isinstance(where, slice):
return named_axis
elif isinstance(where, tuple):
return tuple(_propagate_named_axis_through_slice(w, named_axis) for w in where)
else:
raise ValueError("Invalid slice type")
26 changes: 21 additions & 5 deletions src/awkward/highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,11 +1079,27 @@ def __getitem__(self, where):
have the same dimension as the array being indexed.
"""
with ak._errors.SlicingErrorContext(self, where):
return wrap_layout(
prepare_layout(self._layout[where]),
self._behavior,
allow_other=True,
attrs=self._attrs,
# normalize for potential named axis
from awkward._namedaxis import _normalize_slice, _get_named_axis, _supports_named_axis

out_named_axis=None
if _supports_named_axis(self):
named_axis = _get_named_axis(self)

# Step 1: normalize the slice
where = _normalize_slice(where, named_axis)

# Step 2: propagate named axis to the output array
out_named_axis = named_axis

return ak.with_named_axis(
array=wrap_layout(
prepare_layout(self._layout[where]),
self._behavior,
allow_other=True,
attrs=self._attrs,
),
named_axis=out_named_axis,
)

def __bytes__(self) -> bytes:
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/operations/ak_to_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def to_backend(array, backend, *, highlevel=True, behavior=None, attrs=None):
"""
Args:
array: Array-like data (anything #ak.to_layout recognizes).
backend (`"cpu"`, `"cuda"`, or `"jax"`): If `"cpu"`, the array structure is
backend (`"cpu"`, `"cuda"`, `"jax"`, or `"typetracer"`): If `"cpu"`, the array structure is
recursively copied (if need be) to main memory for use with
the default Numpy backend; if `"cuda"`, the structure is copied
to the GPU(s) for use with CuPy. If `"jax"`, the structure is
Expand Down
77 changes: 67 additions & 10 deletions tests/test_2596_named_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,17 @@ def test_named_axis_ak_local_index():


def test_named_axis_ak_mask():
assert True
array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
mask = array > 3

named_array = ak.with_named_axis(array, ("x", "y"))
named_mask = named_array > 3

assert ak.all(ak.mask(array, mask) == ak.mask(named_array, mask))
assert ak.all(ak.mask(array, mask) == ak.mask(named_array, named_mask))

assert ak.mask(named_array, mask).named_axis == named_array.named_axis
assert ak.mask(named_array, named_mask).named_axis == named_array.named_axis


def test_named_axis_ak_max():
Expand Down Expand Up @@ -942,7 +952,11 @@ def test_named_axis_ak_to_arrow_table():


def test_named_axis_ak_to_backend():
assert True
array = ak.Array([[1, 2], [3], [], [4, 5, 6]])

named_array = ak.with_named_axis(array, ("x", "y"))

assert ak.to_backend(named_array, "typetracer").named_axis == named_array.named_axis


def test_named_axis_ak_to_buffers():
Expand Down Expand Up @@ -991,7 +1005,13 @@ def test_named_axis_ak_to_numpy():


def test_named_axis_ak_to_packed():
assert True
array = ak.Array([[1, 2], [3], [], [4, 5, 6]])

named_array = ak.with_named_axis(array, ("x", "y"))

assert ak.all(ak.to_packed(array) == ak.to_packed(named_array))

assert ak.to_packed(named_array).named_axis == named_array.named_axis


def test_named_axis_ak_to_parquet():
Expand Down Expand Up @@ -1059,7 +1079,13 @@ def test_named_axis_ak_validity_error():


def test_named_axis_ak_values_astype():
assert True
array = ak.Array([[1, 2, 3, 4], [], [5, 6, 7], [8, 9]])

named_array = ak.with_named_axis(array, ("x", "y"))

assert ak.all(ak.values_astype(array, np.float32) == ak.values_astype(named_array, np.float32))

assert ak.values_astype(named_array, np.float32).named_axis == named_array.named_axis


def test_named_axis_ak_var():
Expand All @@ -1071,32 +1097,63 @@ def test_named_axis_ak_where():


def test_named_axis_ak_with_field():
# skip
assert True


def test_named_axis_ak_with_name():
assert True
array = ak.Array([[1, 2], [3], [], [4, 5, 6]])

named_array = ak.with_named_axis(array, ("x", "y"))

assert ak.with_name(named_array, "new_name").named_axis == named_array.named_axis


def test_named_axis_ak_with_named_axis():
assert True
array = ak.Array([[1, 2], [3], [], [4, 5, 6]])

named_array = ak.with_named_axis(array, ("x", "y"))

assert named_array.named_axis == ("x", "y")


def test_named_axis_ak_with_parameter():
assert True
array = ak.Array([[1, 2], [3], [], [4, 5, 6]])

named_array = ak.with_named_axis(array, ("x", "y"))

assert ak.with_parameter(named_array, "param", 1.0).named_axis == named_array.named_axis


def test_named_axis_ak_without_field():
# skip
assert True


def test_named_axis_ak_without_parameters():
assert True
array = ak.Array([[1, 2], [3], [], [4, 5, 6]])

named_array = ak.with_named_axis(array, ("x", "y"))

named_array_with_parameteter = ak.with_parameter(named_array, "param", 1.0)

assert ak.without_parameters(named_array).named_axis == named_array.named_axis


def test_named_axis_ak_zeros_like():
assert True
array = ak.Array([[1, 2], [3], [], [4, 5, 6]])

named_array = ak.with_named_axis(array, ("x", "y"))

assert ak.all(ak.zeros_like(array) == ak.zeros_like(named_array))

assert ak.zeros_like(named_array).named_axis == named_array.named_axis


def test_named_axis_ak_zip():
named_array1 = ak.with_named_axis(ak.Array([1,2,3]), ("a",))
named_array2 = ak.with_named_axis(ak.Array([[4,5,6], [], [7]]), ("x", "y"))

record = ak.zip({"x": named_array1, "y": named_array2})

# TODO: need to implement broadcasting properly first
assert True

0 comments on commit 4c51e6b

Please sign in to comment.