diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index 34811e50..08590ba8 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -20,7 +20,7 @@ import queue import ray from ray.util.placement_group import PlacementGroup -from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy from vllm.engine.llm_engine import LLMEngine from vllm.core.scheduler import ScheduledSequenceGroup @@ -49,6 +49,7 @@ def __init__(self, instance_id, output_queue_type: QueueType): self.instance_id = instance_id self.request_output_queue_client: QueueClientBase = get_output_queue_client(output_queue_type) self.engine_actor_handle = None + self.output_queue_type = output_queue_type async def put_nowait_to_servers(self, server_request_outputs: Dict[str, List[RequestOutput]], @@ -65,7 +66,7 @@ async def put_nowait_to_servers(self, server_id = list(server_request_outputs.keys())[idx] server_info = server_info_dict[server_id] logger.info("Server {} is dead".format(server_id)) - if output_queue_type == QueueType.ZMQ: + if self.output_queue_type == QueueType.ZMQ: logger.info("request output queue ip: {}, port: {}".format(server_info.request_output_queue_ip, server_info.request_output_queue_port)) req_outputs = list(server_request_outputs.values())[idx] diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index 13320e20..cf9dbf1f 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -12,6 +12,7 @@ # limitations under the License. import threading +import traceback from typing import List, Union, Iterable import time import ray @@ -31,7 +32,6 @@ class Llumlet: - # TODO(KuilongCui): catch the exception generated in ctor def __init__(self, instance_id: str, output_queue_type: QueueType, @@ -39,24 +39,29 @@ def __init__(self, migration_config: MigrationConfig, *args, **kwargs) -> None: - self.instance_id = instance_id - self.actor_name = f"instance_{instance_id}" - self.backend_engine: BackendInterface = init_backend_engine(self.instance_id, - output_queue_type, - backend_type, - migration_config, - *args, - **kwargs) - self.migration_coordinator = MigrationCoordinator(self.backend_engine, - migration_config.last_stage_max_blocks, - migration_config.max_stages) - self.migration_scheduler = LocalMigrationScheduler(migration_config.request_migration_policy, - self.backend_engine) - self.log_requests = True - - self.check_state_thread = threading.Thread(target=self.check_state, daemon=True, - name="llumlet_check_state_loop") - self.check_state_thread.start() + try: + self.instance_id = instance_id + self.actor_name = f"instance_{instance_id}" + self.backend_engine: BackendInterface = init_backend_engine(self.instance_id, + output_queue_type, + backend_type, + migration_config, + *args, + **kwargs) + self.migration_coordinator = MigrationCoordinator(self.backend_engine, + migration_config.last_stage_max_blocks, + migration_config.max_stages) + self.migration_scheduler = LocalMigrationScheduler(migration_config.request_migration_policy, + self.backend_engine) + self.log_requests = True + + self.check_state_thread = threading.Thread(target=self.check_state, daemon=True, + name="llumlet_check_state_loop") + self.check_state_thread.start() + # pylint: disable=broad-except + except Exception as e: + logger.error("Failed to initialize llumlet: {}".format(e)) + logger.error("exception traceback: {}".format(traceback.format_exc())) @classmethod def from_args(cls,