Skip to content

Conversation

petercad
Copy link

@petercad petercad commented Oct 4, 2025

This PR updates FlashAttention to the new copy/MMA atoms.

Changes:

  • Prefill and decode unified into a single implementation, allowing simultaneous K and Q subgroup-level parallelization rather than an either-or.
  • GEMMs and softmax grouped together and the full k loop consolidated into an FMHA mainloop class.
    • This will facilitate further manual pipelining/overlap of GEMM with softmax.
  • Use new copy/MMA atoms and reorders to transparently support arbitrary data types.
  • Automatic copy/MMA operator selection.

Current status: prefill/decode examples almost all working, similar/better performance to old examples.

Known issues:

  • Head size 192 decode config doesn't compile yet -- to be fixed.
  • Strange SYCL compiler behavior/bug with tSrS->tArP reorder. Apparently the compiler believes there is UB somewhere and will omit a large section of the kernel as a result. For the moment, there's a direct copy as a workaround while I pin down the issue. I'm not able to reproduce this behavior with the reorder in isolation.

Additional features (causal masking, variable sequence lengths, etc.) to be added later.

Reminder: the new atoms require a very recent driver due to necessary IGC fixes/enhancements. Recommended version: ci-comp_igc-30613.

@petercad petercad changed the title [Umbrella commit] Re-implement FlashAttention with new Xe atoms Re-implement FlashAttention with new Xe atoms Oct 4, 2025
@petercad
Copy link
Author

petercad commented Oct 4, 2025

I will break up this large commit into self-contained smaller commits after review is complete.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this here? This isn't flash attention specific, is it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's not. These started as some simple helpers to make copying to/from SLM easier for the epilogue. We could move them, maybe to include/cute/algorithm/cute.hpp, though they should be made more sophisticated (use smaller/larger block sizes as appropriate, automatic fallback to scatter/gather, etc.).

// No diagnostics/error will be issued by the compiler if it is not.
template <typename T>
CUTE_HOST_DEVICE void
set_wi_value(T &x, int i, T val)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't you take i as compile time value to make this safer? The usage is on line 137 where the input comes from the unrolled loop index. If you replace the loop with for_each you have a compile time constant.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is an option -- I did it this way since compile-time unrolling of the loop is IMO harder to use and harder to read.

I opened a compiler ticket for the lack of diagnostics, and they have a patch under review now to address it.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. As long as we have diagnostic that's fine. Current solution won't compile for O0. Not sure whether it matters.

for (int VV = 0; VV < VTiles; VV++) {
copy(copy_v, tVgV(_,_,_,VV,K), tVrV);
reorder(tVrV, tArV);
cute::gemm(mma_pv, tArP, tArV, tArA(_,_,_,VV));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the namespace?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes it's required to disambiguate the gemm name. I can't remember the exact ambiguity here, but I had to add it.

for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto [blk_q, blk_v, head, idx_b] = tile_scheduler.get_block_coord(); // (Q,V,h,b)
auto blk_qv = make_coord(blk_q, blk_v);
int head_q = head / head_group_q;
Copy link

@wuxun-zhang wuxun-zhang Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In line65 of xe_tile_scheduler.hpp, grid.z is set to batch * num_heads_q, so here head should stand for idx of query heads, it seems we need to calculate head_kv instead of head_q?

int head_group_q = s.num_heads_q / s.num_heads_kv;
int head_kv = head / head_group_q;

Edit:

In my local test, after applying all suggested changes, now it works well with correctness check passed.


auto &p = params.kernel;
ProblemShape const& s = p.shape;
int head_group_q = s.num_heads_kv / s.num_heads_q;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int head_group_q = s.num_heads_kv / s.num_heads_q;
int head_group_q = s.num_heads_q / s.num_heads_kv;

Comment on lines +189 to +191
auto [blk_q, blk_v, head, idx_b] = tile_scheduler.get_block_coord(); // (Q,V,h,b)
auto blk_qv = make_coord(blk_q, blk_v);
int head_q = head / head_group_q;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto [blk_q, blk_v, head, idx_b] = tile_scheduler.get_block_coord(); // (Q,V,h,b)
auto blk_qv = make_coord(blk_q, blk_v);
int head_q = head / head_group_q;
auto [blk_q, blk_v, head_q, idx_b] = tile_scheduler.get_block_coord(); // (Q,V,h,b)
auto blk_qv = make_coord(blk_q, blk_v);
int head = head_q / head_group_q;


// Epilogue
CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue};
epilogue(O(_,_,head,idx_b),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
epilogue(O(_,_,head,idx_b),
epilogue(O(_,_,head_q,idx_b),

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants