Skip to content

Commit

Permalink
edits for NamedRange class
Browse files Browse the repository at this point in the history
  • Loading branch information
nfarabullini committed Mar 13, 2024
1 parent d5b83a8 commit fd7bd24
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 15 deletions.
26 changes: 17 additions & 9 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,17 @@ def unit_range(r: RangeLike) -> UnitRange:
raise ValueError(f"'{r!r}' cannot be interpreted as 'UnitRange'.")


@dataclasses.dataclass(frozen=True)
class NamedRange:
dims: Dimension
urange: UnitRange

def __str__(self) -> str:
return f"{self.dims}={self.urange}"


IntIndex: TypeAlias = int | core_defs.IntegralScalar
NamedIndex: TypeAlias = tuple[Dimension, IntIndex] # TODO: convert to NamedTuple
NamedRange: TypeAlias = tuple[Dimension, UnitRange] # TODO: convert to NamedTuple
FiniteNamedRange: TypeAlias = tuple[Dimension, FiniteUnitRange] # TODO: convert to NamedTuple
RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType
NamedSlice: TypeAlias = slice # once slice is generic we should do: slice[NamedIndex, NamedIndex, Literal[1]], see https://peps.python.org/pep-0696/
Expand Down Expand Up @@ -309,7 +317,7 @@ def is_named_range(v: AnyIndexSpec) -> TypeGuard[NamedRange]:


def is_finite_named_range(v: NamedRange) -> TypeGuard[FiniteNamedRange]:
return UnitRange.is_finite(v[1])
return UnitRange.is_finite(v.urange)


def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]:
Expand Down Expand Up @@ -351,7 +359,7 @@ def as_any_index_sequence(index: AnyIndexSpec) -> AnyIndexSequence:


def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange:
return (v[0], unit_range(v[1]))
return NamedRange(dims=v[0], urange=unit_range(v[1]))


_Rng = TypeVar(
Expand Down Expand Up @@ -439,15 +447,15 @@ def __getitem__(self, index: Dimension) -> tuple[Dimension, _Rng]: ...

def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain:
if isinstance(index, int):
return self.dims[index], self.ranges[index]
return named_range((self.dims[index], self.ranges[index]))
elif isinstance(index, slice):
dims_slice = self.dims[index]
ranges_slice = self.ranges[index]
return Domain(dims=dims_slice, ranges=ranges_slice)
elif isinstance(index, Dimension):
try:
index_pos = self.dims.index(index)
return self.dims[index_pos], self.ranges[index_pos]
return named_range((self.dims[index_pos], self.ranges[index_pos]))
except ValueError as ex:
raise KeyError(f"No Dimension of type '{index}' is present in the Domain.") from ex
else:
Expand Down Expand Up @@ -957,15 +965,15 @@ def from_offset(

def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]:
if not isinstance(image_range, UnitRange):
if image_range[0] != self.codomain:
if image_range.dims != self.codomain:
raise ValueError(
f"Dimension '{image_range[0]}' does not match the codomain dimension '{self.codomain}'."
f"Dimension '{image_range.dims}' does not match the codomain dimension '{self.codomain}'."
)

image_range = image_range[1]
image_range = image_range.urange

assert isinstance(image_range, UnitRange)
return ((self.codomain, image_range - self.offset),)
return (named_range((self.codomain, image_range - self.offset)),)

def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> ConnectivityField:
raise NotImplementedError()
Expand Down
14 changes: 8 additions & 6 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,9 @@ def __init__(self, kstart: int, data: np.ndarray | Scalar) -> None:
self.kstart = kstart
assert isinstance(data, (np.ndarray, Scalar)) # type: ignore # mypy bug #11673
column_range: common.NamedRange = column_range_cvar.get()
self.data = data if isinstance(data, np.ndarray) else np.full(len(column_range[1]), data)
self.data = (
data if isinstance(data, np.ndarray) else np.full(len(column_range.urange), data)
)

def __getitem__(self, i: int) -> Any:
result = self.data[i - self.kstart]
Expand Down Expand Up @@ -746,7 +748,7 @@ def _make_tuple(
except embedded_exceptions.IndexOutOfBounds:
return _UNDEFINED
else:
column_range = column_range_cvar.get()[1]
column_range = column_range_cvar.get().urange
assert column_range is not None

col: list[
Expand Down Expand Up @@ -823,7 +825,7 @@ def deref(self) -> Any:
assert isinstance(k_pos, int)
# the following range describes a range in the field
# (negative values are relative to the origin, not relative to the size)
slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range[1]))
slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range.urange))

assert _is_concrete_position(shifted_pos)
position = {**shifted_pos, **slice_column}
Expand Down Expand Up @@ -864,7 +866,7 @@ def make_in_iterator(
init = [None] * sparse_dimensions.count(sparse_dim)
new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused
if column_axis is not None:
column_range = column_range_cvar.get()[1]
column_range = column_range_cvar.get().urange
# if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted
assert column_range is not None
new_pos[column_axis] = column_range.start
Expand Down Expand Up @@ -1090,7 +1092,7 @@ def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -

def restrict(self, item: common.AnyIndexSpec) -> common.Field:
if common.is_absolute_index_sequence(item) and all(common.is_named_index(e) for e in item): # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code # fmt: off
d, r = item[0]
d, r = item.dims
assert d == self._dimension
assert isinstance(r, core_defs.INTEGRAL_TYPES)
return self.__class__(self._dimension, r) # type: ignore[arg-type] # not sure why the assert above does not work
Expand Down Expand Up @@ -1489,7 +1491,7 @@ def _column_dtype(elem: Any) -> np.dtype:
@builtins.scan.register(EMBEDDED)
def scan(scan_pass, is_forward: bool, init):
def impl(*iters: ItIterator):
column_range = column_range_cvar.get()[1]
column_range = column_range_cvar.get().urange
if column_range is None:
raise RuntimeError("Column range is not defined, cannot scan.")

Expand Down

0 comments on commit fd7bd24

Please sign in to comment.