From c124b2094c2a71608c1875ea404a15d1ddb5059b Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 25 Jul 2024 18:05:09 -0400 Subject: [PATCH] [Bugfix] Fix empty (nullptr) channelwise scales when loading wNa16 using compressed tensors (#6798) --- .../compressed_tensors/schemes/compressed_tensors_wNa16.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index a39462334e732..36bbcb76d8846 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -64,7 +64,12 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, is_row_parallel = input_size != input_size_per_partition # If group_size is -1, we are in channelwise case. - group_size = input_size if self.group_size == -1 else self.group_size + channelwise = (self.group_size == -1) + group_size = input_size if channelwise else self.group_size + row_parallel = (input_size != input_size_per_partition) + # In the case of channelwise quantization, we need to replicate the + # scales across all gpus. + partition_scales = (row_parallel and not channelwise) verify_marlin_supports_shape( output_size_per_partition=output_size_per_partition,