Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Flash attention 2 #275

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open

[Kernel] Flash attention 2 #275

wants to merge 14 commits into from

Conversation

remi-or
Copy link

@remi-or remi-or commented Sep 26, 2024

Summary

This PR adds a Flash Attention 2 triton kernel and the monkey-patching of SDPA attention layers with our FA kernel.

Details

The kernel supports fp16 and bfloat16, attention masking, attention bias, GQA, causal masking and different sequence length for Q and KV. There are no restriction on the sequence length nor on the dimension of the heads.
Dropout is implemented for the forward pass only, as I am having trouble computing the gradient when there is dropout. As most models do not use attention dropout anymore, I think this can remain outside of the scope of this PR.

Testing Done

Extensive testing was conducted before drafting this PR, but I chose to not include all test to not clutter the code of this repository. The kernel has errors comparable to the official Flash Attention kernel, and you can refer to the implemented test to see how testing was conducted.
The added unit test was parameterized to still cover a broad range of cases, and all tests passed, though the absolute tolerance for the output of the forward pass is slightly higher than what was done in the FA repo.
Convergence testes where SDPA attention was monkey patched with FA all passed.

  • Hardware Type: AMD Mi210, ROCm 5.6.0
  • run make test to ensure correctness -- but some tests failed
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Test failures - unrelated

Additionally, I noticed that all tests for the liger_cross_entropy_kernel failed because of num_warps=32 and some convergence tests with dtype=float32 (so no monkey patching of attention took place) failed. This is more deserving of an issue and probably related to running the tests on a Mi210, but this is why I did not include the test logs.

Benchmarking

attention_speed

This figure relates the speed of a forward and a backward pass (mode=full in the benchmark script).
I benchmarked the kernel using the parameters of a Llama3 (32 attention heads, 8 KV heads, 4096 hidden dim) with batch size of 4 and fp16 data, varying the sequence length from 2^5 to 2^14.

Misc.

I have noticed that auto-tuning takes quite a while. I would be happy if anyone has recommendation related to this. Also, I think this kernel is a good placeholder for flex attention, but once it is added to pytorch (only supported in pytorch nightly for now) it might be worthwhile to compare the two.

@ByronHsu
Copy link
Collaborator

exciting!! Let me take some time to review this

@ByronHsu ByronHsu mentioned this pull request Sep 30, 2024
@xzuyn
Copy link

xzuyn commented Oct 18, 2024

This is very nice, thank you so much for the work. Running on a 7900XTX within Axolotl reduced the memory usage by a good amount. (I may be wrong about it working in Axolotl, I'm not sure yet, but benchmark below works) I look forward to this being merged.

Here are the results from benchmark_attention.py as well. I powercap and undervolt/underclock a little so this may not be fully representative of the performance.
Figure_1
Figure_2
all_benchmark_data.csv

@winglian
Copy link

would be great to see the comparison graphs vs the official FA2 implementation too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants