diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7e9f53b1816d1..54aa5a04e0d28 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -614,6 +614,7 @@ class EmbeddingRequest(OpenAIBaseModel): encoding_format: Literal["float", "base64"] = "float" dimensions: Optional[int] = None user: Optional[str] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None # doc: begin-embedding-pooling-params additional_data: Optional[Any] = None diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 12ec6be03cd62..0cc740e27efed 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -110,6 +110,13 @@ async def create_embedding( request_id = f"embd-{random_uuid()}" created_time = int(time.monotonic()) + truncate_prompt_tokens = None + + if request.truncate_prompt_tokens is not None: + if request.truncate_prompt_tokens < self.max_model_len: + truncate_prompt_tokens = request.truncate_prompt_tokens + else: + truncate_prompt_tokens = self.max_model_len # Schedule the request and get the result generator. generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = [] try: @@ -128,6 +135,7 @@ async def create_embedding( request, tokenizer, request.input, + truncate_prompt_tokens )) for i, prompt_inputs in enumerate(prompts):