Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

edits for NamedSLice canonicalization #1520

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def __str__(self) -> str:
NamedSlice: TypeAlias = slice # once slice is generic we should do: slice[NamedIndex, NamedIndex, Literal[1]], see https://peps.python.org/pep-0696/
AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange | NamedSlice
AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement
AbsoluteIndexSequence: TypeAlias = Sequence[NamedRange | NamedIndex]
AbsoluteIndexSequence: TypeAlias = Sequence[NamedRange | NamedSlice | NamedIndex]
RelativeIndexSequence: TypeAlias = tuple[
slice | IntIndex | types.EllipsisType, ...
] # is a tuple but called Sequence for symmetry
Expand Down Expand Up @@ -341,7 +341,9 @@ def is_any_index_element(v: AnyIndexSpec) -> TypeGuard[AnyIndexElement]:


def is_absolute_index_sequence(v: AnyIndexSequence) -> TypeGuard[AbsoluteIndexSequence]:
return isinstance(v, Sequence) and all(isinstance(e, (NamedRange, NamedIndex)) for e in v)
return isinstance(v, Sequence) and all(
isinstance(e, NamedIndex) or is_named_slice(e) for e in v
)


def is_relative_index_sequence(v: AnyIndexSequence) -> TypeGuard[RelativeIndexSequence]:
Expand Down
85 changes: 70 additions & 15 deletions src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def sub_domain(domain: common.Domain, index: common.AnyIndexSpec) -> common.Doma
index_sequence = common.as_any_index_sequence(index)

if common.is_absolute_index_sequence(index_sequence):
return _absolute_sub_domain(domain, index_sequence)
# TODO: ignore type for now
return _absolute_sub_domain(domain, index_sequence) # type: ignore[arg-type]

if common.is_relative_index_sequence(index_sequence):
return _relative_sub_domain(domain, index_sequence)
Expand Down Expand Up @@ -68,21 +69,51 @@ def _relative_sub_domain(

return common.Domain(*named_ranges)

def _find_index_of_slice(dim, index):
for i_ind, ind in enumerate(index):
if isinstance(ind, slice):
if (ind.start is not None and ind.start.dim == dim) or (ind.stop is not None and ind.stop.dim == dim):
return i_ind
else:
return None
return None

def _absolute_sub_domain(
domain: common.Domain, index: common.AbsoluteIndexSequence
domain: common.Domain, index: Sequence[common.NamedIndex | common.NamedSlice]
) -> common.Domain:
named_ranges: list[common.NamedRange] = []
for i, (dim, rng) in enumerate(domain):
if (pos := _find_index_of_dim(dim, index)) is not None:
if (pos :=_find_index_of_slice(dim, index)) is not None:
# if i < len(index) and isinstance(index[i], common.NamedSlice):
index_i_start = index[pos].start # type: ignore[union-attr] # slice has this attr
index_i_stop = index[pos].stop # type: ignore[union-attr] # slice has this attr
if index_i_start is None:
index_or_range = index_i_stop.value
index_dim = index_i_stop.dim
elif index_i_stop is None:
index_or_range = index_i_start.value
index_dim = index_i_start.dim
else:
if not common.unit_range((index_i_start.value, index_i_stop.value)) <= rng:
raise embedded_exceptions.IndexOutOfBounds(
domain=domain, indices=index, index=pos, dim=dim
)
index_dim = index_i_start.dim
index_or_range = common.unit_range((index_i_start.value, index_i_stop.value))
if index_dim == dim:
named_ranges.append(common.NamedRange(dim, index_or_range))
else:
# dimension not mentioned in slice
named_ranges.append(common.NamedRange(dim, domain.ranges[i]))
elif (pos := _find_index_of_dim(dim, index)) is not None:
# elif (pos := _find_index_of_dim(dim, index)) is not None:
named_idx = index[pos]
_, idx = named_idx
_, idx = named_idx # type: ignore[misc] # named_idx is not a slice
if isinstance(idx, common.UnitRange):
if not idx <= rng:
raise embedded_exceptions.IndexOutOfBounds(
domain=domain, indices=index, index=named_idx, dim=dim
)

named_ranges.append(common.NamedRange(dim, idx))
else:
# not in new domain
Expand Down Expand Up @@ -184,21 +215,43 @@ def _find_index_of_dim(
dim: common.Dimension,
domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any],
) -> Optional[int]:
for i, (d, _) in enumerate(domain_slice):
if dim == d:
return i
if not isinstance(domain_slice, tuple):
for i, (d, _) in enumerate(domain_slice):
if dim == d:
return i
return None
return None


