From 4029b00bd5b590fbc136d182d597e86f0b9e57ee Mon Sep 17 00:00:00 2001 From: Biao Sun Date: Fri, 13 Dec 2024 11:34:18 +0800 Subject: [PATCH] [Entrypoints] Use LlumnixClient class in entrypoints (#84) --- llumnix/__init__.py | 2 +- llumnix/entrypoints/{utils.py => setup.py} | 72 ++++------ llumnix/entrypoints/vllm/api_server.py | 41 +++--- llumnix/entrypoints/vllm/arg_utils.py | 31 +++++ llumnix/entrypoints/vllm/client.py | 128 ++++++++++++++++++ llumnix/entrypoints/vllm/utils.py | 98 -------------- llumnix/queue/ray_queue_server.py | 1 + tests/unit_test/entrypoints/test_utils.py | 2 +- .../entrypoints/vllm/api_server_manager.py | 18 ++- tests/unit_test/queue/test_zmq.py | 4 +- 10 files changed, 219 insertions(+), 178 deletions(-) rename llumnix/entrypoints/{utils.py => setup.py} (78%) create mode 100644 llumnix/entrypoints/vllm/arg_utils.py create mode 100644 llumnix/entrypoints/vllm/client.py delete mode 100644 llumnix/entrypoints/vllm/utils.py diff --git a/llumnix/__init__.py b/llumnix/__init__.py index a6892514..5ef7ecee 100644 --- a/llumnix/__init__.py +++ b/llumnix/__init__.py @@ -15,7 +15,7 @@ from vllm import * from llumnix.server_info import ServerInfo -from llumnix.entrypoints.utils import (launch_ray_cluster, +from llumnix.entrypoints.setup import (launch_ray_cluster, connect_to_ray_cluster, init_manager, init_llumlets) diff --git a/llumnix/entrypoints/utils.py b/llumnix/entrypoints/setup.py similarity index 78% rename from llumnix/entrypoints/utils.py rename to llumnix/entrypoints/setup.py index 496f151d..35280e0f 100644 --- a/llumnix/entrypoints/utils.py +++ b/llumnix/entrypoints/setup.py @@ -29,6 +29,7 @@ from llumnix.queue.queue_type import QueueType from llumnix.server_info import ServerInfo, RequestTimestamps from llumnix.queue.utils import init_request_output_queue_server +from llumnix.queue.queue_server_base import QueueServerBase logger = init_logger(__name__) @@ -39,17 +40,20 @@ class LlumnixEntrypointsContext: - def __init__(self): - self.engine_manager: LLMEngineManager = None - self.instances: Dict[str, Llumlet] = {} - self.request_output_queue: QueueServerBase = None - self.server_info: ServerInfo = None - self.request_streams: Dict[str, AsyncStream] = {} - self.manager_available = True - self.num_finished_requests = 0 - self.instance_num_requests: Dict[str, int] = {} - self.log_requests: bool = None - self.log_request_timestamps: bool = None + def __init__(self, + engine_manager: LLMEngineManager, + instances: Dict[str, Llumlet], + request_output_queue: QueueServerBase, + server_info: ServerInfo, + log_requests: bool, + log_request_timestamps: bool): + self.engine_manager = engine_manager + self.instances = instances + self.request_output_queue = request_output_queue + self.server_info = server_info + self.log_requests = log_requests + self.log_request_timestamps = log_request_timestamps + def get_ip_address(): hostname = socket.gethostname() @@ -240,41 +244,21 @@ def setup_llumnix(engine_manager_args, engine_args, cfg): ip, cfg.SERVER.REQUEST_OUTPUT_QUEUE_PORT) instances: Dict[str, Llumlet] = {} - instance_num_requests: Dict[str, int] = {} for idx, ins_id in enumerate(instance_ids): instances[ins_id] = llumlets[idx] - instance_num_requests[ins_id] = 0 + log_requests = not cfg.SERVER.DISABLE_LOG_REQUESTS_SERVER log_request_timestamps = cfg.SERVER.LOG_REQUEST_TIMESTAMPS logger.info("log_requests: {}, log_request_timestamps: {}".format(log_requests, log_request_timestamps)) - context = LlumnixEntrypointsContext() - context.engine_manager = engine_manager - context.instances = instances - context.request_output_queue = request_output_queue - context.server_info = server_info - context.instance_num_requests = instance_num_requests - context.log_requests = log_requests - context.log_request_timestamps = log_request_timestamps - - return context + llumnix_entrypoints_context = LlumnixEntrypointsContext(engine_manager, + instances, + request_output_queue, + server_info, + log_requests, + log_request_timestamps) -# TODO(s5u13b): Fix the potential output token out-of-order issue caused by the migration. -async def _background_process_outputs(llumnix_context): - while True: - request_outputs = await llumnix_context.request_output_queue.get() - for request_output in request_outputs: - if hasattr(request_output, 'request_timestamps'): - request_output.request_timestamps.api_server_background_process_get_queue_timestamp = time.time() - for request_output in request_outputs: - request_id = request_output.request_id - # Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished. - if request_id not in llumnix_context.request_streams: - continue - llumnix_context.request_streams[request_id].put(request_output) - if request_output.finished: - llumnix_context.request_streams[request_id].finish() - del llumnix_context.request_streams[request_id] + return llumnix_entrypoints_context def init_per_token_latency_breakdown_dict() -> Dict[str, int]: per_token_latency_breakdown_dict = { @@ -290,11 +274,5 @@ def init_per_token_latency_breakdown_dict() -> Dict[str, int]: return per_token_latency_breakdown_dict def record_per_token_latency_breakdown(per_token_latency_breakdown_dict: Dict[str, int], request_timestamps: RequestTimestamps): - per_token_latency_breakdown_dict['step_latency_engine'].append(request_timestamps.step_latency_engine) - per_token_latency_breakdown_dict['process_model_outputs_latency'].append(request_timestamps.process_model_outputs_latency) - per_token_latency_breakdown_dict['step_postprocess_latency'].append(request_timestamps.step_postprocess_latency) - per_token_latency_breakdown_dict['across_async_put_queue_thread_latency'].append(request_timestamps.across_async_put_queue_thread_latency) - per_token_latency_breakdown_dict['across_async_put_queue_actor_latency'].append(request_timestamps.across_async_put_queue_actor_latency) - per_token_latency_breakdown_dict['queue_rpc_latency'].append(request_timestamps.queue_rpc_latency) - per_token_latency_breakdown_dict['background_process_get_queue_latency'].append(request_timestamps.background_process_get_queue_latency) - per_token_latency_breakdown_dict['generate_benchmark_return_output_latency'].append(request_timestamps.generate_benchmark_return_output_latency) + for key in per_token_latency_breakdown_dict.keys(): + per_token_latency_breakdown_dict[key].append(getattr(request_timestamps, key)) diff --git a/llumnix/entrypoints/vllm/api_server.py b/llumnix/entrypoints/vllm/api_server.py index ca7ead93..4d6fa730 100644 --- a/llumnix/entrypoints/vllm/api_server.py +++ b/llumnix/entrypoints/vllm/api_server.py @@ -23,18 +23,14 @@ from vllm.sampling_params import SamplingParams from llumnix.arg_utils import LlumnixArgumentParser -from llumnix.entrypoints.utils import (setup_ray_cluster, +from llumnix.entrypoints.setup import (setup_ray_cluster, setup_llumnix, is_gpu_available, - LlumnixEntrypointsContext, - _background_process_outputs, init_per_token_latency_breakdown_dict, record_per_token_latency_breakdown) -from llumnix.entrypoints.vllm.utils import (add_cli_args, - get_args, - manager_generate, - manager_abort, - manager_is_ready) +from llumnix.entrypoints.vllm.arg_utils import (add_cli_args, + get_args) +from llumnix.entrypoints.vllm.client import LlumnixClientVLLM from llumnix.logger import init_logger from llumnix.utils import random_uuid from llumnix.config import get_llumnix_config, LlumnixConfig @@ -44,16 +40,16 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds. -llumnix_context: LlumnixEntrypointsContext = None +llumnix_client: LlumnixClientVLLM = None # pylint: disable=unused-argument @asynccontextmanager async def lifespan(fastapi_app: FastAPI): - asyncio.create_task(llumnix_context.request_output_queue.run_server_loop()) - asyncio.create_task(_background_process_outputs(llumnix_context)) + asyncio.create_task(llumnix_client.request_output_queue.run_server_loop()) + asyncio.create_task(llumnix_client.get_request_outputs_loop()) yield - llumnix_context.request_output_queue.cleanup() + llumnix_client.request_output_queue.cleanup() app = FastAPI(lifespan=lifespan) @@ -79,8 +75,8 @@ async def generate(request: Request) -> Response: sampling_params = SamplingParams(**request_dict) request_id = random_uuid() - # Use manager_generate and manager_abort to replace with vllm async engine generate and abort api. - results_generator = await manager_generate(prompt, sampling_params, request_id, llumnix_context) + # Use LlumnixClientVLLM's generate and abort api to replace with vLLM AsyncLLMEngine's generate and abort api. + results_generator = await llumnix_client.generate(prompt, sampling_params, request_id) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: @@ -100,7 +96,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: async for request_output in results_generator: if await request.is_disconnected(): # Abort the request if the client disconnects. - await manager_abort(request_id, llumnix_context) + await llumnix_client.abort(request_id) return Response(status_code=499) final_output = request_output @@ -128,7 +124,7 @@ async def generate_benchmark(request: Request) -> Response: start = time.time() - results_generator = await manager_generate(prompt, sampling_params, request_id, llumnix_context) + results_generator = await llumnix_client.generate(prompt, sampling_params, request_id) # Non-streaming case final_output = None @@ -137,7 +133,7 @@ async def generate_benchmark(request: Request) -> Response: async for request_output in results_generator: if await request.is_disconnected(): # Abort the request if the client disconnects. - await manager_abort(request_id, llumnix_context) + await llumnix_client.abort(request_id) return Response(status_code=499) now = time.time() per_token_latency.append([now, (now - start)*1000]) @@ -148,10 +144,10 @@ async def generate_benchmark(request: Request) -> Response: record_per_token_latency_breakdown(per_token_latency_breakdown_dict, request_output.request_timestamps) assert final_output is not None - if llumnix_context.log_requests: - llumnix_context.num_finished_requests += 1 + if llumnix_client.log_requests: + llumnix_client.num_finished_requests += 1 logger.info("entrypoints finished request {}.".format(request_id)) - logger.info("num_finished_requests {}.".format(llumnix_context.num_finished_requests)) + logger.info("num_finished_requests {}.".format(llumnix_client.num_finished_requests)) generation = final_output.outputs[0].text num_output_tokens = len(final_output.outputs[0].token_ids) @@ -172,7 +168,7 @@ async def generate_benchmark(request: Request) -> Response: @app.get("/is_ready") async def is_ready(): - return await manager_is_ready(llumnix_context) + return await llumnix_client.is_ready() if __name__ == "__main__": @@ -192,7 +188,8 @@ async def is_ready(): # if gpu is not available, it means that this node is head pod without any llumnix components if is_gpu_available(): - llumnix_context = setup_llumnix(engine_manager_args, engine_args, cfg) + llumnix_entrypoints_context = setup_llumnix(engine_manager_args, engine_args, cfg) + llumnix_client = LlumnixClientVLLM(llumnix_entrypoints_context) # Start the api server after all the components of llumnix are ready. logger.info("Start Api Server on '{}:{}'".format(cfg.SERVER.HOST, cfg.SERVER.PORT)) uvicorn.run(app, diff --git a/llumnix/entrypoints/vllm/arg_utils.py b/llumnix/entrypoints/vllm/arg_utils.py new file mode 100644 index 00000000..bb7daacd --- /dev/null +++ b/llumnix/entrypoints/vllm/arg_utils.py @@ -0,0 +1,31 @@ +from vllm.engine.arg_utils import AsyncEngineArgs +from llumnix.backends.vllm.utils import check_engine_args + +from llumnix.arg_utils import LlumnixEntrypointsArgs, EngineManagerArgs +from llumnix.logger import init_logger + +logger = init_logger(__name__) + + +def add_cli_args(parser): + parser.set_namespace("llumnix") + parser = LlumnixEntrypointsArgs.add_cli_args(parser) + parser = EngineManagerArgs.add_cli_args(parser) + parser.set_namespace("vllm") + parser = AsyncEngineArgs.add_cli_args(parser) + cli_args = parser.parse_args() + return cli_args + +def get_args(cfg, parser, cli_args): + llumnix_entrypoints_args = LlumnixEntrypointsArgs.from_llumnix_config(cfg) + LlumnixEntrypointsArgs.check_args(llumnix_entrypoints_args, parser) + engine_manager_args = EngineManagerArgs.from_llumnix_config(cfg) + EngineManagerArgs.check_args(engine_manager_args, parser) + engine_args = AsyncEngineArgs.from_cli_args(cli_args) + check_engine_args(engine_args, engine_manager_args) + + logger.info("llumnix_entrypoints_args: {}".format(llumnix_entrypoints_args)) + logger.info("engine_manager_args: {}".format(engine_manager_args)) + logger.info("engine_args: {}".format(engine_args)) + + return llumnix_entrypoints_args, engine_manager_args, engine_args diff --git a/llumnix/entrypoints/vllm/client.py b/llumnix/entrypoints/vllm/client.py new file mode 100644 index 00000000..b59ee4be --- /dev/null +++ b/llumnix/entrypoints/vllm/client.py @@ -0,0 +1,128 @@ +import copy +import time +import asyncio +import ray + +from vllm.engine.async_llm_engine import AsyncStream +from vllm import SamplingParams + +from llumnix.logger import init_logger +from llumnix.entrypoints.setup import LlumnixEntrypointsContext +from llumnix.server_info import RequestTimestamps +from llumnix.queue.queue_server_base import QueueServerBase +from llumnix.server_info import ServerInfo + +logger = init_logger(__name__) + +WAIT_MANAGER_INTERVAL = 5 + + +class LlumnixClientVLLM: + def __init__(self, + llumnix_entrypoints_context: LlumnixEntrypointsContext): + self.engine_manager: LLMEngineManager = llumnix_entrypoints_context.engine_manager + self.instances: Dict[str, Llumlet] = llumnix_entrypoints_context.instances + self.request_output_queue: QueueServerBase = llumnix_entrypoints_context.request_output_queue + self.server_info: ServerInfo = llumnix_entrypoints_context.server_info + self.log_requests: bool = llumnix_entrypoints_context.log_requests + self.log_request_timestamps: bool = llumnix_entrypoints_context.log_request_timestamps + + self.request_streams: Dict[str, AsyncStream] = {} + self.instance_num_requests: Dict[str, int] = {} + for ins_id in self.instances.keys(): + self.instance_num_requests[ins_id] = 0 + self.num_finished_requests: int = 0 + self.manager_available: bool = True + + async def generate(self, + prompt: str, + sampling_params: SamplingParams, + request_id: str, + *args, + **kwargs) -> AsyncStream: + if sampling_params.n > 1 or sampling_params.use_beam_search: + raise ValueError("Unsupported feature: multiple sequence decoding") + + results_generator = AsyncStream(request_id) + self.request_streams[request_id] = results_generator + server_info_copy = copy.deepcopy(self.server_info) + + # If manager is unavailable, request will be directly added to the llumlet held by api server. + try: + await self._generate_by_manager(request_id, server_info_copy, prompt, sampling_params, *args, **kwargs) + self.manager_available = True + except ray.exceptions.RayActorError: + # Do not re-generate the request to avoid duplicate requests. + if self.manager_available: + self.manager_available = False + return results_generator + await self._generate_by_instance(request_id, server_info_copy, prompt, sampling_params, *args, **kwargs) + + return results_generator + + async def _generate_by_manager(self, + request_id: str, + server_info: ServerInfo, + prompt: str, + sampling_params: SamplingParams, + *args, + **kwargs) -> AsyncStream: + if self.log_request_timestamps: + # Hack request timestamps in server_info for latency breakdown. + server_info.request_timestamps = RequestTimestamps() + server_info.request_timestamps.api_server_manager_generate_timestamp = time.time() + await self.engine_manager.generate.remote(request_id, server_info, prompt, sampling_params, *args, **kwargs) + + async def _generate_by_instance(self, + request_id: str, + server_info: ServerInfo, + prompt: str, + sampling_params: SamplingParams, + *args, + **kwargs) -> AsyncStream: + try: + if self.instance_num_requests: + instance_id = min(self.instance_num_requests, key=self.instance_num_requests.get) + self.instance_num_requests[instance_id] += 1 + await self.instances[instance_id].generate.remote(request_id, server_info, prompt, sampling_params, *args, **kwargs) + logger.info("LLMEngineManager is unavailable temporarily, dispatch request {} to instance {}".format( + request_id, instance_id)) + else: + logger.info("LLMEngineManager is unavailable temporarily, but there is no instance behind this api server, " + "sleep {}s, waiting for manager available".format(WAIT_MANAGER_INTERVAL)) + await asyncio.sleep(WAIT_MANAGER_INTERVAL) + return await asyncio.create_task(self.generate(prompt, sampling_params, request_id, *args, **kwargs)) + except (ray.exceptions.RayActorError, KeyError): + if instance_id in self.instances: + logger.info("[manager_generate] instance {} is dead".format(instance_id)) + del self.instances[instance_id] + del self.instance_num_requests[instance_id] + return await asyncio.create_task(self.generate(prompt, sampling_params, request_id, *args, **kwargs)) + + async def abort(self, request_id: str) -> None: + try: + logger.info("abort request: {}.".format(request_id)) + await self.engine_manager.abort.remote(request_id) + except ray.exceptions.RayActorError: + logger.info("manager is unavailable") + + async def is_ready(self) -> bool: + ready_status = await self.engine_manager.is_ready.remote() + return ready_status + + # TODO(s5u13b): Fix the potential output token out-of-order issue caused by the migration. + async def get_request_outputs_loop(self): + while True: + request_outputs = await self.request_output_queue.get() + for request_output in request_outputs: + if hasattr(request_output, 'request_timestamps'): + request_output.request_timestamps.api_server_background_process_get_queue_timestamp = time.time() + for request_output in request_outputs: + request_id = request_output.request_id + # Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished. + if request_id not in self.request_streams: + continue + self.request_streams[request_id].put(request_output) + if request_output.finished: + self.request_streams[request_id].finish() + del self.request_streams[request_id] diff --git a/llumnix/entrypoints/vllm/utils.py b/llumnix/entrypoints/vllm/utils.py deleted file mode 100644 index 25e2bfb1..00000000 --- a/llumnix/entrypoints/vllm/utils.py +++ /dev/null @@ -1,98 +0,0 @@ -import copy -import time -import asyncio -import ray - -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncStream -from vllm import SamplingParams - -from llumnix.backends.vllm.utils import check_engine_args -from llumnix.arg_utils import LlumnixEntrypointsArgs, EngineManagerArgs -from llumnix.logger import init_logger -from llumnix.entrypoints.utils import LlumnixEntrypointsContext -from llumnix.server_info import RequestTimestamps - -logger = init_logger(__name__) - -WAIT_MANAGER_INTERVAL = 5 - - -def add_cli_args(parser): - parser.set_namespace("llumnix") - parser = LlumnixEntrypointsArgs.add_cli_args(parser) - parser = EngineManagerArgs.add_cli_args(parser) - parser.set_namespace("vllm") - parser = AsyncEngineArgs.add_cli_args(parser) - cli_args = parser.parse_args() - return cli_args - -def get_args(cfg, parser, cli_args): - llumnix_entrypoints_args = LlumnixEntrypointsArgs.from_llumnix_config(cfg) - LlumnixEntrypointsArgs.check_args(llumnix_entrypoints_args, parser) - engine_manager_args = EngineManagerArgs.from_llumnix_config(cfg) - EngineManagerArgs.check_args(engine_manager_args, parser) - engine_args = AsyncEngineArgs.from_cli_args(cli_args) - check_engine_args(engine_args, engine_manager_args) - - logger.info("llumnix_entrypoints_args: {}".format(llumnix_entrypoints_args)) - logger.info("engine_manager_args: {}".format(engine_manager_args)) - logger.info("engine_args: {}".format(engine_args)) - - return llumnix_entrypoints_args, engine_manager_args, engine_args - -async def manager_generate(prompt: str, - sampling_params: SamplingParams, - request_id: str, - llumnix_context: LlumnixEntrypointsContext) -> AsyncStream: - results_generator = AsyncStream(request_id) - llumnix_context.request_streams[request_id] = results_generator - - if sampling_params.n > 1 or sampling_params.use_beam_search: - raise ValueError("Unsupported feature: multiple sequence decoding") - # This request's outputs will be put to the request_output_queue of this api server no matter which instance it's running in. - # If manager is unavailable, request will be directly added to the llumlet held by api server. - try: - server_info_copy = copy.deepcopy(llumnix_context.server_info) - if llumnix_context.log_request_timestamps: - # Hack request timestamps in server_info for latency breakdown. - server_info_copy.request_timestamps = RequestTimestamps() - server_info_copy.request_timestamps.api_server_manager_generate_timestamp = time.time() - # await to catch exception - await llumnix_context.engine_manager.generate.remote(request_id, server_info_copy, prompt, sampling_params) - llumnix_context.manager_available = True - except ray.exceptions.RayActorError: - # Do not re-generate the request to avoid duplicate requests. - if llumnix_context.manager_available: - llumnix_context.manager_available = False - return results_generator - try: - if llumnix_context.instance_num_requests: - instance_id = min(llumnix_context.instance_num_requests, key=llumnix_context.instance_num_requests.get) - llumnix_context.instance_num_requests[instance_id] += 1 - await llumnix_context.instances[instance_id].generate.remote(request_id, server_info_copy, prompt, sampling_params) - logger.info("Manager is unavailable, directly pass request {} to instance {}".format(request_id, instance_id)) - else: - logger.info("Manager is unavailable, but there is no instance behind this api server, " - "sleep {}s, waiting for manager restarts".format(WAIT_MANAGER_INTERVAL)) - await asyncio.sleep(WAIT_MANAGER_INTERVAL) - return await asyncio.create_task(manager_generate(prompt, sampling_params, request_id, llumnix_context)) - except (ray.exceptions.RayActorError, KeyError): - if instance_id in llumnix_context.instances: - logger.info("[manager_generate] instance {} is dead".format(instance_id)) - del llumnix_context.instances[instance_id] - del llumnix_context.instance_num_requests[instance_id] - return await asyncio.create_task(manager_generate(prompt, sampling_params, request_id, llumnix_context)) - - return results_generator - -async def manager_abort(request_id: str, llumnix_context: LlumnixEntrypointsContext) -> None: - try: - logger.info("abort request: {}.".format(request_id)) - await llumnix_context.engine_manager.abort.remote(request_id) - except ray.exceptions.RayActorError: - logger.info("manager is unavailable") - -async def manager_is_ready(llumnix_context: LlumnixEntrypointsContext): - ready_status = await llumnix_context.engine_manager.is_ready.remote() - return ready_status diff --git a/llumnix/queue/ray_queue_server.py b/llumnix/queue/ray_queue_server.py index b20fd271..6cff2607 100644 --- a/llumnix/queue/ray_queue_server.py +++ b/llumnix/queue/ray_queue_server.py @@ -19,6 +19,7 @@ from llumnix.queue.queue_server_base import QueueServerBase + class RayQueueServer(QueueServerBase): def __init__(self) -> None: self.queue = RayQueue( diff --git a/tests/unit_test/entrypoints/test_utils.py b/tests/unit_test/entrypoints/test_utils.py index 033835ce..d6ff1fae 100644 --- a/tests/unit_test/entrypoints/test_utils.py +++ b/tests/unit_test/entrypoints/test_utils.py @@ -16,7 +16,7 @@ import ray from llumnix.arg_utils import EngineManagerArgs -from llumnix.entrypoints.utils import (get_ip_address, +from llumnix.entrypoints.setup import (get_ip_address, launch_ray_cluster, init_manager, retry_manager_method_sync, diff --git a/tests/unit_test/entrypoints/vllm/api_server_manager.py b/tests/unit_test/entrypoints/vllm/api_server_manager.py index f9616555..63837d04 100644 --- a/tests/unit_test/entrypoints/vllm/api_server_manager.py +++ b/tests/unit_test/entrypoints/vllm/api_server_manager.py @@ -24,7 +24,8 @@ from llumnix.server_info import ServerInfo, RequestTimestamps from llumnix.utils import random_uuid from llumnix.queue.utils import init_request_output_queue_server, init_request_output_queue_client, QueueType -from llumnix.entrypoints.utils import LlumnixEntrypointsContext +from llumnix.entrypoints.setup import LlumnixEntrypointsContext +from llumnix.entrypoints.vllm.client import LlumnixClientVLLM app = llumnix.entrypoints.vllm.api_server.app engine_manager = None @@ -73,17 +74,20 @@ def stats() -> Response: request_output_queue_type = QueueType(args.request_output_queue_type) engine_manager = init_manager(request_output_queue_type) - llumnix.entrypoints.vllm.api_server.llumnix_context = LlumnixEntrypointsContext() - llumnix.entrypoints.vllm.api_server.llumnix_context.engine_manager = engine_manager ip = '127.0.0.1' port = 1234 - llumnix.entrypoints.vllm.api_server.llumnix_context.request_output_queue = \ - init_request_output_queue_server(ip, port, request_output_queue_type) + request_output_queue = init_request_output_queue_server(ip, port, request_output_queue_type) ray_queue_server = None if request_output_queue_type == QueueType.RAYQUEUE: - ray_queue_server = llumnix.entrypoints.vllm.api_server.llumnix_context.request_output_queue + ray_queue_server = request_output_queue server_info = ServerInfo(random_uuid(), request_output_queue_type, ray_queue_server, ip, port) - llumnix.entrypoints.vllm.api_server.llumnix_context.server_info = server_info + llumnix_context = LlumnixEntrypointsContext(engine_manager, + {'0': None}, + request_output_queue, + server_info, + None, + None) + llumnix.entrypoints.vllm.api_server.llumnix_client = LlumnixClientVLLM(llumnix_context) uvicorn.run( app, diff --git a/tests/unit_test/queue/test_zmq.py b/tests/unit_test/queue/test_zmq.py index d4303d37..9787f77a 100644 --- a/tests/unit_test/queue/test_zmq.py +++ b/tests/unit_test/queue/test_zmq.py @@ -34,10 +34,10 @@ def __init__(self, rpc_path): asyncio.create_task(self.server.run_server_loop()) request_output_queue = self.server self.stop_signal = asyncio.Event() - asyncio.create_task(self._background_process_outputs(request_output_queue)) + asyncio.create_task(self.get_request_outputs_loop(request_output_queue)) asyncio.create_task(self._wait_until_done()) - async def _background_process_outputs(self, request_output_queue): + async def get_request_outputs_loop(self, request_output_queue): while True: request_outputs = await request_output_queue.get() for request_output in request_outputs: