Skip to content

Commit

Permalink
Fix (nn): QuantConv group calculation in acc bit width (#703)
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert authored Nov 10, 2023
1 parent fc7ff8e commit 32186be
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/brevitas/nn/quant_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Option
def max_acc_bit_width(self, input_bit_width, weight_bit_width):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
group_size = self.out_channels // self.groups
group_size = self.in_channels // self.groups
max_uint_output = max_uint_input * max_kernel_val * self.kernel_size[0] * group_size
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
return max_output_bit_width
Expand Down Expand Up @@ -206,7 +206,7 @@ def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Option
def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
group_size = self.out_channels // self.groups
group_size = self.in_channels // self.groups
kernel_size = self.kernel_size[0] * self.kernel_size[1]
max_uint_output = max_uint_input * max_kernel_val * kernel_size * group_size
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
Expand Down

0 comments on commit 32186be

Please sign in to comment.