Skip to content

Commit

Permalink
fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui committed Dec 13, 2024
1 parent e18cfa5 commit a3ff024
Show file tree
Hide file tree
Showing 22 changed files with 176 additions and 182 deletions.
1 change: 0 additions & 1 deletion configs/blade.yml → configs/bladellm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ SERVER:
LAUNCH_RAY_CLUSTER: True
REQUEST_OUTPUT_QUEUE_TYPE: "rayqueue"


MANAGER:
DISABLE_FIXED_NODE_INIT_INSTANCE: False
DISABLE_INIT_INSTANCE_BY_MANAGER: True
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion docs/Arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]

`--migration-backend`
- Communication backend of migration.
- Possible choices: gloo, rpc, grpc, kvtransfer
- Possible choices: gloo, rpc, grpc, kvtransfer. [gloo, rpc] are available for vllm and [grpc, kvtransfer] are available for bladellm.
- Default: "rpc"

`--migration-backend-transfer-type`
Expand Down
7 changes: 7 additions & 0 deletions llumnix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,10 @@
__all__.extend(getattr(vllm, "__all__", []))
except ImportError:
pass

try:
import blade_llm
from blade_llm import *
__all__.extend(getattr(blade_llm, "__all__", []))
except ImportError:
pass
9 changes: 6 additions & 3 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import dataclasses
from dataclasses import dataclass
import argparse
from typing import Tuple, Optional
from typing import Tuple

from llumnix.internal_config import GlobalSchedulerConfig, MigrationConfig
from llumnix.config import LlumnixConfig, get_llumnix_config
Expand Down Expand Up @@ -175,12 +175,12 @@ def create_global_scheduler_configs(
def create_migration_config(self) -> MigrationConfig:
migration_config = MigrationConfig(self.request_migration_policy,
self.migration_backend,
self.migration_backend_transfer_type,
self.migration_buffer_blocks,
self.migration_num_layers,
self.last_stage_max_blocks,
self.max_stages,
self.migration_backend_init_timeout,
self.migration_backend_transfer_type,
self.migration_backend_server_address,
self.migration_backend_kvtransfer_naming_url)
return migration_config
Expand All @@ -202,11 +202,13 @@ def check_args(cls, args: 'EngineManagerArgs', parser: argparse.ArgumentParser):
cur_arg = getattr(args, action.dest)
assert cur_arg in action.choices, f"{action.dest} should be one of {action.choices}, but {cur_arg} is set."

# vllm only
assert args.migration_backend != 'gloo' or (args.migration_backend == 'gloo' \
and not args.disable_init_instance_by_manager and not args.disable_fixed_node_init_instance), \
("When using gloo as migration backend, "
"do not set --disable-init-instance-by-manager and --disable-fixed-node-init-instance.")

# bladellm only
assert args.migration_backend not in ['kvtransfer'] or (args.migration_backend == 'kvtransfer' \
and args.migration_backend_transfer_type), \
("When using kvTransfer as migration backend, "
Expand Down Expand Up @@ -315,7 +317,8 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument('--migration-backend',
type=str,
choices=['gloo','nccl','rpc','grpc','kvtransfer'],
help='communication backend of migration')
help='communication backend of migration, [gloo, rpc] are available for vllm \
and [grpc, kvtransfer] are available for bladellm')
parser.add_argument('--migration-backend-transfer-type',
type=str,
choices=['cuda_ipc','rdma', ''],
Expand Down
1 change: 1 addition & 0 deletions llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def is_sim_backend(status: "BackendType") -> bool:
BackendType.SIM_VLLM,
]

# TODO(KuilongCui): separate backend interface into two parts: DispatchBackendInterface and MigrationBackendInterface
class BackendInterface(ABC):
# Methods for inference
@abstractmethod
Expand Down
104 changes: 67 additions & 37 deletions llumnix/backends/bladellm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@

import json
import traceback
from typing import Any, List, Optional, Dict, Tuple, Union, Iterable, Deque
from typing import List, Optional, Tuple, Union, Iterable, Deque
from collections import defaultdict
import threading
import asyncio
import queue

import ray
import pickle
import zmq
from loguru import logger
from ray.util.placement_group import PlacementGroup
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy
Expand All @@ -30,16 +28,47 @@
from blade_llm.service.args import ServingArgs
from blade_llm.protocol import GenerateStreamResponse, ServerRequest
from blade_llm.service.communications.engine_wrapper import APIWrapper
from blade_llm.protocol import GenerateStreamResponse
from blade_llm.utils.disagg_utils import InstanceRole
from blade_llm.service.disagg_decode_server import DecodeEntrypoint
from blade_llm.service.disagg_pd_engine import PrefillAsyncLLMEngine, DecodeAsyncLLMEngine

from llumnix.backends.backend_interface import BackendInterface, EngineState
from llumnix.internal_config import MigrationConfig
from llumnix.server_info import ServerInfo
from llumnix.queue.utils import QueueType, AsyncPutQueueActor
from llumnix.queue.utils import QueueType, QueueClientBase, init_request_output_queue_client
from llumnix.llumlet.request import LlumnixRequest, RequestStatus
from llumnix.instance_info import InstanceInfo

class AsyncPutQueueActor:
def __init__(self, instance_id, output_queue_type: QueueType):
self.instance_id = instance_id
self.output_queue_type = output_queue_type
self.request_output_queue_client: QueueClientBase = init_request_output_queue_client(output_queue_type)
self.engine_actor_handle = None

async def put_nowait_to_servers(self,
server_request_outputs,
server_info_dict) -> None:
if self.engine_actor_handle is None:
self.engine_actor_handle = ray.get_actor("instance_{}".format(self.instance_id), namespace="llumnix")
tasks = []
for server_id, req_outputs in server_request_outputs.items():
server_info = server_info_dict[server_id]
for req_output in req_outputs:
if hasattr(req_output, 'request_timestamps'):
req_output.request_timestamps.engine_actor_put_queue_timestamp = time.time()
tasks.append(asyncio.create_task(self.request_output_queue_client.put_nowait(req_outputs, server_info)))
rets = await asyncio.gather(*tasks, return_exceptions=True)
for idx, ret in enumerate(rets):
if isinstance(ret, (TimeoutError, ray.exceptions.RayActorError)):
server_id = list(server_request_outputs.keys())[idx]
server_info = server_info_dict[server_id]
logger.info("Server {} is dead".format(server_id))
if self.output_queue_type == QueueType.ZMQ:
logger.info("request output queue ip: {}, port: {}".format(server_info.request_output_queue_ip,
server_info.request_output_queue_port))
req_outputs = list(server_request_outputs.values())[idx]
request_ids = [req_output.request_id for req_output in req_outputs]
self.engine_actor_handle.abort.remote(request_ids)

class AsyncBackQueue(APIWrapper):
def __init__(self, placement_group, node_id, instance_id, output_queue_type) -> None:
Expand All @@ -61,17 +90,17 @@ def __init__(self, placement_group, node_id, instance_id, output_queue_type) ->
)
self.put_queue_args_queue = queue.Queue()
self.put_queue_loop_thread = threading.Thread(
target=self._start_put_queue_loop, args=(), daemon=True, name="put_queue_loop"
target=self._put_request_outputs_loop, args=(), daemon=True, name="put_queue_loop"
)
self.async_put_queue_actor = ray.remote(
num_cpus=1,
scheduling_strategy=scheduling_strategy
)(AsyncPutQueueActor).remote(instance_id, output_queue_type)
self.put_queue_loop_thread.start()

self.request_client_map = {}
self.request_server_map = {}

def _start_put_queue_loop(self):
def _put_request_outputs_loop(self):
while True:
request_outputs, req_id_outputs, server_info_outputs = [], [], []

Expand All @@ -90,7 +119,7 @@ def _start_put_queue_loop(self):

self._put_request_outputs_to_server(request_outputs, req_id_outputs, server_info_outputs)

