From 5362727ec366c1542b2be7a520e7c44e5cc3ce30 Mon Sep 17 00:00:00 2001 From: Charlie Fu Date: Thu, 14 Nov 2024 10:34:21 -0600 Subject: [PATCH] Improve the heuristic logic for fp8 weight padding (#279) * add heuristic logic for weight padding * lint --- vllm/model_executor/layers/quantization/fp8.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 0803dcba5cbd2..205a7e19811e8 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -248,8 +248,11 @@ def process_weights_after_loading(self, layer: Module) -> None: ) # Pad the weight - if envs.VLLM_FP8_PADDING: - weight = F.pad(weight, (0, 256), "constant", 0)[..., :-256] + if envs.VLLM_FP8_PADDING and weight.stride(-1) == 1 \ + and (weight.stride(-2) * weight.element_size()) % 512 == 0: + num_pad = 256 // weight.element_size() + weight = F.pad(weight, (0, num_pad), "constant", + 0)[..., :-num_pad] torch.cuda.empty_cache() # Update layer with new values.