Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Ensure Llumlet main thread exits on Engine.Step errors #38

Merged
merged 4 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/bench_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
cancel_previous_workflows:
Expand All @@ -27,6 +30,7 @@ jobs:
- name: Build And Test
run: ./tools/bench_test.sh
- name: Create comment from file
if: ${{ github.event_name != 'push' }}
uses: actions/github-script@v7
with:
script: |
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/e2e_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
cancel_previous_workflows:
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/migration_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
cancel_previous_workflows:
Expand All @@ -27,6 +30,7 @@ jobs:
- name: Build And Test
run: ./tools/migration_test.sh
- name: Create comment from file
if: ${{ github.event_name != 'push' }}
uses: actions/github-script@v7
with:
script: |
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/offline_inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
cancel_previous_workflows:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
cancel_previous_workflows:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
cancel_previous_workflows:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/whl_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
whl_build:
Expand Down
5 changes: 5 additions & 0 deletions llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from llumnix.llumlet.request import LlumnixRequest
from llumnix.server_info import ServerInfo

class EngineState(str, Enum):
INIT = "INIT"
CRASHED = "CRASHED"
RUNNING = "RUNNING"
STOPPED = "STOPPED"

class BackendType(str, Enum):
VLLM = "VLLM"
Expand Down
37 changes: 34 additions & 3 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.

import time
import traceback
from typing import Any, List, Optional, Dict, Union, Iterable, Tuple
from collections import defaultdict
import threading
Expand All @@ -29,7 +30,7 @@

from llumnix.logger import init_logger
from llumnix.instance_info import InstanceInfo
from llumnix.backends.backend_interface import BackendInterface
from llumnix.backends.backend_interface import BackendInterface, EngineState
from llumnix.backends.vllm.scheduler import SchedulerLlumnix
from llumnix.backends.vllm.sequence import SequenceGroupLlumnix
from llumnix.backends.profiling import LatencyMemData
Expand Down Expand Up @@ -244,14 +245,44 @@ def __init__(
self._run_workers("init_migration", instance_id=instance_id, migration_config=migration_config,\
src_worker_handle_list=self.worker_handle_list,
placement_group=placement_group, node_id=node_id)

self.state_lock = threading.Lock()
self.state = EngineState.INIT
logger.info("{} current state {}".format(self.instance_id, self.state))
KuilongCui marked this conversation as resolved.
Show resolved Hide resolved

self._stop_event = threading.Event()
self._thread = threading.Thread(
target=self._start_engine_loop, args=(), daemon=True, name="engine_loop"
)
self._thread.start()

def _start_engine_loop(self) -> None:
while True:
self.engine.step()
self._stop_event.clear()

with self.state_lock:
zhypku marked this conversation as resolved.
Show resolved Hide resolved
previous_state = self.state
self.state = EngineState.RUNNING
logger.info("{} change state: {} -> {}".format(self.instance_id, previous_state, self.state))

while not self._stop_event.is_set():
try:
self.engine.step()
# pylint: disable=broad-except
except Exception as e:
logger.error("Error in engine loop: {}".format(e))
s5u13b marked this conversation as resolved.
Show resolved Hide resolved
logger.error("exception traceback: {}".format(traceback.format_exc()))
self._run_workers("shutdown")

with self.state_lock:
previous_state = self.state
self.state = EngineState.CRASHED
logger.info("{} change state: {} -> {}".format(self.instance_id, previous_state, self.state))
break

with self.state_lock:
if self.state == EngineState.RUNNING:
self.state = EngineState.STOPPED
logger.info("{} change state: {} -> {}".format(self.instance_id, EngineState.RUNNING, self.state))

def execute_worker_method(self, method, *args, **kwargs):
return self.engine.model_executor.driver_worker.execute_method(method, *args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion llumnix/backends/vllm/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import os
import threading
from typing import List
import ray.actor

import ray.actor
from vllm.engine.arg_utils import EngineArgs

from llumnix.logger import init_logger
Expand Down
1 change: 1 addition & 0 deletions llumnix/backends/vllm/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def shutdown(self) -> None:
del self.model_runner
del self.cache_engine
del self.gpu_cache
del self.migration_backend
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

Expand Down
28 changes: 24 additions & 4 deletions llumnix/llumlet/llumlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
from typing import List, Union, Iterable
import time
import ray
Expand All @@ -19,7 +20,7 @@

from llumnix.logger import init_logger
from llumnix.instance_info import InstanceInfo
from llumnix.backends.backend_interface import BackendInterface, BackendType
from llumnix.backends.backend_interface import BackendInterface, BackendType, EngineState
from llumnix.backends.utils import init_backend_engine, initialize_placement_group
from llumnix.llumlet.migration_coordinator import MigrationCoordinator, MigrationStatus
from llumnix.llumlet.local_migration_scheduler import LocalMigrationScheduler
Expand Down Expand Up @@ -54,6 +55,9 @@ def __init__(self,
self.backend_engine)
self.log_requests = True

self.state_check_thread = threading.Thread(target=self.check_state, daemon=True)
self.state_check_thread.start()
KuilongCui marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def from_args(cls,
output_queue_type: QueueType,
Expand All @@ -68,13 +72,14 @@ def from_args(cls,
**kwargs):
lifetime = "detached" if detached else None
assert backend_type in [backend_type.VLLM, backend_type.SIM_VLLM], f'unimplemented backend {backend_type}'
actor_name = f"instance_{instance_id}"
if backend_type == backend_type.VLLM:
if disable_fixed_node_init_instance:
# TODO(s5u13b): Support placement_group lifetime management when the migration backend is gloo.
placement_group = initialize_placement_group(world_size, detached=detached)
kwargs["placement_group"] = placement_group
engine_class = ray.remote(num_cpus=1,
name=f"instance_{instance_id}",
name=actor_name,
namespace='llumnix',
max_concurrency=4,
lifetime=lifetime)(cls).options(
Expand All @@ -84,7 +89,7 @@ def from_args(cls,
else:
kwargs["node_id"] = node_id
engine_class = ray.remote(num_cpus=1,
name=f"instance_{instance_id}",
name=actor_name,
namespace='llumnix',
max_concurrency=4,
lifetime=lifetime)(cls).options(
Expand All @@ -93,7 +98,7 @@ def from_args(cls,
soft=False,))
else: # backend_type == backend_type.SIM_VLLM:
engine_class = ray.remote(num_cpus=1,
name=f"instance_{instance_id}",
name=actor_name,
namespace='llumnix',
max_concurrency=4,
lifetime=lifetime)(cls).options(
Expand All @@ -103,6 +108,21 @@ def from_args(cls,
llumlet = engine_class.remote(instance_id, output_queue_type, backend_type, migration_config, *args, **kwargs)
return llumlet

def check_state(self):
while True:
time.sleep(1)

with self.backend_engine.state_lock:
if self.backend_engine.state == EngineState.CRASHED:
logger.warning("llumlet({}) detected backend engine crashed. Stopping...".format(self.instance_id))
# pylint: disable=protected-access
self.backend_engine._stop_event.set()
if self.backend_engine._thread.is_alive():
self.backend_engine._thread.join()

self_actor = ray.get_actor(self.actor_name)
ray.kill(self_actor)

def migrate_out(self, dst_instance_name: str) -> List[str]:
try:
t0 = time.time()
Expand Down
92 changes: 92 additions & 0 deletions tests/unit_test/llumlet/test_engine_step_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
import time
import ray
import torch
import pytest

from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy

from vllm.engine.arg_utils import EngineArgs

from llumnix.backends.backend_interface import BackendType
from llumnix.llumlet.llumlet import Llumlet
from llumnix.internal_config import MigrationConfig
from llumnix.queue.queue_type import QueueType
# pylint: disable=unused-import
from tests.conftest import setup_ray_env

@ray.remote(num_cpus=1, max_concurrency=4)
class MockLlumlet(Llumlet):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.origin_step = self.backend_engine.engine.step

def set_error_step(self, broken: bool):
self.backend_engine._stop_event.set()
if self.backend_engine._thread.is_alive():
self.backend_engine._thread.join()

def raise_error_step():
self.origin_step()
raise ValueError("Mock engine step error")

if broken:
self.backend_engine.engine.step = raise_error_step
else:
self.backend_engine.engine.step = self.origin_step

self.backend_engine._thread = threading.Thread(
target=self.backend_engine._start_engine_loop, args=(), daemon=True, name="engine_loop"
)
self.backend_engine._thread.start()

@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need at least 1 GPU to run the test.")
def test_engine_step_exception(setup_ray_env):
engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True)
migration_config = MigrationConfig("LCFS", "rpc", 16, 1, 4, 5, 20)
node_id = ray.get_runtime_context().get_node_id()
scheduling_strategy = NodeAffinitySchedulingStrategy(node_id=node_id, soft=False)

origin_free_memory, _ = torch.cuda.mem_get_info()

actor_name = "instance_0"
llumlet = MockLlumlet.options(name=actor_name, namespace='llumnix',
scheduling_strategy=scheduling_strategy).remote(
output_queue_type=QueueType.RAYQUEUE,
instance_id="0",
backend_type=BackendType.VLLM,
migration_config=migration_config,
engine_args=engine_args,
node_id=node_id
)
ray.get(llumlet.is_ready.remote())

all_actors = ray.util.list_named_actors(True)
all_actor_names = [actor["name"] for actor in all_actors]
assert actor_name in all_actor_names

cur_free_memory, _ = torch.cuda.mem_get_info()
assert cur_free_memory < origin_free_memory

ray.get(llumlet.set_error_step.remote(True))
time.sleep(3)

all_actors = ray.util.list_named_actors(True)
all_actor_names = [actor["name"] for actor in all_actors]
assert actor_name not in all_actor_names

cur_free_memory, _ = torch.cuda.mem_get_info()
assert origin_free_memory == cur_free_memory
13 changes: 13 additions & 0 deletions tests/unit_test/llumlet/test_local_migration_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from llumnix.llumlet.local_migration_scheduler import LocalMigrationScheduler
from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType

Expand Down
Loading