Skip to content

Commit

Permalink
Merge branch 'main' into vllm_upgrade
Browse files Browse the repository at this point in the history
ZeldaHuang committed Dec 17, 2024
2 parents 05b5499 + 4029b00 commit c44810e
Showing 10 changed files with 219 additions and 178 deletions.
2 changes: 1 addition & 1 deletion llumnix/__init__.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
from vllm import *

from llumnix.server_info import ServerInfo
from llumnix.entrypoints.utils import (launch_ray_cluster,
from llumnix.entrypoints.setup import (launch_ray_cluster,
connect_to_ray_cluster,
init_manager,
init_llumlets)
72 changes: 25 additions & 47 deletions llumnix/entrypoints/utils.py → llumnix/entrypoints/setup.py
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@
from llumnix.queue.queue_type import QueueType
from llumnix.server_info import ServerInfo, RequestTimestamps
from llumnix.queue.utils import init_request_output_queue_server
from llumnix.queue.queue_server_base import QueueServerBase

logger = init_logger(__name__)

@@ -39,17 +40,20 @@


class LlumnixEntrypointsContext:
def __init__(self):
self.engine_manager: LLMEngineManager = None
self.instances: Dict[str, Llumlet] = {}
self.request_output_queue: QueueServerBase = None
self.server_info: ServerInfo = None
self.request_streams: Dict[str, AsyncStream] = {}
self.manager_available = True
self.num_finished_requests = 0
self.instance_num_requests: Dict[str, int] = {}
self.log_requests: bool = None
self.log_request_timestamps: bool = None
def __init__(self,
engine_manager: LLMEngineManager,
instances: Dict[str, Llumlet],
request_output_queue: QueueServerBase,
server_info: ServerInfo,
log_requests: bool,
log_request_timestamps: bool):
self.engine_manager = engine_manager
self.instances = instances
self.request_output_queue = request_output_queue
self.server_info = server_info
self.log_requests = log_requests
self.log_request_timestamps = log_request_timestamps


def get_ip_address():
hostname = socket.gethostname()
@@ -240,41 +244,21 @@ def setup_llumnix(engine_manager_args, engine_args, cfg):
ip,
cfg.SERVER.REQUEST_OUTPUT_QUEUE_PORT)
instances: Dict[str, Llumlet] = {}
instance_num_requests: Dict[str, int] = {}
for idx, ins_id in enumerate(instance_ids):
instances[ins_id] = llumlets[idx]
instance_num_requests[ins_id] = 0

log_requests = not cfg.SERVER.DISABLE_LOG_REQUESTS_SERVER
log_request_timestamps = cfg.SERVER.LOG_REQUEST_TIMESTAMPS
logger.info("log_requests: {}, log_request_timestamps: {}".format(log_requests, log_request_timestamps))

context = LlumnixEntrypointsContext()
context.engine_manager = engine_manager
context.instances = instances
context.request_output_queue = request_output_queue
context.server_info = server_info
context.instance_num_requests = instance_num_requests
context.log_requests = log_requests
context.log_request_timestamps = log_request_timestamps

return context
llumnix_entrypoints_context = LlumnixEntrypointsContext(engine_manager,
instances,
request_output_queue,
server_info,
log_requests,
log_request_timestamps)

# TODO(s5u13b): Fix the potential output token out-of-order issue caused by the migration.
async def _background_process_outputs(llumnix_context):
while True:
request_outputs = await llumnix_context.request_output_queue.get()
for request_output in request_outputs:
if hasattr(request_output, 'request_timestamps'):
request_output.request_timestamps.api_server_background_process_get_queue_timestamp = time.time()
for request_output in request_outputs:
request_id = request_output.request_id
# Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished.
if request_id not in llumnix_context.request_streams:
continue
llumnix_context.request_streams[request_id].put(request_output)
if request_output.finished:
llumnix_context.request_streams[request_id].finish()
del llumnix_context.request_streams[request_id]
return llumnix_entrypoints_context

