From f91dd5470fccdb4fd0a6efca5cc5302fdab00ad8 Mon Sep 17 00:00:00 2001 From: Jonas Schroeder Date: Tue, 5 Mar 2024 08:32:58 +0000 Subject: [PATCH] Limit number of alternatives to vocabulary size --- server/lorax_server/models/flash_causal_lm.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 934e25649..c655d4a44 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1024,9 +1024,12 @@ def generate_token( next_token_logprob, ) in enumerate(iterator): if request.parameters.return_k_alternatives > 0: + # Limit the number of alternatives to the vocabulary size + num_alternatives = min(request.parameters.return_k_alternatives, len(alternative_token_ids[i])) + # Select top-k logprobs - batch_alternative_token_ids = alternative_token_ids[i][:request.parameters.return_k_alternatives] - batch_alternative_token_logprobs = alternative_token_logprobs[i][:request.parameters.return_k_alternatives] + batch_alternative_token_ids = alternative_token_ids[i][:num_alternatives] + batch_alternative_token_logprobs = alternative_token_logprobs[i][:num_alternatives] # Decode tokens alternative_token_texts = list()