From dd179d45927fc969425f5601294b2888ba3d5952 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 29 Aug 2024 11:29:05 +0100 Subject: [PATCH] fix (nn/avg_pool): Fix for trunc quant not being applied --- src/brevitas/nn/quant_avg_pool.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/brevitas/nn/quant_avg_pool.py b/src/brevitas/nn/quant_avg_pool.py index 73520598c..7a3f108da 100644 --- a/src/brevitas/nn/quant_avg_pool.py +++ b/src/brevitas/nn/quant_avg_pool.py @@ -62,10 +62,9 @@ def forward(self, input: Union[Tensor, QuantTensor]): if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled: y = AvgPool2d.forward(self, x) - if self.is_trunc_quant_enabled: - rescaled_value = x.value * self._avg_scaling - x = x.set(value=rescaled_value) - x = self.trunc_quant(x) + rescaled_value = y.value * self._avg_scaling + y = y.set(value=rescaled_value) + y = self.trunc_quant(y) else: y = AvgPool2d.forward(self, _unpack_quant_tensor(x)) @@ -123,11 +122,10 @@ def forward(self, input: Union[Tensor, QuantTensor]): if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled: y = AdaptiveAvgPool2d.forward(self, x) - if self.is_trunc_quant_enabled: - k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:]) - reduce_size = reduce(mul, k_size, 1) - rescaled_value = y.value * reduce_size # remove avg scaling - y = y.set(value=rescaled_value) + k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:]) + reduce_size = reduce(mul, k_size, 1) + rescaled_value = y.value * reduce_size # remove avg scaling + y = y.set(value=rescaled_value) y = self.trunc_quant(y) else: y = AdaptiveAvgPool2d.forward(self, _unpack_quant_tensor(x))