From 12fe5bfb5c25d9404439a28d823a42632be350c6 Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Wed, 23 Oct 2024 02:07:30 +0800 Subject: [PATCH] [Frontend] Support custom request_id from request (#9550) Co-authored-by: Yuhong Guo Signed-off-by: Tyler Michael Smith --- vllm/entrypoints/openai/protocol.py | 6 ++++++ vllm/entrypoints/openai/serving_chat.py | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 06114339b7c69..733decf80a711 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -284,6 +284,12 @@ class ChatCompletionRequest(OpenAIBaseModel): "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " "if the served model does not use priority scheduling.")) + request_id: str = Field( + default_factory=lambda: f"{random_uuid()}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response.")) # doc: end-chat-completion-extra-params diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index c3fa0e44e5e8d..b9b240b64850e 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -38,7 +38,7 @@ from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import iterate_with_cancellation, random_uuid +from vllm.utils import iterate_with_cancellation logger = init_logger(__name__) @@ -176,7 +176,7 @@ async def create_chat_completion( "\"auto\" tool choice requires " "--enable-auto-tool-choice and --tool-call-parser to be set") - request_id = f"chat-{random_uuid()}" + request_id = f"chat-{request.request_id}" request_metadata = RequestResponseMetadata(request_id=request_id) if raw_request: