Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ingest FP8 attn scales and use them in ROCm FlashAttention #338

Merged
merged 7 commits into from
Dec 20, 2024

Conversation

mawong-amd
Copy link

@mawong-amd mawong-amd commented Dec 19, 2024

Thanks to the work of @ilia-cher in #301, Triton FA supports per-tensor quantized FP8 almost-everything (quantized first and second GEMMs and attention output <- needs per-tensor quantized Q, K, V, softmax(QK^T) and corresponding scales).

This PR enables the aforementioned quantization routines in Triton FA and ROCm PA if a quantized (text-only) Llama model contains attention output scales and an appropriate environment variable is set (VLLM_USE_ROCM_FP8_FLASH_ATTN={True/1}, off by default). Extending this to other model architectures is straightforward but not done for now. Accuracy might dip for Triton FA if not all scales are present in the model.

@mawong-amd mawong-amd force-pushed the ingest_fp8_attn_scales branch from 4e42946 to 9ba2fab Compare December 19, 2024 01:42
@@ -428,7 +434,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
if loaded_weight.shape:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for mllama.py?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I targeted only the models which unconditionally do this logic

Copy link
Collaborator

@gshtras gshtras left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without explicitly disabling VLLM_USE_ROCM_FP8_ATTN now quark quantized models (amd/Meta-Llama-3.1-70B-Instruct-FP8-KV) fail with a triton exception:

python: /root/.triton/llvm/llvm-c08c6a71-ubuntu-x64/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From&) [with To = mlir::detail::TypedValue<mlir::RankedTensorType>; From = mlir::OpResult]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.

@mawong-amd
Copy link
Author

mawong-amd commented Dec 19, 2024

Without explicitly disabling VLLM_USE_ROCM_FP8_ATTN now quark quantized models (amd/Meta-Llama-3.1-70B-Instruct-FP8-KV) fail with a triton exception:

python: /root/.triton/llvm/llvm-c08c6a71-ubuntu-x64/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From&) [with To = mlir::detail::TypedValue<mlir::RankedTensorType>; From = mlir::OpResult]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.

This requires a newer Triton, preferably one after triton-lang/triton#5362 by @ilia-cher to obtain full performance benefits.

EDIT: @ilia-cher identified the issue and provided a simple fix that works on older Triton. Still recommended to upgrade to latest.

@mawong-amd mawong-amd force-pushed the ingest_fp8_attn_scales branch from 9639307 to 1ed1389 Compare December 19, 2024 23:15
@mawong-amd mawong-amd changed the title Ingest FP8 attn scales and use them in ROCm Flash/Paged Attention Ingest FP8 attn scales and use them in ROCm FlashAttention Dec 19, 2024
gshtras
gshtras previously approved these changes Dec 19, 2024
@gshtras gshtras merged commit 1dcd9fe into main Dec 20, 2024
7 of 8 checks passed
@gshtras gshtras deleted the ingest_fp8_attn_scales branch December 20, 2024 00:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants