From 00b2a8b39de16c64ec4d4c7b9c7d832d5b3488e4 Mon Sep 17 00:00:00 2001 From: zifeitong Date: Wed, 31 Jul 2024 21:13:34 -0700 Subject: [PATCH] [Bugfix] Set SamplingParams.max_tokens for OpenAI requests if not provided by user (#6954) --- tests/entrypoints/openai/test_serving_chat.py | 39 +++++++++++++++++++ vllm/entrypoints/openai/protocol.py | 30 ++++++++++---- vllm/entrypoints/openai/serving_chat.py | 23 ++++------- vllm/entrypoints/openai/serving_completion.py | 27 +++++-------- vllm/entrypoints/openai/serving_engine.py | 17 ++++++-- 5 files changed, 92 insertions(+), 44 deletions(-) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 464465494b714..168ba7ba888ef 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -1,7 +1,12 @@ import asyncio +from contextlib import suppress from dataclasses import dataclass +from unittest.mock import MagicMock +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.transformers_utils.tokenizer import get_tokenizer MODEL_NAME = "openai-community/gpt2" CHAT_TEMPLATE = "Dummy chat template for testing {}" @@ -42,3 +47,37 @@ async def _async_serving_chat_init(): def test_async_serving_chat_init(): serving_completion = asyncio.run(_async_serving_chat_init()) assert serving_completion.chat_template == CHAT_TEMPLATE + + +def test_serving_chat_should_set_correct_max_tokens(): + mock_engine = MagicMock(spec=AsyncLLMEngine) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + + serving_chat = OpenAIServingChat(mock_engine, + MockModelConfig(), + served_model_names=[MODEL_NAME], + response_role="assistant", + chat_template=CHAT_TEMPLATE, + lora_modules=None, + prompt_adapters=None, + request_logger=None) + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "what is 1+1?" + }], + guided_decoding_backend="outlines", + ) + + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + # AsyncLLMEngine.generate(inputs, sampling_params, ...) + assert mock_engine.generate.call_args.args[1].max_tokens == 93 + + req.max_tokens = 10 + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + assert mock_engine.generate.call_args.args[1].max_tokens == 10 diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 205860aa8e722..3b35ae1ebd705 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -11,7 +11,7 @@ from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.pooling_params import PoolingParams -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.utils import random_uuid @@ -215,15 +215,22 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: end-chat-completion-extra-params - def to_sampling_params(self, - tokenizer: PreTrainedTokenizer) -> SamplingParams: - # We now allow logprobs being true without top_logrobs. + def to_sampling_params( + self, tokenizer: PreTrainedTokenizer, + guided_decode_logits_processor: Optional[LogitsProcessor], + default_max_tokens: int) -> SamplingParams: + max_tokens = self.max_tokens + if max_tokens is None: + max_tokens = default_max_tokens + # We now allow logprobs being true without top_logrobs. logits_processors = get_logits_processors( logit_bias=self.logit_bias, allowed_token_ids=None, tokenizer=tokenizer, ) + if guided_decode_logits_processor: + logits_processors.append(guided_decode_logits_processor) return SamplingParams( n=self.n, @@ -241,7 +248,7 @@ def to_sampling_params(self, logprobs=self.top_logprobs if self.logprobs else None, prompt_logprobs=self.top_logprobs if self.echo else None, ignore_eos=self.ignore_eos, - max_tokens=self.max_tokens, + max_tokens=max_tokens, min_tokens=self.min_tokens, use_beam_search=self.use_beam_search, early_stopping=self.early_stopping, @@ -395,7 +402,14 @@ class CompletionRequest(OpenAIBaseModel): # doc: end-completion-extra-params - def to_sampling_params(self, tokenizer: PreTrainedTokenizer): + def to_sampling_params( + self, tokenizer: PreTrainedTokenizer, + guided_decode_logits_processor: Optional[LogitsProcessor], + default_max_tokens: int) -> SamplingParams: + max_tokens = self.max_tokens + if max_tokens is None: + max_tokens = default_max_tokens + echo_without_generation = self.echo and self.max_tokens == 0 logits_processors = get_logits_processors( @@ -403,6 +417,8 @@ def to_sampling_params(self, tokenizer: PreTrainedTokenizer): allowed_token_ids=self.allowed_token_ids, tokenizer=tokenizer, ) + if guided_decode_logits_processor: + logits_processors.append(guided_decode_logits_processor) return SamplingParams( n=self.n, @@ -419,7 +435,7 @@ def to_sampling_params(self, tokenizer: PreTrainedTokenizer): stop_token_ids=self.stop_token_ids, logprobs=self.logprobs, ignore_eos=self.ignore_eos, - max_tokens=self.max_tokens if not echo_without_generation else 1, + max_tokens=max_tokens if not echo_without_generation else 1, min_tokens=self.min_tokens, use_beam_search=self.use_beam_search, early_stopping=self.early_stopping, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 01843930bf11d..c832cf2a24b50 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -25,8 +25,6 @@ PromptAdapterPath) from vllm.inputs import PromptInputs from vllm.logger import init_logger -from vllm.model_executor.guided_decoding import ( - get_guided_decoding_logits_processor) from vllm.multimodal import MultiModalDataDict from vllm.outputs import RequestOutput from vllm.sequence import Logprob @@ -134,28 +132,23 @@ async def create_chat_completion( request_id = f"chat-{random_uuid()}" try: - sampling_params = request.to_sampling_params(tokenizer) - decoding_config = await self.engine.get_decoding_config() - guided_decoding_backend = request.guided_decoding_backend \ - or decoding_config.guided_decoding_backend guided_decode_logits_processor = ( - await - get_guided_decoding_logits_processor(guided_decoding_backend, - request, tokenizer)) - if guided_decode_logits_processor: - if sampling_params.logits_processors is None: - sampling_params.logits_processors = [] - sampling_params.logits_processors.append( - guided_decode_logits_processor) + await self._guided_decode_logits_processor(request, tokenizer)) prompt_inputs = self._tokenize_prompt_input( request, tokenizer, prompt, - truncate_prompt_tokens=sampling_params.truncate_prompt_tokens, + truncate_prompt_tokens=request.truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, ) + sampling_params = request.to_sampling_params( + tokenizer, + guided_decode_logits_processor, + default_max_tokens=self.max_model_len - + len(prompt_inputs["prompt_token_ids"])) + self._log_inputs(request_id, prompt_inputs, params=sampling_params, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 8548352791680..7765c5903f341 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -24,8 +24,6 @@ OpenAIServing, PromptAdapterPath) from vllm.logger import init_logger -from vllm.model_executor.guided_decoding import ( - get_guided_decoding_logits_processor) from vllm.outputs import RequestOutput from vllm.sequence import Logprob from vllm.tracing import (contains_trace_headers, extract_trace_headers, @@ -95,31 +93,24 @@ async def create_completion(self, request: CompletionRequest, tokenizer = await self.engine.get_tokenizer(lora_request) - sampling_params = request.to_sampling_params(tokenizer) - decoding_config = await self.engine.get_decoding_config() - guided_decoding_backend = request.guided_decoding_backend \ - or decoding_config.guided_decoding_backend - guided_decode_logit_processor = ( - await - get_guided_decoding_logits_processor(guided_decoding_backend, - request, tokenizer)) - if guided_decode_logit_processor is not None: - if sampling_params.logits_processors is None: - sampling_params.logits_processors = [] - sampling_params.logits_processors.append( - guided_decode_logit_processor) - + guided_decode_logits_processor = ( + await self._guided_decode_logits_processor(request, tokenizer)) prompts = list( self._tokenize_prompt_input_or_inputs( request, tokenizer, request.prompt, - truncate_prompt_tokens=sampling_params. - truncate_prompt_tokens, + truncate_prompt_tokens=request.truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, )) for i, prompt_inputs in enumerate(prompts): + sampling_params = request.to_sampling_params( + tokenizer, + guided_decode_logits_processor, + default_max_tokens=self.max_model_len - + len(prompt_inputs["prompt_token_ids"])) + request_id_item = f"{request_id}-{i}" self._log_inputs(request_id_item, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index b374a7946b11e..8c7929a12e9a0 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -25,9 +25,11 @@ from vllm.inputs import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + get_guided_decoding_logits_processor) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer_group import AnyTokenizer @@ -150,6 +152,15 @@ def create_streaming_error_response( }) return json_str + async def _guided_decode_logits_processor( + self, request: Union[ChatCompletionRequest, CompletionRequest], + tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]: + decoding_config = await self.engine.get_decoding_config() + guided_decoding_backend = request.guided_decoding_backend \ + or decoding_config.guided_decoding_backend + return await get_guided_decoding_logits_processor( + guided_decoding_backend, request, tokenizer) + async def _check_model( self, request: AnyRequest, @@ -254,9 +265,7 @@ def _validate_input( f"{self.max_model_len} tokens. However, you requested " f"{token_num} tokens in the messages, " f"Please reduce the length of the messages.") - request.max_tokens = self.max_model_len - token_num - - if token_num + request.max_tokens > self.max_model_len: + elif token_num + request.max_tokens > self.max_model_len: raise ValueError( f"This model's maximum context length is " f"{self.max_model_len} tokens. However, you requested "