Skip to content

Commit

Permalink
[Misc] exception
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui committed Oct 8, 2024
1 parent e9cf870 commit 5258e78
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 8 deletions.
5 changes: 5 additions & 0 deletions llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from llumnix.llumlet.request import LlumnixRequest
from llumnix.server_info import ServerInfo

class EngineState(str, Enum):
INIT = "INIT"
CRASHED = "CRASHED"
RUNNING = "RUNNING"
STOPPED = "STOPPED"

class BackendType(str, Enum):
VLLM = "VLLM"
Expand Down
2 changes: 1 addition & 1 deletion llumnix/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# pylint: disable=unused-import
from ray.util.placement_group import PlacementGroup

from llumnix.backends.backend_interface import BackendInterface, BackendType
from llumnix.backends.backend_interface import BackendInterface, BackendType, EngineState


def init_backend_engine(instance_id: str, backend_type: BackendType, *args, **kwargs) -> BackendInterface:
Expand Down
21 changes: 19 additions & 2 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.

import time
import traceback
from typing import Any, List, Optional, Dict, Union, Iterable, Tuple
from collections import defaultdict
import threading
Expand All @@ -29,7 +30,7 @@

from llumnix.logger import init_logger
from llumnix.instance_info import InstanceInfo
from llumnix.backends.backend_interface import BackendInterface
from llumnix.backends.backend_interface import BackendInterface, EngineState
from llumnix.backends.vllm.scheduler import SchedulerLlumnix
from llumnix.backends.vllm.sequence import SequenceGroupLlumnix
from llumnix.backends.profiling import LatencyMemData
Expand Down Expand Up @@ -237,14 +238,30 @@ def __init__(
self._run_workers("init_migration", instance_id=instance_id, migration_config=migration_config,\
src_worker_handle_list=self.worker_handle_list,
placement_group=placement_group, node_id=node_id)

self.state_lock = threading.Lock()
self.state = EngineState.INIT

self._thread = threading.Thread(
target=self._start_engine_loop, args=(), daemon=True, name="engine_loop"
)
self._thread.start()

def _start_engine_loop(self) -> None:
with self.state_lock:
self.state = EngineState.RUNNING

while True:
self.engine.step()
try:
self.engine.step()
# pylint: disable=broad-except
except Exception as e:
logger.error("Error in engine loop: {}".format(e))
logger.error("exception traceback: {}".format(traceback.format_exc()))
self._run_workers("shutdown")
with self.state_lock:
self.state = EngineState.CRASHED
break

def execute_worker_method(self, method, *args, **kwargs):
return self.engine.model_executor.driver_worker.execute_method(method, *args, **kwargs)
Expand Down
3 changes: 2 additions & 1 deletion llumnix/backends/vllm/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import threading
from typing import List

import ray.actor
from vllm.engine.arg_utils import EngineArgs

from llumnix.logger import init_logger
Expand Down Expand Up @@ -66,5 +67,5 @@ def __init__(
)
self._thread.start()

def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None:
def send_blocks(self, dst_ray_actor: ray.actor.ActorHandle, src_blocks: List[int], dst_blocks: List[int]) -> None:
self.engine.model_executor.send_blocks(len(src_blocks))
1 change: 1 addition & 0 deletions llumnix/backends/vllm/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def shutdown(self) -> None:
del self.model_runner
del self.cache_engine
del self.gpu_cache
del self.migration_backend
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

Expand Down
22 changes: 18 additions & 4 deletions llumnix/llumlet/llumlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
from typing import List, Union, Iterable
import time
import ray
Expand All @@ -19,7 +20,7 @@

from llumnix.logger import init_logger
from llumnix.instance_info import InstanceInfo
from llumnix.backends.backend_interface import BackendInterface, BackendType
from llumnix.backends.backend_interface import BackendInterface, BackendType, EngineState
from llumnix.backends.utils import init_backend_engine, initialize_placement_group
from llumnix.llumlet.migration_coordinator import MigrationCoordinator, MigrationStatus
from llumnix.llumlet.local_migration_scheduler import LocalMigrationScheduler
Expand Down Expand Up @@ -50,6 +51,9 @@ def __init__(self,
self.backend_engine)
self.log_requests = True

self.state_check_thread = threading.Thread(target=self.check_state, daemon=True)
self.state_check_thread.start()

@classmethod
def from_args(cls,
disable_fixed_node_init_instance: bool,
Expand All @@ -63,13 +67,14 @@ def from_args(cls,
**kwargs):
lifetime = "detached" if detached else None
assert backend_type in [backend_type.VLLM, backend_type.SIM_VLLM], f'unimplemented backend {backend_type}'
actor_name = f"instance_{instance_id}"
if backend_type == backend_type.VLLM:
if disable_fixed_node_init_instance:
# TODO(s5u13b): Support placement_group lifetime management when the migration backend is gloo.
placement_group = initialize_placement_group(world_size, detached=detached)
kwargs["placement_group"] = placement_group
engine_class = ray.remote(num_cpus=1,
name=f"instance_{instance_id}",
name=actor_name,
namespace='llumnix',
max_concurrency=4,
lifetime=lifetime)(cls).options(
Expand All @@ -79,7 +84,7 @@ def from_args(cls,
else:
kwargs["node_id"] = node_id
engine_class = ray.remote(num_cpus=1,
name=f"instance_{instance_id}",
name=actor_name,
namespace='llumnix',
max_concurrency=4,
lifetime=lifetime)(cls).options(
Expand All @@ -88,7 +93,7 @@ def from_args(cls,
soft=False,))
else: # backend_type == backend_type.SIM_VLLM:
engine_class = ray.remote(num_cpus=1,
name=f"instance_{instance_id}",
name=actor_name,
namespace='llumnix',
max_concurrency=4,
lifetime=lifetime)(cls).options(
Expand All @@ -98,6 +103,15 @@ def from_args(cls,
llumlet = engine_class.remote(instance_id, backend_type, migration_config, *args, **kwargs)
return llumlet

def check_state(self):
while True:
time.sleep(1)

with self.backend_engine.state_lock:
if self.backend_engine.state == EngineState.CRASHED:
self_actor = ray.get_actor(self.actor_name)
ray.kill(self_actor)

def migrate_out(self, dst_instance_name: str) -> List[str]:
try:
t0 = time.time()
Expand Down
13 changes: 13 additions & 0 deletions tests/unit_test/llumlet/test_local_migration_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from llumnix.llumlet.local_migration_scheduler import LocalMigrationScheduler
from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType

Expand Down

0 comments on commit 5258e78

Please sign in to comment.