From 4a8c0c3a968878a747a051d0789848c1277f05d2 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 6 Dec 2024 16:12:34 -0700 Subject: [PATCH] :sparkles: Use request id from header Signed-off-by: Joe Runde --- vllm/entrypoints/openai/api_server.py | 4 ++-- vllm/entrypoints/openai/serving_chat.py | 3 ++- vllm/entrypoints/openai/serving_completion.py | 4 ++-- vllm/entrypoints/openai/serving_embedding.py | 4 ++-- vllm/entrypoints/openai/serving_engine.py | 10 +++++++++- vllm/entrypoints/openai/serving_score.py | 4 ++-- vllm/entrypoints/openai/serving_tokenization.py | 9 ++++++--- 7 files changed, 25 insertions(+), 13 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 6bc31ef83ded4..b5d4ab16a6750 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -305,7 +305,7 @@ async def health(raw_request: Request) -> Response: async def tokenize(request: TokenizeRequest, raw_request: Request): handler = tokenization(raw_request) - generator = await handler.create_tokenize(request) + generator = await handler.create_tokenize(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -319,7 +319,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): async def detokenize(request: DetokenizeRequest, raw_request: Request): handler = tokenization(raw_request) - generator = await handler.create_detokenize(request) + generator = await handler.create_detokenize(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 54ca0463bcab1..0af7613a473a4 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -176,7 +176,8 @@ async def create_chat_completion( logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) - request_id = f"chatcmpl-{request.request_id}" + request_id = "chatcmpl-" \ + f"{self._base_request_id(raw_request, request.request_id)}" request_metadata = RequestResponseMetadata(request_id=request_id) if raw_request: diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index fc1c4908d6650..c54d5f07cf58c 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -30,7 +30,7 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import merge_async_iterators, random_uuid +from vllm.utils import merge_async_iterators logger = init_logger(__name__) @@ -86,7 +86,7 @@ async def create_completion( "suffix is not currently supported") model_name = self.base_model_paths[0].name - request_id = f"cmpl-{random_uuid()}" + request_id = f"cmpl-{self._base_request_id(raw_request)}" created_time = int(time.time()) request_metadata = RequestResponseMetadata(request_id=request_id) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 2cbb252610e39..3f7b75e893cad 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -19,7 +19,7 @@ from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing from vllm.logger import init_logger from vllm.outputs import PoolingOutput, PoolingRequestOutput -from vllm.utils import merge_async_iterators, random_uuid +from vllm.utils import merge_async_iterators logger = init_logger(__name__) @@ -110,7 +110,7 @@ async def create_embedding( "dimensions is currently not supported") model_name = request.model - request_id = f"embd-{random_uuid()}" + request_id = f"embd-{self._base_request_id(raw_request)}" created_time = int(time.monotonic()) truncate_prompt_tokens = None diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8232c6116c1bd..908885ee78491 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -6,6 +6,7 @@ from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple, TypedDict, Union) +from fastapi import Request from pydantic import Field from starlette.datastructures import Headers from typing_extensions import Annotated @@ -47,7 +48,7 @@ from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import AtomicCounter, is_list_of, make_async +from vllm.utils import AtomicCounter, is_list_of, make_async, random_uuid logger = init_logger(__name__) @@ -565,6 +566,13 @@ async def _get_trace_headers( return None + @staticmethod + def _base_request_id(raw_request: Request, + default: Optional[str] = None) -> Optional[str]: + """Pulls the request id to use from a header, if provided""" + default = default or random_uuid() + return raw_request.headers.get("X-Request-Id", default) + @staticmethod def _get_decoded_token(logprob: Logprob, token_id: int, diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index a1f14449ba9c3..fed06fa452955 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -15,7 +15,7 @@ from vllm.logger import init_logger from vllm.outputs import PoolingRequestOutput from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer -from vllm.utils import make_async, merge_async_iterators, random_uuid +from vllm.utils import make_async, merge_async_iterators logger = init_logger(__name__) @@ -102,7 +102,7 @@ async def create_score( return error_check_ret model_name = request.model - request_id = f"score-{random_uuid()}" + request_id = f"score-{self._base_request_id(raw_request)}" created_time = int(time.monotonic()) truncate_prompt_tokens = request.truncate_prompt_tokens diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 9c3dc2c98b2dd..2e849333680d4 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,5 +1,7 @@ from typing import Final, List, Optional, Union +from fastapi import Request + from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption @@ -17,7 +19,6 @@ LoRAModulePath, OpenAIServing) from vllm.logger import init_logger -from vllm.utils import random_uuid logger = init_logger(__name__) @@ -48,12 +49,13 @@ def __init__( async def create_tokenize( self, request: TokenizeRequest, + raw_request: Request, ) -> Union[TokenizeResponse, ErrorResponse]: error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret - request_id = f"tokn-{random_uuid()}" + request_id = f"tokn-{self._base_request_id(raw_request)}" try: ( @@ -112,12 +114,13 @@ async def create_tokenize( async def create_detokenize( self, request: DetokenizeRequest, + raw_request: Request, ) -> Union[DetokenizeResponse, ErrorResponse]: error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret - request_id = f"tokn-{random_uuid()}" + request_id = f"tokn-{self._base_request_id(raw_request)}" ( lora_request,