From caaa471f9ec2fd3f71979ce563987ebc151896b0 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 6 Sep 2023 23:50:15 -0700 Subject: [PATCH] fixup! Implemented BatchedEinsumArrayContext avoid loop-fusion errors associated with saturations of long dimensions --- arraycontext/impl/pytato/batched_einsum/utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/arraycontext/impl/pytato/batched_einsum/utils.py b/arraycontext/impl/pytato/batched_einsum/utils.py index 1a69299f..a8ff97d5 100644 --- a/arraycontext/impl/pytato/batched_einsum/utils.py +++ b/arraycontext/impl/pytato/batched_einsum/utils.py @@ -191,10 +191,14 @@ def apply_kennedy_fusion_with_batched_einsum_extension( if insn.reduction_inames(): einsum, _ = fnsm.get_a_matched_einsum( - t_unit, insn_match=lp_match.Id(insn.id)) + t_unit, insn_match=lp_match.Id(insn.id), + # only consider inames with same length for fusion + # => do not parametrize inames with very long loop-counts. + long_dim_length=np.inf) einsum = fnsm.canonicalize_einsum(einsum) subst_map = fnsm.match_t_unit_to_einsum( - t_unit, einsum, insn_match=lp_match.Id(insn.id)) + t_unit, einsum, insn_match=lp_match.Id(insn.id), + long_dim_length=np.inf) else: # we treat any non-reduction einsum as a copy-einsum assignee = insn.assignee