Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: as_offset implementation in embedded #1397

Open
wants to merge 45 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
5b6b8b7
as_offset implementation in embedded
nfarabullini Dec 13, 2023
81c0141
Merge branch 'main' into as_offset_embedded
nfarabullini Dec 13, 2023
1acb1d8
edit to exclusion_matrices
nfarabullini Dec 13, 2023
b94e81c
edit to exclusion_matrices
nfarabullini Dec 13, 2023
f815712
resolved some pre-commit errors
nfarabullini Dec 13, 2023
67f8117
resolved some pre-commit errors
nfarabullini Dec 13, 2023
e17ff41
implemented EXPERIMENTAL_FUN_BUILTIN_NAMES
nfarabullini Dec 14, 2023
0e61be2
edits for as_offset
nfarabullini Jan 4, 2024
0219e72
additional cleanup
nfarabullini Jan 4, 2024
762d7eb
additional cleanup
nfarabullini Jan 4, 2024
fa1c588
reverted a couple of edits
nfarabullini Jan 4, 2024
38c052c
ran pre-commit
nfarabullini Jan 4, 2024
e8d6e5e
edit to test
nfarabullini Jan 4, 2024
782375b
edit for md dimensional field
nfarabullini Jan 5, 2024
654b14d
replaced connectivity with restricted
nfarabullini Jan 5, 2024
186b81d
edit to as_offset in experimental
nfarabullini Jan 5, 2024
e50eb64
small clenaup
nfarabullini Jan 5, 2024
09a4c44
updated code to checked vars
nfarabullini Jan 5, 2024
9f2bcc6
ran pre-commit
nfarabullini Jan 5, 2024
90f6796
removed [0][0] indexing
nfarabullini Jan 5, 2024
b19ba78
edits for tests and others for as_offset
nfarabullini Jan 9, 2024
c5464f3
edits to test
nfarabullini Jan 15, 2024
7c3e9cb
ran pre-commit
nfarabullini Jan 15, 2024
2ea83f1
edits
nfarabullini Jan 15, 2024
6387049
edits for other offsets
nfarabullini Jan 15, 2024
1e892b4
changes to path
nfarabullini Jan 15, 2024
2f3d9f6
edit for dace backend
nfarabullini Jan 16, 2024
b941b92
update with main
nfarabullini Jan 16, 2024
81fc838
trout attempt for fieldoffset in cache
nfarabullini Jan 16, 2024
7261ae7
edit suggested by Edoardo
nfarabullini Jan 16, 2024
ba1a91f
edit to offset_invariants
nfarabullini Jan 16, 2024
e074b10
edits following Hannes' review
nfarabullini Jan 17, 2024
ea7f20f
ran pre-commit
nfarabullini Jan 17, 2024
c483a0f
commented test out
nfarabullini Jan 18, 2024
b363cc2
placed test back
nfarabullini Jan 18, 2024
391508b
edit to failing test
nfarabullini Jan 18, 2024
c668dc8
Update tests/next_tests/integration_tests/feature_tests/ffront_tests/…
nfarabullini Jan 18, 2024
83a4b38
Merge branch 'main' of https://github.com/nfarabullini/gt4py into as_…
nfarabullini Jan 18, 2024
7e26c20
edits to dimensions refactoring
nfarabullini Jan 22, 2024
8bc6725
minor cleanup
nfarabullini Jan 22, 2024
e3e8db8
edit to test
nfarabullini Jan 23, 2024
5b6e553
Merge branch 'ruff-config' into as_offset_embedded
egparedes Mar 4, 2024
dacc6d5
Merge new style lint config
egparedes Mar 4, 2024
34fdeb7
Merge branch 'ruff-config' into as_offset_embedded
egparedes Mar 4, 2024
84abf3d
Recover deleted pieces after merging with main
egparedes Mar 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,8 +499,12 @@ def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain:
if index < 0:
index += len(self.dims)
new_dims, new_ranges = zip(*named_ranges) if len(named_ranges) > 0 else ((), ())
dims = self.dims[:index] + new_dims + self.dims[index + 1 :]
ranges = self.ranges[:index] + new_ranges + self.ranges[index + 1 :]
if new_dims == self.dims:
dims = new_dims
ranges = new_ranges
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved
else:
dims = self.dims[:index] + new_dims + self.dims[index + 1 :]
ranges = self.ranges[:index] + new_ranges + self.ranges[index + 1 :]

return Domain(dims=dims, ranges=ranges)

Expand Down
45 changes: 42 additions & 3 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from gt4py.eve.extended_typing import Any, Never, Optional, ParamSpec, TypeAlias, TypeVar
from gt4py.next import common
from gt4py.next.embedded import common as embedded_common
from gt4py.next.ffront import fbuiltins
from gt4py.next.ffront import experimental, fbuiltins


try:
Expand Down Expand Up @@ -74,6 +74,24 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField:
return _builtin_op


def _take_mdim(
input_arr: core_defs.NDArrayObject,
restricted_connectivity: core_defs.NDArrayObject,
new_domain: common.Domain,
dim: common.Dimension,
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved
) -> core_defs.NDArrayObject:
offset_abs = [
restricted_connectivity if d == dim else np.indices(restricted_connectivity.shape)[d_i]
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved
for d_i, d in enumerate(new_domain.dims)
]
new_buffer_flat = np.take(
np.asarray(input_arr).flatten(),
np.ravel_multi_index(tuple(offset_abs), input_arr.shape).flatten(),
)
new_buffer = new_buffer_flat.reshape(restricted_connectivity.shape)
return new_buffer


_Value: TypeAlias = common.Field | core_defs.ScalarT
_P = ParamSpec("_P")
_R = TypeVar("_R", _Value, tuple[_Value, ...])
Expand Down Expand Up @@ -198,8 +216,11 @@ def remap(
# then compute the index array
xp = self.array_ns
new_idx_array = xp.asarray(restricted_connectivity.ndarray) - current_range.start
# finally, take the new array
new_buffer = xp.take(self._ndarray, new_idx_array, axis=dim_idx)
if self._ndarray.ndim > 1 and restricted_connectivity_domain == new_domain:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the second part of this condition? restricted_connectivity_domain == new_domain

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to avoid entering this condition in cases like:

    @gtx.field_operator
    def testee(a: gtx.Field[[Vertex, KDim], float]) -> gtx.Field[[Edge, KDim], float]:
        return a(E2V[0])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain this if else branch and are you sure all cases are handled? I am confused...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using FieldOffsets, only the specific dimensions related to the offset are taken into account.
Say I have this field_operator:

    @gtx.field_operator
    def testee(a: gtx.Field[[Edge, KDim], int]) -> gtx.Field[[Vertex, KDim], int]:
        tmp = neighbor_sum(a(V2E), axis=V2EDim)
        return tmp

Here the restricted_connectivity_domain will be over [Edge, V2E] and will exclude KDim. In this case using the regular xp.take works.

When using as_offset, xp.take is also ok to use if the offset_field contains only one dimension.

However, when restricted_connectivity_domain contains multiple dimensions that are exactly the same as in new_domain, we have seen that xp.take does not work and hence had to create _take_mdim

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but what about restricted_connectivity_domain.dims == new_domain.dims, but ranges are different?

new_buffer = _take_mdim(self._ndarray, new_idx_array, new_domain, dim)
else:
# finally, take the new array
new_buffer = xp.take(self._ndarray, new_idx_array, axis=dim_idx)

return self.__class__.from_array(new_buffer, domain=new_domain, dtype=self.dtype)

Expand Down Expand Up @@ -592,6 +613,24 @@ def _astype(field: common.Field | core_defs.ScalarT | tuple, type_: type) -> NdA
NdArrayField.register_builtin_func(fbuiltins.astype, _astype)


def _as_offset(offset_: fbuiltins.FieldOffset, field: common.Field) -> NdArrayConnectivityField:
if isinstance(field, NdArrayField):
# change field.ndarray from relative to absolute
offset_dim = np.squeeze(
np.where(list(map(lambda x: x == offset_.source, field.domain.dims)))
).item()
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved
new_connectivity = np.indices(field.ndarray.shape)[offset_dim] + field.ndarray
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved
return NumPyArrayConnectivityField.from_array(
new_connectivity, codomain=offset_.source, domain=field.domain
)
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved
raise AssertionError(
"This is the NdArrayConnectivityField implementation of `experimental.as_offset`."
)


NdArrayField.register_builtin_func(experimental.as_offset, _as_offset) # type: ignore[has-type] #type specified in experimental


def _get_slices_from_domain_slice(
domain: common.Domain,
domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any],
Expand Down
37 changes: 15 additions & 22 deletions src/gt4py/next/ffront/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,23 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from dataclasses import dataclass
import numpy as np

from gt4py.next.type_system import type_specifications as ts
from gt4py.next import common
from gt4py.next.ffront.fbuiltins import BuiltInFunction, FieldOffset


@dataclass
class BuiltInFunction:
__gt_type: ts.FunctionType
@BuiltInFunction
def as_offset(
offset_: FieldOffset,
field: common.Field,
/,
) -> common.ConnectivityField:
offset_dim = np.squeeze(
np.where(list(map(lambda x: x == offset_.source, field.domain.dims)))
).item()
new_connectivity = np.indices(field.ndarray.shape)[offset_dim] + field.ndarray
return common.connectivity(new_connectivity, codomain=offset_.source, domain=field.domain)
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved

def __call__(self, *args, **kwargs):
"""Act as an empty place holder for the built in function."""

def __gt_type__(self):
return self.__gt_type


as_offset = BuiltInFunction(
ts.FunctionType(
pos_only_args=[
ts.DeferredType(constraint=ts.OffsetType),
ts.DeferredType(constraint=ts.FieldType),
],
pos_or_kw_args={},
kw_only_args={},
returns=ts.DeferredType(constraint=ts.OffsetType),
)
)
EXPERIMENTAL_FUN_BUILTIN_NAMES = ["as_offset"]
6 changes: 4 additions & 2 deletions src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from gt4py._core import definitions as core_defs
from gt4py.next import common, embedded
from gt4py.next.common import Dimension, Field # noqa: F401 # direct import for TYPE_BUILTINS
from gt4py.next.ffront.experimental import as_offset # noqa: F401
from gt4py.next.iterator import runtime
from gt4py.next.type_system import type_specifications as ts

Expand Down Expand Up @@ -58,6 +57,10 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSp
return ts.FieldType
elif t is common.Dimension:
return ts.DimensionType
elif t is FieldOffset:
return ts.OffsetType
elif t is common.ConnectivityField:
return ts.OffsetType
havogt marked this conversation as resolved.
Show resolved Hide resolved
elif t is core_defs.ScalarT:
return ts.ScalarType
elif t is type:
Expand Down Expand Up @@ -300,7 +303,6 @@ def impl(
"broadcast",
"where",
"astype",
"as_offset",
] + MATH_BUILTIN_NAMES

BUILTIN_NAMES = TYPE_BUILTIN_NAMES + FUN_BUILTIN_NAMES
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from gt4py.next.common import DimensionKind
from gt4py.next.ffront import ( # noqa
dialect_ast_enums,
experimental,
fbuiltins,
type_info as ti_ffront,
type_specifications as ts_ffront,
Expand Down Expand Up @@ -717,7 +718,8 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call:
isinstance(new_func.type, ts.FunctionType)
and not type_info.is_concrete(return_type)
and isinstance(new_func, foast.Name)
and new_func.id in fbuiltins.FUN_BUILTIN_NAMES
and new_func.id
in (fbuiltins.FUN_BUILTIN_NAMES + experimental.EXPERIMENTAL_FUN_BUILTIN_NAMES)
):
visitor = getattr(self, f"_visit_{new_func.id}")
return visitor(new_node, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,24 +167,29 @@ def get_cache_id(
column_axis: Optional[common.Dimension],
offset_provider: Mapping[str, Any],
) -> str:
max_neighbors = [
(k, v.max_neighbors)
for k, v in offset_provider.items()
def offset_invariants(offset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are the changes in this file needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as_offset was breaking with the dace backend. These changes were discussed with @edopao and @petiaccja

The dace backend does not consider offset_dim to compute the key to access the build cache

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have included this change in one of my PRs, which I have already merged. So you won't need to do any change in this file after you rebase (but you'll get a rebase conflict).

if isinstance(
v,
offset,
(
itir_embedded.NeighborTableOffsetProvider,
itir_embedded.StridedNeighborOffsetProvider,
),
)
):
return offset.origin_axis, offset.neighbor_axis, offset.max_neighbors
if isinstance(offset, common.Dimension):
return (offset,)
return tuple()

offset_cache_keys = [
(name, offset_invariants(offset)) for name, offset in offset_provider.items()
]
cache_id_args = [
str(arg)
for arg in (
program,
*arg_types,
column_axis,
*max_neighbors,
*offset_cache_keys,
)
]
m = hashlib.sha256()
Expand Down
1 change: 0 additions & 1 deletion tests/next_tests/exclusion_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
(USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE),
]
EMBEDDED_SKIP_LIST = [
(USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE),
(CHECKS_SPECIFIC_ERROR, XFAIL, UNSUPPORTED_MESSAGE),
]
GTFN_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [
Expand Down
Loading