diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b827c77e..23a35844 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,6 +41,9 @@ jobs: curl -L -O https://tiker.net/ci-support-v0 . ./ci-support-v0 + CONDA_ENVIRONMENT=.test-conda-env-py3.yml + echo "- cupy" >> "$CONDA_ENVIRONMENT" + build_py_project_in_conda_env cipip install basedpyright diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 0f0947a0..e4d02faf 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -33,6 +33,25 @@ Python 3 Nvidia Titan V: reports: junit: test/pytest.xml +Python 3 CuPy Nvidia Titan V: + script: | + curl -L -O https://tiker.net/ci-support-v0 + . ./ci-support-v0 + CONDA_ENVIRONMENT=.test-conda-env-py3.yml + echo "- cupy" >> "$CONDA_ENVIRONMENT" + export PYOPENCL_TEST=port:cpu + build_py_project_in_conda_env + test_py_project + + tags: + - python3 + - nvidia-titan-v + except: + - tags + artifacts: + reports: + junit: test/pytest.xml + Python 3 POCL Nvidia Titan V: script: | curl -L -O https://tiker.net/ci-support-v0 diff --git a/README.rst b/README.rst index a704c122..06800fc3 100644 --- a/README.rst +++ b/README.rst @@ -15,7 +15,8 @@ GPU arrays? Deferred-evaluation arrays? Just plain ``numpy`` arrays? You'd like code to work with all of them? No problem! Comes with pre-made array context implementations for: -- numpy +- `numpy `__ +- `cupy `__ - `PyOpenCL `__ - `JAX `__ - `Pytato `__ (for lazy/deferred evaluation) @@ -24,7 +25,7 @@ implementations for: - Profiling ``arraycontext`` started life as an array abstraction for use with the -`meshmode `__ unstrucuted discretization +`meshmode `__ unstructured discretization package. Distributed under the MIT license. diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 4abe4a54..54ed689a 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -88,6 +88,7 @@ ScalarLike, tag_axes, ) +from .impl.cupy import CupyArrayContext from .impl.jax import EagerJAXArrayContext from .impl.numpy import NumpyArrayContext from .impl.pyopencl import PyOpenCLArrayContext @@ -118,6 +119,7 @@ "ArrayT", "BcastUntilActxArray", "CommonSubexpressionTag", + "CupyArrayContext", "EagerJAXArrayContext", "ElementwiseMapKernelTag", "NotAnArrayContainerError", diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 1ebc2cb2..a44f9cc9 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -94,6 +94,8 @@ if TYPE_CHECKING: + from typing import Any + from pymbolic.geometric_algebra import MultiVector from arraycontext import ArrayOrContainer @@ -283,7 +285,7 @@ def get_container_context_opt(ary: ArrayContainer) -> ArrayContext | None: @serialize_container.register(np.ndarray) def _serialize_ndarray_container( - ary: numpy.ndarray) -> SerializedContainer: + ary: numpy.ndarray[Any, Any]) -> SerializedContainer: if ary.dtype.char != "O": raise NotAnArrayContainerError( f"cannot serialize '{type(ary).__name__}' with dtype '{ary.dtype}'") @@ -303,8 +305,8 @@ def _serialize_ndarray_container( @deserialize_container.register(np.ndarray) # https://github.com/python/mypy/issues/13040 def _deserialize_ndarray_container( # type: ignore[misc] - template: numpy.ndarray, - serialized: SerializedContainer) -> numpy.ndarray: + template: numpy.ndarray[Any, Any], + serialized: SerializedContainer) -> numpy.ndarray[Any, Any]: # disallow subclasses assert type(template) is np.ndarray assert template.dtype.char == "O" diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 69d050f7..8620b7ec 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -917,7 +917,7 @@ def _flat_size(subary: ArrayOrContainer) -> Array | Integer: # {{{ numpy conversion def from_numpy( - ary: np.ndarray | ScalarLike, + ary: np.ndarray[Any, Any] | ScalarLike, actx: ArrayContext) -> ArrayOrContainerOrScalar: """Convert all :mod:`numpy` arrays in the :class:`~arraycontext.ArrayContainer` to the base array type of :class:`~arraycontext.ArrayContext`. diff --git a/arraycontext/context.py b/arraycontext/context.py index f751413c..61baf8ae 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -282,7 +282,7 @@ def __rtruediv__(self, other: Self | ScalarLike) -> Array: ... ContainerOrScalarT = TypeVar("ContainerOrScalarT", bound="ArrayContainer | ScalarLike") -NumpyOrContainerOrScalar = Union[np.ndarray, "ArrayContainer", ScalarLike] +NumpyOrContainerOrScalar = Union[np.ndarray[Any, Any], "ArrayContainer", ScalarLike] # }}} @@ -358,7 +358,7 @@ def zeros(self, return self.np.zeros(shape, dtype) @overload - def from_numpy(self, array: np.ndarray) -> Array: + def from_numpy(self, array: np.ndarray[Any, Any]) -> Array: ... @overload @@ -379,7 +379,7 @@ def from_numpy(self, """ @overload - def to_numpy(self, array: Array) -> np.ndarray: + def to_numpy(self, array: Array) -> np.ndarray[Any, Any]: ... @overload diff --git a/arraycontext/impl/cupy/__init__.py b/arraycontext/impl/cupy/__init__.py new file mode 100644 index 00000000..08e24bee --- /dev/null +++ b/arraycontext/impl/cupy/__init__.py @@ -0,0 +1,155 @@ +""" +.. currentmodule:: arraycontext + +A :mod:`cupy`-based array context. + +.. autoclass:: CupyArrayContext +""" + +from __future__ import annotations + + +__copyright__ = """ +Copyright (C) 2024 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from typing import Any, overload + +import numpy as np + +import loopy as lp +from pytools.tag import ToTagSetConvertible + +from arraycontext.container.traversal import rec_map_array_container, with_array_context +from arraycontext.context import ( + Array, + ArrayContext, + ArrayOrContainerOrScalar, + ArrayOrContainerOrScalarT, + ContainerOrScalarT, + NumpyOrContainerOrScalar, +) + + +class CupyNonObjectArrayMetaclass(type): + def __instancecheck__(cls, instance: Any) -> bool: + import cupy as cp # type: ignore[import-untyped] + return isinstance(instance, cp.ndarray) and instance.dtype != object + + +class CupyNonObjectArray(metaclass=CupyNonObjectArrayMetaclass): + pass + + +class CupyArrayContext(ArrayContext): + """An :class:`ArrayContext` that uses :class:`cupy.ndarray` to represent arrays.""" + + array_types = (CupyNonObjectArray,) + + def _get_fake_numpy_namespace(self): + from .fake_numpy import CupyFakeNumpyNamespace + return CupyFakeNumpyNamespace(self) + + # {{{ ArrayContext interface + + def clone(self): + return type(self)() + + @overload + def from_numpy(self, array: np.ndarray[Any, Any]) -> Array: + ... + + @overload + def from_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT: + ... + + def from_numpy(self, + array: NumpyOrContainerOrScalar + ) -> ArrayOrContainerOrScalar: + import cupy as cp + return with_array_context(rec_map_array_container(cp.array, array), + actx=self) + + @overload + def to_numpy(self, array: Array) -> np.ndarray[Any, Any]: + ... + + @overload + def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT: + ... + + def to_numpy(self, + array: ArrayOrContainerOrScalar + ) -> NumpyOrContainerOrScalar: + import cupy as cp + return with_array_context(rec_map_array_container(cp.asnumpy, array), + actx=None) + + def call_loopy( + self, + t_unit: lp.TranslationUnit, **kwargs: Any + ) -> dict[str, Array]: + raise NotImplementedError( + "Calling loopy on CuPy arrays is not supported. Maybe rewrite" + " the loopy kernel as numpy-flavored array operations using" + " ArrayContext.np.") + + def freeze(self, array): + import cupy as cp + # Note that we could use a non-blocking version of cp.asnumpy here, but + # it appears to have very little impact on performance. + return with_array_context(rec_map_array_container(cp.asnumpy, array), actx=None) + + def thaw(self, array): + import cupy as cp + return with_array_context(rec_map_array_container(cp.array, array), actx=self) + + # }}} + + def tag(self, + tags: ToTagSetConvertible, + array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT: + # Cupy (like numpy) doesn't support tagging + return array + + def tag_axis(self, + iaxis: int, tags: ToTagSetConvertible, + array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT: + # Cupy (like numpy) doesn't support tagging + return array + + def einsum(self, spec, *args, arg_names=None, tagged=()): + import cupy as cp + return cp.einsum(spec, *args) + + @property + def permits_inplace_modification(self): + return True + + @property + def supports_nonscalar_broadcasting(self): + return True + + @property + def permits_advanced_indexing(self): + return True diff --git a/arraycontext/impl/cupy/fake_numpy.py b/arraycontext/impl/cupy/fake_numpy.py new file mode 100644 index 00000000..9ef85203 --- /dev/null +++ b/arraycontext/impl/cupy/fake_numpy.py @@ -0,0 +1,201 @@ +from __future__ import annotations + + +__copyright__ = """ +Copyright (C) 2024 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +from functools import partial, reduce + +from arraycontext.container import NotAnArrayContainerError, serialize_container +from arraycontext.container.traversal import ( + rec_map_array_container, + rec_map_reduce_array_container, + rec_multimap_array_container, + rec_multimap_reduce_array_container, +) +from arraycontext.context import Array, ArrayOrContainer +from arraycontext.fake_numpy import ( + BaseFakeNumpyLinalgNamespace, + BaseFakeNumpyNamespace, +) + + +class CupyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): + # Everything is implemented in the base class for now. + pass + + +_NUMPY_UFUNCS = frozenset({"concatenate", "reshape", "transpose", + "where", + *BaseFakeNumpyNamespace._numpy_math_functions + }) + + +class CupyFakeNumpyNamespace(BaseFakeNumpyNamespace): + """ + A :mod:`cupy` mimic for :class:`CupyArrayContext`. + """ + def _get_fake_numpy_linalg_namespace(self): + return CupyFakeNumpyLinalgNamespace(self._array_context) + + def zeros(self, shape, dtype): + import cupy as cp # type: ignore[import-untyped] + return cp.zeros(shape, dtype) + + def __getattr__(self, name): + import cupy as cp + + if name in _NUMPY_UFUNCS: + from functools import partial + return partial(rec_multimap_array_container, + getattr(cp, name)) + + raise AttributeError(name) + + def sum(self, a, axis=None, dtype=None): + import cupy as cp + return rec_map_reduce_array_container(sum, partial(cp.sum, + axis=axis, + dtype=dtype), + a) + + def min(self, a, axis=None): + import cupy as cp + return rec_map_reduce_array_container( + partial(reduce, cp.minimum), partial(cp.amin, axis=axis), a) + + def max(self, a, axis=None): + import cupy as cp + return rec_map_reduce_array_container( + partial(reduce, cp.maximum), partial(cp.amax, axis=axis), a) + + def stack(self, arrays, axis=0): + import cupy as cp + return rec_multimap_array_container( + lambda *args: cp.stack(args, axis=axis), + *arrays) + + def broadcast_to(self, array, shape): + import cupy as cp + return rec_map_array_container(partial(cp.broadcast_to, shape=shape), array) + + # {{{ relational operators + + def equal(self, x, y): + import cupy as cp + return rec_multimap_array_container(cp.equal, x, y) + + def not_equal(self, x, y): + import cupy as cp + return rec_multimap_array_container(cp.not_equal, x, y) + + def greater(self, x, y): + import cupy as cp + return rec_multimap_array_container(cp.greater, x, y) + + def greater_equal(self, x, y): + import cupy as cp + return rec_multimap_array_container(cp.greater_equal, x, y) + + def less(self, x, y): + import cupy as cp + return rec_multimap_array_container(cp.less, x, y) + + def less_equal(self, x, y): + import cupy as cp + return rec_multimap_array_container(cp.less_equal, x, y) + + # }}} + + def ravel(self, a, order="C"): + import cupy as cp + return rec_map_array_container(partial(cp.ravel, order=order), a) + + def vdot(self, x, y): + import cupy as cp + return rec_multimap_reduce_array_container(sum, cp.vdot, x, y) + + def any(self, a): + import cupy as cp + return rec_map_reduce_array_container(partial(reduce, cp.logical_or), + lambda subary: cp.any(subary), a) + + def all(self, a): + import cupy as cp + return rec_map_reduce_array_container(partial(reduce, cp.logical_and), + lambda subary: cp.all(subary), a) + + def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: + import cupy as cp + false_ary = cp.array(False) + true_ary = cp.array(True) + if type(a) is not type(b): + return false_ary + + try: + serialized_x = serialize_container(a) + serialized_y = serialize_container(b) + except NotAnArrayContainerError: + assert isinstance(a, cp.ndarray) + assert isinstance(b, cp.ndarray) + return cp.array(cp.array_equal(a, b)) + else: + if len(serialized_x) != len(serialized_y): + return false_ary + return reduce( + cp.logical_and, + [(true_ary if kx_i == ky_i else false_ary) + and self.array_equal(x_i, y_i) + for (kx_i, x_i), (ky_i, y_i) + in zip(serialized_x, serialized_y, strict=True)], + true_ary) + + def arange(self, *args, **kwargs): + import cupy as cp + return cp.arange(*args, **kwargs) + + def linspace(self, *args, **kwargs): + import cupy as cp + return cp.linspace(*args, **kwargs) + + def zeros_like(self, ary): # pyright: ignore[reportIncompatibleMethodOverride] + import cupy as cp + if isinstance(ary, int | float | complex): + # Cupy does not support zeros_like with scalar arguments + ary = cp.array(ary) + return rec_map_array_container(cp.zeros_like, ary) + + def ones_like(self, ary): + import cupy as cp + if isinstance(ary, int | float | complex): + # Cupy does not support ones_like with scalar arguments + ary = cp.array(ary) + return rec_map_array_container(cp.ones_like, ary) + + def reshape(self, a, newshape, order="C"): + return rec_map_array_container( + lambda ary: ary.reshape(newshape, order=order), + a) + + +# vim: fdm=marker diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index f9d6c541..d29fac7a 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -63,7 +63,7 @@ class NumpyNonObjectArray(metaclass=NumpyNonObjectArrayMetaclass): class NumpyArrayContext(ArrayContext): """ - A :class:`ArrayContext` that uses :class:`numpy.ndarray` to represent arrays. + An :class:`ArrayContext` that uses :class:`numpy.ndarray` to represent arrays. .. automethod:: __init__ """ @@ -86,7 +86,7 @@ def clone(self): return type(self)() @overload - def from_numpy(self, array: np.ndarray) -> Array: + def from_numpy(self, array: np.ndarray[Any, Any]) -> Array: ... @overload @@ -99,7 +99,7 @@ def from_numpy(self, return array @overload - def to_numpy(self, array: Array) -> np.ndarray: + def to_numpy(self, array: Array) -> np.ndarray[Any, Any]: ... @overload diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index 582ccda9..b6702415 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -150,7 +150,7 @@ def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: return false_ary return np.logical_and.reduce( [(true_ary if kx_i == ky_i else false_ary) - and cast(np.ndarray, self.array_equal(x_i, y_i)) + and cast(np.ndarray, self.array_equal(x_i, y_i)) # pyright: ignore[reportMissingTypeArgument] for (kx_i, x_i), (ky_i, y_i) in zip(serialized_x, serialized_y, strict=True)], initial=true_ary) diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index 19c9faea..10ff1663 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -59,7 +59,7 @@ class PyOpenCLArrayContext(ArrayContext): """ - A :class:`ArrayContext` that uses :class:`pyopencl.array.Array` instances + An :class:`ArrayContext` that uses :class:`pyopencl.array.Array` instances for its base array class. .. attribute:: context diff --git a/arraycontext/impl/pyopencl/taggable_cl_array.py b/arraycontext/impl/pyopencl/taggable_cl_array.py index 79a108e2..60ea3f53 100644 --- a/arraycontext/impl/pyopencl/taggable_cl_array.py +++ b/arraycontext/impl/pyopencl/taggable_cl_array.py @@ -212,7 +212,7 @@ def zeros( def to_device( queue: cl.CommandQueue, - ary: np.ndarray[Any], + ary: np.ndarray[Any, Any], *, axes: tuple[Axis, ...] | None = None, tags: frozenset[Tag] = _EMPTY_TAG_SET, allocator: cla.Allocator | None = None, diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index e78c4e62..86e9bec1 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -100,12 +100,12 @@ def __hash__(self): @dataclass(frozen=True, eq=True) class ScalarInputDescriptor(AbstractInputDescriptor): - dtype: np.dtype + dtype: np.dtype[Any] @dataclass(frozen=True, eq=True) class LeafArrayDescriptor(AbstractInputDescriptor): - dtype: np.dtype + dtype: np.dtype[Any] shape: pt.array.ShapeType # }}} diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 760fc103..c85c7a22 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -227,6 +227,26 @@ def __str__(self): return "" +class _PytestCupyArrayContextFactory(PytestArrayContextFactory): + @classmethod + def is_available(cls) -> bool: + try: + import cupy # type: ignore[import-untyped] # noqa: F401 + return True + except ImportError: + return False + + def __call__(self): + from arraycontext import CupyArrayContext + return CupyArrayContext() + + def __str__(self): + import cupy # pylint: disable=import-error + d = cupy.cuda.runtime.getDeviceProperties(cupy.cuda.Device()) + name = d["name"].decode("utf-8") + return f" on {cupy.cuda.Device()}:{name}" + + # {{{ _PytestArrayContextFactory class _NumpyArrayContextForTests(NumpyArrayContext): @@ -253,6 +273,7 @@ def __str__(self): "pytato:jax": _PytestPytatoJaxArrayContextFactory, "eagerjax": _PytestEagerJaxArrayContextFactory, "numpy": _PytestNumpyArrayContextFactory, + "cupy": _PytestCupyArrayContextFactory, } diff --git a/doc/Makefile b/doc/Makefile index d0ac5f2f..0568a00c 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -3,7 +3,7 @@ # You can set these variables from the command line, and also # from the environment for the first two. -SPHINXOPTS ?= +SPHINXOPTS ?= -W -n SPHINXBUILD ?= python $(shell which sphinx-build) SOURCEDIR = . BUILDDIR = _build diff --git a/doc/conf.py b/doc/conf.py index f01e4072..6af725ae 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -23,6 +23,7 @@ "pytest": ("https://docs.pytest.org/en/latest/", None), "python": ("https://docs.python.org/3/", None), "pytools": ("https://documen.tician.de/pytools", None), + "cupy": ("https://docs.cupy.dev/en/stable/", None), } # Some modules need to import things just so that sphinx can resolve symbols in diff --git a/doc/implementations.rst b/doc/implementations.rst index 4023e37c..2e6344e8 100644 --- a/doc/implementations.rst +++ b/doc/implementations.rst @@ -13,6 +13,12 @@ Array context based on :mod:`numpy` .. automodule:: arraycontext.impl.numpy + +Array context based on :mod:`cupy` +-------------------------------------------- + +.. automodule:: arraycontext.impl.cupy + Array context based on :mod:`pyopencl.array` -------------------------------------------- diff --git a/doc/index.rst b/doc/index.rst index d3f9854b..2445d2e2 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -6,6 +6,7 @@ code to work with all of them? No problem! Comes with pre-made array context implementations for: - :mod:`numpy` +- :mod:`cupy` - :mod:`pyopencl` - :mod:`jax.numpy` - :mod:`pytato` (for lazy/deferred evaluation) @@ -13,7 +14,7 @@ implementations for: - Profiling :mod:`arraycontext` started life as an array abstraction for use with the -:mod:`meshmode` unstrucuted discretization package. +:mod:`meshmode` unstructured discretization package. Design Guidelines ----------------- diff --git a/doc/make_numpy_coverage_table.py b/doc/make_numpy_coverage_table.py index 57f833d7..abaab57c 100644 --- a/doc/make_numpy_coverage_table.py +++ b/doc/make_numpy_coverage_table.py @@ -11,7 +11,7 @@ .. code:: - python make_numpy_support_table.py numpy_coverage.rst + python make_numpy_coverage_table.py numpy_coverage.rst """ from __future__ import annotations @@ -67,6 +67,8 @@ def initialize_contexts(): arraycontext.EagerJAXArrayContext(), arraycontext.PytatoPyOpenCLArrayContext(queue), arraycontext.PytatoJAXArrayContext(), + arraycontext.NumpyArrayContext(), + arraycontext.CupyArrayContext(), ] diff --git a/doc/numpy_coverage.rst b/doc/numpy_coverage.rst index 5a7b2918..679bceb1 100644 --- a/doc/numpy_coverage.rst +++ b/doc/numpy_coverage.rst @@ -18,12 +18,18 @@ Array creation routines - :class:`~arraycontext.EagerJAXArrayContext` - :class:`~arraycontext.PytatoPyOpenCLArrayContext` - :class:`~arraycontext.PytatoJAXArrayContext` + - :class:`~arraycontext.NumpyArrayContext` + - :class:`~arraycontext.CupyArrayContext` * - :func:`numpy.empty_like` - :green:`Yes` - :green:`Yes` + - :red:`No` + - :red:`No` + - :red:`No` + - :red:`No` + * - :func:`numpy.ones_like` - :green:`Yes` - :green:`Yes` - * - :func:`numpy.ones_like` - :green:`Yes` - :green:`Yes` - :green:`Yes` @@ -33,16 +39,22 @@ Array creation routines - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :func:`numpy.full_like` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :red:`No` + - :red:`No` * - :func:`numpy.copy` - :green:`Yes` - :green:`Yes` - :red:`No` - :red:`No` + - :red:`No` + - :red:`No` Array manipulation routines ~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -55,36 +67,50 @@ Array manipulation routines - :class:`~arraycontext.EagerJAXArrayContext` - :class:`~arraycontext.PytatoPyOpenCLArrayContext` - :class:`~arraycontext.PytatoJAXArrayContext` + - :class:`~arraycontext.NumpyArrayContext` + - :class:`~arraycontext.CupyArrayContext` * - :func:`numpy.reshape` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :func:`numpy.ravel` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :func:`numpy.transpose` - :red:`No` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :func:`numpy.broadcast_to` - :red:`No` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :func:`numpy.concatenate` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :func:`numpy.stack` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` Linear algebra ~~~~~~~~~~~~~~ @@ -97,11 +123,15 @@ Linear algebra - :class:`~arraycontext.EagerJAXArrayContext` - :class:`~arraycontext.PytatoPyOpenCLArrayContext` - :class:`~arraycontext.PytatoJAXArrayContext` + - :class:`~arraycontext.NumpyArrayContext` + - :class:`~arraycontext.CupyArrayContext` * - :func:`numpy.vdot` - :green:`Yes` - :green:`Yes` - - :red:`No` - - :red:`No` + - :green:`Yes` + - :green:`Yes` + - :green:`Yes` + - :green:`Yes` Logic Functions ~~~~~~~~~~~~~~~ @@ -114,46 +144,64 @@ Logic Functions - :class:`~arraycontext.EagerJAXArrayContext` - :class:`~arraycontext.PytatoPyOpenCLArrayContext` - :class:`~arraycontext.PytatoJAXArrayContext` + - :class:`~arraycontext.NumpyArrayContext` + - :class:`~arraycontext.CupyArrayContext` * - :func:`numpy.all` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :func:`numpy.any` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.greater` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.greater_equal` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.less` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.less_equal` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.equal` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.not_equal` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` Mathematical functions ~~~~~~~~~~~~~~~~~~~~~~ @@ -166,133 +214,187 @@ Mathematical functions - :class:`~arraycontext.EagerJAXArrayContext` - :class:`~arraycontext.PytatoPyOpenCLArrayContext` - :class:`~arraycontext.PytatoJAXArrayContext` + - :class:`~arraycontext.NumpyArrayContext` + - :class:`~arraycontext.CupyArrayContext` * - :data:`numpy.sin` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.cos` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.tan` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.arcsin` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.arccos` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.arctan` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.arctan2` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.sinh` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.cosh` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.tanh` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.floor` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.ceil` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :func:`numpy.sum` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.exp` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.log` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.log10` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :func:`numpy.real` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :func:`numpy.imag` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.conjugate` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.maximum` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :func:`numpy.amax` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :red:`No` + - :red:`No` * - :data:`numpy.minimum` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :func:`numpy.amin` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :red:`No` + - :red:`No` * - :data:`numpy.sqrt` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.absolute` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` * - :data:`numpy.fabs` - :green:`Yes` - :green:`Yes` - :green:`Yes` - :green:`Yes` + - :green:`Yes` + - :green:`Yes` diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 31fa9e79..3b78ee73 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -35,6 +35,7 @@ from arraycontext import ( BcastUntilActxArray, + CupyArrayContext, EagerJAXArrayContext, NumpyArrayContext, PyOpenCLArrayContext, @@ -46,6 +47,7 @@ with_container_arithmetic, ) from arraycontext.pytest import ( + _PytestCupyArrayContextFactory, _PytestEagerJaxArrayContextFactory, _PytestNumpyArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass, @@ -100,6 +102,7 @@ class _PytatoPyOpenCLArrayContextForTestsFactory( _PytatoPyOpenCLArrayContextForTestsFactory, _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory, + _PytestCupyArrayContextFactory, _PytestNumpyArrayContextFactory, ]) @@ -1028,7 +1031,7 @@ def test_numpy_conversion(actx_factory): assert np.allclose(ac.mass, ac_roundtrip.mass) assert np.allclose(ac.momentum[0], ac_roundtrip.momentum[0]) - if not isinstance(actx, NumpyArrayContext): + if not isinstance(actx, NumpyArrayContext | CupyArrayContext): from dataclasses import replace ac_with_cl = replace(ac, enthalpy=ac_actx.mass) with pytest.raises(TypeError): @@ -1466,7 +1469,7 @@ def test_to_numpy_on_frozen_arrays(actx_factory): def test_tagging(actx_factory): actx = actx_factory() - if isinstance(actx, NumpyArrayContext | EagerJAXArrayContext): + if isinstance(actx, NumpyArrayContext | EagerJAXArrayContext | CupyArrayContext): pytest.skip(f"{type(actx)} has no tagging support") from pytools.tag import Tag @@ -1517,6 +1520,9 @@ def test_linspace(actx_factory, args, kwargs): actx = actx_factory() + if isinstance(actx, CupyArrayContext) and kwargs.get("dtype") == np.complex128: + pytest.skip("CupyArrayContext does not support complex args to linspace") + actx_linspace = actx.to_numpy(actx.np.linspace(*args, **kwargs)) np_linspace = np.linspace(*args, **kwargs)