def _put_request_outputs_to_server(self, request_outputs: List[GenerateStreamResponse],
def _put_request_outputs_to_server(self, request_outputs: List[GenerateStreamResponse],
req_ids: List[str], server_infos: List[ServerInfo]) -> None:
server_request_outputs = defaultdict(list)
server_info_dict = {}
Expand All @@ -103,24 +132,26 @@ def _put_request_outputs_to_server(self, request_outputs: List[GenerateStreamRes
logger.debug("_put_request_outputs_to_server, {}", server_request_outputs)
self.async_put_queue_actor.put_nowait_to_servers.remote(server_request_outputs, server_info_dict)

# pylint: disable=unused-argument
async def send(self, req_id, msg, reset=False):
self.put_queue_args_queue.put_nowait((msg, str(req_id), self.request_client_map[req_id]))
self.put_queue_args_queue.put_nowait((msg, str(req_id), self.request_server_map[req_id]))
if msg.is_finished:
self.request_client_map.pop(req_id)
self.request_server_map.pop(req_id)

async def recv(self):
return None

def drop_request(self, request_id: int) -> None:
self.request_client_map.pop(request_id)
self.request_server_map.pop(request_id)

def add_request(self, request_id: str, server_info: ServerInfo) -> None:
self.request_client_map[request_id] = server_info
self.request_server_map[request_id] = server_info

def stop(self):
pass

class LLMEngineLlumnixMixin:
class AsyncLLMEngineLlumnixMixin:
# pylint: disable=unused-argument
def __init__(self,
instance_id: str,
output_queue_type: QueueType,
Expand All @@ -138,7 +169,7 @@ def __init__(self,
self.node_id = node_id

@property
def instance_info(self):
def instance_info(self) -> InstanceInfo:
return self._scheduler.llumnix_metrics.to_instance_info()

def start(self, loop: asyncio.AbstractEventLoop):
Expand All @@ -151,14 +182,15 @@ def start(self, loop: asyncio.AbstractEventLoop):
async def update_callback(self, resp_list, step_requests):
await super().update_callback(resp_list, step_requests)
self._scheduler.llumnix_metrics.engine_step_metrics(self._scheduler)

async def _loop(self):
previous_state = self.state
self.state = EngineState.RUNNING
logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, previous_state, self.state))

try:
await super()._loop()
# pylint: disable=broad-except
except Exception as e:
logger.error("Error in engine loop: {}".format(e))
logger.error("exception traceback: {}".format(traceback.format_exc()))
Expand All @@ -183,14 +215,16 @@ async def _handle_abort(self, abort: Optional[List[Tuple[int, int, str]]] = None
self.trans_wrapper.drop_request(req_id)
await super()._handle_abort(abort)

async def add_request(self, server_request: ServerRequest, server_info: ServerInfo):
async def add_request(self, server_info: ServerInfo, server_request: ServerRequest):
logger.debug("engine {} add request {}", self.instance_id, server_request)
self.trans_wrapper.add_request(server_request.id, server_info)
# pylint: disable=protected-access
await self._client._add_request(server_request)

async def drop_request(self, req_id: int):
await self._client.drop_request(req_id)

class AsyncLLMEngineLlumnix(LLMEngineLlumnixMixin, AsyncLLMEngine):
class AsyncLLMEngineLlumnix(AsyncLLMEngineLlumnixMixin, AsyncLLMEngine):
def __init__(self,
instance_id: str,
output_queue_type: QueueType,
Expand All @@ -200,9 +234,9 @@ def __init__(self,
*args, **kwargs,
) -> None:
AsyncLLMEngine.__init__(self, *args, **kwargs)
LLMEngineLlumnixMixin.__init__(self, instance_id, output_queue_type, migration_config, placement_group, node_id)
AsyncLLMEngineLlumnixMixin.__init__(self, instance_id, output_queue_type, migration_config, placement_group, node_id)

class PrefillAsyncLLMEngineLlumnix(LLMEngineLlumnixMixin, PrefillAsyncLLMEngine):
class PrefillAsyncLLMEngineLlumnix(AsyncLLMEngineLlumnixMixin, PrefillAsyncLLMEngine):
def __init__(self,
instance_id: str,
output_queue_type: QueueType,
Expand All @@ -212,9 +246,9 @@ def __init__(self,
*args, **kwargs,
) -> None:
PrefillAsyncLLMEngine.__init__(self, *args, **kwargs)
LLMEngineLlumnixMixin.__init__(self, instance_id, output_queue_type, migration_config, placement_group, node_id)
AsyncLLMEngineLlumnixMixin.__init__(self, instance_id, output_queue_type, migration_config, placement_group, node_id)

class DecodeAsyncLLMEngineLlumnix(LLMEngineLlumnixMixin, DecodeAsyncLLMEngine):
class DecodeAsyncLLMEngineLlumnix(AsyncLLMEngineLlumnixMixin, DecodeAsyncLLMEngine):
def __init__(self,
instance_id: str,
output_queue_type: QueueType,
Expand All @@ -224,7 +258,7 @@ def __init__(self,
*args, **kwargs,
) -> None:
DecodeAsyncLLMEngine.__init__(self, *args, **kwargs)
LLMEngineLlumnixMixin.__init__(self, instance_id, output_queue_type, migration_config, placement_group, node_id)
AsyncLLMEngineLlumnixMixin.__init__(self, instance_id, output_queue_type, migration_config, placement_group, node_id)

class BackendBladeLLM(BackendInterface):
def __init__(
Expand All @@ -235,22 +269,21 @@ def __init__(
engine_args: ServingArgs,
placement_group: PlacementGroup = None,
node_id: str = None,
*args,
**kwargs
) -> None:
) -> None:
self.instance_id = instance_id
self.engine_args = engine_args
engine_cls = self._get_engine_cls()
self.engine = engine_cls(instance_id, output_queue_type,migration_config, placement_group, node_id, engine_args)
self.engine = engine_cls(instance_id, output_queue_type, migration_config, placement_group, node_id, engine_args)

self._loop = asyncio.new_event_loop()
self._engine_ready = threading.Event()
self._thread = threading.Thread(target=self._start_loop, args=(self._loop,), daemon=True, name="async_engine")
self._thread.start()
self._engine_ready.wait()

@property
def _stop_event(self):
# pylint: disable=protected-access
return self.engine._stop_event

@property
Expand All @@ -274,10 +307,10 @@ def _start_loop(self, loop):
self._engine_ready.set()
loop.run_forever()

def add_request(self, request_id: str, server_info: ServerInfo, expected_steps: int, server_request: str) -> None:
logger.debug("engine {} add request {}", self.instance_id, server_request)
server_request = ServerRequest(**json.loads(server_request))
asyncio.run_coroutine_threadsafe(self.engine.add_request(server_request, server_info), self._loop)
def add_request(self, request_id: str, server_info: ServerInfo, expected_steps: int, *args, **kwargs) -> None:
assert "server_request" in kwargs and kwargs["server_request"]
server_request = ServerRequest(**json.loads(kwargs["server_request"]))
asyncio.run_coroutine_threadsafe(self.engine.add_request(server_info, server_request), self._loop)

def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
if isinstance(request_id, str):
Expand All @@ -286,9 +319,6 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
for req_id in request_ids:
self.engine.drop_request(int(req_id))

async def _start_engine_step_loop(self) -> None:
pass

def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_stage_num_blocks: int) -> List[int]:
pass

Expand Down
5 changes: 1 addition & 4 deletions llumnix/backends/bladellm/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
from llumnix.metrics.base_metrics import LlumnixMetrics
from llumnix.metrics.dumper import LoggerDumper

class BladellmMetrics(LlumnixMetrics):
def __init__(self):
super().__init__()

class BladeLLMMetrics(LlumnixMetrics):
def _init_dumper(self,):
self.dumper = LoggerDumper()

Expand Down
12 changes: 8 additions & 4 deletions llumnix/backends/bladellm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@
from blade_llm.service.scheduler_types import SchedulerStepOutput
from blade_llm.service.args import ServingArgs

from llumnix.backends.bladellm.metrics import BladellmMetrics
from llumnix.backends.bladellm.metrics import BladeLLMMetrics

class PagedSchedulerLlumnix(PagedScheduler):
class SchedulerLlumnixMixin:
def __init__(self):
self.llumnix_metrics = BladeLLMMetrics()

class PagedSchedulerLlumnix(PagedScheduler, SchedulerLlumnixMixin):
def __init__(self, serving_args: ServingArgs, *args, **kwargs) -> None:
PagedScheduler.__init__(self, serving_args, *args, **kwargs)
self.llumnix_metrics = BladellmMetrics()
SchedulerLlumnixMixin.__init__(self)
self.llumnix_metrics.block_manager_init_metrics(self.block_manager)

def step(self) -> SchedulerStepOutput:
step_out = super().step()
self.llumnix_metrics.scheduler_step_metrics(self)
Expand Down
Loading

0 comments on commit a3ff024

Please sign in to comment.