From 9efb5fe2937206c07e0c6a7db6c05e4139358e17 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 31 Oct 2024 19:09:46 -0600 Subject: [PATCH] [Bugfix][Frontend] Reject guided decoding in multistep mode (#9892) Signed-off-by: Joe Runde Signed-off-by: Richard Liu --- docs/source/serving/compatibility_matrix.rst | 2 +- .../openai/test_prompt_validation.py | 20 +++++++++++++++++++ vllm/engine/llm_engine.py | 7 +++++++ vllm/sampling_params.py | 4 ++-- 4 files changed, 30 insertions(+), 3 deletions(-) diff --git a/docs/source/serving/compatibility_matrix.rst b/docs/source/serving/compatibility_matrix.rst index 20a81f4cad1d1..cab19e4ec5b6c 100644 --- a/docs/source/serving/compatibility_matrix.rst +++ b/docs/source/serving/compatibility_matrix.rst @@ -283,7 +283,7 @@ Feature x Feature - ✅ - ✅ - ✅ - - `✗ `__ + - `✗ `__ - ? - ✅ - ✅ diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index 58075f7023821..1ae64ef492d5b 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -35,3 +35,23 @@ async def test_out_of_vocab_token_ids(): prompt=[999999], max_tokens=5, temperature=0.0) + + +@pytest.mark.asyncio +async def test_reject_multistep_with_guided_decoding(): + model_name = "gpt2" + server_args = ["--enforce-eager", "--num-scheduler-steps", "8"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + + with pytest.raises(openai.BadRequestError, + match=re.compile( + '.*Guided decoding .* multi-step decoding.*')): + await client.completions.create( + model=model_name, + prompt="Hello", + max_tokens=5, + temperature=0.0, + extra_body={"response_format": { + "type": "json_object" + }}) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3fd34fadee1ca..edef1f30a9e91 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -829,6 +829,13 @@ def add_request( raise ValueError(f"Got priority {priority} but " "Priority scheduling is not enabled.") + if isinstance(params, SamplingParams) \ + and (params.guided_decoding or params.logits_processors) \ + and self.scheduler_config.num_scheduler_steps > 1: + raise ValueError( + "Guided decoding and logits processors are not supported " + "in multi-step decoding") + if arrival_time is None: arrival_time = time.time() diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 5e191c6e715e0..5c6df5aaf5446 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -485,8 +485,8 @@ def __repr__(self) -> str: f"skip_special_tokens={self.skip_special_tokens}, " "spaces_between_special_tokens=" f"{self.spaces_between_special_tokens}, " - f"truncate_prompt_tokens={self.truncate_prompt_tokens}), " - f"guided_decoding={self.guided_decoding}") + f"truncate_prompt_tokens={self.truncate_prompt_tokens}, " + f"guided_decoding={self.guided_decoding})") class BeamSearchParams(