Skip to content

Commit

Permalink
Improve the heuristic logic for fp8 weight padding (#279)
Browse files Browse the repository at this point in the history
* add heuristic logic for weight padding

* lint
  • Loading branch information
charlifu authored Nov 14, 2024
1 parent 04aa1a7 commit 5362727
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,11 @@ def process_weights_after_loading(self, layer: Module) -> None:
)

# Pad the weight
if envs.VLLM_FP8_PADDING:
weight = F.pad(weight, (0, 256), "constant", 0)[..., :-256]
if envs.VLLM_FP8_PADDING and weight.stride(-1) == 1 \
and (weight.stride(-2) * weight.element_size()) % 512 == 0:
num_pad = 256 // weight.element_size()
weight = F.pad(weight, (0, num_pad), "constant",
0)[..., :-num_pad]
torch.cuda.empty_cache()

# Update layer with new values.
Expand Down

0 comments on commit 5362727

Please sign in to comment.