Skip to content

Commit abd4b30

Browse files
committed
add NumpyArrayContext subclass and tests
1 parent a449133 commit abd4b30

13 files changed

+100
-37
lines changed

meshmode/array_context.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@
2828
from warnings import warn
2929

3030
from arraycontext import (
31+
NumpyArrayContext as NumpyArrayContextBase,
3132
PyOpenCLArrayContext as PyOpenCLArrayContextBase,
3233
PytatoPyOpenCLArrayContext as PytatoPyOpenCLArrayContextBase,
3334
)
3435
from arraycontext.pytest import (
36+
_PytestNumpyArrayContextFactory,
3537
_PytestPyOpenCLArrayContextFactoryWithClass,
3638
_PytestPytatoPyOpenCLArrayContextFactory,
3739
register_pytest_array_context_factory,
@@ -198,6 +200,24 @@ def _transform_with_element_and_dof_inames(t_unit, el_inames, dof_inames):
198200
# }}}
199201

200202

203+
# {{{ numpy array context subclass
204+
205+
class NumpyArrayContext(NumpyArrayContextBase):
206+
def transform_loopy_program(self, t_unit):
207+
default_ep = t_unit.default_entrypoint
208+
options = default_ep.options
209+
if not (options.return_dict and options.no_numpy):
210+
raise ValueError("Loopy kernel passed to call_loopy must "
211+
"have return_dict and no_numpy options set. "
212+
"Did you use arraycontext.make_loopy_program "
213+
"to create this kernel?")
214+
215+
import loopy as lp
216+
return lp.add_inames_for_unused_hw_axes(t_unit)
217+
218+
# }}}
219+
220+
201221
# {{{ pyopencl array context subclass
202222

203223
class PyOpenCLArrayContext(PyOpenCLArrayContextBase):
@@ -268,6 +288,11 @@ def transform_loopy_program(self, t_unit):
268288

269289
# {{{ pytest actx factory
270290

291+
class PytestNumpyArrayContextFactory(_PytestNumpyArrayContextFactory):
292+
def __call__(self):
293+
return NumpyArrayContext()
294+
295+
271296
class PytestPyOpenCLArrayContextFactory(
272297
_PytestPyOpenCLArrayContextFactoryWithClass):
273298
actx_class = PyOpenCLArrayContext
@@ -281,6 +306,8 @@ def actx_class(self):
281306
return PytatoPyOpenCLArrayContext
282307

283308

309+
register_pytest_array_context_factory("meshmode.numpy",
310+
PytestNumpyArrayContextFactory)
284311
register_pytest_array_context_factory("meshmode.pyopencl",
285312
PytestPyOpenCLArrayContextFactory)
286313
register_pytest_array_context_factory("meshmode.pytato_cl",

test/test_array.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
from meshmode import _acf # noqa: F401
3939
from meshmode.array_context import (
40+
PytestNumpyArrayContextFactory,
4041
PytestPyOpenCLArrayContextFactory,
4142
PytestPytatoPyOpenCLArrayContextFactory,
4243
)
@@ -46,10 +47,11 @@
4647

4748

4849
logger = logging.getLogger(__name__)
49-
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
50-
[PytestPytatoPyOpenCLArrayContextFactory,
51-
PytestPyOpenCLArrayContextFactory,
52-
])
50+
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
51+
PytestNumpyArrayContextFactory,
52+
PytestPytatoPyOpenCLArrayContextFactory,
53+
PytestPyOpenCLArrayContextFactory,
54+
])
5355

5456

