From 46f316f1c16473ad33848783fb68a82cf691ffbb Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 Nov 2024 04:14:39 +0000 Subject: [PATCH] Optionally initialize request handlers --- vllm/entrypoints/openai/api_server.py | 95 ++++++++++++-------- vllm/entrypoints/openai/run_batch.py | 33 +++++-- vllm/entrypoints/openai/serving_embedding.py | 15 ---- 3 files changed, 82 insertions(+), 61 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index a51d030016fa2..95fd56d916050 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -11,7 +11,7 @@ from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import AsyncIterator, Set +from typing import AsyncIterator, Optional, Set import uvloop from fastapi import APIRouter, FastAPI, Request @@ -51,7 +51,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding -from vllm.entrypoints.openai.serving_engine import BaseModelPath +from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.entrypoints.openai.tool_parsers import ToolParserManager @@ -248,20 +248,25 @@ def mount_metrics(app: FastAPI): app.routes.append(metrics_route) -def chat(request: Request) -> OpenAIServingChat: +def base(request: Request) -> OpenAIServing: + # Reuse the existing instance + return tokenization(request) + + +def chat(request: Request) -> Optional[OpenAIServingChat]: return request.app.state.openai_serving_chat -def completion(request: Request) -> OpenAIServingCompletion: +def completion(request: Request) -> Optional[OpenAIServingCompletion]: return request.app.state.openai_serving_completion -def tokenization(request: Request) -> OpenAIServingTokenization: - return request.app.state.openai_serving_tokenization +def embedding(request: Request) -> Optional[OpenAIServingEmbedding]: + return request.app.state.openai_serving_embedding -def embedding(request: Request) -> OpenAIServingEmbedding: - return request.app.state.openai_serving_embedding +def tokenization(request: Request) -> OpenAIServingTokenization: + return request.app.state.openai_serving_tokenization def engine_client(request: Request) -> EngineClient: @@ -277,7 +282,9 @@ async def health(raw_request: Request) -> Response: @router.post("/tokenize") async def tokenize(request: TokenizeRequest, raw_request: Request): - generator = await tokenization(raw_request).create_tokenize(request) + handler = tokenization(raw_request) + + generator = await handler.create_tokenize(request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -289,7 +296,9 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): @router.post("/detokenize") async def detokenize(request: DetokenizeRequest, raw_request: Request): - generator = await tokenization(raw_request).create_detokenize(request) + handler = tokenization(raw_request) + + generator = await handler.create_detokenize(request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -301,7 +310,9 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request): @router.get("/v1/models") async def show_available_models(raw_request: Request): - models = await completion(raw_request).show_available_models() + handler = base(raw_request) + + models = await handler.show_available_models() return JSONResponse(content=models.model_dump()) @@ -314,9 +325,12 @@ async def show_version(): @router.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): + handler = chat(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Chat Completions API") - generator = await chat(raw_request).create_chat_completion( - request, raw_request) + generator = await handler.create_chat_completion(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -330,8 +344,12 @@ async def create_chat_completion(request: ChatCompletionRequest, @router.post("/v1/completions") async def create_completion(request: CompletionRequest, raw_request: Request): - generator = await completion(raw_request).create_completion( - request, raw_request) + handler = completion(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Completions API") + + generator = await handler.create_completion(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -343,8 +361,12 @@ async def create_completion(request: CompletionRequest, raw_request: Request): @router.post("/v1/embeddings") async def create_embedding(request: EmbeddingRequest, raw_request: Request): - generator = await embedding(raw_request).create_embedding( - request, raw_request) + handler = embedding(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Embeddings API") + + generator = await handler.create_embedding(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -382,30 +404,26 @@ async def stop_profile(raw_request: Request): @router.post("/v1/load_lora_adapter") async def load_lora_adapter(request: LoadLoraAdapterRequest, raw_request: Request): - response = await chat(raw_request).load_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.code) - - response = await completion(raw_request).load_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.code) + for route in [chat, completion, embedding]: + handler = route(raw_request) + if handler is not None: + response = await handler.load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) return Response(status_code=200, content=response) @router.post("/v1/unload_lora_adapter") async def unload_lora_adapter(request: UnloadLoraAdapterRequest, raw_request: Request): - response = await chat(raw_request).unload_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.code) - - response = await completion(raw_request).unload_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.code) + for route in [chat, completion, embedding]: + handler = route(raw_request) + if handler is not None: + response = await handler.unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) return Response(status_code=200, content=response) @@ -501,7 +519,8 @@ def init_app_state( chat_template=args.chat_template, return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, - tool_parser=args.tool_call_parser) + tool_parser=args.tool_call_parser, + ) if model_config.task == "generate" else None state.openai_serving_completion = OpenAIServingCompletion( engine_client, model_config, @@ -510,14 +529,14 @@ def init_app_state( prompt_adapters=args.prompt_adapters, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, - ) + ) if model_config.task == "generate" else None state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, model_config, base_model_paths, request_logger=request_logger, chat_template=args.chat_template, - ) + ) if model_config.task == "embedding" else None state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, model_config, diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 41b9d92f1166d..a64467a311523 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -217,14 +217,14 @@ async def main(args): prompt_adapters=None, request_logger=request_logger, chat_template=None, - ) + ) if model_config.task == "generate" else None openai_serving_embedding = OpenAIServingEmbedding( engine, model_config, base_model_paths, request_logger=request_logger, chat_template=None, - ) + ) if model_config.task == "embedding" else None tracker = BatchProgressTracker() logger.info("Reading batch from %s...", args.input_file) @@ -241,14 +241,31 @@ async def main(args): # Determine the type of request and run it. if request.url == "/v1/chat/completions": - response_futures.append( - run_request(openai_serving_chat.create_chat_completion, - request, tracker)) + handler_fn = (None if openai_serving_chat is None else + openai_serving_chat.create_chat_completion) + if handler_fn is None: + response_futures.append( + make_async_error_request_output( + request, + error_msg= + "The model does not support Chat Completions API", + )) + continue + + response_futures.append(run_request(handler_fn, request, tracker)) tracker.submitted() elif request.url == "/v1/embeddings": - response_futures.append( - run_request(openai_serving_embedding.create_embedding, request, - tracker)) + handler_fn = (None if openai_serving_embedding is None else + openai_serving_embedding.create_embedding) + if handler_fn is None: + response_futures.append( + make_async_error_request_output( + request, + error_msg="The model does not support Embeddings API", + )) + continue + + response_futures.append(run_request(handler_fn, request, tracker)) tracker.submitted() else: response_futures.append( diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 779fce3c04869..917856cd2b2dd 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -24,17 +24,6 @@ logger = init_logger(__name__) -def check_embedding_mode(model_config: ModelConfig) -> bool: - embedding_mode = model_config.task == "embedding" - - if not embedding_mode: - logger.warning("embedding_mode is False. Embedding API will not work.") - else: - logger.info("Activating the server engine with embedding enabled.") - - return embedding_mode - - def _get_embedding( output: EmbeddingOutput, encoding_format: Literal["float", "base64"], @@ -98,8 +87,6 @@ def __init__( self.chat_template = load_chat_template(chat_template) - self._enabled = check_embedding_mode(model_config) - async def create_embedding( self, request: EmbeddingRequest, @@ -111,8 +98,6 @@ async def create_embedding( See https://platform.openai.com/docs/api-reference/embeddings/create for the API specification. This API mimics the OpenAI Embedding API. """ - if not self._enabled: - return self.create_error_response("Embedding API disabled") error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret