Skip to content

Commit 716565a

Browse files
committed
fixup! Test BatchedEinsumPytatoPyOpenCLArrayContext
Adds a failing test for the dimension mismatch error.
1 parent 8211a9c commit 716565a

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

test/test_batched_einsum_actx.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
_PytestSplitPytatoPyOpenCLArrayContextFactory,
3030
pytest_generate_tests_for_array_contexts)
3131

32+
from pytools.obj_array import make_obj_array
33+
3234

3335
# {{{ axes tag types for image processing
3436

@@ -265,4 +267,32 @@ def test_dg_3d_divergence(actx_factory):
265267

266268
np.testing.assert_allclose(ref_out, actx.to_numpy(out))
267269

270+
271+
def test_multiple_large_sized_outputs(actx_factory):
272+
actx = actx_factory()
273+
rng = np.random.default_rng(0)
274+
n1 = 1_000_000
275+
n2 = 2_000_000
276+
277+
x1_np = rng.random((n1, 1))
278+
x2_np = rng.random((n2, 1))
279+
280+
x1 = actx.from_numpy(x1_np)
281+
x2 = actx.from_numpy(x2_np)
282+
283+
x1 = tag_axes(actx, {0: NamedAxis("e"),
284+
1: NamedAxis("i")},
285+
x1)
286+
x2 = tag_axes(actx, {0: NamedAxis("e"),
287+
1: NamedAxis("i")},
288+
x2)
289+
290+
out = make_obj_array([actx.einsum("ij->i", 3 * x1),
291+
actx.einsum("ij->i", 4 * x2)])
292+
ref_out = make_obj_array([np.einsum("ij->i", 3 * x1_np),
293+
np.einsum("ij->i", 4 * x2_np)])
294+
295+
np.testing.assert_allclose(ref_out[0], actx.to_numpy(out)[0])
296+
np.testing.assert_allclose(ref_out[1], actx.to_numpy(out)[1])
297+
268298
# vim: fdm=marker

0 commit comments

Comments
 (0)