Skip to content

Commit

Permalink
Feat (bias_correction): better handling of bias-corrected quant models (
Browse files Browse the repository at this point in the history
#920)

---------

Co-authored-by: Giuseppe Franco <[email protected]>
  • Loading branch information
costigt-dev and Giuseppe5 authored May 14, 2024
1 parent 386c7c9 commit a1926f0
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 25 deletions.
42 changes: 19 additions & 23 deletions src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
'DisableEnableQuantization',
'bias_correction_mode',
'calibration_mode',
'load_quant_model']
'load_quant_model_mode']

_PARAM_PROXIES = (WeightQuantProxyFromInjector, BiasQuantProxyFromInjector)

Expand Down Expand Up @@ -106,28 +106,6 @@ def __exit__(self, type, value, traceback):
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)


class load_quant_model:

def __init__(self, model):
self.model = model
self.tracked_modules = []

def __enter__(self):
for module in self.model.modules():
if issubclass(type(module), QuantWBIOL):
if module.bias is None:
module.register_parameter(
'bias',
nn.Parameter(torch.empty(module.weight.shape[0])).to(module.weight.device))
self.tracked_modules.append(module)

def __exit__(self, type, value, traceback):
for module in self.tracked_modules:
# empty tensor has a numel result of 0
if torch.numel(module.bias) == 0:
module.bias = None


class bias_correction_mode:

def __init__(self, model, enabled=True, skip_if_no_bias=False):
Expand All @@ -146,6 +124,24 @@ def __exit__(self, type, value, traceback):
hook.remove()


class load_quant_model_mode:

def __init__(self, model):
self.model = model
self.tracked_modules = []

def __enter__(self):
for module in self.model.modules():
if issubclass(type(module), QuantWBIOL):
module._quant_load_model_mode = True

def __exit__(self, *args, **kwargs):
for module in self.model.modules():
if issubclass(type(module), QuantWBIOL):
module._quant_load_model_mode = False
return True


class ClipFloatWeights(Transform):

def __init__(self, threshold=15., layers_to_clip=_LAYERS_TO_CLIP) -> None:
Expand Down
13 changes: 13 additions & 0 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
**kwargs)
QuantWeightMixin.__init__(self, weight_quant, **kwargs)
QuantBiasMixin.__init__(self, bias_quant, **kwargs)
self._quant_load_model_mode = False

@abstractmethod
def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]):
Expand Down Expand Up @@ -158,3 +159,15 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe

quant_output = self.output_quant(output_tensor)
return self.pack_output(quant_output)

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
bias_key = prefix + 'bias'
# If the state dict has a bias and the module does not, bias correction was used
# We add a bias module to prevent failing during the load of the state dict
if bias_key in state_dict and self.bias is None and self._quant_load_model_mode:
self.register_parameter(
'bias', torch.nn.Parameter(torch.zeros(self.out_channels)).to(self.weight.device))
super(QuantWeightBiasInputOutputLayer, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
5 changes: 3 additions & 2 deletions tests/brevitas/graph/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from brevitas.graph.calibrate import bias_correction_mode
from brevitas.graph.calibrate import calibration_mode
from brevitas.graph.calibrate import load_quant_model
from brevitas.graph.calibrate import load_quant_model_mode
import brevitas.nn as qnn
from brevitas.quant import Int8ActPerTensorFixedPoint
# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations
Expand Down Expand Up @@ -213,7 +213,8 @@ def forward(self, inp):
assert m.bias is not None

new_model = SimpleQuantLinearNet()
with load_quant_model(new_model):

with load_quant_model_mode(new_model):
new_model.load_state_dict(model.state_dict())

for m in new_model.modules():
Expand Down

0 comments on commit a1926f0

Please sign in to comment.