diff --git a/xformers/ops/fmha/triton_splitk.py b/xformers/ops/fmha/triton_splitk.py index ea6df6b056..e3a85d9fb5 100644 --- a/xformers/ops/fmha/triton_splitk.py +++ b/xformers/ops/fmha/triton_splitk.py @@ -641,11 +641,12 @@ def _splitK_reduce( G: tl.constexpr, WRITE_LSE: tl.constexpr, ): - off_zhg = tl.program_id(0).to(tl.int64) - off_z = off_zhg // (H * G) - off_h = (off_zhg // G) % H - off_g = off_zhg % G - off_m = tl.program_id(1) + # grid = (M, B, G * H) + off_m = tl.program_id(0).to(tl.int64) + off_z = tl.program_id(1).to(tl.int64) + off_gh = tl.program_id(2).to(tl.int64) + off_g = off_gh % G + off_h = off_gh // G Out_splitK_ptr = ( Out_splitK @@ -1173,7 +1174,9 @@ def merge_attentions( B == B3 and G == G3 and H == H3 and M == M3 ), f"Incompatible shapes: {attn_out.shape=}, {lse_out.shape=}" - grid = (B * G * H, M, 1) + grid = (M, B, G * H) + if max(grid[1], grid[2]) > 2**16: + raise ValueError(f"Problem size too big: {grid}") _splitK_reduce[grid]( attn_split, lse_split,