diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 0aa19b20ae..43704b7931 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -277,9 +277,17 @@ def unit_range(r: RangeLike) -> UnitRange: raise ValueError(f"'{r!r}' cannot be interpreted as 'UnitRange'.") +@dataclasses.dataclass(frozen=True) +class NamedRange: + dims: Dimension + urange: UnitRange + + def __str__(self) -> str: + return f"{self.dims}={self.urange}" + + IntIndex: TypeAlias = int | core_defs.IntegralScalar NamedIndex: TypeAlias = tuple[Dimension, IntIndex] # TODO: convert to NamedTuple -NamedRange: TypeAlias = tuple[Dimension, UnitRange] # TODO: convert to NamedTuple FiniteNamedRange: TypeAlias = tuple[Dimension, FiniteUnitRange] # TODO: convert to NamedTuple RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType NamedSlice: TypeAlias = slice # once slice is generic we should do: slice[NamedIndex, NamedIndex, Literal[1]], see https://peps.python.org/pep-0696/ @@ -309,7 +317,7 @@ def is_named_range(v: AnyIndexSpec) -> TypeGuard[NamedRange]: def is_finite_named_range(v: NamedRange) -> TypeGuard[FiniteNamedRange]: - return UnitRange.is_finite(v[1]) + return UnitRange.is_finite(v.urange) def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]: @@ -351,7 +359,7 @@ def as_any_index_sequence(index: AnyIndexSpec) -> AnyIndexSequence: def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange: - return (v[0], unit_range(v[1])) + return NamedRange(dims=v[0], urange=unit_range(v[1])) _Rng = TypeVar( @@ -439,7 +447,7 @@ def __getitem__(self, index: Dimension) -> tuple[Dimension, _Rng]: ... def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: if isinstance(index, int): - return self.dims[index], self.ranges[index] + return named_range((self.dims[index], self.ranges[index])) elif isinstance(index, slice): dims_slice = self.dims[index] ranges_slice = self.ranges[index] @@ -447,7 +455,7 @@ def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: elif isinstance(index, Dimension): try: index_pos = self.dims.index(index) - return self.dims[index_pos], self.ranges[index_pos] + return named_range((self.dims[index_pos], self.ranges[index_pos])) except ValueError as ex: raise KeyError(f"No Dimension of type '{index}' is present in the Domain.") from ex else: @@ -957,15 +965,15 @@ def from_offset( def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: if not isinstance(image_range, UnitRange): - if image_range[0] != self.codomain: + if image_range.dims != self.codomain: raise ValueError( - f"Dimension '{image_range[0]}' does not match the codomain dimension '{self.codomain}'." + f"Dimension '{image_range.dims}' does not match the codomain dimension '{self.codomain}'." ) - image_range = image_range[1] + image_range = image_range.urange assert isinstance(image_range, UnitRange) - return ((self.codomain, image_range - self.offset),) + return (named_range((self.codomain, image_range - self.offset)),) def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> ConnectivityField: raise NotImplementedError() diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index f9f1ba47e0..b66f577afb 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -205,7 +205,9 @@ def __init__(self, kstart: int, data: np.ndarray | Scalar) -> None: self.kstart = kstart assert isinstance(data, (np.ndarray, Scalar)) # type: ignore # mypy bug #11673 column_range: common.NamedRange = column_range_cvar.get() - self.data = data if isinstance(data, np.ndarray) else np.full(len(column_range[1]), data) + self.data = ( + data if isinstance(data, np.ndarray) else np.full(len(column_range.urange), data) + ) def __getitem__(self, i: int) -> Any: result = self.data[i - self.kstart] @@ -746,7 +748,7 @@ def _make_tuple( except embedded_exceptions.IndexOutOfBounds: return _UNDEFINED else: - column_range = column_range_cvar.get()[1] + column_range = column_range_cvar.get().urange assert column_range is not None col: list[ @@ -823,7 +825,7 @@ def deref(self) -> Any: assert isinstance(k_pos, int) # the following range describes a range in the field # (negative values are relative to the origin, not relative to the size) - slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range[1])) + slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range.urange)) assert _is_concrete_position(shifted_pos) position = {**shifted_pos, **slice_column} @@ -864,7 +866,7 @@ def make_in_iterator( init = [None] * sparse_dimensions.count(sparse_dim) new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused if column_axis is not None: - column_range = column_range_cvar.get()[1] + column_range = column_range_cvar.get().urange # if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted assert column_range is not None new_pos[column_axis] = column_range.start @@ -1090,7 +1092,7 @@ def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) - def restrict(self, item: common.AnyIndexSpec) -> common.Field: if common.is_absolute_index_sequence(item) and all(common.is_named_index(e) for e in item): # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code # fmt: off - d, r = item[0] + d, r = item.dims assert d == self._dimension assert isinstance(r, core_defs.INTEGRAL_TYPES) return self.__class__(self._dimension, r) # type: ignore[arg-type] # not sure why the assert above does not work @@ -1489,7 +1491,7 @@ def _column_dtype(elem: Any) -> np.dtype: @builtins.scan.register(EMBEDDED) def scan(scan_pass, is_forward: bool, init): def impl(*iters: ItIterator): - column_range = column_range_cvar.get()[1] + column_range = column_range_cvar.get().urange if column_range is None: raise RuntimeError("Column range is not defined, cannot scan.")