def init_per_token_latency_breakdown_dict() -> Dict[str, int]:
per_token_latency_breakdown_dict = {
@@ -290,11 +274,5 @@ def init_per_token_latency_breakdown_dict() -> Dict[str, int]:
return per_token_latency_breakdown_dict

def record_per_token_latency_breakdown(per_token_latency_breakdown_dict: Dict[str, int], request_timestamps: RequestTimestamps):
per_token_latency_breakdown_dict['step_latency_engine'].append(request_timestamps.step_latency_engine)
per_token_latency_breakdown_dict['process_model_outputs_latency'].append(request_timestamps.process_model_outputs_latency)
per_token_latency_breakdown_dict['step_postprocess_latency'].append(request_timestamps.step_postprocess_latency)
per_token_latency_breakdown_dict['across_async_put_queue_thread_latency'].append(request_timestamps.across_async_put_queue_thread_latency)
per_token_latency_breakdown_dict['across_async_put_queue_actor_latency'].append(request_timestamps.across_async_put_queue_actor_latency)
per_token_latency_breakdown_dict['queue_rpc_latency'].append(request_timestamps.queue_rpc_latency)
per_token_latency_breakdown_dict['background_process_get_queue_latency'].append(request_timestamps.background_process_get_queue_latency)
per_token_latency_breakdown_dict['generate_benchmark_return_output_latency'].append(request_timestamps.generate_benchmark_return_output_latency)
for key in per_token_latency_breakdown_dict.keys():
per_token_latency_breakdown_dict[key].append(getattr(request_timestamps, key))
41 changes: 19 additions & 22 deletions llumnix/entrypoints/vllm/api_server.py
Original file line number Diff line number Diff line change
@@ -23,18 +23,14 @@
from vllm.sampling_params import SamplingParams

from llumnix.arg_utils import LlumnixArgumentParser
from llumnix.entrypoints.utils import (setup_ray_cluster,
from llumnix.entrypoints.setup import (setup_ray_cluster,
setup_llumnix,
is_gpu_available,
LlumnixEntrypointsContext,
_background_process_outputs,
init_per_token_latency_breakdown_dict,
record_per_token_latency_breakdown)
from llumnix.entrypoints.vllm.utils import (add_cli_args,
get_args,
manager_generate,
manager_abort,
manager_is_ready)
from llumnix.entrypoints.vllm.arg_utils import (add_cli_args,
get_args)
from llumnix.entrypoints.vllm.client import LlumnixClientVLLM
from llumnix.logger import init_logger
from llumnix.utils import random_uuid
from llumnix.config import get_llumnix_config, LlumnixConfig
@@ -44,16 +40,16 @@

TIMEOUT_KEEP_ALIVE = 5 # seconds.

llumnix_context: LlumnixEntrypointsContext = None
llumnix_client: LlumnixClientVLLM = None


# pylint: disable=unused-argument
@asynccontextmanager
async def lifespan(fastapi_app: FastAPI):
asyncio.create_task(llumnix_context.request_output_queue.run_server_loop())
asyncio.create_task(_background_process_outputs(llumnix_context))
asyncio.create_task(llumnix_client.request_output_queue.run_server_loop())
asyncio.create_task(llumnix_client.get_request_outputs_loop())
yield
llumnix_context.request_output_queue.cleanup()
llumnix_client.request_output_queue.cleanup()

app = FastAPI(lifespan=lifespan)

@@ -79,8 +75,8 @@ async def generate(request: Request) -> Response:
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()

# Use manager_generate and manager_abort to replace with vllm async engine generate and abort api.
results_generator = await manager_generate(prompt, sampling_params, request_id, llumnix_context)
# Use LlumnixClientVLLM's generate and abort api to replace with vLLM AsyncLLMEngine's generate and abort api.
results_generator = await llumnix_client.generate(prompt, sampling_params, request_id)

# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
@@ -100,7 +96,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator.generator():
if await request.is_disconnected():
# Abort the request if the client disconnects.
await manager_abort(request_id, llumnix_context)
await llumnix_client.abort(request_id)
return Response(status_code=499)
final_output = request_output

@@ -128,7 +124,7 @@ async def generate_benchmark(request: Request) -> Response:

start = time.time()

results_generator = await manager_generate(prompt, sampling_params, request_id, llumnix_context)
results_generator = await llumnix_client.generate(prompt, sampling_params, request_id)

# Non-streaming case
final_output = None
@@ -137,7 +133,7 @@ async def generate_benchmark(request: Request) -> Response:
async for request_output in results_generator.generator():
if await request.is_disconnected():
# Abort the request if the client disconnects.
await manager_abort(request_id, llumnix_context)
await llumnix_client.abort(request_id)
return Response(status_code=499)
now = time.time()
per_token_latency.append([now, (now - start)*1000])
@@ -148,10 +144,10 @@ async def generate_benchmark(request: Request) -> Response:
record_per_token_latency_breakdown(per_token_latency_breakdown_dict, request_output.request_timestamps)
assert final_output is not None

if llumnix_context.log_requests:
llumnix_context.num_finished_requests += 1
if llumnix_client.log_requests:
llumnix_client.num_finished_requests += 1
logger.info("entrypoints finished request {}.".format(request_id))
logger.info("num_finished_requests {}.".format(llumnix_context.num_finished_requests))
logger.info("num_finished_requests {}.".format(llumnix_client.num_finished_requests))

generation = final_output.outputs[0].text
num_output_tokens = len(final_output.outputs[0].token_ids)
@@ -172,7 +168,7 @@ async def generate_benchmark(request: Request) -> Response:

@app.get("/is_ready")
async def is_ready():
return await manager_is_ready(llumnix_context)
return await llumnix_client.is_ready()


if __name__ == "__main__":
@@ -192,7 +188,8 @@ async def is_ready():

# if gpu is not available, it means that this node is head pod without any llumnix components
if is_gpu_available():
llumnix_context = setup_llumnix(engine_manager_args, engine_args, cfg)
llumnix_entrypoints_context = setup_llumnix(engine_manager_args, engine_args, cfg)
llumnix_client = LlumnixClientVLLM(llumnix_entrypoints_context)
# Start the api server after all the components of llumnix are ready.
logger.info("Start Api Server on '{}:{}'".format(cfg.SERVER.HOST, cfg.SERVER.PORT))
uvicorn.run(app,
31 changes: 31 additions & 0 deletions llumnix/entrypoints/vllm/arg_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from vllm.engine.arg_utils import AsyncEngineArgs
from llumnix.backends.vllm.utils import check_engine_args

from llumnix.arg_utils import LlumnixEntrypointsArgs, EngineManagerArgs
from llumnix.logger import init_logger

logger = init_logger(__name__)


def add_cli_args(parser):
parser.set_namespace("llumnix")
parser = LlumnixEntrypointsArgs.add_cli_args(parser)
parser = EngineManagerArgs.add_cli_args(parser)
parser.set_namespace("vllm")
parser = AsyncEngineArgs.add_cli_args(parser)
cli_args = parser.parse_args()
return cli_args

def get_args(cfg, parser, cli_args):
llumnix_entrypoints_args = LlumnixEntrypointsArgs.from_llumnix_config(cfg)
LlumnixEntrypointsArgs.check_args(llumnix_entrypoints_args, parser)
engine_manager_args = EngineManagerArgs.from_llumnix_config(cfg)
EngineManagerArgs.check_args(engine_manager_args, parser)
engine_args = AsyncEngineArgs.from_cli_args(cli_args)
check_engine_args(engine_args, engine_manager_args)

logger.info("llumnix_entrypoints_args: {}".format(llumnix_entrypoints_args))
logger.info("engine_manager_args: {}".format(engine_manager_args))
logger.info("engine_args: {}".format(engine_args))

return llumnix_entrypoints_args, engine_manager_args, engine_args
128 changes: 128 additions & 0 deletions llumnix/entrypoints/vllm/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import copy
import time
import asyncio
import ray

from vllm.engine.async_llm_engine import AsyncStream
from vllm import SamplingParams

from llumnix.logger import init_logger
from llumnix.entrypoints.setup import LlumnixEntrypointsContext
from llumnix.server_info import RequestTimestamps
from llumnix.queue.queue_server_base import QueueServerBase
from llumnix.server_info import ServerInfo

logger = init_logger(__name__)

WAIT_MANAGER_INTERVAL = 5


class LlumnixClientVLLM:
def __init__(self,
llumnix_entrypoints_context: LlumnixEntrypointsContext):
self.engine_manager: LLMEngineManager = llumnix_entrypoints_context.engine_manager
self.instances: Dict[str, Llumlet] = llumnix_entrypoints_context.instances
self.request_output_queue: QueueServerBase = llumnix_entrypoints_context.request_output_queue
self.server_info: ServerInfo = llumnix_entrypoints_context.server_info
self.log_requests: bool = llumnix_entrypoints_context.log_requests
self.log_request_timestamps: bool = llumnix_entrypoints_context.log_request_timestamps

self.request_streams: Dict[str, AsyncStream] = {}
self.instance_num_requests: Dict[str, int] = {}
for ins_id in self.instances.keys():
self.instance_num_requests[ins_id] = 0
self.num_finished_requests: int = 0
self.manager_available: bool = True

async def generate(self,
prompt: str,
sampling_params: SamplingParams,
request_id: str,
*args,
**kwargs) -> AsyncStream:
if sampling_params.n > 1:
raise ValueError("Unsupported feature: multiple sequence decoding")

results_generator = AsyncStream(request_id, cancel=None)
self.request_streams[request_id] = results_generator
server_info_copy = copy.deepcopy(self.server_info)

# If manager is unavailable, request will be directly added to the llumlet held by api server.
try:
await self._generate_by_manager(request_id, server_info_copy, prompt, sampling_params, *args, **kwargs)
self.manager_available = True
except ray.exceptions.RayActorError:
# Do not re-generate the request to avoid duplicate requests.
if self.manager_available:
self.manager_available = False
return results_generator
await self._generate_by_instance(request_id, server_info_copy, prompt, sampling_params, *args, **kwargs)

return results_generator

async def _generate_by_manager(self,
request_id: str,
server_info: ServerInfo,
prompt: str,
sampling_params: SamplingParams,
*args,
**kwargs) -> AsyncStream:
if self.log_request_timestamps:
# Hack request timestamps in server_info for latency breakdown.
server_info.request_timestamps = RequestTimestamps()
server_info.request_timestamps.api_server_manager_generate_timestamp = time.time()
await self.engine_manager.generate.remote(request_id, server_info, prompt, sampling_params, *args, **kwargs)

async def _generate_by_instance(self,
request_id: str,
server_info: ServerInfo,
prompt: str,
sampling_params: SamplingParams,
*args,
**kwargs) -> AsyncStream:
try:
if self.instance_num_requests:
instance_id = min(self.instance_num_requests, key=self.instance_num_requests.get)
self.instance_num_requests[instance_id] += 1
await self.instances[instance_id].generate.remote(request_id, server_info, prompt, sampling_params, *args, **kwargs)
logger.info("LLMEngineManager is unavailable temporarily, dispatch request {} to instance {}".format(
request_id, instance_id))
else:
logger.info("LLMEngineManager is unavailable temporarily, but there is no instance behind this api server, "
"sleep {}s, waiting for manager available".format(WAIT_MANAGER_INTERVAL))
await asyncio.sleep(WAIT_MANAGER_INTERVAL)
return await asyncio.create_task(self.generate(prompt, sampling_params, request_id, *args, **kwargs))
except (ray.exceptions.RayActorError, KeyError):
if instance_id in self.instances:
logger.info("[manager_generate] instance {} is dead".format(instance_id))
del self.instances[instance_id]
del self.instance_num_requests[instance_id]
return await asyncio.create_task(self.generate(prompt, sampling_params, request_id, *args, **kwargs))

async def abort(self, request_id: str) -> None:
try:
logger.info("abort request: {}.".format(request_id))
await self.engine_manager.abort.remote(request_id)
except ray.exceptions.RayActorError:
logger.info("manager is unavailable")

async def is_ready(self) -> bool:
ready_status = await self.engine_manager.is_ready.remote()
return ready_status

# TODO(s5u13b): Fix the potential output token out-of-order issue caused by the migration.
async def get_request_outputs_loop(self):
while True:
request_outputs = await self.request_output_queue.get()
for request_output in request_outputs:
if hasattr(request_output, 'request_timestamps'):
request_output.request_timestamps.api_server_background_process_get_queue_timestamp = time.time()
for request_output in request_outputs:
request_id = request_output.request_id
# Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished.
if request_id not in self.request_streams:
continue
self.request_streams[request_id].put(request_output)
if request_output.finished:
self.request_streams[request_id].finish()
del self.request_streams[request_id]
98 changes: 0 additions & 98 deletions llumnix/entrypoints/vllm/utils.py

This file was deleted.

1 change: 1 addition & 0 deletions llumnix/queue/ray_queue_server.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@

from llumnix.queue.queue_server_base import QueueServerBase


class RayQueueServer(QueueServerBase):
def __init__(self) -> None:
self.queue = RayQueue(
2 changes: 1 addition & 1 deletion tests/unit_test/entrypoints/test_utils.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
import ray

from llumnix.arg_utils import EngineManagerArgs
from llumnix.entrypoints.utils import (get_ip_address,
from llumnix.entrypoints.setup import (get_ip_address,
launch_ray_cluster,
init_manager,
retry_manager_method_sync,
18 changes: 11 additions & 7 deletions tests/unit_test/entrypoints/vllm/api_server_manager.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,8 @@
from llumnix.server_info import ServerInfo, RequestTimestamps
from llumnix.utils import random_uuid
from llumnix.queue.utils import init_request_output_queue_server, init_request_output_queue_client, QueueType
from llumnix.entrypoints.utils import LlumnixEntrypointsContext
from llumnix.entrypoints.setup import LlumnixEntrypointsContext
from llumnix.entrypoints.vllm.client import LlumnixClientVLLM

app = llumnix.entrypoints.vllm.api_server.app
engine_manager = None
@@ -73,17 +74,20 @@ def stats() -> Response:

request_output_queue_type = QueueType(args.request_output_queue_type)
engine_manager = init_manager(request_output_queue_type)
llumnix.entrypoints.vllm.api_server.llumnix_context = LlumnixEntrypointsContext()
llumnix.entrypoints.vllm.api_server.llumnix_context.engine_manager = engine_manager
ip = '127.0.0.1'
port = 1234
llumnix.entrypoints.vllm.api_server.llumnix_context.request_output_queue = \
init_request_output_queue_server(ip, port, request_output_queue_type)
request_output_queue = init_request_output_queue_server(ip, port, request_output_queue_type)
ray_queue_server = None
if request_output_queue_type == QueueType.RAYQUEUE:
ray_queue_server = llumnix.entrypoints.vllm.api_server.llumnix_context.request_output_queue
ray_queue_server = request_output_queue
server_info = ServerInfo(random_uuid(), request_output_queue_type, ray_queue_server, ip, port)
llumnix.entrypoints.vllm.api_server.llumnix_context.server_info = server_info
llumnix_context = LlumnixEntrypointsContext(engine_manager,
{'0': None},
request_output_queue,
server_info,
None,
None)
llumnix.entrypoints.vllm.api_server.llumnix_client = LlumnixClientVLLM(llumnix_context)

uvicorn.run(
app,
4 changes: 2 additions & 2 deletions tests/unit_test/queue/test_zmq.py
Original file line number Diff line number Diff line change
@@ -34,10 +34,10 @@ def __init__(self, rpc_path):
asyncio.create_task(self.server.run_server_loop())
request_output_queue = self.server
self.stop_signal = asyncio.Event()
asyncio.create_task(self._background_process_outputs(request_output_queue))
asyncio.create_task(self.get_request_outputs_loop(request_output_queue))
asyncio.create_task(self._wait_until_done())

async def _background_process_outputs(self, request_output_queue):
async def get_request_outputs_loop(self, request_output_queue):
while True:
request_outputs = await request_output_queue.get()
for request_output in request_outputs:

0 comments on commit c44810e

Please sign in to comment.