-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dynamic Scale Factor Calculations for Key/Value Scales With FP8 KV Caching #317
Conversation
…cales flag in arg_utils.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good, aside from a few minor questions and comments.
Also pending conflict resolution
vllm/envs.py
Outdated
@@ -17,7 +17,7 @@ | |||
VLLM_USE_TRITON_FLASH_ATTN: bool = True | |||
VLLM_USE_ROCM_SKINNY_GEMM: bool = True | |||
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True | |||
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = True | |||
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this switched?
@@ -82,6 +82,8 @@ | |||
VLLM_MOE_PADDING: bool = False | |||
VLLM_FP8_PADDING: bool = True | |||
VLLM_ENABLE_V1_MULTIPROCESSING: bool = False | |||
K_SCALE_CONSTANT: int = 200 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want different values?
@@ -47,7 +47,8 @@ def _init_attn_metadata_from_tensor_dict( | |||
# Extract the fields used to create AttentionMetadata. | |||
valid_attn_kwargs = {} | |||
for field in dataclasses.fields(attn_backend.get_metadata_cls()): | |||
if field.name in tensor_dict: | |||
if field.name in tensor_dict and field.name != \ | |||
'enable_kv_scales_calculation': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure, why do we filter it out here?
benchmarks/P3L.py
Outdated
engine_args = EngineArgs.from_cli_args(args) | ||
llm = LLM(**dataclasses.asdict(engine_args)) | ||
|
||
llm = LLM( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not needed now with the **dataclasses.asdict(engine_args)
self._k_scale = 1.0 | ||
self._v_scale = 1.0 | ||
self.calculate_kv_scales = calculate_kv_scales | ||
self._k_scale = torch.tensor(1.0, dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly torch.ones is better
…ching (#317) * Changed _k_scale and _v_scale to tensors * fixed rocm paged attention with tensor kv scales * Added on the fly scale factor calculation * trying to fix attn metadata * fixed AttentionMetadata issue, updated description for calculate-kv-scales flag in arg_utils.py * Changed K and V scale constants * Removed unneeded comment * Changes to pass format.sh, also fixed lingering k_scale/v_scale : float * Fix for TP > 1 * Ran format.sh * Removed legacy kv_scale loading from the json file * Removed the outdated kv cache docs * Revert some unwanted changes --------- Co-authored-by: Gregory Shtrasberg <[email protected]> Signed-off-by: Gregory Shtrasberg <[email protected]>
This PR implements a simple method for calculating k_scale and v_scale in the attention layer. This is especially useful in the absence of scale factors in the model checkpoints, where the previous solution was to default the scale factors to 1.0.
This feature necessitated changing k_scale and v_scale to tensors rather than floats, which should be useful for exploring different types of key & value scaling in the future (e.g. per-channel).
Here are a few PPL measurements taken using Llama 3.1 70B, demonstrating superior accuracy compared to using a scale factor of 1.0 for both k_scale and v_scale.