Skip to content

Commit

Permalink
feat[next]: Enable tests for embedded with cupy (#1372)
Browse files Browse the repository at this point in the history
Introduces mechanism in tests for having different allocators for the same (`None`) backend.

Fixes:
- The resulting buffer for scan is deduced from the buffer type of the arguments, if there are no arguments we fallback to numpy (maybe break). We need to find a mechanism for this corner case. Currently these tests are excluded with `pytest.mark.uses_scan_without_field_args` for cupy embedded execution.

Refactoring:
- make common.field and common.connectivity private
- rename next_tests.exclusion_matrices to definitions

TODOs for later:
- `broadcast` of scalar ignores the broadcast
---------

Co-authored-by: Enrique González Paredes <[email protected]>
  • Loading branch information
havogt and egparedes authored Jan 29, 2024
1 parent 8c3b3d7 commit f0986bb
Show file tree
Hide file tree
Showing 29 changed files with 274 additions and 194 deletions.
8 changes: 4 additions & 4 deletions docs/development/ADRs/0015-Test_Exclusion_Matrices.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
---
tags: []
tags: [testing]
---

# Test-Exclusion Matrices

- **Status**: valid
- **Authors**: Edoardo Paone (@edopao), Enrique G. Paredes (@egparedes)
- **Created**: 2023-09-21
- **Updated**: 2023-09-21
- **Updated**: 2024-01-25

In the context of Field View testing, lacking support for specific ITIR features while a certain backend
is being developed, we decided to use `pytest` fixtures to exclude unsupported tests.
Expand All @@ -22,7 +22,7 @@ the supported backends, while keeping the test code clean.
## Decision

It was decided to apply fixtures and markers from `pytest` module. The fixture is the same used to execute the test
on different backends (`fieldview_backend` and `program_processor`), but it is extended with a check on the available feature markers.
on different backends (`exec_alloc_descriptor` and `program_processor`), but it is extended with a check on the available feature markers.
If a test is annotated with a feature marker, the fixture will check if this feature is supported on the selected backend.
If no marker is specified, the test is supposed to run on all backends.

Expand All @@ -33,7 +33,7 @@ In the example below, `test_offset_field` requires the backend to support dynami
def test_offset_field(cartesian_case):
```

In order to selectively enable the backends, the dictionary `next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX`
In order to selectively enable the backends, the dictionary `next_tests.definitions.BACKEND_SKIP_TEST_MATRIX`
lists for each backend the features that are not supported.
The fixture will check if the annotated feature is present in the exclusion-matrix for the selected backend.
If so, the exclusion matrix will also specify the action `pytest` should take (e.g. `SKIP` or `XFAIL`).
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,11 @@ markers = [
'uses_origin: tests that require backend support for domain origin',
'uses_reduction_over_lift_expressions: tests that require backend support for reduction over lift expressions',
'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields',
'uses_scan: tests that uses scan',
'uses_scan_in_field_operator: tests that require backend support for scan in field operator',
'uses_scan_without_field_args: tests that require calls to scan that do not have any fields as arguments',
'uses_scan_nested: tests that use nested scans',
'uses_scan_requiring_projector: tests need a projector implementation in gtfn',
'uses_sparse_fields: tests that require backend support for sparse fields',
'uses_sparse_fields_as_output: tests that require backend support for writing sparse fields',
'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset',
Expand All @@ -349,7 +353,7 @@ markers = [
'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields',
'uses_cartesian_shift: tests that use a Cartesian connectivity',
'uses_unstructured_shift: tests that use a unstructured connectivity',
'uses_scan: tests that uses scan',
'uses_max_over: tests that use the max_over builtin',
'checks_specific_error: tests that rely on the backend to produce a specific error message'
]
norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*']
Expand Down
1 change: 1 addition & 0 deletions src/gt4py/next/allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def __init__(self) -> None:

device_allocators[core_defs.DeviceType.CPU] = StandardCPUFieldBufferAllocator()


assert is_field_allocator(device_allocators[core_defs.DeviceType.CPU])


Expand Down
9 changes: 6 additions & 3 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,8 +843,10 @@ def is_connectivity_field(
return isinstance(v, ConnectivityField) # type: ignore[misc] # we use extended_runtime_checkable


# Utility function to construct a `Field` from different buffer representations.
# Consider removing this function and using `Field` constructor directly. See also `_connectivity`.
@functools.singledispatch
def field(
def _field(
definition: Any,
/,
*,
Expand All @@ -854,8 +856,9 @@ def field(
raise NotImplementedError


# See comment for `_field`.
@functools.singledispatch
def connectivity(
def _connectivity(
definition: Any,
/,
codomain: Dimension,
Expand Down Expand Up @@ -980,7 +983,7 @@ def restrict(self, index: AnyIndexSpec) -> core_defs.IntegralScalar:
__getitem__ = restrict


connectivity.register(numbers.Integral, CartesianConnectivity.from_offset)
_connectivity.register(numbers.Integral, CartesianConnectivity.from_offset)


@enum.unique
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def empty(
buffer = next_allocators.allocate(
domain, dtype, aligned_index=aligned_index, allocator=allocator, device=device
)
res = common.field(buffer.ndarray, domain=domain)
res = common._field(buffer.ndarray, domain=domain)
assert common.is_mutable_field(res)
assert isinstance(res, nd_array_field.NdArrayField)
return res
Expand Down Expand Up @@ -356,9 +356,9 @@ def as_connectivity(
if (allocator is None) and (device is None) and xtyping.supports_dlpack(data):
device = core_defs.Device(*data.__dlpack_device__())
buffer = next_allocators.allocate(actual_domain, dtype, allocator=allocator, device=device)
# TODO(havogt): consider addin MutableNDArrayObject
# TODO(havogt): consider adding MutableNDArrayObject
buffer.ndarray[...] = storage_utils.asarray(data) # type: ignore[index]
connectivity_field = common.connectivity(
connectivity_field = common._connectivity(
buffer.ndarray, codomain=codomain, domain=actual_domain
)
assert isinstance(connectivity_field, nd_array_field.NdArrayConnectivityField)
Expand Down
22 changes: 12 additions & 10 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@ class NdArrayField(
_domain: common.Domain
_ndarray: core_defs.NDArrayObject

array_ns: ClassVar[
ModuleType
] # TODO(havogt) after storage PR is merged, update to the NDArrayNamespace protocol
array_ns: ClassVar[ModuleType] # TODO(havogt) introduce a NDArrayNamespace protocol

@property
def domain(self) -> common.Domain:
Expand Down Expand Up @@ -197,7 +195,11 @@ def remap(
# finally, take the new array
new_buffer = xp.take(self._ndarray, new_idx_array, axis=dim_idx)

return self.__class__.from_array(new_buffer, domain=new_domain, dtype=self.dtype)
return self.__class__.from_array(
new_buffer,
domain=new_domain,
dtype=self.dtype,
)

__call__ = remap # type: ignore[assignment]

Expand Down Expand Up @@ -510,15 +512,15 @@ class NumPyArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = np


common.field.register(np.ndarray, NumPyArrayField.from_array)
common._field.register(np.ndarray, NumPyArrayField.from_array)


@dataclasses.dataclass(frozen=True, eq=False)
class NumPyArrayConnectivityField(NdArrayConnectivityField):
array_ns: ClassVar[ModuleType] = np


common.connectivity.register(np.ndarray, NumPyArrayConnectivityField.from_array)
common._connectivity.register(np.ndarray, NumPyArrayConnectivityField.from_array)

# CuPy
if cp:
Expand All @@ -528,13 +530,13 @@ class NumPyArrayConnectivityField(NdArrayConnectivityField):
class CuPyArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = cp

common.field.register(cp.ndarray, CuPyArrayField.from_array)
common._field.register(cp.ndarray, CuPyArrayField.from_array)

@dataclasses.dataclass(frozen=True, eq=False)
class CuPyArrayConnectivityField(NdArrayConnectivityField):
array_ns: ClassVar[ModuleType] = cp

common.connectivity.register(cp.ndarray, CuPyArrayConnectivityField.from_array)
common._connectivity.register(cp.ndarray, CuPyArrayConnectivityField.from_array)

# JAX
if jnp:
Expand All @@ -552,7 +554,7 @@ def __setitem__(
# TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)`
raise NotImplementedError("'__setitem__' for JaxArrayField not yet implemented.")

common.field.register(jnp.ndarray, JaxArrayField.from_array)
common._field.register(jnp.ndarray, JaxArrayField.from_array)


def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field:
Expand All @@ -565,7 +567,7 @@ def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]
else:
domain_slice.append(np.newaxis)
named_ranges.append((dim, common.UnitRange.infinite()))
return common.field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges))
return common._field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges))


def _builtins_broadcast(
Expand Down
33 changes: 27 additions & 6 deletions src/gt4py/next/embedded/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import dataclasses
from types import ModuleType
from typing import Any, Callable, Generic, ParamSpec, Sequence, TypeVar

import numpy as np

from gt4py import eve
from gt4py._core import definitions as core_defs
from gt4py.next import common, constructors, errors, utils
from gt4py.next import common, errors, utils
from gt4py.next.embedded import common as embedded_common, context as embedded_context


Expand All @@ -43,7 +46,8 @@ def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Fiel
scan_range = embedded_context.closure_column_range.get()
assert self.axis == scan_range[0]
scan_axis = scan_range[0]
domain_intersection = _intersect_scan_args(*args, *kwargs.values())
all_args = [*args, *kwargs.values()]
domain_intersection = _intersect_scan_args(*all_args)
non_scan_domain = common.Domain(*[nr for nr in domain_intersection if nr[0] != scan_axis])

out_domain = common.Domain(
Expand All @@ -53,7 +57,8 @@ def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Fiel
# even if the scan dimension is not in the input, we can scan over it
out_domain = common.Domain(*out_domain, (scan_range))

res = _construct_scan_array(out_domain)(self.init)
xp = _get_array_ns(*all_args)
res = _construct_scan_array(out_domain, xp)(self.init)

def scan_loop(hpos):
acc = self.init
Expand Down Expand Up @@ -128,7 +133,11 @@ def _tuple_assign_field(
):
@utils.tree_map
def impl(target: common.MutableField, source: common.Field):
target[domain] = source[domain]
if common.is_field(source):
target[domain] = source[domain]
else:
assert core_defs.is_scalar_type(source)
target[domain] = source

impl(target, source)

Expand All @@ -141,10 +150,21 @@ def _intersect_scan_args(
)


def _construct_scan_array(domain: common.Domain):
def _get_array_ns(
*args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...]
) -> ModuleType:
for arg in utils.flatten_nested_tuple(args):
if hasattr(arg, "array_ns"):
return arg.array_ns
return np


def _construct_scan_array(
domain: common.Domain, xp: ModuleType
): # TODO(havogt) introduce a NDArrayNamespace protocol
@utils.tree_map
def impl(init: core_defs.Scalar) -> common.Field:
return constructors.empty(domain, dtype=type(init))
return common._field(xp.empty(domain.shape, dtype=type(init)), domain=domain)

return impl

Expand All @@ -168,6 +188,7 @@ def _tuple_at(
@utils.tree_map
def impl(field: common.Field | core_defs.Scalar) -> core_defs.Scalar:
res = field[pos] if common.is_field(field) else field
res = res.item() if hasattr(res, "item") else res # extract scalar value from array
assert core_defs.is_scalar_type(res)
return res

Expand Down
8 changes: 2 additions & 6 deletions src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,8 @@ def broadcast(
assert core_defs.is_scalar_type(
field
) # default implementation for scalars, Fields are handled via dispatch
return common.field(
np.asarray(field)[
tuple([np.newaxis] * len(dims))
], # TODO(havogt) use FunctionField once available
domain=common.Domain(dims=dims, ranges=tuple([common.UnitRange.infinite()] * len(dims))),
)
# TODO(havogt) implement with FunctionField, the workaround is to ignore broadcasting on scalars as they broadcast automatically, but we lose the check for compatible dimensions
return field # type: ignore[return-value] # see comment above


@WhereBuiltinFunction
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,7 @@ def _maker(a) -> common.Field:
offset = origin.get(d, 0)
ranges.append(common.UnitRange(-offset, s - offset))

res = common.field(a, domain=common.Domain(dims=tuple(axes), ranges=tuple(ranges)))
res = common._field(a, domain=common.Domain(dims=tuple(axes), ranges=tuple(ranges)))
return res

return _maker
Expand Down
4 changes: 2 additions & 2 deletions tests/next_tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from . import exclusion_matrices
from . import definitions


__all__ = ["exclusion_matrices", "get_processor_id"]
__all__ = ["definitions", "get_processor_id"]


def get_processor_id(processor):
Expand Down
Loading

0 comments on commit f0986bb

Please sign in to comment.