From cae94b920c2a226e70b775be4af5063388c8951e Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 21 Jul 2024 21:01:05 -0400 Subject: [PATCH] API: Add ability to use request IDs Identify which request is being processed to help users disambiguate which logs correspond to which request. Signed-off-by: kingbri --- backends/exllamav2/model.py | 14 +++-- common/gen_logging.py | 6 +- endpoints/OAI/router.py | 10 ++-- endpoints/OAI/utils/chat_completion.py | 77 ++++++++++++++------------ endpoints/OAI/utils/completion.py | 50 +++++++++++++---- endpoints/server.py | 12 +++- 6 files changed, 112 insertions(+), 57 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 74cf7130..200be6b3 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -828,10 +828,10 @@ def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor): return dict(zip_longest(top_tokens, cleaned_values)) - async def generate(self, prompt: str, **kwargs): + async def generate(self, prompt: str, request_id: str, **kwargs): """Generate a response to a prompt""" generations = [] - async for generation in self.generate_gen(prompt, **kwargs): + async for generation in self.generate_gen(prompt, request_id, **kwargs): generations.append(generation) joined_generation = { @@ -881,7 +881,11 @@ def check_unsupported_settings(self, **kwargs): return kwargs async def generate_gen( - self, prompt: str, abort_event: Optional[asyncio.Event] = None, **kwargs + self, + prompt: str, + request_id: str, + abort_event: Optional[asyncio.Event] = None, + **kwargs, ): """ Create generator function for prompt completion. @@ -1116,6 +1120,7 @@ async def generate_gen( # Log generation options to console # Some options are too large, so log the args instead log_generation_params( + request_id=request_id, max_tokens=max_tokens, min_tokens=min_tokens, stream=kwargs.get("stream"), @@ -1138,9 +1143,10 @@ async def generate_gen( ) # Log prompt to console - log_prompt(prompt, negative_prompt) + log_prompt(prompt, request_id, negative_prompt) # Create and add a new job + # Don't use the request ID here as there can be multiple jobs per request job_id = uuid.uuid4().hex job = ExLlamaV2DynamicJobAsync( self.generator, diff --git a/common/gen_logging.py b/common/gen_logging.py index fbf10f66..94c4405b 100644 --- a/common/gen_logging.py +++ b/common/gen_logging.py @@ -51,11 +51,13 @@ def log_generation_params(**kwargs): logger.info(f"Generation options: {kwargs}\n") -def log_prompt(prompt: str, negative_prompt: Optional[str]): +def log_prompt(prompt: str, request_id: str, negative_prompt: Optional[str]): """Logs the prompt to console.""" if PREFERENCES.prompt: formatted_prompt = "\n" + prompt - logger.info(f"Prompt: {formatted_prompt if prompt else 'Empty'}\n") + logger.info( + f"Prompt (ID: {request_id}): {formatted_prompt if prompt else 'Empty'}\n" + ) if negative_prompt: formatted_negative_prompt = "\n" + negative_prompt diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 1c0a7c63..1297d879 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -107,12 +107,14 @@ async def completion_request( ping=maxsize, ) else: - generate_task = asyncio.create_task(generate_completion(data, model_path)) + generate_task = asyncio.create_task( + generate_completion(data, request, model_path) + ) response = await run_with_request_disconnect( request, generate_task, - disconnect_message="Completion generation cancelled by user.", + disconnect_message=f"Completion {request.state.id} cancelled by user.", ) return response @@ -161,13 +163,13 @@ async def chat_completion_request( ) else: generate_task = asyncio.create_task( - generate_chat_completion(prompt, data, model_path) + generate_chat_completion(prompt, data, request, model_path) ) response = await run_with_request_disconnect( request, generate_task, - disconnect_message="Chat completion generation cancelled by user.", + disconnect_message=f"Chat completion {request.state.id} cancelled by user.", ) return response diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 10f25cdd..b9c6f71c 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -5,7 +5,6 @@ from asyncio import CancelledError from copy import deepcopy from typing import List, Optional -from uuid import uuid4 from fastapi import HTTPException, Request from jinja2 import TemplateError @@ -30,9 +29,12 @@ ChatCompletionStreamChoice, ) from endpoints.OAI.types.common import UsageStats +from endpoints.OAI.utils.completion import _stream_collector -def _create_response(generations: List[dict], model_name: Optional[str]): +def _create_response( + request_id: str, generations: List[dict], model_name: Optional[str] +): """Create a chat completion response from the provided text.""" prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0) @@ -77,6 +79,7 @@ def _create_response(generations: List[dict], model_name: Optional[str]): choices.append(choice) response = ChatCompletionResponse( + id=f"chatcmpl-{request_id}", choices=choices, model=unwrap(model_name, ""), usage=UsageStats( @@ -90,7 +93,7 @@ def _create_response(generations: List[dict], model_name: Optional[str]): def _create_stream_chunk( - const_id: str, + request_id: str, generation: Optional[dict] = None, model_name: Optional[str] = None, is_usage_chunk: bool = False, @@ -150,7 +153,7 @@ def _create_stream_chunk( choices.append(choice) chunk = ChatCompletionStreamChunk( - id=const_id, + id=f"chatcmpl-{request_id}", choices=choices, model=unwrap(model_name, ""), usage=usage_stats, @@ -235,39 +238,18 @@ def format_prompt_with_template(data: ChatCompletionRequest): raise HTTPException(400, error_message) from exc -async def _stream_collector( - task_idx: int, - gen_queue: asyncio.Queue, - prompt: str, - abort_event: asyncio.Event, - **kwargs, -): - """Collects a stream and places results in a common queue""" - - try: - new_generation = model.container.generate_gen(prompt, abort_event, **kwargs) - async for generation in new_generation: - generation["index"] = task_idx - - await gen_queue.put(generation) - - if "finish_reason" in generation: - break - except Exception as e: - await gen_queue.put(e) - - async def stream_generate_chat_completion( prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path ): """Generator for the generation process.""" - const_id = f"chatcmpl-{uuid4().hex}" abort_event = asyncio.Event() gen_queue = asyncio.Queue() gen_tasks: List[asyncio.Task] = [] disconnect_task = asyncio.create_task(request_disconnect_loop(request)) try: + logger.info(f"Recieved chat completion streaming request {request.state.id}") + gen_params = data.to_gen_params() for n in range(0, data.n): @@ -277,7 +259,14 @@ async def stream_generate_chat_completion( task_gen_params = gen_params gen_task = asyncio.create_task( - _stream_collector(n, gen_queue, prompt, abort_event, **task_gen_params) + _stream_collector( + n, + gen_queue, + prompt, + request.state.id, + abort_event, + **task_gen_params, + ) ) gen_tasks.append(gen_task) @@ -286,7 +275,9 @@ async def stream_generate_chat_completion( while True: if disconnect_task.done(): abort_event.set() - handle_request_disconnect("Completion generation cancelled by user.") + handle_request_disconnect( + f"Chat completion generation {request.state.id} cancelled by user." + ) generation = await gen_queue.get() @@ -294,7 +285,9 @@ async def stream_generate_chat_completion( if isinstance(generation, Exception): raise generation - response = _create_stream_chunk(const_id, generation, model_path.name) + response = _create_stream_chunk( + request.state.id, generation, model_path.name + ) yield response.model_dump_json() # Check if all tasks are completed @@ -302,10 +295,17 @@ async def stream_generate_chat_completion( # Send a usage chunk if data.stream_options and data.stream_options.include_usage: usage_chunk = _create_stream_chunk( - const_id, generation, model_path.name, is_usage_chunk=True + request.state.id, + generation, + model_path.name, + is_usage_chunk=True, ) yield usage_chunk.model_dump_json() + logger.info( + f"Finished chat completion streaming request {request.state.id}" + ) + yield "[DONE]" break except CancelledError: @@ -320,7 +320,7 @@ async def stream_generate_chat_completion( async def generate_chat_completion( - prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path + prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path ): gen_tasks: List[asyncio.Task] = [] gen_params = data.to_gen_params() @@ -335,16 +335,23 @@ async def generate_chat_completion( task_gen_params = gen_params gen_tasks.append( - asyncio.create_task(model.container.generate(prompt, **task_gen_params)) + asyncio.create_task( + model.container.generate( + prompt, request.state.id, **task_gen_params + ) + ) ) generations = await asyncio.gather(*gen_tasks) - response = _create_response(generations, model_path.name) + response = _create_response(request.state.id, generations, model_path.name) + + logger.info(f"Finished chat completion request {request.state.id}") return response except Exception as exc: error_message = handle_request_error( - "Chat completion aborted. Maybe the model was unloaded? " + f"Chat completion {request.state.id} aborted. " + "Maybe the model was unloaded? " "Please check the server console." ).error.message diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 2b5dfbf2..23f26922 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -7,6 +7,8 @@ from fastapi import HTTPException, Request from typing import List, Union +from loguru import logger + from common import model from common.networking import ( get_generator_error, @@ -24,7 +26,9 @@ from endpoints.OAI.types.common import UsageStats -def _create_response(generations: Union[dict, List[dict]], model_name: str = ""): +def _create_response( + request_id: str, generations: Union[dict, List[dict]], model_name: str = "" +): """Create a completion response from the provided choices.""" # Convert the single choice object into a list @@ -61,6 +65,7 @@ def _create_response(generations: Union[dict, List[dict]], model_name: str = "") completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0) response = CompletionResponse( + id=f"cmpl-{request_id}", choices=choices, model=model_name, usage=UsageStats( @@ -77,13 +82,16 @@ async def _stream_collector( task_idx: int, gen_queue: asyncio.Queue, prompt: str, + request_id: str, abort_event: asyncio.Event, **kwargs, ): """Collects a stream and places results in a common queue""" try: - new_generation = model.container.generate_gen(prompt, abort_event, **kwargs) + new_generation = model.container.generate_gen( + prompt, request_id, abort_event, **kwargs + ) async for generation in new_generation: generation["index"] = task_idx @@ -106,6 +114,8 @@ async def stream_generate_completion( disconnect_task = asyncio.create_task(request_disconnect_loop(request)) try: + logger.info(f"Recieved streaming completion request {request.state.id}") + gen_params = data.to_gen_params() for n in range(0, data.n): @@ -116,7 +126,12 @@ async def stream_generate_completion( gen_task = asyncio.create_task( _stream_collector( - n, gen_queue, data.prompt, abort_event, **task_gen_params + n, + gen_queue, + data.prompt, + request.state.id, + abort_event, + **task_gen_params, ) ) @@ -126,7 +141,9 @@ async def stream_generate_completion( while True: if disconnect_task.done(): abort_event.set() - handle_request_disconnect("Completion generation cancelled by user.") + handle_request_disconnect( + f"Completion generation {request.state.id} cancelled by user." + ) generation = await gen_queue.get() @@ -134,31 +151,38 @@ async def stream_generate_completion( if isinstance(generation, Exception): raise generation - response = _create_response(generation, model_path.name) + response = _create_response(request.state.id, generation, model_path.name) yield response.model_dump_json() # Check if all tasks are completed if all(task.done() for task in gen_tasks) and gen_queue.empty(): yield "[DONE]" + logger.info(f"Finished streaming completion request {request.state.id}") break except CancelledError: # Get out if the request gets disconnected abort_event.set() - handle_request_disconnect("Completion generation cancelled by user.") + handle_request_disconnect( + f"Completion generation {request.state.id} cancelled by user." + ) except Exception: yield get_generator_error( - "Completion aborted. Please check the server console." + f"Completion {request.state.id} aborted. Please check the server console." ) -async def generate_completion(data: CompletionRequest, model_path: pathlib.Path): +async def generate_completion( + data: CompletionRequest, request: Request, model_path: pathlib.Path +): """Non-streaming generate for completions""" gen_tasks: List[asyncio.Task] = [] gen_params = data.to_gen_params() try: + logger.info(f"Recieved completion request {request.state.id}") + for n in range(0, data.n): # Deepcopy gen params above the first index # to ensure nested structures aren't shared @@ -169,17 +193,21 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path) gen_tasks.append( asyncio.create_task( - model.container.generate(data.prompt, **task_gen_params) + model.container.generate( + data.prompt, request.state.id, **task_gen_params + ) ) ) generations = await asyncio.gather(*gen_tasks) - response = _create_response(generations, model_path.name) + response = _create_response(request.state.id, generations, model_path.name) + + logger.info(f"Finished completion request {request.state.id}") return response except Exception as exc: error_message = handle_request_error( - "Completion aborted. Maybe the model was unloaded? " + f"Completion {request.state.id} aborted. Maybe the model was unloaded? " "Please check the server console." ).error.message diff --git a/endpoints/server.py b/endpoints/server.py index 7ceb2086..ec59455f 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -1,5 +1,6 @@ +from uuid import uuid4 import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from loguru import logger @@ -25,6 +26,15 @@ ) +@app.middleware("http") +async def add_request_id(request: Request, call_next): + """Middleware to append an ID to a request""" + + request.state.id = uuid4().hex + response = await call_next(request) + return response + + def setup_app(): """Includes the correct routers for startup"""