Skip to content

Commit

Permalink
feat[next]: variadic generic type (#1486)
Browse files Browse the repository at this point in the history
Variadic generic type: 
```python
ShapeT = TypeVarTuple("ShapeT")


class Dims(Generic[Unpack[ShapeT]]):
    shape: tuple[Unpack[ShapeT]]


DimsT = TypeVar("DimsT", bound=Dims, covariant=True)
```
such that field type is as follows: 
```python
Field[Dims[D0, D1, ...], DType]
```
  • Loading branch information
nfarabullini authored Mar 13, 2024
1 parent cb2056c commit d5b83a8
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 38 deletions.
63 changes: 32 additions & 31 deletions docs/user/next/QuickstartGuide.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ The following snippet imports the most commonly used features that are needed to
import numpy as np
import gt4py.next as gtx
from gt4py.next import float64, neighbor_sum, where
from gt4py.next import float64, neighbor_sum, where, Dims
```

#### Fields
Expand Down Expand Up @@ -91,11 +91,12 @@ Let's see an example for a field operator that adds two fields elementwise:

```{code-cell} ipython3
@gtx.field_operator
def add(a: gtx.Field[[CellDim, KDim], float64],
b: gtx.Field[[CellDim, KDim], float64]) -> gtx.Field[[CellDim, KDim], float64]:
def add(a: gtx.Field[gtx.Dims[CellDim, KDim], float64],
b: gtx.Field[gtx.Dims[CellDim, KDim], float64]) -> gtx.Field[gtx.Dims[CellDim, KDim], float64]:
return a + b
```

\_Note: for now `Dims` is not mandatory, hence this type hint is also accepted: `gtx.Field[[CellDim, KDim], float64]`
You can call field operators from [programs](#Programs), other field operators, or directly. The code snippet below shows a direct call, in which case you have to supply two additional arguments: `out`, which is a field to write the return value to, and `offset_provider`, which is left empty for now. The result of the field operator is a field with all entries equal to 5, but for brevity, only the average and the standard deviation of the entries are printed:

```{code-cell} ipython3
Expand All @@ -115,9 +116,9 @@ This example program below calls the above elementwise addition field operator t

```{code-cell} ipython3
@gtx.program
def run_add(a : gtx.Field[[CellDim, KDim], float64],
b : gtx.Field[[CellDim, KDim], float64],
result : gtx.Field[[CellDim, KDim], float64]):
def run_add(a : gtx.Field[gtx.Dims[CellDim, KDim], float64],
b : gtx.Field[gtx.Dims[CellDim, KDim], float64],
result : gtx.Field[gtx.Dims[CellDim, KDim], float64]):
add(a, b, out=result)
add(b, result, out=result)
```
Expand Down Expand Up @@ -247,11 +248,11 @@ Pay attention to the syntax where the field offset `E2C` can be freely accessed

```{code-cell} ipython3
@gtx.field_operator
def nearest_cell_to_edge(cell_values: gtx.Field[[CellDim], float64]) -> gtx.Field[[EdgeDim], float64]:
def nearest_cell_to_edge(cell_values: gtx.Field[gtx.Dims[CellDim], float64]) -> gtx.Field[gtx.Dims[EdgeDim], float64]:
return cell_values(E2C[0])
@gtx.program
def run_nearest_cell_to_edge(cell_values: gtx.Field[[CellDim], float64], out : gtx.Field[[EdgeDim], float64]):
def run_nearest_cell_to_edge(cell_values: gtx.Field[gtx.Dims[CellDim], float64], out : gtx.Field[gtx.Dims[EdgeDim], float64]):
nearest_cell_to_edge(cell_values, out=out)
run_nearest_cell_to_edge(cell_values, edge_values, offset_provider={"E2C": E2C_offset_provider})
Expand All @@ -273,12 +274,12 @@ Similarly to the previous example, the output is once again a field on edges. Th

```{code-cell} ipython3
@gtx.field_operator
def sum_adjacent_cells(cells : gtx.Field[[CellDim], float64]) -> gtx.Field[[EdgeDim], float64]:
# type of cells(E2C) is gtx.Field[[CellDim, E2CDim], float64]
def sum_adjacent_cells(cells : gtx.Field[gtx.Dims[CellDim], float64]) -> gtx.Field[gtx.Dims[EdgeDim], float64]:
# type of cells(E2C) is gtx.Field[gtx.Dims[CellDim, E2CDim], float64]
return neighbor_sum(cells(E2C), axis=E2CDim)
@gtx.program
def run_sum_adjacent_cells(cells : gtx.Field[[CellDim], float64], out : gtx.Field[[EdgeDim], float64]):
def run_sum_adjacent_cells(cells : gtx.Field[gtx.Dims[CellDim], float64], out : gtx.Field[gtx.Dims[EdgeDim], float64]):
sum_adjacent_cells(cells, out=out)
run_sum_adjacent_cells(cell_values, edge_values, offset_provider={"E2C": E2C_offset_provider})
Expand All @@ -302,7 +303,7 @@ This function takes 3 input arguments:
- mask: a field with dtype boolean
- true branch: a tuple, a field, or a scalar
- false branch: a tuple, a field, of a scalar
The mask can be directly a field of booleans (e.g. `gtx.Field[[CellDim], bool]`) or an expression evaluating to this type (e.g. `gtx.Field[[CellDim], float64] > 3`).
The mask can be directly a field of booleans (e.g. `gtx.Field[gtx.Dims[CellDim], bool]`) or an expression evaluating to this type (e.g. `gtx.Field[[CellDim], float64] > 3`).
The `where` builtin loops over each entry of the mask and returns values corresponding to the same indexes of either the true or the false branch.
In the case where the true and false branches are either fields or scalars, the resulting output will be a field including all dimensions from all inputs. For example:

Expand All @@ -312,8 +313,8 @@ result_where = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape))
b = 6.0
@gtx.field_operator
def conditional(mask: gtx.Field[[CellDim, KDim], bool], a: gtx.Field[[CellDim, KDim], float64], b: float
) -> gtx.Field[[CellDim, KDim], float64]:
def conditional(mask: gtx.Field[gtx.Dims[CellDim, KDim], bool], a: gtx.Field[gtx.Dims[CellDim, KDim], float64], b: float
) -> gtx.Field[gtx.Dims[CellDim, KDim], float64]:
return where(mask, a, b)
conditional(mask, a, b, out=result_where, offset_provider={})
Expand All @@ -329,13 +330,13 @@ result_1 = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape))
result_2 = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape))
@gtx.field_operator
def _conditional_tuple(mask: gtx.Field[[CellDim, KDim], bool], a: gtx.Field[[CellDim, KDim], float64], b: float
) -> tuple[gtx.Field[[CellDim, KDim], float64], gtx.Field[[CellDim, KDim], float64]]:
def _conditional_tuple(mask: gtx.Field[gtx.Dims[CellDim, KDim], bool], a: gtx.Field[gtx.Dims[CellDim, KDim], float64], b: float
) -> tuple[gtx.Field[gtx.Dims[CellDim, KDim], float64], gtx.Field[gtx.Dims[CellDim, KDim], float64]]:
return where(mask, (a, b), (b, a))
@gtx.program
def conditional_tuple(mask: gtx.Field[[CellDim, KDim], bool], a: gtx.Field[[CellDim, KDim], float64], b: float,
result_1: gtx.Field[[CellDim, KDim], float64], result_2: gtx.Field[[CellDim, KDim], float64]
def conditional_tuple(mask: gtx.Field[gtx.Dims[CellDim, KDim], bool], a: gtx.Field[gtx.Dims[CellDim, KDim], float64], b: float,
result_1: gtx.Field[gtx.Dims[CellDim, KDim], float64], result_2: gtx.Field[gtx.Dims[CellDim, KDim], float64]
):
_conditional_tuple(mask, a, b, out=(result_1, result_2))
Expand All @@ -360,17 +361,17 @@ result_2 = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape))
@gtx.field_operator
def _conditional_tuple_nested(
mask: gtx.Field[[CellDim, KDim], bool], a: gtx.Field[[CellDim, KDim], float64], b: gtx.Field[[CellDim, KDim], float64], c: gtx.Field[[CellDim, KDim], float64], d: gtx.Field[[CellDim, KDim], float64]
mask: gtx.Field[gtx.Dims[CellDim, KDim], bool], a: gtx.Field[gtx.Dims[CellDim, KDim], float64], b: gtx.Field[gtx.Dims[CellDim, KDim], float64], c: gtx.Field[gtx.Dims[CellDim, KDim], float64], d: gtx.Field[gtx.Dims[CellDim, KDim], float64]
) -> tuple[
tuple[gtx.Field[[CellDim, KDim], float64], gtx.Field[[CellDim, KDim], float64]],
tuple[gtx.Field[[CellDim, KDim], float64], gtx.Field[[CellDim, KDim], float64]],
tuple[gtx.Field[gtx.Dims[CellDim, KDim], float64], gtx.Field[gtx.Dims[CellDim, KDim], float64]],
tuple[gtx.Field[gtx.Dims[CellDim, KDim], float64], gtx.Field[gtx.Dims[CellDim, KDim], float64]],
]:
return where(mask, ((a, b), (b, a)), ((c, d), (d, c)))
@gtx.program
def conditional_tuple_nested(
mask: gtx.Field[[CellDim, KDim], bool], a: gtx.Field[[CellDim, KDim], float64], b: gtx.Field[[CellDim, KDim], float64], c: gtx.Field[[CellDim, KDim], float64], d: gtx.Field[[CellDim, KDim], float64],
result_1: gtx.Field[[CellDim, KDim], float64], result_2: gtx.Field[[CellDim, KDim], float64]
mask: gtx.Field[gtx.Dims[CellDim, KDim], bool], a: gtx.Field[gtx.Dims[CellDim, KDim], float64], b: gtx.Field[gtx.Dims[CellDim, KDim], float64], c: gtx.Field[gtx.Dims[CellDim, KDim], float64], d: gtx.Field[gtx.Dims[CellDim, KDim], float64],
result_1: gtx.Field[gtx.Dims[CellDim, KDim], float64], result_2: gtx.Field[gtx.Dims[CellDim, KDim], float64]
):
_conditional_tuple_nested(mask, a, b, c, d, out=((result_1, result_2), (result_2, result_1)))
Expand Down Expand Up @@ -425,19 +426,19 @@ The second lines first creates a temporary field using `edge_differences(C2E)`,

```{code-cell} ipython3
@gtx.field_operator
def pseudo_lap(cells : gtx.Field[[CellDim], float64],
edge_weights : gtx.Field[[CellDim, C2EDim], float64]) -> gtx.Field[[CellDim], float64]:
edge_differences = cells(E2C[0]) - cells(E2C[1]) # type: gtx.Field[[EdgeDim], float64]
def pseudo_lap(cells : gtx.Field[gtx.Dims[CellDim], float64],
edge_weights : gtx.Field[gtx.Dims[CellDim, C2EDim], float64]) -> gtx.Field[gtx.Dims[CellDim], float64]:
edge_differences = cells(E2C[0]) - cells(E2C[1]) # type: gtx.Field[gtx.Dims[EdgeDim], float64]
return neighbor_sum(edge_differences(C2E) * edge_weights, axis=C2EDim)
```

The program itself is just a shallow wrapper over the `pseudo_lap` field operator. The significant part is how offset providers for both the edge-to-cell and cell-to-edge connectivities are supplied when the program is called:

```{code-cell} ipython3
@gtx.program
def run_pseudo_laplacian(cells : gtx.Field[[CellDim], float64],
edge_weights : gtx.Field[[CellDim, C2EDim], float64],
out : gtx.Field[[CellDim], float64]):
def run_pseudo_laplacian(cells : gtx.Field[gtx.Dims[CellDim], float64],
edge_weights : gtx.Field[gtx.Dims[CellDim, C2EDim], float64],
out : gtx.Field[gtx.Dims[CellDim], float64]):
pseudo_lap(cells, edge_weights, out=out)
result_pseudo_lap = gtx.as_field([CellDim], np.zeros(shape=(6,)))
Expand All @@ -454,7 +455,7 @@ As a closure, here is an example of chaining field operators, which is very simp

```{code-cell} ipython3
@gtx.field_operator
def pseudo_laplap(cells : gtx.Field[[CellDim], float64],
edge_weights : gtx.Field[[CellDim, C2EDim], float64]) -> gtx.Field[[CellDim], float64]:
def pseudo_laplap(cells : gtx.Field[gtx.Dims[CellDim], float64],
edge_weights : gtx.Field[gtx.Dims[CellDim, C2EDim], float64]) -> gtx.Field[gtx.Dims[CellDim], float64]:
return pseudo_lap(pseudo_lap(cells, edge_weights), edge_weights)
```
12 changes: 11 additions & 1 deletion src/gt4py/next/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,17 @@
"""

from . import common, ffront, iterator, program_processors, type_inference
from .common import Dimension, DimensionKind, Domain, Field, GridType, UnitRange, domain, unit_range
from .common import (
Dimension,
DimensionKind,
Dims,
Domain,
Field,
GridType,
UnitRange,
domain,
unit_range,
)
from .constructors import as_connectivity, as_field, empty, full, ones, zeros
from .embedded import ( # Just for registering field implementations
nd_array_field as _nd_array_field,
Expand Down
10 changes: 9 additions & 1 deletion src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
TypeAlias,
TypeGuard,
TypeVar,
TypeVarTuple,
Unpack,
cast,
extended_runtime_checkable,
overload,
Expand All @@ -51,9 +53,15 @@


DimT = TypeVar("DimT", bound="Dimension") # , covariant=True)
DimsT = TypeVar("DimsT", bound=Sequence["Dimension"], covariant=True)
ShapeT = TypeVarTuple("ShapeT")


class Dims(Generic[Unpack[ShapeT]]):
shape: tuple[Unpack[ShapeT]]


DimsT = TypeVar("DimsT", bound=Dims, covariant=True)

Tag: TypeAlias = str


Expand Down
12 changes: 7 additions & 5 deletions src/gt4py/next/type_system/type_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import functools
import types
import typing
from typing import Any, ForwardRef, Optional, Union
from typing import Any, ForwardRef, Optional

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -106,16 +106,18 @@ def from_type_hint(
case common.Field:
if (n_args := len(args)) != 2:
raise ValueError(f"Field type requires two arguments, got {n_args}: '{type_hint}'.")

dims: Union[Ellipsis, list[common.Dimension]] = []
dims: list[common.Dimension] = []
dim_arg, dtype_arg = args
dim_arg = (
list(typing.get_args(dim_arg))
if typing.get_origin(dim_arg) is common.Dims
else dim_arg
)
if isinstance(dim_arg, list):
for d in dim_arg:
if not isinstance(d, common.Dimension):
raise ValueError(f"Invalid field dimension definition '{d}'.")
dims.append(d)
elif dim_arg is Ellipsis:
dims = dim_arg
else:
raise ValueError(f"Invalid field dimensions '{dim_arg}'.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,19 @@ def test_invalid_symbol_types():
type_translation.from_type_hint(typing.Callable[[int], str])
with pytest.raises(ValueError, match="Invalid callable annotations"):
type_translation.from_type_hint(typing.Callable[[int], float])


@pytest.mark.parametrize(
"value, expected_dims",
[
(common.Dims[IDim, JDim], [IDim, JDim]),
(common.Dims[IDim, np.float64], ValueError),
(common.Dims["IDim"], ValueError),
],
)
def test_generic_variadic_dims(value, expected_dims):
if expected_dims == ValueError:
with pytest.raises(ValueError, match="Invalid field dimension definition"):
type_translation.from_type_hint(gtx.Field[value, np.int32])
else:
assert type_translation.from_type_hint(gtx.Field[value, np.int32]).dims == expected_dims

0 comments on commit d5b83a8

Please sign in to comment.