diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 8ac55caaa..f72c1e6e0 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -286,11 +286,18 @@ def apply_correction(self, model): for name, module in model.named_modules(): if name in self.correction_map.keys(): correction = self.correction_map[name] / self.iterations[name] + # When accelerate is enabled, bring tensors onto the device to avoid allocating a meta parameter. + if hasattr(module, 'allocate_params'): + module.allocate_params(module) if module.bias is not None: module.bias.data += correction elif self.skip_if_no_bias is False: + # If accelerate is enabled, bias will be on the same execution device as the weights, but won't be managed properly by accelerate module.register_parameter( 'bias', nn.Parameter(correction).to(module.weight.device)) + # Offload params again + if hasattr(module, 'offload_params'): + module.offload_params(module) def compute_correct_bias(self, module, inp, name): inp = self.unpack_input(inp)