Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Expand FP8 support to Ampere GPUs using FP8 Marlin #5975

Merged
merged 19 commits into from
Jul 3, 2024

Conversation

mgoin
Copy link
Collaborator

@mgoin mgoin commented Jun 28, 2024

This work expands FP8 support in vLLM from GPUs with hardware FP8 support (Hopper and Ada Lovelace) to GPUs without native support (currently Ampere) by introducing FP8 Marlin - a fast fused dequantization kernel for FP8 to BF16/FP16 conversion.

Key features:

  • Enables FP8 quantization on a wider range of GPUs (SM 8.0 and 8.7, Ampere)
  • Improves performance up to 2x in memory-bound scenarios
  • Maintains accuracy comparable to FP16 baselines
  • Reduces weight memory usage by 2x, allowing larger batches
  • Simple to use - just specify quantization="fp8" at runtime or use pre-quantized FP8 checkpoints

Implementation details:

  • Based on existing 8-bit integer support in GPTQ Marlin kernel
  • Packs FP8 weights into int32 doublewords (GPTQ format) and then permutes weights into Marlin format
  • Efficient 4xFP8 to 4xFP16/BF16 dequantization using bit arithmetic and SIMT operations

End-to-end performance and accuracy results:
FP8 Marlin A10 E2E Latency in vLLM
FP8 Marlin A100 E2E Latency in vLLM
GSM8k lm-eval with FP8 Marlin in vLLM
Individual layer sweeps:
A10 Layer-wise Sweep _ PyTorch FP16 vs FP8 Marlin MatMul
A100 Layer-wise Sweep _ PyTorch FP16 vs FP8 Marlin MatMul

As shown in the graphs, FP8 Marlin can provide significant speedups with minimal accuracy impact. Performance gains are higher on GPUs with less memory bandwidth (A10, RTX 3090) and for larger models.

Notes:

  • This weight-only approach differs slightly from the existing W8A8 FP8 quantization, offering higher accuracy because the activations have no need to be quantized
  • Currently expanding scales to be channelwise; future work will revert to per-tensor scales
  • This does not include support for MoE models.

Testing:

  • Tested on H100, A100, and A10 GPUs

This enhancement enables more users to benefit from FP8 quantization without hardware restrictions, improving vLLM's performance and efficiency across a broader range of setups!

@mgoin mgoin changed the title [Kernel] Expand FP8 support to Ampere GPUs using FP8 Marlin #331 [Kernel] Expand FP8 support to Ampere GPUs using FP8 Marlin Jun 28, 2024
@robertgshaw2-neuralmagic
Copy link
Collaborator

This is an awesome feature!

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. Thanks!

tests/kernels/test_marlin_gemm.py Show resolved Hide resolved
@mgoin mgoin enabled auto-merge (squash) July 3, 2024 16:30
@mgoin mgoin merged commit 47f0954 into vllm-project:main Jul 3, 2024
70 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants