-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Test BatchedEinsumPytatoPyOpenCLArrayContext
- Loading branch information
1 parent
1d8c2f2
commit fb42c99
Showing
1 changed file
with
217 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
import pytest | ||
|
||
|
||
try: | ||
import feinsum # noqa: F401 | ||
except ModuleNotFoundError: | ||
pytest.skip(reason="BatchedEinsumActx imposes feinsum as a hard dep.", | ||
allow_module_level=True) | ||
|
||
try: | ||
from loopy import get_kennedy_unweighted_fusion_candidates # noqa: F401 | ||
from loopy import rename_inames_in_batch # noqa: F401 | ||
except ImportError: | ||
pytest.skip(reason="BatchedEinsumActx imposes loop-fusion support in " | ||
"loopy as a hard dep.", allow_module_level=True) | ||
|
||
import numpy as np | ||
|
||
from pytools.tag import UniqueTag | ||
|
||
from arraycontext import ( | ||
BatchedEinsumPytatoPyOpenCLArrayContext, PyOpenCLArrayContext, | ||
PytatoPyOpenCLArrayContext, tag_axes) | ||
from arraycontext.pytest import ( | ||
_PytestEagerJaxArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass, | ||
_PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory, | ||
_PytestSplitPytatoPyOpenCLArrayContextFactory, | ||
pytest_generate_tests_for_array_contexts) | ||
|
||
|
||
# {{{ axes tag types for image processing | ||
|
||
class ImageDimensionTag(UniqueTag): | ||
""" | ||
An abstract tag type that is tagged to an array's axis indexing along an image's | ||
axis. | ||
""" | ||
|
||
|
||
class XDimension(ImageDimensionTag): | ||
""" | ||
A tag that is attached to a :class:`pytato.array.Axis` that indexes along the | ||
x-dimension of an image. | ||
""" | ||
|
||
|
||
class YDimension(ImageDimensionTag): | ||
""" | ||
A tag that is attached to a :class:`pytato.array.Axis` that indexes along the | ||
y-dimension of an image. | ||
""" | ||
|
||
|
||
class ChannelDimension(ImageDimensionTag): | ||
""" | ||
A tag that is attached to a :class:`pytato.array.Axis` that indexes along the | ||
channels of an image. | ||
""" | ||
|
||
# }}} | ||
|
||
|
||
# {{{ array context fixture | ||
|
||
class ImageProcessingFusionPytatoPyOpenCLArrayContextForImageProc( | ||
BatchedEinsumPytatoPyOpenCLArrayContext): | ||
def __init__(self, queue, allocator=None): | ||
super().__init__(queue, allocator, | ||
fallback_to_no_fusion=False, | ||
loop_fusion_axis_tag_t=ImageDimensionTag) | ||
|
||
|
||
class _PyOpenCLArrayContextForTests(PyOpenCLArrayContext): | ||
"""Like :class:`PyOpenCLArrayContext`, but applies no program transformations | ||
whatsoever. Only to be used for testing internal to :mod:`arraycontext`. | ||
""" | ||
|
||
def transform_loopy_program(self, t_unit): | ||
return t_unit | ||
|
||
|
||
class _PytatoPyOpenCLArrayContextForTests(PytatoPyOpenCLArrayContext): | ||
"""Like :class:`PytatoPyOpenCLArrayContext`, but applies no program | ||
transformations whatsoever. Only to be used for testing internal to | ||
:mod:`arraycontext`. | ||
""" | ||
|
||
def transform_loopy_program(self, t_unit): | ||
return t_unit | ||
|
||
|
||
class _PytatoPyOpenCLArrayContextForTestsFactory( | ||
_PytestPytatoPyOpenCLArrayContextFactory): | ||
actx_class = _PytatoPyOpenCLArrayContextForTests | ||
|
||
|
||
class _PyOpenCLArrayContextForTestsFactoryWithHostScalars( | ||
_PytestPyOpenCLArrayContextFactoryWithClass): | ||
force_device_scalars = True | ||
actx_class = _PyOpenCLArrayContextForTests | ||
|
||
|
||
class _PytestBatchedEinsumPytatoPyOpenCLArrayContextFactory( | ||
_PytestPytatoPyOpenCLArrayContextFactory): | ||
@property | ||
def actx_class(self): | ||
return ImageProcessingFusionPytatoPyOpenCLArrayContextForImageProc | ||
|
||
|
||
pytest_generate_tests = pytest_generate_tests_for_array_contexts([ | ||
_PyOpenCLArrayContextForTestsFactoryWithHostScalars, | ||
_PytatoPyOpenCLArrayContextForTestsFactory, | ||
_PytestEagerJaxArrayContextFactory, | ||
_PytestPytatoJaxArrayContextFactory, | ||
_PytestSplitPytatoPyOpenCLArrayContextFactory, | ||
_PytestBatchedEinsumPytatoPyOpenCLArrayContextFactory, | ||
]) | ||
|
||
# }}} | ||
|
||
|
||
def test_simple_add(actx_factory): | ||
# Lesson 01 of Halide Tutorial | ||
actx = actx_factory() | ||
|
||
rng = np.random.default_rng(0) | ||
a_np = rng.random((800, 600)) | ||
b_np = rng.random((800, 600)) | ||
a = actx.from_numpy(a_np) | ||
b = actx.from_numpy(b_np) | ||
|
||
a = tag_axes(actx, {0: XDimension(), 1: YDimension()}, a) | ||
b = tag_axes(actx, {0: XDimension(), 1: YDimension()}, b) | ||
|
||
out = actx.to_numpy(a + b) | ||
ref_out = a_np + b_np | ||
|
||
np.testing.assert_allclose(out, ref_out) | ||
|
||
|
||
def test_brighten_image(actx_factory): | ||
# Lesson 02 of Halide Tutorial | ||
actx = actx_factory() | ||
|
||
rng = np.random.default_rng(0) | ||
|
||
img_np = 255*rng.random((800, 600, 3), dtype=np.float32) | ||
|
||
img = actx.from_numpy(img_np) | ||
img = tag_axes(actx, | ||
{0: XDimension(), 1: YDimension(), 2: ChannelDimension()}, | ||
img) | ||
|
||
brightened_img = 1.5*img | ||
clamped_brightened_img = actx.np.minimum(brightened_img, np.float32(255)) | ||
|
||
out = actx.to_numpy(clamped_brightened_img) | ||
ref_out = np.minimum(1.5*img_np, np.float32(255)) | ||
|
||
np.testing.assert_allclose(out, ref_out) | ||
|
||
|
||
def test_simple_einsum(actx_factory): | ||
actx = actx_factory() | ||
|
||
rng = np.random.default_rng() | ||
|
||
a_np = rng.random((10, 4)) | ||
a = actx.from_numpy(a_np) | ||
a = tag_axes(actx, | ||
{0: XDimension(), 1: YDimension()}, a) | ||
|
||
out1 = actx.einsum("ij,ij->i", a, a+1) | ||
out2 = actx.einsum("ij,ij->i", 2*a, 3*a+7) | ||
|
||
ref_out = (np.einsum("ij,ij->i", a_np, a_np + 1) | ||
+ np.einsum("ij,ij->i", 2*a_np, 3*a_np+7)) | ||
out = actx.to_numpy(out1 + out2) | ||
|
||
np.testing.assert_allclose(ref_out, out) | ||
|
||
|
||
def test_nested_einsum(actx_factory): | ||
actx = actx_factory() | ||
|
||
rng = np.random.default_rng() | ||
|
||
a_np = rng.random((10, 4)) | ||
|
||
# {{{ compute out | ||
|
||
a = actx.from_numpy(a_np) | ||
a = tag_axes(actx, | ||
{0: XDimension(), 1: YDimension()}, a) | ||
b = a + 1 | ||
|
||
out1 = actx.einsum("ij,ij->i", a, b) | ||
out2 = actx.einsum("ij,ij->i", 2*a, 3*a+7) | ||
out3 = actx.einsum("ij,i->i", 3*b, 2*out1) | ||
|
||
out = actx.to_numpy(out1 + out2 + out3) | ||
|
||
# }}} | ||
|
||
# {{{ compute ref_out | ||
|
||
b_np = a_np + 1 | ||
out1_np = np.einsum("ij,ij->i", a_np, a_np+1) | ||
out2_np = np.einsum("ij,ij->i", 2*a_np, 3*a_np+7) | ||
out3_np = np.einsum("ij,i->i", 3*b_np, 2*out1_np) | ||
ref_out = out1_np + out2_np + out3_np | ||
|
||
# }}} | ||
|
||
np.testing.assert_allclose(ref_out, out) | ||
|
||
# vim: fdm=marker |