Skip to content

Commit

Permalink
[BugFix] Fix server crash on empty prompt (vllm-project#7746)
Browse files Browse the repository at this point in the history
Signed-off-by: Max de Bayser <[email protected]>
  • Loading branch information
maxdebayser authored Aug 23, 2024
1 parent faeddb5 commit e25fee5
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 0 deletions.
9 changes: 9 additions & 0 deletions tests/entrypoints/llm/test_prompt_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest

from vllm import LLM


def test_empty_prompt():
llm = LLM(model="gpt2")
with pytest.raises(ValueError, match='Prompt cannot be empty'):
llm.generate([""])
22 changes: 22 additions & 0 deletions tests/entrypoints/openai/test_prompt_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# imports for guided decoding tests
import re

import openai
import pytest

from ...utils import RemoteOpenAIServer


@pytest.mark.asyncio
async def test_empty_prompt():
model_name = "gpt2"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()

with pytest.raises(openai.BadRequestError,
match=re.compile('.+Prompt cannot be empty.+')):
await client.completions.create(model=model_name,
prompt="",
max_tokens=5,
temperature=0.0)
8 changes: 8 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@ def _add_processed_request(
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
) -> None:
self._validate_model_inputs(processed_inputs)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
Expand Down Expand Up @@ -1647,3 +1648,10 @@ def is_encoder_decoder_model(self):

def is_embedding_model(self):
return self.model_config.is_embedding_model

def _validate_model_inputs(self, inputs: Union[LLMInputs,
EncoderDecoderLLMInputs]):
prompt_key = "encoder_prompt_token_ids" \
if self.is_encoder_decoder_model() else "prompt_token_ids"
if not inputs.get(prompt_key):
raise ValueError("Prompt cannot be empty")

0 comments on commit e25fee5

Please sign in to comment.