From 052b49977495eeae5c5ca81fc631602003c8970e Mon Sep 17 00:00:00 2001 From: root Date: Tue, 8 Oct 2024 07:41:42 +0000 Subject: [PATCH] fix --- llumnix/backends/utils.py | 1 - llumnix/backends/vllm/llm_engine.py | 15 +++++++-------- llumnix/backends/vllm/simulator.py | 8 +++++++- llumnix/llumlet/llumlet.py | 13 +++++++++---- 4 files changed, 23 insertions(+), 14 deletions(-) diff --git a/llumnix/backends/utils.py b/llumnix/backends/utils.py index 162107fd..f1924338 100644 --- a/llumnix/backends/utils.py +++ b/llumnix/backends/utils.py @@ -28,7 +28,6 @@ def init_backend_engine(instance_id: str, backend_type: BackendType, *args, **kw # pylint: disable=import-outside-toplevel from llumnix.backends.vllm.simulator import BackendSimVLLM backend_engine = BackendSimVLLM(instance_id, *args, **kwargs) - backend_engine.state = EngineState.RUNNING else: raise ValueError(f'Unsupported backend: {backend_type}') return backend_engine diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index 60403c2a..0071704a 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -238,29 +238,28 @@ def __init__( src_worker_handle_list=self.worker_handle_list, placement_group=placement_group, node_id=node_id) + self.state_lock = threading.Lock() self.state = EngineState.INIT + self._thread = threading.Thread( target=self._start_engine_loop, args=(), daemon=True, name="engine_loop" ) self._thread.start() def _start_engine_loop(self) -> None: - self.state = EngineState.RUNNING + with self.state_lock: + self.state = EngineState.RUNNING + while True: try: self.engine.step() - - raise ValueError("test") # pylint: disable=broad-except except Exception as e: logger.error("Error in engine loop: {}".format(e)) logger.error("exception traceback: {}".format(traceback.format_exc())) self._run_workers("shutdown") - self.state = EngineState.CRASHED - - named_actors = ray.util.list_named_actors(True) - for actor in named_actors: - print(actor['name']) + with self.state_lock: + self.state = EngineState.CRASHED break def execute_worker_method(self, method, *args, **kwargs): diff --git a/llumnix/backends/vllm/simulator.py b/llumnix/backends/vllm/simulator.py index 061c517d..9ef6b774 100644 --- a/llumnix/backends/vllm/simulator.py +++ b/llumnix/backends/vllm/simulator.py @@ -12,12 +12,15 @@ # limitations under the License. import os +import ray.actor +import threading from typing import List from vllm.utils import Counter from vllm.engine.arg_utils import EngineArgs from llumnix.logger import init_logger +from llumnix.backends.backend_interface import EngineState from llumnix.internal_config import MigrationConfig from llumnix.backends.vllm.scheduler import SchedulerLlumnix from llumnix.backends.vllm.llm_engine import LLMEngineLlumnix, BackendVLLM @@ -36,6 +39,9 @@ def __init__( gpu_type: str, engine_args: EngineArgs, ) -> None: + self.state_lock = threading.Lock() + self.state = EngineState.RUNNING + # load database profiling_database = ProfilingDatabase(profiling_result_file_path) engine_config = engine_args.create_engine_config() @@ -61,5 +67,5 @@ def __init__( self.instance_id = instance_id self.step_counter = Counter() - def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None: + def send_blocks(self, dst_ray_actor: ray.actor.ActorHandle, src_blocks: List[int], dst_blocks: List[int]) -> None: self.engine.model_executor.send_blocks(len(src_blocks)) diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index 62eaed66..f58c8128 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import threading from typing import List, Union, Iterable import time import ray @@ -50,6 +51,9 @@ def __init__(self, self.backend_engine) self.log_requests = True + self.state_check_thread = threading.Thread(target=self.check_state, daemon=True) + self.state_check_thread.start() + @classmethod def from_args(cls, disable_fixed_node_init_instance: bool, @@ -102,10 +106,11 @@ def from_args(cls, def check_state(self): while True: time.sleep(1) - - if self.backend_engine.state == EngineState.CRASHED: - self_actor = ray.get_actor(self.actor_name) - ray.kill(self_actor) + + with self.backend_engine.state_lock: + if self.backend_engine.state == EngineState.CRASHED: + self_actor = ray.get_actor(self.actor_name) + ray.kill(self_actor) def migrate_out(self, dst_instance_name: str) -> List[str]: try: