Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature[next]: Extend cartesian offset syntax #1484

Merged
merged 32 commits into from
Sep 10, 2024
Merged
Changes from 12 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
e37fb21
Implement support for syntax field(I+1,J-1) for embedded
SF-N Mar 8, 2024
54c630f
Merge branch 'main' into cartesian_offset_syntax
SF-N Mar 8, 2024
53da840
Make _call_recursive and add type annotation (to be checked)
SF-N Mar 8, 2024
7073713
Try to fix arguments and annotations
SF-N Mar 11, 2024
c80e394
Fix test_laplacian
SF-N Mar 11, 2024
4c3dddf
Update and extend various parts
SF-N Mar 11, 2024
5b237d7
Remove comma
SF-N Mar 11, 2024
e479b96
Add support fir IDim + 1 + 1
SF-N Mar 13, 2024
b7fdc68
Only support one argument in case order matters
SF-N Mar 13, 2024
c50b083
Modify laplacian tests to avoid symmetry, extend visit_shift to also …
SF-N Mar 14, 2024
c1f2194
Use squared inpit field in laplacian tests, remove support for IDim +…
SF-N Mar 22, 2024
876e5b0
merge main
SF-N Mar 22, 2024
2c22920
Add annotation
SF-N Mar 22, 2024
b57afa0
Start working on review comments
SF-N Mar 27, 2024
79e8154
Extend offset_provider automatically
SF-N Mar 27, 2024
c080415
Update comments and add support for IDim+1+2 again (in foast_to_itir)
SF-N Mar 28, 2024
f95c6ca
Minor refactoring
SF-N Mar 28, 2024
72e2a1e
Reformat wrt pre-commit
SF-N Mar 28, 2024
d10ae99
Fix some problems occuring during PMAP-L implementation
SF-N Apr 11, 2024
5f8ab74
Implement review comments
SF-N Apr 12, 2024
c2d2188
Minor
SF-N Apr 12, 2024
f425e43
Minor
SF-N Apr 12, 2024
684fa0d
Remove support for (Idim + 1) + 2 again
SF-N Apr 12, 2024
6567c65
pre-commit changes
SF-N Apr 12, 2024
dc5616f
Merge origin/main
tehrengruber Sep 9, 2024
25786f8
General cleanup, added & adopted new offset syntax into new GTIR lowe…
tehrengruber Sep 9, 2024
267660f
Minor cleanups
tehrengruber Sep 9, 2024
1bc5a7e
Merge branch 'main' into cartesian_offset_syntax
tehrengruber Sep 9, 2024
c3e0a9a
Fix broken tests
tehrengruber Sep 9, 2024
2d2bd45
Merge remote-tracking branch 'origin_sf_n/cartesian_offset_syntax' in…
tehrengruber Sep 9, 2024
9dbe6c2
Cleanup
tehrengruber Sep 9, 2024
4d9394a
Fix dace sdfg-convertible program
tehrengruber Sep 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
@@ -87,6 +87,18 @@ def __str__(self) -> str:
def __call__(self, val: int) -> NamedIndex:
return NamedIndex(self, val)

def __add__(self, offset: int):
from gt4py.next.ffront import fbuiltins

assert isinstance(self.value, str)
return fbuiltins.FieldOffset(f"{self.value}off", source=self, target=(self,))[offset]

def __sub__(self, offset: int):
from gt4py.next.ffront import fbuiltins

assert isinstance(self.value, str)
return fbuiltins.FieldOffset(f"{self.value}off", source=self, target=(self,))[-offset]


class Infinity(enum.Enum):
"""Describes an unbounded `UnitRange`."""
@@ -646,7 +658,11 @@ def as_scalar(self) -> core_defs.ScalarT: ...

