Skip to content

Commit

Permalink
Test BatchedEinsumPytatoPyOpenCLArrayContext
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikcfd committed Jan 24, 2023
1 parent 1d8c2f2 commit fb42c99
Showing 1 changed file with 217 additions and 0 deletions.
217 changes: 217 additions & 0 deletions test/test_batched_einsum_actx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import pytest


try:
import feinsum # noqa: F401
except ModuleNotFoundError:
pytest.skip(reason="BatchedEinsumActx imposes feinsum as a hard dep.",
allow_module_level=True)

try:
from loopy import get_kennedy_unweighted_fusion_candidates # noqa: F401
from loopy import rename_inames_in_batch # noqa: F401
except ImportError:
pytest.skip(reason="BatchedEinsumActx imposes loop-fusion support in "
"loopy as a hard dep.", allow_module_level=True)

import numpy as np

from pytools.tag import UniqueTag

from arraycontext import (
BatchedEinsumPytatoPyOpenCLArrayContext, PyOpenCLArrayContext,
PytatoPyOpenCLArrayContext, tag_axes)
from arraycontext.pytest import (
_PytestEagerJaxArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass,
_PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory,
_PytestSplitPytatoPyOpenCLArrayContextFactory,
pytest_generate_tests_for_array_contexts)


# {{{ axes tag types for image processing

class ImageDimensionTag(UniqueTag):
"""
An abstract tag type that is tagged to an array's axis indexing along an image's
axis.
"""


class XDimension(ImageDimensionTag):
"""
A tag that is attached to a :class:`pytato.array.Axis` that indexes along the
x-dimension of an image.
"""


class YDimension(ImageDimensionTag):
"""
A tag that is attached to a :class:`pytato.array.Axis` that indexes along the
y-dimension of an image.
"""


class ChannelDimension(ImageDimensionTag):
"""
A tag that is attached to a :class:`pytato.array.Axis` that indexes along the
channels of an image.
"""

# }}}


# {{{ array context fixture

class ImageProcessingFusionPytatoPyOpenCLArrayContextForImageProc(
BatchedEinsumPytatoPyOpenCLArrayContext):
def __init__(self, queue, allocator=None):
super().__init__(queue, allocator,
fallback_to_no_fusion=False,
loop_fusion_axis_tag_t=ImageDimensionTag)


class _PyOpenCLArrayContextForTests(PyOpenCLArrayContext):
"""Like :class:`PyOpenCLArrayContext`, but applies no program transformations
whatsoever. Only to be used for testing internal to :mod:`arraycontext`.
"""

def transform_loopy_program(self, t_unit):
return t_unit


class _PytatoPyOpenCLArrayContextForTests(PytatoPyOpenCLArrayContext):
"""Like :class:`PytatoPyOpenCLArrayContext`, but applies no program
transformations whatsoever. Only to be used for testing internal to
:mod:`arraycontext`.
"""

def transform_loopy_program(self, t_unit):
return t_unit


class _PytatoPyOpenCLArrayContextForTestsFactory(
_PytestPytatoPyOpenCLArrayContextFactory):
actx_class = _PytatoPyOpenCLArrayContextForTests


class _PyOpenCLArrayContextForTestsFactoryWithHostScalars(
_PytestPyOpenCLArrayContextFactoryWithClass):
force_device_scalars = True
actx_class = _PyOpenCLArrayContextForTests


class _PytestBatchedEinsumPytatoPyOpenCLArrayContextFactory(
_PytestPytatoPyOpenCLArrayContextFactory):
@property
def actx_class(self):
return ImageProcessingFusionPytatoPyOpenCLArrayContextForImageProc


pytest_generate_tests = pytest_generate_tests_for_array_contexts([
_PyOpenCLArrayContextForTestsFactoryWithHostScalars,
_PytatoPyOpenCLArrayContextForTestsFactory,
_PytestEagerJaxArrayContextFactory,
_PytestPytatoJaxArrayContextFactory,
_PytestSplitPytatoPyOpenCLArrayContextFactory,
_PytestBatchedEinsumPytatoPyOpenCLArrayContextFactory,
])

# }}}


def test_simple_add(actx_factory):
# Lesson 01 of Halide Tutorial
actx = actx_factory()

rng = np.random.default_rng(0)
a_np = rng.random((800, 600))
b_np = rng.random((800, 600))
a = actx.from_numpy(a_np)
b = actx.from_numpy(b_np)

a = tag_axes(actx, {0: XDimension(), 1: YDimension()}, a)
b = tag_axes(actx, {0: XDimension(), 1: YDimension()}, b)

out = actx.to_numpy(a + b)
ref_out = a_np + b_np

np.testing.assert_allclose(out, ref_out)


def test_brighten_image(actx_factory):
# Lesson 02 of Halide Tutorial
actx = actx_factory()

rng = np.random.default_rng(0)

img_np = 255*rng.random((800, 600, 3), dtype=np.float32)

img = actx.from_numpy(img_np)
img = tag_axes(actx,
{0: XDimension(), 1: YDimension(), 2: ChannelDimension()},
img)

brightened_img = 1.5*img
clamped_brightened_img = actx.np.minimum(brightened_img, np.float32(255))

out = actx.to_numpy(clamped_brightened_img)
ref_out = np.minimum(1.5*img_np, np.float32(255))

np.testing.assert_allclose(out, ref_out)


def test_simple_einsum(actx_factory):
actx = actx_factory()

rng = np.random.default_rng()

a_np = rng.random((10, 4))
a = actx.from_numpy(a_np)
a = tag_axes(actx,
{0: XDimension(), 1: YDimension()}, a)

out1 = actx.einsum("ij,ij->i", a, a+1)
out2 = actx.einsum("ij,ij->i", 2*a, 3*a+7)

ref_out = (np.einsum("ij,ij->i", a_np, a_np + 1)
+ np.einsum("ij,ij->i", 2*a_np, 3*a_np+7))
out = actx.to_numpy(out1 + out2)

np.testing.assert_allclose(ref_out, out)


def test_nested_einsum(actx_factory):
actx = actx_factory()

rng = np.random.default_rng()

a_np = rng.random((10, 4))

# {{{ compute out

a = actx.from_numpy(a_np)
a = tag_axes(actx,
{0: XDimension(), 1: YDimension()}, a)
b = a + 1

out1 = actx.einsum("ij,ij->i", a, b)
out2 = actx.einsum("ij,ij->i", 2*a, 3*a+7)
out3 = actx.einsum("ij,i->i", 3*b, 2*out1)

out = actx.to_numpy(out1 + out2 + out3)

# }}}

# {{{ compute ref_out

b_np = a_np + 1
out1_np = np.einsum("ij,ij->i", a_np, a_np+1)
out2_np = np.einsum("ij,ij->i", 2*a_np, 3*a_np+7)
out3_np = np.einsum("ij,i->i", 3*b_np, 2*out1_np)
ref_out = out1_np + out2_np + out3_np

# }}}

np.testing.assert_allclose(ref_out, out)

# vim: fdm=marker

0 comments on commit fb42c99

Please sign in to comment.