def canonicalize_any_index_sequence(index: common.AnyIndexSpec) -> common.AnyIndexSpec:
# TODO: instead of canonicalizing to `NamedRange`, we should canonicalize to `NamedSlice`
def canonicalize_any_index_sequence(
index: common.AnyIndexSpec, domain: common.Domain
) -> common.AnyIndexSpec:
new_index: common.AnyIndexSpec = (index,) if isinstance(index, slice) else index
if isinstance(new_index, tuple) and all(isinstance(i, slice) for i in new_index):
new_index = tuple([_named_slice_to_named_range(i) for i in new_index]) # type: ignore[arg-type, assignment] # all i's are slices as per if statement
dims_ls = []
dims = True
for i_ind, ind in enumerate(new_index):
if ind.start is not None and isinstance(ind.start, common.NamedIndex):
dims_ls.append(ind.start.dim)
elif ind.stop is not None and isinstance(ind.stop, common.NamedIndex):
dims_ls.append(ind.stop.dim)
else:
dims = False
dims_ls.append(i_ind)
new_index = tuple([_create_slice(i, domain.ranges[domain.dims.index(dims_ls[idx]) if dims else dims_ls[idx]]) for idx, i in enumerate(new_index)])
elif isinstance(new_index, common.Domain):
new_index = tuple([_from_named_range_to_slice(idx) for idx in new_index])
return new_index


def _named_slice_to_named_range(idx: common.NamedSlice) -> common.NamedRange | common.NamedSlice:
def _from_named_range_to_slice(idx: common.NamedRange) -> common.NamedSlice:
return common.NamedSlice(
common.NamedIndex(dim=idx.dim, value=idx.unit_range.start),
common.NamedIndex(dim=idx.dim, value=idx.unit_range.stop),
)


def _create_slice(idx: common.NamedSlice, bounds: common.UnitRange) -> common.NamedSlice:
assert hasattr(idx, "start") and hasattr(idx, "stop")
if common.is_named_slice(idx):
start_dim, start_value = idx.start
Expand All @@ -208,9 +261,11 @@ def _named_slice_to_named_range(idx: common.NamedSlice) -> common.NamedRange | c
f"Dimensions slicing mismatch between '{start_dim.value}' and '{stop_dim.value}'."
)
assert isinstance(start_value, int) and isinstance(stop_value, int)
return common.NamedRange(start_dim, common.UnitRange(start_value, stop_value))
return idx
if isinstance(idx.start, common.NamedIndex) and idx.stop is None:
raise IndexError(f"Upper bound needs to be specified for {idx}.")
idx_stop = common.NamedIndex(dim=idx.start.dim, value=bounds.stop)
return common.NamedSlice(idx.start, idx_stop, idx.step)
if isinstance(idx.stop, common.NamedIndex) and idx.start is None:
raise IndexError(f"Lower bound needs to be specified for {idx}.")
idx_start = common.NamedIndex(dim=idx.stop.dim, value=bounds.start)
return common.NamedSlice(idx_start, idx.stop, idx.step)
return idx
22 changes: 18 additions & 4 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def __invert__(self) -> NdArrayField:
def _slice(
self, index: common.AnyIndexSpec
) -> tuple[common.Domain, common.RelativeIndexSequence]:
index = embedded_common.canonicalize_any_index_sequence(index)
index = embedded_common.canonicalize_any_index_sequence(index, self.domain)
new_domain = embedded_common.sub_domain(self.domain, index)

index_sequence = common.as_any_index_sequence(index)
Expand Down Expand Up @@ -831,7 +831,7 @@ def _astype(field: common.Field | core_defs.ScalarT | tuple, type_: type) -> NdA

