Skip to content

Commit

Permalink
Revert "[Frontend] Factor out code for running uvicorn" (vllm-project…
Browse files Browse the repository at this point in the history
…#7012)

Co-authored-by: Robert Shaw <[email protected]>
  • Loading branch information
simon-mo and robertgshaw2-neuralmagic authored Jul 31, 2024
1 parent 7ecf28a commit c15b011
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 117 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ files = [
"vllm/logging",
"vllm/multimodal",
"vllm/platforms",
"vllm/server",
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
Expand Down
74 changes: 24 additions & 50 deletions vllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,21 @@
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
"""
import asyncio

import json
import ssl
from argparse import Namespace
from typing import Any, AsyncGenerator, Optional
from typing import AsyncGenerator

import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.server import serve_http
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger("vllm.entrypoints.api_server")

Expand Down Expand Up @@ -83,50 +81,6 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
return JSONResponse(ret)


def build_app(args: Namespace) -> FastAPI:
global app

app.root_path = args.root_path
return app


async def init_app(
args: Namespace,
llm_engine: Optional[AsyncLLMEngine] = None,
) -> FastAPI:
app = build_app(args)

global engine

engine_args = AsyncEngineArgs.from_cli_args(args)
engine = (llm_engine
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.API_SERVER))

return app


async def run_server(args: Namespace,
llm_engine: Optional[AsyncLLMEngine] = None,
**uvicorn_kwargs: Any) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)

app = await init_app(args, llm_engine)
await serve_http(
app,
host=args.host,
port=args.port,
log_level=args.log_level,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
**uvicorn_kwargs,
)


if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default=None)
Expand All @@ -151,5 +105,25 @@ async def run_server(args: Namespace,
parser.add_argument("--log-level", type=str, default="debug")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.API_SERVER)

app.root_path = args.root_path

asyncio.run(run_server(args))
logger.info("Available routes are:")
for route in app.routes:
if not hasattr(route, 'methods'):
continue
methods = ', '.join(route.methods)
logger.info("Route: %s, Methods: %s", route.path, methods)

uvicorn.run(app,
host=args.host,
port=args.port,
log_level=args.log_level,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs)
72 changes: 51 additions & 21 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import importlib
import inspect
import re
from argparse import Namespace
import signal
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Any, Optional, Set
from typing import Optional, Set

from fastapi import APIRouter, FastAPI, Request
import fastapi
import uvicorn
from fastapi import APIRouter, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
Expand Down Expand Up @@ -36,7 +38,6 @@
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.logger import init_logger
from vllm.server import serve_http
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
from vllm.version import __version__ as VLLM_VERSION
Expand All @@ -56,7 +57,7 @@


@asynccontextmanager
async def lifespan(app: FastAPI):
async def lifespan(app: fastapi.FastAPI):

async def _force_log():
while True:
Expand All @@ -74,7 +75,7 @@ async def _force_log():
router = APIRouter()


def mount_metrics(app: FastAPI):
def mount_metrics(app: fastapi.FastAPI):
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
Expand Down Expand Up @@ -164,8 +165,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump())


def build_app(args: Namespace) -> FastAPI:
app = FastAPI(lifespan=lifespan)
def build_app(args):
app = fastapi.FastAPI(lifespan=lifespan)
app.include_router(router)
app.root_path = args.root_path

Expand Down Expand Up @@ -213,8 +214,11 @@ async def authentication(request: Request, call_next):
return app


async def init_app(args: Namespace,
llm_engine: Optional[AsyncLLMEngine] = None) -> FastAPI:
async def build_server(
args,
llm_engine: Optional[AsyncLLMEngine] = None,
**uvicorn_kwargs,
) -> uvicorn.Server:
app = build_app(args)

if args.served_model_name is not None:
Expand Down Expand Up @@ -277,17 +281,14 @@ async def init_app(args: Namespace,
)
app.root_path = args.root_path

return app


async def run_server(args: Namespace,
llm_engine: Optional[AsyncLLMEngine] = None,
**uvicorn_kwargs: Any) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
logger.info("Available routes are:")
for route in app.routes:
if not hasattr(route, 'methods'):
continue
methods = ', '.join(route.methods)
logger.info("Route: %s, Methods: %s", route.path, methods)

app = await init_app(args, llm_engine)
await serve_http(
config = uvicorn.Config(
app,
host=args.host,
port=args.port,
Expand All @@ -300,6 +301,36 @@ async def run_server(args: Namespace,
**uvicorn_kwargs,
)

return uvicorn.Server(config)


async def run_server(args, llm_engine=None, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)

server = await build_server(
args,
llm_engine,
**uvicorn_kwargs,
)

loop = asyncio.get_running_loop()

server_task = loop.create_task(server.serve())

def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()

loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)

try:
await server_task
except asyncio.CancelledError:
print("Gracefully stopping http server")
await server.shutdown()


if __name__ == "__main__":
# NOTE(simon):
Expand All @@ -308,5 +339,4 @@ async def run_server(args: Namespace,
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
args = parser.parse_args()

asyncio.run(run_server(args))
3 changes: 0 additions & 3 deletions vllm/server/__init__.py

This file was deleted.

42 changes: 0 additions & 42 deletions vllm/server/launch.py

This file was deleted.

0 comments on commit c15b011

Please sign in to comment.