Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature[next]: Extend cartesian offset syntax #1484

Merged
merged 32 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
e37fb21
Implement support for syntax field(I+1,J-1) for embedded
SF-N Mar 8, 2024
54c630f
Merge branch 'main' into cartesian_offset_syntax
SF-N Mar 8, 2024
53da840
Make _call_recursive and add type annotation (to be checked)
SF-N Mar 8, 2024
7073713
Try to fix arguments and annotations
SF-N Mar 11, 2024
c80e394
Fix test_laplacian
SF-N Mar 11, 2024
4c3dddf
Update and extend various parts
SF-N Mar 11, 2024
5b237d7
Remove comma
SF-N Mar 11, 2024
e479b96
Add support fir IDim + 1 + 1
SF-N Mar 13, 2024
b7fdc68
Only support one argument in case order matters
SF-N Mar 13, 2024
c50b083
Modify laplacian tests to avoid symmetry, extend visit_shift to also …
SF-N Mar 14, 2024
c1f2194
Use squared inpit field in laplacian tests, remove support for IDim +…
SF-N Mar 22, 2024
876e5b0
merge main
SF-N Mar 22, 2024
2c22920
Add annotation
SF-N Mar 22, 2024
b57afa0
Start working on review comments
SF-N Mar 27, 2024
79e8154
Extend offset_provider automatically
SF-N Mar 27, 2024
c080415
Update comments and add support for IDim+1+2 again (in foast_to_itir)
SF-N Mar 28, 2024
f95c6ca
Minor refactoring
SF-N Mar 28, 2024
72e2a1e
Reformat wrt pre-commit
SF-N Mar 28, 2024
d10ae99
Fix some problems occuring during PMAP-L implementation
SF-N Apr 11, 2024
5f8ab74
Implement review comments
SF-N Apr 12, 2024
c2d2188
Minor
SF-N Apr 12, 2024
f425e43
Minor
SF-N Apr 12, 2024
684fa0d
Remove support for (Idim + 1) + 2 again
SF-N Apr 12, 2024
6567c65
pre-commit changes
SF-N Apr 12, 2024
dc5616f
Merge origin/main
tehrengruber Sep 9, 2024
25786f8
General cleanup, added & adopted new offset syntax into new GTIR lowe…
tehrengruber Sep 9, 2024
267660f
Minor cleanups
tehrengruber Sep 9, 2024
1bc5a7e
Merge branch 'main' into cartesian_offset_syntax
tehrengruber Sep 9, 2024
c3e0a9a
Fix broken tests
tehrengruber Sep 9, 2024
2d2bd45
Merge remote-tracking branch 'origin_sf_n/cartesian_offset_syntax' in…
tehrengruber Sep 9, 2024
9dbe6c2
Cleanup
tehrengruber Sep 9, 2024
4d9394a
Fix dace sdfg-convertible program
tehrengruber Sep 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,20 @@ def __str__(self) -> str:
return self.value


def dimension_to_implicit_offset(dim: str) -> str:
"""
Return name of offset implicitly defined by a dimension.

Each dimension implicitly also defines an offset, such that we can allow syntax like::

field(TDim + 1)

without having to explicitly define an offset for ``TDim``. This function defines the respective
naming convention.
"""
return f"_{dim}Off"


@dataclasses.dataclass(frozen=True)
class Dimension:
value: str
Expand All @@ -81,6 +95,18 @@ def __str__(self) -> str:
def __call__(self, val: int) -> NamedIndex:
return NamedIndex(self, val)

def __add__(self, offset: int) -> ConnectivityField:
# TODO(sf-n): just to avoid circular import. Move or refactor the FieldOffset to avoid this.
from gt4py.next.ffront import fbuiltins
SF-N marked this conversation as resolved.
Show resolved Hide resolved

assert isinstance(self.value, str)
return fbuiltins.FieldOffset(
dimension_to_implicit_offset(self.value), source=self, target=(self,)
)[offset]

def __sub__(self, offset: int) -> ConnectivityField:
return self + (-offset)


class Infinity(enum.Enum):
"""Describes an unbounded `UnitRange`."""
Expand Down Expand Up @@ -672,7 +698,11 @@ def as_scalar(self) -> core_defs.ScalarT: ...

