From cacdd70834022426954cb673887bd11e988a5063 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 3 Dec 2024 10:48:02 +0000 Subject: [PATCH] Fix (brevitas_example/llm): fix learned_round entrypoint --- .../learned_round/learned_round_optimizer.py | 13 ++++--------- .../llm/llm_quant/learned_round_utils.py | 11 ++++++----- src/brevitas_examples/llm/main.py | 19 ++++++++++--------- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index d7763f957..ab9206a66 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -347,8 +347,6 @@ def __init__( *, scale_optimizer_class: Optional[Type[Optimizer]] = None, lr_scheduler_class: Optional[Type] = None, - optimizer_lr: float = 5e-3, - optimizer_scale_lr: float = 5e-3, batch_size: float = 8, iters: int = 200, learn_scale: bool = False, @@ -358,6 +356,7 @@ def __init__( loss_scaling_factor: float = 1000., learned_round_loss_kwargs: Optional[Dict] = None, optimizer_kwargs: Optional[Dict] = None, + scale_optimizer_kwargs: Optional[Dict] = None, lr_scheduler_kwargs: Optional[Dict] = None, ) -> None: # Verify that an optimizer is passed for optimizing the scale if learn_scale=True @@ -367,8 +366,6 @@ def __init__( self.optimizer_class = optimizer_class self.scale_optimizer_class = scale_optimizer_class self.lr_scheduler_class = lr_scheduler_class - self.optimizer_lr = optimizer_lr - self.optimizer_scale_lr = optimizer_scale_lr self.batch_size = batch_size self.iters = iters self.learn_scale = learn_scale @@ -377,6 +374,7 @@ def __init__( self.amp_dtype = amp_dtype self.loss_scaling_factor = loss_scaling_factor self.optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs + self.scale_optimizer_kwargs = {} if scale_optimizer_kwargs is None else scale_optimizer_kwargs self.lr_scheduler_kwargs = {} if lr_scheduler_kwargs is None else lr_scheduler_kwargs self.lr_scheduler_kwargs["total_iters"] = self.iters @@ -481,7 +479,6 @@ def _optimize_learned_round_block( *[ block_learned_round_module.parameters() for block_learned_round_module in block_learned_round_modules]), - lr=self.optimizer_lr, **self.optimizer_kwargs, ) lr_scheduler = ( @@ -492,12 +489,10 @@ def _optimize_learned_round_block( if self.learn_scale and scale_params is not None: optimizer_scale = self.scale_optimizer_class( scale_params, - lr=self.optimizer_scale_lr, - **self.optimizer_kwargs, + **self.scale_optimizer_kwargs, ) lr_scheduler_scale = ( - self.lr_scheduler_class( - optimizer_scale, start_factor=1, end_factor=0, total_iters=600) + self.lr_scheduler_class(optimizer_scale, **self.lr_scheduler_kwargs) if self.lr_scheduler_class else None) else: optimizer_scale = None diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index fa3ca0048..395423012 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -114,8 +114,6 @@ def apply_learned_round( learned_round_loss: str = "mse", block_name_attribute: str = "layers", optimizer: str = "sign_sgd", - optimizer_lr: float = 5e-3, - optimizer_scale_lr: float = 5e-3, batch_size: int = 8, learn_scale: bool = False, use_best_model: bool = True, @@ -126,11 +124,14 @@ def apply_learned_round( optimizer_kwargs: Optional[Dict] = None, lr_scheduler_kwargs: Optional[Dict] = None, learned_round_loss_kwargs: Optional[Dict] = None, + scale_optimizer_class: Optional[str] = None, + scale_optimizer_kwargs: Optional[dict] = None, ) -> None: # Parse strings to obtain the arguments for the optimizer learned_round = parse_learned_round(learned_round) learned_round_loss_class = parse_learned_round_loss_class(learned_round_loss) optimizer_class = parse_optimizer_class(optimizer) + scale_optimizer_class = parse_optimizer_class(scale_optimizer_class) lr_scheduler_class = parse_lr_scheduler_class(lr_scheduler) llm_block_check_fn = functools.partial(get_blocks, block_name_attribute=block_name_attribute) @@ -144,8 +145,6 @@ def apply_learned_round( learned_round_loss_class=learned_round_loss_class, optimizer_class=optimizer_class, lr_scheduler_class=lr_scheduler_class, - optimizer_lr=optimizer_lr, - optimizer_scale_lr=optimizer_scale_lr, batch_size=batch_size, iters=iters, learn_scale=learn_scale, @@ -155,7 +154,9 @@ def apply_learned_round( loss_scaling_factor=loss_scaling_factor, learned_round_loss_kwargs=learned_round_loss_kwargs, optimizer_kwargs=optimizer_kwargs, - lr_scheduler_kwargs=lr_scheduler_kwargs) + lr_scheduler_kwargs=lr_scheduler_kwargs, + scale_optimizer_kwargs=scale_optimizer_kwargs, + scale_optimizer_class=scale_optimizer_class) cache = CacheLLM() learned_round_optimizer.apply_learned_round( model=model, diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 46d974d23..dd1c5ca91 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -368,6 +368,11 @@ def main(args): with torch.no_grad(): model(**calibration_loader[0]) + if args.act_calibration: + print("Apply act calibration...") + apply_calibration(model, calibration_loader) + print("Act calibration applied.") + if args.learned_round: print("Applying learned round...") remove_hooks(model) @@ -376,18 +381,14 @@ def main(args): calibration_loader, iters=args.learned_round_iters, block_name_attribute=args.gpxq_block_name, - optimizer_lr=args.learned_round_lr, - optimizer_scale_lr=args.learned_round_scale_lr, learn_scale=args.learned_round_scale, - ) + scale_optimizer_class='sgd', + optimizer_kwargs={'lr': args.learned_round_lr}, + scale_optimizer_kwargs={ + 'lr': 1e-2, 'momentum': 0.9}) print("Learned round applied.") - model = offload_model(model) - - if args.act_calibration: - print("Apply act calibration...") - apply_calibration(model, calibration_loader) - print("Act calibration applied.") + model = offload_model(model) if args.gptq: print("Applying GPTQ...")