diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index 70b0c688d..662f9142a 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -722,7 +722,8 @@ def skip_full_execution(self, block, next_block, block_forward, cache): # Finally (!), we compute the quantized input of the next block block.eval() - block.cuda() + if torch.cuda.is_available(): + block.cuda() next_quant_input = [] pbar = tqdm(range(len(cache)), desc='', leave=False) with torch.no_grad(): @@ -731,8 +732,8 @@ def skip_full_execution(self, block, next_block, block_forward, cache): out = block_forward(block, (args, kwargs)) out = send_to_device(out, 'cpu') next_quant_input.append((out,)) + pbar.close() cache['args'] = next_quant_input block.cpu() - pbar.close() return cache