From a3b0efab661dd88d4105094250b887aa65a460fa Mon Sep 17 00:00:00 2001 From: LucasWilkinson Date: Thu, 25 Jul 2024 20:02:30 +0000 Subject: [PATCH] fix empty channelwise scales for compressed tensors --- .../schemes/compressed_tensors_wNa16.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index e4cf0c0b5d95b..26cf7e4d4d04b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -54,7 +54,12 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, output_size_per_partition = sum(output_partition_sizes) # If group_size is -1, we are in channelwise case. - group_size = input_size if self.group_size == -1 else self.group_size + channelwise = (self.group_size == -1) + group_size = input_size if channelwise else self.group_size + row_parallel = (input_size != input_size_per_partition) + # In the case of channelwise quantization, we need to replicate the + # scales across all gpus. + partition_scales = (row_parallel and not channelwise) verify_marlin_supports_shape( output_size_per_partition=output_size_per_partition, @@ -65,8 +70,8 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, weight_scale_dim = None scales_and_zp_size = input_size // group_size - if (input_size != input_size_per_partition - and self.group_size is not None): + if partition_scales: + assert input_size_per_partition % group_size == 0 weight_scale_dim = 1 scales_and_zp_size = input_size_per_partition // group_size