5557
@with_container_arithmetic(bcast_obj_array=False,
@@ -163,6 +165,10 @@ class FooAxisTag2(Tag):
163165
def test_dof_array_pickling_tags(actx_factory):
164166
actx = actx_factory()
165167

168+
from meshmode.array_context import NumpyArrayContext
169+
if isinstance(actx, NumpyArrayContext):
170+
pytest.skip(f"{type(actx).__name__} does not support tags")
171+
166172
from pickle import dumps, loads
167173

168174
state = DOFArray(actx, (

test/test_chained.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,16 @@
2828
from arraycontext import flatten, pytest_generate_tests_for_array_contexts
2929

3030
from meshmode import _acf # noqa: F401
31-
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
31+
from meshmode.array_context import (
32+
PytestPyOpenCLArrayContextFactory,
33+
)
3234
from meshmode.dof_array import flat_norm
3335

3436

3537
logger = logging.getLogger(__name__)
36-
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
37-
[PytestPyOpenCLArrayContextFactory])
38+
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
39+
PytestPyOpenCLArrayContextFactory,
40+
])
3841

3942

4043
def create_discretization(actx, ndim,

test/test_connection.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929

3030
import meshmode.mesh.generation as mgen
3131
from meshmode import _acf # noqa: F401
32-
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
32+
from meshmode.array_context import (
33+
PytestPyOpenCLArrayContextFactory,
34+
)
3335
from meshmode.discretization import Discretization
3436
from meshmode.discretization.connection import FACE_RESTR_ALL
3537
from meshmode.discretization.poly_element import (
@@ -43,8 +45,9 @@
4345

4446

4547
logger = logging.getLogger(__name__)
46-
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
47-
[PytestPyOpenCLArrayContextFactory])
48+
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
49+
PytestPyOpenCLArrayContextFactory,
50+
])
4851

4952

5053
@pytest.mark.parametrize("group_factory", [

test/test_discretization.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,18 @@
2626

2727
import meshmode.mesh.generation as mgen
2828
from meshmode import _acf # noqa: F401
29-
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
29+
from meshmode.array_context import (
30+
PytestPyOpenCLArrayContextFactory,
31+
)
3032
from meshmode.discretization import Discretization
3133
from meshmode.discretization.poly_element import (
3234
InterpolatoryQuadratureSimplexGroupFactory,
3335
)
3436

3537

36-
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
37-
[PytestPyOpenCLArrayContextFactory])
38+
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
39+
PytestPyOpenCLArrayContextFactory,
40+
])
3841

3942

4043
def test_discr_nodes_caching(actx_factory):

test/test_firedrake_interop.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@
4343

4444

4545
logger = logging.getLogger(__name__)
46-
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
47-
[PytestPyOpenCLArrayContextFactory])
46+
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
47+
PytestPyOpenCLArrayContextFactory,
48+
])
4849

4950
CLOSE_ATOL = 1e-12
5051

test/test_interop.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,16 @@
2828
from arraycontext import pytest_generate_tests_for_array_contexts
2929

3030
from meshmode import _acf # noqa: F401
31-
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
31+
from meshmode.array_context import (
32+
PytestPyOpenCLArrayContextFactory,
33+
)
3234
from meshmode.dof_array import flat_norm
3335

3436

3537
logger = logging.getLogger(__name__)
36-
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
37-
[PytestPyOpenCLArrayContextFactory])
38+
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
39+
PytestPyOpenCLArrayContextFactory,
40+
])
3841

3942

4043
@pytest.mark.parametrize("dim", [1, 2, 3])

test/test_mesh.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@
3838
import meshmode.mesh.io as mio
3939
import meshmode.mesh.processing as mproc
4040
from meshmode import _acf # noqa: F401
41-
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
41+
from meshmode.array_context import (
42+
PytestPyOpenCLArrayContextFactory,
43+
)
4244
from meshmode.discretization.poly_element import (
4345
LegendreGaussLobattoTensorProductGroupFactory,
4446
default_simplex_group_factory,
@@ -54,8 +56,9 @@
5456

5557

5658
logger = logging.getLogger(__name__)
57-
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
58-
[PytestPyOpenCLArrayContextFactory])
59+
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
60+
PytestPyOpenCLArrayContextFactory,
61+
])
5962

6063
thisdir = pathlib.Path(__file__).parent
6164

test/test_meshmode.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import meshmode.mesh.generation as mgen
3434
from meshmode import _acf # noqa: F401
3535
from meshmode.array_context import (
36+
PytestNumpyArrayContextFactory,
3637
PytestPyOpenCLArrayContextFactory,
3738
PytestPytatoPyOpenCLArrayContextFactory,
3839
)
@@ -58,10 +59,11 @@
5859

5960

6061
logger = logging.getLogger(__name__)
61-
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
62-
[PytestPytatoPyOpenCLArrayContextFactory,
63-
PytestPyOpenCLArrayContextFactory,
64-
])
62+
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
63+
PytestNumpyArrayContextFactory,
64+
PytestPytatoPyOpenCLArrayContextFactory,
65+
PytestPyOpenCLArrayContextFactory,
66+
])
6567

6668
thisdir = pathlib.Path(__file__).parent
6769

test/test_modal.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030

3131
import meshmode.mesh.generation as mgen
3232
from meshmode import _acf # noqa: F401
33-
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
33+
from meshmode.array_context import (
34+
PytestPyOpenCLArrayContextFactory,
35+
)
3436
from meshmode.discretization import Discretization
3537
from meshmode.discretization.connection.modal import (
3638
ModalToNodalDiscretizationConnection,
@@ -51,8 +53,9 @@
5153
from meshmode.mesh import SimplexElementGroup, TensorProductElementGroup
5254

5355

54-
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
55-
[PytestPyOpenCLArrayContextFactory])
56+
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
57+
PytestPyOpenCLArrayContextFactory,
58+
])
5659

5760

5861
@pytest.mark.parametrize("nodal_group_factory", [

test/test_partition.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
from arraycontext import flatten, pytest_generate_tests_for_array_contexts, unflatten
3434

3535
from meshmode import _acf # noqa: F401
36-
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
36+
from meshmode.array_context import (
37+
PytestPyOpenCLArrayContextFactory,
38+
)
3739
from meshmode.discretization.poly_element import default_simplex_group_factory
3840
from meshmode.dof_array import flat_norm
3941
from meshmode.mesh import (
@@ -45,8 +47,9 @@
4547

4648

4749
logger = logging.getLogger(__name__)
48-
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
49-
[PytestPyOpenCLArrayContextFactory])
50+
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
51+
PytestPyOpenCLArrayContextFactory,
52+
])
5053

5154
# Is there a smart way of choosing this number?
5255
# Currently it is the same as the base from MPIBoundaryCommSetupHelper

test/test_refinement.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131

3232
import meshmode.mesh.generation as mgen
3333
from meshmode import _acf # noqa: F401
34-
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
34+
from meshmode.array_context import (
35+
PytestPyOpenCLArrayContextFactory,
36+
)
3537
from meshmode.discretization.poly_element import (
3638
GaussLegendreTensorProductGroupFactory,
3739
InterpolatoryQuadratureSimplexGroupFactory,
@@ -46,8 +48,9 @@
4648

4749

4850
logger = logging.getLogger(__name__)
49-
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
50-
[PytestPyOpenCLArrayContextFactory])
51+
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
52+
PytestPyOpenCLArrayContextFactory,
53+
])
5154

5255
thisdir = pathlib.Path(__file__).parent
5356

test/test_visualization.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333

3434
import meshmode.mesh.generation as mgen
3535
from meshmode import _acf # noqa: F401
36-
from meshmode.array_context import PytestPyOpenCLArrayContextFactory
36+
from meshmode.array_context import (
37+
PytestPyOpenCLArrayContextFactory,
38+
)
3739
from meshmode.discretization.poly_element import (
3840
InterpolatoryQuadratureSimplexGroupFactory,
3941
LegendreGaussLobattoTensorProductGroupFactory,
@@ -43,8 +45,9 @@
4345

4446

4547
logger = logging.getLogger(__name__)
46-
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
47-
[PytestPyOpenCLArrayContextFactory])
48+
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
49+
PytestPyOpenCLArrayContextFactory,
50+
])
4851

4952
thisdir = pathlib.Path(__file__).parent
5053

0 commit comments

Comments
 (0)