-
Notifications
You must be signed in to change notification settings - Fork 58
Re-implement FlashAttention with new Xe atoms #547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
I will break up this large commit into self-contained smaller commits after review is complete. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this here? This isn't flash attention specific, is it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's not. These started as some simple helpers to make copying to/from SLM easier for the epilogue. We could move them, maybe to include/cute/algorithm/cute.hpp
, though they should be made more sophisticated (use smaller/larger block sizes as appropriate, automatic fallback to scatter/gather, etc.).
// No diagnostics/error will be issued by the compiler if it is not. | ||
template <typename T> | ||
CUTE_HOST_DEVICE void | ||
set_wi_value(T &x, int i, T val) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why don't you take i as compile time value to make this safer? The usage is on line 137 where the input comes from the unrolled loop index. If you replace the loop with for_each you have a compile time constant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is an option -- I did it this way since compile-time unrolling of the loop is IMO harder to use and harder to read.
I opened a compiler ticket for the lack of diagnostics, and they have a patch under review now to address it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. As long as we have diagnostic that's fine. Current solution won't compile for O0. Not sure whether it matters.
for (int VV = 0; VV < VTiles; VV++) { | ||
copy(copy_v, tVgV(_,_,_,VV,K), tVrV); | ||
reorder(tVrV, tArV); | ||
cute::gemm(mma_pv, tArP, tArV, tArA(_,_,_,VV)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why the namespace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sometimes it's required to disambiguate the gemm
name. I can't remember the exact ambiguity here, but I had to add it.
for (; tile_scheduler.is_valid(); ++tile_scheduler) { | ||
auto [blk_q, blk_v, head, idx_b] = tile_scheduler.get_block_coord(); // (Q,V,h,b) | ||
auto blk_qv = make_coord(blk_q, blk_v); | ||
int head_q = head / head_group_q; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In line65 of xe_tile_scheduler.hpp
, grid.z
is set to batch * num_heads_q
, so here head
should stand for idx of query heads, it seems we need to calculate head_kv
instead of head_q
?
int head_group_q = s.num_heads_q / s.num_heads_kv;
int head_kv = head / head_group_q;
Edit:
In my local test, after applying all suggested changes, now it works well with correctness check passed.
|
||
auto &p = params.kernel; | ||
ProblemShape const& s = p.shape; | ||
int head_group_q = s.num_heads_kv / s.num_heads_q; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int head_group_q = s.num_heads_kv / s.num_heads_q; | |
int head_group_q = s.num_heads_q / s.num_heads_kv; |
auto [blk_q, blk_v, head, idx_b] = tile_scheduler.get_block_coord(); // (Q,V,h,b) | ||
auto blk_qv = make_coord(blk_q, blk_v); | ||
int head_q = head / head_group_q; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto [blk_q, blk_v, head, idx_b] = tile_scheduler.get_block_coord(); // (Q,V,h,b) | |
auto blk_qv = make_coord(blk_q, blk_v); | |
int head_q = head / head_group_q; | |
auto [blk_q, blk_v, head_q, idx_b] = tile_scheduler.get_block_coord(); // (Q,V,h,b) | |
auto blk_qv = make_coord(blk_q, blk_v); | |
int head = head_q / head_group_q; |
|
||
// Epilogue | ||
CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; | ||
epilogue(O(_,_,head,idx_b), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
epilogue(O(_,_,head,idx_b), | |
epilogue(O(_,_,head_q,idx_b), |
This PR updates FlashAttention to the new copy/MMA atoms.
Changes:
Current status: prefill/decode examples almost all working, similar/better performance to old examples.
Known issues:
Additional features (causal masking, variable sequence lengths, etc.) to be added later.
Reminder: the new atoms require a very recent driver due to necessary IGC fixes/enhancements. Recommended version: ci-comp_igc-30613.