diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index e9430155f..e43c8e01b 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -709,11 +709,15 @@ def skip_full_execution(self, block, next_block, floating_point_datasets, block_ # Then, we compute the floating point output of the current block next_float_input = [] + block.cuda() + pbar = tqdm(floating_point_datasets, desc='', leave=False) with torch.no_grad(): - for args, kwargs in floating_point_datasets: + for args, kwargs in pbar: out = block_forward(block, (args, kwargs)) out = send_to_device(out, 'cpu') next_float_input.append((out,)) + pbar.close() + block.cpu() # We use this new output to generate a new temporary dataloder for the next block # and to update our floating_point_dataset new_data_loader = [] @@ -725,8 +729,7 @@ def skip_full_execution(self, block, next_block, floating_point_datasets, block_ floating_point_datasets[i] = (next_float_input[i], fp_dataset_kwargs) # Temporary cache - tmp_cache = copy.deepcopy(cache) - tmp_cache.clear_cache() + tmp_cache = type(cache)() # We compute the floating point output of the upcoming block next_block.cuda() @@ -754,13 +757,15 @@ def skip_full_execution(self, block, next_block, floating_point_datasets, block_ block.eval() block.cuda() next_quant_input = [] + pbar = tqdm(range(len(cache)), desc='', leave=False) with torch.no_grad(): - for i in range(len(cache)): + for i in pbar: (args, kwargs), _ = cache.sample_batch([i]) out = block_forward(block, (args, kwargs)) out = send_to_device(out, 'cpu') next_quant_input.append((out,)) - cache['args'] = copy.deepcopy(next_quant_input) + cache['args'] = next_quant_input block.cpu() + pbar.close() return cache, floating_point_datasets diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index ccf0c11b1..d28bd286e 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -23,6 +23,9 @@ class CacheLLM(Cache, dict): def __init__(self) -> None: super().__init__() + self["args"] = [] + self["kwargs"] = [] + self["output"] = [] def store_inputs(self, args, kwargs) -> None: self["args"].append(args)