Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: concat_where for boundary conditions #1468

Merged
merged 67 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
bf1f1f8
skip value connectivity
havogt Jan 18, 2024
5495937
fix formatting
havogt Jan 23, 2024
512d795
Merge remote-tracking branch 'upstream/main' into embedded_skip_value…
havogt Jan 30, 2024
c7b01eb
Merge remote-tracking branch 'upstream/main' into embedded_skip_value…
havogt Jan 31, 2024
d51cfca
cleanup parts and tests
havogt Feb 1, 2024
fdb4423
skip fvm test with no atlas
havogt Feb 1, 2024
1e0e228
testcase which requires broadcasting the mask
havogt Feb 2, 2024
05bdc67
Merge remote-tracking branch 'upstream/main' into embedded_skip_value…
havogt Feb 6, 2024
96090e8
add comment
havogt Feb 6, 2024
8c408dd
cleanup
havogt Feb 6, 2024
ea798d1
Merge remote-tracking branch 'upstream/main' into embedded_skip_value…
havogt Feb 6, 2024
b96280e
fix lowering issue past to itir
havogt Feb 7, 2024
8669f33
fix bug
havogt Feb 7, 2024
fab1185
fix connectivity names
havogt Feb 9, 2024
d24f6a3
explicit xp.newaxis
havogt Feb 9, 2024
dcf17e2
wrap the mask hypercube
havogt Feb 9, 2024
3804eb6
prepare configurable skip_value
havogt Feb 9, 2024
e4a6f39
fix test
havogt Feb 9, 2024
1f9ac85
skip value refactoring
havogt Feb 9, 2024
07cffa6
fix skip_value check
havogt Feb 9, 2024
93bf889
fix assert
havogt Feb 10, 2024
cb1d017
Merge remote-tracking branch 'upstream/main' into embedded_skip_value…
havogt Feb 12, 2024
ce6adde
prototype
havogt Feb 14, 2024
6462610
fix bug in reverse sub and div
havogt Feb 14, 2024
5c92491
alternative concat_where that deals with multiple ranges
havogt Feb 14, 2024
35bc515
Merge remote-tracking branch 'upstream/main' into embedded_skip_value…
havogt Feb 22, 2024
75d1b03
SKIP_VALUE -> _DEFAULT_SKIP_VALUE
havogt Feb 22, 2024
649d30b
Merge remote-tracking branch 'local/embedded_skip_value_connectivity'…
havogt Feb 23, 2024
26a2668
Merge remote-tracking branch 'upstream/main' into concat_where
havogt Feb 23, 2024
cbfd824
format
havogt Feb 23, 2024
971dc44
add tests for scalar binary with field
havogt Feb 23, 2024
ceb1a09
cleanup tests
havogt Feb 23, 2024
ddf6667
cleanup
havogt Feb 23, 2024
77ce553
Merge branch 'fix_reverse_ops' into concat_where
havogt Feb 23, 2024
864ddd3
add concat_where for embedded
havogt Feb 23, 2024
a028503
add tests and very hacked version of broadcasting
havogt Feb 23, 2024
05ae764
add TODOs
havogt Feb 23, 2024
19a176f
cleanup
havogt Feb 24, 2024
6dd79d8
refactoring
havogt Feb 24, 2024
ab78dc4
more cleanups
havogt Feb 24, 2024
f891e37
address review comments
havogt Feb 26, 2024
ddcc272
change scalar value
havogt Feb 26, 2024
5c81c03
Merge remote-tracking branch 'upstream/main' into fix_reverse_ops
havogt Feb 26, 2024
9816121
Merge remote-tracking branch 'origin/fix_reverse_ops' into concat_where
havogt Feb 26, 2024
b604a7a
Merge remote-tracking branch 'upstream/main' into concat_where
havogt Feb 26, 2024
cd11b75
Merge commit 'ea852984dbf22ec0f2bb72ed454d7a1392040478' into concat_w…
havogt Mar 5, 2024
05fd105
default format
havogt Mar 5, 2024
0b1f12c
Merge commit '4c8f706a9f3cceff946f128022390c406523a7a1' into concat_w…
havogt Mar 5, 2024
d3db930
Merge commit '77a205b6b31d9854e0e15d01d91349047ec0c426' into concat_w…
havogt Mar 5, 2024
f22c149
fix tests
havogt Mar 6, 2024
80f21ff
describe algorithm
havogt Mar 7, 2024
c9dfc8d
add docstring to concat_where, but test not working
havogt Mar 7, 2024
14acc84
add tests
havogt Mar 7, 2024
071a9e0
switch back to list[tuple]
havogt Mar 7, 2024
a6186fc
add unit_range.is_empty
havogt Mar 11, 2024
6413753
add more tests
havogt Mar 11, 2024
c798f97
address more review comments
havogt Mar 11, 2024
54428f7
steal refactoring from nfarabullini/as_offset_embedded
havogt Mar 11, 2024
c178f25
move to experimental
havogt Mar 11, 2024
a99561e
address review comments
havogt Mar 13, 2024
5755d62
add a test (wip)
havogt Mar 13, 2024
4ebf2d6
add todos
havogt Mar 13, 2024
2c89af3
fix refactoring bugs
havogt Mar 14, 2024
8b268ed
add tests in field_operators
havogt Mar 14, 2024
b15a0aa
Merge remote-tracking branch 'upstream/main' into concat_where
havogt Mar 14, 2024
c2738b6
add type ignore
havogt Mar 14, 2024
70274a7
Merge remote-tracking branch 'upstream/main' into concat_where
havogt Mar 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ def is_left_finite(cls, obj: UnitRange) -> TypeGuard[UnitRange[int, _Right]]:
# classmethod since TypeGuards requires the guarded obj as separate argument
return obj.start is not Infinity.NEGATIVE

