Skip to content

Commit

Permalink
Fix (quant_tensor): fix AvgPool functional implementation (#945)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Apr 30, 2024
1 parent 66f28b2 commit cdc1811
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
17 changes: 16 additions & 1 deletion src/brevitas/nn/quant_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ def channelwise_separable(self) -> bool:
def requires_export_handler(self):
return True

@property
def _avg_scaling(self):
if isinstance(self.kernel_size, tuple):
return self.kernel_size[0] * self.kernel_size[1]
else:
return self.kernel_size * self.kernel_size

def forward(self, input: Union[Tensor, QuantTensor]):
x = self.unpack_input(input)

Expand All @@ -55,7 +62,10 @@ def forward(self, input: Union[Tensor, QuantTensor]):

if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled:
y = AvgPool2d.forward(self, x)
y = self.trunc_quant(y)
if self.is_trunc_quant_enabled:
rescaled_value = x.value * self._avg_scaling
x = x.set(value=rescaled_value)
x = self.trunc_quant(x)
else:
y = AvgPool2d.forward(self, _unpack_quant_tensor(x))

Expand Down Expand Up @@ -113,6 +123,11 @@ def forward(self, input: Union[Tensor, QuantTensor]):

if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled:
y = AdaptiveAvgPool2d.forward(self, x)
if self.is_trunc_quant_enabled:
k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:])
reduce_size = reduce(mul, k_size, 1)
rescaled_value = y.value * reduce_size # remove avg scaling
y = y.set(value=rescaled_value)
y = self.trunc_quant(y)
else:
y = AdaptiveAvgPool2d.forward(self, _unpack_quant_tensor(x))
Expand Down
7 changes: 3 additions & 4 deletions src/brevitas/quant_tensor/torch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ def avg_pool2d_handler(
avg_scaling = kernel_size[0] * kernel_size[1]
else:
avg_scaling = kernel_size * kernel_size
rescaled_value = x * avg_scaling
quant_input = quant_input.set(value=rescaled_value)

quant_input = quant_input.set(value=x)
quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, avg_scaling))
return quant_input

Expand All @@ -264,9 +264,8 @@ def adaptive_avg_pool2d_handler(quant_input, output_shape):

max_acc_bit_width = FN_ACC_BITWIDTH_MAPPING[F.avg_pool2d]
reduce_size = reduce(mul, k_size, 1)
rescaled_value = x * reduce_size # remove avg scaling

quant_input = quant_input.set(value=rescaled_value)
quant_input = quant_input.set(value=x)
quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, reduce_size))
return quant_input

Expand Down

0 comments on commit cdc1811

Please sign in to comment.