# Operators
@abc.abstractmethod
def __call__(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ...
def __call__(
self,
index_field: ConnectivityField | fbuiltins.FieldOffset,
*args: ConnectivityField | fbuiltins.FieldOffset,
) -> Field: ...

@abc.abstractmethod
def __getitem__(self, item: AnyIndexSpec) -> Self: ...
@@ -785,14 +801,12 @@ def __eq__(self, other: Any) -> Never:
def __ne__(self, other: Any) -> Never:
raise TypeError("'ConnectivityField' does not support this operation.")

def __add__(self, other: Field | core_defs.IntegralScalar) -> Never:
raise TypeError("'ConnectivityField' does not support this operation.")
def __add__(self, other: Field | core_defs.IntegralScalar) -> Field: ...

def __radd__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe
raise TypeError("'ConnectivityField' does not support this operation.")

def __sub__(self, other: Field | core_defs.IntegralScalar) -> Never:
raise TypeError("'ConnectivityField' does not support this operation.")
def __sub__(self, other: Field | core_defs.IntegralScalar) -> Field: ...

def __rsub__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe
raise TypeError("'ConnectivityField' does not support this operation.")
@@ -901,6 +915,18 @@ def __gt_builtin_func__(cls, _: fbuiltins.BuiltInFunction) -> Never: # type: ig
def ndarray(self) -> Never:
raise NotImplementedError()

def __add__(self, other: Field | core_defs.IntegralScalar) -> Field:
if isinstance(other, int):
return CartesianConnectivity(dimension=self.codomain, offset=self.offset + other)
else:
raise TypeError("'ConnectivityField' does not support this operation.")

def __sub__(self, other: Field | core_defs.IntegralScalar) -> Field:
if isinstance(other, int):
return CartesianConnectivity(dimension=self.codomain, offset=self.offset - other)
else:
raise TypeError("'ConnectivityField' does not support this operation.")

def asnumpy(self) -> Never:
raise NotImplementedError()

@@ -957,7 +983,11 @@ def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRa
assert isinstance(image_range, UnitRange)
return (named_range((self.codomain, image_range - self.offset)),)

def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> ConnectivityField:
def remap(
self,
index_field: ConnectivityField | fbuiltins.FieldOffset,
*args: ConnectivityField | fbuiltins.FieldOffset,
) -> ConnectivityField:
raise NotImplementedError()

__call__ = remap
9 changes: 8 additions & 1 deletion src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
@@ -226,7 +226,14 @@ def remap(
dtype=self.dtype,
)

__call__ = remap # type: ignore[assignment]
def __call__(
self,
index_field: common.ConnectivityField | fbuiltins.FieldOffset,
*args: common.ConnectivityField | fbuiltins.FieldOffset,
) -> common.Field:
return functools.reduce(
lambda field, connectivity: field.remap(connectivity), [index_field, *args], self
)

def restrict(self, index: common.AnyIndexSpec) -> NdArrayField:
new_domain, buffer_slice = self._slice(index)
13 changes: 13 additions & 0 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
@@ -620,6 +620,19 @@ def _deduce_binop_type(
right: foast.Expr,
**kwargs: Any,
) -> Optional[ts.TypeSpec]:
if (
isinstance(left.type, ts.DimensionType)
and isinstance(right.type, ts.ScalarType)
and type_info.is_integral(right.type)
):
return ts.OffsetType(source=left.type.dim, target=(left.type.dim,))
if (
isinstance(left.type, ts.OffsetType)
and isinstance(right.type, ts.ScalarType)
and type_info.is_integral(right.type)
):
return ts.OffsetType(source=left.type.source, target=(left.type.source,))

logical_ops = {
dialect_ast_enums.BinaryOperator.BIT_AND,
dialect_ast_enums.BinaryOperator.BIT_OR,
55 changes: 38 additions & 17 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
@@ -295,23 +295,44 @@ def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall:
return self._map(node.op.value, node.left, node.right)

def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
match node.args[0]:
case foast.Subscript(value=foast.Name(id=offset_name), index=int(offset_index)):
shift_offset = im.shift(offset_name, offset_index)
case foast.Name(id=offset_name):
return im.lifted_neighbors(str(offset_name), self.visit(node.func, **kwargs))
case foast.Call(func=foast.Name(id="as_offset")):
func_args = node.args[0]
offset_dim = func_args.args[0]
assert isinstance(offset_dim, foast.Name)
shift_offset = im.shift(
offset_dim.id, im.deref(self.visit(func_args.args[1], **kwargs))
)
case _:
raise FieldOperatorLoweringError("Unexpected shift arguments!")
return im.lift(im.lambda_("it")(im.deref(shift_offset("it"))))(
self.visit(node.func, **kwargs)
)
shift_offsets = []
for i in range(len(node.args)):
match node.args[i]:
case foast.Subscript(value=foast.Name(id=offset_name), index=int(offset_index)):
shift_offsets.append(im.shift(offset_name, offset_index))
case foast.BinOp(
op=dialect_ast_enums.BinaryOperator.ADD
| dialect_ast_enums.BinaryOperator.SUB,
left=foast.Name(id=dimension),
right=foast.Constant(value=offset_index),
):
if node.args[i].op == dialect_ast_enums.BinaryOperator.SUB: # type: ignore[attr-defined] # ensured by pattern
offset_index *= -1
shift_offsets.append(im.shift(f"{dimension}off", offset_index))
case foast.Name(id=offset_name):
assert len(node.args) == 1
return im.lifted_neighbors(
str(offset_name), self.visit(node.func, **kwargs)
) # Todo: fix return statement to take care of several args
case foast.Call(func=foast.Name(id="as_offset")):
func_args = node.args[i]
offset_dim = func_args.args[0] # type: ignore[attr-defined] # ensured by pattern
assert isinstance(offset_dim, foast.Name)
shift_offsets.append(
im.shift(offset_dim.id, im.deref(self.visit(func_args.args[1], **kwargs))) # type: ignore[attr-defined] # ensured by pattern
)
case _:
raise FieldOperatorLoweringError("Unexpected shift arguments!")

ret = im.lift(im.lambda_("it")(im.deref(shift_offsets[0]("it"))))
for i in range(len(shift_offsets) - 1):
ret = im.lift(im.lambda_("it")(im.deref(shift_offsets[i + 1]("it"))))(
ret(self.visit(node.func, **kwargs))
)
if len(shift_offsets) == 1:
return ret(self.visit(node.func, **kwargs))
else:
return ret

def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
if type_info.type_class(node.func.type) is ts.FieldType:
7 changes: 4 additions & 3 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@
import factory

from gt4py.eve import NodeTranslator, concepts, traits
from gt4py.next import common, config
from gt4py.next import common, config, errors
from gt4py.next.ffront import (
fbuiltins,
gtcallable,
@@ -412,9 +412,10 @@ def _compute_field_slice(node: past.Subscript) -> list[past.Slice]:
node_dims_ls = cast(ts.FieldType, node.type).dims
assert isinstance(node_dims_ls, list)
if isinstance(node.type, ts.FieldType) and len(out_field_slice_) != len(node_dims_ls):
raise ValueError(
raise errors.DSLError(
node.location,
f"Too many indices for field '{out_field_name}': field is {len(node_dims_ls)}"
f"-dimensional, but {len(out_field_slice_)} were indexed."
f"-dimensional, but {len(out_field_slice_)} were indexed.",
)
return out_field_slice_

12 changes: 10 additions & 2 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
@@ -1089,7 +1089,11 @@ def as_scalar(self) -> core_defs.IntegralScalar:
assert self._cur_index is not None
return self._cur_index

def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -> common.Field:
def remap(
self,
index_field: common.ConnectivityField | fbuiltins.FieldOffset,
*args: common.ConnectivityField | fbuiltins.FieldOffset,
) -> common.Field:
# TODO can be implemented by constructing and ndarray (but do we know of which kind?)
raise NotImplementedError()

@@ -1209,7 +1213,11 @@ def ndarray(self) -> core_defs.NDArrayObject:
def asnumpy(self) -> np.ndarray:
raise NotImplementedError()

def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -> common.Field:
def remap(
self,
index_field: common.ConnectivityField | fbuiltins.FieldOffset,
*args: common.ConnectivityField | fbuiltins.FieldOffset,
) -> common.Field:
# TODO can be implemented by constructing and ndarray (but do we know of which kind?)
raise NotImplementedError()

19 changes: 11 additions & 8 deletions src/gt4py/next/type_system/type_info.py
Original file line number Diff line number Diff line change
@@ -701,20 +701,23 @@ def function_signature_incompatibilities_field(
args: list[ts.TypeSpec],
kwargs: dict[str, ts.TypeSpec],
) -> Iterator[str]:
if len(args) != 1:
yield f"Function takes 1 argument, but {len(args)} were given."
return

if not isinstance(args[0], ts.OffsetType):
yield f"Expected first argument to be of type '{ts.OffsetType}', got '{args[0]}'."
if len(args) < 1: # Todo: is this the right condition in general?
yield f"Function takes at least 1 argument, but {len(args)} were given."
return
for arg in args:
if not isinstance(arg, ts.OffsetType):
yield f"Expected first argument to be of type '{ts.OffsetType}', got '{arg}'." # Todo fix message
return
if len(args) > 1 and len(arg.target) > 1:
yield f"Function takes only 1 argument in unstructured case, but {len(args)} were given."
return

if kwargs:
yield f"Got unexpected keyword argument(s) '{', '.join(kwargs.keys())}'."
return

source_dim = args[0].source
target_dims = args[0].target
source_dim = args[0].source # type: ignore[attr-defined] # ensured by loop above
target_dims = args[0].target # type: ignore[attr-defined] # ensured by loop above
# TODO: This code does not handle ellipses for dimensions. Fix it.
assert field_type.dims is not ...
if field_type.dims and source_dim not in field_type.dims:
9 changes: 8 additions & 1 deletion tests/next_tests/integration_tests/cases.py
Original file line number Diff line number Diff line change
@@ -483,7 +483,14 @@ def cartesian_case(
):
yield Case(
exec_alloc_descriptor if exec_alloc_descriptor.executor else None,
offset_provider={"Ioff": IDim, "Joff": JDim, "Koff": KDim},
offset_provider={
"Ioff": IDim,
"Joff": JDim,
"Koff": KDim,
"IDimoff": IDim,
"JDimoff": JDim,
"KDimoff": KDim,
},
default_sizes={IDim: 10, JDim: 10, KDim: 10},
grid_type=common.GridType.CARTESIAN,
allocator=exec_alloc_descriptor.allocator,
Original file line number Diff line number Diff line change
@@ -23,18 +23,28 @@
exec_alloc_descriptor,
)


pytestmark = pytest.mark.uses_cartesian_shift


@gtx.field_operator
def lap(in_field: gtx.Field[[IDim, JDim], "float"]) -> gtx.Field[[IDim, JDim], "float"]:
return (
-4.0 * in_field
+ in_field(Ioff[1])
+ in_field(Joff[1])
+ in_field(Ioff[-1])
+ in_field(Joff[-1])
+ in_field(IDim + 1)
+ in_field(JDim + 1)
+ in_field(IDim - 1)
+ in_field(JDim - 1)
)


@gtx.field_operator
def skewedlap(in_field: gtx.Field[[IDim, JDim], "float"]) -> gtx.Field[[IDim, JDim], "float"]:
return (
-4.0 * in_field
+ in_field(IDim + 1, JDim + 1)
+ in_field(IDim + 1, JDim - 1)
+ in_field(IDim - 1, JDim + 1)
+ in_field(IDim - 1, JDim - 1)
)


@@ -51,6 +61,14 @@ def lap_program(
lap(in_field, out=out_field[1:-1, 1:-1])


@gtx.program
def skewedlap_program(
in_field: gtx.Field[[IDim, JDim], "float"],
out_field: gtx.Field[[IDim, JDim], "float"],
):
skewedlap(in_field, out=out_field[1:-1, 1:-1])


@gtx.program
def laplap_program(
in_field: gtx.Field[[IDim, JDim], "float"],
@@ -59,13 +77,24 @@ def laplap_program(
laplap(in_field, out=out_field[2:-2, 2:-2])


def square(inp):
"""Compute the square of the field entries"""
return inp[:, :] * inp[:, :]


def lap_ref(inp):
"""Compute the laplacian using numpy"""
return -4.0 * inp[1:-1, 1:-1] + inp[:-2, 1:-1] + inp[2:, 1:-1] + inp[1:-1, :-2] + inp[1:-1, 2:]
return -4.0 * inp[1:-1, 1:-1] + inp[2:, 1:-1] + inp[1:-1, 2:] + inp[:-2, 1:-1] + inp[1:-1, :-2]


def skewedlap_ref(inp):
"""Compute the laplacian using numpy"""
return -4.0 * inp[1:-1, 1:-1] + inp[2:, 2:] + inp[2:, :-2] + inp[:-2, 2:] + inp[:-2, :-2]


def test_ffront_lap(cartesian_case):
in_field = cases.allocate(cartesian_case, lap_program, "in_field")()
in_field = square(in_field)
out_field = cases.allocate(cartesian_case, lap_program, "out_field")()

cases.verify(
@@ -77,7 +106,25 @@ def test_ffront_lap(cartesian_case):
ref=lap_ref(in_field.ndarray),
)


def test_ffront_skewedlap(cartesian_case):
in_field = cases.allocate(cartesian_case, skewedlap_program, "in_field")()
in_field = square(in_field)
out_field = cases.allocate(cartesian_case, skewedlap_program, "out_field")()

cases.verify(
cartesian_case,
skewedlap_program,
in_field,
out_field,
inout=out_field[1:-1, 1:-1],
ref=skewedlap_ref(in_field.ndarray),
)


def test_ffront_laplap(cartesian_case):
in_field = cases.allocate(cartesian_case, laplap_program, "in_field")()
in_field = square(in_field)
out_field = cases.allocate(cartesian_case, laplap_program, "out_field")()

cases.verify(