diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index 935a38cba..c9927d23b 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -586,7 +586,8 @@ def apply_learned_round( get_blocks_fn: Callable, model_prepare_fn: Optional[Callable] = None, model_finish_fn: Optional[Callable] = None, - keep_gpu: bool = True) -> None: + keep_gpu: bool = True, + fast_update: bool = False) -> None: # Perform any needed preprocessing before rounding optimisation, e.g. disabling caching in LLMs model_dict = None if model_prepare_fn is None else model_prepare_fn(model) @@ -602,26 +603,27 @@ def apply_learned_round( # Initialize cache to store partial inputs and outputs for each block cache.initialize_cache() - # Iterate over blocks and optimise the rounding parameters within each of them for block_idx, block in enumerate(blocks): # Distribute the model across devices to run a forward pass to capture # inputs/outputs to the given block - model = offload_model(model) - # Cache needs to be cleared before populating it with the inputs and outputs - # to the block under optimization. - self._populate_cache( - cache, - model, - model_forward, - block, - data_loader, - keep_gpu=keep_gpu, - capture_quant_input=True, - capture_quant_output=False, - ) - # Remove hooks needed to offload the model blocks to cpu - remove_hooks(model) + if block_idx == 0 or not fast_update: + cache.clear_cache() + model = offload_model(model) + # Cache needs to be cleared before populating it with the inputs and outputs + # to the block under optimization. + self._populate_cache( + cache, + model, + model_forward, + block, + data_loader, + keep_gpu=keep_gpu, + capture_quant_input=True, + capture_quant_output=False, + ) + # Remove hooks needed to offload the model blocks to cpu + remove_hooks(model) # Retrieve scales scale_params = return_scale_parameters(block) @@ -678,9 +680,60 @@ def apply_learned_round( # Move the block back to CPU block.cpu() - # Reset cache after optimisation - cache.clear_cache() + if block_idx + 1 < len(blocks) and fast_update: + cache = self.skip_full_execution(block, blocks[block_idx + 1], block_forward, cache) # The original configuration of the model is restored after finishing the optimization if model_finish_fn is not None: model_finish_fn(model, model_dict) + + def skip_full_execution(self, block, next_block, block_forward, cache): + + # We need to compute two inputs, one is a floating point one to compute float out + # The second is a quantized one to create the quantized input of the next blocks + + # We use the cache output to generate a new temporary dataloder for the next block + tmp_data_loader = [] + for i in range(len(cache)): + (args, kwargs), output = cache.sample_batch([i]) + + tmp_data_loader.append(((output,), kwargs)) + + # Temporary cache + tmp_cache = type(cache)() + + # We compute the floating point output of the upcoming block + if torch.cuda.is_available(): + next_block.cuda() + save_inputs_output( + next_block, + block_forward, + next_block, + tmp_data_loader, + tmp_cache, + store_inputs=False, + store_output=True, + keep_gpu=False, + disable_quant=True, + ) + next_block.cpu() + + cache['output'] = tmp_cache['output'] + + # Finally (!), we compute the quantized input of the next block + block.eval() + if torch.cuda.is_available(): + block.cuda() + next_quant_input = [] + pbar = tqdm(range(len(cache)), desc='', leave=False) + with torch.no_grad(): + for i in pbar: + (args, kwargs), _ = cache.sample_batch([i]) + out = block_forward(block, (args, kwargs)) + out = send_to_device(out, 'cpu') + next_quant_input.append((out,)) + pbar.close() + cache['args'] = next_quant_input + block.cpu() + + return cache diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index 3df861a93..de367bfdd 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -69,6 +69,7 @@ class CacheVision(Cache, dict): def __init__(self) -> None: super().__init__() self.batch_dim = 0 + self.initialize_cache() def store_inputs(self, args, kwargs) -> None: input_batch = args[0] diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index a97ad548b..64c82c80a 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -54,6 +54,7 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--export-prefix EXPORT_PREFIX] [--checkpoint-name CHECKPOINT_NAME] [--fuse-sequences] [--learned-round {None,linear_round}] + [--learned-round-fast-update] options: -h, --help show this help message and exit @@ -196,5 +197,8 @@ options: --learned-round {None,linear_round} Whether to use learned round. If `None`, RTN is used (default: None) + --learned-round-fast-update + Whether to use fast update with learned round. + Prototype (default: False) ``` diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index 2a2850eec..91d4ef405 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -23,6 +23,7 @@ class CacheLLM(Cache, dict): def __init__(self) -> None: super().__init__() + self.initialize_cache() def store_inputs(self, args, kwargs) -> None: self["args"].append(args) @@ -107,25 +108,25 @@ def get_blocks(model: nn.Module, block_name_attribute: str) -> List[nn.Module]: def apply_learned_round( - model: nn.Module, - calibration_loader: DataLoader, - iters: int = 200, - learned_round: str = "linear_round", - learned_round_loss: str = "mse", - block_name_attribute: str = "layers", - optimizer: str = "sign_sgd", - batch_size: int = 8, - learn_scale: bool = False, - use_best_model: bool = True, - amp_dtype: torch.dtype = torch.float16, - loss_scaling_factor: float = 1000, - lr_scheduler: Optional[str] = "linear", - optimizer_kwargs: Optional[Dict] = None, - lr_scheduler_kwargs: Optional[Dict] = None, - learned_round_loss_kwargs: Optional[Dict] = None, - scale_optimizer_class: Optional[str] = None, - scale_optimizer_kwargs: Optional[Dict] = None, -) -> None: + model: nn.Module, + calibration_loader: DataLoader, + iters: int = 200, + learned_round: str = "linear_round", + learned_round_loss: str = "mse", + block_name_attribute: str = "layers", + optimizer: str = "sign_sgd", + batch_size: int = 8, + learn_scale: bool = False, + use_best_model: bool = True, + amp_dtype: torch.dtype = torch.float16, + loss_scaling_factor: float = 1000, + lr_scheduler: Optional[str] = "linear", + optimizer_kwargs: Optional[Dict] = None, + lr_scheduler_kwargs: Optional[Dict] = None, + learned_round_loss_kwargs: Optional[Dict] = None, + scale_optimizer_class: Optional[str] = None, + scale_optimizer_kwargs: Optional[Dict] = None, + fast_update: bool = False) -> None: # Parse strings to obtain the arguments for the optimizer learned_round = parse_learned_round(learned_round) learned_round_loss_class = parse_learned_round_loss_class(learned_round_loss) @@ -166,4 +167,4 @@ def apply_learned_round( model_prepare_fn=llm_learned_round_prepare_fn, model_finish_fn=llm_learned_round_finish_fn, keep_gpu=False, - ) + fast_update=fast_update) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 6004ec97d..c8c76d4a1 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -385,7 +385,8 @@ def main(args): scale_optimizer_class='sgd', optimizer_kwargs={'lr': args.learned_round_lr}, scale_optimizer_kwargs={ - 'lr': args.learned_round_scale_lr, 'momentum': args.learned_round_scale_momentum}) + 'lr': args.learned_round_scale_lr, 'momentum': args.learned_round_scale_momentum}, + fast_update=args.learned_round_fast_update) print("Learned round applied.") model = offload_model(model) @@ -705,6 +706,11 @@ def parse_args(args): default=None, choices=[None, 'linear_round'], help='Whether to use learned round. If `None`, RTN is used (default: %(default)s)') + parser.add_argument( + '--learned-round-fast-update', + default=False, + action="store_true", + help='Whether to use fast update with learned round. Prototype (default: %(default)s)') return parser.parse_args(args)