From 6b62d0eb9a0d51bfff411309a2c8e4ffa910371a Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Thu, 26 Sep 2024 16:15:48 -0300 Subject: [PATCH] Adds truncate_prompt_tokens param for embeddings creation Signed-off-by: Flavia Beo --- tests/entrypoints/openai/test_embedding.py | 48 ++++++++++++++++++++ vllm/entrypoints/openai/protocol.py | 1 + vllm/entrypoints/openai/serving_embedding.py | 19 ++++++-- 3 files changed, 63 insertions(+), 5 deletions(-) diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index 3baaeab2feeaf..0ba4a4eb474b0 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -144,3 +144,51 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI, 0].embedding assert responses_float.data[1].embedding == responses_default.data[ 1].embedding + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [EMBEDDING_MODEL_NAME], +) +async def test_single_embedding_truncation( + embedding_client: openai.AsyncOpenAI, model_name: str): + input_texts = [ + "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?", + ] + + # test single embedding + embeddings = await embedding_client.embeddings.create( + model=model_name, + input=input_texts, + extra_body={"truncate_prompt_tokens": 10}) + assert embeddings.id is not None + assert len(embeddings.data) == 1 + assert len(embeddings.data[0].embedding) == 4096 + assert embeddings.usage.completion_tokens == 0 + assert embeddings.usage.prompt_tokens == 10 + assert embeddings.usage.total_tokens == 10 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [EMBEDDING_MODEL_NAME], +) +async def test_single_embedding_truncation_invalid( + embedding_client: openai.AsyncOpenAI, model_name: str): + input_texts = [ + "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?", + ] + + with pytest.raises(openai.BadRequestError): + embeddings = await embedding_client.embeddings.create( + model=model_name, + input=input_texts, + extra_body={"truncate_prompt_tokens": 8193}) + assert "error" in embeddings.object + assert "truncate_prompt_tokens value is greater than max_model_len. "\ + "Please, select a smaller truncation size." in embeddings.message + + +## working on these tests -> run the async server with the PR changes? diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7e9f53b1816d1..54aa5a04e0d28 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -614,6 +614,7 @@ class EmbeddingRequest(OpenAIBaseModel): encoding_format: Literal["float", "base64"] = "float" dimensions: Optional[int] = None user: Optional[str] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None # doc: begin-embedding-pooling-params additional_data: Optional[Any] = None diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 12ec6be03cd62..2ca35efb328a4 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -110,6 +110,17 @@ async def create_embedding( request_id = f"embd-{random_uuid()}" created_time = int(time.monotonic()) + truncate_prompt_tokens = None + + if request.truncate_prompt_tokens is not None: + if request.truncate_prompt_tokens <= self.max_model_len: + truncate_prompt_tokens = request.truncate_prompt_tokens + else: + return self.create_error_response( + "truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size.") + # Schedule the request and get the result generator. generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = [] try: @@ -124,11 +135,9 @@ async def create_embedding( pooling_params = request.to_pooling_params() prompts = list( - self._tokenize_prompt_input_or_inputs( - request, - tokenizer, - request.input, - )) + self._tokenize_prompt_input_or_inputs(request, tokenizer, + request.input, + truncate_prompt_tokens)) for i, prompt_inputs in enumerate(prompts): request_id_item = f"{request_id}-{i}"