From 85f0e00d76d3012e5590b96832a261c0ed2c99ef Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 22 Jan 2023 15:05:08 -0600 Subject: [PATCH] Test BatchedEinsumPytatoPyOpenCLArrayContext --- test/test_batched_einsum_actx.py | 217 +++++++++++++++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 test/test_batched_einsum_actx.py diff --git a/test/test_batched_einsum_actx.py b/test/test_batched_einsum_actx.py new file mode 100644 index 00000000..1e00b1e3 --- /dev/null +++ b/test/test_batched_einsum_actx.py @@ -0,0 +1,217 @@ +import pytest + + +try: + import feinsum # noqa: F401 +except ModuleNotFoundError: + pytest.skip(reason="BatchedEinsumActx imposes feinsum as a hard dep.", + allow_module_level=True) + +try: + from loopy import get_kennedy_unweighted_fusion_candidates # noqa: F401 + from loopy import rename_inames_in_batch # noqa: F401 +except ImportError: + pytest.skip(reason="BatchedEinsumActx imposes loop-fusion support in " + "loopy as a hard dep.", allow_module_level=True) + +import numpy as np + +from pytools.tag import UniqueTag + +from arraycontext import ( + BatchedEinsumPytatoPyOpenCLArrayContext, PyOpenCLArrayContext, + PytatoPyOpenCLArrayContext, tag_axes) +from arraycontext.pytest import ( + _PytestEagerJaxArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass, + _PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory, + _PytestSplitPytatoPyOpenCLArrayContextFactory, + pytest_generate_tests_for_array_contexts) + + +# {{{ axes tag types for image processing + +class ImageDimensionTag(UniqueTag): + """ + An abstract tag type that is tagged to an array's axis indexing along an image's + axis. + """ + + +class XDimension(ImageDimensionTag): + """ + A tag that is attached to a :class:`pytato.array.Axis` that indexes along the + x-dimension of an image. + """ + + +class YDimension(ImageDimensionTag): + """ + A tag that is attached to a :class:`pytato.array.Axis` that indexes along the + y-dimension of an image. + """ + + +class ChannelDimension(ImageDimensionTag): + """ + A tag that is attached to a :class:`pytato.array.Axis` that indexes along the + channels of an image. + """ + +# }}} + + +# {{{ array context fixture + +class ImageProcessingFusionPytatoPyOpenCLArrayContextForImageProc( + BatchedEinsumPytatoPyOpenCLArrayContext): + def __init__(self, queue, allocator=None): + super().__init__(queue, allocator, + fallback_to_no_fusion=False, + loop_fusion_axis_tag_t=ImageDimensionTag) + + +class _PyOpenCLArrayContextForTests(PyOpenCLArrayContext): + """Like :class:`PyOpenCLArrayContext`, but applies no program transformations + whatsoever. Only to be used for testing internal to :mod:`arraycontext`. + """ + + def transform_loopy_program(self, t_unit): + return t_unit + + +class _PytatoPyOpenCLArrayContextForTests(PytatoPyOpenCLArrayContext): + """Like :class:`PytatoPyOpenCLArrayContext`, but applies no program + transformations whatsoever. Only to be used for testing internal to + :mod:`arraycontext`. + """ + + def transform_loopy_program(self, t_unit): + return t_unit + + +class _PytatoPyOpenCLArrayContextForTestsFactory( + _PytestPytatoPyOpenCLArrayContextFactory): + actx_class = _PytatoPyOpenCLArrayContextForTests + + +class _PyOpenCLArrayContextForTestsFactoryWithHostScalars( + _PytestPyOpenCLArrayContextFactoryWithClass): + force_device_scalars = True + actx_class = _PyOpenCLArrayContextForTests + + +class _PytestBatchedEinsumPytatoPyOpenCLArrayContextFactory( + _PytestPytatoPyOpenCLArrayContextFactory): + @property + def actx_class(self): + return ImageProcessingFusionPytatoPyOpenCLArrayContextForImageProc + + +pytest_generate_tests = pytest_generate_tests_for_array_contexts([ + _PyOpenCLArrayContextForTestsFactoryWithHostScalars, + _PytatoPyOpenCLArrayContextForTestsFactory, + _PytestEagerJaxArrayContextFactory, + _PytestPytatoJaxArrayContextFactory, + _PytestSplitPytatoPyOpenCLArrayContextFactory, + _PytestBatchedEinsumPytatoPyOpenCLArrayContextFactory, + ]) + +# }}} + + +def test_simple_add(actx_factory): + # Lesson 01 of Halide Tutorial + actx = actx_factory() + + rng = np.random.default_rng(0) + a_np = rng.random((800, 600)) + b_np = rng.random((800, 600)) + a = actx.from_numpy(a_np) + b = actx.from_numpy(b_np) + + a = tag_axes(actx, {0: XDimension(), 1: YDimension()}, a) + b = tag_axes(actx, {0: XDimension(), 1: YDimension()}, b) + + out = actx.to_numpy(a + b) + ref_out = a_np + b_np + + np.testing.assert_allclose(out, ref_out) + + +def test_brighten_image(actx_factory): + # Lesson 02 of Halide Tutorial + actx = actx_factory() + + rng = np.random.default_rng(0) + + img_np = 255*rng.random((800, 600, 3), dtype=np.float32) + + img = actx.from_numpy(img_np) + img = tag_axes(actx, + {0: XDimension(), 1: YDimension(), 2: ChannelDimension()}, + img) + + brightened_img = 1.5*img + clamped_brightened_img = actx.np.minimum(brightened_img, np.float32(255)) + + out = actx.to_numpy(clamped_brightened_img) + ref_out = np.minimum(1.5*img_np, np.float32(255)) + + np.testing.assert_allclose(out, ref_out) + + +def test_simple_einsum(actx_factory): + actx = actx_factory() + + rng = np.random.default_rng() + + a_np = rng.random((10, 4)) + a = actx.from_numpy(a_np) + a = tag_axes(actx, + {0: XDimension(), 1: YDimension()}, a) + + out1 = actx.einsum("ij,ij->i", a, a+1) + out2 = actx.einsum("ij,ij->i", 2*a, 3*a+7) + + ref_out = (np.einsum("ij,ij->i", a_np, a_np + 1) + + np.einsum("ij,ij->i", 2*a_np, 3*a_np+7)) + out = actx.to_numpy(out1 + out2) + + np.testing.assert_allclose(ref_out, out) + + +def test_nested_einsum(actx_factory): + actx = actx_factory() + + rng = np.random.default_rng() + + a_np = rng.random((10, 4)) + + # {{{ compute out + + a = actx.from_numpy(a_np) + a = tag_axes(actx, + {0: XDimension(), 1: YDimension()}, a) + b = a + 1 + + out1 = actx.einsum("ij,ij->i", a, b) + out2 = actx.einsum("ij,ij->i", 2*a, 3*a+7) + out3 = actx.einsum("ij,i->i", 3*b, 2*out1) + + out = actx.to_numpy(out1 + out2 + out3) + + # }}} + + # {{{ compute ref_out + + b_np = a_np + 1 + out1_np = np.einsum("ij,ij->i", a_np, a_np+1) + out2_np = np.einsum("ij,ij->i", 2*a_np, 3*a_np+7) + out3_np = np.einsum("ij,i->i", 3*b_np, 2*out1_np) + ref_out = out1_np + out2_np + out3_np + + # }}} + + np.testing.assert_allclose(ref_out, out) + +# vim: fdm=marker