Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Catch the exception generated in llumlet constructor #50

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import queue
import ray
from ray.util.placement_group import PlacementGroup
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy

from vllm.engine.llm_engine import LLMEngine
from vllm.core.scheduler import ScheduledSequenceGroup
Expand Down Expand Up @@ -49,6 +49,7 @@ def __init__(self, instance_id, output_queue_type: QueueType):
self.instance_id = instance_id
self.request_output_queue_client: QueueClientBase = get_output_queue_client(output_queue_type)
self.engine_actor_handle = None
self.output_queue_type = output_queue_type

async def put_nowait_to_servers(self,
server_request_outputs: Dict[str, List[RequestOutput]],
Expand All @@ -65,7 +66,7 @@ async def put_nowait_to_servers(self,
server_id = list(server_request_outputs.keys())[idx]
server_info = server_info_dict[server_id]
logger.info("Server {} is dead".format(server_id))
if output_queue_type == QueueType.ZMQ:
if self.output_queue_type == QueueType.ZMQ:
logger.info("request output queue ip: {}, port: {}".format(server_info.request_output_queue_ip,
server_info.request_output_queue_port))
req_outputs = list(server_request_outputs.values())[idx]
Expand Down
43 changes: 24 additions & 19 deletions llumnix/llumlet/llumlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.

import threading
import traceback
from typing import List, Union, Iterable
import time
import ray
Expand All @@ -31,32 +32,36 @@


class Llumlet:
# TODO(KuilongCui): catch the exception generated in ctor
def __init__(self,
instance_id: str,
output_queue_type: QueueType,
backend_type: BackendType,
migration_config: MigrationConfig,
*args,
**kwargs) -> None:
self.instance_id = instance_id
self.actor_name = f"instance_{instance_id}"
self.backend_engine: BackendInterface = init_backend_engine(self.instance_id,
output_queue_type,
backend_type,
migration_config,
*args,
**kwargs)
self.migration_coordinator = MigrationCoordinator(self.backend_engine,
migration_config.last_stage_max_blocks,
migration_config.max_stages)
self.migration_scheduler = LocalMigrationScheduler(migration_config.request_migration_policy,
self.backend_engine)
self.log_requests = True

self.check_state_thread = threading.Thread(target=self.check_state, daemon=True,
name="llumlet_check_state_loop")
self.check_state_thread.start()
try:
self.instance_id = instance_id
self.actor_name = f"instance_{instance_id}"
self.backend_engine: BackendInterface = init_backend_engine(self.instance_id,
output_queue_type,
backend_type,
migration_config,
*args,
**kwargs)
self.migration_coordinator = MigrationCoordinator(self.backend_engine,
migration_config.last_stage_max_blocks,
migration_config.max_stages)
self.migration_scheduler = LocalMigrationScheduler(migration_config.request_migration_policy,
self.backend_engine)
self.log_requests = True

self.check_state_thread = threading.Thread(target=self.check_state, daemon=True,
name="llumlet_check_state_loop")
self.check_state_thread.start()
# pylint: disable=broad-except
except Exception as e:
logger.error("Failed to initialize llumlet: {}".format(e))
logger.error("exception traceback: {}".format(traceback.format_exc()))

@classmethod
def from_args(cls,
Expand Down