Skip to content

Commit

Permalink
[rfc][dont merge] Use the skip_guard_eval stance to remove torch.comp…
Browse files Browse the repository at this point in the history
…ile guard overhead
  • Loading branch information
anijain2305 committed Nov 4, 2024
1 parent 57a093a commit 7b0bfd4
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions hqq/utils/generation_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def __init__(
if hasattr(self, "decode_one_token") is False:
self.decode_one_token = decode_one_token

self.is_warmup_done = False
self.init() # check this: move this before setup_cache?

@torch.no_grad()
Expand Down Expand Up @@ -202,6 +203,7 @@ def warmup(self):
if self.is_compiled:
for prompt in WARMUP_PROMPTS:
self.generate(prompt, print_tokens=False)
self.is_warmup_done = True
return self

def next_multiple(self, val): # next power of 2
Expand Down Expand Up @@ -348,6 +350,7 @@ def next_token_iterator(
for i in tqdm(range(1, max_new_tokens), disable=(not verbose or print_tokens)):
next_token = self.gen_next_token(next_token)


if next_token[0].item() == self.tokenizer.eos_token_id:
break

Expand Down Expand Up @@ -381,9 +384,16 @@ def generate(
inputs=self.tokenize_prompt(prompt, use_chat_template=use_chat_template),
max_new_tokens=self.max_new_tokens,
)
return self.next_token_iterator(
self.prefill(), self.max_new_tokens, verbose, print_tokens
)
if self.is_warmup_done:
with torch.compiler.skip_guard_eval_unsafe():
return self.next_token_iterator(
self.prefill(), self.max_new_tokens, verbose, print_tokens
)
else:
return self.next_token_iterator(
self.prefill(), self.max_new_tokens, verbose, print_tokens
)


def generate_(
self, prompt, use_chat_template=True, verbose=False, print_tokens=False
Expand Down

0 comments on commit 7b0bfd4

Please sign in to comment.