Skip to content

Commit

Permalink
Fix (brevitas_example/llm): fix learned_round entrypoint
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 3, 2024
1 parent 8e0c399 commit cacdd70
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,6 @@ def __init__(
*,
scale_optimizer_class: Optional[Type[Optimizer]] = None,
lr_scheduler_class: Optional[Type] = None,
optimizer_lr: float = 5e-3,
optimizer_scale_lr: float = 5e-3,
batch_size: float = 8,
iters: int = 200,
learn_scale: bool = False,
Expand All @@ -358,6 +356,7 @@ def __init__(
loss_scaling_factor: float = 1000.,
learned_round_loss_kwargs: Optional[Dict] = None,
optimizer_kwargs: Optional[Dict] = None,
scale_optimizer_kwargs: Optional[Dict] = None,
lr_scheduler_kwargs: Optional[Dict] = None,
) -> None:
# Verify that an optimizer is passed for optimizing the scale if learn_scale=True
Expand All @@ -367,8 +366,6 @@ def __init__(
self.optimizer_class = optimizer_class
self.scale_optimizer_class = scale_optimizer_class
self.lr_scheduler_class = lr_scheduler_class
self.optimizer_lr = optimizer_lr
self.optimizer_scale_lr = optimizer_scale_lr
self.batch_size = batch_size
self.iters = iters
self.learn_scale = learn_scale
Expand All @@ -377,6 +374,7 @@ def __init__(
self.amp_dtype = amp_dtype
self.loss_scaling_factor = loss_scaling_factor
self.optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
self.scale_optimizer_kwargs = {} if scale_optimizer_kwargs is None else scale_optimizer_kwargs
self.lr_scheduler_kwargs = {} if lr_scheduler_kwargs is None else lr_scheduler_kwargs
self.lr_scheduler_kwargs["total_iters"] = self.iters

Expand Down Expand Up @@ -481,7 +479,6 @@ def _optimize_learned_round_block(
*[
block_learned_round_module.parameters()
for block_learned_round_module in block_learned_round_modules]),
lr=self.optimizer_lr,
**self.optimizer_kwargs,
)
lr_scheduler = (
Expand All @@ -492,12 +489,10 @@ def _optimize_learned_round_block(
if self.learn_scale and scale_params is not None:
optimizer_scale = self.scale_optimizer_class(
scale_params,
lr=self.optimizer_scale_lr,
**self.optimizer_kwargs,
**self.scale_optimizer_kwargs,
)
lr_scheduler_scale = (
self.lr_scheduler_class(
optimizer_scale, start_factor=1, end_factor=0, total_iters=600)
self.lr_scheduler_class(optimizer_scale, **self.lr_scheduler_kwargs)
if self.lr_scheduler_class else None)
else:
optimizer_scale = None
Expand Down
11 changes: 6 additions & 5 deletions src/brevitas_examples/llm/llm_quant/learned_round_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,6 @@ def apply_learned_round(
learned_round_loss: str = "mse",
block_name_attribute: str = "layers",
optimizer: str = "sign_sgd",
optimizer_lr: float = 5e-3,
optimizer_scale_lr: float = 5e-3,
batch_size: int = 8,
learn_scale: bool = False,
use_best_model: bool = True,
Expand All @@ -126,11 +124,14 @@ def apply_learned_round(
optimizer_kwargs: Optional[Dict] = None,
lr_scheduler_kwargs: Optional[Dict] = None,
learned_round_loss_kwargs: Optional[Dict] = None,
scale_optimizer_class: Optional[str] = None,
scale_optimizer_kwargs: Optional[dict] = None,
) -> None:
# Parse strings to obtain the arguments for the optimizer
learned_round = parse_learned_round(learned_round)
learned_round_loss_class = parse_learned_round_loss_class(learned_round_loss)
optimizer_class = parse_optimizer_class(optimizer)
scale_optimizer_class = parse_optimizer_class(scale_optimizer_class)
lr_scheduler_class = parse_lr_scheduler_class(lr_scheduler)

llm_block_check_fn = functools.partial(get_blocks, block_name_attribute=block_name_attribute)
Expand All @@ -144,8 +145,6 @@ def apply_learned_round(
learned_round_loss_class=learned_round_loss_class,
optimizer_class=optimizer_class,
lr_scheduler_class=lr_scheduler_class,
optimizer_lr=optimizer_lr,
optimizer_scale_lr=optimizer_scale_lr,
batch_size=batch_size,
iters=iters,
learn_scale=learn_scale,
Expand All @@ -155,7 +154,9 @@ def apply_learned_round(
loss_scaling_factor=loss_scaling_factor,
learned_round_loss_kwargs=learned_round_loss_kwargs,
optimizer_kwargs=optimizer_kwargs,
lr_scheduler_kwargs=lr_scheduler_kwargs)
lr_scheduler_kwargs=lr_scheduler_kwargs,
scale_optimizer_kwargs=scale_optimizer_kwargs,
scale_optimizer_class=scale_optimizer_class)
cache = CacheLLM()
learned_round_optimizer.apply_learned_round(
model=model,
Expand Down
19 changes: 10 additions & 9 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,11 @@ def main(args):
with torch.no_grad():
model(**calibration_loader[0])

if args.act_calibration:
print("Apply act calibration...")
apply_calibration(model, calibration_loader)
print("Act calibration applied.")

if args.learned_round:
print("Applying learned round...")
remove_hooks(model)
Expand All @@ -376,18 +381,14 @@ def main(args):
calibration_loader,
iters=args.learned_round_iters,
block_name_attribute=args.gpxq_block_name,
optimizer_lr=args.learned_round_lr,
optimizer_scale_lr=args.learned_round_scale_lr,
learn_scale=args.learned_round_scale,
)
scale_optimizer_class='sgd',
optimizer_kwargs={'lr': args.learned_round_lr},
scale_optimizer_kwargs={
'lr': 1e-2, 'momentum': 0.9})
print("Learned round applied.")

model = offload_model(model)

if args.act_calibration:
print("Apply act calibration...")
apply_calibration(model, calibration_loader)
print("Act calibration applied.")
model = offload_model(model)

if args.gptq:
print("Applying GPTQ...")
Expand Down

0 comments on commit cacdd70

Please sign in to comment.