Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Support for direct field operator call with domain arg #1779

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
365c942
Support for direct field operator call with domain arg
tehrengruber Dec 10, 2024
7bb21fe
Merge branch 'main' into direct_fo_call_with_domain_arg
tehrengruber Dec 10, 2024
aed4d1e
Support for calling a program with field arguments whose domain does …
tehrengruber Dec 10, 2024
1e0aa93
Merge branch 'field_arg_with_non_zero_domain_start' into direct_fo_ca…
tehrengruber Dec 10, 2024
f722c14
Add test for input arg with different domain
tehrengruber Dec 11, 2024
c5a61e9
Fix format
tehrengruber Dec 11, 2024
9e09c86
Merge branch 'main' into field_arg_with_non_zero_domain_start
tehrengruber Dec 11, 2024
9deb814
update dace backend
edopao Dec 11, 2024
61feb99
Fix failing tests
tehrengruber Jan 10, 2025
30a4911
Merge remote-tracking branch 'origin_tehrengruber/field_arg_with_non_…
tehrengruber Jan 10, 2025
9d97ea7
Merge branch 'field_arg_with_non_zero_domain_start' into direct_fo_ca…
tehrengruber Jan 10, 2025
3f15911
Disable in dace backend
tehrengruber Jan 10, 2025
fd95ff4
Merge branch 'main' into direct_fo_call_with_domain_arg
tehrengruber Jan 10, 2025
052c54b
Merge branch 'main' into field_arg_with_non_zero_domain_start
tehrengruber Jan 10, 2025
d1009be
Fix gpu tests
tehrengruber Jan 10, 2025
0d903cc
Address review comments
tehrengruber Jan 10, 2025
7b77c9f
Merge remote-tracking branch 'origin_tehrengruber/field_arg_with_non_…
tehrengruber Jan 10, 2025
a6cf988
Merge origin/main
tehrengruber Jan 10, 2025
858a573
Merge branch 'main' into field_arg_with_non_zero_domain_start
tehrengruber Jan 14, 2025
da65ca1
Merge remote-tracking branch 'origin/main' into field_arg_with_non_ze…
edopao Jan 15, 2025
e6e640c
dace support for domain range and field origin
edopao Jan 15, 2025
a9f67f9
minor edit
edopao Jan 15, 2025
b97232a
Revert "minor edit"
edopao Jan 16, 2025
56ec88d
Revert "dace support for domain range and field origin"
edopao Jan 16, 2025
ad68fac
Merge remote-tracking branch 'origin/main' into field_arg_with_non_ze…
edopao Jan 16, 2025
a28fbf3
skip dace orchestration tests
edopao Jan 16, 2025
9637866
skip dace test_halo_exchange_helper_attrs
edopao Jan 16, 2025
6f58be6
Merge direct_fo_call_with_domain_arg
tehrengruber Jan 17, 2025
63cb251
Merge remote-tracking branch 'origin/main' into direct_fo_call_with_d…
tehrengruber Jan 17, 2025
9fb7d2e
Fix pytest mark
tehrengruber Jan 17, 2025
bbdae9d
Fix pytest mark
tehrengruber Jan 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,10 @@ def __call__(self, *args, **kwargs) -> None:
if "out" not in kwargs:
raise errors.MissingArgumentError(None, "out", True)
out = kwargs.pop("out")
if "domain" in kwargs:
domain = common.domain(kwargs.pop("domain"))
out = out[domain]

