Skip to content

Commit

Permalink
[Bugfix] Fix empty (nullptr) channelwise scales when loading wNa16 us…
Browse files Browse the repository at this point in the history
…ing compressed tensors (vllm-project#6798)
  • Loading branch information
LucasWilkinson authored and kylesayrs committed Aug 17, 2024
1 parent 0c5f59a commit c124b20
Showing 1 changed file with 6 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
is_row_parallel = input_size != input_size_per_partition

# If group_size is -1, we are in channelwise case.
group_size = input_size if self.group_size == -1 else self.group_size
channelwise = (self.group_size == -1)
group_size = input_size if channelwise else self.group_size
row_parallel = (input_size != input_size_per_partition)
# In the case of channelwise quantization, we need to replicate the
# scales across all gpus.
partition_scales = (row_parallel and not channelwise)

verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
Expand Down

0 comments on commit c124b20

Please sign in to comment.