Skip to content

Commit

Permalink
Update learned_round_optimizer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Dec 5, 2024
1 parent 0f0dd43 commit 7cef5ce
Showing 1 changed file with 3 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,8 @@ def skip_full_execution(self, block, next_block, block_forward, cache):

# Finally (!), we compute the quantized input of the next block
block.eval()
block.cuda()
if torch.cuda.is_available():
block.cuda()
next_quant_input = []
pbar = tqdm(range(len(cache)), desc='', leave=False)
with torch.no_grad():
Expand All @@ -731,8 +732,8 @@ def skip_full_execution(self, block, next_block, block_forward, cache):
out = block_forward(block, (args, kwargs))
out = send_to_device(out, 'cpu')
next_quant_input.append((out,))
pbar.close()
cache['args'] = next_quant_input
block.cpu()
pbar.close()

return cache

0 comments on commit 7cef5ce

Please sign in to comment.