args, kwargs = type_info.canonicalize_arguments(
self.foast_stage.foast_node.type, args, kwargs
)
Expand Down
26 changes: 16 additions & 10 deletions src/gt4py/next/ffront/past_process_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,40 +83,46 @@ def _process_args(
# TODO(tehrengruber): Previously this function was called with the actual arguments
# not their type. The check using the shape here is not functional anymore and
# should instead be placed in a proper location.
shapes_and_dims = [*_field_constituents_shape_and_dims(args[param_idx], param.type)]
ranges_and_dims = [*_field_constituents_range_and_dims(args[param_idx], param.type)]
# check that all non-scalar like constituents have the same shape and dimension, e.g.
# for `(scalar, (field1, field2))` the two fields need to have the same shape and
# dimension
if shapes_and_dims:
shape, dims = shapes_and_dims[0]
if ranges_and_dims:
range_, dims = ranges_and_dims[0]
if not all(
el_shape == shape and el_dims == dims for (el_shape, el_dims) in shapes_and_dims
el_range == range_ and el_dims == dims
for (el_range, el_dims) in ranges_and_dims
):
raise ValueError(
"Constituents of composite arguments (e.g. the elements of a"
" tuple) need to have the same shape and dimensions."
)
index_type = ts.ScalarType(kind=ts.ScalarKind.INT32)
size_args.extend(
shape if shape else [ts.ScalarType(kind=ts.ScalarKind.INT32)] * len(dims) # type: ignore[arg-type] # shape is always empty
range_ if range_ else [ts.TupleType(types=[index_type, index_type])] * len(dims) # type: ignore[arg-type] # shape is always empty
)
return tuple(rewritten_args), tuple(size_args), kwargs


def _field_constituents_shape_and_dims(
def _field_constituents_range_and_dims(
Copy link
Contributor Author

@tehrengruber tehrengruber Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part of the code-base is in an extremely poor state, let's not prettify a blobfish and invest time on improving this here.

arg: Any, # TODO(havogt): improve typing
arg_type: ts.DataType,
) -> Iterator[tuple[tuple[int, ...], list[common.Dimension]]]:
) -> Iterator[tuple[tuple[tuple[int, int], ...], list[common.Dimension]]]:
match arg_type:
case ts.TupleType():
for el, el_type in zip(arg, arg_type.types):
yield from _field_constituents_shape_and_dims(el, el_type)
yield from _field_constituents_range_and_dims(el, el_type)
case ts.FieldType():
dims = type_info.extract_dims(arg_type)
if isinstance(arg, ts.TypeSpec): # TODO
yield (tuple(), dims)
elif dims:
assert hasattr(arg, "shape") and len(arg.shape) == len(dims)
yield (arg.shape, dims)
assert (
hasattr(arg, "domain")
and isinstance(arg.domain, common.Domain)
and len(arg.domain.dims) == len(dims)
)
yield (tuple((r.start, r.stop) for r in arg.domain.ranges), dims)
else:
yield from [] # ignore 0-dim fields
case ts.ScalarType():
Expand Down
34 changes: 18 additions & 16 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def _column_axis(all_closure_vars: dict[str, Any]) -> Optional[common.Dimension]
return iter(scanops_per_axis.keys()).__next__()


def _size_arg_from_field(field_name: str, dim: int) -> str:
return f"__{field_name}_size_{dim}"
def _range_arg_from_field(field_name: str, dim: int) -> str:
return f"__{field_name}_{dim}_range"


def _flatten_tuple_expr(node: past.Expr) -> list[past.Name | past.Subscript]:
Expand Down Expand Up @@ -217,13 +217,14 @@ def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]:
)
if len(fields_dims) > 0: # otherwise `param` has no constituent which is of `FieldType`
assert all(field_dims == fields_dims[0] for field_dims in fields_dims)
index_type = ts.ScalarType(
kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
)
for dim_idx in range(len(fields_dims[0])):
size_params.append(
itir.Sym(
id=_size_arg_from_field(param.id, dim_idx),
type=ts.ScalarType(
kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
),
id=_range_arg_from_field(param.id, dim_idx),
type=ts.TupleType(types=[index_type, index_type]),
)
)

