Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[vllm] Add support for FP8 in Triton FA kernel #301

Merged
merged 2 commits into from
Dec 4, 2024
Merged

Conversation

ilia-cher
Copy link

@ilia-cher ilia-cher commented Dec 4, 2024

Adding support for FP8 (E4M3) in Triton FA kernel, including per-tensor scaling factors.

Test:

1. Patched rocm_flash_attn.py to call FA kernel with scaling factors
   (https://gist.github.com/ilia-cher/216762889331cefeb158634a651b2fac)

2. Ran the benchmark:
   python3 benchmark_latency.py --model \
      /data/models/Llama-3.1-8B-Instruct-FP8-KV \
      --input-len 8192 \
      --output-len 1 \
      --batch-size 32 \
      --enforce-eager \
      --num-iters 10 \
      --num-iters-warmup 2 \
      --enable-chunked-prefill False \
      --dtype float16
Before:
Avg latency: 6.418297152221203 seconds
10% percentile latency: 6.380122036673129 seconds
25% percentile latency: 6.390297698322684 seconds
50% percentile latency: 6.404989298898727 seconds
75% percentile latency: 6.421127524343319 seconds
90% percentile latency: 6.4394324975088235 seconds
99% percentile latency: 6.562963163470849 seconds

After:
Avg latency: 5.162057781498879 seconds
10% percentile latency: 5.1219399653142315 seconds
25% percentile latency: 5.135780334530864 seconds
50% percentile latency: 5.151887209853157 seconds
75% percentile latency: 5.158517300733365 seconds
90% percentile latency: 5.184290232090279 seconds
99% percentile latency: 5.314461483638734 seconds

3. (Sanity) check using
   https://gist.github.com/ilia-cher/951a3d011a8bafa7c5180fbc3a151a57

4. (follow up in scaling factors loading PR) P3L perplexity check

Copy link
Collaborator

@shajrawi shajrawi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Brilliant work Ilia, Kudos!

Adding support for FP8 (E4M3) in Triton FA kernel, including per-tensor
scaling factors.

Test:
1. Patched rocm_flash_attn.py to call FA kernel with scaling factors
   (https://gist.github.com/ilia-cher/216762889331cefeb158634a651b2fac)
2. Run the benchmark:
   python3 benchmark_latency.py --model \
      /data/models/Llama-3.1-8B-Instruct-FP8-KV \
      --input-len 8192 \
      --output-len 1 \
      --batch-size 32 \
      --enforce-eager \
      --num-iters 10 \
      --num-iters-warmup 2 \
      --enable-chunked-prefill False \
      --dtype float16
Before:
Avg latency: 6.418297152221203 seconds
10% percentile latency: 6.380122036673129 seconds
25% percentile latency: 6.390297698322684 seconds
50% percentile latency: 6.404989298898727 seconds
75% percentile latency: 6.421127524343319 seconds
90% percentile latency: 6.4394324975088235 seconds
99% percentile latency: 6.562963163470849 seconds

After:
Avg latency: 5.162057781498879 seconds
10% percentile latency: 5.1219399653142315 seconds
25% percentile latency: 5.135780334530864 seconds
50% percentile latency: 5.151887209853157 seconds
75% percentile latency: 5.158517300733365 seconds
90% percentile latency: 5.184290232090279 seconds
99% percentile latency: 5.314461483638734 seconds

3. (Sanity) check using
   https://gist.github.com/ilia-cher/951a3d011a8bafa7c5180fbc3a151a57

4. (follow up in scaling factors loading PR) P3L perplexity check
@ilia-cher ilia-cher merged commit 97fd542 into develop Dec 4, 2024
7 of 8 checks passed
@gshtras gshtras deleted the attn_fp8 branch December 7, 2024 03:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants