diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 9eed8b38e..bb435b7ef 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -24,7 +24,11 @@ from .base import Transform __all__ = [ - 'ClipFloatWeights', 'DisableEnableQuantization', 'bias_correction_mode', 'calibration_mode'] + 'ClipFloatWeights', + 'DisableEnableQuantization', + 'bias_correction_mode', + 'calibration_mode', + 'load_quant_model'] _PARAM_PROXIES = (WeightQuantProxyFromInjector, BiasQuantProxyFromInjector) @@ -85,11 +89,33 @@ def __exit__(self, type, value, traceback): self.model, is_training=self.previous_training_state, quantization_enabled=True) +class load_quant_model: + + def __init__(self, model): + self.model = model + self.tracked_modules = [] + + def __enter__(self): + for module in self.model.modules(): + if issubclass(type(module), QuantWBIOL): + if module.bias is None: + module.register_parameter( + 'bias', + nn.Parameter(torch.empty(module.weight.shape[0])).to(module.weight.device)) + self.tracked_modules.append(module) + + def __exit__(self, type, value, traceback): + for module in self.tracked_modules: + # empty tensor has a numel result of 0 + if torch.numel(module.bias) == 0: + module.bias = None + + class bias_correction_mode: - def __init__(self, model, enabled=True): + def __init__(self, model, enabled=True, skip_if_no_bias=False): self.model = model - self.bias_correction = _BiasCorrection() + self.bias_correction = _BiasCorrection(skip_if_no_bias=skip_if_no_bias) self.enabled = enabled self.hooks = [] @@ -209,7 +235,7 @@ class _BiasCorrection(DisableEnableQuantization): LAYERS = (QuantWBIOL,) - def __init__(self, layers=LAYERS): + def __init__(self, layers=LAYERS, skip_if_no_bias=False): super(_BiasCorrection, self).__init__() self.layers = layers self.iterations = {} @@ -217,6 +243,7 @@ def __init__(self, layers=LAYERS): self.float_mean_map = {} self.collect_float_mean_hooks = [] self.correct_bias_hooks = [] + self.skip_if_no_bias = skip_if_no_bias def compute_mean(self, inp, transpose_dim): inp = inp.transpose(0, transpose_dim) @@ -248,7 +275,7 @@ def apply_correction(self, model): correction = self.correction_map[name] / self.iterations[name] if module.bias is not None: module.bias.data += correction - else: + elif self.skip_if_no_bias is False: module.register_parameter( 'bias', nn.Parameter(correction).to(module.weight.device)) diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 6580d971b..0b6303a8b 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -10,6 +10,7 @@ from brevitas.graph.calibrate import bias_correction_mode from brevitas.graph.calibrate import calibration_mode +from brevitas.graph.calibrate import load_quant_model import brevitas.nn as qnn from brevitas.quant import Int8ActPerTensorFixedPoint # Use custom implementation of kthvalue as work around to (b)float16 kernel limitations @@ -189,3 +190,53 @@ def simple_hook(mod, inp, out): ) # In bias_correction mode, the input to each layer is equal to the FP output of the previous layer assert (inputs[1] == fp_outs[1, 0, :]).all( ) # In bias_correction mode, the input to each layer is equal to the FP output of the previous layer + + +def test_import_bias_correction(): + + class SimpleQuantLinearNet(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.net = nn.Sequential(qnn.QuantLinear(IN_CH, OUT_CH, bias=False)) + + def forward(self, inp): + return self.net(inp) + + model = SimpleQuantLinearNet() + + with bias_correction_mode(model): + model(torch.randn((1, IN_CH))) + + for m in model.modules(): + if isinstance(m, qnn.QuantLinear): + assert m.bias is not None + + new_model = SimpleQuantLinearNet() + with load_quant_model(new_model): + new_model.load_state_dict(model.state_dict()) + + for m in new_model.modules(): + if isinstance(m, qnn.QuantLinear): + assert m.bias is not None + + +def test_bias_correction_flag(): + + class SimpleQuantLinearNet(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.net = nn.Sequential(qnn.QuantLinear(IN_CH, OUT_CH, bias=False)) + + def forward(self, inp): + return self.net(inp) + + model = SimpleQuantLinearNet() + + with bias_correction_mode(model, skip_if_no_bias=True): + model(torch.randn((1, IN_CH))) + + for m in model.modules(): + if isinstance(m, qnn.QuantLinear): + assert m.bias is None