Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Ensure Llumlet main thread exits on Engine.Step errors #38

Merged
merged 4 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/bench_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
cancel_previous_workflows:
Expand All @@ -27,6 +30,7 @@ jobs:
- name: Build And Test
run: ./tools/bench_test.sh
- name: Create comment from file
if: ${{ github.event_name != 'push' }}
uses: actions/github-script@v7
with:
script: |
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/e2e_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
cancel_previous_workflows:
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/migration_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
cancel_previous_workflows:
Expand All @@ -27,6 +30,7 @@ jobs:
- name: Build And Test
run: ./tools/migration_test.sh
- name: Create comment from file
if: ${{ github.event_name != 'push' }}
uses: actions/github-script@v7
with:
script: |
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/offline_inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
cancel_previous_workflows:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
cancel_previous_workflows:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
cancel_previous_workflows:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/whl_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
whl_build:
Expand Down
5 changes: 5 additions & 0 deletions llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from llumnix.llumlet.request import LlumnixRequest
from llumnix.server_info import ServerInfo

class EngineState(str, Enum):
INIT = "INIT"
CRASHED = "CRASHED"
RUNNING = "RUNNING"
STOPPED = "STOPPED"

class BackendType(str, Enum):
VLLM = "VLLM"
Expand Down
37 changes: 34 additions & 3 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.

import time
import traceback
from typing import Any, List, Optional, Dict, Union, Iterable, Tuple
from collections import defaultdict
import threading
Expand All @@ -29,7 +30,7 @@

from llumnix.logger import init_logger
from llumnix.instance_info import InstanceInfo
from llumnix.backends.backend_interface import BackendInterface
from llumnix.backends.backend_interface import BackendInterface, EngineState
from llumnix.backends.vllm.scheduler import SchedulerLlumnix
from llumnix.backends.vllm.sequence import SequenceGroupLlumnix
from llumnix.backends.profiling import LatencyMemData
Expand Down Expand Up @@ -244,14 +245,44 @@ def __init__(
self._run_workers("init_migration", instance_id=instance_id, migration_config=migration_config,\
src_worker_handle_list=self.worker_handle_list,
placement_group=placement_group, node_id=node_id)

self.state_lock = threading.Lock()
self.state = EngineState.INIT
logger.info("engine ({}) current state {}".format(self.instance_id, self.state))

self._stop_event = threading.Event()
self._thread = threading.Thread(
target=self._start_engine_loop, args=(), daemon=True, name="engine_loop"
)
self._thread.start()

def _start_engine_loop(self) -> None:
while True:
self.engine.step()
self._stop_event.clear()

with self.state_lock:
zhypku marked this conversation as resolved.
Show resolved Hide resolved
previous_state = self.state
self.state = EngineState.RUNNING
logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, previous_state, self.state))

while not self._stop_event.is_set():
try:
self.engine.step()
# pylint: disable=broad-except
except Exception as e:
logger.error("Error in engine loop: {}".format(e))
s5u13b marked this conversation as resolved.
Show resolved Hide resolved
logger.error("exception traceback: {}".format(traceback.format_exc()))
self._run_workers("shutdown")

with self.state_lock:
previous_state = self.state
self.state = EngineState.CRASHED
logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, previous_state, self.state))
break

with self.state_lock:
if self.state == EngineState.RUNNING:
self.state = EngineState.STOPPED
logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, EngineState.RUNNING, self.state))

def execute_worker_method(self, method, *args, **kwargs):
return self.engine.model_executor.driver_worker.execute_method(method, *args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion llumnix/backends/vllm/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import os
import threading
from typing import List
import ray.actor

import ray.actor
from vllm.engine.arg_utils import EngineArgs

from llumnix.logger import init_logger
Expand Down
1 change: 1 addition & 0 deletions llumnix/backends/vllm/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def shutdown(self) -> None:
del self.model_runner
del self.cache_engine
del self.gpu_cache
del self.migration_backend
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

Expand Down
15 changes: 8 additions & 7 deletions llumnix/llm_engine_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from llumnix.backends.profiling import ProfilingDatabase
from llumnix.server_info import ServerInfo
from llumnix.backends.backend_interface import BackendType
from llumnix.utils import random_uuid
from llumnix.utils import random_uuid, clear_gloo_backend_state
from llumnix.queue.queue_type import QueueType

logger = init_logger(__name__)
Expand Down Expand Up @@ -291,19 +291,17 @@ async def run_task(alive_instances: List[str], task_name: str, *args, **kwargs):
self.scale_down(dead_instances, rebuild_migrate_backend=False)

if self.engine_manager_args.migration_backend == 'gloo':
try:
# clear gloo migrate backend intermediate state
ray.kill(ray.get_actor("gloo_queue", "llumnix"))
except ValueError:
# gloo_queue may not have been created yet; just ignore this error.
pass
clear_gloo_backend_state()

return dead_instances

alive_instances = sorted(self.instances.keys())
pending_task = self.pending_rebuild_migration_instances
group_name = None

if self.engine_manager_args.migration_backend == 'gloo':
clear_gloo_backend_state()

while len(alive_instances) > 0 and self.pending_rebuild_migration_instances > 0:
dead_instances = set()
group_name = random_uuid()
Expand Down Expand Up @@ -376,6 +374,9 @@ def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migrate_bac
if self.engine_manager_args.migration_backend != 'rpc':
if len(self.instances) == 0:
self.pending_rebuild_migration_instances = 0

if self.engine_manager_args.migration_backend == 'gloo':
clear_gloo_backend_state()
elif indeed_update and no_pending_instance and rebuild_migrate_backend:
asyncio.create_task(self.rebuild_migrate_backend())

Expand Down
29 changes: 25 additions & 4 deletions llumnix/llumlet/llumlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
from typing import List, Union, Iterable
import time
import ray
Expand All @@ -19,7 +20,7 @@

from llumnix.logger import init_logger
from llumnix.instance_info import InstanceInfo
from llumnix.backends.backend_interface import BackendInterface, BackendType
from llumnix.backends.backend_interface import BackendInterface, BackendType, EngineState
from llumnix.backends.utils import init_backend_engine, initialize_placement_group
from llumnix.llumlet.migration_coordinator import MigrationCoordinator, MigrationStatus
from llumnix.llumlet.local_migration_scheduler import LocalMigrationScheduler
Expand Down Expand Up @@ -54,6 +55,10 @@ def __init__(self,
self.backend_engine)
self.log_requests = True

self.check_state_thread = threading.Thread(target=self.check_state, daemon=True,
name="llumlet_check_state_loop")
self.check_state_thread.start()

@classmethod
def from_args(cls,
output_queue_type: QueueType,
Expand All @@ -68,13 +73,14 @@ def from_args(cls,
**kwargs):
lifetime = "detached" if detached else None
assert backend_type in [backend_type.VLLM, backend_type.SIM_VLLM], f'unimplemented backend {backend_type}'
actor_name = f"instance_{instance_id}"
if backend_type == backend_type.VLLM:
if disable_fixed_node_init_instance:
# TODO(s5u13b): Support placement_group lifetime management when the migration backend is gloo.
placement_group = initialize_placement_group(world_size, detached=detached)
kwargs["placement_group"] = placement_group
engine_class = ray.remote(num_cpus=1,
name=f"instance_{instance_id}",
name=actor_name,
namespace='llumnix',
max_concurrency=4,
lifetime=lifetime)(cls).options(
Expand All @@ -84,7 +90,7 @@ def from_args(cls,
else:
kwargs["node_id"] = node_id
engine_class = ray.remote(num_cpus=1,
name=f"instance_{instance_id}",
name=actor_name,
namespace='llumnix',
max_concurrency=4,
lifetime=lifetime)(cls).options(
Expand All @@ -93,7 +99,7 @@ def from_args(cls,
soft=False,))
else: # backend_type == backend_type.SIM_VLLM:
engine_class = ray.remote(num_cpus=1,
name=f"instance_{instance_id}",
name=actor_name,
namespace='llumnix',
max_concurrency=4,
lifetime=lifetime)(cls).options(
Expand All @@ -103,6 +109,21 @@ def from_args(cls,
llumlet = engine_class.remote(instance_id, output_queue_type, backend_type, migration_config, *args, **kwargs)
return llumlet

def check_state(self):
while True:
time.sleep(1)

with self.backend_engine.state_lock:
if self.backend_engine.state == EngineState.CRASHED:
logger.warning("llumlet ({}) detected backend engine crashed. Stopping...".format(self.instance_id))
# pylint: disable=protected-access
self.backend_engine._stop_event.set()
if self.backend_engine._thread.is_alive():
self.backend_engine._thread.join()

self_actor = ray.get_actor(self.actor_name)
ray.kill(self_actor)

def migrate_out(self, dst_instance_name: str) -> List[str]:
try:
t0 = time.time()
Expand Down
10 changes: 9 additions & 1 deletion llumnix/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.

import uuid

import ray

def random_uuid() -> str:
return str(uuid.uuid4().hex)
Expand All @@ -30,3 +30,11 @@ def convert_bytes(bytes_size):
index += 1

return f"{bytes_size:.2f} {size_suffixes[index]}"

def clear_gloo_backend_state():
try:
# clear gloo migrate backend intermediate state
ray.kill(ray.get_actor("gloo_queue", "llumnix"))
except ValueError:
# gloo_queue may not have been created yet; just ignore this error.
pass
Loading