diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 85f485364a411..e88d6c3c67829 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -103,6 +103,116 @@ def test_serving_chat_should_set_correct_max_tokens(): assert mock_engine.generate.call_args.args[1].max_tokens == 10 + # Setting server's max_tokens in the generation_config.json + # lower than context_window - prompt_tokens + mock_model_config = MockModelConfig() + mock_model_config.diff_sampling_param = { + "max_tokens": 10 # Setting server-side max_tokens limit + } + + # Reinitialize the engine with new settings + mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + + # Initialize the serving chat + models = OpenAIServingModels(engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config) + serving_chat = OpenAIServingChat(mock_engine, + mock_model_config, + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None) + + # Test Case 1: No max_tokens specified in request + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "what is 1+1?" + }], + guided_decoding_backend="outlines", + ) + + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + assert mock_engine.generate.call_args.args[1].max_tokens == 10 + + # Test Case 2: Request's max_tokens set higher than server accepts + req.max_tokens = 15 + + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + assert mock_engine.generate.call_args.args[1].max_tokens == 10 + + # Test Case 3: Request's max_tokens set lower than server accepts + req.max_tokens = 5 + + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + assert mock_engine.generate.call_args.args[1].max_tokens == 5 + + # Setting server's max_tokens in the generation_config.json + # higher than context_window - prompt_tokens + mock_model_config = MockModelConfig() + mock_model_config.diff_sampling_param = { + "max_tokens": 200 # Setting server-side max_tokens limit + } + + # Reinitialize the engine with new settings + mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + + # Initialize the serving chat + models = OpenAIServingModels(engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config) + serving_chat = OpenAIServingChat(mock_engine, + mock_model_config, + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None) + + # Test case 1: No max_tokens specified, defaults to context_window + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "what is 1+1?" + }], + guided_decoding_backend="outlines", + ) + + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + assert mock_engine.generate.call_args.args[1].max_tokens == 93 + + # Test Case 2: Request's max_tokens set higher than server accepts + req.max_tokens = 100 + + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + assert mock_engine.generate.call_args.args[1].max_tokens == 93 + + # Test Case 3: Request's max_tokens set lower than server accepts + req.max_tokens = 5 + + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + assert mock_engine.generate.call_args.args[1].max_tokens == 5 + def test_serving_chat_could_load_correct_generation_config(): diff --git a/vllm/config.py b/vllm/config.py index 11c6f853b2b45..7a58d64bcc6e2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -910,12 +910,18 @@ def get_diff_sampling_param(self) -> Dict[str, Any]: "top_k", "top_p", "min_p", + "max_new_tokens", ] if any(p in config for p in available_params): diff_sampling_param = { p: config.get(p) for p in available_params if config.get(p) is not None } + # Huggingface definition of max_new_tokens is equivalent + # to vLLM's max_tokens + if "max_new_tokens" in diff_sampling_param: + diff_sampling_param["max_tokens"] = diff_sampling_param.pop( + "max_new_tokens") else: diff_sampling_param = {} return diff_sampling_param diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f16e8e6df76bd..ba96484e3fce9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -939,7 +939,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "Defaults to None, will use the default generation config in vLLM. " "If set to 'auto', the generation config will be automatically " "loaded from model. If set to a folder path, the generation config " - "will be loaded from the specified folder path.") + "will be loaded from the specified folder path. If " + "`max_new_tokens` is specified, then it sets a server-wide limit " + "on the number of output tokens for all requests.") parser.add_argument("--enable-sleep-mode", action="store_true", diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 80403f77d5375..6f546aaec442a 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -380,13 +380,17 @@ def to_beam_search_params( ) -> BeamSearchParams: # TODO(#9845): remove max_tokens when field is removed from OpenAI API max_tokens = self.max_completion_tokens or self.max_tokens - if max_tokens is None: - max_tokens = default_max_tokens if default_sampling_params is None: default_sampling_params = {} n = self.n if self.n is not None else 1 + # Use minimum of context window, user request & server limit. + max_tokens = min( + val for val in (default_max_tokens, max_tokens, + default_sampling_params.get("max_tokens", None)) + if val is not None) + if (temperature := self.temperature) is None: temperature = default_sampling_params.get( "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) @@ -406,11 +410,16 @@ def to_sampling_params( default_sampling_params: Optional[dict] = None) -> SamplingParams: # TODO(#9845): remove max_tokens when field is removed from OpenAI API max_tokens = self.max_completion_tokens or self.max_tokens - if max_tokens is None: - max_tokens = default_max_tokens if default_sampling_params is None: default_sampling_params = {} + + # Use minimum of context window, user request & server limit. + max_tokens = min( + val for val in (default_max_tokens, max_tokens, + default_sampling_params.get("max_tokens", None)) + if val is not None) + # Default parameters if (repetition_penalty := self.repetition_penalty) is None: repetition_penalty = default_sampling_params.get( @@ -740,13 +749,17 @@ def to_beam_search_params( default_sampling_params: Optional[dict] = None ) -> BeamSearchParams: max_tokens = self.max_tokens - if max_tokens is None: - max_tokens = default_max_tokens if default_sampling_params is None: default_sampling_params = {} n = self.n if self.n is not None else 1 + # Use minimum of context window, user request & server limit. + max_tokens = min( + val for val in (default_max_tokens, max_tokens, + default_sampling_params.get("max_tokens", None)) + if val is not None) + if (temperature := self.temperature) is None: temperature = default_sampling_params.get("temperature", 1.0) @@ -764,11 +777,16 @@ def to_sampling_params( logits_processor_pattern: Optional[str], default_sampling_params: Optional[dict] = None) -> SamplingParams: max_tokens = self.max_tokens - if max_tokens is None: - max_tokens = default_max_tokens if default_sampling_params is None: default_sampling_params = {} + + # Use minimum of context window, user request & server limit. + max_tokens = min( + val for val in (default_max_tokens, max_tokens, + default_sampling_params.get("max_tokens", None)) + if val is not None) + # Default parameters if (repetition_penalty := self.repetition_penalty) is None: repetition_penalty = default_sampling_params.get(