From f23a9735052f50ffe69c81ae77f231d61b340693 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 8 Jan 2024 16:59:55 -0800 Subject: [PATCH] Fixed tests --- server/lorax_server/models/causal_lm.py | 3 +++ server/lorax_server/models/seq2seq_lm.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index 5e00ea809..f215f3a2c 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -655,12 +655,15 @@ def generate_token( prefill_tokens = PrefillTokens( prefill_token_ids, prefill_logprobs, prefill_texts ) + prefill_tokens_length = len(prefill_tokens.token_ids) else: prefill_tokens = None + prefill_tokens_length = 0 generation = Generation( request.id, prefill_tokens, + prefill_tokens_length, next_token_id_squeezed, next_token_logprob, next_token_text, diff --git a/server/lorax_server/models/seq2seq_lm.py b/server/lorax_server/models/seq2seq_lm.py index 33587fae6..2004a2c2d 100644 --- a/server/lorax_server/models/seq2seq_lm.py +++ b/server/lorax_server/models/seq2seq_lm.py @@ -699,12 +699,15 @@ def generate_token( [float("nan")], [self.tokenizer.bos_token], ) + prefill_tokens_length = len(prefill_tokens.token_ids) else: prefill_tokens = None + prefill_tokens_length = 0 generation = Generation( request.id, prefill_tokens, + prefill_tokens_length, next_token_id_squeezed, next_token_logprob, next_token_text,