From e7ef917aa7b36560ae090b6907c63679c3fc9163 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 18 Dec 2024 22:53:27 +0000 Subject: [PATCH] Fix zero point --- src/brevitas/quant_tensor/groupwise_float_quant_tensor.py | 5 +++-- src/brevitas/quant_tensor/groupwise_int_quant_tensor.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index a4099b785..fe40e5319 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -116,8 +116,9 @@ def expand(self): torch.unbind(new_value, dim=self.group_dim)[residual:], dim=self.group_dim) new_scale = torch.stack( torch.unbind(new_scale, dim=self.group_dim)[residual:], dim=self.group_dim) - new_zp = torch.stack( - torch.unbind(new_zp, dim=self.group_dim)[residual:], dim=self.group_dim) + if self.zero_point_.shape != (): + new_zp = torch.stack( + torch.unbind(new_zp, dim=self.group_dim)[residual:], dim=self.group_dim) return new_value, new_scale, new_zp diff --git a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py index 65fc9b73f..67e6e769f 100644 --- a/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_int_quant_tensor.py @@ -102,8 +102,9 @@ def expand(self): torch.unbind(new_value, dim=self.group_dim)[residual:], dim=self.group_dim) new_scale = torch.stack( torch.unbind(new_scale, dim=self.group_dim)[residual:], dim=self.group_dim) - new_zp = torch.stack( - torch.unbind(new_zp, dim=self.group_dim)[residual:], dim=self.group_dim) + if self.zero_point_.shape != (): + new_zp = torch.stack( + torch.unbind(new_zp, dim=self.group_dim)[residual:], dim=self.group_dim) return new_value, new_scale, new_zp