# Operators
@abc.abstractmethod
def __call__(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ...
def __call__(
self,
index_field: ConnectivityField | fbuiltins.FieldOffset,
*args: ConnectivityField | fbuiltins.FieldOffset,
) -> Field: ...

@abc.abstractmethod
def __getitem__(self, item: AnyIndexSpec) -> Self: ...
Expand Down Expand Up @@ -992,7 +1022,11 @@ def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRa
assert isinstance(image_range, UnitRange)
return (named_range((self.domain_dim, image_range - self.offset)),)

def premap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> ConnectivityField:
def premap(
self,
index_field: ConnectivityField | fbuiltins.FieldOffset,
*args: ConnectivityField | fbuiltins.FieldOffset,
) -> ConnectivityField:
raise NotImplementedError()

__call__ = premap
Expand Down
11 changes: 10 additions & 1 deletion src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,16 @@ def premap(
assert len(conn_fields) == 1
return _remapping_premap(self, conn_fields[0])

__call__ = premap # type: ignore[assignment]
def __call__(
self,
index_field: common.ConnectivityField | fbuiltins.FieldOffset,
*args: common.ConnectivityField | fbuiltins.FieldOffset,
) -> common.Field:
return functools.reduce(
lambda field, current_index_field: field.premap(current_index_field),
[index_field, *args],
self,
)

def restrict(self, index: common.AnyIndexSpec) -> NdArrayField:
new_domain, buffer_slice = self._slice(index)
Expand Down
29 changes: 29 additions & 0 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from gt4py.next import (
allocators as next_allocators,
backend as next_backend,
common,
embedded as next_embedded,
errors,
)
Expand Down Expand Up @@ -185,7 +186,35 @@ def itir(self) -> itir.FencilDefinition:
return self.backend.transforms_prog.past_to_itir(no_args_past).program
return past_to_itir.PastToItirFactory()(no_args_past).program

@functools.cached_property
def _implicit_offset_provider(self) -> dict[common.Tag, common.OffsetProviderElem]:
"""
Add all implicit offset providers.

Each dimension implicitly defines an offset provider such that we can allow syntax like::

field(TDim + 1)

This function adds these implicit offset providers.
"""
# TODO(tehrengruber): We add all dimensions here regardless of whether they are cartesian
# or unstructured. While it is conceptually fine, but somewhat meaningless,
# to do something `Cell+1` the GTFN backend for example doesn't support these. We should
# find a way to avoid adding these dimensions, but since we don't have the grid type here
# and since the dimensions don't this information either, we just add all dimensions here
# and filter them out in the backends that don't support this.
implicit_offset_provider = {}
params = self.past_stage.past_node.params
SF-N marked this conversation as resolved.
Show resolved Hide resolved
for param in params:
if isinstance(param.type, ts.FieldType):
for dim in param.type.dims:
implicit_offset_provider.update(
{common.dimension_to_implicit_offset(dim.value): dim}
)
return implicit_offset_provider

def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs: Any) -> None:
offset_provider = offset_provider | self._implicit_offset_provider
if self.backend is None:
warnings.warn(
UserWarning(
Expand Down
12 changes: 12 additions & 0 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,18 @@ def _deduce_compare_type(
def _deduce_binop_type(
self, node: foast.BinOp, *, left: foast.Expr, right: foast.Expr, **kwargs: Any
) -> Optional[ts.TypeSpec]:
# e.g. `IDim+1`
if (
SF-N marked this conversation as resolved.
Show resolved Hide resolved
isinstance(left.type, ts.DimensionType)
and isinstance(right.type, ts.ScalarType)
and type_info.is_integral(right.type)
):
return ts.OffsetType(source=left.type.dim, target=(left.type.dim,))
SF-N marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(left.type, ts.OffsetType):
raise errors.DSLError(
node.location, f"Type '{left.type}' can not be used in operator '{node.op}'."
)

logical_ops = {
dialect_ast_enums.BinaryOperator.BIT_AND,
dialect_ast_enums.BinaryOperator.BIT_OR,
Expand Down
76 changes: 57 additions & 19 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from gt4py import eve
from gt4py.eve import utils as eve_utils
from gt4py.eve.extended_typing import Never
from gt4py.next import common
from gt4py.next.ffront import (
dialect_ast_enums,
experimental as experimental_builtins,
Expand Down Expand Up @@ -219,25 +220,62 @@ def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall:
return self._map(node.op.value, node.left, node.right)

def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
match node.args[0]:
case foast.Subscript(value=foast.Name(id=offset_name), index=int(offset_index)):
shift_offset = im.shift(offset_name, offset_index)
return im.as_fieldop(im.lambda_("__it")(im.deref(shift_offset("__it"))))(
self.visit(node.func, **kwargs)
)
case foast.Name(id=offset_name):
return im.as_fieldop_neighbors(str(offset_name), self.visit(node.func, **kwargs))
case foast.Call(func=foast.Name(id="as_offset")):
# TODO(havogt): discuss this representation
func_args = node.args[0]
offset_dim = func_args.args[0]
assert isinstance(offset_dim, foast.Name)
shift_offset = im.shift(offset_dim.id, im.deref("__offset"))
return im.as_fieldop(
im.lambda_("__it", "__offset")(im.deref(shift_offset("__it")))
)(self.visit(node.func, **kwargs), self.visit(func_args.args[1], **kwargs))
case _:
raise FieldOperatorLoweringError("Unexpected shift arguments!")
current_expr = self.visit(node.func, **kwargs)

for arg in node.args:
match arg:
# `field(Off[idx])`
case foast.Subscript(value=foast.Name(id=offset_name), index=int(offset_index)):
current_expr = im.as_fieldop(
im.lambda_("__it")(im.deref(im.shift(offset_name, offset_index)("__it")))
)(current_expr)
# `field(Dim + idx)`
case foast.BinOp(
op=dialect_ast_enums.BinaryOperator.ADD
| dialect_ast_enums.BinaryOperator.SUB,
left=foast.Name(id=dimension), # TODO(tehrengruber): use type of lhs
right=foast.Constant(value=offset_index),
):
if arg.op == dialect_ast_enums.BinaryOperator.SUB:
offset_index *= -1
current_expr = im.as_fieldop(
# TODO(SF-N): we rely on the naming-convention that the cartesian dimensions
# are passed suffixed with `off`, e.g. the `K` is passed as `Koff` in the
# offset provider. This is a rather unclean solution and should be
# improved.
im.lambda_("__it")(
im.deref(
im.shift(
common.dimension_to_implicit_offset(dimension), offset_index
)("__it")
)
)
)(current_expr)
# `field(Off)`
case foast.Name(id=offset_name):
# only a single unstructured shift is supported so returning here is fine even though we
# are in a loop.
assert len(node.args) == 1 and len(arg.type.target) > 1 # type: ignore[attr-defined] # ensured by pattern
return im.as_fieldop_neighbors(
str(offset_name), self.visit(node.func, **kwargs)
)
# `field(as_offset(Off, offset_field))`
case foast.Call(func=foast.Name(id="as_offset")):
func_args = arg
# TODO(tehrengruber): Discuss representation. We could use the type system to
# deduce the offset dimension instead of (e.g. to allow aliasing).
offset_dim = func_args.args[0]
assert isinstance(offset_dim, foast.Name)
offset_field = self.visit(func_args.args[1], **kwargs)
current_expr = im.as_fieldop(
im.lambda_("__it", "__offset")(
im.deref(im.shift(offset_dim.id, im.deref("__offset"))("__it"))
)
)(current_expr, offset_field)
case _:
raise FieldOperatorLoweringError("Unexpected shift arguments!")

return current_expr

def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
if type_info.type_class(node.func.type) is ts.FieldType:
Expand Down
72 changes: 55 additions & 17 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.eve.extended_typing import Never
from gt4py.eve.utils import UIDGenerator
from gt4py.next import common
from gt4py.next.ffront import (
dialect_ast_enums,
fbuiltins,
Expand Down Expand Up @@ -276,23 +277,60 @@ def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall:
return self._map(node.op.value, node.left, node.right)

def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
match node.args[0]:
case foast.Subscript(value=foast.Name(id=offset_name), index=int(offset_index)):
shift_offset = im.shift(offset_name, offset_index)
case foast.Name(id=offset_name):
return im.lifted_neighbors(str(offset_name), self.visit(node.func, **kwargs))
case foast.Call(func=foast.Name(id="as_offset")):
func_args = node.args[0]
offset_dim = func_args.args[0]
assert isinstance(offset_dim, foast.Name)
shift_offset = im.shift(
offset_dim.id, im.deref(self.visit(func_args.args[1], **kwargs))
)
case _:
raise FieldOperatorLoweringError("Unexpected shift arguments!")
return im.lift(im.lambda_("it")(im.deref(shift_offset("it"))))(
self.visit(node.func, **kwargs)
)
current_expr = self.visit(node.func, **kwargs)

for arg in node.args:
match arg:
# `field(Off[idx])`
case foast.Subscript(value=foast.Name(id=offset_name), index=int(offset_index)):
current_expr = im.lift(
im.lambda_("it")(im.deref(im.shift(offset_name, offset_index)("it")))
)(current_expr)
# `field(Dim + idx)`
case foast.BinOp(
op=dialect_ast_enums.BinaryOperator.ADD
| dialect_ast_enums.BinaryOperator.SUB,
left=foast.Name(id=dimension),
right=foast.Constant(value=offset_index),
):
if arg.op == dialect_ast_enums.BinaryOperator.SUB:
offset_index *= -1
current_expr = im.lift(
# TODO(SF-N): we rely on the naming-convention that the cartesian dimensions
# are passed suffixed with `off`, e.g. the `K` is passed as `Koff` in the
# offset provider. This is a rather unclean solution and should be
# improved.
im.lambda_("it")(
im.deref(
im.shift(
common.dimension_to_implicit_offset(dimension), offset_index
)("it")
)
)
)(current_expr)
# `field(Off)`
case foast.Name(id=offset_name):
# only a single unstructured shift is supported so returning here is fine even though we
# are in a loop.
assert len(node.args) == 1 and len(arg.type.target) > 1 # type: ignore[attr-defined] # ensured by pattern
return im.lifted_neighbors(str(offset_name), self.visit(node.func, **kwargs))
# `field(as_offset(Off, offset_field))`
case foast.Call(func=foast.Name(id="as_offset")):
func_args = arg
# TODO(tehrengruber): Use type system to deduce the offset dimension instead of
# (e.g. to allow aliasing)
offset_dim = func_args.args[0]
assert isinstance(offset_dim, foast.Name)
SF-N marked this conversation as resolved.
Show resolved Hide resolved
offset_it = self.visit(func_args.args[1], **kwargs)
current_expr = im.lift(
im.lambda_("it", "offset")(
im.deref(im.shift(offset_dim.id, im.deref("offset"))("it"))
)
)(current_expr, offset_it)
case _:
raise FieldOperatorLoweringError("Unexpected shift arguments!")

return current_expr

def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
if type_info.type_class(node.func.type) is ts.FieldType:
Expand Down
7 changes: 4 additions & 3 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import factory

from gt4py.eve import NodeTranslator, concepts, traits
from gt4py.next import common, config
from gt4py.next import common, config, errors
from gt4py.next.ffront import (
fbuiltins,
gtcallable,
Expand Down Expand Up @@ -435,9 +435,10 @@ def _compute_field_slice(node: past.Subscript) -> list[past.Slice]:
node_dims = cast(ts.FieldType, node.type).dims
assert isinstance(node_dims, list)
if isinstance(node.type, ts.FieldType) and len(out_field_slice_) != len(node_dims):
raise ValueError(
raise errors.DSLError(
node.location,
f"Too many indices for field '{out_field_name}': field is {len(node_dims)}"
f"-dimensional, but {len(out_field_slice_)} were indexed."
f"-dimensional, but {len(out_field_slice_)} were indexed.",
)
return out_field_slice_

Expand Down
12 changes: 10 additions & 2 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,11 @@ def as_scalar(self) -> core_defs.IntegralScalar:
assert self._cur_index is not None
return self._cur_index

def premap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -> common.Field:
def premap(
self,
index_field: common.ConnectivityField | fbuiltins.FieldOffset,
*args: common.ConnectivityField | fbuiltins.FieldOffset,
) -> common.Field:
# TODO can be implemented by constructing and ndarray (but do we know of which kind?)
raise NotImplementedError()

Expand Down Expand Up @@ -1273,7 +1277,11 @@ def ndarray(self) -> core_defs.NDArrayObject:
def asnumpy(self) -> np.ndarray:
raise NotImplementedError()

def premap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -> common.Field:
def premap(
self,
index_field: common.ConnectivityField | fbuiltins.FieldOffset,
*args: common.ConnectivityField | fbuiltins.FieldOffset,
) -> common.Field:
# TODO can be implemented by constructing and ndarray (but do we know of which kind?)
raise NotImplementedError()

Expand Down
11 changes: 11 additions & 0 deletions src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,17 @@ def _collect_offset_definitions(
)
else:
assert grid_type == common.GridType.UNSTRUCTURED
# TODO(tehrengruber): The implicit offset providers added to support syntax like
# `KDim+1` can also include horizontal dimensions. Cartesian shifts in this
# dimension are not supported by the backend and also never occur in user code.
# We just skip these here for now, but this is not a clean solution. Not having
# any unstructured dimensions in here would be preferred.
if (
dim.kind == common.DimensionKind.HORIZONTAL
and offset_name == common.dimension_to_implicit_offset(dim.value)
):
continue

if not dim.kind == common.DimensionKind.VERTICAL:
raise ValueError(
"Mapping an offset to a horizontal dimension in unstructured is not allowed."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,9 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG:
raise ValueError(
"[DaCe Orchestration] Connectivities -at compile time- are required to generate the SDFG. Use `with_connectivities` method."
)
offset_provider = self.connectivities # tables are None at this point
offset_provider = (
self.connectivities | self._implicit_offset_provider
) # tables are None at this point

sdfg = self.backend.executor.otf_workflow.step.translation.generate_sdfg( # type: ignore[union-attr]
self.itir,
Expand Down
Loading
Loading