-
-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Attempt to pipe logit_bias to sampler's embedding_bias #1279
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,6 +62,7 @@ class SamplingParams: | |
the stop tokens are sepcial tokens. | ||
ignore_eos: Whether to ignore the EOS token and continue generating | ||
tokens after the EOS token is generated. | ||
logit_bias: Bias adjustment for different logits | ||
max_tokens: Maximum number of tokens to generate per output sequence. | ||
logprobs: Number of log probabilities to return per output token. | ||
Note that the implementation follows the OpenAI API: The return | ||
|
@@ -92,6 +93,7 @@ def __init__( | |
stop_token_ids: Optional[List[int]] = None, | ||
ignore_eos: bool = False, | ||
max_tokens: int = 16, | ||
logit_bias: float = [], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Proper type is: |
||
logprobs: Optional[int] = None, | ||
prompt_logprobs: Optional[int] = None, | ||
skip_special_tokens: bool = True, | ||
|
@@ -122,6 +124,7 @@ def __init__( | |
self.max_tokens = max_tokens | ||
self.logprobs = logprobs | ||
self.prompt_logprobs = prompt_logprobs | ||
self.logit_bias = logit_bias | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Define type here as well There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also add |
||
self.skip_special_tokens = skip_special_tokens | ||
self.spaces_between_special_tokens = spaces_between_special_tokens | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add validation of the size of the
logit_bias
array received fromsampling_params
. Having an explicit error with a clear explanation here is better than ending up with a cryptic PyTorch error message while trying to add arrays of mismatching sizes later.(It cannot be validated inside
SamplingParams
, because the right size is not known there.)