Skip to content

Commit

Permalink
feat: venus adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
jimpang committed Sep 12, 2024
1 parent 3fd2b0d commit c0c25d3
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 53 deletions.
90 changes: 53 additions & 37 deletions vllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import json
import ssl
from argparse import Namespace
from dataclasses import asdict
from typing import Any, AsyncGenerator, Optional

from fastapi import FastAPI, Request
Expand All @@ -17,6 +18,7 @@
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.launcher import serve_http
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
Expand Down Expand Up @@ -53,19 +55,28 @@ async def generate(request: Request) -> Response:
request_id = random_uuid()

assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id)
# jimpang add
inputs = prompt
if prompt and len(prompt) > 0:
first_element = prompt[0]
if isinstance(first_element, int):
inputs = TokensPrompt(prompt_token_ids=prompt)

results_generator = engine.generate(
inputs=inputs, sampling_params=sampling_params, request_id=request_id)
results_generator = iterate_with_cancellation(
results_generator, is_cancelled=request.is_disconnected)

# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
assert prompt is not None
text_outputs = [
prompt + output.text for output in request_output.outputs
output.text for output in request_output.outputs
]
ret = {"text": text_outputs}
output_tokens = [output.token_ids for output in request_output.outputs]
logprobs = [[{k: asdict(v) for k, v in logprobs.items()} for logprobs in
output.logprobs] if output.logprobs is not None else None for output in request_output.outputs]
ret = {"text": text_outputs, "output_token_ids": output_tokens, "logprobs": logprobs}
yield (json.dumps(ret) + "\0").encode("utf-8")

if stream:
Expand All @@ -80,10 +91,11 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
return Response(status_code=499)

assert final_output is not None
prompt = final_output.prompt
assert prompt is not None
text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs}
text_outputs = [output.text for output in final_output.outputs]
output_tokens = [output.token_ids for output in final_output.outputs]
logprobs = [[{k: asdict(v) for k, v in logprobs.items()} for logprobs in
output.logprobs] if output.logprobs is not None else None for output in final_output.outputs]
ret = {"text": text_outputs, "output_token_ids": output_tokens, "logprobs": logprobs}
return JSONResponse(ret)


Expand All @@ -95,8 +107,8 @@ def build_app(args: Namespace) -> FastAPI:


async def init_app(
args: Namespace,
llm_engine: Optional[AsyncLLMEngine] = None,
args: Namespace,
llm_engine: Optional[AsyncLLMEngine] = None,
) -> FastAPI:
app = build_app(args)

Expand All @@ -105,7 +117,7 @@ async def init_app(
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = (llm_engine
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.API_SERVER))
engine_args, usage_context=UsageContext.API_SERVER))

return app

Expand Down Expand Up @@ -137,28 +149,32 @@ async def run_server(args: Namespace,


if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None)
parser.add_argument("--ssl-ca-certs",
type=str,
default=None,
help="The CA certificates file")
parser.add_argument(
"--ssl-cert-reqs",
type=int,
default=int(ssl.CERT_NONE),
help="Whether client certificate is required (see stdlib ssl module's)"
)
parser.add_argument(
"--root-path",
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument("--log-level", type=str, default="debug")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()

asyncio.run(run_server(args))
try:
parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None)
parser.add_argument("--ssl-ca-certs",
type=str,
default=None,
help="The CA certificates file")
parser.add_argument(
"--ssl-cert-reqs",
type=int,
default=int(ssl.CERT_NONE),
help="Whether client certificate is required (see stdlib ssl module's)"
)
parser.add_argument(
"--root-path",
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument("--log-level", type=str, default="debug")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()

asyncio.run(run_server(args))
except Exception as e:
logger.error(str(e))
raise
34 changes: 18 additions & 16 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def model_is_embedding(model_name: str, trust_remote_code: bool,

@asynccontextmanager
async def lifespan(app: FastAPI):

async def _force_log():
while True:
await asyncio.sleep(10)
Expand All @@ -98,7 +97,6 @@ async def _force_log():
@asynccontextmanager
async def build_async_engine_client(
args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:

# Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
global engine_args
Expand All @@ -109,15 +107,14 @@ async def build_async_engine_client(

async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing) as engine:

async_engine_client = engine # type: ignore[assignment]
yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
) -> AsyncIterator[Optional[AsyncEngineClient]]:
"""
Create AsyncEngineClient, either:
Expand Down Expand Up @@ -285,7 +282,6 @@ async def show_version():
@router.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):

generator = await openai_serving_chat.create_chat_completion(
request, raw_request)

Expand Down Expand Up @@ -330,13 +326,15 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
"Torch Profiler is enabled in the API server. This should ONLY be "
"used for local development!")


@router.post("/start_profile")
async def start_profile():
logger.info("Starting profiler...")
await async_engine_client.start_profile()
logger.info("Profiler started.")
return Response(status_code=200)


@router.post("/stop_profile")
async def stop_profile():
logger.info("Stopping profiler...")
Expand Down Expand Up @@ -429,8 +427,8 @@ async def authentication(request: Request, call_next):


async def init_app(
async_engine_client: AsyncEngineClient,
args: Namespace,
async_engine_client: AsyncEngineClient,
args: Namespace,
) -> FastAPI:
app = build_app(args)

Expand Down Expand Up @@ -521,11 +519,15 @@ async def run_server(args, **uvicorn_kwargs) -> None:


if __name__ == "__main__":
# NOTE(simon):
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
args = parser.parse_args()

asyncio.run(run_server(args))
try:
# NOTE(simon):
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
args = parser.parse_args()

asyncio.run(run_server(args))
except Exception as e:
logger.error(str(e))
raise

0 comments on commit c0c25d3

Please sign in to comment.