Skip to content

Commit

Permalink
Feat (ptq/learned_round): fast amp training
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 3, 2024
1 parent cacdd70 commit acee4d3
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@
from accelerate.utils.operations import send_to_device
import torch
from torch import autocast

try:
from torch import GradScaler
except:
from torch.cuda.amp import GradScaler

from torch import nn
from torch.optim.optimizer import Optimizer
from torch.utils.data.dataloader import DataLoader
Expand Down Expand Up @@ -351,7 +357,6 @@ def __init__(
iters: int = 200,
learn_scale: bool = False,
use_best_model: bool = True,
use_amp: bool = True,
amp_dtype: torch.dtype = torch.float16,
loss_scaling_factor: float = 1000.,
learned_round_loss_kwargs: Optional[Dict] = None,
Expand All @@ -370,7 +375,6 @@ def __init__(
self.iters = iters
self.learn_scale = learn_scale
self.use_best_model = use_best_model
self.use_amp = use_amp
self.amp_dtype = amp_dtype
self.loss_scaling_factor = loss_scaling_factor
self.optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
Expand All @@ -396,10 +400,13 @@ def _collect_round_params(self, block: nn.Module) -> Dict:
params[n] = copy.deepcopy(m.state_dict())
return params

def _optim_step(self, *optimizers: Optimizer) -> None:
def _optim_step(self, *optimizers: Optimizer, scaler: Optional[Any] = None) -> None:
for optimizer in optimizers:
if optimizer:
optimizer.step()
if scaler:
scaler.step(optimizer)
else:
optimizer.step()
optimizer.zero_grad()

def _lr_sched_step(self, *lr_schedulers: Any) -> None:
Expand Down Expand Up @@ -503,25 +510,32 @@ def _optimize_learned_round_block(
init_loss = -1.0
last_best_iter = self.iters

scaler = None
use_amp = next(block.parameters()).dtype == torch.float32
if use_amp:
scaler = GradScaler()

# Dictionary to store the rounding parameters yielding the lowest
# training loss
optimal_rounding_params = {}
torch.autograd.set_detect_anomaly(True)
n_samples = len(cache)
pbar = tqdm(range(self.iters), desc='')
for i in pbar:
# Sample mini-batch from cache
idxs = torch.randperm(n_samples)[:self.batch_size]
inputs, fp_outs = cache.sample_batch(idxs)

# Run block forward to obtain quant outputs
quant_outs = block_forward(block, inputs)
fp_outs = send_to_device(fp_outs, quant_outs.device)
if self.use_amp:
if use_amp:
with autocast(device_type="cuda" if torch.cuda.is_available() else "cpu",
dtype=self.amp_dtype):
# Run block forward to obtain quant outputs
quant_outs = block_forward(block, inputs)
fp_outs = send_to_device(fp_outs, quant_outs.device)
loss, loss_components = block_loss(quant_outs, fp_outs)
else:
# Run block forward to obtain quant outputs
quant_outs = block_forward(block, inputs)
fp_outs = send_to_device(fp_outs, quant_outs.device)
loss, loss_components = block_loss(quant_outs.to(torch.float32), fp_outs.to(torch.float32))

# Save best parameters before taking gradient step
Expand All @@ -534,10 +548,15 @@ def _optimize_learned_round_block(
optimal_rounding_params = self._collect_round_params(block)

# Scale loss and perform gradient step
loss = loss * self.loss_scaling_factor
loss.backward()
self._optim_step(optimizer, optimizer_scale)
if scaler:
scaler.scale(loss).backward()
else:
loss = loss * self.loss_scaling_factor
loss.backward()
self._optim_step(optimizer, optimizer_scale, scaler=scaler)
self._lr_sched_step(lr_scheduler, lr_scheduler_scale)
if scaler:
scaler.update()

# Update progress bar
pbar.set_description("{}".format(block_loss.format_loss_components(*loss_components)))
Expand Down
4 changes: 1 addition & 3 deletions src/brevitas_examples/llm/llm_quant/learned_round_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import functools
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Union

from accelerate.utils.operations import send_to_device
import torch
Expand Down Expand Up @@ -117,7 +117,6 @@ def apply_learned_round(
batch_size: int = 8,
learn_scale: bool = False,
use_best_model: bool = True,
use_amp: bool = True,
amp_dtype: torch.dtype = torch.float16,
loss_scaling_factor: float = 1000,
lr_scheduler: Optional[str] = "linear",
Expand Down Expand Up @@ -149,7 +148,6 @@ def apply_learned_round(
iters=iters,
learn_scale=learn_scale,
use_best_model=use_best_model,
use_amp=use_amp,
amp_dtype=amp_dtype,
loss_scaling_factor=loss_scaling_factor,
learned_round_loss_kwargs=learned_round_loss_kwargs,
Expand Down

0 comments on commit acee4d3

Please sign in to comment.