From 758acb89ec4659ccc39f951b6312ff9b0b04fd7b Mon Sep 17 00:00:00 2001 From: icolbert Date: Thu, 7 Sep 2023 14:30:43 -0700 Subject: [PATCH] Update quant_conv.py --- src/brevitas/nn/quant_conv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index fd6ab2e39..30a4046d4 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -113,7 +113,7 @@ def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Option def max_acc_bit_width(self, input_bit_width, weight_bit_width): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) - group_size = self.out_channels // self.groups + group_size = self.in_channels // self.groups max_uint_output = max_uint_input * max_kernel_val * self.kernel_size[0] * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output)) return max_output_bit_width @@ -209,7 +209,7 @@ def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Option def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor): max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False) max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width) - group_size = self.out_channels // self.groups + group_size = self.in_channels // self.groups kernel_size = self.kernel_size[0] * self.kernel_size[1] max_uint_output = max_uint_input * max_kernel_val * kernel_size * group_size max_output_bit_width = ceil_ste(torch.log2(max_uint_output))