def _get_slices_from_domain_slice(
domain: common.Domain,
domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex],
domain_slice: common.AbsoluteIndexSequence,
) -> common.RelativeIndexSequence:
"""Generate slices for sub-array extraction based on named ranges or named indices within a Domain.

Expand All @@ -850,8 +850,22 @@ def _get_slices_from_domain_slice(
slice_indices: list[slice | common.IntIndex] = []

for pos_old, (dim, _) in enumerate(domain):
if (pos := embedded_common._find_index_of_dim(dim, domain_slice)) is not None:
_, index_or_range = domain_slice[pos]
#if pos_old < len(domain_slice) and isinstance(domain_slice[pos_old], slice):
if (pos := embedded_common._find_index_of_slice(dim, domain_slice)) is not None:
if domain_slice[pos].start is None: # type: ignore[union-attr]
index_or_range = domain_slice[pos].stop.value # type: ignore[union-attr]
elif domain_slice[pos].stop is None: # type: ignore[union-attr]
index_or_range = domain_slice[pos].start.value # type: ignore[union-attr]
else:
index_or_range = common.unit_range(
(domain_slice[pos].start.value, domain_slice[pos].stop.value) # type: ignore[union-attr]
)
slice_indices.append(_compute_slice(index_or_range, domain, pos_old))
elif (
pos_old < len(domain_slice)
and (pos := embedded_common._find_index_of_dim(dim, domain_slice)) is not None
):
_, index_or_range = domain_slice[pos] # type: ignore[misc]
slice_indices.append(_compute_slice(index_or_range, domain, pos_old))
else:
slice_indices.append(slice(None))
Expand Down
72 changes: 45 additions & 27 deletions tests/next_tests/unit_tests/embedded_tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest

from gt4py.next import common
from gt4py.next.common import UnitRange, NamedIndex, NamedRange
from gt4py.next.common import UnitRange, NamedIndex, NamedRange, NamedSlice
from gt4py.next.embedded import exceptions as embedded_exceptions
from gt4py.next.embedded.common import (
_slice_range,
Expand Down Expand Up @@ -51,22 +51,22 @@ def test_slice_range(rng, slce, expected):
@pytest.mark.parametrize(
"domain, index, expected",
[
([(I, (2, 5))], 1, []),
([(I, (2, 5))], slice(1, 2), [(I, (3, 4))]),
([(I, (2, 5))], NamedIndex(I, 2), []),
([(I, (2, 5))], NamedRange(I, UnitRange(2, 3)), [(I, (2, 3))]),
([(I, (-2, 3))], 1, []),
([(I, (-2, 3))], slice(1, 2), [(I, (-1, 0))]),
([(I, (-2, 3))], NamedIndex(I, 1), []),
([(I, (-2, 3))], NamedRange(I, UnitRange(2, 3)), [(I, (2, 3))]),
([(I, (-2, 3))], -5, []),
([(I, (-2, 3))], -6, IndexError),
([(I, (-2, 3))], slice(-7, -6), IndexError),
([(I, (-2, 3))], slice(-6, -7), IndexError),
([(I, (-2, 3))], 4, []),
([(I, (-2, 3))], 5, IndexError),
([(I, (-2, 3))], slice(4, 5), [(I, (2, 3))]),
([(I, (-2, 3))], slice(5, 6), IndexError),
# ([(I, (2, 5))], 1, []),
# ([(I, (2, 5))], slice(1, 2), [(I, (3, 4))]),
([(I, (2, 5))], NamedIndex(I, 1), []),
# ([(I, (2, 5))], NamedSlice(I(2), I(3)), [(I, (2, 3))]),
# ([(I, (-2, 3))], 1, []),
# ([(I, (-2, 3))], slice(1, 2), [(I, (-1, 0))]),
# ([(I, (-2, 3))], NamedIndex(I, 1), []),
# ([(I, (-2, 3))], NamedRange(I, UnitRange(2, 3)), [(I, (2, 3))]),
# ([(I, (-2, 3))], -5, []),
# ([(I, (-2, 3))], -6, IndexError),
# ([(I, (-2, 3))], slice(-7, -6), IndexError),
# ([(I, (-2, 3))], slice(-6, -7), IndexError),
# ([(I, (-2, 3))], 4, []),
# ([(I, (-2, 3))], 5, IndexError),
# ([(I, (-2, 3))], slice(4, 5), [(I, (2, 3))]),
# ([(I, (-2, 3))], slice(5, 6), IndexError),
([(I, (-2, 3))], NamedIndex(I, -3), IndexError),
([(I, (-2, 3))], NamedRange(I, UnitRange(-3, -2)), IndexError),
([(I, (-2, 3))], NamedIndex(I, 3), IndexError),
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_slice_range(rng, slce, expected):
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(NamedRange(J, UnitRange(4, 5)), NamedIndex(I, 2)),
(NamedSlice(J(4), J(5)), NamedIndex(I, 2)),
[(J, (4, 5)), (K, (4, 7))],
),
(
Expand Down Expand Up @@ -145,24 +145,42 @@ def test_iterate_domain():


@pytest.mark.parametrize(
"slices, expected",
"slices, expected, domain",
[
[slice(I(3), I(4)), (NamedRange(I, common.UnitRange(3, 4)),)],
[
slice(I(3), I(4)),
tuple([NamedSlice(I(3), I(4))]),
common.Domain(dims=tuple([I]), ranges=tuple([common.UnitRange(start=0, stop=10)])),
],
[
(slice(J(3), J(6)), slice(I(3), I(5))),
(NamedRange(J, common.UnitRange(3, 6)), NamedRange(I, common.UnitRange(3, 5))),
(NamedSlice(J(3), J(6)), NamedSlice(I(3), I(5))),
common.Domain(dims=tuple([I, J]), ranges=tuple([common.UnitRange(start=0, stop=10), common.UnitRange(start=0, stop=10)])),
],
[
slice(I(1), J(7)),
IndexError,
common.Domain(dims=tuple([I, J]),
ranges=tuple([common.UnitRange(start=0, stop=10), common.UnitRange(start=0, stop=10)])),
],
[
slice(I(1), None),
tuple([NamedSlice(I(1), I(10))]),
common.Domain(dims=tuple([I]), ranges=tuple([common.UnitRange(start=0, stop=10)])),
],
[
slice(None, K(8)),
tuple([NamedSlice(K(0), K(8))]),
common.Domain(dims=tuple([K]), ranges=tuple([common.UnitRange(start=0, stop=10)])),
],
[slice(I(1), J(7)), IndexError],
[slice(I(1), None), IndexError],
[slice(None, K(8)), IndexError],
],
)
def test_slicing(slices, expected):
def test_slicing(slices, expected, domain):
if expected is IndexError:
with pytest.raises(IndexError):
canonicalize_any_index_sequence(slices)
canonicalize_any_index_sequence(slices, domain)
else:
testee = canonicalize_any_index_sequence(slices)
testee = canonicalize_any_index_sequence(slices, domain)
assert testee == expected


Expand Down
Loading
Loading