Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic Scale Factor Calculations for Key/Value Scales With FP8 KV Caching #317

Merged
merged 15 commits into from
Dec 17, 2024

Conversation

micah-wil
Copy link

@micah-wil micah-wil commented Dec 10, 2024

This PR implements a simple method for calculating k_scale and v_scale in the attention layer. This is especially useful in the absence of scale factors in the model checkpoints, where the previous solution was to default the scale factors to 1.0.

This feature necessitated changing k_scale and v_scale to tensors rather than floats, which should be useful for exploring different types of key & value scaling in the future (e.g. per-channel).

Here are a few PPL measurements taken using Llama 3.1 70B, demonstrating superior accuracy compared to using a scale factor of 1.0 for both k_scale and v_scale.

FP16 KV Cache: PPL=2.7317
FP8 KV Cache with k/v scale set to 1.0:  PPL=2.8874
**FP8 KV Cache with dynamic scale calculation: PPL=2.7484**

Copy link
Collaborator

@gshtras gshtras left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, aside from a few minor questions and comments.
Also pending conflict resolution

vllm/envs.py Outdated
@@ -17,7 +17,7 @@
VLLM_USE_TRITON_FLASH_ATTN: bool = True
VLLM_USE_ROCM_SKINNY_GEMM: bool = True
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = True
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this switched?

@@ -82,6 +82,8 @@
VLLM_MOE_PADDING: bool = False
VLLM_FP8_PADDING: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = False
K_SCALE_CONSTANT: int = 200
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want different values?

@@ -47,7 +47,8 @@ def _init_attn_metadata_from_tensor_dict(
# Extract the fields used to create AttentionMetadata.
valid_attn_kwargs = {}
for field in dataclasses.fields(attn_backend.get_metadata_cls()):
if field.name in tensor_dict:
if field.name in tensor_dict and field.name != \
'enable_kv_scales_calculation':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, why do we filter it out here?

engine_args = EngineArgs.from_cli_args(args)
llm = LLM(**dataclasses.asdict(engine_args))

llm = LLM(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed now with the **dataclasses.asdict(engine_args)

self._k_scale = 1.0
self._v_scale = 1.0
self.calculate_kv_scales = calculate_kv_scales
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly torch.ones is better

@gshtras gshtras merged commit d9fed26 into main Dec 17, 2024
9 of 10 checks passed
@gshtras gshtras deleted the kv-scales-on-the-fly branch December 17, 2024 23:43
gshtras added a commit that referenced this pull request Jan 7, 2025
…ching (#317)

* Changed _k_scale and _v_scale to tensors

* fixed rocm paged attention with tensor kv scales

* Added on the fly scale factor calculation

* trying to fix attn metadata

* fixed AttentionMetadata issue, updated description for calculate-kv-scales flag in arg_utils.py

* Changed K and V scale constants

* Removed unneeded comment

* Changes to pass format.sh, also fixed lingering k_scale/v_scale : float

* Fix for TP > 1

* Ran format.sh

* Removed legacy kv_scale loading from the json file

* Removed the outdated kv cache docs

* Revert some unwanted changes

---------

Co-authored-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants