diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index a0adf4d0b8c03..3b1ed7617c9b4 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -199,9 +199,10 @@ async def create_chat_completion(request: ChatCompletionRequest, return error_check_ret if request.logit_bias is not None and len(request.logit_bias) > 0: + pass # TODO: support logit_bias in vLLM engine. - return create_error_response(HTTPStatus.BAD_REQUEST, - "logit_bias is not currently supported") + #return create_error_response(HTTPStatus.BAD_REQUEST, + # "logit_bias is not currently supported") prompt = await get_gen_prompt(request) token_ids, error_check_ret = await check_length(request, prompt=prompt) @@ -228,6 +229,7 @@ async def create_chat_completion(request: ChatCompletionRequest, use_beam_search=request.use_beam_search, skip_special_tokens=request.skip_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens, + logit_bias=request.logit_bias ) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 6a29f1afd368b..688826fd8be64 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -44,7 +44,7 @@ def forward( hidden_states = _prune_hidden_states(hidden_states, input_metadata) # Get the logits for the next tokens. - logits = _get_logits(hidden_states, embedding, embedding_bias, + logits = _get_logits(hidden_states, embedding, embedding_bias if embedding_bias is not None else _get_logit_bias(input_metadata), self.vocab_size) # Apply presence and frequency penalties. @@ -217,6 +217,11 @@ def _apply_penalties( logits -= presence_penalties.unsqueeze(dim=1) * mask return logits +def _get_logit_bias(input_metadata: InputMetadata) -> any: + logit_biases: any = [] + for seq_group in input_metadata.seq_groups: + set_ids, sampling_params = seq_group + logit_biases += [sampling_params.logit_bias] def _get_temperatures(input_metadata: InputMetadata) -> List[float]: # Collect the temperatures for the logits. diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 00a9135a5ca7f..7fe4aed20f8f6 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -62,6 +62,7 @@ class SamplingParams: the stop tokens are sepcial tokens. ignore_eos: Whether to ignore the EOS token and continue generating tokens after the EOS token is generated. + logit_bias: Bias adjustment for different logits max_tokens: Maximum number of tokens to generate per output sequence. logprobs: Number of log probabilities to return per output token. Note that the implementation follows the OpenAI API: The return @@ -92,6 +93,7 @@ def __init__( stop_token_ids: Optional[List[int]] = None, ignore_eos: bool = False, max_tokens: int = 16, + logit_bias: float = [], logprobs: Optional[int] = None, prompt_logprobs: Optional[int] = None, skip_special_tokens: bool = True, @@ -122,6 +124,7 @@ def __init__( self.max_tokens = max_tokens self.logprobs = logprobs self.prompt_logprobs = prompt_logprobs + self.logit_bias = logit_bias self.skip_special_tokens = skip_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens