Skip to content

Commit

Permalink
Fix weight residual computation
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 18, 2024
1 parent e7ef917 commit 33bd3f7
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/brevitas/quant_tensor/groupwise_float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def expand(self):
# Then, we unbind the tensor along the group_dim shape, and drop the padded columns
# Finally, we stack the remaining tensors
unpadding_shape = final_shape[self.group_dim]
residual = curr_shape[self.group_dim] - unpadding_shape
residual = new_value.shape[self.group_dim] - unpadding_shape

if residual > 0:
new_value = torch.stack(
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/quant_tensor/groupwise_int_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def expand(self):
# Then, we unbind the tensor along the group_dim shape, and drop the padded columns
# Finally, we stack the remaining tensors
unpadding_shape = final_shape[self.group_dim]
residual = curr_shape[self.group_dim] - unpadding_shape
residual = new_value.shape[self.group_dim] - unpadding_shape

if residual > 0:
new_value = torch.stack(
Expand Down

0 comments on commit 33bd3f7

Please sign in to comment.