diff --git a/src/brevitas/utils/python_utils.py b/src/brevitas/utils/python_utils.py index fa69499c7..ae8845b48 100644 --- a/src/brevitas/utils/python_utils.py +++ b/src/brevitas/utils/python_utils.py @@ -54,3 +54,13 @@ def _getattr(obj, attr): return getattr(obj, attr) return functools.reduce(_getattr, [obj] + attr.split(".")) + + +def hooked_on_a_function(function, prefunction): + + @functools.wraps(function) + def run(*args, **kwargs): + prefunction() + return function(*args, **kwargs) + + return run diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 65c5334d7..05d9bcb47 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -3,6 +3,7 @@ import argparse from copy import deepcopy +import functools import sys from warnings import warn @@ -19,8 +20,10 @@ from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.utils import get_module +from brevitas.utils.python_utils import hooked_on_a_function from brevitas_examples.common.accelerate_utils.accelerate import offload_model from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks +from brevitas_examples.common.accelerate_utils.accelerate import update_internal_dict from brevitas_examples.common.generative.quantize import generate_quant_maps from brevitas_examples.common.generative.quantize import generate_quantizers from brevitas_examples.common.parse_utils import quant_format_validator @@ -377,9 +380,24 @@ def main(args): model = offload_model(model) + dict_hooks = dict() + + def update_params_post_init(module): + update_internal_dict(module) + + for m in model.modules(): + if hasattr(m, '_hf_hook'): + if m._hf_hook.weights_map is not None: + dict_hooks[m] = m._hf_hook.post_forward + new_funct = functools.partial(update_params_post_init, m) + m._hf_hook.post_forward = hooked_on_a_function(m._hf_hook.post_forward, new_funct) + with torch.no_grad(): model(**calibration_loader[0]) + for k, v in dict_hooks.items(): + k._hf_hook.post_forward = v + if args.act_calibration: print("Apply act calibration...") apply_calibration(model, calibration_loader)