Skip to content

Commit

Permalink
Fix (quant): improvements to quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Mar 4, 2025
1 parent 7af9749 commit a214b47
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
4 changes: 4 additions & 0 deletions src/brevitas/graph/quantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from brevitas.graph.base import ModuleToModuleByInstance
from brevitas.graph.utils import del_module
from brevitas.graph.utils import get_module
from brevitas.utils.logging import setup_logger

logging = setup_logger(__name__)

ADD_FNS = [torch.add, operator.add, operator.iadd]

Expand Down Expand Up @@ -522,6 +525,7 @@ def find_module(
for name, module in model.named_children():
full_name = prefix + '.' + name if prefix != '' else name
if name_blacklist is not None and full_name in name_blacklist:
logging.INFO(f"Skipping {name_blacklist} module from quantization")
continue
find_module(module, layer_map, module_to_replace, name_blacklist, full_name)

Expand Down
3 changes: 3 additions & 0 deletions src/brevitas/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ def name_from_module(model, module):


def replace_module(model, old_module, new_module):
old_module_is_training = old_module.is_training()
if isinstance(new_module, nn.Module):
new_module = new_module.train() if old_module_is_training else new_module.eval()
name = name_from_module(model, old_module)
set_module(model, new_module, name)

Expand Down
10 changes: 4 additions & 6 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,12 +419,6 @@ def quantize_llm(args, extra_args=None):
quantization_cm = nullcontext()

with quantization_cm:
with torch.no_grad():
model(**calibration_loader[0])

# We restore the original behaviour of the post-forward.
for k, v in dict_hooks.items():
k._hf_hook.post_forward = v

if args.optimize_rotations:
apply_rotation_optimization(
Expand Down Expand Up @@ -504,6 +498,10 @@ def quantize_llm(args, extra_args=None):
apply_bias_correction(model, calibration_loader)
print("Bias correction applied.")

# We restore the original behaviour of the post-forward.
for k, v in dict_hooks.items():
k._hf_hook.post_forward = v

if args.eval and not args.no_quantize:
print("Model eval...")
with torch.no_grad(), quant_inference_mode(model):
Expand Down

0 comments on commit a214b47

Please sign in to comment.