diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 934e25649..c655d4a44 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1024,9 +1024,12 @@ def generate_token( next_token_logprob, ) in enumerate(iterator): if request.parameters.return_k_alternatives > 0: + # Limit the number of alternatives to the vocabulary size + num_alternatives = min(request.parameters.return_k_alternatives, len(alternative_token_ids[i])) + # Select top-k logprobs - batch_alternative_token_ids = alternative_token_ids[i][:request.parameters.return_k_alternatives] - batch_alternative_token_logprobs = alternative_token_logprobs[i][:request.parameters.return_k_alternatives] + batch_alternative_token_ids = alternative_token_ids[i][:num_alternatives] + batch_alternative_token_logprobs = alternative_token_logprobs[i][:num_alternatives] # Decode tokens alternative_token_texts = list()