diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 67e68f511..b57ae4716 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -160,7 +160,7 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self.cache_inference_quant_bias = False self.cache_inference_quant_bias_metadata_only = False self.requires_input_scale = self.quant_injector.requires_input_scale - self.skip_create_quant_tensor = True + self.skip_create_quant_tensor = False @property def tracked_parameter_list(self): diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 4d1afa165..cff192490 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -60,10 +60,6 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor: @runtime_checkable class AccQuantProxyProtocol(QuantProxyProtocol, Protocol): - def __init__(self): - super().__init__() - self.skip_create_quant_tensor = False - def forward(self, x: QuantTensor) -> QuantTensor: ... @@ -251,6 +247,10 @@ def zero_point(self, force_eval=True): class ClampQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol): + def __init__(self): + super().__init__() + self.skip_create_quant_tensor = False + def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: if self.is_quant_enabled: out_tuple = self.tensor_quant(x.value, x.scale, x.bit_width) @@ -264,6 +264,10 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: class TruncQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.skip_create_quant_tensor = False + def bit_width(self): if not self.is_quant_enabled: return None