diff --git a/configs/blade.yml b/configs/bladellm.yml similarity index 99% rename from configs/blade.yml rename to configs/bladellm.yml index 4a2b9818..d0170196 100644 --- a/configs/blade.yml +++ b/configs/bladellm.yml @@ -3,7 +3,6 @@ SERVER: LAUNCH_RAY_CLUSTER: True REQUEST_OUTPUT_QUEUE_TYPE: "rayqueue" - MANAGER: DISABLE_FIXED_NODE_INIT_INSTANCE: False DISABLE_INIT_INSTANCE_BY_MANAGER: True diff --git a/configs/base.yml b/configs/vllm.yml similarity index 100% rename from configs/base.yml rename to configs/vllm.yml diff --git a/docs/Arguments.md b/docs/Arguments.md index 2dd707da..6d8a3c0d 100644 --- a/docs/Arguments.md +++ b/docs/Arguments.md @@ -147,7 +147,7 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] `--migration-backend` - Communication backend of migration. -- Possible choices: gloo, rpc, grpc, kvtransfer +- Possible choices: gloo, rpc, grpc, kvtransfer. [gloo, rpc] are available for vllm and [grpc, kvtransfer] are available for bladellm. - Default: "rpc" `--migration-backend-transfer-type` diff --git a/llumnix/__init__.py b/llumnix/__init__.py index d5670921..bdff7d4d 100644 --- a/llumnix/__init__.py +++ b/llumnix/__init__.py @@ -42,3 +42,10 @@ __all__.extend(getattr(vllm, "__all__", [])) except ImportError: pass + +try: + import blade_llm + from blade_llm import * + __all__.extend(getattr(blade_llm, "__all__", [])) +except ImportError: + pass diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index a71287c0..1d25b139 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -16,7 +16,7 @@ import dataclasses from dataclasses import dataclass import argparse -from typing import Tuple, Optional +from typing import Tuple from llumnix.internal_config import GlobalSchedulerConfig, MigrationConfig from llumnix.config import LlumnixConfig, get_llumnix_config @@ -175,12 +175,12 @@ def create_global_scheduler_configs( def create_migration_config(self) -> MigrationConfig: migration_config = MigrationConfig(self.request_migration_policy, self.migration_backend, - self.migration_backend_transfer_type, self.migration_buffer_blocks, self.migration_num_layers, self.last_stage_max_blocks, self.max_stages, self.migration_backend_init_timeout, + self.migration_backend_transfer_type, self.migration_backend_server_address, self.migration_backend_kvtransfer_naming_url) return migration_config @@ -202,11 +202,13 @@ def check_args(cls, args: 'EngineManagerArgs', parser: argparse.ArgumentParser): cur_arg = getattr(args, action.dest) assert cur_arg in action.choices, f"{action.dest} should be one of {action.choices}, but {cur_arg} is set." + # vllm only assert args.migration_backend != 'gloo' or (args.migration_backend == 'gloo' \ and not args.disable_init_instance_by_manager and not args.disable_fixed_node_init_instance), \ ("When using gloo as migration backend, " "do not set --disable-init-instance-by-manager and --disable-fixed-node-init-instance.") + # bladellm only assert args.migration_backend not in ['kvtransfer'] or (args.migration_backend == 'kvtransfer' \ and args.migration_backend_transfer_type), \ ("When using kvTransfer as migration backend, " @@ -315,7 +317,8 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument('--migration-backend', type=str, choices=['gloo','nccl','rpc','grpc','kvtransfer'], - help='communication backend of migration') + help='communication backend of migration, [gloo, rpc] are available for vllm \ + and [grpc, kvtransfer] are available for bladellm') parser.add_argument('--migration-backend-transfer-type', type=str, choices=['cuda_ipc','rdma', ''], diff --git a/llumnix/backends/backend_interface.py b/llumnix/backends/backend_interface.py index c1fc3a84..5e34c01f 100644 --- a/llumnix/backends/backend_interface.py +++ b/llumnix/backends/backend_interface.py @@ -35,6 +35,7 @@ def is_sim_backend(status: "BackendType") -> bool: BackendType.SIM_VLLM, ] +# TODO(KuilongCui): separate backend interface into two parts: DispatchBackendInterface and MigrationBackendInterface class BackendInterface(ABC): # Methods for inference @abstractmethod diff --git a/llumnix/backends/bladellm/llm_engine.py b/llumnix/backends/bladellm/llm_engine.py index 02cedc2f..d3ad956e 100644 --- a/llumnix/backends/bladellm/llm_engine.py +++ b/llumnix/backends/bladellm/llm_engine.py @@ -13,15 +13,13 @@ import json import traceback -from typing import Any, List, Optional, Dict, Tuple, Union, Iterable, Deque +from typing import List, Optional, Tuple, Union, Iterable, Deque from collections import defaultdict import threading import asyncio import queue import ray -import pickle -import zmq from loguru import logger from ray.util.placement_group import PlacementGroup from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy @@ -30,16 +28,47 @@ from blade_llm.service.args import ServingArgs from blade_llm.protocol import GenerateStreamResponse, ServerRequest from blade_llm.service.communications.engine_wrapper import APIWrapper -from blade_llm.protocol import GenerateStreamResponse from blade_llm.utils.disagg_utils import InstanceRole -from blade_llm.service.disagg_decode_server import DecodeEntrypoint from blade_llm.service.disagg_pd_engine import PrefillAsyncLLMEngine, DecodeAsyncLLMEngine from llumnix.backends.backend_interface import BackendInterface, EngineState from llumnix.internal_config import MigrationConfig from llumnix.server_info import ServerInfo -from llumnix.queue.utils import QueueType, AsyncPutQueueActor +from llumnix.queue.utils import QueueType, QueueClientBase, init_request_output_queue_client from llumnix.llumlet.request import LlumnixRequest, RequestStatus +from llumnix.instance_info import InstanceInfo + +class AsyncPutQueueActor: + def __init__(self, instance_id, output_queue_type: QueueType): + self.instance_id = instance_id + self.output_queue_type = output_queue_type + self.request_output_queue_client: QueueClientBase = init_request_output_queue_client(output_queue_type) + self.engine_actor_handle = None + + async def put_nowait_to_servers(self, + server_request_outputs, + server_info_dict) -> None: + if self.engine_actor_handle is None: + self.engine_actor_handle = ray.get_actor("instance_{}".format(self.instance_id), namespace="llumnix") + tasks = [] + for server_id, req_outputs in server_request_outputs.items(): + server_info = server_info_dict[server_id] + for req_output in req_outputs: + if hasattr(req_output, 'request_timestamps'): + req_output.request_timestamps.engine_actor_put_queue_timestamp = time.time() + tasks.append(asyncio.create_task(self.request_output_queue_client.put_nowait(req_outputs, server_info))) + rets = await asyncio.gather(*tasks, return_exceptions=True) + for idx, ret in enumerate(rets): + if isinstance(ret, (TimeoutError, ray.exceptions.RayActorError)): + server_id = list(server_request_outputs.keys())[idx] + server_info = server_info_dict[server_id] + logger.info("Server {} is dead".format(server_id)) + if self.output_queue_type == QueueType.ZMQ: + logger.info("request output queue ip: {}, port: {}".format(server_info.request_output_queue_ip, + server_info.request_output_queue_port)) + req_outputs = list(server_request_outputs.values())[idx] + request_ids = [req_output.request_id for req_output in req_outputs] + self.engine_actor_handle.abort.remote(request_ids) class AsyncBackQueue(APIWrapper): def __init__(self, placement_group, node_id, instance_id, output_queue_type) -> None: @@ -61,7 +90,7 @@ def __init__(self, placement_group, node_id, instance_id, output_queue_type) -> ) self.put_queue_args_queue = queue.Queue() self.put_queue_loop_thread = threading.Thread( - target=self._start_put_queue_loop, args=(), daemon=True, name="put_queue_loop" + target=self._put_request_outputs_loop, args=(), daemon=True, name="put_queue_loop" ) self.async_put_queue_actor = ray.remote( num_cpus=1, @@ -69,9 +98,9 @@ def __init__(self, placement_group, node_id, instance_id, output_queue_type) -> )(AsyncPutQueueActor).remote(instance_id, output_queue_type) self.put_queue_loop_thread.start() - self.request_client_map = {} + self.request_server_map = {} - def _start_put_queue_loop(self): + def _put_request_outputs_loop(self): while True: request_outputs, req_id_outputs, server_info_outputs = [], [], [] @@ -90,7 +119,7 @@ def _start_put_queue_loop(self): self._put_request_outputs_to_server(request_outputs, req_id_outputs, server_info_outputs) - def _put_request_outputs_to_server(self, request_outputs: List[GenerateStreamResponse], + def _put_request_outputs_to_server(self, request_outputs: List[GenerateStreamResponse], req_ids: List[str], server_infos: List[ServerInfo]) -> None: server_request_outputs = defaultdict(list) server_info_dict = {} @@ -103,24 +132,26 @@ def _put_request_outputs_to_server(self, request_outputs: List[GenerateStreamRes logger.debug("_put_request_outputs_to_server, {}", server_request_outputs) self.async_put_queue_actor.put_nowait_to_servers.remote(server_request_outputs, server_info_dict) + # pylint: disable=unused-argument async def send(self, req_id, msg, reset=False): - self.put_queue_args_queue.put_nowait((msg, str(req_id), self.request_client_map[req_id])) + self.put_queue_args_queue.put_nowait((msg, str(req_id), self.request_server_map[req_id])) if msg.is_finished: - self.request_client_map.pop(req_id) + self.request_server_map.pop(req_id) async def recv(self): return None def drop_request(self, request_id: int) -> None: - self.request_client_map.pop(request_id) + self.request_server_map.pop(request_id) def add_request(self, request_id: str, server_info: ServerInfo) -> None: - self.request_client_map[request_id] = server_info + self.request_server_map[request_id] = server_info def stop(self): pass -class LLMEngineLlumnixMixin: +class AsyncLLMEngineLlumnixMixin: + # pylint: disable=unused-argument def __init__(self, instance_id: str, output_queue_type: QueueType, @@ -138,7 +169,7 @@ def __init__(self, self.node_id = node_id @property - def instance_info(self): + def instance_info(self) -> InstanceInfo: return self._scheduler.llumnix_metrics.to_instance_info() def start(self, loop: asyncio.AbstractEventLoop): @@ -151,7 +182,7 @@ def start(self, loop: asyncio.AbstractEventLoop): async def update_callback(self, resp_list, step_requests): await super().update_callback(resp_list, step_requests) self._scheduler.llumnix_metrics.engine_step_metrics(self._scheduler) - + async def _loop(self): previous_state = self.state self.state = EngineState.RUNNING @@ -159,6 +190,7 @@ async def _loop(self): try: await super()._loop() + # pylint: disable=broad-except except Exception as e: logger.error("Error in engine loop: {}".format(e)) logger.error("exception traceback: {}".format(traceback.format_exc())) @@ -183,14 +215,16 @@ async def _handle_abort(self, abort: Optional[List[Tuple[int, int, str]]] = None self.trans_wrapper.drop_request(req_id) await super()._handle_abort(abort) - async def add_request(self, server_request: ServerRequest, server_info: ServerInfo): + async def add_request(self, server_info: ServerInfo, server_request: ServerRequest): + logger.debug("engine {} add request {}", self.instance_id, server_request) self.trans_wrapper.add_request(server_request.id, server_info) + # pylint: disable=protected-access await self._client._add_request(server_request) async def drop_request(self, req_id: int): await self._client.drop_request(req_id) -class AsyncLLMEngineLlumnix(LLMEngineLlumnixMixin, AsyncLLMEngine): +class AsyncLLMEngineLlumnix(AsyncLLMEngineLlumnixMixin, AsyncLLMEngine): def __init__(self, instance_id: str, output_queue_type: QueueType, @@ -200,9 +234,9 @@ def __init__(self, *args, **kwargs, ) -> None: AsyncLLMEngine.__init__(self, *args, **kwargs) - LLMEngineLlumnixMixin.__init__(self, instance_id, output_queue_type, migration_config, placement_group, node_id) + AsyncLLMEngineLlumnixMixin.__init__(self, instance_id, output_queue_type, migration_config, placement_group, node_id) -class PrefillAsyncLLMEngineLlumnix(LLMEngineLlumnixMixin, PrefillAsyncLLMEngine): +class PrefillAsyncLLMEngineLlumnix(AsyncLLMEngineLlumnixMixin, PrefillAsyncLLMEngine): def __init__(self, instance_id: str, output_queue_type: QueueType, @@ -212,9 +246,9 @@ def __init__(self, *args, **kwargs, ) -> None: PrefillAsyncLLMEngine.__init__(self, *args, **kwargs) - LLMEngineLlumnixMixin.__init__(self, instance_id, output_queue_type, migration_config, placement_group, node_id) + AsyncLLMEngineLlumnixMixin.__init__(self, instance_id, output_queue_type, migration_config, placement_group, node_id) -class DecodeAsyncLLMEngineLlumnix(LLMEngineLlumnixMixin, DecodeAsyncLLMEngine): +class DecodeAsyncLLMEngineLlumnix(AsyncLLMEngineLlumnixMixin, DecodeAsyncLLMEngine): def __init__(self, instance_id: str, output_queue_type: QueueType, @@ -224,7 +258,7 @@ def __init__(self, *args, **kwargs, ) -> None: DecodeAsyncLLMEngine.__init__(self, *args, **kwargs) - LLMEngineLlumnixMixin.__init__(self, instance_id, output_queue_type, migration_config, placement_group, node_id) + AsyncLLMEngineLlumnixMixin.__init__(self, instance_id, output_queue_type, migration_config, placement_group, node_id) class BackendBladeLLM(BackendInterface): def __init__( @@ -235,22 +269,21 @@ def __init__( engine_args: ServingArgs, placement_group: PlacementGroup = None, node_id: str = None, - *args, - **kwargs - ) -> None: + ) -> None: self.instance_id = instance_id self.engine_args = engine_args engine_cls = self._get_engine_cls() - self.engine = engine_cls(instance_id, output_queue_type,migration_config, placement_group, node_id, engine_args) - + self.engine = engine_cls(instance_id, output_queue_type, migration_config, placement_group, node_id, engine_args) + self._loop = asyncio.new_event_loop() self._engine_ready = threading.Event() self._thread = threading.Thread(target=self._start_loop, args=(self._loop,), daemon=True, name="async_engine") self._thread.start() self._engine_ready.wait() - + @property def _stop_event(self): + # pylint: disable=protected-access return self.engine._stop_event @property @@ -274,10 +307,10 @@ def _start_loop(self, loop): self._engine_ready.set() loop.run_forever() - def add_request(self, request_id: str, server_info: ServerInfo, expected_steps: int, server_request: str) -> None: - logger.debug("engine {} add request {}", self.instance_id, server_request) - server_request = ServerRequest(**json.loads(server_request)) - asyncio.run_coroutine_threadsafe(self.engine.add_request(server_request, server_info), self._loop) + def add_request(self, request_id: str, server_info: ServerInfo, expected_steps: int, *args, **kwargs) -> None: + assert "server_request" in kwargs and kwargs["server_request"] + server_request = ServerRequest(**json.loads(kwargs["server_request"])) + asyncio.run_coroutine_threadsafe(self.engine.add_request(server_info, server_request), self._loop) def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: if isinstance(request_id, str): @@ -286,9 +319,6 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: for req_id in request_ids: self.engine.drop_request(int(req_id)) - async def _start_engine_step_loop(self) -> None: - pass - def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_stage_num_blocks: int) -> List[int]: pass diff --git a/llumnix/backends/bladellm/metrics.py b/llumnix/backends/bladellm/metrics.py index 83ddb90c..a64a0a53 100644 --- a/llumnix/backends/bladellm/metrics.py +++ b/llumnix/backends/bladellm/metrics.py @@ -18,10 +18,7 @@ from llumnix.metrics.base_metrics import LlumnixMetrics from llumnix.metrics.dumper import LoggerDumper -class BladellmMetrics(LlumnixMetrics): - def __init__(self): - super().__init__() - +class BladeLLMMetrics(LlumnixMetrics): def _init_dumper(self,): self.dumper = LoggerDumper() diff --git a/llumnix/backends/bladellm/scheduler.py b/llumnix/backends/bladellm/scheduler.py index f59a6066..4535dec3 100644 --- a/llumnix/backends/bladellm/scheduler.py +++ b/llumnix/backends/bladellm/scheduler.py @@ -15,14 +15,18 @@ from blade_llm.service.scheduler_types import SchedulerStepOutput from blade_llm.service.args import ServingArgs -from llumnix.backends.bladellm.metrics import BladellmMetrics +from llumnix.backends.bladellm.metrics import BladeLLMMetrics -class PagedSchedulerLlumnix(PagedScheduler): +class SchedulerLlumnixMixin: + def __init__(self): + self.llumnix_metrics = BladeLLMMetrics() + +class PagedSchedulerLlumnix(PagedScheduler, SchedulerLlumnixMixin): def __init__(self, serving_args: ServingArgs, *args, **kwargs) -> None: PagedScheduler.__init__(self, serving_args, *args, **kwargs) - self.llumnix_metrics = BladellmMetrics() + SchedulerLlumnixMixin.__init__(self) self.llumnix_metrics.block_manager_init_metrics(self.block_manager) - + def step(self) -> SchedulerStepOutput: step_out = super().step() self.llumnix_metrics.scheduler_step_metrics(self) diff --git a/llumnix/entrypoints/bladellm/api_server.py b/llumnix/entrypoints/bladellm/api_server.py index 87ad682d..9cd632b5 100644 --- a/llumnix/entrypoints/bladellm/api_server.py +++ b/llumnix/entrypoints/bladellm/api_server.py @@ -18,9 +18,9 @@ from llumnix.backends.backend_interface import BackendType from llumnix.arg_utils import LlumnixEntrypointsArgs, EngineManagerArgs, LlumnixArgumentParser from llumnix.entrypoints.utils import setup_ray_cluster, setup_llumnix, is_gpu_available -from llumnix.entrypoints.bladellm.client import LlumnixClientBladellm +from llumnix.entrypoints.bladellm.client import LlumnixClientBladeLLM from llumnix.entrypoints.utils import LlumnixEntrypointsContext -from llumnix.entrypoints.bladellm.utils import check_args +from llumnix.entrypoints.bladellm.utils import get_args def setup_llumnix_api_server(bladellm_args: ServingArgs, loop: asyncio.AbstractEventLoop): # generate llumnix_parser for checking parameters with choices @@ -28,7 +28,7 @@ def setup_llumnix_api_server(bladellm_args: ServingArgs, loop: asyncio.AbstractE llumnix_parser = LlumnixEntrypointsArgs.add_cli_args(llumnix_parser) llumnix_parser = EngineManagerArgs.add_cli_args(llumnix_parser) llumnix_config: LlumnixConfig = get_llumnix_config(bladellm_args.llumnix_config) - _, engine_manager_args, engine_args = check_args(llumnix_config, llumnix_parser, bladellm_args) + _, engine_manager_args, engine_args = get_args(llumnix_config, llumnix_parser, bladellm_args) setup_ray_cluster(llumnix_config) @@ -39,10 +39,10 @@ def setup_llumnix_api_server(bladellm_args: ServingArgs, loop: asyncio.AbstractE instance_ids = None if engine_args.enable_disagg: instance_ids = [engine_args.disagg_options.inst_id] - + llumnix_context: LlumnixEntrypointsContext = \ setup_llumnix(engine_manager_args, engine_args, llumnix_config, BackendType.BLADELLM, world_size, instance_ids=instance_ids) - llm_client = LlumnixClientBladellm(bladellm_args, llumnix_context, loop) + llm_client = LlumnixClientBladeLLM(bladellm_args, llumnix_context, loop) return llm_client diff --git a/llumnix/entrypoints/bladellm/client.py b/llumnix/entrypoints/bladellm/client.py index e93ad140..b56f6996 100644 --- a/llumnix/entrypoints/bladellm/client.py +++ b/llumnix/entrypoints/bladellm/client.py @@ -23,7 +23,6 @@ from blade_llm.service.communications.protocol import Stats from blade_llm.service.communications.response import LLMResponse from blade_llm.service.args import ServingArgs -from blade_llm.protocol import ServerRequest from blade_llm.protocol import ServerRequest, GenerateStreamResponse from blade_llm.service.communications.response import error_resp @@ -35,7 +34,7 @@ WAIT_MANAGER_INTERVAL = 5 -class LlumnixClientBladellm(MultiProcessingLLMClient): +class LlumnixClientBladeLLM(MultiProcessingLLMClient): def __init__(self, args: ServingArgs, llumnix_context: LlumnixEntrypointsContext, loop: asyncio.AbstractEventLoop): super().__init__(args, -1) self.entrypoint_id2llumnix_id = {} @@ -59,7 +58,7 @@ async def background_process_outputs(self): del self.entrypoint_id2llumnix_id[self.llumnix_id2entrypoint_id[request_id]] del self.llumnix_id2entrypoint_id[request_id] del self.llumnix_context.request_streams[request_id] - + async def _add_request(self, request: ServerRequest) -> LLMResponse: if request.sampling_params.n > 1 or request.sampling_params.use_beam_search: return error_resp(request.id, err_code=400, err_msg="Unsupported feature: multiple sequence decoding in Llumnix.") @@ -86,8 +85,9 @@ async def _manager_generate(self, request, request_id: str) -> LLMResponse: server_info_copy.request_timestamps = RequestTimestamps() server_info_copy.request_timestamps.api_server_manager_generate_timestamp = time.time() # await to catch exception - await self.llumnix_context.engine_manager.generate.remote(str(request_id), server_info_copy, request) + await self.llumnix_context.engine_manager.generate.remote(str(request_id), server_info_copy, server_request=request) self.llumnix_context.manager_available = True + # pylint: disable=broad-except except Exception as e: logger.error("Error in manager generate: {}".format(e)) # Do not re-generate the request to avoid duplicate requests. diff --git a/llumnix/entrypoints/bladellm/utils.py b/llumnix/entrypoints/bladellm/utils.py index 9b043c62..a2f08bc7 100644 --- a/llumnix/entrypoints/bladellm/utils.py +++ b/llumnix/entrypoints/bladellm/utils.py @@ -26,7 +26,7 @@ def detect_unsupported_feature(engine_args: ServingArgs) -> None: elif engine_args.use_sps: unsupported_feature = "speculative decoding" if unsupported_feature: - raise ValueError(f'Unsupported feature: Llumnix does not support "{unsupported_feature}" currently.') + raise ValueError(f'Llumnix does not support "{unsupported_feature}" for bladeLLM currently.') def check_engine_args(engine_args: ServingArgs, engine_manager_args: EngineManagerArgs) -> None: migration_config = engine_manager_args.create_migration_config() @@ -36,12 +36,12 @@ def check_engine_args(engine_args: ServingArgs, engine_manager_args: EngineManag change migration backend to gloo.") engine_manager_args.migration_backend = 'gloo' detect_unsupported_feature(engine_args) - -def check_args(llumnixCfg, llumnixParser, engine_args): - llumnix_entrypoints_args = LlumnixEntrypointsArgs.from_llumnix_config(llumnixCfg) - LlumnixEntrypointsArgs.check_args(llumnix_entrypoints_args, llumnixParser) - engine_manager_args = EngineManagerArgs.from_llumnix_config(llumnixCfg) - EngineManagerArgs.check_args(engine_manager_args, llumnixParser) + +def get_args(llumnix_cfg, llumnix_parser, engine_args): + llumnix_entrypoints_args = LlumnixEntrypointsArgs.from_llumnix_config(llumnix_cfg) + LlumnixEntrypointsArgs.check_args(llumnix_entrypoints_args, llumnix_parser) + engine_manager_args = EngineManagerArgs.from_llumnix_config(llumnix_cfg) + EngineManagerArgs.check_args(engine_manager_args, llumnix_parser) check_engine_args(engine_args, engine_manager_args) logger.info("llumnix_entrypoints_args: {}", llumnix_entrypoints_args) diff --git a/llumnix/entrypoints/utils.py b/llumnix/entrypoints/utils.py index a566f573..995c29f0 100644 --- a/llumnix/entrypoints/utils.py +++ b/llumnix/entrypoints/utils.py @@ -71,7 +71,7 @@ def launch_ray_cluster(port: int) -> subprocess.CompletedProcess: sys.exit(1) ray_start_command = None if 'HEAD_NODE' in os.environ: - ray_start_command = f"ray start --head --node-ip-address={node_ip_address} --port={port} --log-dir=/mnt/xinyi/custom_logs" + ray_start_command = f"ray start --head --node-ip-address={node_ip_address} --port={port}" try: result = subprocess.run(['ray', 'start', '--head', f'--port={port}'], check=True, text=True, capture_output=True) except subprocess.CalledProcessError as e: @@ -147,14 +147,13 @@ def init_manager(engine_manager_args: EngineManagerArgs) -> LLMEngineManager: engine_manager = ray.get_actor(MANAGER_ACTOR_NAME, namespace='llumnix') logger.info("Get existing LLMEngineManager") return engine_manager - -def init_llumlets(engine_manager_args: EngineManagerArgs, engine_args, node_id: str, request_output_queue_type: QueueType, - backend_type: BackendType, world_size: int, *args, **kwargs - ) -> Tuple[List[str], List[Llumlet]]: + +def init_llumlets(engine_manager_args: EngineManagerArgs, engine_args, node_id: str, request_output_queue_type: QueueType, + backend_type: BackendType, world_size: int, *args, **kwargs) -> Tuple[List[str], List[Llumlet]]: instance_ids: List[str] = [] llumlets: List[Llumlet] = [] instance_ids = [random_uuid() for _ in range(engine_manager_args.initial_instances)] - if 'instance_ids' in kwargs and kwargs['instance_ids']: + if 'instance_ids' in kwargs and kwargs['instance_ids'][0]: instance_ids = kwargs['instance_ids'] migration_configs = engine_manager_args.create_migration_config() for idx in range(engine_manager_args.initial_instances): @@ -167,8 +166,8 @@ def init_llumlets(engine_manager_args: EngineManagerArgs, engine_args, node_id: node_id, instance_id, backend_type, - migration_configs, world_size, + migration_configs, engine_args, *args, **kwargs @@ -182,8 +181,8 @@ def init_llumlets(engine_manager_args: EngineManagerArgs, engine_args, node_id: node_id, instance_id, BackendType.SIM_VLLM, - migration_configs, world_size, + migration_configs, engine_manager_args.profiling_result_file_path, *args, **kwargs, @@ -205,7 +204,7 @@ def init_llumnix_components(engine_manager_args: EngineManagerArgs, instance_ids, llumlets = init_llumlets(engine_manager_args, engine_args, node_id, request_output_queue_type, *args, **kwargs) else: instance_ids, llumlets = retry_manager_method_sync( - engine_manager.init_llumlets.remote, 'init_llumlets', node_id, request_output_queue_type, *args, **kwargs) + engine_manager.init_llumlets.remote, 'init_llumlets', engine_args, node_id, request_output_queue_type, *args, **kwargs) available_instance_ids = [] dead_instance_ids = [] diff --git a/llumnix/global_scheduler/dispatch_scheduler.py b/llumnix/global_scheduler/dispatch_scheduler.py index 71421fc4..df3857b4 100644 --- a/llumnix/global_scheduler/dispatch_scheduler.py +++ b/llumnix/global_scheduler/dispatch_scheduler.py @@ -59,6 +59,8 @@ def update_instance_infos(self, def add_instance(self, instance_id: str) -> None: self.instance_id_set.add(instance_id) self.num_instances = len(self.instance_id_set) + + # TODO(KuilongCui): a hacky method is being used to avoid the only-decode type engine dispatched if "decode" not in instance_id: if self.num_dispatch_instances <= 0 or (self.num_dispatch_instances > 0 and len(self.available_dispatch_instance_set) < self.num_dispatch_instances): diff --git a/llumnix/internal_config.py b/llumnix/internal_config.py index f8c2eb82..5a441e35 100644 --- a/llumnix/internal_config.py +++ b/llumnix/internal_config.py @@ -16,14 +16,14 @@ def __init__( self, request_migration_policy: str, migration_backend: str, - migration_backend_transfer_type: str, migration_buffer_blocks: int, migration_num_layers: int, last_stage_max_blocks: int, max_stages: int, migration_backend_init_timeout: float, - migration_backend_server_address: str, - migration_backend_kvtransfer_naming_url: str, + migration_backend_transfer_type: str = "", + migration_backend_server_address: str = "", + migration_backend_kvtransfer_naming_url: str = "", ) -> None: self.request_migration_policy = request_migration_policy self.migration_backend = migration_backend diff --git a/llumnix/llm_engine_manager.py b/llumnix/llm_engine_manager.py index d067ea77..cd68ed70 100644 --- a/llumnix/llm_engine_manager.py +++ b/llumnix/llm_engine_manager.py @@ -100,7 +100,7 @@ def __init__(self, self.instance_last_logged_empty = {} # When manager starts, it automatically connects to all existing instances. - asyncio.run_coroutine_threadsafe(self._connect_to_instances(), asyncio.get_event_loop()) + self._connect_to_instances() async def generate(self, request_id: str, server_info: ServerInfo, *args, **kwargs,) -> None: while self.num_instances == 0: @@ -175,7 +175,7 @@ def update_instance_info_done_callback(instance_id: str, fut): dead_instance_ids.append(instance_id) logger.info("[_update_instance_info_loop] dead instances: {}.".format(ret)) logger.info("[_update_instance_info_loop] dead instances: {}.".format(self.instances)) - + while True: try: await asyncio.sleep(interval) @@ -189,7 +189,6 @@ def update_instance_info_done_callback(instance_id: str, fut): tasks.append(task) await asyncio.gather(*tasks, return_exceptions=True) if len(dead_instance_ids) > 0: - logger.info("[_update_instance_info_loop] dead instances: {}.".format(dead_instance_ids)) self.scale_down(dead_instance_ids) self.num_instance_info_updates += 1 # Push migrate when the instance_info have updated a certain number of times. @@ -395,7 +394,7 @@ def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migrate_bac return self.num_instances - async def _connect_to_instances(self): + def _connect_to_instances(self): actor_names_dict = ray.util.list_named_actors(True) instance_actor_names = [actor_name_dict['name'] for actor_name_dict in actor_names_dict if actor_name_dict['name'] != MANAGER_ACTOR_NAME] instance_actor_handles = [ray.get_actor(actor_name, namespace='llumnix') for actor_name in instance_actor_names] @@ -405,7 +404,7 @@ async def _connect_to_instances(self): instance_id = instance_actor_name[len('instance_'):] if instance_id not in self.instances: try: - await instance_actor_handle.is_ready.remote() + ray.get(instance_actor_handle.is_ready.remote()) # pylint: disable=W0703 except Exception as e: logger.info("connect to instance {} abort, which may be not ready or alive, err: {}".format(instance_id, e)) @@ -448,44 +447,44 @@ def from_args(cls, return engine_manager # TODO(s5u13b): Fix the logger when enabling init instance by manager. - def init_llumlets(self, node_id: str, output_queue_type: QueueType, backend_type: BackendType, world_size: int, *args) -> Tuple[List[str], List[Llumlet]]: + def init_llumlets(self, engine_args, node_id: str, request_output_queue_type: QueueType, + backend_type: BackendType, world_size: int, *args, **kwargs) -> Tuple[List[str], List[Llumlet]]: engine_manager_args = self.engine_manager_args instance_ids: List[str] = [] llumlets: List[Llumlet] = [] + if 'instance_ids' in kwargs and kwargs['instance_ids'][0]: + instance_ids = kwargs['instance_ids'] for _ in range(engine_manager_args.initial_instances): instance_id = random_uuid() - try: - if not engine_manager_args.profiling_result_file_path: - llumlet = Llumlet.from_args( - output_queue_type, - engine_manager_args.disable_fixed_node_init_instance, - True, - node_id, - instance_id, - backend_type, - engine_manager_args.create_migration_config(), - world_size, - *args, - ) - else: - assert backend_type == backend_type.VLLM, f'unimplemented backend SIM_{backend_type}' - llumlet = Llumlet.from_args( - output_queue_type, - engine_manager_args.disable_fixed_node_init_instance, - True, - node_id, - instance_id, - BackendType.SIM_VLLM, - engine_manager_args.create_migration_config(), - world_size, - engine_manager_args.profiling_result_file_path, - *args, - ) - except Exception as e: - import traceback - logger.error("Error in engine loop: {}".format(e)) - logger.error("exception traceback: {}".format(traceback.format_exc())) - + if not engine_manager_args.profiling_result_file_path: + llumlet = Llumlet.from_args( + request_output_queue_type, + engine_manager_args.disable_fixed_node_init_instance, + True, + node_id, + instance_id, + backend_type, + world_size, + engine_manager_args.create_migration_config(), + engine_args, + *args, + **kwargs + ) + else: + assert backend_type == backend_type.VLLM, f'unimplemented backend SIM_{backend_type}' + llumlet = Llumlet.from_args( + request_output_queue_type, + engine_manager_args.disable_fixed_node_init_instance, + True, + node_id, + instance_id, + BackendType.SIM_VLLM, + world_size, + engine_manager_args.create_migration_config(), + engine_manager_args.profiling_result_file_path, + *args, + **kwargs + ) instance_ids.append(instance_id) llumlets.append(llumlet) diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index 98bc6090..0ebc11ea 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -15,8 +15,8 @@ import traceback from typing import List, Union, Iterable import time + import ray -import os from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy from llumnix.logger import init_logger @@ -73,8 +73,8 @@ def from_args(cls, node_id: str, instance_id: str, backend_type: BackendType, - migration_config: MigrationConfig, world_size: int, + migration_config: MigrationConfig, *args, **kwargs): try: @@ -104,7 +104,7 @@ def from_args(cls, else: kwargs["node_id"] = node_id engine_class = ray.remote(num_cpus=1, - num_gpus=num_gpu, # todo(xinyi) bladellm need this + num_gpus=num_gpu, name=actor_name, namespace='llumnix', max_concurrency=4, @@ -128,6 +128,7 @@ def from_args(cls, ) ) llumlet = engine_class.remote(instance_id, request_output_queue_type, backend_type, migration_config, *args, **kwargs) + # pylint: disable=broad-except except Exception as e: logger.error("Failed to initialize llumlet: {}".format(e)) logger.error("exception traceback: {}".format(traceback.format_exc())) @@ -204,12 +205,7 @@ async def _migrate_out_one_request(self, migrate_out_request: LlumnixRequest, ds return migrated_request def get_instance_info(self) -> InstanceInfo: - try: - return self.backend_engine.engine.instance_info - except Exception as e: - logger.error("Error in engine loop: {}".format(e)) - logger.error("exception traceback: {}".format(traceback.format_exc())) - return None + return self.backend_engine.engine.instance_info def is_ready(self) -> bool: return True diff --git a/llumnix/metrics/base_metrics.py b/llumnix/metrics/base_metrics.py index 67c2ee34..ad7d1799 100644 --- a/llumnix/metrics/base_metrics.py +++ b/llumnix/metrics/base_metrics.py @@ -35,26 +35,25 @@ def __init__(self): def dump(self): self.dumper.dump(_REGISTRY.describe_all()) - def to_instance_info(self): + def to_instance_info(self) -> InstanceInfo: return InstanceInfo(**(_REGISTRY.describe_all())) def _init_dumper(self,): self.dumper = DummyDumper() @abstractmethod - def block_manager_init_metrics(self, *args, **kwargs): + def block_manager_init_metrics(self, block_manager): ... @abstractmethod - def engine_init_metrics(self, *args, **kwargs): + def engine_init_metrics(self, engine): ... @abstractmethod - def scheduler_step_metrics(self, *args, **kwargs): + def scheduler_step_metrics(self, scheduler): ... @abstractmethod - def engine_step_metrics(self, *args, **kwargs): + def engine_step_metrics(self, scheduler): ... - \ No newline at end of file diff --git a/llumnix/metrics/variable.py b/llumnix/metrics/variable.py index b4a44221..4ea191d2 100644 --- a/llumnix/metrics/variable.py +++ b/llumnix/metrics/variable.py @@ -67,6 +67,6 @@ def __init__(self, name: str, initial_value: Any = None): def collect(self) -> Any: return self._value - + def observe(self, value: Any) -> None: self._value = value diff --git a/llumnix/queue/utils.py b/llumnix/queue/utils.py index 9dbf3923..c39fa91c 100644 --- a/llumnix/queue/utils.py +++ b/llumnix/queue/utils.py @@ -11,11 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -import time -import traceback -import ray - from llumnix.queue.queue_server_base import QueueServerBase from llumnix.queue.queue_client_base import QueueClientBase from llumnix.queue.zmq_server import ZmqServer @@ -28,43 +23,6 @@ logger = init_logger(__name__) -class AsyncPutQueueActor: - def __init__(self, instance_id, output_queue_type: QueueType): - self.instance_id = instance_id - self.output_queue_type = output_queue_type - self.request_output_queue_client: QueueClientBase = init_request_output_queue_client(output_queue_type) - self.engine_actor_handle = None - - async def put_nowait_to_servers(self, - server_request_outputs, - server_info_dict) -> None: - try: - if self.engine_actor_handle is None: - self.engine_actor_handle = ray.get_actor("instance_{}".format(self.instance_id), namespace="llumnix") - tasks = [] - for server_id, req_outputs in server_request_outputs.items(): - server_info = server_info_dict[server_id] - for req_output in req_outputs: - if hasattr(req_output, 'request_timestamps'): - req_output.request_timestamps.engine_actor_put_queue_timestamp = time.time() - tasks.append(asyncio.create_task(self.request_output_queue_client.put_nowait(req_outputs, server_info))) - rets = await asyncio.gather(*tasks, return_exceptions=True) - for idx, ret in enumerate(rets): - if isinstance(ret, (TimeoutError, ray.exceptions.RayActorError)): - server_id = list(server_request_outputs.keys())[idx] - server_info = server_info_dict[server_id] - logger.info("Server {} is dead".format(server_id)) - if self.output_queue_type == QueueType.ZMQ: - logger.info("request output queue ip: {}, port: {}".format(server_info.request_output_queue_ip, - server_info.request_output_queue_port)) - req_outputs = list(server_request_outputs.values())[idx] - request_ids = [req_output.request_id for req_output in req_outputs] - self.engine_actor_handle.abort.remote(request_ids) - # pylint: disable=W0703 - except Exception as e: - logger.error("Error in engine loop: {}".format(e)) - logger.error("exception traceback: {}".format(traceback.format_exc())) - def init_request_output_queue_server(zmq_ip: str, zmq_port: int, queue_type: QueueType) -> QueueServerBase: output_queue_server: QueueServerBase = None if queue_type == QueueType.ZMQ: diff --git a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py index 106c08a6..11408902 100644 --- a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py +++ b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py @@ -44,7 +44,7 @@ def test_add_instance_and_remove_instance(dispatch_scheduler, num_dispatch_insta assert len(dispatch_scheduler.available_dispatch_instance_set) == 1 dispatch_scheduler.add_instance('instance_3') assert dispatch_scheduler.num_instances == 2 - assert len(dispatch_scheduler.available_dispatch_instance_set) == 2 + assert len(dispatch_scheduler.available_dispatch_instance_set) == min(2, dispatch_scheduler.num_dispatch_instances) dispatch_scheduler.remove_instance('instance_2') assert dispatch_scheduler.num_instances == 1 diff --git a/tests/unit_test/global_scheduler/test_llm_engine_manager.py b/tests/unit_test/global_scheduler/test_llm_engine_manager.py index 5a09b283..16befae9 100644 --- a/tests/unit_test/global_scheduler/test_llm_engine_manager.py +++ b/tests/unit_test/global_scheduler/test_llm_engine_manager.py @@ -27,6 +27,7 @@ from llumnix.queue.queue_type import QueueType from llumnix.global_scheduler.scaling_scheduler import InstanceType from llumnix.backends.vllm.simulator import BackendSimVLLM +from llumnix.backends.backend_interface import BackendType from llumnix.backends.profiling import LatencyMemData # pylint: disable=unused-import @@ -153,7 +154,7 @@ def test_init_llumlet(setup_ray_env, llumlet): def test_init_llumlets(setup_ray_env, engine_manager): engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) node_id = ray.get_runtime_context().get_node_id() - instance_ids, llumlets = ray.get(engine_manager.init_llumlets.remote(engine_args, node_id, QueueType("rayqueue"))) + instance_ids, llumlets = ray.get(engine_manager.init_llumlets.remote(engine_args, node_id, QueueType("rayqueue"), BackendType.VLLM, 1)) num_instances = ray.get(engine_manager.scale_up.remote(instance_ids, llumlets)) engine_manager_args = EngineManagerArgs() assert num_instances == engine_manager_args.initial_instances @@ -165,7 +166,7 @@ def test_init_llumlets_sim(setup_ray_env, engine_manager): llumnix.backends.vllm.simulator.BackendSimVLLM = MockBackendSim engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) node_id = ray.get_runtime_context().get_node_id() - instance_ids, llumlets = ray.get(engine_manager.init_llumlets.remote(engine_args, node_id, QueueType("rayqueue"))) + instance_ids, llumlets = ray.get(engine_manager.init_llumlets.remote(engine_args, node_id, QueueType("rayqueue"), BackendType.VLLM, 1)) num_instances = ray.get(engine_manager.scale_up.remote(instance_ids, llumlets)) engine_manager_args = EngineManagerArgs() assert num_instances == engine_manager_args.initial_instances @@ -219,15 +220,14 @@ def test_generate_and_abort(setup_ray_env, engine_manager, llumlet): def test_get_request_instance(setup_ray_env): _, llumlets = init_llumlets(2) llumlet, llumlet_1 = llumlets[0], llumlets[1] + engine_manager = init_manager() request_id = random_uuid() request_id_1 = random_uuid() - ray.get(llumlet.generate.remote(request_id, None, math.inf, None, None)) - ray.get(llumlet_1.generate.remote(request_id_1, None, math.inf, None, None)) + ray.get(engine_manager.generate.remote(request_id, None, math.inf, None, None)) + ray.get(engine_manager.generate.remote(request_id_1, None, math.inf, None, None)) num_requests = ray.get(llumlet.get_num_requests.remote()) num_requests_1 = ray.get(llumlet_1.get_num_requests.remote()) - assert num_requests == 1 - assert num_requests_1 == 1 - engine_manager = init_manager() + assert num_requests + num_requests_1 == 2 ray.get(engine_manager.abort.remote(request_id)) ray.get(engine_manager.abort.remote(request_id_1)) num_requests = ray.get(llumlet.get_num_requests.remote())