Skip to content

Commit

Permalink
fixup! Implemented BatchedEinsumArrayContext
Browse files Browse the repository at this point in the history
avoid loop-fusion errors associated with saturations of long dimensions
  • Loading branch information
kaushikcfd committed Sep 7, 2023
1 parent 716565a commit 8bc3e14
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions arraycontext/impl/pytato/batched_einsum/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,14 @@ def apply_kennedy_fusion_with_batched_einsum_extension(

if insn.reduction_inames():
einsum, _ = fnsm.get_a_matched_einsum(
t_unit, insn_match=lp_match.Id(insn.id))
t_unit, insn_match=lp_match.Id(insn.id),
# only consider inames with same length for fusion
# => do not parametrize inames with very long loop-counts.
long_dim_length=np.inf)
einsum = fnsm.canonicalize_einsum(einsum)
subst_map = fnsm.match_t_unit_to_einsum(
t_unit, einsum, insn_match=lp_match.Id(insn.id))
t_unit, einsum, insn_match=lp_match.Id(insn.id),
long_dim_length=np.inf)
else:
# we treat any non-reduction einsum as a copy-einsum
assignee = insn.assignee
Expand Down

0 comments on commit 8bc3e14

Please sign in to comment.