Skip to content

Commit

Permalink
[Bugfix] Fix empty (nullptr) channelwise scales when loading wNa16 us…
Browse files Browse the repository at this point in the history
…ing compressed tensors (vllm-project#6798)
  • Loading branch information
LucasWilkinson authored and cadedaniel committed Jul 27, 2024
1 parent 49b6e10 commit 38616f1
Showing 1 changed file with 8 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
output_size_per_partition = sum(output_partition_sizes)

# If group_size is -1, we are in channelwise case.
group_size = input_size if self.group_size == -1 else self.group_size
channelwise = (self.group_size == -1)
group_size = input_size if channelwise else self.group_size
row_parallel = (input_size != input_size_per_partition)
# In the case of channelwise quantization, we need to replicate the
# scales across all gpus.
partition_scales = (row_parallel and not channelwise)

verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
Expand All @@ -66,8 +71,8 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
weight_scale_dim = None
scales_and_zp_size = input_size // group_size

if (input_size != input_size_per_partition
and self.group_size is not None):
if partition_scales:
assert input_size_per_partition % group_size == 0
weight_scale_dim = 1
scales_and_zp_size = input_size_per_partition // group_size

Expand Down

0 comments on commit 38616f1

Please sign in to comment.