diff --git a/arraycontext/impl/pytato/batched_einsum/utils.py b/arraycontext/impl/pytato/batched_einsum/utils.py index 1a69299f..a8ff97d5 100644 --- a/arraycontext/impl/pytato/batched_einsum/utils.py +++ b/arraycontext/impl/pytato/batched_einsum/utils.py @@ -191,10 +191,14 @@ def apply_kennedy_fusion_with_batched_einsum_extension( if insn.reduction_inames(): einsum, _ = fnsm.get_a_matched_einsum( - t_unit, insn_match=lp_match.Id(insn.id)) + t_unit, insn_match=lp_match.Id(insn.id), + # only consider inames with same length for fusion + # => do not parametrize inames with very long loop-counts. + long_dim_length=np.inf) einsum = fnsm.canonicalize_einsum(einsum) subst_map = fnsm.match_t_unit_to_einsum( - t_unit, einsum, insn_match=lp_match.Id(insn.id)) + t_unit, einsum, insn_match=lp_match.Id(insn.id), + long_dim_length=np.inf) else: # we treat any non-reduction einsum as a copy-einsum assignee = insn.assignee