From 1c5229f47e27b08e383fdbcfdf7319383e067645 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 4 Dec 2024 16:43:50 +0000 Subject: [PATCH 01/12] Feat (llm/learned_round): fast block update --- .../learned_round/learned_round_optimizer.py | 120 +++++++++++++++--- .../llm/llm_quant/learned_round_utils.py | 2 +- 2 files changed, 101 insertions(+), 21 deletions(-) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index 935a38cba..e9430155f 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -586,7 +586,8 @@ def apply_learned_round( get_blocks_fn: Callable, model_prepare_fn: Optional[Callable] = None, model_finish_fn: Optional[Callable] = None, - keep_gpu: bool = True) -> None: + keep_gpu: bool = True, + partial_update: bool = False) -> None: # Perform any needed preprocessing before rounding optimisation, e.g. disabling caching in LLMs model_dict = None if model_prepare_fn is None else model_prepare_fn(model) @@ -602,26 +603,28 @@ def apply_learned_round( # Initialize cache to store partial inputs and outputs for each block cache.initialize_cache() - + floating_point_datasets = [] # Iterate over blocks and optimise the rounding parameters within each of them for block_idx, block in enumerate(blocks): # Distribute the model across devices to run a forward pass to capture - # inputs/outputs to the given block - model = offload_model(model) - # Cache needs to be cleared before populating it with the inputs and outputs - # to the block under optimization. - self._populate_cache( - cache, - model, - model_forward, - block, - data_loader, - keep_gpu=keep_gpu, - capture_quant_input=True, - capture_quant_output=False, - ) - # Remove hooks needed to offload the model blocks to cpu - remove_hooks(model) + # inputs/outputs to the given block# + if block_idx == 0 and partial_update: + cache.clear_cache() + model = offload_model(model) + # Cache needs to be cleared before populating it with the inputs and outputs + # to the block under optimization. + self._populate_cache( + cache, + model, + model_forward, + block, + data_loader, + keep_gpu=keep_gpu, + capture_quant_input=True, + capture_quant_output=False, + ) + # Remove hooks needed to offload the model blocks to cpu + remove_hooks(model) # Retrieve scales scale_params = return_scale_parameters(block) @@ -678,9 +681,86 @@ def apply_learned_round( # Move the block back to CPU block.cpu() - # Reset cache after optimisation - cache.clear_cache() + if block_idx + 1 < len(blocks) and partial_update: + cache, floating_point_datasets = self.skip_full_execution(block, blocks[block_idx+1], floating_point_datasets, block_forward, cache) # The original configuration of the model is restored after finishing the optimization if model_finish_fn is not None: model_finish_fn(model, model_dict) + + def skip_full_execution(self, block, next_block, floating_point_datasets, block_forward, cache): + + # We need to propagate two datasets, one is a floating point dataset to compute float out + # The second is a quantized dataset to create the quantized input of the next blocks + + # First, we disable quantization + disable_quant_class = DisableEnableQuantization() + disable_quant_class.disable_act_quantization(block, False) + disable_quant_class.disable_param_quantization(block, False) + return_quant_tensor_state = disable_return_quant_tensor(block) + + # If we don't have a floating_point_dataset, we retrieve it from the cache + # The idea is that the cache contains the input to the very first block, and there is nothing + # quantized before that. This is a moderately strong assumption + if len(floating_point_datasets) <= 0: + for i in range(len(cache)): + (args, kwargs), _ = cache.sample_batch([i]) + floating_point_datasets.append((args, kwargs)) + + # Then, we compute the floating point output of the current block + next_float_input = [] + with torch.no_grad(): + for args, kwargs in floating_point_datasets: + out = block_forward(block, (args, kwargs)) + out = send_to_device(out, 'cpu') + next_float_input.append((out,)) + # We use this new output to generate a new temporary dataloder for the next block + # and to update our floating_point_dataset + new_data_loader = [] + for i in range(len(cache)): + (args, kwargs), _ = cache.sample_batch([i]) + new_data_loader.append((next_float_input[i], kwargs)) + + _, fp_dataset_kwargs = floating_point_datasets[i] + floating_point_datasets[i] = (next_float_input[i], fp_dataset_kwargs) + + # Temporary cache + tmp_cache = copy.deepcopy(cache) + tmp_cache.clear_cache() + + # We compute the floating point output of the upcoming block + next_block.cuda() + save_inputs_output( + next_block, + block_forward, + next_block, + new_data_loader, + tmp_cache, + store_inputs=False, + store_output=True, + keep_gpu=False, + disable_quant=True, + ) + next_block.cpu() + + cache['output'] = tmp_cache['output'] + + # Re-enable quantization + disable_quant_class.enable_act_quantization(block, False) + disable_quant_class.enable_param_quantization(block, False) + restore_return_quant_tensor(block, return_quant_tensor_state) + + # Finally (!), we compute the quantized input of the next block + block.eval() + block.cuda() + next_quant_input = [] + with torch.no_grad(): + for i in range(len(cache)): + (args, kwargs), _ = cache.sample_batch([i]) + out = block_forward(block, (args, kwargs)) + out = send_to_device(out, 'cpu') + next_quant_input.append((out,)) + cache['args'] = copy.deepcopy(next_quant_input) + block.cpu() + + return cache, floating_point_datasets diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index 2a2850eec..ccf0c11b1 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -166,4 +166,4 @@ def apply_learned_round( model_prepare_fn=llm_learned_round_prepare_fn, model_finish_fn=llm_learned_round_finish_fn, keep_gpu=False, - ) + partial_update=True) From f0fc1910f02d049d3147ceeedd5676ef14d50747 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 4 Dec 2024 18:26:40 +0000 Subject: [PATCH 02/12] Review --- .../learned_round/learned_round_optimizer.py | 15 ++++++++++----- .../llm/llm_quant/learned_round_utils.py | 3 +++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index e9430155f..e43c8e01b 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -709,11 +709,15 @@ def skip_full_execution(self, block, next_block, floating_point_datasets, block_ # Then, we compute the floating point output of the current block next_float_input = [] + block.cuda() + pbar = tqdm(floating_point_datasets, desc='', leave=False) with torch.no_grad(): - for args, kwargs in floating_point_datasets: + for args, kwargs in pbar: out = block_forward(block, (args, kwargs)) out = send_to_device(out, 'cpu') next_float_input.append((out,)) + pbar.close() + block.cpu() # We use this new output to generate a new temporary dataloder for the next block # and to update our floating_point_dataset new_data_loader = [] @@ -725,8 +729,7 @@ def skip_full_execution(self, block, next_block, floating_point_datasets, block_ floating_point_datasets[i] = (next_float_input[i], fp_dataset_kwargs) # Temporary cache - tmp_cache = copy.deepcopy(cache) - tmp_cache.clear_cache() + tmp_cache = type(cache)() # We compute the floating point output of the upcoming block next_block.cuda() @@ -754,13 +757,15 @@ def skip_full_execution(self, block, next_block, floating_point_datasets, block_ block.eval() block.cuda() next_quant_input = [] + pbar = tqdm(range(len(cache)), desc='', leave=False) with torch.no_grad(): - for i in range(len(cache)): + for i in pbar: (args, kwargs), _ = cache.sample_batch([i]) out = block_forward(block, (args, kwargs)) out = send_to_device(out, 'cpu') next_quant_input.append((out,)) - cache['args'] = copy.deepcopy(next_quant_input) + cache['args'] = next_quant_input block.cpu() + pbar.close() return cache, floating_point_datasets diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index ccf0c11b1..d28bd286e 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -23,6 +23,9 @@ class CacheLLM(Cache, dict): def __init__(self) -> None: super().__init__() + self["args"] = [] + self["kwargs"] = [] + self["output"] = [] def store_inputs(self, args, kwargs) -> None: self["args"].append(args) From fbb91094d3006bdcbdc98a61816184f4d38a4d0a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 4 Dec 2024 22:05:34 +0000 Subject: [PATCH 03/12] review --- .../learned_round/learned_round_optimizer.py | 29 +++++--------- .../llm/llm_quant/learned_round_utils.py | 40 +++++++++---------- src/brevitas_examples/llm/main.py | 8 +++- 3 files changed, 36 insertions(+), 41 deletions(-) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index e43c8e01b..dad75ccb0 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -587,7 +587,7 @@ def apply_learned_round( model_prepare_fn: Optional[Callable] = None, model_finish_fn: Optional[Callable] = None, keep_gpu: bool = True, - partial_update: bool = False) -> None: + fast_update: bool = False) -> None: # Perform any needed preprocessing before rounding optimisation, e.g. disabling caching in LLMs model_dict = None if model_prepare_fn is None else model_prepare_fn(model) @@ -607,8 +607,8 @@ def apply_learned_round( # Iterate over blocks and optimise the rounding parameters within each of them for block_idx, block in enumerate(blocks): # Distribute the model across devices to run a forward pass to capture - # inputs/outputs to the given block# - if block_idx == 0 and partial_update: + # inputs/outputs to the given block + if block_idx == 0 or not fast_update: cache.clear_cache() model = offload_model(model) # Cache needs to be cleared before populating it with the inputs and outputs @@ -681,7 +681,7 @@ def apply_learned_round( # Move the block back to CPU block.cpu() - if block_idx + 1 < len(blocks) and partial_update: + if block_idx + 1 < len(blocks) and fast_update: cache, floating_point_datasets = self.skip_full_execution(block, blocks[block_idx+1], floating_point_datasets, block_forward, cache) # The original configuration of the model is restored after finishing the optimization @@ -707,32 +707,21 @@ def skip_full_execution(self, block, next_block, floating_point_datasets, block_ (args, kwargs), _ = cache.sample_batch([i]) floating_point_datasets.append((args, kwargs)) - # Then, we compute the floating point output of the current block - next_float_input = [] - block.cuda() - pbar = tqdm(floating_point_datasets, desc='', leave=False) - with torch.no_grad(): - for args, kwargs in pbar: - out = block_forward(block, (args, kwargs)) - out = send_to_device(out, 'cpu') - next_float_input.append((out,)) - pbar.close() - block.cpu() # We use this new output to generate a new temporary dataloder for the next block # and to update our floating_point_dataset new_data_loader = [] for i in range(len(cache)): - (args, kwargs), _ = cache.sample_batch([i]) - new_data_loader.append((next_float_input[i], kwargs)) + (args, kwargs), output = cache.sample_batch([i]) + new_data_loader.append(((output,), kwargs)) - _, fp_dataset_kwargs = floating_point_datasets[i] - floating_point_datasets[i] = (next_float_input[i], fp_dataset_kwargs) + floating_point_datasets[i] = ((output,), kwargs) # Temporary cache tmp_cache = type(cache)() # We compute the floating point output of the upcoming block - next_block.cuda() + if torch.cuda.is_available(): + next_block.cuda() save_inputs_output( next_block, block_forward, diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index d28bd286e..54a8510f4 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -110,25 +110,25 @@ def get_blocks(model: nn.Module, block_name_attribute: str) -> List[nn.Module]: def apply_learned_round( - model: nn.Module, - calibration_loader: DataLoader, - iters: int = 200, - learned_round: str = "linear_round", - learned_round_loss: str = "mse", - block_name_attribute: str = "layers", - optimizer: str = "sign_sgd", - batch_size: int = 8, - learn_scale: bool = False, - use_best_model: bool = True, - amp_dtype: torch.dtype = torch.float16, - loss_scaling_factor: float = 1000, - lr_scheduler: Optional[str] = "linear", - optimizer_kwargs: Optional[Dict] = None, - lr_scheduler_kwargs: Optional[Dict] = None, - learned_round_loss_kwargs: Optional[Dict] = None, - scale_optimizer_class: Optional[str] = None, - scale_optimizer_kwargs: Optional[Dict] = None, -) -> None: + model: nn.Module, + calibration_loader: DataLoader, + iters: int = 200, + learned_round: str = "linear_round", + learned_round_loss: str = "mse", + block_name_attribute: str = "layers", + optimizer: str = "sign_sgd", + batch_size: int = 8, + learn_scale: bool = False, + use_best_model: bool = True, + amp_dtype: torch.dtype = torch.float16, + loss_scaling_factor: float = 1000, + lr_scheduler: Optional[str] = "linear", + optimizer_kwargs: Optional[Dict] = None, + lr_scheduler_kwargs: Optional[Dict] = None, + learned_round_loss_kwargs: Optional[Dict] = None, + scale_optimizer_class: Optional[str] = None, + scale_optimizer_kwargs: Optional[Dict] = None, + fast_update: bool = False) -> None: # Parse strings to obtain the arguments for the optimizer learned_round = parse_learned_round(learned_round) learned_round_loss_class = parse_learned_round_loss_class(learned_round_loss) @@ -169,4 +169,4 @@ def apply_learned_round( model_prepare_fn=llm_learned_round_prepare_fn, model_finish_fn=llm_learned_round_finish_fn, keep_gpu=False, - partial_update=True) + partial_update=fast_update) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 6004ec97d..203a86e89 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -385,7 +385,8 @@ def main(args): scale_optimizer_class='sgd', optimizer_kwargs={'lr': args.learned_round_lr}, scale_optimizer_kwargs={ - 'lr': args.learned_round_scale_lr, 'momentum': args.learned_round_scale_momentum}) + 'lr': args.learned_round_scale_lr, 'momentum': args.learned_round_scale_momentum}, + fast_update=args.learned_round_fast_update) print("Learned round applied.") model = offload_model(model) @@ -705,6 +706,11 @@ def parse_args(args): default=None, choices=[None, 'linear_round'], help='Whether to use learned round. If `None`, RTN is used (default: %(default)s)') + parser.add_argument( + '--learned-round-fast-update', + default=False, + type=bool, + help='Whether to use fast update with learned round. Prototype (default: %(default)s)') return parser.parse_args(args) From b3ca8d100772078d18c9e4ffa9a4a084c915e1d9 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 4 Dec 2024 22:13:32 +0000 Subject: [PATCH 04/12] update flag --- src/brevitas_examples/llm/llm_quant/learned_round_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index 54a8510f4..0c71455aa 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -169,4 +169,4 @@ def apply_learned_round( model_prepare_fn=llm_learned_round_prepare_fn, model_finish_fn=llm_learned_round_finish_fn, keep_gpu=False, - partial_update=fast_update) + fast_update=fast_update) From 23de059f46a9f30f9b8ceada6d5520a5f41dd1bb Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 5 Dec 2024 10:03:45 +0000 Subject: [PATCH 05/12] Fix comments --- .../common/learned_round/learned_round_optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index dad75ccb0..e03bfa9f1 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -707,13 +707,13 @@ def skip_full_execution(self, block, next_block, floating_point_datasets, block_ (args, kwargs), _ = cache.sample_batch([i]) floating_point_datasets.append((args, kwargs)) - # We use this new output to generate a new temporary dataloder for the next block + # We use the cache output to generate a new temporary dataloder for the next block # and to update our floating_point_dataset new_data_loader = [] for i in range(len(cache)): (args, kwargs), output = cache.sample_batch([i]) - new_data_loader.append(((output,), kwargs)) + new_data_loader.append(((output,), kwargs)) floating_point_datasets[i] = ((output,), kwargs) # Temporary cache From 2a1e32429e8ca598349deec9e0dacb0478f253b5 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 5 Dec 2024 10:14:33 +0000 Subject: [PATCH 06/12] Update flag and readme --- .../learned_round/learned_round_optimizer.py | 15 ++------------- src/brevitas_examples/llm/README.md | 4 ++++ src/brevitas_examples/llm/main.py | 2 +- 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index e03bfa9f1..3fe864a4d 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -690,14 +690,8 @@ def apply_learned_round( def skip_full_execution(self, block, next_block, floating_point_datasets, block_forward, cache): - # We need to propagate two datasets, one is a floating point dataset to compute float out - # The second is a quantized dataset to create the quantized input of the next blocks - - # First, we disable quantization - disable_quant_class = DisableEnableQuantization() - disable_quant_class.disable_act_quantization(block, False) - disable_quant_class.disable_param_quantization(block, False) - return_quant_tensor_state = disable_return_quant_tensor(block) + # We need to compute two inputs, one is a floating point one to compute float out + # The second is a quantized one to create the quantized input of the next blocks # If we don't have a floating_point_dataset, we retrieve it from the cache # The idea is that the cache contains the input to the very first block, and there is nothing @@ -737,11 +731,6 @@ def skip_full_execution(self, block, next_block, floating_point_datasets, block_ cache['output'] = tmp_cache['output'] - # Re-enable quantization - disable_quant_class.enable_act_quantization(block, False) - disable_quant_class.enable_param_quantization(block, False) - restore_return_quant_tensor(block, return_quant_tensor_state) - # Finally (!), we compute the quantized input of the next block block.eval() block.cuda() diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index a97ad548b..64c82c80a 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -54,6 +54,7 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--export-prefix EXPORT_PREFIX] [--checkpoint-name CHECKPOINT_NAME] [--fuse-sequences] [--learned-round {None,linear_round}] + [--learned-round-fast-update] options: -h, --help show this help message and exit @@ -196,5 +197,8 @@ options: --learned-round {None,linear_round} Whether to use learned round. If `None`, RTN is used (default: None) + --learned-round-fast-update + Whether to use fast update with learned round. + Prototype (default: False) ``` diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 203a86e89..c8c76d4a1 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -709,7 +709,7 @@ def parse_args(args): parser.add_argument( '--learned-round-fast-update', default=False, - type=bool, + action="store_true", help='Whether to use fast update with learned round. Prototype (default: %(default)s)') return parser.parse_args(args) From db0d5d006fe5d7b0caa7064137b5f6d132b8724d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 5 Dec 2024 10:19:36 +0000 Subject: [PATCH 07/12] init cache --- src/brevitas_examples/llm/llm_quant/learned_round_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index 0c71455aa..91d4ef405 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -23,9 +23,7 @@ class CacheLLM(Cache, dict): def __init__(self) -> None: super().__init__() - self["args"] = [] - self["kwargs"] = [] - self["output"] = [] + self.initialize_cache() def store_inputs(self, args, kwargs) -> None: self["args"].append(args) From 7ed24fb7d363a8dc4933015b51e0e868511b619d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 5 Dec 2024 10:20:16 +0000 Subject: [PATCH 08/12] Init cache vision --- .../imagenet_classification/ptq/learned_round_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index 3df861a93..de367bfdd 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -69,6 +69,7 @@ class CacheVision(Cache, dict): def __init__(self) -> None: super().__init__() self.batch_dim = 0 + self.initialize_cache() def store_inputs(self, args, kwargs) -> None: input_batch = args[0] From d369baa7c9b58874df722c1b3d5a6a257ef07230 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 5 Dec 2024 10:48:12 +0000 Subject: [PATCH 09/12] Simplification --- .../learned_round/learned_round_optimizer.py | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index 3fe864a4d..72db86f5f 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -603,7 +603,6 @@ def apply_learned_round( # Initialize cache to store partial inputs and outputs for each block cache.initialize_cache() - floating_point_datasets = [] # Iterate over blocks and optimise the rounding parameters within each of them for block_idx, block in enumerate(blocks): # Distribute the model across devices to run a forward pass to capture @@ -682,33 +681,24 @@ def apply_learned_round( block.cpu() if block_idx + 1 < len(blocks) and fast_update: - cache, floating_point_datasets = self.skip_full_execution(block, blocks[block_idx+1], floating_point_datasets, block_forward, cache) + cache = self.skip_full_execution(block, blocks[block_idx + 1], block_forward, cache) # The original configuration of the model is restored after finishing the optimization if model_finish_fn is not None: model_finish_fn(model, model_dict) - def skip_full_execution(self, block, next_block, floating_point_datasets, block_forward, cache): + def skip_full_execution(self, block, next_block, block_forward, cache): # We need to compute two inputs, one is a floating point one to compute float out # The second is a quantized one to create the quantized input of the next blocks - # If we don't have a floating_point_dataset, we retrieve it from the cache - # The idea is that the cache contains the input to the very first block, and there is nothing - # quantized before that. This is a moderately strong assumption - if len(floating_point_datasets) <= 0: - for i in range(len(cache)): - (args, kwargs), _ = cache.sample_batch([i]) - floating_point_datasets.append((args, kwargs)) - # We use the cache output to generate a new temporary dataloder for the next block # and to update our floating_point_dataset - new_data_loader = [] + tmp_data_loader = [] for i in range(len(cache)): (args, kwargs), output = cache.sample_batch([i]) - new_data_loader.append(((output,), kwargs)) - floating_point_datasets[i] = ((output,), kwargs) + tmp_data_loader.append(((output,), kwargs)) # Temporary cache tmp_cache = type(cache)() @@ -720,7 +710,7 @@ def skip_full_execution(self, block, next_block, floating_point_datasets, block_ next_block, block_forward, next_block, - new_data_loader, + tmp_data_loader, tmp_cache, store_inputs=False, store_output=True, @@ -746,4 +736,4 @@ def skip_full_execution(self, block, next_block, floating_point_datasets, block_ block.cpu() pbar.close() - return cache, floating_point_datasets + return cache From 0f0dd43461edff2251e676cb14409adde3983c47 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 5 Dec 2024 11:51:04 +0100 Subject: [PATCH 10/12] Update learned_round_optimizer.py --- .../common/learned_round/learned_round_optimizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index 72db86f5f..70b0c688d 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -693,7 +693,6 @@ def skip_full_execution(self, block, next_block, block_forward, cache): # The second is a quantized one to create the quantized input of the next blocks # We use the cache output to generate a new temporary dataloder for the next block - # and to update our floating_point_dataset tmp_data_loader = [] for i in range(len(cache)): (args, kwargs), output = cache.sample_batch([i]) From 7cef5ce87db7847539416cf36f07dca190e163b0 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 5 Dec 2024 11:53:08 +0100 Subject: [PATCH 11/12] Update learned_round_optimizer.py --- .../common/learned_round/learned_round_optimizer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index 70b0c688d..662f9142a 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -722,7 +722,8 @@ def skip_full_execution(self, block, next_block, block_forward, cache): # Finally (!), we compute the quantized input of the next block block.eval() - block.cuda() + if torch.cuda.is_available(): + block.cuda() next_quant_input = [] pbar = tqdm(range(len(cache)), desc='', leave=False) with torch.no_grad(): @@ -731,8 +732,8 @@ def skip_full_execution(self, block, next_block, block_forward, cache): out = block_forward(block, (args, kwargs)) out = send_to_device(out, 'cpu') next_quant_input.append((out,)) + pbar.close() cache['args'] = next_quant_input block.cpu() - pbar.close() return cache From 79df5c10a660f8a87a2e26a3e08bf66cebfab273 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 5 Dec 2024 10:54:39 +0000 Subject: [PATCH 12/12] precommit --- .../common/learned_round/learned_round_optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index 662f9142a..c9927d23b 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -723,7 +723,7 @@ def skip_full_execution(self, block, next_block, block_forward, cache): # Finally (!), we compute the quantized input of the next block block.eval() if torch.cuda.is_available(): - block.cuda() + block.cuda() next_quant_input = [] pbar = tqdm(range(len(cache)), desc='', leave=False) with torch.no_grad():