From fb9b2571fbc6294a6f81439dadc12219d9edd07c Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 29 May 2024 15:50:44 +0000 Subject: [PATCH 1/4] Fix (graph/bias_correction): Fix when layer parameters are offloaded to `accelerate` --- src/brevitas/graph/calibrate.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 8ac55caaa..ad849c5e7 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -289,8 +289,14 @@ def apply_correction(self, model): if module.bias is not None: module.bias.data += correction elif self.skip_if_no_bias is False: + # When accelerate is enabled, bring tensors onto the device to avoid allocating a meta parameter. + if hasattr(self.layer, 'allocate_params'): + self.layer.allocate_params(self.layer) module.register_parameter( 'bias', nn.Parameter(correction).to(module.weight.device)) + # Offload params again + if hasattr(self.layer, 'offload_params'): + self.layer.offload_params(self.layer) def compute_correct_bias(self, module, inp, name): inp = self.unpack_input(inp) From d264a7c2f2eccad7d43b31493b6007e8092c0fde Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 30 May 2024 17:09:01 +0000 Subject: [PATCH 2/4] Fix (bias_correction): Typo fix --- src/brevitas/graph/calibrate.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index ad849c5e7..0fd01a2eb 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -290,13 +290,13 @@ def apply_correction(self, model): module.bias.data += correction elif self.skip_if_no_bias is False: # When accelerate is enabled, bring tensors onto the device to avoid allocating a meta parameter. - if hasattr(self.layer, 'allocate_params'): - self.layer.allocate_params(self.layer) + if hasattr(module, 'allocate_params'): + module.allocate_params(module) module.register_parameter( 'bias', nn.Parameter(correction).to(module.weight.device)) # Offload params again - if hasattr(self.layer, 'offload_params'): - self.layer.offload_params(self.layer) + if hasattr(module, 'offload_params'): + module.offload_params(module) def compute_correct_bias(self, module, inp, name): inp = self.unpack_input(inp) From 930296eb3c4d77a15a28461981758b9a6889dd6d Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 31 May 2024 13:46:28 +0000 Subject: [PATCH 3/4] Fix (bias_correction): Apply accelerate fix to entire `if/elif` block. --- src/brevitas/graph/calibrate.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 0fd01a2eb..d9c0b970e 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -286,17 +286,17 @@ def apply_correction(self, model): for name, module in model.named_modules(): if name in self.correction_map.keys(): correction = self.correction_map[name] / self.iterations[name] + # When accelerate is enabled, bring tensors onto the device to avoid allocating a meta parameter. + if hasattr(module, 'allocate_params'): + module.allocate_params(module) if module.bias is not None: module.bias.data += correction elif self.skip_if_no_bias is False: - # When accelerate is enabled, bring tensors onto the device to avoid allocating a meta parameter. - if hasattr(module, 'allocate_params'): - module.allocate_params(module) module.register_parameter( 'bias', nn.Parameter(correction).to(module.weight.device)) - # Offload params again - if hasattr(module, 'offload_params'): - module.offload_params(module) + # Offload params again + if hasattr(module, 'offload_params'): + module.offload_params(module) def compute_correct_bias(self, module, inp, name): inp = self.unpack_input(inp) From b1d7d5a29a4306c3d65127a51beca3a19df5633c Mon Sep 17 00:00:00 2001 From: nickfraser Date: Mon, 8 Jul 2024 15:19:52 +0100 Subject: [PATCH 4/4] fix (bias_corr/accelerate): Added comment --- src/brevitas/graph/calibrate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index d9c0b970e..f72c1e6e0 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -292,6 +292,7 @@ def apply_correction(self, model): if module.bias is not None: module.bias.data += correction elif self.skip_if_no_bias is False: + # If accelerate is enabled, bias will be on the same execution device as the weights, but won't be managed properly by accelerate module.register_parameter( 'bias', nn.Parameter(correction).to(module.weight.device)) # Offload params again