diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index 81aaf86f2..846d4f290 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -111,7 +111,7 @@ def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Option def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) - group_size = self.out_channels // self.groups + group_size = self.in_channels // self.groups max_uint_output = max_uint_input * max_kernel_val * self.kernel_size[0] * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width @@ -206,7 +206,7 @@ def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Option def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) - group_size = self.out_channels // self.groups + group_size = self.in_channels // self.groups kernel_size = self.kernel_size[0] * self.kernel_size[1] max_uint_output = max_uint_input * max_kernel_val * kernel_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output))