Skip to content

Commit

Permalink
fix: empty prompt crashing the server (#912)
Browse files Browse the repository at this point in the history
Co-authored-by: Max de Bayser <[email protected]>
  • Loading branch information
AlpinDale and maxdebayser authored Dec 18, 2024
1 parent 673621a commit 16e5b2b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 0 deletions.
9 changes: 9 additions & 0 deletions aphrodite/engine/aphrodite_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,8 @@ def _add_processed_request(
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
# to prevent empty prompts crashing the engine
self._validate_model_inputs(processed_inputs)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
Expand Down Expand Up @@ -1467,5 +1469,12 @@ def is_encoder_decoder_model(self):
def is_embedding_model(self):
return self.model_config.is_embedding_model

def _validate_model_inputs(self, inputs: Union[LLMInputs,
EncoderDecoderLLMInputs]):
prompt_key = "encoder_prompt_token_ids" \
if self.is_encoder_decoder_model() else "prompt_token_ids"
if not inputs.get(prompt_key):
raise ValueError("Prompt cannot be empty")


setup_logger()
9 changes: 9 additions & 0 deletions tests/endpoints/llm/test_prompt_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest

from aphrodite import LLM


def test_empty_prompt():
llm = LLM(model="gpt2")
with pytest.raises(ValueError, match='Prompt cannot be empty'):
llm.generate([""])
21 changes: 21 additions & 0 deletions tests/endpoints/openai/test_prompt_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# imports for guided decoding tests
import re

import openai
import pytest

from ...utils import RemoteOpenAIServer


@pytest.mark.asyncio
async def test_empty_prompt():
model_name = "gpt2"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
with pytest.raises(openai.BadRequestError,
match=re.compile('.+Prompt cannot be empty.+')):
await client.completions.create(model=model_name,
prompt="",
max_tokens=5,
temperature=0.0)

0 comments on commit 16e5b2b

Please sign in to comment.