Skip to content

Commit

Permalink
Optionally initialize request handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 committed Nov 1, 2024
1 parent c3ba030 commit 46f316f
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 61 deletions.
95 changes: 57 additions & 38 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from typing import AsyncIterator, Set
from typing import AsyncIterator, Optional, Set

import uvloop
from fastapi import APIRouter, FastAPI, Request
Expand Down Expand Up @@ -51,7 +51,7 @@
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
Expand Down Expand Up @@ -248,20 +248,25 @@ def mount_metrics(app: FastAPI):
app.routes.append(metrics_route)


def chat(request: Request) -> OpenAIServingChat:
def base(request: Request) -> OpenAIServing:
# Reuse the existing instance
return tokenization(request)


def chat(request: Request) -> Optional[OpenAIServingChat]:
return request.app.state.openai_serving_chat


def completion(request: Request) -> OpenAIServingCompletion:
def completion(request: Request) -> Optional[OpenAIServingCompletion]:
return request.app.state.openai_serving_completion


def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
return request.app.state.openai_serving_embedding


def embedding(request: Request) -> OpenAIServingEmbedding:
return request.app.state.openai_serving_embedding
def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization


def engine_client(request: Request) -> EngineClient:
Expand All @@ -277,7 +282,9 @@ async def health(raw_request: Request) -> Response:

@router.post("/tokenize")
async def tokenize(request: TokenizeRequest, raw_request: Request):
generator = await tokenization(raw_request).create_tokenize(request)
handler = tokenization(raw_request)

generator = await handler.create_tokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
Expand All @@ -289,7 +296,9 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):

@router.post("/detokenize")
async def detokenize(request: DetokenizeRequest, raw_request: Request):
generator = await tokenization(raw_request).create_detokenize(request)
handler = tokenization(raw_request)

generator = await handler.create_detokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
Expand All @@ -301,7 +310,9 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request):

@router.get("/v1/models")
async def show_available_models(raw_request: Request):
models = await completion(raw_request).show_available_models()
handler = base(raw_request)

models = await handler.show_available_models()
return JSONResponse(content=models.model_dump())


Expand All @@ -314,9 +325,12 @@ async def show_version():
@router.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
handler = chat(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Chat Completions API")

generator = await chat(raw_request).create_chat_completion(
request, raw_request)
generator = await handler.create_chat_completion(request, raw_request)

if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
Expand All @@ -330,8 +344,12 @@ async def create_chat_completion(request: ChatCompletionRequest,

@router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
generator = await completion(raw_request).create_completion(
request, raw_request)
handler = completion(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Completions API")

generator = await handler.create_completion(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
Expand All @@ -343,8 +361,12 @@ async def create_completion(request: CompletionRequest, raw_request: Request):

@router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
generator = await embedding(raw_request).create_embedding(
request, raw_request)
handler = embedding(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Embeddings API")

generator = await handler.create_embedding(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
Expand Down Expand Up @@ -382,30 +404,26 @@ async def stop_profile(raw_request: Request):
@router.post("/v1/load_lora_adapter")
async def load_lora_adapter(request: LoadLoraAdapterRequest,
raw_request: Request):
response = await chat(raw_request).load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)

response = await completion(raw_request).load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
for route in [chat, completion, embedding]:
handler = route(raw_request)
if handler is not None:
response = await handler.load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)

return Response(status_code=200, content=response)

@router.post("/v1/unload_lora_adapter")
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
raw_request: Request):
response = await chat(raw_request).unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)

response = await completion(raw_request).unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
for route in [chat, completion, embedding]:
handler = route(raw_request)
if handler is not None:
response = await handler.unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)

return Response(status_code=200, content=response)

Expand Down Expand Up @@ -501,7 +519,8 @@ def init_app_state(
chat_template=args.chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser)
tool_parser=args.tool_call_parser,
) if model_config.task == "generate" else None
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config,
Expand All @@ -510,14 +529,14 @@ def init_app_state(
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
) if model_config.task == "generate" else None
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
base_model_paths,
request_logger=request_logger,
chat_template=args.chat_template,
)
) if model_config.task == "embedding" else None
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
model_config,
Expand Down
33 changes: 25 additions & 8 deletions vllm/entrypoints/openai/run_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,14 @@ async def main(args):
prompt_adapters=None,
request_logger=request_logger,
chat_template=None,
)
) if model_config.task == "generate" else None
openai_serving_embedding = OpenAIServingEmbedding(
engine,
model_config,
base_model_paths,
request_logger=request_logger,
chat_template=None,
)
) if model_config.task == "embedding" else None

tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)
Expand All @@ -241,14 +241,31 @@ async def main(args):

# Determine the type of request and run it.
if request.url == "/v1/chat/completions":
response_futures.append(
run_request(openai_serving_chat.create_chat_completion,
request, tracker))
handler_fn = (None if openai_serving_chat is None else
openai_serving_chat.create_chat_completion)
if handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg=
"The model does not support Chat Completions API",
))
continue

response_futures.append(run_request(handler_fn, request, tracker))
tracker.submitted()
elif request.url == "/v1/embeddings":
response_futures.append(
run_request(openai_serving_embedding.create_embedding, request,
tracker))
handler_fn = (None if openai_serving_embedding is None else
openai_serving_embedding.create_embedding)
if handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Embeddings API",
))
continue

response_futures.append(run_request(handler_fn, request, tracker))
tracker.submitted()
else:
response_futures.append(
Expand Down
15 changes: 0 additions & 15 deletions vllm/entrypoints/openai/serving_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,6 @@
logger = init_logger(__name__)


def check_embedding_mode(model_config: ModelConfig) -> bool:
embedding_mode = model_config.task == "embedding"

if not embedding_mode:
logger.warning("embedding_mode is False. Embedding API will not work.")
else:
logger.info("Activating the server engine with embedding enabled.")

return embedding_mode


def _get_embedding(
output: EmbeddingOutput,
encoding_format: Literal["float", "base64"],
Expand Down Expand Up @@ -98,8 +87,6 @@ def __init__(

self.chat_template = load_chat_template(chat_template)

self._enabled = check_embedding_mode(model_config)

async def create_embedding(
self,
request: EmbeddingRequest,
Expand All @@ -111,8 +98,6 @@ async def create_embedding(
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
if not self._enabled:
return self.create_error_response("Embedding API disabled")
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
Expand Down

0 comments on commit 46f316f

Please sign in to comment.