Skip to content

Commit

Permalink
[Misc] Ensure Llumlet main thread exits on Engine.Step errors (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui authored Oct 10, 2024
1 parent 653ba46 commit fc5ecee
Show file tree
Hide file tree
Showing 16 changed files with 211 additions and 16 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/bench_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
cancel_previous_workflows:
Expand All @@ -27,6 +30,7 @@ jobs:
- name: Build And Test
run: ./tools/bench_test.sh
- name: Create comment from file
if: ${{ github.event_name != 'push' }}
uses: actions/github-script@v7
with:
script: |
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/e2e_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
cancel_previous_workflows:
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/migration_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
cancel_previous_workflows:
Expand All @@ -27,6 +30,7 @@ jobs:
- name: Build And Test
run: ./tools/migration_test.sh
- name: Create comment from file
if: ${{ github.event_name != 'push' }}
uses: actions/github-script@v7
with:
script: |
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/offline_inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
cancel_previous_workflows:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
cancel_previous_workflows:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
cancel_previous_workflows:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/whl_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
whl_build:
Expand Down
5 changes: 5 additions & 0 deletions llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from llumnix.llumlet.request import LlumnixRequest
from llumnix.server_info import ServerInfo

class EngineState(str, Enum):
INIT = "INIT"
CRASHED = "CRASHED"
RUNNING = "RUNNING"
STOPPED = "STOPPED"

class BackendType(str, Enum):
VLLM = "VLLM"
Expand Down
37 changes: 34 additions & 3 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.

import time
import traceback
from typing import Any, List, Optional, Dict, Union, Iterable, Tuple
from collections import defaultdict
import threading
Expand All @@ -29,7 +30,7 @@

from llumnix.logger import init_logger
from llumnix.instance_info import InstanceInfo
from llumnix.backends.backend_interface import BackendInterface
from llumnix.backends.backend_interface import BackendInterface, EngineState
from llumnix.backends.vllm.scheduler import SchedulerLlumnix
from llumnix.backends.vllm.sequence import SequenceGroupLlumnix
from llumnix.backends.profiling import LatencyMemData
Expand Down Expand Up @@ -244,14 +245,44 @@ def __init__(
self._run_workers("init_migration", instance_id=instance_id, migration_config=migration_config,\
src_worker_handle_list=self.worker_handle_list,
placement_group=placement_group, node_id=node_id)

self.state_lock = threading.Lock()
self.state = EngineState.INIT
logger.info("engine ({}) current state {}".format(self.instance_id, self.state))

self._stop_event = threading.Event()
self._thread = threading.Thread(
target=self._start_engine_loop, args=(), daemon=True, name="engine_loop"
)
self._thread.start()

def _start_engine_loop(self) -> None:
while True:
self.engine.step()
self._stop_event.clear()

with self.state_lock:
previous_state = self.state
self.state = EngineState.RUNNING
logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, previous_state, self.state))

while not self._stop_event.is_set():
try:
self.engine.step()
# pylint: disable=broad-except
except Exception as e:
logger.error("Error in engine loop: {}".format(e))
logger.error("exception traceback: {}".format(traceback.format_exc()))
self._run_workers("shutdown")

with self.state_lock:
previous_state = self.state
self.state = EngineState.CRASHED
logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, previous_state, self.state))
break

with self.state_lock:
if self.state == EngineState.RUNNING:
self.state = EngineState.STOPPED
logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, EngineState.RUNNING, self.state))

def execute_worker_method(self, method, *args, **kwargs):
return self.engine.model_executor.driver_worker.execute_method(method, *args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion llumnix/backends/vllm/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import os
import threading
from typing import List
import ray.actor

import ray.actor
from vllm.engine.arg_utils import EngineArgs

from llumnix.logger import init_logger
Expand Down
1 change: 1 addition & 0 deletions llumnix/backends/vllm/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def shutdown(self) -> None:
del self.model_runner
del self.cache_engine
del self.gpu_cache
del self.migration_backend
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

Expand Down
15 changes: 8 additions & 7 deletions llumnix/llm_engine_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from llumnix.backends.profiling import ProfilingDatabase
from llumnix.server_info import ServerInfo
from llumnix.backends.backend_interface import BackendType
from llumnix.utils import random_uuid
from llumnix.utils import random_uuid, clear_gloo_backend_state
from llumnix.queue.queue_type import QueueType

logger = init_logger(__name__)
Expand Down Expand Up @@ -291,19 +291,17 @@ async def run_task(alive_instances: List[str], task_name: str, *args, **kwargs):
self.scale_down(dead_instances, rebuild_migrate_backend=False)

if self.engine_manager_args.migration_backend == 'gloo':
try:
# clear gloo migrate backend intermediate state
ray.kill(ray.get_actor("gloo_queue", "llumnix"))
except ValueError:
# gloo_queue may not have been created yet; just ignore this error.
pass
clear_gloo_backend_state()

return dead_instances

alive_instances = sorted(self.instances.keys())
pending_task = self.pending_rebuild_migration_instances
group_name = None

if self.engine_manager_args.migration_backend == 'gloo':
clear_gloo_backend_state()

while len(alive_instances) > 0 and self.pending_rebuild_migration_instances > 0:
dead_instances = set()
group_name = random_uuid()
Expand Down Expand Up @@ -376,6 +374,9 @@ def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migrate_bac
if self.engine_manager_args.migration_backend != 'rpc':
if len(self.instances) == 0:
self.pending_rebuild_migration_instances = 0

if self.engine_manager_args.migration_backend == 'gloo':
clear_gloo_backend_state()
elif indeed_update and no_pending_instance and rebuild_migrate_backend:
asyncio.create_task(self.rebuild_migrate_backend())

Expand Down
29 changes: 25 additions & 4 deletions llumnix/llumlet/llumlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
from typing import List, Union, Iterable
import time
import ray
Expand All @@ -19,7 +20,7 @@

from llumnix.logger import init_logger
from llumnix.instance_info import InstanceInfo
from llumnix.backends.backend_interface import BackendInterface, BackendType
from llumnix.backends.backend_interface import BackendInterface, BackendType, EngineState
from llumnix.backends.utils import init_backend_engine, initialize_placement_group
from llumnix.llumlet.migration_coordinator import MigrationCoordinator, MigrationStatus
from llumnix.llumlet.local_migration_scheduler import LocalMigrationScheduler
Expand Down Expand Up @@ -54,6 +55,10 @@ def __init__(self,
self.backend_engine)
self.log_requests = True

self.check_state_thread = threading.Thread(target=self.check_state, daemon=True,
name="llumlet_check_state_loop")
self.check_state_thread.start()

@classmethod
def from_args(cls,
output_queue_type: QueueType,
Expand All @@ -68,13 +73,14 @@ def from_args(cls,
**kwargs):
lifetime = "detached" if detached else None
assert backend_type in [backend_type.VLLM, backend_type.SIM_VLLM], f'unimplemented backend {backend_type}'
actor_name = f"instance_{instance_id}"
if backend_type == backend_type.VLLM:
if disable_fixed_node_init_instance:
# TODO(s5u13b): Support placement_group lifetime management when the migration backend is gloo.
placement_group = initialize_placement_group(world_size, detached=detached)
kwargs["placement_group"] = placement_group
engine_class = ray.remote(num_cpus=1,
name=f"instance_{instance_id}",
name=actor_name,
namespace='llumnix',
max_concurrency=4,
lifetime=lifetime)(cls).options(
Expand All @@ -84,7 +90,7 @@ def from_args(cls,
else:
kwargs["node_id"] = node_id
engine_class = ray.remote(num_cpus=1,
name=f"instance_{instance_id}",
name=actor_name,
namespace='llumnix',
max_concurrency=4,
lifetime=lifetime)(cls).options(
Expand All @@ -93,7 +99,7 @@ def from_args(cls,
soft=False,))
else: # backend_type == backend_type.SIM_VLLM:
engine_class = ray.remote(num_cpus=1,
name=f"instance_{instance_id}",
name=actor_name,
namespace='llumnix',
max_concurrency=4,
lifetime=lifetime)(cls).options(
Expand All @@ -103,6 +109,21 @@ def from_args(cls,
llumlet = engine_class.remote(instance_id, output_queue_type, backend_type, migration_config, *args, **kwargs)
return llumlet

def check_state(self):
while True:
time.sleep(1)

with self.backend_engine.state_lock:
if self.backend_engine.state == EngineState.CRASHED:
logger.warning("llumlet ({}) detected backend engine crashed. Stopping...".format(self.instance_id))
# pylint: disable=protected-access
self.backend_engine._stop_event.set()
if self.backend_engine._thread.is_alive():
self.backend_engine._thread.join()

self_actor = ray.get_actor(self.actor_name)
ray.kill(self_actor)

def migrate_out(self, dst_instance_name: str) -> List[str]:
try:
t0 = time.time()
Expand Down
10 changes: 9 additions & 1 deletion llumnix/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.

import uuid

import ray

def random_uuid() -> str:
return str(uuid.uuid4().hex)
Expand All @@ -30,3 +30,11 @@ def convert_bytes(bytes_size):
index += 1

return f"{bytes_size:.2f} {size_suffixes[index]}"

def clear_gloo_backend_state():
try:
# clear gloo migrate backend intermediate state
ray.kill(ray.get_actor("gloo_queue", "llumnix"))
except ValueError:
# gloo_queue may not have been created yet; just ignore this error.
pass
Loading

0 comments on commit fc5ecee

Please sign in to comment.