From f0f8c560f99517f3474cb6789e3959b2dbfb9249 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 22 Mar 2024 14:57:48 +0100 Subject: [PATCH] Fix (nn): removed unused caching in quant adaptive avgpool2d (#911) --- src/brevitas/nn/quant_avg_pool.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/src/brevitas/nn/quant_avg_pool.py b/src/brevitas/nn/quant_avg_pool.py index 5d567d0ca..554504908 100644 --- a/src/brevitas/nn/quant_avg_pool.py +++ b/src/brevitas/nn/quant_avg_pool.py @@ -93,14 +93,10 @@ def __init__( output_size: Union[int, Tuple[int, int]], trunc_quant: Optional[AccQuantType] = RoundTo8bit, return_quant_tensor: bool = True, - cache_kernel_size_stride: bool = True, **kwargs): AdaptiveAvgPool2d.__init__(self, output_size=output_size) QuantLayerMixin.__init__(self, return_quant_tensor) TruncMixin.__init__(self, trunc_quant=trunc_quant, **kwargs) - self.cache_kernel_size_stride = cache_kernel_size_stride - self._cached_kernel_size = None - self._cached_kernel_stride = None @property def channelwise_separable(self) -> bool: @@ -114,15 +110,8 @@ def requires_export_handler(self): def padding(self): return 0 - @property - def kernel_size(self): - return self._cached_kernel_size - - @property - def stride(self): - return self._cached_kernel_stride - - def compute_kernel_size_stride(self, input_shape, output_shape): + @staticmethod + def compute_kernel_size_stride(input_shape, output_shape): kernel_size_list = [] stride_list = [] for inp, out in zip(input_shape, output_shape): @@ -141,10 +130,6 @@ def forward(self, input: Union[Tensor, QuantTensor]): self._set_global_is_quant_layer(False) return out - if self.cache_kernel_size_stride: - self._cached_kernel_size = k_size - self._cached_kernel_stride = stride - if isinstance(x, QuantTensor): y = x.set(value=super(TruncAdaptiveAvgPool2d, self).forward(x.value)) k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:])