Skip to content

Commit

Permalink
Fix (nn): removed unused caching in quant adaptive avgpool2d (#911)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Mar 22, 2024
1 parent e45b0c3 commit f0f8c56
Showing 1 changed file with 2 additions and 17 deletions.
19 changes: 2 additions & 17 deletions src/brevitas/nn/quant_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,10 @@ def __init__(
output_size: Union[int, Tuple[int, int]],
trunc_quant: Optional[AccQuantType] = RoundTo8bit,
return_quant_tensor: bool = True,
cache_kernel_size_stride: bool = True,
**kwargs):
AdaptiveAvgPool2d.__init__(self, output_size=output_size)
QuantLayerMixin.__init__(self, return_quant_tensor)
TruncMixin.__init__(self, trunc_quant=trunc_quant, **kwargs)
self.cache_kernel_size_stride = cache_kernel_size_stride
self._cached_kernel_size = None
self._cached_kernel_stride = None

@property
def channelwise_separable(self) -> bool:
Expand All @@ -114,15 +110,8 @@ def requires_export_handler(self):
def padding(self):
return 0

@property
def kernel_size(self):
return self._cached_kernel_size

@property
def stride(self):
return self._cached_kernel_stride

def compute_kernel_size_stride(self, input_shape, output_shape):
@staticmethod
def compute_kernel_size_stride(input_shape, output_shape):
kernel_size_list = []
stride_list = []
for inp, out in zip(input_shape, output_shape):
Expand All @@ -141,10 +130,6 @@ def forward(self, input: Union[Tensor, QuantTensor]):
self._set_global_is_quant_layer(False)
return out

if self.cache_kernel_size_stride:
self._cached_kernel_size = k_size
self._cached_kernel_stride = stride

if isinstance(x, QuantTensor):
y = x.set(value=super(TruncAdaptiveAvgPool2d, self).forward(x.value))
k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:])
Expand Down

0 comments on commit f0f8c56

Please sign in to comment.