Expand Down Expand Up @@ -286,7 +287,8 @@ def _visit_slice_bound(
self,
slice_bound: Optional[past.Constant],
default_value: itir.Expr,
dim_size: itir.Expr,
start_idx: itir.Expr,
stop_idx: itir.Expr,
**kwargs: Any,
) -> itir.Expr:
if slice_bound is None:
Expand All @@ -296,11 +298,9 @@ def _visit_slice_bound(
slice_bound.type
)
if slice_bound.value < 0:
lowered_bound = itir.FunCall(
fun=itir.SymRef(id="plus"), args=[dim_size, self.visit(slice_bound, **kwargs)]
)
lowered_bound = im.plus(stop_idx, self.visit(slice_bound, **kwargs))
else:
lowered_bound = self.visit(slice_bound, **kwargs)
lowered_bound = im.plus(start_idx, self.visit(slice_bound, **kwargs))
else:
raise AssertionError("Expected 'None' or 'past.Constant'.")
if slice_bound:
Expand Down Expand Up @@ -348,8 +348,9 @@ def _construct_itir_domain_arg(
domain_args = []
domain_args_kind = []
for dim_i, dim in enumerate(out_dims):
# an expression for the size of a dimension
dim_size = itir.SymRef(id=_size_arg_from_field(out_field.id, dim_i))
# an expression for the range of a dimension
dim_range = itir.SymRef(id=_range_arg_from_field(out_field.id, dim_i))
dim_start, dim_stop = im.tuple_get(0, dim_range), im.tuple_get(1, dim_range)
# bounds
lower: itir.Expr
upper: itir.Expr
Expand All @@ -359,11 +360,12 @@ def _construct_itir_domain_arg(
else:
lower = self._visit_slice_bound(
slices[dim_i].lower if slices else None,
im.literal("0", itir.INTEGER_INDEX_BUILTIN),
dim_size,
dim_start,
dim_start,
dim_stop,
)
upper = self._visit_slice_bound(
slices[dim_i].upper if slices else None, dim_size, dim_size
slices[dim_i].upper if slices else None, dim_stop, dim_start, dim_stop
)

if dim.kind == common.DimensionKind.LOCAL:
Expand Down
9 changes: 6 additions & 3 deletions src/gt4py/next/otf/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def find_first_field(tuple_arg: tuple[Any, ...]) -> Optional[common.Field]:
return None


def iter_size_args(args: tuple[Any, ...]) -> Iterator[int]:
def iter_size_args(args: tuple[Any, ...]) -> Iterator[tuple[int, int]]:
"""
Yield the size of each field argument in each dimension.

Expand All @@ -136,7 +136,9 @@ def iter_size_args(args: tuple[Any, ...]) -> Iterator[int]:
if first_field:
yield from iter_size_args((first_field,))
case common.Field():
yield from arg.ndarray.shape
for range_ in arg.domain.ranges:
assert isinstance(range_, common.UnitRange)
yield (range_.start, range_.stop)
case _:
pass

Expand All @@ -156,6 +158,7 @@ def iter_size_compile_args(
)
if field_constituents:
# we only need the first field, because all fields in a tuple must have the same dims and sizes
index_type = ts.ScalarType(kind=ts.ScalarKind.INT32)
yield from [
ts.ScalarType(kind=ts.ScalarKind.INT32) for dim in field_constituents[0].dims
ts.TupleType(types=[index_type, index_type]) for dim in field_constituents[0].dims
]
12 changes: 9 additions & 3 deletions tests/next_tests/integration_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ def verify(
fieldview_prog: decorator.FieldOperator | decorator.Program,
*args: FieldViewArg,
ref: ReferenceValue,
domain: Optional[dict[common.Dimension, tuple[int, int]]] = None,
out: Optional[FieldViewInout] = None,
inout: Optional[FieldViewInout] = None,
offset_provider: Optional[OffsetProvider] = None,
Expand All @@ -405,6 +406,8 @@ def verify(
or tuple of fields here and they will be compared to ``ref`` under
the assumption that the fieldview code stores its results in
them.
domain: If given will be passed to the fieldview code as ``domain=``
keyword argument.
offset_provider: An override for the test case's offset_provider.
Use with care!
comparison: A comparison function, which will be called as
Expand All @@ -414,10 +417,13 @@ def verify(
used as an argument to the fieldview program and compared against ``ref``.
Else, ``inout`` will not be passed and compared to ``ref``.
"""
kwargs = {}
if out:
run(case, fieldview_prog, *args, out=out, offset_provider=offset_provider)
else:
run(case, fieldview_prog, *args, offset_provider=offset_provider)
kwargs["out"] = out
if domain:
kwargs["domain"] = domain

run(case, fieldview_prog, *args, **kwargs, offset_provider=offset_provider)

out_comp = out or inout
assert out_comp is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
import pytest

from gt4py.next import errors
from gt4py.next import errors, common, constructors
from gt4py.next.ffront.decorator import field_operator, program, scan_operator
from gt4py.next.ffront.fbuiltins import broadcast, int32

Expand Down Expand Up @@ -296,3 +296,20 @@ def test_call_bound_program_with_already_bound_arg(cartesian_case, bound_args_te
)
is not None
)


def test_direct_fo_call_with_domain_arg(cartesian_case):
@field_operator
def testee(inp: IField) -> IField:
return inp

size = cartesian_case.default_sizes[IDim]
inp = cases.allocate(cartesian_case, testee, "inp").unique()()
out = cases.allocate(
cartesian_case, testee, cases.RETURN, strategy=cases.ConstInitializer(42)
)()
ref = np.zeros(size)
ref[0] = ref[-1] = 42
ref[1:-1] = inp.ndarray[1:-1]

cases.verify(cartesian_case, testee, inp, out=out, domain={IDim: (1, size - 1)}, ref=ref)
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pytest

import gt4py.next as gtx
from gt4py.next import errors
from gt4py.next import errors, constructors, common

from next_tests.integration_tests import cases
from next_tests.integration_tests.cases import (
Expand Down Expand Up @@ -251,3 +251,18 @@ def empty_domain_program(a: cases.IJField, out_field: cases.IJField):
ValueError, match=(r"Dimensions in out field and field domain are not equivalent")
):
cases.run(cartesian_case, empty_domain_program, a, out_field, offset_provider={})


def test_field_arg_with_non_zero_domain_start(cartesian_case, copy_program_def):
copy_program = gtx.program(copy_program_def, backend=cartesian_case.backend)

size = cartesian_case.default_sizes[IDim]

inp = cases.allocate(cartesian_case, copy_program, "in_field").unique()()
out = constructors.empty(
common.domain({IDim: (1, size - 2)}),
allocator=cartesian_case.allocator,
)
ref = inp.ndarray[1:-2]

cases.verify(cartesian_case, copy_program, inp, out=out, ref=ref)
Loading