From 551c33b7bd72ccf4d96cd723986cb6c4ca8c1f75 Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Thu, 26 Sep 2024 16:15:48 -0300 Subject: [PATCH] Adds truncate_prompt_tokens param for embeddings creation Signed-off-by: Flavia Beo --- vllm/entrypoints/openai/protocol.py | 1 + vllm/entrypoints/openai/serving_embedding.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7e9f53b1816d1..54aa5a04e0d28 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -614,6 +614,7 @@ class EmbeddingRequest(OpenAIBaseModel): encoding_format: Literal["float", "base64"] = "float" dimensions: Optional[int] = None user: Optional[str] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None # doc: begin-embedding-pooling-params additional_data: Optional[Any] = None diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 12ec6be03cd62..47d7a08277d84 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -110,6 +110,20 @@ async def create_embedding( request_id = f"embd-{random_uuid()}" created_time = int(time.monotonic()) + truncate_prompt_tokens = None + + if request.truncate_prompt_tokens is not None: + if request.truncate_prompt_tokens < self.max_model_len: + truncate_prompt_tokens = request.truncate_prompt_tokens + elif request.truncate_prompt_tokens > self.max_model_len: + raise ValueError( + "truncate_prompt_tokens value is greater than max_model_len. " + "Please, select a smaller truncation size.") + else: + logger.warning( + "truncating input tokens to max_model_len") + truncate_prompt_tokens = self.max_model_len + # Schedule the request and get the result generator. generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = [] try: @@ -128,6 +142,7 @@ async def create_embedding( request, tokenizer, request.input, + truncate_prompt_tokens )) for i, prompt_inputs in enumerate(prompts):