-
Notifications
You must be signed in to change notification settings - Fork 62
[SYCL-TLA] Integrate FlashAttention fwd/bwd kernels #2341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
EikanWang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TBH, I cannot quite understand the detailed implementation. I need to take more time to understand the logic.
|
|
||
| file(GLOB xpu_cpp "xpu/*.cpp") | ||
| file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/nested/*.cpp" "native/nested/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp") | ||
| file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/nested/*.cpp" "native/nested/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp" "native/transformers/xpu/flash_attn/*.cpp") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I think we should install the header file under flash_attn into PyTorch such as line 42
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May I know what is the purpose of installing header file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Give a chance to use them in cpp extension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@guangyey , I think PyTorch does not expose flash_attn because it is the underlying logic of sdpa, which is exposed as a backend. Meanwhile, I don't believe users invoke the flash_atten of PyTorch because dao/flash_atten is a better choice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Meanwhile, the namespace of these functions is sycltla. It is weird to let users invoke sycl-tla-specific functions.
770035a to
442c445
Compare
| out = at::empty({batch_size, numhead_qo, seqlen_qo, headsize_vo}, opts); | ||
| } else if (layout == ATTN_TENSOR_LAYOUT::BSHD) { | ||
| out = at::empty({batch_size, seqlen_qo, numhead_qo, headsize_vo}, opts) | ||
| .permute({0, 2, 1, 3}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why need to permute here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output is inited as BSHD contiguous but the shape should be BHSD in for SDPA. Hence it needs to permute the seqlen and numhead dimension.
This PR moves the sycltla kernels in pytorch/pytorch#167056 into torch-xpu-ops.
This PR is based on #2030. When the build PR merge, I will rebase this PR.