Skip to content

Commit

Permalink
fix empty channelwise scales for compressed tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Jul 25, 2024
1 parent 316a41a commit cdbe26d
Showing 1 changed file with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
output_size_per_partition = sum(output_partition_sizes)

# If group_size is -1, we are in channelwise case.
group_size = input_size if self.group_size == -1 else self.group_size
channelwise = (self.group_size == -1)
group_size = input_size if channelwise else self.group_size
row_parallel = (input_size != input_size_per_partition)
# In the case of channelwise quantization, we need to replicate the
# scales across all gpus.
partition_scales = (row_parallel and not channelwise)

verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
Expand All @@ -65,9 +70,8 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
weight_scale_dim = None
scales_and_zp_size = input_size // group_size

if (input_size != input_size_per_partition
and self.group_size is not None):
weight_scale_dim = 1
if partition_scales:
assert input_size_per_partition % group_size == 0
scales_and_zp_size = input_size_per_partition // group_size

weight = Parameter(
Expand Down

0 comments on commit cdbe26d

Please sign in to comment.