Skip to content

Commit

Permalink
[merge_attentions] Fix for sequences longer than 65k (fairinternal/xf…
Browse files Browse the repository at this point in the history
  • Loading branch information
danthe3rd authored and xFormers Bot committed Mar 13, 2024
1 parent 2c719c6 commit 503a5d7
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions xformers/ops/fmha/triton_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,11 +641,12 @@ def _splitK_reduce(
G: tl.constexpr,
WRITE_LSE: tl.constexpr,
):
off_zhg = tl.program_id(0).to(tl.int64)
off_z = off_zhg // (H * G)
off_h = (off_zhg // G) % H
off_g = off_zhg % G
off_m = tl.program_id(1)
# grid = (M, B, G * H)
off_m = tl.program_id(0).to(tl.int64)
off_z = tl.program_id(1).to(tl.int64)
off_gh = tl.program_id(2).to(tl.int64)
off_g = off_gh % G
off_h = off_gh // G

Out_splitK_ptr = (
Out_splitK
Expand Down Expand Up @@ -1173,7 +1174,9 @@ def merge_attentions(
B == B3 and G == G3 and H == H3 and M == M3
), f"Incompatible shapes: {attn_out.shape=}, {lse_out.shape=}"

grid = (B * G * H, M, 1)
grid = (M, B, G * H)
if max(grid[1], grid[2]) > 2**16:
raise ValueError(f"Problem size too big: {grid}")
_splitK_reduce[grid](
attn_split,
lse_split,
Expand Down

0 comments on commit 503a5d7

Please sign in to comment.