Skip to content

Commit

Permalink
Minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Nov 28, 2024
1 parent a9e9b52 commit ff65892
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 125 deletions.
125 changes: 4 additions & 121 deletions src/brevitas_examples/common/learned_round/learned_round_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,14 +364,14 @@ def __init__(
scale_optimizer_class: Type[Optimizer] = SGD,
lr_scheduler_class: Optional[Type[LRScheduler]] = LinearLR,
optimizer_lr: float = 5e-3,
optimizer_scale_lr: float = 5e-3,
batch_size: float = 8,
iters: int = 200,
learn_scale: bool = False,
use_best_model: bool = True,
use_amp: bool = True,
amp_dtype: torch.dtype = torch.float16,
loss_scaling_factor: float = 1000.,
use_accelerate: bool = False,
learned_round_loss_kwargs: Optional[Dict] = None,
optimizer_kwargs: Optional[Dict] = None,
lr_scheduler_kwargs: Optional[Dict] = None,
Expand All @@ -381,6 +381,7 @@ def __init__(
self.scale_optimizer_class = scale_optimizer_class
self.lr_scheduler_class = lr_scheduler_class
self.optimizer_lr = optimizer_lr
self.optimizer_scale_lr = optimizer_scale_lr
self.batch_size = batch_size
self.iters = iters
self.learn_scale = learn_scale
Expand All @@ -396,10 +397,6 @@ def __init__(
self.learned_round_loss_init = partial(
learned_round_loss_class, **learned_round_loss_kwargs)

# TODO: Remove once validated and expose the flag
# self.use_accelerate = use_accelerate
self.use_accelerate = False

@torch.no_grad()
def _load_round_params(self, block: nn.Module, round_params: Dict) -> None:
for n, m in block.named_modules():
Expand Down Expand Up @@ -508,8 +505,7 @@ def _optimize_learned_round_block(
if self.learn_scale and scale_params is not None:
optimizer_scale = self.scale_optimizer_class(
scale_params,
lr=self.optimizer_lr,
momentum=0.9,
lr=self.optimizer_scale_lr,
**self.optimizer_kwargs,
)
lr_scheduler_scale = (
Expand Down Expand Up @@ -579,114 +575,6 @@ def _optimize_learned_round_block(

return init_loss, best_loss, last_best_iter

# TODO: Enable saving best parameters
def _accelerate_optimize_learned_round_block(
self,
block: nn.Module,
block_learned_round_modules: List[nn.Module],
cache: Cache,
block_loss: LearnedRoundLoss,
block_forward: Callable,
) -> Tuple[float, float, int]:
# Enable running in mixed precision
TORCH_DTYPE_TO_PRECISION_TYPE_MAP = {
torch.float16: PrecisionType.FP16,
torch.bfloat16: PrecisionType.BF16,}
raise_warning_dtype = False
if not self.use_amp:
mixed_precision_type = None
else:
if self.amp_dtype not in TORCH_DTYPE_TO_PRECISION_TYPE_MAP:
raise_warning_dtype = True
mixed_precision_type = None
else:
mixed_precision_type = TORCH_DTYPE_TO_PRECISION_TYPE_MAP[self.amp_dtype]
# Instantiate accelerator to run in a multi-GPU setting
accelerator = Accelerator(mixed_precision=mixed_precision_type)

# Raise warning if the AMP dtype was defaulted to float32. This warning is raised after
# the instantiation of accelerator, to use its print functionality so the message is only
# printed once.
if raise_warning_dtype:
accelerator.print(
f"The dtype {self.amp_dtype} cannot be used for AMP training with accelerate. Defaulting to float32."
)

# Initilalize optimizer and LR scheduler
optimizer = self.optimizer_class(
itertools.chain(
*[
block_learned_round_module.parameters()
for block_learned_round_module in block_learned_round_modules]),
lr=self.optimizer_lr,
**self.optimizer_kwargs,
)
lr_scheduler = (
self.lr_scheduler_class(optimizer, **self.lr_scheduler_kwargs)
if self.lr_scheduler_class else None)

# Prepare dataset from cache
cache_dataset = cache.cache_to_dataset()
cache_dataloader = DataLoader(
cache_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=cache.collate_fn)

# Prepare elements for training
cache_dataloader, block, optimizer, lr_scheduler = accelerator.prepare(cache_dataloader, block, optimizer, lr_scheduler)

# Variables needed for printing
best_loss = torch.finfo(torch.float).max
init_loss = -1.0
last_best_iter = self.iters

# Initialize an iterator to extract elements from the cache dataloader
cache_iterator = iter(cache_dataloader)

pbar = tqdm_accelerate(range(self.iters), desc='')
for i in pbar:
# Sample mini-batch from cache
inputs, fp_outs = next(cache_iterator)

# Run block forward to obtain quant outputs
quant_outs = block_forward(block, inputs)
# Compute loss using the block loss function
loss, loss_components = block_loss(quant_outs, fp_outs)

# Save best parameters before taking gradient step
curr_loss = loss.detach().cpu().item()
init_loss = curr_loss if i == 0 else init_loss
if loss < best_loss:
best_loss = curr_loss
last_best_iter = i + 1

# Scale loss and perform gradient step
# loss = loss * self.loss_scaling_factor
accelerator.backward(loss)
self._step(optimizer, lr_scheduler)

# Update progress bar
pbar.set_description("{}".format(block_loss.format_loss_components(*loss_components)))

# Make sure no updates are received in the progress bar
pbar.close()

# TODO: Include support for saving the best configuration during training
if not self.use_best_model:
# Override if the model with the lowest training error is not used
best_loss = curr_loss
last_best_iter = self.iters

# TODO: Verify if this call is actually needed
# Wait for everyone before proceding to next block
accelerator.wait_for_everyone()
# Remove all the wrapper around the block
block = accelerator.unwrap_model(block)
# Clear memory
accelerator.free_memory()
# Move the block back to CPU
block.cpu()

return init_loss, best_loss, last_best_iter

def apply_learned_round(
self,
model: nn.Module,
Expand Down Expand Up @@ -764,11 +652,7 @@ def apply_learned_round(
)

# Optimize block rounding
init_loss, best_loss, last_best_iter = (
self._optimize_learned_round_block
if not self.use_accelerate
else self._accelerate_optimize_learned_round_block
)(
init_loss, best_loss, last_best_iter = self._optimize_learned_round_block(
block=block,
block_learned_round_modules=block_learned_round_modules,
cache=cache,
Expand All @@ -793,7 +677,6 @@ def apply_learned_round(
# Move the block back to CPU
block.cpu()

# TODO: This call might not be needed, check_clear and reset_cache methods
# Reset cache after optimisation
cache.clear_cache()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,14 @@ def __len__(self):
if isinstance(self["inputs"], list) else self["inputs"].shape[self.batch_dim])


def cnn_forward(model: nn.Module, inputs: Any) -> None:
def vision_forward(model: nn.Module, inputs: Any) -> None:
device = next(model.parameters()).device
img, _ = inputs
img = send_to_device(img, device)
model(img)


def cnn_block_forward(block: nn.Module, inputs: Any) -> torch.Tensor:
def vision_block_forward(block: nn.Module, inputs: Any) -> torch.Tensor:
device = next(block.parameters()).device
inputs = send_to_device(inputs, device)
return block(inputs)
Expand Down Expand Up @@ -186,8 +186,8 @@ def apply_learned_round(
cache = CacheVision()
learned_round_optimizer.apply_learned_round(
model=model,
model_forward=cnn_forward,
block_forward=cnn_block_forward,
model_forward=vision_forward,
block_forward=vision_block_forward,
data_loader=calibration_loader,
cache=cache,
get_blocks_fn=get_blocks_fn,
Expand Down
2 changes: 2 additions & 0 deletions src/brevitas_examples/llm/llm_quant/learned_round_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def apply_learned_round(
block_name_attribute: str = "layers",
optimizer: str = "sign_sgd",
optimizer_lr: float = 5e-3,
optimizer_scale_lr: float = 5e-3,
batch_size: int = 8,
learn_scale: bool = False,
use_best_model: bool = True,
Expand Down Expand Up @@ -176,6 +177,7 @@ def apply_learned_round(
optimizer_class=optimizer_class,
lr_scheduler_class=lr_scheduler_class,
optimizer_lr=optimizer_lr,
optimizer_scale_lr=optimizer_scale_lr,
batch_size=batch_size,
iters=iters,
learn_scale=learn_scale,
Expand Down
12 changes: 12 additions & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ def main(args):
calibration_loader,
iters=args.learned_round_iters,
block_name_attribute=args.gpxq_block_name,
optimizer_lr=args.learned_round_lr,
optimizer_scale_lr=args.learned_round_scale_lr,
learn_scale=args.learned_round_scale,
)
print("Learned round applied.")
Expand Down Expand Up @@ -566,6 +568,16 @@ def parse_args(args):
type=int,
default=64,
help='Group size for per_group input quantization. Default: 64.')
parser.add_argument(
'--learned-round-lr',
type=float,
default=5e-3,
help='Learning rate for learned round parameter optimization. Default: %(default)s')
parser.add_argument(
'--learned-round-scale-lr',
type=float,
default=5e-3,
help='Learning rate for scale optimization during round learning. Default: %(default)s')
parser.add_argument(
'--learned-round-iters',
type=int,
Expand Down

0 comments on commit ff65892

Please sign in to comment.