Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature][Hardware][AMD] Enable Scaled FP8 GEMM on ROCm #6006

Open
wants to merge 94 commits into
base: main
Choose a base branch
from

Conversation

HaiShaw
Copy link
Contributor

@HaiShaw HaiShaw commented Jun 30, 2024

Enable Scaled FP8 GEMM on ROCm (AMD GPU)

As part of a series of FP8 development in vLLM, this pull request introduces latest acceleration with FP8 computations on newer AMD hardware (MI30x and later).

  • Using OCP FP8 inference data type (float8_e4m3fn) at interface and file exchange level, compatible with OCP FP8 quantized model checkpoints.
  • Other than PTQ weights, static scaling factors are used for activations (and KV caches), via calibration process by Quark - the AMD quantizer, or AMMO from Nvidia.
  • This is ROCm/hipBLASLt based Implementation of scaled FP8 GEMM, adding to previously implemented scaled FP8 KV cache. In case multiple weight matrices are concatenated to a bigger GEMM for performance, this implementation suboptimally uses one vs. multiple scaling factors (to be addressed later). Note - AMD Quark can be configured s.t. certain matrices can be virtually merged prior to their quantization.
  • Largely follows the vLLM FP8 RFC: FP8 in vLLM #2461. Specifically, linear and projection layers are covered, while FP8 computation within self attention itself is left for future extension. Current GEMM takes FP8 as input and defaults to output float16/bfloat16. Further optimizations are working in progress, include but not limited to float16/bfloat16 ingress (in kernel conversion), direct FP8 egress to KV cache, etc.
  • Note - this feature will not work on MI2xx or older GPUs lacking FP8 MFMA instructions.

Design Reference:

  • Note - Quark may add AutoFP8 compatible export, by then we will extend the support accordingly.
  • RFC: FP8 Quantization Schema in vLLM update #5802
  • RFC: FP8 Quantization Schema in vLLM #3218
  • RFC: FP8 in vLLM #2461

Introducing Quark - AMD Quantizer:

Please refer to: AMD Quark landing page

Performance Tuning:

Please refer to: AMD vLLM performance tuning guide

Usage and Examples:

To get started, please refer to:

./examples/fp8/quantizer/README.md

Performance and Accuracy:

With FP8 KV cache together, we observed up to ~50% performance increases on top of FP16 Llama2 baseline, in favor of larger batch sizes and sequence, even on the quantized 70B model served on a single MI300X.

LLM-Q&A, Llama2-70b, dataset: OPEN ORCA on 8 MI300X GPUs (TP=8)

GEMM Types Rouge-1 Rouge-2 Rouge-L
FP16 44.4860 N/A 28.6992
FP8 scaled 44.5001 22.0853 28.7140

@mgoin
Copy link
Collaborator

mgoin commented Jul 1, 2024

Hi @HaiShaw thanks for pushing up this chunk of work. Is there a reason you haven't tried enabling AMD explicitly through the existing "fp8" quantization backend with the current checkpoint format? It seems within your "Fp8Fnuz" method that torch._scaled_mm is actually a valid else case, so could you take advantage of its usage already in the "fp8" backend for an easier starting point?

@HaiShaw
Copy link
Contributor Author

HaiShaw commented Jul 1, 2024

Hi @HaiShaw thanks for pushing up this chunk of work. Is there a reason you haven't tried enabling AMD explicitly through the existing "fp8" quantization backend with the current checkpoint format? It seems within your "Fp8Fnuz" method that torch._scaled_mm is actually a valid else case, so could you take advantage of its usage already in the "fp8" backend for an easier starting point?

@mgoin thanks for your question! There were couple of reasons that we did not reuse the same backend as exact, other than different internal (HW) format and gemm implementations, not to consider dynamic scaling is a main reason (and we don't prefer to mixup CUDA backend too much in code). In terms of model loading, we started with AMMO support, now AMD Quark, and will be extended to AutoFP8 compatible checkpoint support once RFC #5802 is landed in Quark. Some discrepancy we have is due to the moving nature or completeness of several quantizers that we deal here, arising from different design ideas.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants