From 6de3bd783459fc1db2c9aa45c061219768d54ef7 Mon Sep 17 00:00:00 2001 From: Biao Sun Date: Fri, 11 Oct 2024 09:39:34 +0800 Subject: [PATCH] [Core] Optimize request output tokens putting back implementation to reduce overhead (#45) --- configs/base.yml | 2 +- examlpes/offline_inference.py | 15 +-- llumnix/backends/backend_interface.py | 2 +- llumnix/backends/vllm/llm_engine.py | 105 +++++++++++------- llumnix/backends/vllm/simulator.py | 6 +- llumnix/config/default.py | 2 +- llumnix/entrypoints/vllm/api_server.py | 23 ++-- llumnix/llumlet/llumlet.py | 24 ++-- llumnix/queue/queue_client_base.py | 5 + llumnix/queue/ray_queue_client.py | 6 +- llumnix/queue/zmq_client.py | 12 +- llumnix/queue/zmq_server.py | 13 ++- llumnix/queue/zmq_utils.py | 6 +- llumnix/server_info.py | 6 +- tests/conftest.py | 2 +- .../unit_test/backends/vllm/test_migration.py | 24 ++-- .../entrypoints/test_llumnix_utils.py | 2 +- .../entrypoints/vllm/api_server_manager.py | 2 +- .../llumlet/test_engine_step_exception.py | 10 +- .../{output_queue => queue}/__init__.py | 0 .../{output_queue => queue}/test_zmq.py | 10 +- .../{output_queue => queue}/utils.py | 0 22 files changed, 173 insertions(+), 104 deletions(-) rename tests/unit_test/{output_queue => queue}/__init__.py (100%) rename tests/unit_test/{output_queue => queue}/test_zmq.py (93%) rename tests/unit_test/{output_queue => queue}/utils.py (100%) diff --git a/configs/base.yml b/configs/base.yml index 23df2ec8..d91a5135 100644 --- a/configs/base.yml +++ b/configs/base.yml @@ -4,7 +4,7 @@ SERVER: QUEUE_TYPE: "rayqueue" RAY: - RAY_CLUSTER_PORT: 30037 + RAY_CLUSTER_PORT: 6379 LAUNCH_RAY_CLUSTER: True MANAGER: diff --git a/examlpes/offline_inference.py b/examlpes/offline_inference.py index 25ec12ba..96d86b0e 100644 --- a/examlpes/offline_inference.py +++ b/examlpes/offline_inference.py @@ -24,7 +24,7 @@ # Launch ray cluster os.environ['HEAD_NODE'] = '1' os.environ['HEAD_NODE_IP'] = '127.0.0.1' -ray_cluster_port=37000 +ray_cluster_port=6379 # Note: launch_ray_cluster will stop current ray cluster first, then init a new one. launch_ray_cluster(ray_cluster_port=ray_cluster_port) @@ -58,12 +58,13 @@ async def background_process_outputs(num_tasks): finish_task = 0 while finish_task != num_tasks: - request_output = await request_output_queue.get() - if request_output.finished: - finish_task += 1 - prompt = request_output.prompt - generated_text = request_output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + request_outputs = await request_output_queue.get() + for request_output in request_outputs: + if request_output.finished: + finish_task += 1 + prompt = request_output.prompt + generated_text = request_output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") request_output_queue.cleanup() async def main(): diff --git a/llumnix/backends/backend_interface.py b/llumnix/backends/backend_interface.py index 0c726899..53823db3 100644 --- a/llumnix/backends/backend_interface.py +++ b/llumnix/backends/backend_interface.py @@ -63,7 +63,7 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: raise NotImplementedError @abstractmethod - def _start_engine_loop(self) -> None: + def _start_engine_step_loop(self) -> None: """Start step loop of backend engine. """ raise NotImplementedError diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index cfe0fb66..34811e50 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -17,8 +17,10 @@ from collections import defaultdict import threading import asyncio +import queue import ray from ray.util.placement_group import PlacementGroup +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy from vllm.engine.llm_engine import LLMEngine from vllm.core.scheduler import ScheduledSequenceGroup @@ -42,58 +44,71 @@ logger = init_logger(__name__) -class AsyncPutQueueThread(threading.Thread): +class AsyncPutQueueActor: def __init__(self, instance_id, output_queue_type: QueueType): - super().__init__() self.instance_id = instance_id - - self.request_output_queue_client: QueueClientBase \ - = get_output_queue_client(output_queue_type) + self.request_output_queue_client: QueueClientBase = get_output_queue_client(output_queue_type) self.engine_actor_handle = None - self.loop = asyncio.new_event_loop() - self.daemon = True - - def run(self): - asyncio.set_event_loop(self.loop) - self.loop.run_forever() - async def _put_nowait_batch_to_servers(self, - server_request_outputs: Dict[str, List[RequestOutput]], - server_info_dict: Dict[str, ServerInfo]) -> None: + async def put_nowait_to_servers(self, + server_request_outputs: Dict[str, List[RequestOutput]], + server_info_dict: Dict[str, ServerInfo]) -> None: if self.engine_actor_handle is None: self.engine_actor_handle = ray.get_actor("instance_{}".format(self.instance_id), namespace="llumnix") tasks = [] for server_id, req_outputs in server_request_outputs.items(): server_info = server_info_dict[server_id] - tasks.append(asyncio.create_task(self.request_output_queue_client.put_nowait_batch(req_outputs, server_info))) + tasks.append(asyncio.create_task(self.request_output_queue_client.put_nowait(req_outputs, server_info))) rets = await asyncio.gather(*tasks, return_exceptions=True) for idx, ret in enumerate(rets): - if isinstance(ret, TimeoutError): + if isinstance(ret, (TimeoutError, ray.exceptions.RayActorError)): server_id = list(server_request_outputs.keys())[idx] server_info = server_info_dict[server_id] logger.info("Server {} is dead".format(server_id)) - logger.info("request output queue ip: {}, port: {}".format(server_info.request_output_queue_ip, - server_info.request_output_queue_port)) + if output_queue_type == QueueType.ZMQ: + logger.info("request output queue ip: {}, port: {}".format(server_info.request_output_queue_ip, + server_info.request_output_queue_port)) req_outputs = list(server_request_outputs.values())[idx] request_ids = [req_output.request_id for req_output in req_outputs] self.engine_actor_handle.abort_request.remote(request_ids) - def put_nowait_batch_to_servers(self, - server_request_outputs: Dict[str, List[RequestOutput]], - server_info_dict: Dict[str, ServerInfo]) -> None: - asyncio.run_coroutine_threadsafe(self._put_nowait_batch_to_servers(server_request_outputs, server_info_dict), - self.loop) - class LLMEngineLlumnix(LLMEngine): - def __init__(self, instance_id: str, output_queue_type: QueueType, *arg, **kwargs) -> None: + def __init__(self, + instance_id: str, + output_queue_type: QueueType, + placement_group: Optional[PlacementGroup], + node_id: Optional[str], + *arg, **kwargs) -> None: super().__init__(*arg, **kwargs) self.instance_id = instance_id self.step_counter = Counter() self.instance_info = None - # TODO(s5u13b): Reduce the overhead. - self.async_put_queue_thread = AsyncPutQueueThread(instance_id, output_queue_type) - self.async_put_queue_thread.start() + # Place the async put queue actor together with the instance. + if placement_group: + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_capture_child_tasks=True, + ) + elif node_id: + scheduling_strategy = NodeAffinitySchedulingStrategy( + node_id=node_id, + soft=False, + ) + else: # When use simulator, placement_group and node_id are both None. + scheduling_strategy = NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), + soft=False, + ) + self.put_queue_args_queue = queue.Queue() + self.put_queue_loop_thread = threading.Thread( + target=self._start_put_queue_loop, args=(), daemon=True, name="put_queue_loop" + ) + self.async_put_queue_actor = ray.remote( + num_cpus=1, + scheduling_strategy=scheduling_strategy + )(AsyncPutQueueActor).remote(instance_id, output_queue_type) + self.put_queue_loop_thread.start() # pylint: disable=W0221 @classmethod @@ -105,7 +120,7 @@ def from_engine_args( usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, instance_id: str = None, placement_group: Optional[PlacementGroup] = None, - node_id: str = None, + node_id: Optional[str] = None, latency_mem: Optional[LatencyMemData] = None ) -> "LLMEngineLlumnix": """Creates an LLM engine from the engine arguments.""" @@ -130,6 +145,8 @@ def from_engine_args( engine = cls( instance_id=instance_id, output_queue_type=output_queue_type, + placement_group=placement_group, + node_id=node_id, **engine_config.to_dict(), executor_class=executor_class, log_stats=not engine_args.disable_log_stats, @@ -186,10 +203,12 @@ def step(self) -> None: tot_blocks.extend(blocks) tot_blocks = set(tot_blocks) instance_info.num_blocks_last_running_request = len(tot_blocks) - if request_outputs: - self._put_request_outputs_to_server(request_outputs, server_infos) + self.put_queue_args_queue.put((request_outputs, server_infos)) self.instance_info = instance_info + num_request_outputs = len(request_outputs) + + return num_request_outputs def update_instance_info(self, instance_info: InstanceInfo) -> None: # These fields are updated after step. @@ -208,7 +227,13 @@ def add_request(self, request_id: str, server_info: ServerInfo, *args, **kwargs) seq_group.metrics.arrival_time, seq_group.lora_request, seq_group.multi_modal_data) self.scheduler.scheduler_lock.release() - def _put_request_outputs_to_server(self, request_outputs, server_infos: List[ServerInfo]) -> None: + def _start_put_queue_loop(self): + while True: + args = self.put_queue_args_queue.get() + request_outputs, server_infos = args + self._put_request_outputs_to_server(request_outputs, server_infos) + + def _put_request_outputs_to_server(self, request_outputs: List[RequestOutput], server_infos: List[ServerInfo]) -> None: server_request_outputs = defaultdict(list) server_info_dict = {} # Reorganize data in orther to put request output to queue in batch at one time. @@ -217,7 +242,9 @@ def _put_request_outputs_to_server(self, request_outputs, server_infos: List[Ser server_request_outputs[server_id].append(request_output) if server_id not in server_info_dict: server_info_dict[server_id] = server_info - self.async_put_queue_thread.put_nowait_batch_to_servers(server_request_outputs, server_info_dict) + # TODO(s5u13b): Reduce the across-actor overhead. + # TODO(s5u13b): It is not necessary to use async_put_queue_actor when output_queue_type is RayQueue. + self.async_put_queue_actor.put_nowait_to_servers.remote(server_request_outputs, server_info_dict) class BackendVLLM(BackendInterface): def __init__( @@ -251,12 +278,12 @@ def __init__( logger.info("engine ({}) current state {}".format(self.instance_id, self.state)) self._stop_event = threading.Event() - self._thread = threading.Thread( - target=self._start_engine_loop, args=(), daemon=True, name="engine_loop" + self.engine_step_loop_thread = threading.Thread( + target=self._start_engine_step_loop, args=(), daemon=True, name="engine_step_loop" ) - self._thread.start() + self.engine_step_loop_thread.start() - def _start_engine_loop(self) -> None: + def _start_engine_step_loop(self) -> None: self._stop_event.clear() with self.state_lock: @@ -266,7 +293,9 @@ def _start_engine_loop(self) -> None: while not self._stop_event.is_set(): try: - self.engine.step() + num_request_outputs = self.engine.step() + if num_request_outputs == 0: + time.sleep(0.01) # pylint: disable=broad-except except Exception as e: logger.error("Error in engine loop: {}".format(e)) diff --git a/llumnix/backends/vllm/simulator.py b/llumnix/backends/vllm/simulator.py index 99ed96ad..a87b455c 100644 --- a/llumnix/backends/vllm/simulator.py +++ b/llumnix/backends/vllm/simulator.py @@ -64,10 +64,10 @@ def __init__( self.engine.scheduler.add_update_instance_info_callback(self.engine.update_instance_info) self.engine.output_processor.scheduler = self.engine.scheduler self.instance_id = instance_id - self._thread = threading.Thread( - target=self._start_engine_loop, args=(), daemon=True, name="engine_loop" + self.engine_step_loop_thread = threading.Thread( + target=self._start_engine_step_loop, args=(), daemon=True, name="engine_loop" ) - self._thread.start() + self.engine_step_loop_thread.start() # pylint: disable=unused-argument def send_blocks(self, dst_ray_actor: ray.actor.ActorHandle, src_blocks: List[int], dst_blocks: List[int]) -> None: diff --git a/llumnix/config/default.py b/llumnix/config/default.py index ec4ecbc7..c79c27a4 100644 --- a/llumnix/config/default.py +++ b/llumnix/config/default.py @@ -42,7 +42,7 @@ # ----------------------------------------------------------------------------- _C.RAY = LC() # Port number for the Ray cluster -_C.RAY.RAY_CLUSTER_PORT = 30050 +_C.RAY.RAY_CLUSTER_PORT = 6379 # If True, launch Ray cluster in API server _C.RAY.LAUNCH_RAY_CLUSTER = False diff --git a/llumnix/entrypoints/vllm/api_server.py b/llumnix/entrypoints/vllm/api_server.py index 76f55d2a..9334e8ba 100644 --- a/llumnix/entrypoints/vllm/api_server.py +++ b/llumnix/entrypoints/vllm/api_server.py @@ -56,15 +56,16 @@ async def _background_process_outputs(): while True: - request_output = await request_output_queue.get() - request_id = request_output.request_id - # Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished. - if request_id not in request_streams: - continue - request_streams[request_id].put(request_output) - if request_output.finished: - request_streams[request_id].finish() - del request_streams[request_id] + request_outputs = await request_output_queue.get() + for request_output in request_outputs: + request_id = request_output.request_id + # Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished. + if request_id not in request_streams: + continue + request_streams[request_id].put(request_output) + if request_output.finished: + request_streams[request_id].finish() + del request_streams[request_id] # pylint: disable=unused-argument @asynccontextmanager @@ -180,11 +181,11 @@ async def generate_benchmark(request: Request) -> Response: sampling_params = SamplingParams(**request_dict) request_id = random_uuid() + start = time.time() + results_generator = await manager_generate(prompt, sampling_params, request_id) per_token_latency = [] - start = time.time() - # Non-streaming case final_output = None async for request_output in results_generator: diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index 097a04c4..13320e20 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -15,8 +15,7 @@ from typing import List, Union, Iterable import time import ray -from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy from llumnix.logger import init_logger from llumnix.instance_info import InstanceInfo @@ -86,7 +85,9 @@ def from_args(cls, lifetime=lifetime)(cls).options( scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=placement_group, - placement_group_bundle_index=0,)) + placement_group_bundle_index=0, + ) + ) else: kwargs["node_id"] = node_id engine_class = ray.remote(num_cpus=1, @@ -96,16 +97,21 @@ def from_args(cls, lifetime=lifetime)(cls).options( scheduling_strategy=NodeAffinitySchedulingStrategy( node_id=node_id, - soft=False,)) + soft=False, + ) + ) else: # backend_type == backend_type.SIM_VLLM: + kwargs["node_id"] = node_id engine_class = ray.remote(num_cpus=1, name=actor_name, namespace='llumnix', max_concurrency=4, lifetime=lifetime)(cls).options( - scheduling_strategy=NodeAffinitySchedulingStrategy( - node_id=node_id, - soft=False,)) + scheduling_strategy=NodeAffinitySchedulingStrategy( + node_id=node_id, + soft=False, + ) + ) llumlet = engine_class.remote(instance_id, output_queue_type, backend_type, migration_config, *args, **kwargs) return llumlet @@ -118,8 +124,8 @@ def check_state(self): logger.warning("llumlet ({}) detected backend engine crashed. Stopping...".format(self.instance_id)) # pylint: disable=protected-access self.backend_engine._stop_event.set() - if self.backend_engine._thread.is_alive(): - self.backend_engine._thread.join() + if self.backend_engine.engine_step_loop_thread.is_alive(): + self.backend_engine.engine_step_loop_thread.join() self_actor = ray.get_actor(self.actor_name) ray.kill(self_actor) diff --git a/llumnix/queue/queue_client_base.py b/llumnix/queue/queue_client_base.py index 967ae1e4..9e2c52c8 100644 --- a/llumnix/queue/queue_client_base.py +++ b/llumnix/queue/queue_client_base.py @@ -12,11 +12,16 @@ # limitations under the License. from abc import ABC, abstractmethod +from typing import Any from collections.abc import Iterable from llumnix.server_info import ServerInfo class QueueClientBase(ABC): + @abstractmethod + async def put_nowait(self, item: Any, server_info: ServerInfo): + raise NotImplementedError + @abstractmethod async def put_nowait_batch(self, items: Iterable, server_info: ServerInfo): raise NotImplementedError diff --git a/llumnix/queue/ray_queue_client.py b/llumnix/queue/ray_queue_client.py index 590f730e..628c637a 100644 --- a/llumnix/queue/ray_queue_client.py +++ b/llumnix/queue/ray_queue_client.py @@ -11,13 +11,17 @@ # See the License for the specific language governing permissions and # limitations under the License. - +from typing import Any from collections.abc import Iterable from llumnix.server_info import ServerInfo from llumnix.queue.queue_client_base import QueueClientBase class RayQueueClient(QueueClientBase): + async def put_nowait(self, item: Any, server_info: ServerInfo): + output_queue = server_info.request_output_queue + return await output_queue.actor.put_nowait.remote(item) + async def put_nowait_batch(self, items: Iterable, server_info: ServerInfo): output_queue = server_info.request_output_queue return await output_queue.actor.put_nowait_batch.remote(items) diff --git a/llumnix/queue/zmq_client.py b/llumnix/queue/zmq_client.py index 9cd58cbd..afbdf170 100644 --- a/llumnix/queue/zmq_client.py +++ b/llumnix/queue/zmq_client.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from contextlib import contextmanager from collections.abc import Iterable @@ -22,8 +23,8 @@ from llumnix.server_info import ServerInfo from llumnix.queue.zmq_utils import (RPC_GET_DATA_TIMEOUT_MS, RPC_SOCKET_LIMIT_CUTOFF, RPC_ZMQ_HWM, RPC_SUCCESS_STR, - RPCClientClosedError, RPC_REQUEST_TYPE, RPCUtilityRequest, RPCPutNoWaitBatchQueueRequest, - get_open_zmq_ipc_path) + RPCClientClosedError, RPC_REQUEST_TYPE, RPCUtilityRequest, RPCPutNoWaitQueueRequest, + RPCPutNoWaitBatchQueueRequest, get_open_zmq_ipc_path) logger = init_logger(__name__) @@ -104,6 +105,13 @@ async def wait_for_server_rpc(self, rpc_path=rpc_path, error_message="Unable to start RPC Server") + async def put_nowait(self, item: Any, server_info: ServerInfo): + rpc_path = get_open_zmq_ipc_path(server_info.request_output_queue_ip, server_info.request_output_queue_port) + await self._send_one_way_rpc_request( + request=RPCPutNoWaitQueueRequest(item=item), + rpc_path=rpc_path, + error_message="Unable to put items into queue.") + async def put_nowait_batch(self, items: Iterable, server_info: ServerInfo): rpc_path = get_open_zmq_ipc_path(server_info.request_output_queue_ip, server_info.request_output_queue_port) await self._send_one_way_rpc_request( diff --git a/llumnix/queue/zmq_server.py b/llumnix/queue/zmq_server.py index 503820c4..f1114782 100644 --- a/llumnix/queue/zmq_server.py +++ b/llumnix/queue/zmq_server.py @@ -20,7 +20,7 @@ import cloudpickle from llumnix.queue.zmq_utils import (RPC_ZMQ_HWM, RPC_SUCCESS_STR, RPC_SOCKET_LIMIT_CUTOFF, - RPCPutNoWaitBatchQueueRequest, RPCUtilityRequest) + RPCPutNoWaitQueueRequest, RPCPutNoWaitBatchQueueRequest, RPCUtilityRequest) from llumnix.logger import init_logger logger = init_logger(__name__) @@ -110,6 +110,8 @@ def _make_handler_coro(self, identity, request = cloudpickle.loads(message) if request == RPCUtilityRequest.IS_SERVER_READY: return self._is_server_ready(identity) + if isinstance(request, RPCPutNoWaitQueueRequest): + return self._put_nowait(identity, request) if isinstance(request, RPCPutNoWaitBatchQueueRequest): return self._put_nowait_batch(identity, request) @@ -119,6 +121,15 @@ async def _is_server_ready(self, identity): await self.socket.send_multipart( [identity, cloudpickle.dumps(RPC_SUCCESS_STR)]) + async def _put_nowait(self, identity, put_nowait_queue_request: RPCPutNoWaitQueueRequest): + try: + self.put_nowait(put_nowait_queue_request.item) + await self.socket.send_multipart( + [identity, cloudpickle.dumps(RPC_SUCCESS_STR)]) + # pylint: disable=W0703 + except Exception as e: + await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + async def _put_nowait_batch(self, identity, put_nowait_batch_queue_request: RPCPutNoWaitBatchQueueRequest): try: self.put_nowait_batch(put_nowait_batch_queue_request.items) diff --git a/llumnix/queue/zmq_utils.py b/llumnix/queue/zmq_utils.py index bdd76bc8..f98b4ba3 100644 --- a/llumnix/queue/zmq_utils.py +++ b/llumnix/queue/zmq_utils.py @@ -20,6 +20,10 @@ RPC_ZMQ_HWM = 0 RPC_SUCCESS_STR = "SUCCESS" +@dataclass +class RPCPutNoWaitQueueRequest: + item: Any = None + @dataclass class RPCPutNoWaitBatchQueueRequest: items: List[Any] = None @@ -28,7 +32,7 @@ class RPCUtilityRequest(Enum): IS_SERVER_READY = 1 # pylint: disable=C0103 -RPC_REQUEST_TYPE = Union[RPCPutNoWaitBatchQueueRequest, RPCUtilityRequest] +RPC_REQUEST_TYPE = Union[RPCPutNoWaitQueueRequest, RPCPutNoWaitBatchQueueRequest, RPCUtilityRequest] class RPCClientClosedError(Exception): """Exception class raised when the client is used post-close. diff --git a/llumnix/server_info.py b/llumnix/server_info.py index 521618c7..51e288db 100644 --- a/llumnix/server_info.py +++ b/llumnix/server_info.py @@ -23,10 +23,8 @@ def __init__(self, request_output_queue_port: int) -> None: self.server_id = server_id self.output_queue_type = output_queue_type - if output_queue_type == QueueType.RAYQUEUE: - assert request_output_queue is not None and hasattr(request_output_queue, "queue") - self.request_output_queue = request_output_queue.queue if hasattr(request_output_queue, "queue") else None - + assert request_output_queue is not None + self.request_output_queue = request_output_queue.queue if output_queue_type == QueueType.RAYQUEUE else None self.request_output_queue_ip = request_output_queue_ip self.request_output_queue_port = request_output_queue_port diff --git a/tests/conftest.py b/tests/conftest.py index dbf2996e..2749ba00 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,7 @@ def pytest_sessionstart(session): subprocess.run(["ray", "stop", "--force"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) sleep(3) - subprocess.run(["ray", "start", "--head", "--disable-usage-stats", "--port=30050"], check=True, + subprocess.run(["ray", "start", "--head", "--disable-usage-stats", "--port=6379"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) sleep(3) diff --git a/tests/unit_test/backends/vllm/test_migration.py b/tests/unit_test/backends/vllm/test_migration.py index 14416d06..4289a0b3 100644 --- a/tests/unit_test/backends/vllm/test_migration.py +++ b/tests/unit_test/backends/vllm/test_migration.py @@ -26,7 +26,7 @@ from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType from llumnix.queue.queue_type import QueueType -from tests.unit_test.output_queue.utils import request_output_queue_server +from tests.unit_test.queue.utils import request_output_queue_server # pylint: disable=unused-import from tests.conftest import setup_ray_env @@ -104,9 +104,10 @@ async def test_correctness(prompt): origin_output = None finished = False while not finished: - request_output = await request_output_queue.get() - origin_output = request_output.outputs[0] - finished = request_output.finished + request_outputs = await request_output_queue.get() + for request_output in request_outputs: + origin_output = request_output.outputs[0] + finished = request_output.finished request_id1 = random_uuid() ray.get(llumlet_0.generate.remote(request_id1, server_info, prompt, sampling_params)) @@ -123,13 +124,14 @@ async def test_correctness(prompt): output = None finished = False while not finished: - request_output = await request_output_queue.get() - origin_output = request_output.outputs[0] - finished = request_output.finished - if request_output.request_id != request_id1: - continue - output = request_output.outputs[0] - finished = request_output.finished + request_outputs = await request_output_queue.get() + for request_output in request_outputs: + origin_output = request_output.outputs[0] + finished = request_output.finished + if request_output.request_id != request_id1: + continue + output = request_output.outputs[0] + finished = request_output.finished assert output.text == origin_output.text assert output.cumulative_logprob == origin_output.cumulative_logprob diff --git a/tests/unit_test/entrypoints/test_llumnix_utils.py b/tests/unit_test/entrypoints/test_llumnix_utils.py index 77919acf..aad1b7de 100644 --- a/tests/unit_test/entrypoints/test_llumnix_utils.py +++ b/tests/unit_test/entrypoints/test_llumnix_utils.py @@ -32,7 +32,7 @@ def test_launch_ray_cluster(): ip_address = get_ip_address() os.environ['HEAD_NODE'] = '1' os.environ['HEAD_NODE_IP'] = ip_address - result = launch_ray_cluster(30050) + result = launch_ray_cluster(6379) assert result.returncode == 0 def test_init_manager(setup_ray_env): diff --git a/tests/unit_test/entrypoints/vllm/api_server_manager.py b/tests/unit_test/entrypoints/vllm/api_server_manager.py index ad879536..b3b6015d 100644 --- a/tests/unit_test/entrypoints/vllm/api_server_manager.py +++ b/tests/unit_test/entrypoints/vllm/api_server_manager.py @@ -41,7 +41,7 @@ async def generate(self, request_id, server_info, *args, **kwargs): self._num_generates += 1 completion_output = CompletionOutput(0, "", [], 0.0, None) request_output = RequestOutput(request_id, "", [], None, [completion_output], finished=True) - await self.request_output_queue.put_nowait_batch([request_output], server_info) + await self.request_output_queue.put_nowait([request_output], server_info) async def abort(self, request_id): self._num_aborts += 1 diff --git a/tests/unit_test/llumlet/test_engine_step_exception.py b/tests/unit_test/llumlet/test_engine_step_exception.py index cdd768dc..6ce2fa1a 100644 --- a/tests/unit_test/llumlet/test_engine_step_exception.py +++ b/tests/unit_test/llumlet/test_engine_step_exception.py @@ -36,8 +36,8 @@ def __init__(self, *args, **kwargs) -> None: def set_error_step(self, broken: bool): self.backend_engine._stop_event.set() - if self.backend_engine._thread.is_alive(): - self.backend_engine._thread.join() + if self.backend_engine.engine_step_loop_thread.is_alive(): + self.backend_engine.engine_step_loop_thread.join() def raise_error_step(): self.origin_step() @@ -48,10 +48,10 @@ def raise_error_step(): else: self.backend_engine.engine.step = self.origin_step - self.backend_engine._thread = threading.Thread( - target=self.backend_engine._start_engine_loop, args=(), daemon=True, name="engine_loop" + self.backend_engine.engine_step_loop_thread = threading.Thread( + target=self.backend_engine._start_engine_step_loop, args=(), daemon=True, name="engine_loop" ) - self.backend_engine._thread.start() + self.backend_engine.engine_step_loop_thread.start() @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need at least 1 GPU to run the test.") def test_engine_step_exception(setup_ray_env): diff --git a/tests/unit_test/output_queue/__init__.py b/tests/unit_test/queue/__init__.py similarity index 100% rename from tests/unit_test/output_queue/__init__.py rename to tests/unit_test/queue/__init__.py diff --git a/tests/unit_test/output_queue/test_zmq.py b/tests/unit_test/queue/test_zmq.py similarity index 93% rename from tests/unit_test/output_queue/test_zmq.py rename to tests/unit_test/queue/test_zmq.py index 891c61d9..fd2ba393 100644 --- a/tests/unit_test/output_queue/test_zmq.py +++ b/tests/unit_test/queue/test_zmq.py @@ -39,10 +39,10 @@ def __init__(self, rpc_path): async def _background_process_outputs(self, request_output_queue): while True: - request_output = await request_output_queue.get() - if request_output.finished: - break - self.stop_signal.set() + request_outputs = await request_output_queue.get() + for request_output in request_outputs: + if request_output.finished: + self.stop_signal.set() async def _wait_until_done(self): await self.stop_signal.wait() @@ -68,7 +68,7 @@ async def async_request_output_gen(generator, qps): return async def put_queue(request_output_queue, request_output, server_info): - await request_output_queue.put_nowait_batch([request_output], server_info) + await request_output_queue.put_nowait([request_output], server_info) class TimeoutException(Exception): pass diff --git a/tests/unit_test/output_queue/utils.py b/tests/unit_test/queue/utils.py similarity index 100% rename from tests/unit_test/output_queue/utils.py rename to tests/unit_test/queue/utils.py