diff --git a/src/brevitas/nn/quant_avg_pool.py b/src/brevitas/nn/quant_avg_pool.py index 28d09b529..73520598c 100644 --- a/src/brevitas/nn/quant_avg_pool.py +++ b/src/brevitas/nn/quant_avg_pool.py @@ -47,6 +47,13 @@ def channelwise_separable(self) -> bool: def requires_export_handler(self): return True + @property + def _avg_scaling(self): + if isinstance(self.kernel_size, tuple): + return self.kernel_size[0] * self.kernel_size[1] + else: + return self.kernel_size * self.kernel_size + def forward(self, input: Union[Tensor, QuantTensor]): x = self.unpack_input(input) @@ -55,7 +62,10 @@ def forward(self, input: Union[Tensor, QuantTensor]): if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled: y = AvgPool2d.forward(self, x) - y = self.trunc_quant(y) + if self.is_trunc_quant_enabled: + rescaled_value = x.value * self._avg_scaling + x = x.set(value=rescaled_value) + x = self.trunc_quant(x) else: y = AvgPool2d.forward(self, _unpack_quant_tensor(x)) @@ -113,6 +123,11 @@ def forward(self, input: Union[Tensor, QuantTensor]): if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled: y = AdaptiveAvgPool2d.forward(self, x) + if self.is_trunc_quant_enabled: + k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:]) + reduce_size = reduce(mul, k_size, 1) + rescaled_value = y.value * reduce_size # remove avg scaling + y = y.set(value=rescaled_value) y = self.trunc_quant(y) else: y = AdaptiveAvgPool2d.forward(self, _unpack_quant_tensor(x)) diff --git a/src/brevitas/quant_tensor/torch_handler.py b/src/brevitas/quant_tensor/torch_handler.py index 9ec6a41af..fcbe35a42 100644 --- a/src/brevitas/quant_tensor/torch_handler.py +++ b/src/brevitas/quant_tensor/torch_handler.py @@ -245,8 +245,8 @@ def avg_pool2d_handler( avg_scaling = kernel_size[0] * kernel_size[1] else: avg_scaling = kernel_size * kernel_size - rescaled_value = x * avg_scaling - quant_input = quant_input.set(value=rescaled_value) + + quant_input = quant_input.set(value=x) quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, avg_scaling)) return quant_input @@ -264,9 +264,8 @@ def adaptive_avg_pool2d_handler(quant_input, output_shape): max_acc_bit_width = FN_ACC_BITWIDTH_MAPPING[F.avg_pool2d] reduce_size = reduce(mul, k_size, 1) - rescaled_value = x * reduce_size # remove avg scaling - quant_input = quant_input.set(value=rescaled_value) + quant_input = quant_input.set(value=x) quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, reduce_size)) return quant_input