diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index e43c8e01b..dad75ccb0 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -587,7 +587,7 @@ def apply_learned_round( model_prepare_fn: Optional[Callable] = None, model_finish_fn: Optional[Callable] = None, keep_gpu: bool = True, - partial_update: bool = False) -> None: + fast_update: bool = False) -> None: # Perform any needed preprocessing before rounding optimisation, e.g. disabling caching in LLMs model_dict = None if model_prepare_fn is None else model_prepare_fn(model) @@ -607,8 +607,8 @@ def apply_learned_round( # Iterate over blocks and optimise the rounding parameters within each of them for block_idx, block in enumerate(blocks): # Distribute the model across devices to run a forward pass to capture - # inputs/outputs to the given block# - if block_idx == 0 and partial_update: + # inputs/outputs to the given block + if block_idx == 0 or not fast_update: cache.clear_cache() model = offload_model(model) # Cache needs to be cleared before populating it with the inputs and outputs @@ -681,7 +681,7 @@ def apply_learned_round( # Move the block back to CPU block.cpu() - if block_idx + 1 < len(blocks) and partial_update: + if block_idx + 1 < len(blocks) and fast_update: cache, floating_point_datasets = self.skip_full_execution(block, blocks[block_idx+1], floating_point_datasets, block_forward, cache) # The original configuration of the model is restored after finishing the optimization @@ -707,32 +707,21 @@ def skip_full_execution(self, block, next_block, floating_point_datasets, block_ (args, kwargs), _ = cache.sample_batch([i]) floating_point_datasets.append((args, kwargs)) - # Then, we compute the floating point output of the current block - next_float_input = [] - block.cuda() - pbar = tqdm(floating_point_datasets, desc='', leave=False) - with torch.no_grad(): - for args, kwargs in pbar: - out = block_forward(block, (args, kwargs)) - out = send_to_device(out, 'cpu') - next_float_input.append((out,)) - pbar.close() - block.cpu() # We use this new output to generate a new temporary dataloder for the next block # and to update our floating_point_dataset new_data_loader = [] for i in range(len(cache)): - (args, kwargs), _ = cache.sample_batch([i]) - new_data_loader.append((next_float_input[i], kwargs)) + (args, kwargs), output = cache.sample_batch([i]) + new_data_loader.append(((output,), kwargs)) - _, fp_dataset_kwargs = floating_point_datasets[i] - floating_point_datasets[i] = (next_float_input[i], fp_dataset_kwargs) + floating_point_datasets[i] = ((output,), kwargs) # Temporary cache tmp_cache = type(cache)() # We compute the floating point output of the upcoming block - next_block.cuda() + if torch.cuda.is_available(): + next_block.cuda() save_inputs_output( next_block, block_forward, diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index d28bd286e..54a8510f4 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -110,25 +110,25 @@ def get_blocks(model: nn.Module, block_name_attribute: str) -> List[nn.Module]: def apply_learned_round( - model: nn.Module, - calibration_loader: DataLoader, - iters: int = 200, - learned_round: str = "linear_round", - learned_round_loss: str = "mse", - block_name_attribute: str = "layers", - optimizer: str = "sign_sgd", - batch_size: int = 8, - learn_scale: bool = False, - use_best_model: bool = True, - amp_dtype: torch.dtype = torch.float16, - loss_scaling_factor: float = 1000, - lr_scheduler: Optional[str] = "linear", - optimizer_kwargs: Optional[Dict] = None, - lr_scheduler_kwargs: Optional[Dict] = None, - learned_round_loss_kwargs: Optional[Dict] = None, - scale_optimizer_class: Optional[str] = None, - scale_optimizer_kwargs: Optional[Dict] = None, -) -> None: + model: nn.Module, + calibration_loader: DataLoader, + iters: int = 200, + learned_round: str = "linear_round", + learned_round_loss: str = "mse", + block_name_attribute: str = "layers", + optimizer: str = "sign_sgd", + batch_size: int = 8, + learn_scale: bool = False, + use_best_model: bool = True, + amp_dtype: torch.dtype = torch.float16, + loss_scaling_factor: float = 1000, + lr_scheduler: Optional[str] = "linear", + optimizer_kwargs: Optional[Dict] = None, + lr_scheduler_kwargs: Optional[Dict] = None, + learned_round_loss_kwargs: Optional[Dict] = None, + scale_optimizer_class: Optional[str] = None, + scale_optimizer_kwargs: Optional[Dict] = None, + fast_update: bool = False) -> None: # Parse strings to obtain the arguments for the optimizer learned_round = parse_learned_round(learned_round) learned_round_loss_class = parse_learned_round_loss_class(learned_round_loss) @@ -169,4 +169,4 @@ def apply_learned_round( model_prepare_fn=llm_learned_round_prepare_fn, model_finish_fn=llm_learned_round_finish_fn, keep_gpu=False, - partial_update=True) + partial_update=fast_update) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 6004ec97d..203a86e89 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -385,7 +385,8 @@ def main(args): scale_optimizer_class='sgd', optimizer_kwargs={'lr': args.learned_round_lr}, scale_optimizer_kwargs={ - 'lr': args.learned_round_scale_lr, 'momentum': args.learned_round_scale_momentum}) + 'lr': args.learned_round_scale_lr, 'momentum': args.learned_round_scale_momentum}, + fast_update=args.learned_round_fast_update) print("Learned round applied.") model = offload_model(model) @@ -705,6 +706,11 @@ def parse_args(args): default=None, choices=[None, 'linear_round'], help='Whether to use learned round. If `None`, RTN is used (default: %(default)s)') + parser.add_argument( + '--learned-round-fast-update', + default=False, + type=bool, + help='Whether to use fast update with learned round. Prototype (default: %(default)s)') return parser.parse_args(args)