Skip to content

Commit

Permalink
Feat (brevitas_examples/llm): correct scale init with CPU offloading
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 16, 2024
1 parent 482531c commit 3902b6c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/brevitas/utils/python_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,13 @@ def _getattr(obj, attr):
return getattr(obj, attr)

return functools.reduce(_getattr, [obj] + attr.split("."))


def hooked_on_a_function(function, prefunction):

@functools.wraps(function)
def run(*args, **kwargs):
prefunction()
return function(*args, **kwargs)

return run
18 changes: 18 additions & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import argparse
from copy import deepcopy
import functools
import sys
from warnings import warn

Expand All @@ -19,8 +20,10 @@
from brevitas.graph.equalize import LayerwiseActivationRotation
from brevitas.graph.quantize import layerwise_quantize
from brevitas.graph.utils import get_module
from brevitas.utils.python_utils import hooked_on_a_function
from brevitas_examples.common.accelerate_utils.accelerate import offload_model
from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks
from brevitas_examples.common.accelerate_utils.accelerate import update_internal_dict
from brevitas_examples.common.generative.quantize import generate_quant_maps
from brevitas_examples.common.generative.quantize import generate_quantizers
from brevitas_examples.common.parse_utils import quant_format_validator
Expand Down Expand Up @@ -377,9 +380,24 @@ def main(args):

model = offload_model(model)

dict_hooks = dict()

def update_params_post_init(module):
update_internal_dict(module)

for m in model.modules():
if hasattr(m, '_hf_hook'):
if m._hf_hook.weights_map is not None:
dict_hooks[m] = m._hf_hook.post_forward
new_funct = functools.partial(update_params_post_init, m)
m._hf_hook.post_forward = hooked_on_a_function(m._hf_hook.post_forward, new_funct)

with torch.no_grad():
model(**calibration_loader[0])

for k, v in dict_hooks.items():
k._hf_hook.post_forward = v

if args.act_calibration:
print("Apply act calibration...")
apply_calibration(model, calibration_loader)
Expand Down

0 comments on commit 3902b6c

Please sign in to comment.