From 4dcb4ec205a60f74e94fd4ac4d1cbe0fa5cc620a Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 14 Mar 2024 19:36:04 +0100 Subject: [PATCH] feat[next]: concat_where for boundary conditions (#1468) Introduces `concat_where` to be used for boundary conditions. `where` will intersect all 3 fields. Therefore `where(klevel == 0, boundary_layer, interior)` will not work if `boundary_layer` is only defined on `klevel == 0`. `concat_where` will concatenate the fields that are selected by the mask regions. The `mask` is currently required to be 1 dimensional. --- src/gt4py/next/common.py | 8 + src/gt4py/next/embedded/common.py | 45 ++- src/gt4py/next/embedded/exceptions.py | 10 + src/gt4py/next/embedded/nd_array_field.py | 213 +++++++++++-- src/gt4py/next/embedded/operators.py | 15 +- src/gt4py/next/ffront/experimental.py | 74 +++-- src/gt4py/next/ffront/fbuiltins.py | 9 +- .../ffront/foast_passes/type_deduction.py | 6 +- src/gt4py/next/ffront/foast_to_itir.py | 7 +- .../ffront_tests/test_concat_where.py | 146 +++++++++ .../ffront_tests/test_execution.py | 24 -- .../feature_tests/ffront_tests/test_where.py | 118 +++++++ .../unit_tests/embedded_tests/test_common.py | 41 +++ .../embedded_tests/test_nd_array_field.py | 290 +++++++++++++----- tests/next_tests/unit_tests/test_common.py | 24 +- 15 files changed, 860 insertions(+), 170 deletions(-) create mode 100644 tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py create mode 100644 tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 0aa19b20ae..f21ee0e736 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -168,6 +168,11 @@ def is_left_finite(cls, obj: UnitRange) -> TypeGuard[UnitRange[int, _Right]]: # classmethod since TypeGuards requires the guarded obj as separate argument return obj.start is not Infinity.NEGATIVE + def is_empty(self) -> bool: + return ( + self.start == 0 and self.stop == 0 + ) # post_init ensures that empty is represented as UnitRange(0, 0) + def __repr__(self) -> str: return f"UnitRange({self.start}, {self.stop})" @@ -428,6 +433,9 @@ def is_finite(cls, obj: Domain) -> TypeGuard[FiniteDomain]: # classmethod since TypeGuards requires the guarded obj as separate argument return all(UnitRange.is_finite(rng) for rng in obj.ranges) + def is_empty(self) -> bool: + return any(rng.is_empty() for rng in self.ranges) + @overload def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: ... diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 0ba39d377a..f09e1b1ac3 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -97,7 +97,19 @@ def _absolute_sub_domain( return common.Domain(*named_ranges) -def intersect_domains(*domains: common.Domain) -> common.Domain: +def domain_intersection( + *domains: common.Domain, +) -> common.Domain: + """ + Return the intersection of the given domains. + + Example: + >>> I = common.Dimension("I") + >>> domain_intersection( + ... common.domain({I: (0, 5)}), common.domain({I: (1, 3)}) + ... ) # doctest: +ELLIPSIS + Domain(dims=(Dimension(value='I', ...), ranges=(UnitRange(1, 3),)) + """ return functools.reduce( operator.and_, domains, @@ -105,6 +117,37 @@ def intersect_domains(*domains: common.Domain) -> common.Domain: ) +def restrict_to_intersection( + *domains: common.Domain, + ignore_dims: Optional[common.Dimension | tuple[common.Dimension, ...]] = None, +) -> tuple[common.Domain, ...]: + """ + Return the with each other intersected domains, ignoring 'ignore_dims' dimensions for the intersection. + + Example: + >>> I = common.Dimension("I") + >>> J = common.Dimension("J") + >>> res = restrict_to_intersection( + ... common.domain({I: (0, 5), J: (1, 2)}), + ... common.domain({I: (1, 3), J: (0, 3)}), + ... ignore_dims=J, + ... ) + >>> assert res == (common.domain({I: (1, 3), J: (1, 2)}), common.domain({I: (1, 3), J: (0, 3)})) + """ + ignore_dims_tuple = ignore_dims if isinstance(ignore_dims, tuple) else (ignore_dims,) + intersection_without_ignore_dims = domain_intersection(*[ + common.Domain(*[(d, r) for d, r in domain if d not in ignore_dims_tuple]) + for domain in domains + ]) + return tuple( + common.Domain(*[ + (d, r if d in ignore_dims_tuple else intersection_without_ignore_dims[d][1]) + for d, r in domain + ]) + for domain in domains + ) + + def iterate_domain(domain: common.Domain): for i in itertools.product(*[list(r) for r in domain.ranges]): yield tuple(zip(domain.dims, i)) diff --git a/src/gt4py/next/embedded/exceptions.py b/src/gt4py/next/embedded/exceptions.py index 393123db36..9306e9a002 100644 --- a/src/gt4py/next/embedded/exceptions.py +++ b/src/gt4py/next/embedded/exceptions.py @@ -36,3 +36,13 @@ def __init__( self.indices = indices self.index = index self.dim = dim + + +class NonContiguousDomain(gt4py_exceptions.GT4PyError): + """Describes an error where a domain would become non-contiguous after an operation.""" + + detail: str + + def __init__(self, detail: str): + super().__init__(f"Operation would result in a non-contiguous domain: `{detail}`.") + self.detail = detail diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 1760cb17e8..5a07328531 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -18,7 +18,7 @@ import functools from collections.abc import Callable, Sequence from types import ModuleType -from typing import ClassVar +from typing import ClassVar, Iterable import numpy as np from numpy import typing as npt @@ -26,8 +26,12 @@ from gt4py._core import definitions as core_defs from gt4py.eve.extended_typing import Any, Never, Optional, ParamSpec, TypeAlias, TypeVar from gt4py.next import common -from gt4py.next.embedded import common as embedded_common, context as embedded_context -from gt4py.next.ffront import fbuiltins +from gt4py.next.embedded import ( + common as embedded_common, + context as embedded_context, + exceptions as embedded_exceptions, +) +from gt4py.next.ffront import experimental, fbuiltins from gt4py.next.iterator import embedded as itir_embedded @@ -42,20 +46,22 @@ jnp: Optional[ModuleType] = None # type:ignore[no-redef] +def _get_nd_array_class(*fields: common.Field | core_defs.Scalar) -> type[NdArrayField]: + for f in fields: + if isinstance(f, NdArrayField): + return f.__class__ + raise AssertionError("No 'NdArrayField' found in the arguments.") + + def _make_builtin( builtin_name: str, array_builtin_name: str, reverse=False ) -> Callable[..., NdArrayField]: def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: - first = None - for f in fields: - if isinstance(f, NdArrayField): - first = f - break - assert first is not None - xp = first.__class__.array_ns + cls_ = _get_nd_array_class(*fields) + xp = cls_.array_ns op = getattr(xp, array_builtin_name) - domain_intersection = embedded_common.intersect_domains(*[ + domain_intersection = embedded_common.domain_intersection(*[ f.domain for f in fields if common.is_field(f) ]) @@ -76,7 +82,7 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: if reverse: transformed.reverse() new_data = op(*transformed) - return first.__class__.from_array(new_data, domain=domain_intersection) + return cls_.from_array(new_data, domain=domain_intersection) _builtin_op.__name__ = builtin_name return _builtin_op @@ -423,10 +429,7 @@ def inverse_image( if relative_ranges is None: raise ValueError("Restriction generates non-contiguous dimensions.") - new_dims = [ - common.named_range((d, rr + ar.start)) - for d, ar, rr in zip(self.domain.dims, self.domain.ranges, relative_ranges) - ] + new_dims = _relative_ranges_to_domain(relative_ranges, self.domain) self._cache[cache_key] = new_dims @@ -448,6 +451,14 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field: __getitem__ = restrict +def _relative_ranges_to_domain( + relative_ranges: Sequence[common.UnitRange], domain: common.Domain +) -> common.Domain: + return common.Domain( + dims=domain.dims, ranges=[rr + ar.start for ar, rr in zip(domain.ranges, relative_ranges)] + ) + + def _hypercube( index_array: core_defs.NDArrayObject, image_range: common.UnitRange, @@ -519,6 +530,172 @@ def _hypercube( NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) +def _compute_mask_ranges( + mask: core_defs.NDArrayObject, +) -> list[tuple[bool, common.UnitRange]]: + """Take a 1-dimensional mask and return a sequence of mappings from boolean values to ranges.""" + # TODO: does it make sense to upgrade this naive algorithm to numpy? + assert mask.ndim == 1 + cur = bool(mask[0].item()) + ind = 0 + res = [] + for i in range(1, mask.shape[0]): + if ( + mask_i := bool(mask[i].item()) + ) != cur: # `.item()` to extract the scalar from a 0-d array in case of e.g. cupy + res.append((cur, common.UnitRange(ind, i))) + cur = mask_i + ind = i + res.append((cur, common.UnitRange(ind, mask.shape[0]))) + return res + + +def _trim_empty_domains( + lst: Iterable[tuple[bool, common.Domain]], +) -> list[tuple[bool, common.Domain]]: + """Remove empty domains from beginning and end of the list.""" + lst = list(lst) + if not lst: + return lst + if lst[0][1].is_empty(): + return _trim_empty_domains(lst[1:]) + if lst[-1][1].is_empty(): + return _trim_empty_domains(lst[:-1]) + return lst + + +def _to_field( + value: common.Field | core_defs.Scalar, nd_array_field_type: type[NdArrayField] +) -> common.Field: + # TODO(havogt): this function is only to workaround broadcasting of scalars, once we have a ConstantField, we can broadcast to that directly + return ( + value + if common.is_field(value) + else nd_array_field_type.from_array( + nd_array_field_type.array_ns.asarray(value), domain=common.Domain() + ) + ) + + +def _intersect_fields( + *fields: common.Field | core_defs.Scalar, + ignore_dims: Optional[common.Dimension | tuple[common.Dimension, ...]] = None, +) -> tuple[common.Field, ...]: + # TODO(havogt): this function could be moved to common, but then requires a broadcast implementation for all field implementations; + # currently blocked, because requiring the `_to_field` function, see comment there. + nd_array_class = _get_nd_array_class(*fields) + promoted_dims = common.promote_dims(*(f.domain.dims for f in fields if common.is_field(f))) + broadcasted_fields = [_broadcast(_to_field(f, nd_array_class), promoted_dims) for f in fields] + + intersected_domains = embedded_common.restrict_to_intersection( + *[f.domain for f in broadcasted_fields], ignore_dims=ignore_dims + ) + + return tuple( + nd_array_class.from_array( + f.ndarray[_get_slices_from_domain_slice(f.domain, intersected_domain)], + domain=intersected_domain, + ) + for f, intersected_domain in zip(broadcasted_fields, intersected_domains, strict=True) + ) + + +def _stack_domains(*domains: common.Domain, dim: common.Dimension) -> Optional[common.Domain]: + if not domains: + return common.Domain() + dim_start = domains[0][dim][1].start + dim_stop = dim_start + for domain in domains: + if not domain[dim][1].start == dim_stop: + return None + else: + dim_stop = domain[dim][1].stop + return domains[0].replace(dim, (dim, common.UnitRange(dim_start, dim_stop))) + + +def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field: + # TODO(havogt): this function could be extended to a general concat + # currently only concatenate along the given dimension and requires the fields to be ordered + + if ( + len(fields) > 1 + and not embedded_common.domain_intersection(*[f.domain for f in fields]).is_empty() + ): + raise ValueError("Fields to concatenate must not overlap.") + new_domain = _stack_domains(*[f.domain for f in fields], dim=dim) + if new_domain is None: + raise embedded_exceptions.NonContiguousDomain(f"Cannot concatenate fields along {dim}.") + nd_array_class = _get_nd_array_class(*fields) + return nd_array_class.from_array( + nd_array_class.array_ns.concatenate( + [nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape) for f in fields], + axis=new_domain.dim_index(dim), + ), + domain=new_domain, + ) + + +def _concat_where( + mask_field: common.Field, true_field: common.Field, false_field: common.Field +) -> common.Field: + cls_ = _get_nd_array_class(mask_field, true_field, false_field) + xp = cls_.array_ns + if mask_field.domain.ndim != 1: + raise NotImplementedError( + "'concat_where': Can only concatenate fields with a 1-dimensional mask." + ) + mask_dim = mask_field.domain.dims[0] + + # intersect the field in dimensions orthogonal to the mask, then all slices in the mask field have same domain + t_broadcasted, f_broadcasted = _intersect_fields(true_field, false_field, ignore_dims=mask_dim) + + # TODO(havogt): for clarity, most of it could be implemented on named_range in the masked dimension, but we currently lack the utils + # compute the consecutive ranges (first relative, then domain) of true and false values + mask_values_to_relative_range_mapping: Iterable[tuple[bool, common.UnitRange]] = ( + _compute_mask_ranges(mask_field.ndarray) + ) + mask_values_to_domain_mapping: Iterable[tuple[bool, common.Domain]] = ( + (mask, _relative_ranges_to_domain((relative_range,), mask_field.domain)) + for mask, relative_range in mask_values_to_relative_range_mapping + ) + # mask domains intersected with the respective fields + mask_values_to_intersected_domains_mapping: Iterable[tuple[bool, common.Domain]] = ( + ( + mask_value, + embedded_common.domain_intersection( + t_broadcasted.domain if mask_value else f_broadcasted.domain, mask_domain + ), + ) + for mask_value, mask_domain in mask_values_to_domain_mapping + ) + + # remove the empty domains from the beginning and end + mask_values_to_intersected_domains_mapping = _trim_empty_domains( + mask_values_to_intersected_domains_mapping + ) + if any(d.is_empty() for _, d in mask_values_to_intersected_domains_mapping): + raise embedded_exceptions.NonContiguousDomain( + f"In 'concat_where', cannot concatenate the following 'Domain's: {[d for _, d in mask_values_to_intersected_domains_mapping]}." + ) + + # slice the fields with the domain ranges + transformed = [ + t_broadcasted[d] if v else f_broadcasted[d] + for v, d in mask_values_to_intersected_domains_mapping + ] + + # stack the fields together + if transformed: + return _concat(*transformed, dim=mask_dim) + else: + result_domain = common.Domain((mask_dim, common.UnitRange(0, 0))) + result_array = xp.empty(result_domain.shape) + return cls_.from_array(result_array, domain=result_domain) + + +NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[has-type] + + def _make_reduction( builtin_name: str, array_builtin_name: str, initial_value_op: Callable ) -> Callable[ @@ -635,7 +812,7 @@ def __setitem__( common._field.register(jnp.ndarray, JaxArrayField.from_array) -def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field: +def _broadcast(field: common.Field, new_dimensions: Sequence[common.Dimension]) -> common.Field: if field.domain.dims == new_dimensions: return field domain_slice: list[slice | None] = [] @@ -645,7 +822,7 @@ def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...] domain_slice.append(slice(None)) named_ranges.append((dim, field.domain[pos][1])) else: - domain_slice.append(np.newaxis) + domain_slice.append(None) # np.newaxis named_ranges.append((dim, common.UnitRange.infinite())) return common._field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges)) diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index 0982024090..023da5c5f8 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -83,6 +83,14 @@ def scan_loop(hpos): return res +def _get_out_domain( + out: common.MutableField | tuple[common.MutableField | tuple, ...], +) -> common.Domain: + return embedded_common.domain_intersection(*[ + f.domain for f in utils.flatten_nested_tuple((out,)) + ]) + + def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any): if "out" in kwargs: # called from program or direct field_operator as program @@ -102,10 +110,7 @@ def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any): domain = kwargs.pop("domain", None) - flattened_out: tuple[common.Field, ...] = utils.flatten_nested_tuple((out,)) - assert all(f.domain == flattened_out[0].domain for f in flattened_out) - - out_domain = common.domain(domain) if domain is not None else flattened_out[0].domain + out_domain = common.domain(domain) if domain is not None else _get_out_domain(out) new_context_kwargs["closure_column_range"] = _get_vertical_range(out_domain) @@ -149,7 +154,7 @@ def impl(target: common.MutableField, source: common.Field): def _intersect_scan_args( *args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...], ) -> common.Domain: - return embedded_common.intersect_domains(*[ + return embedded_common.domain_intersection(*[ arg.domain for arg in utils.flatten_nested_tuple(args) if common.is_field(arg) ]) diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index 39da80a5de..b69a118713 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -12,30 +12,50 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from dataclasses import dataclass - -from gt4py.next.type_system import type_specifications as ts - - -@dataclass -class BuiltInFunction: - __gt_type: ts.FunctionType - - def __call__(self, *args, **kwargs): - """Act as an empty place holder for the built in function.""" - - def __gt_type__(self): - return self.__gt_type - - -as_offset = BuiltInFunction( - ts.FunctionType( - pos_only_args=[ - ts.DeferredType(constraint=ts.OffsetType), - ts.DeferredType(constraint=ts.FieldType), - ], - pos_or_kw_args={}, - kw_only_args={}, - returns=ts.DeferredType(constraint=ts.OffsetType), - ) -) +from typing import Tuple + +from gt4py._core import definitions as core_defs +from gt4py.next import common +from gt4py.next.ffront.fbuiltins import BuiltInFunction, FieldOffset, WhereBuiltinFunction + + +@BuiltInFunction +def as_offset( + offset_: FieldOffset, + field: common.Field, + /, +) -> common.ConnectivityField: + raise NotImplementedError() + + +@WhereBuiltinFunction +def concat_where( + mask: common.Field, + true_field: common.Field | core_defs.ScalarT | Tuple, + false_field: common.Field | core_defs.ScalarT | Tuple, + /, +) -> common.Field | Tuple: + """ + Concatenates two field fields based on a 1D mask. + + The resulting domain is the concatenation of the mask subdomains with the domains of the respective true or false fields. + Empty domains at the beginning or end are ignored, but the interior must result in a consecutive domain. + + TODO(havogt): I can't get this doctest to run, even after copying the __doc__ in the decorator + Example: + >>> I = common.Dimension("I") + >>> mask = common._field([True, False, True], domain={I: (0, 3)}) + >>> true_field = common._field([1, 2], domain={I: (0, 2)}) + >>> false_field = common._field([3, 4, 5], domain={I: (1, 4)}) + >>> assert concat_where(mask, true_field, false_field) == _field([1, 3], domain={I: (0, 2)}) + + >>> mask = common._field([True, False, True], domain={I: (0, 3)}) + >>> true_field = common._field([1, 2, 3], domain={I: (0, 3)}) + >>> false_field = common._field( + ... [4], domain={I: (2, 3)} + ... ) # error because of non-consecutive domain: missing I(1), but has I(0) and I(2) values + """ + raise NotImplementedError() + + +EXPERIMENTAL_FUN_BUILTIN_NAMES = ["as_offset", "concat_where"] diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 03b0cdc1a1..8d9552ab6d 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -25,7 +25,6 @@ from gt4py._core import definitions as core_defs from gt4py.next import common, embedded from gt4py.next.common import Dimension, Field # noqa: F401 [unused-import] for TYPE_BUILTINS -from gt4py.next.ffront.experimental import as_offset # noqa: F401 [unused-import] from gt4py.next.iterator import runtime from gt4py.next.type_system import type_specifications as ts @@ -59,6 +58,10 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSp return ts.FieldType elif t is common.Dimension: return ts.DimensionType + elif t is FieldOffset: + return ts.OffsetType + elif t is common.ConnectivityField: + return ts.OffsetType elif t is core_defs.ScalarT: return ts.ScalarType elif t is type: @@ -88,6 +91,7 @@ class BuiltInFunction(Generic[_R, _P]): def __post_init__(self): object.__setattr__(self, "name", f"{self.function.__module__}.{self.function.__name__}") + object.__setattr__(self, "__doc__", self.function.__doc__) def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: impl = self.dispatch(*args) @@ -147,7 +151,7 @@ def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R: raise ValueError( "Tuple of different size not allowed." ) # TODO(havogt) find a strategy to unify parsing and embedded error messages - return tuple(where(mask, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R` + return tuple(self(mask, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R` return super().__call__(mask, true_field, false_field) @@ -295,7 +299,6 @@ def impl( "broadcast", "where", "astype", - "as_offset", *MATH_BUILTIN_NAMES, ] diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index f9d33b5814..982655e079 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -20,6 +20,7 @@ from gt4py.next.common import DimensionKind from gt4py.next.ffront import ( # noqa dialect_ast_enums, + experimental, fbuiltins, type_info as ti_ffront, type_specifications as ts_ffront, @@ -761,7 +762,8 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> foast.Call: isinstance(new_func.type, ts.FunctionType) and not type_info.is_concrete(return_type) and isinstance(new_func, foast.Name) - and new_func.id in fbuiltins.FUN_BUILTIN_NAMES + and new_func.id + in (fbuiltins.FUN_BUILTIN_NAMES + experimental.EXPERIMENTAL_FUN_BUILTIN_NAMES) ): visitor = getattr(self, f"_visit_{new_func.id}") return visitor(new_node, **kwargs) @@ -976,6 +978,8 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: location=node.location, ) + _visit_concat_where = _visit_where + def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> foast.Call: arg_type = cast(ts.FieldType | ts.ScalarType, node.args[0].type) broadcast_dims_expr = cast(foast.TupleExpr, node.args[1]).elts diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 2b6324f485..7cdece6c59 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -25,6 +25,7 @@ lowering_utils, type_specifications as ts_ffront, ) +from gt4py.next.ffront.experimental import EXPERIMENTAL_FUN_BUILTIN_NAMES from gt4py.next.ffront.fbuiltins import FUN_BUILTIN_NAMES, MATH_BUILTIN_NAMES, TYPE_BUILTIN_NAMES from gt4py.next.ffront.foast_introspection import StmtReturnKind, deduce_stmt_return_kind from gt4py.next.iterator import ir as itir @@ -317,7 +318,9 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: return self._visit_shift(node, **kwargs) elif isinstance(node.func, foast.Name) and node.func.id in MATH_BUILTIN_NAMES: return self._visit_math_built_in(node, **kwargs) - elif isinstance(node.func, foast.Name) and node.func.id in FUN_BUILTIN_NAMES: + elif isinstance(node.func, foast.Name) and node.func.id in ( + FUN_BUILTIN_NAMES + EXPERIMENTAL_FUN_BUILTIN_NAMES + ): visitor = getattr(self, f"_visit_{node.func.id}") return visitor(node, **kwargs) elif isinstance(node.func, foast.Name) and node.func.id in TYPE_BUILTIN_NAMES: @@ -375,6 +378,8 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: node.type, ) + _visit_concat_where = _visit_where + def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return self.visit(node.args[0], **kwargs) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py new file mode 100644 index 0000000000..9da6d260e5 --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -0,0 +1,146 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np +from typing import Tuple +import pytest +from next_tests.integration_tests.cases import ( + KDim, + cartesian_case, +) +from gt4py import next as gtx +from gt4py.next.ffront.experimental import concat_where +from next_tests.integration_tests import cases +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + exec_alloc_descriptor, +) + + +def test_boundary_same_size_fields(cartesian_case): + @gtx.field_operator + def testee( + k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField + ) -> cases.IJKField: + return concat_where(k == 0, boundary, interior) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where( + k.asnumpy()[np.newaxis, np.newaxis, :] == 0, boundary.asnumpy(), interior.asnumpy() + ) + + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + + +def test_boundary_horizontal_slice(cartesian_case): + @gtx.field_operator + def testee( + k: cases.KField, interior: cases.IJKField, boundary: cases.IJField + ) -> cases.IJKField: + return concat_where(k == 0, boundary, interior) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where( + k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + boundary.asnumpy()[:, :, np.newaxis], + interior.asnumpy(), + ) + + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + + +def test_boundary_single_layer(cartesian_case): + @gtx.field_operator + def testee( + k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField + ) -> cases.IJKField: + return concat_where(k == 0, boundary, interior) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary", sizes={KDim: 1})() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where( + k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + np.broadcast_to(boundary.asnumpy(), interior.shape), + interior.asnumpy(), + ) + + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + + +def test_alternating_mask(cartesian_case): + @gtx.field_operator + def testee(k: cases.KField, f0: cases.IJKField, f1: cases.IJKField) -> cases.IJKField: + return concat_where(k % 2 == 0, f1, f0) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + f0 = cases.allocate(cartesian_case, testee, "f0")() + f1 = cases.allocate(cartesian_case, testee, "f1")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where(k.asnumpy()[np.newaxis, np.newaxis, :] % 2 == 0, f1.asnumpy(), f0.asnumpy()) + + cases.verify(cartesian_case, testee, k, f0, f1, out=out, ref=ref) + + +@pytest.mark.uses_tuple_returns +def test_with_tuples(cartesian_case): + @gtx.field_operator + def testee( + k: cases.KField, + interior0: cases.IJKField, + boundary0: cases.IJField, + interior1: cases.IJKField, + boundary1: cases.IJField, + ) -> Tuple[cases.IJKField, cases.IJKField]: + return concat_where(k == 0, (boundary0, boundary1), (interior0, interior1)) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior0 = cases.allocate(cartesian_case, testee, "interior0")() + boundary0 = cases.allocate(cartesian_case, testee, "boundary0")() + interior1 = cases.allocate(cartesian_case, testee, "interior1")() + boundary1 = cases.allocate(cartesian_case, testee, "boundary1")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref0 = np.where( + k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + boundary0.asnumpy()[:, :, np.newaxis], + interior0.asnumpy(), + ) + ref1 = np.where( + k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + boundary1.asnumpy()[:, :, np.newaxis], + interior1.asnumpy(), + ) + + cases.verify( + cartesian_case, + testee, + k, + interior0, + boundary0, + interior1, + boundary1, + out=out, + ref=(ref0, ref1), + ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index d571c61590..a2c203b163 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -31,7 +31,6 @@ int64, minimum, neighbor_sum, - where, ) from gt4py.next.ffront.experimental import as_offset from gt4py.next.program_processors.runners import gtfn @@ -1056,29 +1055,6 @@ def program_domain_tuple( ) -@pytest.mark.uses_cartesian_shift -def test_where_k_offset(cartesian_case): - @gtx.field_operator - def fieldop_where_k_offset( - inp: cases.IKField, k_index: gtx.Field[[KDim], gtx.IndexType] - ) -> cases.IKField: - return where(k_index > 0, inp(Koff[-1]), 2) - - @gtx.program - def prog(inp: cases.IKField, k_index: gtx.Field[[KDim], gtx.IndexType], out: cases.IKField): - fieldop_where_k_offset(inp, k_index, out=out, domain={IDim: (0, 10), KDim: (1, 10)}) - - inp = cases.allocate(cartesian_case, fieldop_where_k_offset, "inp")() - k_index = cases.allocate( - cartesian_case, fieldop_where_k_offset, "k_index", strategy=cases.IndexInitializer() - )() - out = cases.allocate(cartesian_case, fieldop_where_k_offset, "inp")() - - ref = np.where(k_index.asnumpy() > 0, np.roll(inp.asnumpy(), 1, axis=1), out.asnumpy()) - - cases.verify(cartesian_case, prog, inp, k_index, out=out, ref=ref) - - def test_undefined_symbols(cartesian_case): with pytest.raises(errors.DSLError, match="Undeclared symbol"): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py new file mode 100644 index 0000000000..2fc31e6574 --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py @@ -0,0 +1,118 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np +from typing import Tuple +import pytest +from next_tests.integration_tests.cases import ( + IDim, + JDim, + KDim, + Koff, + cartesian_case, +) +from gt4py import next as gtx +from gt4py.next.ffront.fbuiltins import where, broadcast +from next_tests.integration_tests import cases +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + exec_alloc_descriptor, +) + + +@pytest.mark.uses_cartesian_shift +def test_where_k_offset(cartesian_case): + @gtx.field_operator + def fieldop_where_k_offset( + inp: cases.IKField, k_index: gtx.Field[[KDim], gtx.IndexType] + ) -> cases.IKField: + return where(k_index > 0, inp(Koff[-1]), 2) + + @gtx.program + def prog(inp: cases.IKField, k_index: gtx.Field[[KDim], gtx.IndexType], out: cases.IKField): + fieldop_where_k_offset(inp, k_index, out=out, domain={IDim: (0, 10), KDim: (1, 10)}) + + inp = cases.allocate(cartesian_case, fieldop_where_k_offset, "inp")() + k_index = cases.allocate( + cartesian_case, fieldop_where_k_offset, "k_index", strategy=cases.IndexInitializer() + )() + out = cases.allocate(cartesian_case, fieldop_where_k_offset, cases.RETURN)() + + ref = np.where(k_index.asnumpy() > 0, np.roll(inp.asnumpy(), 1, axis=1), out.asnumpy()) + + cases.verify(cartesian_case, prog, inp, k_index, out=out, ref=ref) + + +def test_same_size_fields(cartesian_case): + # Note boundaries can only be implemented with `where` if both fields have the same size, see `concat_where` + @gtx.field_operator + def testee( + k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField + ) -> cases.IJKField: + return where(k == 0, boundary, interior) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where( + k.asnumpy()[np.newaxis, np.newaxis, :] == 0, boundary.asnumpy(), interior.asnumpy() + ) + + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + + +@pytest.mark.uses_tuple_returns +def test_with_tuples(cartesian_case): + @gtx.field_operator + def testee( + k: cases.KField, + interior0: cases.IJKField, + boundary0: cases.IJField, + interior1: cases.IJKField, + boundary1: cases.IJField, + ) -> Tuple[cases.IJKField, cases.IJKField]: + return where( + broadcast(k, (IDim, JDim, KDim)) == 0, (boundary0, boundary1), (interior0, interior1) + ) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior0 = cases.allocate(cartesian_case, testee, "interior0")() + boundary0 = cases.allocate(cartesian_case, testee, "boundary0")() + interior1 = cases.allocate(cartesian_case, testee, "interior1")() + boundary1 = cases.allocate(cartesian_case, testee, "boundary1")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref0 = np.where( + k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + boundary0.asnumpy()[:, :, np.newaxis], + interior0.asnumpy(), + ) + ref1 = np.where( + k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + boundary1.asnumpy()[:, :, np.newaxis], + interior1.asnumpy(), + ) + + cases.verify( + cartesian_case, + testee, + k, + interior0, + boundary0, + interior1, + boundary1, + out=out, + ref=(ref0, ref1), + ) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py index 91f15ee936..9765273f94 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_common.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -24,6 +24,8 @@ canonicalize_any_index_sequence, iterate_domain, sub_domain, + restrict_to_intersection, + domain_intersection, ) @@ -180,3 +182,42 @@ def test_slicing(slices, expected): else: testee = canonicalize_any_index_sequence(slices) assert testee == expected + + +def test_domain_intersection(): + # see also tests in unit_tests/test_common.py for tests with 2 domains: `dom0 & dom1` + testee = (common.domain({I: (0, 5)}), common.domain({I: (1, 3)}), common.domain({I: (0, 3)})) + + result = domain_intersection(*testee) + + expected = testee[0] & testee[1] & testee[2] + assert result == expected + + +def test_domain_intersection_empty(): + result = domain_intersection() + assert result == common.Domain() + + +def test_intersect_domains(): + testee = (common.domain({I: (0, 5), J: (1, 2)}), common.domain({I: (1, 3), J: (1, 3)})) + result = restrict_to_intersection(*testee, ignore_dims=J) + + expected = (common.domain({I: (1, 3), J: (1, 2)}), common.domain({I: (1, 3), J: (1, 3)})) + assert result == expected + + +def test_intersect_domains_ignore_dims_none(): + testee = (common.domain({I: (0, 5), J: (1, 2)}), common.domain({I: (1, 3), J: (1, 3)})) + result = restrict_to_intersection(*testee) + + expected = (domain_intersection(*testee),) * 2 + assert result == expected + + +def test_intersect_domains_ignore_all_dims(): + testee = (common.domain({I: (0, 5), J: (1, 2)}), common.domain({I: (1, 3), J: (1, 3)})) + result = restrict_to_intersection(*testee, ignore_dims=(I, J)) + + expected = testee + assert result == expected diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 3ec630c949..7c932533a6 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -14,7 +14,7 @@ import itertools import math import operator -from typing import Callable, Iterable +from typing import Callable, Iterable, Optional import numpy as np import pytest @@ -29,9 +29,9 @@ from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data -IDim = Dimension("IDim") -JDim = Dimension("JDim") -KDim = Dimension("KDim") +D0 = Dimension("D0") +D1 = Dimension("D1") +D2 = Dimension("D2") def nd_array_implementation_params(): @@ -78,6 +78,13 @@ def unary_logical_op(request): yield request.param +def _make_default_domain(shape: tuple[int, ...]) -> Domain: + return common.Domain( + dims=tuple(Dimension(f"D{i}") for i in range(len(shape))), + ranges=tuple(UnitRange(0, s) for s in shape), + ) + + def _make_field_or_scalar( lst: Iterable | core_defs.Scalar, nd_array_implementation, *, domain=None, dtype=None ): @@ -88,9 +95,7 @@ def _make_field_or_scalar( return dtype(lst) buffer = nd_array_implementation.asarray(lst, dtype=dtype) if domain is None: - domain = tuple( - (common.Dimension(f"D{i}"), common.UnitRange(0, s)) for i, s in enumerate(buffer.shape) - ) + domain = _make_default_domain(buffer.shape) return common._field( buffer, domain=domain, @@ -147,16 +152,14 @@ def test_where_builtin_different_domain(nd_array_implementation): true_ = np.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) false_ = np.asarray([7.0, 8.0, 9.0, 10.0], dtype=np.float32) - cond_field = common._field( - nd_array_implementation.asarray(cond), domain=common.domain({JDim: 2}) - ) + cond_field = common._field(nd_array_implementation.asarray(cond), domain=common.domain({D1: 2})) true_field = common._field( nd_array_implementation.asarray(true_), - domain=common.domain({IDim: common.UnitRange(0, 2), JDim: common.UnitRange(-1, 2)}), + domain=common.domain({D0: common.UnitRange(0, 2), D1: common.UnitRange(-1, 2)}), ) false_field = common._field( nd_array_implementation.asarray(false_), - domain=common.domain({JDim: common.UnitRange(-1, 3)}), + domain=common.domain({D1: common.UnitRange(-1, 3)}), ) expected = np.where(cond[np.newaxis, :], true_[:, 1:], false_[np.newaxis, 1:-1]) @@ -259,8 +262,8 @@ def test_unary_arithmetic_ops(unary_arithmetic_op, nd_array_implementation): @pytest.mark.parametrize( "dims,expected_indices", [ - ((IDim,), (slice(5, 10), None)), - ((JDim,), (None, slice(5, 10))), + ((D0,), (slice(5, 10), None)), + ((D1,), (None, slice(5, 10))), ], ) def test_binary_operations_with_intersection(binary_arithmetic_op, dims, expected_indices): @@ -268,7 +271,7 @@ def test_binary_operations_with_intersection(binary_arithmetic_op, dims, expecte arr1_domain = common.Domain(dims=dims, ranges=(UnitRange(0, 10),)) arr2 = np.ones((5, 5)) - arr2_domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 10), UnitRange(5, 10))) + arr2_domain = common.Domain(dims=(D0, D1), ranges=(UnitRange(5, 10), UnitRange(5, 10))) field1 = common._field(arr1, domain=arr1_domain) field2 = common._field(arr2, domain=arr2_domain) @@ -387,33 +390,33 @@ def test_cartesian_remap_implementation(): "new_dims,field,expected_domain", [ (( - (IDim,), + (D0,), common._field( - np.arange(10), domain=common.Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)) + np.arange(10), domain=common.Domain(dims=(D0,), ranges=(UnitRange(0, 10),)) ), - Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)), + Domain(dims=(D0,), ranges=(UnitRange(0, 10),)), )), (( - (IDim, JDim), + (D0, D1), common._field( - np.arange(10), domain=common.Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)) + np.arange(10), domain=common.Domain(dims=(D0,), ranges=(UnitRange(0, 10),)) ), - Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange.infinite())), + Domain(dims=(D0, D1), ranges=(UnitRange(0, 10), UnitRange.infinite())), )), (( - (IDim, JDim), + (D0, D1), common._field( - np.arange(10), domain=common.Domain(dims=(JDim,), ranges=(UnitRange(0, 10),)) + np.arange(10), domain=common.Domain(dims=(D1,), ranges=(UnitRange(0, 10),)) ), - Domain(dims=(IDim, JDim), ranges=(UnitRange.infinite(), UnitRange(0, 10))), + Domain(dims=(D0, D1), ranges=(UnitRange.infinite(), UnitRange(0, 10))), )), (( - (IDim, JDim, KDim), + (D0, D1, D2), common._field( - np.arange(10), domain=common.Domain(dims=(JDim,), ranges=(UnitRange(0, 10),)) + np.arange(10), domain=common.Domain(dims=(D1,), ranges=(UnitRange(0, 10),)) ), Domain( - dims=(IDim, JDim, KDim), + dims=(D0, D1, D2), ranges=(UnitRange.infinite(), UnitRange(0, 10), UnitRange.infinite()), ), )), @@ -427,13 +430,13 @@ def test_field_broadcast(new_dims, field, expected_domain): @pytest.mark.parametrize( "domain_slice", [ - ((IDim, UnitRange(0, 10)),), - common.Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)), + ((D0, UnitRange(0, 10)),), + common.Domain(dims=(D0,), ranges=(UnitRange(0, 10),)), ], ) def test_get_slices_with_named_indices_3d_to_1d(domain_slice): field_domain = common.Domain( - dims=(IDim, JDim, KDim), ranges=(UnitRange(0, 10), UnitRange(0, 10), UnitRange(0, 10)) + dims=(D0, D1, D2), ranges=(UnitRange(0, 10), UnitRange(0, 10), UnitRange(0, 10)) ) slices = _get_slices_from_domain_slice(field_domain, domain_slice) assert slices == (slice(0, 10, None), slice(None), slice(None)) @@ -441,18 +444,18 @@ def test_get_slices_with_named_indices_3d_to_1d(domain_slice): def test_get_slices_with_named_index(): field_domain = common.Domain( - dims=(IDim, JDim, KDim), ranges=(UnitRange(0, 10), UnitRange(0, 10), UnitRange(0, 10)) + dims=(D0, D1, D2), ranges=(UnitRange(0, 10), UnitRange(0, 10), UnitRange(0, 10)) ) - named_index = ((IDim, UnitRange(0, 10)), (JDim, 2), (KDim, 3)) + named_index = ((D0, UnitRange(0, 10)), (D1, 2), (D2, 3)) slices = _get_slices_from_domain_slice(field_domain, named_index) assert slices == (slice(0, 10, None), 2, 3) def test_get_slices_invalid_type(): field_domain = common.Domain( - dims=(IDim, JDim, KDim), ranges=(UnitRange(0, 10), UnitRange(0, 10), UnitRange(0, 10)) + dims=(D0, D1, D2), ranges=(UnitRange(0, 10), UnitRange(0, 10), UnitRange(0, 10)) ) - new_domain = ((IDim, "1"),) + new_domain = ((D0, "1"),) with pytest.raises(ValueError): _get_slices_from_domain_slice(field_domain, new_domain) @@ -462,39 +465,39 @@ def test_get_slices_invalid_type(): [ ( ( - (IDim, UnitRange(7, 9)), - (JDim, UnitRange(8, 10)), + (D0, UnitRange(7, 9)), + (D1, UnitRange(8, 10)), ), - (IDim, JDim, KDim), + (D0, D1, D2), (2, 2, 15), ), ( ( - (IDim, UnitRange(7, 9)), - (KDim, UnitRange(12, 20)), + (D0, UnitRange(7, 9)), + (D2, UnitRange(12, 20)), ), - (IDim, JDim, KDim), + (D0, D1, D2), (2, 10, 8), ), - (common.Domain(dims=(IDim,), ranges=(UnitRange(7, 9),)), (IDim, JDim, KDim), (2, 10, 15)), - (((IDim, 8),), (JDim, KDim), (10, 15)), - (((JDim, 9),), (IDim, KDim), (5, 15)), - (((KDim, 11),), (IDim, JDim), (5, 10)), + (common.Domain(dims=(D0,), ranges=(UnitRange(7, 9),)), (D0, D1, D2), (2, 10, 15)), + (((D0, 8),), (D1, D2), (10, 15)), + (((D1, 9),), (D0, D2), (5, 15)), + (((D2, 11),), (D0, D1), (5, 10)), ( ( - (IDim, 8), - (JDim, UnitRange(8, 10)), + (D0, 8), + (D1, UnitRange(8, 10)), ), - (JDim, KDim), + (D1, D2), (2, 15), ), - ((IDim, 5), (JDim, KDim), (10, 15)), - ((IDim, UnitRange(5, 7)), (IDim, JDim, KDim), (2, 10, 15)), + ((D0, 5), (D1, D2), (10, 15)), + ((D0, UnitRange(5, 7)), (D0, D1, D2), (2, 10, 15)), ], ) def test_absolute_indexing(domain_slice, expected_dimensions, expected_shape): domain = common.Domain( - dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) + dims=(D0, D1, D2), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) ) field = common._field(np.ones((5, 10, 15)), domain=domain) indexed_field = field[domain_slice] @@ -506,11 +509,11 @@ def test_absolute_indexing(domain_slice, expected_dimensions, expected_shape): def test_absolute_indexing_dim_sliced(): domain = common.Domain( - dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) + dims=(D0, D1, D2), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) ) field = common._field(np.ones((5, 10, 15)), domain=domain) - indexed_field_1 = field[JDim(8) : JDim(10), IDim(5) : IDim(9)] - expected = field[(IDim, UnitRange(5, 9)), (JDim, UnitRange(8, 10))] + indexed_field_1 = field[D1(8) : D1(10), D0(5) : D0(9)] + expected = field[(D0, UnitRange(5, 9)), (D1, UnitRange(8, 10))] assert common.is_field(indexed_field_1) assert indexed_field_1 == expected @@ -518,11 +521,11 @@ def test_absolute_indexing_dim_sliced(): def test_absolute_indexing_dim_sliced_single_slice(): domain = common.Domain( - dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) + dims=(D0, D1, D2), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) ) field = common._field(np.ones((5, 10, 15)), domain=domain) - indexed_field_1 = field[KDim(11)] - indexed_field_2 = field[(KDim, 11)] + indexed_field_1 = field[D2(11)] + indexed_field_2 = field[(D2, 11)] assert common.is_field(indexed_field_1) assert indexed_field_1 == indexed_field_2 @@ -530,28 +533,28 @@ def test_absolute_indexing_dim_sliced_single_slice(): def test_absolute_indexing_wrong_dim_sliced(): domain = common.Domain( - dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) + dims=(D0, D1, D2), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) ) field = common._field(np.ones((5, 10, 15)), domain=domain) - with pytest.raises(IndexError, match="Dimensions slicing mismatch between 'JDim' and 'IDim'."): - field[JDim(8) : IDim(10)] + with pytest.raises(IndexError, match="Dimensions slicing mismatch between 'D1' and 'D0'."): + field[D1(8) : D0(10)] def test_absolute_indexing_empty_dim_sliced(): domain = common.Domain( - dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) + dims=(D0, D1, D2), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) ) field = common._field(np.ones((5, 10, 15)), domain=domain) with pytest.raises(IndexError, match="Lower bound needs to be specified"): - field[: IDim(10)] + field[: D0(10)] def test_absolute_indexing_value_return(): - domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(10, 20), UnitRange(5, 15))) + domain = common.Domain(dims=(D0, D1), ranges=(UnitRange(10, 20), UnitRange(5, 15))) field = common._field(np.reshape(np.arange(100, dtype=np.int32), (10, 10)), domain=domain) - named_index = ((IDim, 12), (JDim, 6)) + named_index = ((D0, 12), (D1, 6)) assert common.is_field(field) value = field[named_index] @@ -565,28 +568,28 @@ def test_absolute_indexing_value_return(): ( (slice(None, 5), slice(None, 2)), (5, 2), - Domain((IDim, UnitRange(5, 10)), (JDim, UnitRange(2, 4))), + Domain((D0, UnitRange(5, 10)), (D1, UnitRange(2, 4))), ), - ((slice(None, 5),), (5, 10), Domain((IDim, UnitRange(5, 10)), (JDim, UnitRange(2, 12)))), + ((slice(None, 5),), (5, 10), Domain((D0, UnitRange(5, 10)), (D1, UnitRange(2, 12)))), ( (Ellipsis, 1), (10,), - Domain((IDim, UnitRange(5, 15))), + Domain((D0, UnitRange(5, 15))), ), ( (slice(2, 3), slice(5, 7)), (1, 2), - Domain((IDim, UnitRange(7, 8)), (JDim, UnitRange(7, 9))), + Domain((D0, UnitRange(7, 8)), (D1, UnitRange(7, 9))), ), ( (slice(1, 2), 0), (1,), - Domain((IDim, UnitRange(6, 7))), + Domain((D0, UnitRange(6, 7))), ), ], ) def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): - domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 15), UnitRange(2, 12))) + domain = common.Domain(dims=(D0, D1), ranges=(UnitRange(5, 15), UnitRange(2, 12))) field = common._field(np.ones((10, 10)), domain=domain) indexed_field = field[index] @@ -598,17 +601,17 @@ def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): @pytest.mark.parametrize( "index, expected_shape, expected_domain", [ - ((1, slice(None), 2), (15,), Domain(dims=(JDim,), ranges=(UnitRange(10, 25),))), + ((1, slice(None), 2), (15,), Domain(dims=(D1,), ranges=(UnitRange(10, 25),))), ( (slice(None), slice(None), 2), (10, 15), - Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 15), UnitRange(10, 25))), + Domain(dims=(D0, D1), ranges=(UnitRange(5, 15), UnitRange(10, 25))), ), ( (slice(None),), (10, 15, 10), Domain( - dims=(IDim, JDim, KDim), + dims=(D0, D1, D2), ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), ), ), @@ -616,7 +619,7 @@ def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): (slice(None), slice(None), slice(None)), (10, 15, 10), Domain( - dims=(IDim, JDim, KDim), + dims=(D0, D1, D2), ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), ), ), @@ -624,16 +627,16 @@ def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): (slice(None)), (10, 15, 10), Domain( - dims=(IDim, JDim, KDim), + dims=(D0, D1, D2), ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), ), ), - ((0, Ellipsis, 0), (15,), Domain(dims=(JDim,), ranges=(UnitRange(10, 25),))), + ((0, Ellipsis, 0), (15,), Domain(dims=(D1,), ranges=(UnitRange(10, 25),))), ( Ellipsis, (10, 15, 10), Domain( - dims=(IDim, JDim, KDim), + dims=(D0, D1, D2), ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), ), ), @@ -641,7 +644,7 @@ def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): ) def test_relative_indexing_slice_3D(index, expected_shape, expected_domain): domain = common.Domain( - dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)) + dims=(D0, D1, D2), ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)) ) field = common._field(np.ones((10, 15, 10)), domain=domain) indexed_field = field[index] @@ -659,7 +662,7 @@ def test_relative_indexing_slice_3D(index, expected_shape, expected_domain): ], ) def test_relative_indexing_value_return(index, expected_value): - domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 15), UnitRange(2, 12))) + domain = common.Domain(dims=(D0, D1), ranges=(UnitRange(5, 15), UnitRange(2, 12))) field = common._field(np.reshape(np.arange(100, dtype=int), (10, 10)), domain=domain) indexed_field = field[index] @@ -668,16 +671,16 @@ def test_relative_indexing_value_return(index, expected_value): @pytest.mark.parametrize("lazy_slice", [lambda f: f[13], lambda f: f[:5, :3, :2]]) def test_relative_indexing_out_of_bounds(lazy_slice): - domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(3, 13), UnitRange(-5, 5))) + domain = common.Domain(dims=(D0, D1), ranges=(UnitRange(3, 13), UnitRange(-5, 5))) field = common._field(np.ones((10, 10)), domain=domain) with pytest.raises((embedded_exceptions.IndexOutOfBounds, IndexError)): lazy_slice(field) -@pytest.mark.parametrize("index", [IDim, "1", (IDim, JDim)]) +@pytest.mark.parametrize("index", [D0, "1", (D0, D1)]) def test_field_unsupported_index(index): - domain = common.Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)) + domain = common.Domain(dims=(D0,), ranges=(UnitRange(0, 10),)) field = common._field(np.ones((10,)), domain=domain) with pytest.raises(IndexError, match="Unsupported index type"): field[index] @@ -690,14 +693,14 @@ def test_field_unsupported_index(index): ((1, slice(None)), np.ones((10,)) * 42.0), ( (1, slice(None)), - common._field(np.ones((10,)) * 42.0, domain=common.Domain((JDim, UnitRange(0, 10)))), + common._field(np.ones((10,)) * 42.0, domain=common.Domain((D1, UnitRange(0, 10)))), ), ], ) def test_setitem(index, value): field = common._field( np.arange(100).reshape(10, 10), - domain=common.Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange(0, 10))), + domain=common.Domain(dims=(D0, D1), ranges=(UnitRange(0, 10), UnitRange(0, 10))), ) expected = np.copy(field.asnumpy()) @@ -711,11 +714,11 @@ def test_setitem(index, value): def test_setitem_wrong_domain(): field = common._field( np.arange(100).reshape(10, 10), - domain=common.Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange(0, 10))), + domain=common.Domain(dims=(D0, D1), ranges=(UnitRange(0, 10), UnitRange(0, 10))), ) value_incompatible = common._field( - np.ones((10,)) * 42.0, domain=common.Domain((JDim, UnitRange(-5, 5))) + np.ones((10,)) * 42.0, domain=common.Domain((D1, UnitRange(-5, 5))) ) with pytest.raises(ValueError, match=r"Incompatible 'Domain'.*"): @@ -908,3 +911,120 @@ def test_hypercube(index_array, expected): result = nd_array_field._hypercube(index_array, image_range, np, skip_value) assert result == expected + + +@pytest.mark.parametrize( + "mask_data, true_data, false_data, expected", + [ + ( + ([True, False, True, False, True], None), + ([1, 2, 3, 4, 5], None), + ([6, 7, 8, 9, 10], None), + ([1, 7, 3, 9, 5], None), + ), + ( + ([True, False, True, False], None), + ([1, 2, 3, 4, 5], {D0: (-2, 3)}), + ([6, 7, 8, 9], {D0: (1, 5)}), + ([3, 6, 5, 8], {D0: (0, 4)}), + ), + ( + ([True, False, True, False, True], None), + ([1, 2, 3, 4, 5], {D0: (-2, 3)}), + ([6, 7, 8, 9, 10], {D0: (1, 6)}), + ([3, 6, 5, 8], {D0: (0, 4)}), + ), + ( + ([True, False, True, False, True], None), + ([1, 2, 3, 4, 5], {D0: (-2, 3)}), + ([6, 7, 8, 9, 10], {D0: (2, 7)}), + None, + ), + ( + # empty result domain + ([True, False, True, False, True], None), + ([1, 2, 3, 4, 5], {D0: (-5, 0)}), + ([6, 7, 8, 9, 10], {D0: (5, 10)}), + ([], {D0: (0, 0)}), + ), + ( + ([True, False, True, False, True], None), + ([1, 2, 3, 4, 5], {D0: (-4, 1)}), + ([6, 7, 8, 9, 10], {D0: (5, 10)}), + ([5], {D0: (0, 1)}), + ), + ( + # broadcasting true_field + ([True, False, True, False, True], {D0: 5}), + ([1, 2, 3, 4, 5], {D0: 5}), + ([[6, 11], [7, 12], [8, 13], [9, 14], [10, 15]], {D0: 5, D1: 2}), + ([[1, 1], [7, 12], [3, 3], [9, 14], [5, 5]], {D0: 5, D1: 2}), + ), + ( + ([True, False, True, False, True], None), + (42, None), + ([6, 7, 8, 9, 10], None), + ([42, 7, 42, 9, 42], None), + ), + ( + # parts of mask_ranges are concatenated + ([True, True, False, False], None), + ([1, 2], {D0: (1, 3)}), + ([3, 4], {D0: (1, 3)}), + ([1, 4], {D0: (1, 3)}), + ), + ( + # parts of mask_ranges are concatenated and yield non-contiguous domain + ([True, False, True, False], None), + ([1, 2], {D0: (0, 2)}), + ([3, 4], {D0: (2, 4)}), + None, + ), + ], +) +def test_concat_where( + nd_array_implementation, + mask_data: tuple[list[bool], Optional[common.DomainLike]], + true_data: tuple[list[int], Optional[common.DomainLike]], + false_data: tuple[list[int], Optional[common.DomainLike]], + expected: Optional[tuple[list[int], Optional[common.DomainLike]]], +): + mask_lst, mask_domain = mask_data + true_lst, true_domain = true_data + false_lst, false_domain = false_data + + mask_field = _make_field_or_scalar( + mask_lst, + nd_array_implementation=nd_array_implementation, + domain=common.domain(mask_domain) if mask_domain is not None else None, + dtype=bool, + ) + true_field = _make_field_or_scalar( + true_lst, + nd_array_implementation=nd_array_implementation, + domain=common.domain(true_domain) if true_domain is not None else None, + dtype=np.int32, + ) + false_field = _make_field_or_scalar( + false_lst, + nd_array_implementation=nd_array_implementation, + domain=common.domain(false_domain) if false_domain is not None else None, + dtype=np.int32, + ) + + if expected is None: + with pytest.raises(embedded_exceptions.NonContiguousDomain): + nd_array_field._concat_where(mask_field, true_field, false_field) + else: + expected_lst, expected_domain_like = expected + expected_array = np.asarray(expected_lst) + expected_domain = ( + common.domain(expected_domain_like) + if expected_domain_like is not None + else _make_default_domain(expected_array.shape) + ) + + result = nd_array_field._concat_where(mask_field, true_field, false_field) + + assert expected_domain == result.domain + np.testing.assert_allclose(result.asnumpy(), expected_array) diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index 7650e90c3c..ce940131c3 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -14,7 +14,6 @@ import operator from typing import Optional, Pattern -import numpy as np import pytest from gt4py.next.common import ( @@ -22,7 +21,6 @@ DimensionKind, Domain, Infinity, - NamedRange, UnitRange, domain, named_range, @@ -92,10 +90,12 @@ def test_unbounded_max_min(value): assert min(Infinity.NEGATIVE, value) == Infinity.NEGATIVE -def test_empty_range(): +@pytest.mark.parametrize("empty_range", [UnitRange(1, 0), UnitRange(1, -1)]) +def test_empty_range(empty_range): expected = UnitRange(0, 0) - assert UnitRange(1, 1) == expected - assert UnitRange(1, -1) == expected + + assert empty_range == expected + assert empty_range.is_empty() @pytest.fixture @@ -257,6 +257,20 @@ def test_domain_length(a_domain): assert len(a_domain) == 3 +@pytest.mark.parametrize( + "empty_domain, expected", + [ + (Domain(), False), + (Domain((IDim, UnitRange(0, 10))), False), + (Domain((IDim, UnitRange(0, 0))), True), + (Domain((IDim, UnitRange(0, 0)), (JDim, UnitRange(0, 1))), True), + (Domain((IDim, UnitRange(0, 1)), (JDim, UnitRange(0, 0))), True), + ], +) +def test_empty_domain(empty_domain, expected): + assert empty_domain.is_empty() == expected + + @pytest.mark.parametrize( "domain_like", [