Skip to content

Commit

Permalink
Fix for avgpool export
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 20, 2024
1 parent 838b402 commit 21f055d
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions src/brevitas/nn/quant_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from brevitas.inject.defaults import RoundTo8bit
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO

from .mixin.acc import AccQuantType
from .mixin.acc import TruncMixin
Expand All @@ -38,6 +39,9 @@ def __init__(
AvgPool2d.__init__(self, kernel_size=kernel_size, stride=stride)
QuantLayerMixin.__init__(self, return_quant_tensor)
TruncMixin.__init__(self, trunc_quant=trunc_quant, **kwargs)
self.cache_inference_quant_act = False
self.cache_quant_io_metadata_only = True
self.cache_class = None

@property
def channelwise_separable(self) -> bool:
Expand All @@ -60,7 +64,12 @@ def forward(self, input: Union[Tensor, QuantTensor]):
if self.export_mode:
return self.export_handler(_unpack_quant_tensor(x))

if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled:
if (isinstance(x, QuantTensor) or
self.cache_class is not None) and self.is_trunc_quant_enabled:
if self.cache_inference_quant_act:
self.cache_class = _CachedIO(x, self.cache_quant_io_metadata_only)
if not isinstance(x, QuantTensor):
x = self.cache_class.quant_tensor.set(value=x)
y = AvgPool2d.forward(self, x)
rescaled_value = y.value * self._avg_scaling
y = y.set(value=rescaled_value)
Expand All @@ -87,6 +96,9 @@ def __init__(
AdaptiveAvgPool2d.__init__(self, output_size=output_size)
QuantLayerMixin.__init__(self, return_quant_tensor)
TruncMixin.__init__(self, trunc_quant=trunc_quant, **kwargs)
self.cache_inference_quant_act = False
self.cache_quant_io_metadata_only = True
self.cache_class = None

@property
def channelwise_separable(self) -> bool:
Expand Down Expand Up @@ -120,7 +132,12 @@ def forward(self, input: Union[Tensor, QuantTensor]):
self._set_global_is_quant_layer(False)
return out

if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled:
if (isinstance(x, QuantTensor) or
self.cache_class is not None) and self.is_trunc_quant_enabled:
if self.cache_inference_quant_act:
self.cache_class = _CachedIO(x, self.cache_quant_io_metadata_only)
if not isinstance(x, QuantTensor):
x = self.cache_class.quant_tensor.set(value=x)
y = AdaptiveAvgPool2d.forward(self, x)
k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:])
reduce_size = reduce(mul, k_size, 1)
Expand Down

0 comments on commit 21f055d

Please sign in to comment.