Skip to content

Commit

Permalink
fix (nn/avg_pool): Fix for trunc quant not being applied
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Aug 29, 2024
1 parent 21537ef commit dd179d4
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions src/brevitas/nn/quant_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,9 @@ def forward(self, input: Union[Tensor, QuantTensor]):

if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled:
y = AvgPool2d.forward(self, x)
if self.is_trunc_quant_enabled:
rescaled_value = x.value * self._avg_scaling
x = x.set(value=rescaled_value)
x = self.trunc_quant(x)
rescaled_value = y.value * self._avg_scaling
y = y.set(value=rescaled_value)
y = self.trunc_quant(y)
else:
y = AvgPool2d.forward(self, _unpack_quant_tensor(x))

Expand Down Expand Up @@ -123,11 +122,10 @@ def forward(self, input: Union[Tensor, QuantTensor]):

if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled:
y = AdaptiveAvgPool2d.forward(self, x)
if self.is_trunc_quant_enabled:
k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:])
reduce_size = reduce(mul, k_size, 1)
rescaled_value = y.value * reduce_size # remove avg scaling
y = y.set(value=rescaled_value)
k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:])
reduce_size = reduce(mul, k_size, 1)
rescaled_value = y.value * reduce_size # remove avg scaling
y = y.set(value=rescaled_value)
y = self.trunc_quant(y)
else:
y = AdaptiveAvgPool2d.forward(self, _unpack_quant_tensor(x))
Expand Down

0 comments on commit dd179d4

Please sign in to comment.