Skip to content

Commit

Permalink
feat[next]: concat_where for boundary conditions (GridTools#1468)
Browse files Browse the repository at this point in the history
Introduces `concat_where` to be used for boundary conditions.

`where` will intersect all 3 fields. Therefore `where(klevel == 0,
boundary_layer, interior)` will not work if `boundary_layer` is only
defined on `klevel == 0`.

`concat_where` will concatenate the fields that are selected by the mask
regions. The `mask` is currently required to be 1 dimensional.
  • Loading branch information
havogt authored and edopao committed Mar 18, 2024
1 parent ca2c70a commit 4dcb4ec
Show file tree
Hide file tree
Showing 15 changed files with 860 additions and 170 deletions.
8 changes: 8 additions & 0 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ def is_left_finite(cls, obj: UnitRange) -> TypeGuard[UnitRange[int, _Right]]:
# classmethod since TypeGuards requires the guarded obj as separate argument
return obj.start is not Infinity.NEGATIVE

def is_empty(self) -> bool:
return (
self.start == 0 and self.stop == 0
) # post_init ensures that empty is represented as UnitRange(0, 0)

def __repr__(self) -> str:
return f"UnitRange({self.start}, {self.stop})"

Expand Down Expand Up @@ -428,6 +433,9 @@ def is_finite(cls, obj: Domain) -> TypeGuard[FiniteDomain]:
# classmethod since TypeGuards requires the guarded obj as separate argument
return all(UnitRange.is_finite(rng) for rng in obj.ranges)

def is_empty(self) -> bool:
return any(rng.is_empty() for rng in self.ranges)

@overload
def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: ...

Expand Down
45 changes: 44 additions & 1 deletion src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,57 @@ def _absolute_sub_domain(
return common.Domain(*named_ranges)


def intersect_domains(*domains: common.Domain) -> common.Domain:
def domain_intersection(
*domains: common.Domain,
) -> common.Domain:
"""
Return the intersection of the given domains.
Example:
>>> I = common.Dimension("I")
>>> domain_intersection(
... common.domain({I: (0, 5)}), common.domain({I: (1, 3)})
... ) # doctest: +ELLIPSIS
Domain(dims=(Dimension(value='I', ...), ranges=(UnitRange(1, 3),))
"""
return functools.reduce(
operator.and_,
domains,
common.Domain(dims=tuple(), ranges=tuple()),
)


def restrict_to_intersection(
*domains: common.Domain,
ignore_dims: Optional[common.Dimension | tuple[common.Dimension, ...]] = None,
) -> tuple[common.Domain, ...]:
"""
Return the with each other intersected domains, ignoring 'ignore_dims' dimensions for the intersection.
Example:
>>> I = common.Dimension("I")
>>> J = common.Dimension("J")
>>> res = restrict_to_intersection(
... common.domain({I: (0, 5), J: (1, 2)}),
... common.domain({I: (1, 3), J: (0, 3)}),
... ignore_dims=J,
... )
>>> assert res == (common.domain({I: (1, 3), J: (1, 2)}), common.domain({I: (1, 3), J: (0, 3)}))
"""
ignore_dims_tuple = ignore_dims if isinstance(ignore_dims, tuple) else (ignore_dims,)
intersection_without_ignore_dims = domain_intersection(*[
common.Domain(*[(d, r) for d, r in domain if d not in ignore_dims_tuple])
for domain in domains
])
return tuple(
common.Domain(*[
(d, r if d in ignore_dims_tuple else intersection_without_ignore_dims[d][1])
for d, r in domain
])
for domain in domains
)


def iterate_domain(domain: common.Domain):
for i in itertools.product(*[list(r) for r in domain.ranges]):
yield tuple(zip(domain.dims, i))
Expand Down
10 changes: 10 additions & 0 deletions src/gt4py/next/embedded/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,13 @@ def __init__(
self.indices = indices
self.index = index
self.dim = dim


class NonContiguousDomain(gt4py_exceptions.GT4PyError):
"""Describes an error where a domain would become non-contiguous after an operation."""

detail: str

def __init__(self, detail: str):
super().__init__(f"Operation would result in a non-contiguous domain: `{detail}`.")
self.detail = detail
213 changes: 195 additions & 18 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@
import functools
from collections.abc import Callable, Sequence
from types import ModuleType
from typing import ClassVar
from typing import ClassVar, Iterable

import numpy as np
from numpy import typing as npt

from gt4py._core import definitions as core_defs
from gt4py.eve.extended_typing import Any, Never, Optional, ParamSpec, TypeAlias, TypeVar
from gt4py.next import common
from gt4py.next.embedded import common as embedded_common, context as embedded_context
from gt4py.next.ffront import fbuiltins
from gt4py.next.embedded import (
common as embedded_common,
context as embedded_context,
exceptions as embedded_exceptions,
)
from gt4py.next.ffront import experimental, fbuiltins
from gt4py.next.iterator import embedded as itir_embedded


Expand All @@ -42,20 +46,22 @@
jnp: Optional[ModuleType] = None # type:ignore[no-redef]


def _get_nd_array_class(*fields: common.Field | core_defs.Scalar) -> type[NdArrayField]:
for f in fields:
if isinstance(f, NdArrayField):
return f.__class__
raise AssertionError("No 'NdArrayField' found in the arguments.")


def _make_builtin(
builtin_name: str, array_builtin_name: str, reverse=False
) -> Callable[..., NdArrayField]:
def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField:
first = None
for f in fields:
if isinstance(f, NdArrayField):
first = f
break
assert first is not None
xp = first.__class__.array_ns
cls_ = _get_nd_array_class(*fields)
xp = cls_.array_ns
op = getattr(xp, array_builtin_name)

domain_intersection = embedded_common.intersect_domains(*[
domain_intersection = embedded_common.domain_intersection(*[
f.domain for f in fields if common.is_field(f)
])

Expand All @@ -76,7 +82,7 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField:
if reverse:
transformed.reverse()
new_data = op(*transformed)
return first.__class__.from_array(new_data, domain=domain_intersection)
return cls_.from_array(new_data, domain=domain_intersection)

_builtin_op.__name__ = builtin_name
return _builtin_op
Expand Down Expand Up @@ -423,10 +429,7 @@ def inverse_image(
if relative_ranges is None:
raise ValueError("Restriction generates non-contiguous dimensions.")

new_dims = [
common.named_range((d, rr + ar.start))
for d, ar, rr in zip(self.domain.dims, self.domain.ranges, relative_ranges)
]
new_dims = _relative_ranges_to_domain(relative_ranges, self.domain)

self._cache[cache_key] = new_dims

Expand All @@ -448,6 +451,14 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field:
__getitem__ = restrict


def _relative_ranges_to_domain(
relative_ranges: Sequence[common.UnitRange], domain: common.Domain
) -> common.Domain:
return common.Domain(
dims=domain.dims, ranges=[rr + ar.start for ar, rr in zip(domain.ranges, relative_ranges)]
)


def _hypercube(
index_array: core_defs.NDArrayObject,
image_range: common.UnitRange,
Expand Down Expand Up @@ -519,6 +530,172 @@ def _hypercube(
NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where"))


def _compute_mask_ranges(
mask: core_defs.NDArrayObject,
) -> list[tuple[bool, common.UnitRange]]:
"""Take a 1-dimensional mask and return a sequence of mappings from boolean values to ranges."""
# TODO: does it make sense to upgrade this naive algorithm to numpy?
assert mask.ndim == 1
cur = bool(mask[0].item())
ind = 0
res = []
for i in range(1, mask.shape[0]):
if (
mask_i := bool(mask[i].item())
) != cur: # `.item()` to extract the scalar from a 0-d array in case of e.g. cupy
res.append((cur, common.UnitRange(ind, i)))
cur = mask_i
ind = i
res.append((cur, common.UnitRange(ind, mask.shape[0])))
return res


def _trim_empty_domains(
lst: Iterable[tuple[bool, common.Domain]],
) -> list[tuple[bool, common.Domain]]:
"""Remove empty domains from beginning and end of the list."""
lst = list(lst)
if not lst:
return lst
if lst[0][1].is_empty():
return _trim_empty_domains(lst[1:])
if lst[-1][1].is_empty():
return _trim_empty_domains(lst[:-1])
return lst


def _to_field(
value: common.Field | core_defs.Scalar, nd_array_field_type: type[NdArrayField]
) -> common.Field:
# TODO(havogt): this function is only to workaround broadcasting of scalars, once we have a ConstantField, we can broadcast to that directly
return (
value
if common.is_field(value)
else nd_array_field_type.from_array(
nd_array_field_type.array_ns.asarray(value), domain=common.Domain()
)
)


def _intersect_fields(
*fields: common.Field | core_defs.Scalar,
ignore_dims: Optional[common.Dimension | tuple[common.Dimension, ...]] = None,
) -> tuple[common.Field, ...]:
# TODO(havogt): this function could be moved to common, but then requires a broadcast implementation for all field implementations;
# currently blocked, because requiring the `_to_field` function, see comment there.
nd_array_class = _get_nd_array_class(*fields)
promoted_dims = common.promote_dims(*(f.domain.dims for f in fields if common.is_field(f)))
broadcasted_fields = [_broadcast(_to_field(f, nd_array_class), promoted_dims) for f in fields]

intersected_domains = embedded_common.restrict_to_intersection(
*[f.domain for f in broadcasted_fields], ignore_dims=ignore_dims
)

return tuple(
nd_array_class.from_array(
f.ndarray[_get_slices_from_domain_slice(f.domain, intersected_domain)],
domain=intersected_domain,
)
for f, intersected_domain in zip(broadcasted_fields, intersected_domains, strict=True)
)


def _stack_domains(*domains: common.Domain, dim: common.Dimension) -> Optional[common.Domain]:
if not domains:
return common.Domain()
dim_start = domains[0][dim][1].start
dim_stop = dim_start
for domain in domains:
if not domain[dim][1].start == dim_stop:
return None
else:
dim_stop = domain[dim][1].stop
return domains[0].replace(dim, (dim, common.UnitRange(dim_start, dim_stop)))


def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field:
# TODO(havogt): this function could be extended to a general concat
# currently only concatenate along the given dimension and requires the fields to be ordered

if (
len(fields) > 1
and not embedded_common.domain_intersection(*[f.domain for f in fields]).is_empty()
):
raise ValueError("Fields to concatenate must not overlap.")
new_domain = _stack_domains(*[f.domain for f in fields], dim=dim)
if new_domain is None:
raise embedded_exceptions.NonContiguousDomain(f"Cannot concatenate fields along {dim}.")
nd_array_class = _get_nd_array_class(*fields)
return nd_array_class.from_array(
nd_array_class.array_ns.concatenate(
[nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape) for f in fields],
axis=new_domain.dim_index(dim),
),
domain=new_domain,
)


def _concat_where(
mask_field: common.Field, true_field: common.Field, false_field: common.Field
) -> common.Field:
cls_ = _get_nd_array_class(mask_field, true_field, false_field)
xp = cls_.array_ns
if mask_field.domain.ndim != 1:
raise NotImplementedError(
"'concat_where': Can only concatenate fields with a 1-dimensional mask."
)
mask_dim = mask_field.domain.dims[0]

# intersect the field in dimensions orthogonal to the mask, then all slices in the mask field have same domain
t_broadcasted, f_broadcasted = _intersect_fields(true_field, false_field, ignore_dims=mask_dim)

# TODO(havogt): for clarity, most of it could be implemented on named_range in the masked dimension, but we currently lack the utils
# compute the consecutive ranges (first relative, then domain) of true and false values
mask_values_to_relative_range_mapping: Iterable[tuple[bool, common.UnitRange]] = (
_compute_mask_ranges(mask_field.ndarray)
)
mask_values_to_domain_mapping: Iterable[tuple[bool, common.Domain]] = (
(mask, _relative_ranges_to_domain((relative_range,), mask_field.domain))
for mask, relative_range in mask_values_to_relative_range_mapping
)
# mask domains intersected with the respective fields
mask_values_to_intersected_domains_mapping: Iterable[tuple[bool, common.Domain]] = (
(
mask_value,
embedded_common.domain_intersection(
t_broadcasted.domain if mask_value else f_broadcasted.domain, mask_domain
),
)
for mask_value, mask_domain in mask_values_to_domain_mapping
)

# remove the empty domains from the beginning and end
mask_values_to_intersected_domains_mapping = _trim_empty_domains(
mask_values_to_intersected_domains_mapping
)
if any(d.is_empty() for _, d in mask_values_to_intersected_domains_mapping):
raise embedded_exceptions.NonContiguousDomain(
f"In 'concat_where', cannot concatenate the following 'Domain's: {[d for _, d in mask_values_to_intersected_domains_mapping]}."
)

# slice the fields with the domain ranges
transformed = [
t_broadcasted[d] if v else f_broadcasted[d]
for v, d in mask_values_to_intersected_domains_mapping
]

# stack the fields together
if transformed:
return _concat(*transformed, dim=mask_dim)
else:
result_domain = common.Domain((mask_dim, common.UnitRange(0, 0)))
result_array = xp.empty(result_domain.shape)
return cls_.from_array(result_array, domain=result_domain)


NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[has-type]


def _make_reduction(
builtin_name: str, array_builtin_name: str, initial_value_op: Callable
) -> Callable[
Expand Down Expand Up @@ -635,7 +812,7 @@ def __setitem__(
common._field.register(jnp.ndarray, JaxArrayField.from_array)


def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field:
def _broadcast(field: common.Field, new_dimensions: Sequence[common.Dimension]) -> common.Field:
if field.domain.dims == new_dimensions:
return field
domain_slice: list[slice | None] = []
Expand All @@ -645,7 +822,7 @@ def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]
domain_slice.append(slice(None))
named_ranges.append((dim, field.domain[pos][1]))
else:
domain_slice.append(np.newaxis)
domain_slice.append(None) # np.newaxis
named_ranges.append((dim, common.UnitRange.infinite()))
return common._field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges))

Expand Down
Loading

0 comments on commit 4dcb4ec

Please sign in to comment.