From 7b0bfd4703b51a3d13db2af8ffc91ec261209bb7 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sun, 27 Oct 2024 22:31:55 -0700 Subject: [PATCH] [rfc][dont merge] Use the skip_guard_eval stance to remove torch.compile guard overhead --- hqq/utils/generation_hf.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/hqq/utils/generation_hf.py b/hqq/utils/generation_hf.py index 031f3a0..10a3810 100755 --- a/hqq/utils/generation_hf.py +++ b/hqq/utils/generation_hf.py @@ -162,6 +162,7 @@ def __init__( if hasattr(self, "decode_one_token") is False: self.decode_one_token = decode_one_token + self.is_warmup_done = False self.init() # check this: move this before setup_cache? @torch.no_grad() @@ -202,6 +203,7 @@ def warmup(self): if self.is_compiled: for prompt in WARMUP_PROMPTS: self.generate(prompt, print_tokens=False) + self.is_warmup_done = True return self def next_multiple(self, val): # next power of 2 @@ -348,6 +350,7 @@ def next_token_iterator( for i in tqdm(range(1, max_new_tokens), disable=(not verbose or print_tokens)): next_token = self.gen_next_token(next_token) + if next_token[0].item() == self.tokenizer.eos_token_id: break @@ -381,9 +384,16 @@ def generate( inputs=self.tokenize_prompt(prompt, use_chat_template=use_chat_template), max_new_tokens=self.max_new_tokens, ) - return self.next_token_iterator( - self.prefill(), self.max_new_tokens, verbose, print_tokens - ) + if self.is_warmup_done: + with torch.compiler.skip_guard_eval_unsafe(): + return self.next_token_iterator( + self.prefill(), self.max_new_tokens, verbose, print_tokens + ) + else: + return self.next_token_iterator( + self.prefill(), self.max_new_tokens, verbose, print_tokens + ) + def generate_( self, prompt, use_chat_template=True, verbose=False, print_tokens=False