diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 2a93f1226..06c365a67 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -27,7 +27,7 @@ 'DisableEnableQuantization', 'bias_correction_mode', 'calibration_mode', - 'load_quant_model'] + 'load_quant_model_mode'] _PARAM_PROXIES = (WeightQuantProxyFromInjector, BiasQuantProxyFromInjector) @@ -106,28 +106,6 @@ def __exit__(self, type, value, traceback): restore_return_quant_tensor(self.model, self.return_quant_tensor_state) -class load_quant_model: - - def __init__(self, model): - self.model = model - self.tracked_modules = [] - - def __enter__(self): - for module in self.model.modules(): - if issubclass(type(module), QuantWBIOL): - if module.bias is None: - module.register_parameter( - 'bias', - nn.Parameter(torch.empty(module.weight.shape[0])).to(module.weight.device)) - self.tracked_modules.append(module) - - def __exit__(self, type, value, traceback): - for module in self.tracked_modules: - # empty tensor has a numel result of 0 - if torch.numel(module.bias) == 0: - module.bias = None - - class bias_correction_mode: def __init__(self, model, enabled=True, skip_if_no_bias=False): @@ -146,6 +124,24 @@ def __exit__(self, type, value, traceback): hook.remove() +class load_quant_model_mode: + + def __init__(self, model): + self.model = model + self.tracked_modules = [] + + def __enter__(self): + for module in self.model.modules(): + if issubclass(type(module), QuantWBIOL): + module._quant_load_model_mode = True + + def __exit__(self, *args, **kwargs): + for module in self.model.modules(): + if issubclass(type(module), QuantWBIOL): + module._quant_load_model_mode = False + return True + + class ClipFloatWeights(Transform): def __init__(self, threshold=15., layers_to_clip=_LAYERS_TO_CLIP) -> None: diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index cd5f48418..43d97a071 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -109,6 +109,7 @@ def __init__( **kwargs) QuantWeightMixin.__init__(self, weight_quant, **kwargs) QuantBiasMixin.__init__(self, bias_quant, **kwargs) + self._quant_load_model_mode = False @abstractmethod def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]): @@ -158,3 +159,15 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe quant_output = self.output_quant(output_tensor) return self.pack_output(quant_output) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + bias_key = prefix + 'bias' + # If the state dict has a bias and the module does not, bias correction was used + # We add a bias module to prevent failing during the load of the state dict + if bias_key in state_dict and self.bias is None and self._quant_load_model_mode: + self.register_parameter( + 'bias', torch.nn.Parameter(torch.zeros(self.out_channels)).to(self.weight.device)) + super(QuantWeightBiasInputOutputLayer, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 1775c68d6..af0d16deb 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -10,7 +10,7 @@ from brevitas.graph.calibrate import bias_correction_mode from brevitas.graph.calibrate import calibration_mode -from brevitas.graph.calibrate import load_quant_model +from brevitas.graph.calibrate import load_quant_model_mode import brevitas.nn as qnn from brevitas.quant import Int8ActPerTensorFixedPoint # Use custom implementation of kthvalue as work around to (b)float16 kernel limitations @@ -213,7 +213,8 @@ def forward(self, inp): assert m.bias is not None new_model = SimpleQuantLinearNet() - with load_quant_model(new_model): + + with load_quant_model_mode(new_model): new_model.load_state_dict(model.state_dict()) for m in new_model.modules():