@property
havogt marked this conversation as resolved.
Show resolved Hide resolved
def is_empty(self) -> bool:
return (
self.start == 0 and self.stop == 0
) # post_init ensures that empty is represented as UnitRange(0, 0)

def __repr__(self) -> str:
return f"UnitRange({self.start}, {self.stop})"

Expand Down Expand Up @@ -422,7 +428,7 @@ def is_finite(cls, obj: Domain) -> TypeGuard[FiniteDomain]:

@property
def is_empty(self) -> bool:
return any(rng == UnitRange(0, 0) for rng in self.ranges)
return any(rng.is_empty for rng in self.ranges)

@overload
def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: ...
Expand Down
28 changes: 14 additions & 14 deletions src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ def domain_intersection(
Return the intersection of the given domains.

Example:
>>> I = common.Dimension("I")
>>> domain_intersection(
... common.domain({I: (0, 5)}), common.domain({I: (1, 3)})
... ) # doctest: +ELLIPSIS
Domain(dims=(Dimension(value='I', ...), ranges=(UnitRange(1, 3),))
>>> I = common.Dimension("I")
>>> domain_intersection(
... common.domain({I: (0, 5)}), common.domain({I: (1, 3)})
... ) # doctest: +ELLIPSIS
Domain(dims=(Dimension(value='I', ...), ranges=(UnitRange(1, 3),))
"""
return functools.reduce(
operator.and_,
Expand All @@ -117,22 +117,22 @@ def domain_intersection(
)


def intersect_domains(
def restrict_to_intersection(
*domains: common.Domain,
ignore_dims: Optional[common.Dimension | tuple[common.Dimension, ...]] = None,
) -> tuple[common.Domain, ...]:
"""
Return the with each other intersected domains, ignoring 'ignore_dims' dimensions for the intersection.

Example:
havogt marked this conversation as resolved.
Show resolved Hide resolved
>>> I = common.Dimension("I")
>>> J = common.Dimension("J")
>>> res = intersect_domains(
... common.domain({I: (0, 5), J: (1, 2)}),
... common.domain({I: (1, 3), J: (0, 3)}),
... ignore_dims=J,
... )
>>> assert res == (common.domain({I: (1, 3), J: (1, 2)}), common.domain({I: (1, 3), J: (0, 3)}))
>>> I = common.Dimension("I")
>>> J = common.Dimension("J")
>>> res = intersect_domains(
... common.domain({I: (0, 5), J: (1, 2)}),
... common.domain({I: (1, 3), J: (0, 3)}),
... ignore_dims=J,
... )
>>> assert res == (common.domain({I: (1, 3), J: (1, 2)}), common.domain({I: (1, 3), J: (0, 3)}))
"""
ignore_dims_tuple = ignore_dims if isinstance(ignore_dims, tuple) else (ignore_dims,)
intersection_without_ignore_dims = domain_intersection(*[
Expand Down
2 changes: 2 additions & 0 deletions src/gt4py/next/embedded/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def __init__(


class NonContiguousDomain(gt4py_exceptions.GT4PyError):
"""Describes an error where a domain would become non-contiguous after an operation."""

msg: str
egparedes marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, msg: str):
Expand Down
20 changes: 12 additions & 8 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
context as embedded_context,
exceptions as embedded_exceptions,
)
from gt4py.next.ffront import fbuiltins
from gt4py.next.ffront import experimental, fbuiltins
from gt4py.next.iterator import embedded as itir_embedded


Expand Down Expand Up @@ -540,7 +540,9 @@ def _compute_mask_ranges(
ind = 0
res = []
for i in range(1, mask.shape[0]):
if (mask_i := bool(mask[i].item())) != cur:
if (
mask_i := bool(mask[i].item())
) != cur: # `.item()` to extract the scalar from a 0-d array in case of e.g. cupy
res.append((cur, common.UnitRange(ind, i)))
cur = mask_i
ind = i
Expand Down Expand Up @@ -579,12 +581,13 @@ def _intersect_fields(
*fields: common.Field | core_defs.Scalar,
ignore_dims: Optional[common.Dimension | tuple[common.Dimension, ...]] = None,
) -> tuple[common.Field, ...]:
# TODO(havogt): this function could be moved to common, but then requires a broadcast implementation for all field implementations
# TODO(havogt): this function could be moved to common, but then requires a broadcast implementation for all field implementations;
# currently blocked, because requiring the `_to_field` function, see comment there.
nd_array_class = _get_nd_array_class(*fields)
promoted_dims = common.promote_dims(*[f.domain.dims for f in fields if common.is_field(f)])
promoted_dims = common.promote_dims(*(f.domain.dims for f in fields if common.is_field(f)))
broadcasted_fields = [_broadcast(_to_field(f, nd_array_class), promoted_dims) for f in fields]

intersected_domains = embedded_common.intersect_domains(
intersected_domains = embedded_common.restrict_to_intersection(
*[f.domain for f in broadcasted_fields], ignore_dims=ignore_dims
)

Expand All @@ -597,7 +600,7 @@ def _intersect_fields(
)


def _concat_domains(*domains: common.Domain, dim: common.Dimension) -> Optional[common.Domain]:
def _stack_domains(*domains: common.Domain, dim: common.Dimension) -> Optional[common.Domain]:
if not domains:
return common.Domain()
dim_start = domains[0][dim][1].start
Expand All @@ -619,7 +622,7 @@ def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field:
and not embedded_common.domain_intersection(*[f.domain for f in fields]).is_empty
):
raise ValueError("Fields to concatenate must not overlap.")
new_domain = _concat_domains(*[f.domain for f in fields], dim=dim)
new_domain = _stack_domains(*[f.domain for f in fields], dim=dim)
if new_domain is None:
raise embedded_exceptions.NonContiguousDomain(f"Cannot concatenate fields along {dim}.")
nd_array_class = _get_nd_array_class(*fields)
Expand All @@ -646,6 +649,7 @@ def _concat_where(
# intersect the field in dimensions orthogonal to the mask, then all slices in the mask field have same domain
t_broadcasted, f_broadcasted = _intersect_fields(true_field, false_field, ignore_dims=mask_dim)

# TODO(havogt): for clarity, most of it could be implemented on named_range in the masked dimension, but we currently lack the utils
# compute the consecutive ranges (first relative, then domain) of true and false values
mask_values_to_relative_range_mapping: Iterable[tuple[bool, common.UnitRange]] = (
_compute_mask_ranges(mask_field.ndarray)
Expand Down Expand Up @@ -689,7 +693,7 @@ def _concat_where(
return cls_.from_array(result_array, domain=result_domain)


NdArrayField.register_builtin_func(fbuiltins.concat_where, _concat_where) # type: ignore[arg-type] # tuples are handled in the base implementation
NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] # tuples are handled in the base implementation


def _make_reduction(
Expand Down
74 changes: 47 additions & 27 deletions src/gt4py/next/ffront/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,50 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from dataclasses import dataclass

from gt4py.next.type_system import type_specifications as ts


@dataclass
class BuiltInFunction:
__gt_type: ts.FunctionType

def __call__(self, *args, **kwargs):
"""Act as an empty place holder for the built in function."""

def __gt_type__(self):
return self.__gt_type


as_offset = BuiltInFunction(
ts.FunctionType(
pos_only_args=[
ts.DeferredType(constraint=ts.OffsetType),
ts.DeferredType(constraint=ts.FieldType),
],
pos_or_kw_args={},
kw_only_args={},
returns=ts.DeferredType(constraint=ts.OffsetType),
)
)
from typing import Tuple

from gt4py._core import definitions as core_defs
from gt4py.next import common
from gt4py.next.ffront.fbuiltins import BuiltInFunction, FieldOffset, WhereBuiltinFunction


@BuiltInFunction
def as_offset(
offset_: FieldOffset,
field: common.Field,
/,
) -> common.ConnectivityField:
raise NotImplementedError()


@WhereBuiltinFunction
def concat_where(
mask: common.Field,
true_field: common.Field | core_defs.ScalarT | Tuple,
false_field: common.Field | core_defs.ScalarT | Tuple,
/,
) -> common.Field | Tuple:
"""
Concatenates two field fields based on a 1D mask.

The resulting domain is the concatenation of the mask subdomains with the domains of the respective true or false fields.
Empty domains at the beginning or end are ignored, but the interior must result in a consecutive domain.

TODO: I can't get this doctest to run, even after copying the __doc__ in the decorator
Example:
>>> I = common.Dimension("I")
>>> mask = common._field([True, False, True], domain={I: (0, 3)})
>>> true_field = common._field([1, 2], domain={I: (0, 2)})
>>> false_field = common._field([3, 4, 5], domain={I: (1, 4)})
>>> assert concat_where(mask, true_field, false_field) == _field([1, 3], domain={I: (0, 2)})

>>> mask = common._field([True, False, True], domain={I: (0, 3)})
>>> true_field = common._field([1, 2, 3], domain={I: (0, 3)})
>>> false_field = common._field(
... [4], domain={I: (2, 3)}
... ) # error because of non-consecutive domain: missing I(1), but has I(0) and I(2) values
havogt marked this conversation as resolved.
Show resolved Hide resolved
"""
raise NotImplementedError()


EXPERIMENTAL_FUN_BUILTIN_NAMES = ["as_offset", "concat_where"]
30 changes: 0 additions & 30 deletions src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,36 +202,6 @@ def where(
raise NotImplementedError()


@WhereBuiltinFunction
def concat_where(
mask: common.Field,
true_field: common.Field | core_defs.ScalarT | Tuple,
false_field: common.Field | core_defs.ScalarT | Tuple,
/,
) -> common.Field | Tuple:
"""
Concatenates two field fields based on a 1D mask.

The resulting domain is the concatenation of the mask subdomains with the domains of the respective true or false fields.
Empty domains at the beginning or end are ignored, but the interior must result in a consecutive domain.

TODO: I can't get this doctest to run, even after copying the __doc__ in the decorator
Example:
>>> I = common.Dimension("I")
>>> mask = common._field([True, False, True], domain={I: (0, 3)})
>>> true_field = common._field([1, 2], domain={I: (0, 2)})
>>> false_field = common._field([3, 4, 5], domain={I: (1, 4)})
>>> assert concat_where(mask, true_field, false_field) == _field([1, 3], domain={I: (0, 2)})

>>> mask = common._field([True, False, True], domain={I: (0, 3)})
>>> true_field = common._field([1, 2, 3], domain={I: (0, 3)})
>>> false_field = common._field(
... [4], domain={I: (2, 3)}
... ) # error because of non-consecutive domain: missing I(1), but has I(0) and I(2) values
"""
raise NotImplementedError()


@BuiltInFunction
def astype(
value: common.Field | core_defs.ScalarT | Tuple,
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from gt4py.next.common import DimensionKind
from gt4py.next.ffront import ( # noqa
dialect_ast_enums,
experimental,
fbuiltins,
type_info as ti_ffront,
type_specifications as ts_ffront,
Expand Down Expand Up @@ -727,7 +728,8 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call:
isinstance(new_func.type, ts.FunctionType)
and not type_info.is_concrete(return_type)
and isinstance(new_func, foast.Name)
and new_func.id in fbuiltins.FUN_BUILTIN_NAMES
and new_func.id
in (fbuiltins.FUN_BUILTIN_NAMES + experimental.EXPERIMENTAL_FUN_BUILTIN_NAMES)
):
visitor = getattr(self, f"_visit_{new_func.id}")
return visitor(new_node, **kwargs)
Expand Down Expand Up @@ -941,8 +943,6 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call:
location=node.location,
)

_visit_concat_where = _visit_where

def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call:
arg_type = cast(ts.FieldType | ts.ScalarType, node.args[0].type)
broadcast_dims_expr = cast(foast.TupleExpr, node.args[1]).elts
Expand Down
13 changes: 9 additions & 4 deletions tests/next_tests/unit_tests/embedded_tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
canonicalize_any_index_sequence,
iterate_domain,
sub_domain,
intersect_domains,
restrict_to_intersection,
domain_intersection,
)

Expand Down Expand Up @@ -194,25 +194,30 @@ def test_domain_intersection():
assert result == expected


def test_domain_intersection_empty():
result = domain_intersection()
assert result == common.Domain()


def test_intersect_domains():
testee = (common.domain({I: (0, 5), J: (1, 2)}), common.domain({I: (1, 3), J: (1, 3)}))
result = intersect_domains(*testee, ignore_dims=J)
result = restrict_to_intersection(*testee, ignore_dims=J)

expected = (common.domain({I: (1, 3), J: (1, 2)}), common.domain({I: (1, 3), J: (1, 3)}))
assert result == expected


def test_intersect_domains_ignore_dims_none():
testee = (common.domain({I: (0, 5), J: (1, 2)}), common.domain({I: (1, 3), J: (1, 3)}))
result = intersect_domains(*testee)
result = restrict_to_intersection(*testee)

expected = (domain_intersection(*testee),) * 2
assert result == expected


def test_intersect_domains_ignore_all_dims():
testee = (common.domain({I: (0, 5), J: (1, 2)}), common.domain({I: (1, 3), J: (1, 3)}))
result = intersect_domains(*testee, ignore_dims=(I, J))
result = restrict_to_intersection(*testee, ignore_dims=(I, J))

expected = testee
assert result == expected
14 changes: 14 additions & 0 deletions tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,20 @@ def test_hypercube(index_array, expected):
([6, 7, 8, 9, 10], None),
([42, 7, 42, 9, 42], None),
),
(
# parts of mask_ranges are concatenated
([True, True, False, False], None),
([1, 2], {D0: (1, 3)}),
([3, 4], {D0: (1, 3)}),
([1, 4], {D0: (1, 3)}),
),
(
# parts of mask_ranges are concatenated and yield non-contiguous domain
([True, False, True, False], None),
([1, 2], {D0: (0, 2)}),
([3, 4], {D0: (2, 4)}),
None,
),
],
)
def test_concat_where(
Expand Down
Loading
Loading