diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index ab9206a66..498d3eedc 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -193,6 +193,7 @@ from accelerate.utils.operations import send_to_device import torch from torch import autocast +from torch import GradScaler from torch import nn from torch.optim.optimizer import Optimizer from torch.utils.data.dataloader import DataLoader @@ -351,7 +352,6 @@ def __init__( iters: int = 200, learn_scale: bool = False, use_best_model: bool = True, - use_amp: bool = True, amp_dtype: torch.dtype = torch.float16, loss_scaling_factor: float = 1000., learned_round_loss_kwargs: Optional[Dict] = None, @@ -370,7 +370,6 @@ def __init__( self.iters = iters self.learn_scale = learn_scale self.use_best_model = use_best_model - self.use_amp = use_amp self.amp_dtype = amp_dtype self.loss_scaling_factor = loss_scaling_factor self.optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs @@ -396,10 +395,13 @@ def _collect_round_params(self, block: nn.Module) -> Dict: params[n] = copy.deepcopy(m.state_dict()) return params - def _optim_step(self, *optimizers: Optimizer) -> None: + def _optim_step(self, *optimizers: Optimizer, scaler: Optional[Any] = None) -> None: for optimizer in optimizers: if optimizer: - optimizer.step() + if scaler: + scaler.step(optimizer) + else: + optimizer.step() optimizer.zero_grad() def _lr_sched_step(self, *lr_schedulers: Any) -> None: @@ -503,10 +505,14 @@ def _optimize_learned_round_block( init_loss = -1.0 last_best_iter = self.iters + scaler = None + use_amp = next(block.parameters()).dtype == torch.float32 + if use_amp: + scaler = GradScaler() + # Dictionary to store the rounding parameters yielding the lowest # training loss optimal_rounding_params = {} - torch.autograd.set_detect_anomaly(True) n_samples = len(cache) pbar = tqdm(range(self.iters), desc='') for i in pbar: @@ -514,14 +520,17 @@ def _optimize_learned_round_block( idxs = torch.randperm(n_samples)[:self.batch_size] inputs, fp_outs = cache.sample_batch(idxs) - # Run block forward to obtain quant outputs - quant_outs = block_forward(block, inputs) - fp_outs = send_to_device(fp_outs, quant_outs.device) - if self.use_amp: + if use_amp: with autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=self.amp_dtype): + # Run block forward to obtain quant outputs + quant_outs = block_forward(block, inputs) + fp_outs = send_to_device(fp_outs, quant_outs.device) loss, loss_components = block_loss(quant_outs, fp_outs) else: + # Run block forward to obtain quant outputs + quant_outs = block_forward(block, inputs) + fp_outs = send_to_device(fp_outs, quant_outs.device) loss, loss_components = block_loss(quant_outs.to(torch.float32), fp_outs.to(torch.float32)) # Save best parameters before taking gradient step @@ -534,10 +543,15 @@ def _optimize_learned_round_block( optimal_rounding_params = self._collect_round_params(block) # Scale loss and perform gradient step - loss = loss * self.loss_scaling_factor - loss.backward() - self._optim_step(optimizer, optimizer_scale) + if scaler: + scaler.scale(loss).backward() + else: + loss = loss * self.loss_scaling_factor + loss.backward() + self._optim_step(optimizer, optimizer_scale, scaler=scaler) self._lr_sched_step(lr_scheduler, lr_scheduler_scale) + if scaler: + scaler.update() # Update progress bar pbar.set_description("{}".format(block_loss.format_loss_components(*loss_components))) diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index 395423012..4adee6e2a 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -117,7 +117,6 @@ def apply_learned_round( batch_size: int = 8, learn_scale: bool = False, use_best_model: bool = True, - use_amp: bool = True, amp_dtype: torch.dtype = torch.float16, loss_scaling_factor: float = 1000, lr_scheduler: Optional[str] = "linear", @@ -149,7 +148,6 @@ def apply_learned_round( iters=iters, learn_scale=learn_scale, use_best_model=use_best_model, - use_amp=use_amp, amp_dtype=amp_dtype, loss_scaling_factor=loss_scaling_factor, learned_round_loss_kwargs=learned_round_loss_kwargs,