diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 8b60b7a70b121..3b35ae1ebd705 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -216,9 +216,13 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: end-chat-completion-extra-params def to_sampling_params( - self, tokenizer: PreTrainedTokenizer, - guided_decode_logits_processor: Optional[LogitsProcessor] - ) -> SamplingParams: + self, tokenizer: PreTrainedTokenizer, + guided_decode_logits_processor: Optional[LogitsProcessor], + default_max_tokens: int) -> SamplingParams: + max_tokens = self.max_tokens + if max_tokens is None: + max_tokens = default_max_tokens + # We now allow logprobs being true without top_logrobs. logits_processors = get_logits_processors( logit_bias=self.logit_bias, @@ -244,7 +248,7 @@ def to_sampling_params( logprobs=self.top_logprobs if self.logprobs else None, prompt_logprobs=self.top_logprobs if self.echo else None, ignore_eos=self.ignore_eos, - max_tokens=self.max_tokens, + max_tokens=max_tokens, min_tokens=self.min_tokens, use_beam_search=self.use_beam_search, early_stopping=self.early_stopping, @@ -399,9 +403,13 @@ class CompletionRequest(OpenAIBaseModel): # doc: end-completion-extra-params def to_sampling_params( - self, tokenizer: PreTrainedTokenizer, - guided_decode_logits_processor: Optional[LogitsProcessor] - ) -> SamplingParams: + self, tokenizer: PreTrainedTokenizer, + guided_decode_logits_processor: Optional[LogitsProcessor], + default_max_tokens: int) -> SamplingParams: + max_tokens = self.max_tokens + if max_tokens is None: + max_tokens = default_max_tokens + echo_without_generation = self.echo and self.max_tokens == 0 logits_processors = get_logits_processors( @@ -427,7 +435,7 @@ def to_sampling_params( stop_token_ids=self.stop_token_ids, logprobs=self.logprobs, ignore_eos=self.ignore_eos, - max_tokens=self.max_tokens if not echo_without_generation else 1, + max_tokens=max_tokens if not echo_without_generation else 1, min_tokens=self.min_tokens, use_beam_search=self.use_beam_search, early_stopping=self.early_stopping, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 1dee798858a16..c832cf2a24b50 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -144,10 +144,10 @@ async def create_chat_completion( ) sampling_params = request.to_sampling_params( - tokenizer, guided_decode_logits_processor) - if sampling_params.max_tokens is None: - sampling_params.max_tokens = \ - self.max_model_len - len(prompt_inputs["prompt_token_ids"]) + tokenizer, + guided_decode_logits_processor, + default_max_tokens=self.max_model_len - + len(prompt_inputs["prompt_token_ids"])) self._log_inputs(request_id, prompt_inputs, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 51de2f1826554..7765c5903f341 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -106,10 +106,10 @@ async def create_completion(self, request: CompletionRequest, for i, prompt_inputs in enumerate(prompts): sampling_params = request.to_sampling_params( - tokenizer, guided_decode_logits_processor) - if sampling_params.max_tokens is None: - sampling_params.max_tokens = self.max_model_len - \ - len(prompt_inputs["prompt_token_ids"]) + tokenizer, + guided_decode_logits_processor, + default_max_tokens=self.max_model_len - + len(prompt_inputs["prompt_token_ids"])) request_id_item = f"{request_id}-{i}"