Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement CupyArrayContext #251

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ jobs:
run: |
USE_CONDA_BUILD=1
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/master/prepare-and-run-pylint.sh

CONDA_ENVIRONMENT=.test-conda-env-py3.yml
echo "- cupy" >> "$CONDA_ENVIRONMENT"

. ./prepare-and-run-pylint.sh "$(basename $GITHUB_REPOSITORY)" examples/*.py test/test_*.py

mypy:
Expand All @@ -55,6 +59,7 @@ jobs:

build_py_project_in_conda_env
python -m pip install mypy pytest
conda install cupy
./run-mypy.sh

pytest3_pocl:
Expand Down
2 changes: 2 additions & 0 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
ScalarLike,
tag_axes,
)
from .impl.cupy import CupyArrayContext
from .impl.jax import EagerJAXArrayContext
from .impl.numpy import NumpyArrayContext
from .impl.pyopencl import PyOpenCLArrayContext
Expand All @@ -105,6 +106,7 @@
"ArrayOrContainerT",
"ArrayT",
"CommonSubexpressionTag",
"CupyArrayContext",
"EagerJAXArrayContext",
"ElementwiseMapKernelTag",
"NotAnArrayContainerError",
Expand Down
174 changes: 174 additions & 0 deletions arraycontext/impl/cupy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from __future__ import annotations


"""
.. currentmodule:: arraycontext

A mod :`cupy`-based array context.

.. autoclass:: CupyArrayContext
"""

__copyright__ = """
Copyright (C) 2024 University of Illinois Board of Trustees
"""

__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

from typing import Any

import loopy as lp
from pytools.tag import ToTagSetConvertible

from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import (
Array,
ArrayContext,
ArrayOrContainerOrScalar,
ArrayOrContainerOrScalarT,
NumpyOrContainerOrScalar,
UntransformedCodeWarning,
)


class CupyNonObjectArrayMetaclass(type):
def __instancecheck__(cls, instance: Any) -> bool:
import cupy as cp # type: ignore[import-untyped]
return isinstance(instance, cp.ndarray) and instance.dtype != object


class CupyNonObjectArray(metaclass=CupyNonObjectArrayMetaclass):
pass


class CupyArrayContext(ArrayContext):
"""
A :class:`ArrayContext` that uses :class:`cupy.ndarray` to represent arrays.

.. automethod:: __init__
"""

_loopy_transform_cache: dict[lp.TranslationUnit, lp.ExecutorBase]

def __init__(self) -> None:
super().__init__()
self._loopy_transform_cache = {}

array_types = (CupyNonObjectArray,)

def _get_fake_numpy_namespace(self):
from .fake_numpy import CupyFakeNumpyNamespace
return CupyFakeNumpyNamespace(self)

# {{{ ArrayContext interface

def clone(self):
return type(self)()

def from_numpy(self,
array: NumpyOrContainerOrScalar
) -> ArrayOrContainerOrScalar:
import cupy as cp
return cp.array(array)

def to_numpy(self,
array: ArrayOrContainerOrScalar
) -> NumpyOrContainerOrScalar:
import cupy as cp
return cp.asnumpy(array)

def call_loopy(
self,
t_unit: lp.TranslationUnit, **kwargs: Any
) -> dict[str, Array]:
t_unit = t_unit.copy(target=lp.ExecutableCTarget())
try:
executor = self._loopy_transform_cache[t_unit]
except KeyError:
executor = self.transform_loopy_program(t_unit).executor()
self._loopy_transform_cache[t_unit] = executor

_, result = executor(**kwargs)

return result

def freeze(self, array):
import cupy as cp

def _freeze(ary):
return cp.asnumpy(ary)

return with_array_context(rec_map_array_container(_freeze, array), actx=None)

def thaw(self, array):
import cupy as cp

def _thaw(ary):
return cp.array(ary)

return with_array_context(rec_map_array_container(_thaw, array), actx=self)

# }}}

def transform_loopy_program(self, t_unit):
from warnings import warn
warn("Using the base "
f"{type(self).__name__}.transform_loopy_program "
"to transform a translation unit. "
"This is a no-op and will result in unoptimized C code for"
"the requested optimization, all in a single statement."
"This will work, but is unlikely to be performant."
f"Instead, subclass {type(self).__name__} and implement "
"the specific transform logic required to transform the program "
"for your package or application. Check higher-level packages "
"(e.g. meshmode), which may already have subclasses you may want "
"to build on.",
UntransformedCodeWarning, stacklevel=2)

return t_unit

def tag(self,
tags: ToTagSetConvertible,
array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
# Cupy (like numpy) doesn't support tagging
return array

def tag_axis(self,
iaxis: int, tags: ToTagSetConvertible,
array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
# Cupy (like numpy) doesn't support tagging
return array

def einsum(self, spec, *args, arg_names=None, tagged=()):
import cupy as cp
return cp.einsum(spec, *args)

@property
def permits_inplace_modification(self):
return True

@property
def supports_nonscalar_broadcasting(self):
return True

@property
def permits_advanced_indexing(self):
return True
Loading
Loading