diff --git a/src/brevitas/nn/quant_avg_pool.py b/src/brevitas/nn/quant_avg_pool.py index a427a4d25..79d493c43 100644 --- a/src/brevitas/nn/quant_avg_pool.py +++ b/src/brevitas/nn/quant_avg_pool.py @@ -59,7 +59,7 @@ def forward(self, input: Union[Tensor, QuantTensor]): return self.export_handler(x.value) x = x.set(value=super(TruncAvgPool2d, self).forward(x.value)) if self.is_trunc_quant_enabled: - assert isinstance(x, QuantTensor) # check input quant tensor is filled with values + assert isinstance(x, QuantTensor) # remove avg scaling rescaled_value = x.value * self._avg_scaling x = x.set(value=rescaled_value) @@ -138,7 +138,7 @@ def forward(self, input: Union[Tensor, QuantTensor]): self._cached_kernel_size = k_size self._cached_kernel_stride = stride if self.is_trunc_quant_enabled: - assert isinstance(y, QuantTensor) # check input quant tensor is filled with values + assert isinstance(y, QuantTensor) reduce_size = reduce(mul, k_size, 1) rescaled_value = y.value * reduce_size # remove avg scaling y = y.set(value=rescaled_value)