From 7cef5ce87db7847539416cf36f07dca190e163b0 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 5 Dec 2024 11:53:08 +0100 Subject: [PATCH] Update learned_round_optimizer.py --- .../common/learned_round/learned_round_optimizer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index 70b0c688d..662f9142a 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -722,7 +722,8 @@ def skip_full_execution(self, block, next_block, block_forward, cache): # Finally (!), we compute the quantized input of the next block block.eval() - block.cuda() + if torch.cuda.is_available(): + block.cuda() next_quant_input = [] pbar = tqdm(range(len(cache)), desc='', leave=False) with torch.no_grad(): @@ -731,8 +732,8 @@ def skip_full_execution(self, block, next_block, block_forward, cache): out = block_forward(block, (args, kwargs)) out = send_to_device(out, 'cpu') next_quant_input.append((out,)) + pbar.close() cache['args'] = next_quant_input block.cpu() - pbar.close() return cache