From 38616f11591d683119d00714e325b6c70bc5a9bb Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 25 Jul 2024 18:05:09 -0400 Subject: [PATCH] [Bugfix] Fix empty (nullptr) channelwise scales when loading wNa16 using compressed tensors (#6798) --- .../schemes/compressed_tensors_wNa16.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 996cba315c556..a41962ccd66d8 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -55,7 +55,12 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, output_size_per_partition = sum(output_partition_sizes) # If group_size is -1, we are in channelwise case. - group_size = input_size if self.group_size == -1 else self.group_size + channelwise = (self.group_size == -1) + group_size = input_size if channelwise else self.group_size + row_parallel = (input_size != input_size_per_partition) + # In the case of channelwise quantization, we need to replicate the + # scales across all gpus. + partition_scales = (row_parallel and not channelwise) verify_marlin_supports_shape( output_size_per_partition=output_size_per_partition, @@ -66,8 +71,8 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, weight_scale_dim = None scales_and_zp_size = input_size // group_size - if (input_size != input_size_per_partition - and self.group_size is not None): + if partition_scales: + assert input_size_per_partition % group_size == 0 weight_scale_dim = 1 scales_and_zp_size = input_size_per_partition // group_size