Skip to content

Commit

Permalink
Fix zero point
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 18, 2024
1 parent d7b5036 commit e7ef917
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
5 changes: 3 additions & 2 deletions src/brevitas/quant_tensor/groupwise_float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ def expand(self):
torch.unbind(new_value, dim=self.group_dim)[residual:], dim=self.group_dim)
new_scale = torch.stack(
torch.unbind(new_scale, dim=self.group_dim)[residual:], dim=self.group_dim)
new_zp = torch.stack(
torch.unbind(new_zp, dim=self.group_dim)[residual:], dim=self.group_dim)
if self.zero_point_.shape != ():
new_zp = torch.stack(
torch.unbind(new_zp, dim=self.group_dim)[residual:], dim=self.group_dim)

return new_value, new_scale, new_zp

Expand Down
5 changes: 3 additions & 2 deletions src/brevitas/quant_tensor/groupwise_int_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ def expand(self):
torch.unbind(new_value, dim=self.group_dim)[residual:], dim=self.group_dim)
new_scale = torch.stack(
torch.unbind(new_scale, dim=self.group_dim)[residual:], dim=self.group_dim)
new_zp = torch.stack(
torch.unbind(new_zp, dim=self.group_dim)[residual:], dim=self.group_dim)
if self.zero_point_.shape != ():
new_zp = torch.stack(
torch.unbind(new_zp, dim=self.group_dim)[residual:], dim=self.group_dim)

return new_value, new_scale, new_zp

Expand Down

0 comments on commit e7ef917

Please sign in to comment.