Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt to pipe logit_bias to sampler's embedding_bias #1279

Closed
wants to merge 1 commit into from

Conversation

benbot
Copy link

@benbot benbot commented Oct 6, 2023

logit_bias is an important feature of the OpenAI API that vllm seems to have implemented but not exposed to the actual api.

This is me taking a crack on exposing that functionality.

For the life of me, I can't get my cuda versions to all agree to build this locally, so while I try to do that i'm opening the PR for others to try out.

Should resolve: #379

@benbot
Copy link
Author

benbot commented Oct 6, 2023

I think i'm passing an array of arrays of logit biases, so that will probably need to be changed.
But i can't get vllm to build locally yet, so I can't verify

@viktor-ferenczi
Copy link
Contributor

Please rebase this branch to current main, then I will build and test this. I also want to get LMQL working with vLLM, because performance is not good with any of the other backends.

@benbot
Copy link
Author

benbot commented Nov 1, 2023

Sure thing. I'll rebase tomorrow.
:)

@benbot benbot reopened this Nov 1, 2023
@benbot
Copy link
Author

benbot commented Nov 1, 2023

No idea how this got closed. Reopened and rebased.

Finally got it built too!
Haven't had a chance to test if logit bias works yet though

Copy link
Contributor

@viktor-ferenczi viktor-ferenczi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, added minor comments.

The client must know the vocabulary and the vocab_size in order to pass a logits_bias array which works with the model loaded. Right now the client has to load the tokenizer corresponding to the model loaded into the vLLM server to achieve this. It makes clients more complex and there is a risk of using the wrong tokenizer for the model, ending up in errors.

I suggest exposing two new REST API calls:

  • vocabulary Returns the vocabulary (array of strings). The vocab_size is the length of this array. It may also return the special tokens, stop tokens, start/end of conversation tokens, etc.
  • tokenize Accepts text and returns an array of integers (token values).

This way no tokenizer needs to be built on client side and no way to use the wrong vocabulary. The client can build the logits_bias based on the information returned (the client must acquire the vocabulary once on initialization).

Having these would allow for a pretty straightforward test case and easier integration with LMQL.

@@ -92,6 +93,7 @@ def __init__(
stop_token_ids: Optional[List[int]] = None,
ignore_eos: bool = False,
max_tokens: int = 16,
logit_bias: float = [],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Proper type is: Optional[List[float]]

@@ -122,6 +124,7 @@ def __init__(
self.max_tokens = max_tokens
self.logprobs = logprobs
self.prompt_logprobs = prompt_logprobs
self.logit_bias = logit_bias
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Define type here as well

Copy link
Contributor

@viktor-ferenczi viktor-ferenczi Nov 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add self.logit_bias to __repr__ below.

logit_biases: any = []
for seq_group in input_metadata.seq_groups:
set_ids, sampling_params = seq_group
logit_biases += [sampling_params.logit_bias]
Copy link
Contributor

@viktor-ferenczi viktor-ferenczi Nov 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add validation of the size of the logit_bias array received from sampling_params. Having an explicit error with a clear explanation here is better than ending up with a cryptic PyTorch error message while trying to add arrays of mismatching sizes later.

(It cannot be validated inside SamplingParams, because the right size is not known there.)

@benbot
Copy link
Author

benbot commented Nov 5, 2023

@viktor-ferenczi I've been trying to test this locally, but it doesn't seem like the logit_bias parameter is actually working :(

Were you able to see the logit_bias taking effect?

@benbot
Copy link
Author

benbot commented Nov 6, 2023

The client must know the vocabulary and the vocab_size in order to pass a logits_bias array which works with the model loaded.

@viktor-ferenczi I may be missing something, but why do they need to know the vocab size? Isn't logit_bias a sparse mapping of token ids to bias to apply?

@viktor-ferenczi
Copy link
Contributor

The client must know the vocabulary and the vocab_size in order to pass a logits_bias array which works with the model loaded.

@viktor-ferenczi I may be missing something, but why do they need to know the vocab size? Isn't logit_bias a sparse mapping of token ids to bias to apply?

I did not know it is a sparse mapping, I expected it to be an array of floats with vocab_size items.

Anyway, the client must be able to retrieve the vocabulary of the currently loaded model for the solution to be fully usable. Otherwise the client must instantiate the tokenizer of the exact same model loaded into the vLLM server just to get the vocabulary, which means a lot more dependencies and room for error, not to mention the added loading time.

@benbot
Copy link
Author

benbot commented Nov 20, 2023

Looks like on OAI's API it's a map of token_ids to bias values

https://platform.openai.com/docs/api-reference/chat/create#chat-create-logit_bias

@creatorrr
Copy link

@viktor-ferenczi @benbot did you guys come to a consensus on how to go about implementing this? Loading the tokenizer vocab will be tricky but then again this feature is meant for rather advanced use cases and we could just leave it to the api consumer to figure out the mapping between token_ids and tokens (that's what openai did anyway).

@viktor-ferenczi
Copy link
Contributor

There is a logits_processors parameter in SamplingParams already. It was added by #1469 which has been merged to main already.

@viktor-ferenczi
Copy link
Contributor

Depending on your use case using a grammar would be a much cleaner alternative than trying to manipulate logits, it also moves the problem of handling tokens into the server completely.

See #2105: Add grammars

@benbot
Copy link
Author

benbot commented Mar 20, 2024

Covered by #3027

@benbot benbot closed this Mar 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature] Add support for logit_bias
3 participants