From ff658927541d3a8bdda45bc23c19a19381b04155 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 28 Nov 2024 18:15:52 +0000 Subject: [PATCH] Minor changes --- .../learned_round/learned_round_optimizer.py | 125 +----------------- .../ptq/learned_round_utils.py | 8 +- .../llm/llm_quant/learned_round_utils.py | 2 + src/brevitas_examples/llm/main.py | 12 ++ 4 files changed, 22 insertions(+), 125 deletions(-) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index ea068474e..5db0bdd83 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -364,6 +364,7 @@ def __init__( scale_optimizer_class: Type[Optimizer] = SGD, lr_scheduler_class: Optional[Type[LRScheduler]] = LinearLR, optimizer_lr: float = 5e-3, + optimizer_scale_lr: float = 5e-3, batch_size: float = 8, iters: int = 200, learn_scale: bool = False, @@ -371,7 +372,6 @@ def __init__( use_amp: bool = True, amp_dtype: torch.dtype = torch.float16, loss_scaling_factor: float = 1000., - use_accelerate: bool = False, learned_round_loss_kwargs: Optional[Dict] = None, optimizer_kwargs: Optional[Dict] = None, lr_scheduler_kwargs: Optional[Dict] = None, @@ -381,6 +381,7 @@ def __init__( self.scale_optimizer_class = scale_optimizer_class self.lr_scheduler_class = lr_scheduler_class self.optimizer_lr = optimizer_lr + self.optimizer_scale_lr = optimizer_scale_lr self.batch_size = batch_size self.iters = iters self.learn_scale = learn_scale @@ -396,10 +397,6 @@ def __init__( self.learned_round_loss_init = partial( learned_round_loss_class, **learned_round_loss_kwargs) - # TODO: Remove once validated and expose the flag - # self.use_accelerate = use_accelerate - self.use_accelerate = False - @torch.no_grad() def _load_round_params(self, block: nn.Module, round_params: Dict) -> None: for n, m in block.named_modules(): @@ -508,8 +505,7 @@ def _optimize_learned_round_block( if self.learn_scale and scale_params is not None: optimizer_scale = self.scale_optimizer_class( scale_params, - lr=self.optimizer_lr, - momentum=0.9, + lr=self.optimizer_scale_lr, **self.optimizer_kwargs, ) lr_scheduler_scale = ( @@ -579,114 +575,6 @@ def _optimize_learned_round_block( return init_loss, best_loss, last_best_iter - # TODO: Enable saving best parameters - def _accelerate_optimize_learned_round_block( - self, - block: nn.Module, - block_learned_round_modules: List[nn.Module], - cache: Cache, - block_loss: LearnedRoundLoss, - block_forward: Callable, - ) -> Tuple[float, float, int]: - # Enable running in mixed precision - TORCH_DTYPE_TO_PRECISION_TYPE_MAP = { - torch.float16: PrecisionType.FP16, - torch.bfloat16: PrecisionType.BF16,} - raise_warning_dtype = False - if not self.use_amp: - mixed_precision_type = None - else: - if self.amp_dtype not in TORCH_DTYPE_TO_PRECISION_TYPE_MAP: - raise_warning_dtype = True - mixed_precision_type = None - else: - mixed_precision_type = TORCH_DTYPE_TO_PRECISION_TYPE_MAP[self.amp_dtype] - # Instantiate accelerator to run in a multi-GPU setting - accelerator = Accelerator(mixed_precision=mixed_precision_type) - - # Raise warning if the AMP dtype was defaulted to float32. This warning is raised after - # the instantiation of accelerator, to use its print functionality so the message is only - # printed once. - if raise_warning_dtype: - accelerator.print( - f"The dtype {self.amp_dtype} cannot be used for AMP training with accelerate. Defaulting to float32." - ) - - # Initilalize optimizer and LR scheduler - optimizer = self.optimizer_class( - itertools.chain( - *[ - block_learned_round_module.parameters() - for block_learned_round_module in block_learned_round_modules]), - lr=self.optimizer_lr, - **self.optimizer_kwargs, - ) - lr_scheduler = ( - self.lr_scheduler_class(optimizer, **self.lr_scheduler_kwargs) - if self.lr_scheduler_class else None) - - # Prepare dataset from cache - cache_dataset = cache.cache_to_dataset() - cache_dataloader = DataLoader( - cache_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=cache.collate_fn) - - # Prepare elements for training - cache_dataloader, block, optimizer, lr_scheduler = accelerator.prepare(cache_dataloader, block, optimizer, lr_scheduler) - - # Variables needed for printing - best_loss = torch.finfo(torch.float).max - init_loss = -1.0 - last_best_iter = self.iters - - # Initialize an iterator to extract elements from the cache dataloader - cache_iterator = iter(cache_dataloader) - - pbar = tqdm_accelerate(range(self.iters), desc='') - for i in pbar: - # Sample mini-batch from cache - inputs, fp_outs = next(cache_iterator) - - # Run block forward to obtain quant outputs - quant_outs = block_forward(block, inputs) - # Compute loss using the block loss function - loss, loss_components = block_loss(quant_outs, fp_outs) - - # Save best parameters before taking gradient step - curr_loss = loss.detach().cpu().item() - init_loss = curr_loss if i == 0 else init_loss - if loss < best_loss: - best_loss = curr_loss - last_best_iter = i + 1 - - # Scale loss and perform gradient step - # loss = loss * self.loss_scaling_factor - accelerator.backward(loss) - self._step(optimizer, lr_scheduler) - - # Update progress bar - pbar.set_description("{}".format(block_loss.format_loss_components(*loss_components))) - - # Make sure no updates are received in the progress bar - pbar.close() - - # TODO: Include support for saving the best configuration during training - if not self.use_best_model: - # Override if the model with the lowest training error is not used - best_loss = curr_loss - last_best_iter = self.iters - - # TODO: Verify if this call is actually needed - # Wait for everyone before proceding to next block - accelerator.wait_for_everyone() - # Remove all the wrapper around the block - block = accelerator.unwrap_model(block) - # Clear memory - accelerator.free_memory() - # Move the block back to CPU - block.cpu() - - return init_loss, best_loss, last_best_iter - def apply_learned_round( self, model: nn.Module, @@ -764,11 +652,7 @@ def apply_learned_round( ) # Optimize block rounding - init_loss, best_loss, last_best_iter = ( - self._optimize_learned_round_block - if not self.use_accelerate - else self._accelerate_optimize_learned_round_block - )( + init_loss, best_loss, last_best_iter = self._optimize_learned_round_block( block=block, block_learned_round_modules=block_learned_round_modules, cache=cache, @@ -793,7 +677,6 @@ def apply_learned_round( # Move the block back to CPU block.cpu() - # TODO: This call might not be needed, check_clear and reset_cache methods # Reset cache after optimisation cache.clear_cache() diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index 8994bbc25..4c9b05ea8 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -120,14 +120,14 @@ def __len__(self): if isinstance(self["inputs"], list) else self["inputs"].shape[self.batch_dim]) -def cnn_forward(model: nn.Module, inputs: Any) -> None: +def vision_forward(model: nn.Module, inputs: Any) -> None: device = next(model.parameters()).device img, _ = inputs img = send_to_device(img, device) model(img) -def cnn_block_forward(block: nn.Module, inputs: Any) -> torch.Tensor: +def vision_block_forward(block: nn.Module, inputs: Any) -> torch.Tensor: device = next(block.parameters()).device inputs = send_to_device(inputs, device) return block(inputs) @@ -186,8 +186,8 @@ def apply_learned_round( cache = CacheVision() learned_round_optimizer.apply_learned_round( model=model, - model_forward=cnn_forward, - block_forward=cnn_block_forward, + model_forward=vision_forward, + block_forward=vision_block_forward, data_loader=calibration_loader, cache=cache, get_blocks_fn=get_blocks_fn, diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index 6dad361ab..ad6d43693 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -147,6 +147,7 @@ def apply_learned_round( block_name_attribute: str = "layers", optimizer: str = "sign_sgd", optimizer_lr: float = 5e-3, + optimizer_scale_lr: float = 5e-3, batch_size: int = 8, learn_scale: bool = False, use_best_model: bool = True, @@ -176,6 +177,7 @@ def apply_learned_round( optimizer_class=optimizer_class, lr_scheduler_class=lr_scheduler_class, optimizer_lr=optimizer_lr, + optimizer_scale_lr=optimizer_scale_lr, batch_size=batch_size, iters=iters, learn_scale=learn_scale, diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index ec15fe1f2..aa5365b38 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -376,6 +376,8 @@ def main(args): calibration_loader, iters=args.learned_round_iters, block_name_attribute=args.gpxq_block_name, + optimizer_lr=args.learned_round_lr, + optimizer_scale_lr=args.learned_round_scale_lr, learn_scale=args.learned_round_scale, ) print("Learned round applied.") @@ -566,6 +568,16 @@ def parse_args(args): type=int, default=64, help='Group size for per_group input quantization. Default: 64.') + parser.add_argument( + '--learned-round-lr', + type=float, + default=5e-3, + help='Learning rate for learned round parameter optimization. Default: %(default)s') + parser.add_argument( + '--learned-round-scale-lr', + type=float, + default=5e-3, + help='Learning rate for scale optimization during round learning. Default: %(default)s') parser.add_argument( '--learned-round-iters', type=int,