diff --git a/llumnix/backends/backend_interface.py b/llumnix/backends/backend_interface.py index 441fab2d..0c726899 100644 --- a/llumnix/backends/backend_interface.py +++ b/llumnix/backends/backend_interface.py @@ -18,6 +18,11 @@ from llumnix.llumlet.request import LlumnixRequest from llumnix.server_info import ServerInfo +class EngineState(str, Enum): + INIT = "INIT" + CRASHED = "CRASHED" + RUNNING = "RUNNING" + STOPPED = "STOPPED" class BackendType(str, Enum): VLLM = "VLLM" diff --git a/llumnix/backends/utils.py b/llumnix/backends/utils.py index 1e110f44..f1924338 100644 --- a/llumnix/backends/utils.py +++ b/llumnix/backends/utils.py @@ -16,7 +16,7 @@ # pylint: disable=unused-import from ray.util.placement_group import PlacementGroup -from llumnix.backends.backend_interface import BackendInterface, BackendType +from llumnix.backends.backend_interface import BackendInterface, BackendType, EngineState def init_backend_engine(instance_id: str, backend_type: BackendType, *args, **kwargs) -> BackendInterface: diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index 086444a5..5575ef60 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -12,6 +12,7 @@ # limitations under the License. import time +import traceback from typing import Any, List, Optional, Dict, Union, Iterable, Tuple from collections import defaultdict import threading @@ -29,7 +30,7 @@ from llumnix.logger import init_logger from llumnix.instance_info import InstanceInfo -from llumnix.backends.backend_interface import BackendInterface +from llumnix.backends.backend_interface import BackendInterface, EngineState from llumnix.backends.vllm.scheduler import SchedulerLlumnix from llumnix.backends.vllm.sequence import SequenceGroupLlumnix from llumnix.backends.profiling import LatencyMemData @@ -237,14 +238,30 @@ def __init__( self._run_workers("init_migration", instance_id=instance_id, migration_config=migration_config,\ src_worker_handle_list=self.worker_handle_list, placement_group=placement_group, node_id=node_id) + + self.state_lock = threading.Lock() + self.state = EngineState.INIT + self._thread = threading.Thread( target=self._start_engine_loop, args=(), daemon=True, name="engine_loop" ) self._thread.start() def _start_engine_loop(self) -> None: + with self.state_lock: + self.state = EngineState.RUNNING + while True: - self.engine.step() + try: + self.engine.step() + # pylint: disable=broad-except + except Exception as e: + logger.error("Error in engine loop: {}".format(e)) + logger.error("exception traceback: {}".format(traceback.format_exc())) + self._run_workers("shutdown") + with self.state_lock: + self.state = EngineState.CRASHED + break def execute_worker_method(self, method, *args, **kwargs): return self.engine.model_executor.driver_worker.execute_method(method, *args, **kwargs) diff --git a/llumnix/backends/vllm/simulator.py b/llumnix/backends/vllm/simulator.py index a710a5a9..745fc1d2 100644 --- a/llumnix/backends/vllm/simulator.py +++ b/llumnix/backends/vllm/simulator.py @@ -15,6 +15,7 @@ import threading from typing import List +import ray.actor from vllm.engine.arg_utils import EngineArgs from llumnix.logger import init_logger @@ -66,5 +67,5 @@ def __init__( ) self._thread.start() - def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None: + def send_blocks(self, dst_ray_actor: ray.actor.ActorHandle, src_blocks: List[int], dst_blocks: List[int]) -> None: self.engine.model_executor.send_blocks(len(src_blocks)) diff --git a/llumnix/backends/vllm/worker.py b/llumnix/backends/vllm/worker.py index 4c1ec455..92bf1f1b 100644 --- a/llumnix/backends/vllm/worker.py +++ b/llumnix/backends/vllm/worker.py @@ -147,6 +147,7 @@ def shutdown(self) -> None: del self.model_runner del self.cache_engine del self.gpu_cache + del self.migration_backend torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index 5ad3963f..f58c8128 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import threading from typing import List, Union, Iterable import time import ray @@ -19,7 +20,7 @@ from llumnix.logger import init_logger from llumnix.instance_info import InstanceInfo -from llumnix.backends.backend_interface import BackendInterface, BackendType +from llumnix.backends.backend_interface import BackendInterface, BackendType, EngineState from llumnix.backends.utils import init_backend_engine, initialize_placement_group from llumnix.llumlet.migration_coordinator import MigrationCoordinator, MigrationStatus from llumnix.llumlet.local_migration_scheduler import LocalMigrationScheduler @@ -50,6 +51,9 @@ def __init__(self, self.backend_engine) self.log_requests = True + self.state_check_thread = threading.Thread(target=self.check_state, daemon=True) + self.state_check_thread.start() + @classmethod def from_args(cls, disable_fixed_node_init_instance: bool, @@ -63,13 +67,14 @@ def from_args(cls, **kwargs): lifetime = "detached" if detached else None assert backend_type in [backend_type.VLLM, backend_type.SIM_VLLM], f'unimplemented backend {backend_type}' + actor_name = f"instance_{instance_id}" if backend_type == backend_type.VLLM: if disable_fixed_node_init_instance: # TODO(s5u13b): Support placement_group lifetime management when the migration backend is gloo. placement_group = initialize_placement_group(world_size, detached=detached) kwargs["placement_group"] = placement_group engine_class = ray.remote(num_cpus=1, - name=f"instance_{instance_id}", + name=actor_name, namespace='llumnix', max_concurrency=4, lifetime=lifetime)(cls).options( @@ -79,7 +84,7 @@ def from_args(cls, else: kwargs["node_id"] = node_id engine_class = ray.remote(num_cpus=1, - name=f"instance_{instance_id}", + name=actor_name, namespace='llumnix', max_concurrency=4, lifetime=lifetime)(cls).options( @@ -88,7 +93,7 @@ def from_args(cls, soft=False,)) else: # backend_type == backend_type.SIM_VLLM: engine_class = ray.remote(num_cpus=1, - name=f"instance_{instance_id}", + name=actor_name, namespace='llumnix', max_concurrency=4, lifetime=lifetime)(cls).options( @@ -98,6 +103,15 @@ def from_args(cls, llumlet = engine_class.remote(instance_id, backend_type, migration_config, *args, **kwargs) return llumlet + def check_state(self): + while True: + time.sleep(1) + + with self.backend_engine.state_lock: + if self.backend_engine.state == EngineState.CRASHED: + self_actor = ray.get_actor(self.actor_name) + ray.kill(self_actor) + def migrate_out(self, dst_instance_name: str) -> List[str]: try: t0 = time.time() diff --git a/tests/unit_test/llumlet/test_local_migration_scheduler.py b/tests/unit_test/llumlet/test_local_migration_scheduler.py index cd05c247..447dc215 100644 --- a/tests/unit_test/llumlet/test_local_migration_scheduler.py +++ b/tests/unit_test/llumlet/test_local_migration_scheduler.py @@ -1,3 +1,16 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from llumnix.llumlet.local_migration_scheduler import LocalMigrationScheduler from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType