Skip to content

Commit

Permalink
Limit number of alternatives to vocabulary size
Browse files Browse the repository at this point in the history
  • Loading branch information
JTS22 committed Mar 6, 2024
1 parent 55b1f3c commit f91dd54
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,9 +1024,12 @@ def generate_token(
next_token_logprob,
) in enumerate(iterator):
if request.parameters.return_k_alternatives > 0:
# Limit the number of alternatives to the vocabulary size
num_alternatives = min(request.parameters.return_k_alternatives, len(alternative_token_ids[i]))

# Select top-k logprobs
batch_alternative_token_ids = alternative_token_ids[i][:request.parameters.return_k_alternatives]
batch_alternative_token_logprobs = alternative_token_logprobs[i][:request.parameters.return_k_alternatives]
batch_alternative_token_ids = alternative_token_ids[i][:num_alternatives]
batch_alternative_token_logprobs = alternative_token_logprobs[i][:num_alternatives]

# Decode tokens
alternative_token_texts = list()
Expand Down

0 comments on commit f91dd54

Please sign in to comment.