diff --git a/.gitignore b/.gitignore index 58572718..eba62d4a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +# Proto files +*_pb2.py +*_pb2_grpc.py + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/Makefile b/Makefile index a071a8d0..c3618524 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ init: .PHONY: install install: - @pip install -e . + @pip install -e .[vllm] .PHONY: lint lint: check_pylint_installed check_pytest_installed @@ -27,6 +27,30 @@ lint: check_pylint_installed check_pytest_installed --disable=protected-access,super-init-not-called,unused-argument,redefined-outer-name,invalid-name \ -s n --jobs=128 ./tests +.PHONY: clean +clean: proto-clean + +###################################### proto begin ###################################### + +.PHONY: proto +proto: + @find . -type d -name "proto" | while read dir; do \ + dir_base=$$(dirname $$dir); \ + find $$dir -name "*.proto" | while read proto_file; do \ + echo "Compiling $$proto_file"; \ + PYTHONWARNINGS="ignore::DeprecationWarning" python -m grpc_tools.protoc --proto_path=. --python_out=. --grpc_python_out=. $$proto_file; \ + done; \ + done; + +.PHONY: proto-clean +proto-clean: + @find . -name "*_pb2_grpc.py" | xargs rm -f + @find . -name "*_pb2.py" | xargs rm -f + +####################################### proto end ####################################### + +###################################### test begin ####################################### + .PHONY: test test: check_pytest_installed @pytest -v --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings @@ -55,6 +79,8 @@ bench_test: migration_test: @pytest -v -x -s --tb=long ./tests/e2e_test/test_migration.py +####################################### test end ######################################## + #################### pygloo install for gloo migration backend begin #################### BAZEL_CMD = bazel diff --git a/configs/bladellm.yml b/configs/bladellm.yml new file mode 100644 index 00000000..d0170196 --- /dev/null +++ b/configs/bladellm.yml @@ -0,0 +1,22 @@ +SERVER: + RAY_CLUSTER_PORT: 6379 + LAUNCH_RAY_CLUSTER: True + REQUEST_OUTPUT_QUEUE_TYPE: "rayqueue" + +MANAGER: + DISABLE_FIXED_NODE_INIT_INSTANCE: False + DISABLE_INIT_INSTANCE_BY_MANAGER: True + + LOAD_METRIC: 'remaining_steps' + DISPATCH_POLICY: 'load' + + ENABLE_MIGRATION: False + ENABLE_DEFRAG: True + REQUEST_MIGRATION_POLICY: 'SR' + + MIGRATION_BACKEND: 'grpc' + MIGRATION_BUFFER_BLOCKS: 512 + + ENABLE_SCALING: False + + LOG_INSTANCE_INFO: False diff --git a/configs/base.yml b/configs/vllm.yml similarity index 100% rename from configs/base.yml rename to configs/vllm.yml diff --git a/docs/Arguments.md b/docs/Arguments.md index 37ff6a99..d6c8b7fe 100644 --- a/docs/Arguments.md +++ b/docs/Arguments.md @@ -32,8 +32,11 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] [--profiling-result-file-path PROFILING_RESULT_FILE_PATH] [--gpu-type GPU_TYPE] [--polling-interval POLLING_INTERVAL] - [--migration-backend {gloo,nccl,rpc}] + [--migration-backend {gloo,nccl,rayrpc,grpc,kvtransfer}] [--migration-buffer-blocks MIGRATION_BUFFER_BLOCKS] + [--migration-backend-transfer-type {cuda_ipc,rdma,}] + [--migration-backend-kvtransfer-naming-url MIGRATION_BACKEND_KVTRANSFER_NAMING_URL] + [--migration-backend-server-address MIGRATION_BACKEND_SERVER_ADDRESS] [--migration-backend-init-timeout MIGRATION_BACKEND_INIT_TIMEOUT] [--migration-num-layers MIGRATION_NUM_LAYERS] [--last-stage-max-blocks LAST_STAGE_MAX_BLOCKS] @@ -144,11 +147,24 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] `--migration-backend` - Communication backend of migration. -- Possible choices: gloo, rpc -- Default: "rpc" +- Possible choices: gloo, rayrpc, nccl, grpc, kvtransfer. [gloo, rayrpc, nccl] are available for vllm and [grpc, kvtransfer] are available for bladellm. +- Default: "gloo" + +`--migration-backend-transfer-type` +- Transfer type for migration backend kvTransfer. +- Possible choices: cuda_ipc, rdma +- Default: "rdma" + +`--migration-backend-server-address` +- Address of grpc server for migration backend +- Default: "127.0.0.1:50051" + +`--migration-backend-kvtransfer-naming-url` +- URL of naming server for kvtransfer migration backend +- Default: "file:/tmp/llumnix/naming/" `--migration-buffer-blocks` -- Number of cache blocks in migration. +- Number of buffer blocks in migration. - Default: 512 `--migration-backend-init-timeout` diff --git a/docs/Quickstart.md b/docs/Quickstart.md index 78024cb5..4fcd605f 100644 --- a/docs/Quickstart.md +++ b/docs/Quickstart.md @@ -22,7 +22,7 @@ cd llumnix make install ``` -The default migration backend is RPC. If you want to use NCCL as the migration backend, run `make cupy-cuda` to install [cupy-cuda](https://pypi.org/search/?q=cupy-cuda) manually, as it is related to the CUDA version. +The default migration backend is rayrpc. If you want to use NCCL as the migration backend, run `make cupy-cuda` to install [cupy-cuda](https://pypi.org/search/?q=cupy-cuda) manually, as it is related to the CUDA version. If you want to use Gloo as migration backend, **in addition to installing cupy-cuda**, please refer to [this link](https://github.com/ZeldaHuang/pygloo/blob/main/.github/workflows/ubuntu_basic.yml#L24C1-L26C1) to install [Bazel](https://github.com/bazelbuild/bazel) >= 5.1.0. Then, run `make pygloo` to install [pygloo](https://github.com/ZeldaHuang/pygloo). diff --git a/examlpes/offline_inference.py b/examlpes/offline_inference.py index dabb4f94..5ab8f39f 100644 --- a/examlpes/offline_inference.py +++ b/examlpes/offline_inference.py @@ -6,7 +6,7 @@ from llumnix import launch_ray_cluster, connect_to_ray_cluster, init_manager, init_llumlets from llumnix import (SamplingParams, ServerInfo, EngineManagerArgs, LLMEngineManager, Llumlet, - EngineArgs, QueueType) + EngineArgs, QueueType, BackendType) from llumnix.utils import random_uuid from llumnix.queue.ray_queue_server import RayQueueServer @@ -40,7 +40,7 @@ llumlets: List[Llumlet] = None llumlet_ids, llumlets = init_llumlets( manager_args, engine_args, ray.get_runtime_context().get_node_id(), - QueueType("rayqueue") + QueueType("rayqueue"), BackendType.VLLM, 1, ) diff --git a/llumnix/__init__.py b/llumnix/__init__.py index 5ef7ecee..09bad2ad 100644 --- a/llumnix/__init__.py +++ b/llumnix/__init__.py @@ -11,9 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import vllm -from vllm import * - from llumnix.server_info import ServerInfo from llumnix.entrypoints.setup import (launch_ray_cluster, connect_to_ray_cluster, @@ -23,8 +20,8 @@ from llumnix.llm_engine_manager import LLMEngineManager from llumnix.llumlet.llumlet import Llumlet from llumnix.queue.queue_type import QueueType - -from .version import __version__ +from llumnix.backends.backend_interface import BackendType +from llumnix.version import __version__ __all__ = [ "__version__", @@ -37,6 +34,20 @@ "LLMEngineManager", "Llumlet", "QueueType", + "BackendType", ] -__all__.extend(getattr(vllm, "__all__", [])) +try: + import vllm + from vllm import * + __all__.extend(getattr(vllm, "__all__", [])) +except ImportError: + pass + +# TODO(KuilongCui): import blade_llm after cuda is ready +# try: +# import blade_llm +# from blade_llm import * +# __all__.extend(getattr(blade_llm, "__all__", [])) +# except ImportError: +# pass diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index 3a2a25c2..2386fa4f 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -51,7 +51,7 @@ class LlumnixEntrypointsArgs: request_output_queue_port: int = None disable_log_requests_server: bool = None log_request_timestamps: bool = None - config_file: bool = None + config_file: str = None def __post_init__(self): for attr in dataclasses.fields(self): @@ -132,9 +132,12 @@ class EngineManagerArgs: log_instance_info: bool = None profiling_result_file_path: str = None + migration_backend_kvtransfer_naming_url: str = None + migration_backend_server_address: str = None migration_backend_init_timeout: float = None migration_backend: str = None migration_buffer_blocks: int = None + migration_backend_transfer_type: str = None migration_num_layers: int = None last_stage_max_blocks: int = None max_stages: int = None @@ -177,7 +180,10 @@ def create_migration_config(self) -> MigrationConfig: self.migration_num_layers, self.last_stage_max_blocks, self.max_stages, - self.migration_backend_init_timeout) + self.migration_backend_init_timeout, + self.migration_backend_transfer_type, + self.migration_backend_server_address, + self.migration_backend_kvtransfer_naming_url) return migration_config @classmethod @@ -194,16 +200,23 @@ def check_args(cls, args: 'EngineManagerArgs', parser: argparse.ArgumentParser): # pylint: disable=protected-access for action in parser._optionals._actions: if hasattr(action, 'choices') and action.choices is not None and hasattr(args, action.dest): - assert getattr(args, action.dest) in action.choices, f"{action.dest} should be one of {action.choices}." + cur_arg = getattr(args, action.dest) + assert cur_arg in action.choices, f"{action.dest} should be one of {action.choices}, but {cur_arg} is set." + # vllm only assert args.migration_backend != 'gloo' or (args.migration_backend == 'gloo' \ and not args.disable_init_instance_by_manager and not args.disable_fixed_node_init_instance), \ ("When using gloo as migration backend, " "do not set --disable-init-instance-by-manager and --disable-fixed-node-init-instance.") + # bladellm only + assert args.migration_backend not in ['kvtransfer'] or (args.migration_backend == 'kvtransfer' \ + and args.migration_backend_transfer_type), \ + ("When using kvTransfer as migration backend, " + "do not set --migration-backend-transfer-type as empty.") + @staticmethod - def add_cli_args( - parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument('--disable-fixed-node-init-instance', action='store_true', help='disable fixing the placement of instance to current node') @@ -302,17 +315,27 @@ def add_cli_args( parser.add_argument('--profiling-result-file-path', type=str, help='profiling result file path') - parser.add_argument('--migration-backend', type=str, - choices=['gloo', 'nccl', 'rpc'], - help='communication backend of migration') + choices=['gloo','nccl','rayrpc','grpc','kvtransfer'], + help='communication backend of migration, [gloo, rayrpc, nccl] are available for vllm \ + and [grpc, kvtransfer] are available for bladellm') + parser.add_argument('--migration-backend-transfer-type', + type=str, + choices=['cuda_ipc','rdma', ''], + help='transfer type for migration backend grpc and kvTransfer') + parser.add_argument('--grpc-migration-backend-address', + type=str, + help='address of grpc server for migration backend') + parser.add_argument('--migration-backend-kvtransfer-naming-url', + type=str, + help='url of naming server for kvtransfer migration backend') parser.add_argument('--migration-backend-init-timeout', type=float, help='timeout(s) for initializing migration backend') parser.add_argument('--migration-buffer-blocks', type=int, - help='number of cache blocks in migration') + help='number of buffer blocks in migration') parser.add_argument('--migration-num-layers', type=int, help='number of kv-cache layers to transfer in each round during migration') diff --git a/llumnix/backends/backend_interface.py b/llumnix/backends/backend_interface.py index 28e1e802..5e34c01f 100644 --- a/llumnix/backends/backend_interface.py +++ b/llumnix/backends/backend_interface.py @@ -27,6 +27,7 @@ class EngineState(str, Enum): class BackendType(str, Enum): VLLM = "VLLM" SIM_VLLM = "SIM_VLLM" + BLADELLM = "BLADELLM" @staticmethod def is_sim_backend(status: "BackendType") -> bool: @@ -34,6 +35,7 @@ def is_sim_backend(status: "BackendType") -> bool: BackendType.SIM_VLLM, ] +# TODO(KuilongCui): separate backend interface into two parts: DispatchBackendInterface and MigrationBackendInterface class BackendInterface(ABC): # Methods for inference @abstractmethod @@ -67,12 +69,6 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: """ raise NotImplementedError - @abstractmethod - async def _start_engine_step_loop(self) -> None: - """Start step loop of backend engine. - """ - raise NotImplementedError - # Methods for migration @abstractmethod def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_stage_num_blocks: int) -> List[int]: diff --git a/llumnix/backends/bladellm/__init__.py b/llumnix/backends/bladellm/__init__.py new file mode 100644 index 00000000..4638bd9c --- /dev/null +++ b/llumnix/backends/bladellm/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/llumnix/backends/bladellm/llm_engine.py b/llumnix/backends/bladellm/llm_engine.py new file mode 100644 index 00000000..557b1bc1 --- /dev/null +++ b/llumnix/backends/bladellm/llm_engine.py @@ -0,0 +1,341 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import traceback +from typing import List, Optional, Tuple, Union, Iterable, Deque +from collections import defaultdict +import threading +import asyncio +import queue + +import ray +from loguru import logger +from ray.util.placement_group import PlacementGroup +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy + +from blade_llm.service.engine import AsyncLLMEngine +from blade_llm.service.args import ServingArgs +from blade_llm.protocol import GenerateStreamResponse, ServerRequest +from blade_llm.service.communications.engine_wrapper import APIWrapper +from blade_llm.utils.disagg_utils import InstanceRole +from blade_llm.service.disagg_pd_engine import PrefillAsyncLLMEngine, DecodeAsyncLLMEngine + +from llumnix.backends.backend_interface import BackendInterface, EngineState +from llumnix.internal_config import MigrationConfig +from llumnix.server_info import ServerInfo +from llumnix.backends.utils import AsyncPutQueueActor +from llumnix.llumlet.request import LlumnixRequest, RequestStatus +from llumnix.instance_info import InstanceInfo +from llumnix.queue.queue_type import QueueType + +class AsyncBackQueueWrapper(APIWrapper): + def __init__(self, placement_group, node_id, instance_id, output_queue_type) -> None: + super().__init__(args=None, resp_queue=None) + if placement_group: + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_capture_child_tasks=True, + ) + elif node_id: + scheduling_strategy = NodeAffinitySchedulingStrategy( + node_id=node_id, + soft=False, + ) + else: # When use simulator, placement_group and node_id are both None. + scheduling_strategy = NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), + soft=False, + ) + self.put_queue_args_queue = queue.Queue() + self.put_queue_loop_thread = threading.Thread( + target=self._put_request_outputs_loop, args=(), daemon=True, name="put_queue_loop" + ) + self.async_put_queue_actor = ray.remote( + num_cpus=1, + scheduling_strategy=scheduling_strategy + )(AsyncPutQueueActor).remote(instance_id, output_queue_type) + self.put_queue_loop_thread.start() + + self.request_server_map = {} + + def _put_request_outputs_loop(self): + while True: + request_outputs, req_id_outputs, server_info_outputs = [], [], [] + + resp, req_id, server_info = self.put_queue_args_queue.get() + request_outputs.append(resp) + req_id_outputs.append(req_id) + server_info_outputs.append(server_info) + + if self.put_queue_args_queue.qsize() > 0: + request_size = self.put_queue_args_queue.qsize() + for _ in range(request_size): + resp, req_id, server_info = self.put_queue_args_queue.get() + request_outputs.append(resp) + req_id_outputs.append(req_id) + server_info_outputs.append(server_info) + + self._put_request_outputs_to_server(request_outputs, req_id_outputs, server_info_outputs) + + def _put_request_outputs_to_server(self, request_outputs: List[GenerateStreamResponse], + req_ids: List[str], server_infos: List[ServerInfo]) -> None: + server_request_outputs = defaultdict(list) + server_info_dict = {} + # Reorganize data in orther to put request output to queue in batch at one time. + for request_output, req_id, server_info in zip(request_outputs, req_ids, server_infos): + server_id = server_info.server_id + server_request_outputs[server_id].append((req_id, request_output.model_dump_json())) + if server_id not in server_info_dict: + server_info_dict[server_id] = server_info + logger.debug("_put_request_outputs_to_server, {}", server_request_outputs) + self.async_put_queue_actor.put_nowait_to_servers.remote(server_request_outputs, server_info_dict) + + # pylint: disable=unused-argument + async def send(self, req_id, msg, reset=False): + self.put_queue_args_queue.put_nowait((msg, str(req_id), self.request_server_map[req_id])) + if msg.is_finished: + self.request_server_map.pop(req_id) + + async def recv(self): + return None + + def drop_request(self, request_id: int) -> None: + self.request_server_map.pop(request_id) + + def add_request(self, request_id: str, server_info: ServerInfo) -> None: + self.request_server_map[request_id] = server_info + + def stop(self): + pass + +class AsyncLLMEngineLlumnixMixin: + # pylint: disable=unused-argument + def __init__(self, + instance_id: str, + output_queue_type: QueueType, + migration_config: MigrationConfig, + placement_group: Optional[PlacementGroup], + node_id: Optional[str], + ) -> None: + self.instance_id = instance_id + + self.state = EngineState.INIT + logger.info("engine ({}) current state {}".format(self.instance_id, self.state)) + + self.placement_group = placement_group + self.output_queue_type = output_queue_type + self.node_id = node_id + + @property + def instance_info(self) -> InstanceInfo: + return self._scheduler.llumnix_metrics.to_instance_info() + + def start(self, loop: asyncio.AbstractEventLoop): + super().start(loop) + self._client = self.init_client_from_engine() + self.trans_wrapper: AsyncBackQueueWrapper = AsyncBackQueueWrapper(self.placement_group, + self.node_id, self.instance_id, self.output_queue_type) + self._scheduler.llumnix_metrics.engine_init_metrics(self) + + async def update_callback(self, resp_list, step_requests): + await super().update_callback(resp_list, step_requests) + self._scheduler.llumnix_metrics.engine_step_metrics(self._scheduler) + + async def _loop(self): + previous_state = self.state + self.state = EngineState.RUNNING + logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, previous_state, self.state)) + + try: + await super()._loop() + # pylint: disable=broad-except + except Exception as e: + logger.error("Error in engine loop: {}".format(e)) + logger.error("exception traceback: {}".format(traceback.format_exc())) + + previous_state = self.state + self.state = EngineState.CRASHED + logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, previous_state, self.state)) + + if self.state == EngineState.RUNNING: + self.state = EngineState.STOPPED + logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, EngineState.RUNNING, self.state)) + + async def _handle_dropped_request(self): + if self._dropped_req: + for req_id in self._dropped_req: + self.trans_wrapper.drop_request(req_id) + await super()._handle_dropped_request() + + async def _handle_abort(self, abort: Optional[List[Tuple[int, int, str]]] = None): + if abort is not None and len(abort) > 0: + for req_id, _, _ in abort: + self.trans_wrapper.drop_request(req_id) + await super()._handle_abort(abort) + + async def add_request(self, server_info: ServerInfo, server_request: ServerRequest): + logger.debug("engine {} add request {}", self.instance_id, server_request) + self.trans_wrapper.add_request(server_request.id, server_info) + # pylint: disable=protected-access + await self._client._add_request(server_request) + + async def drop_request(self, req_id: int): + await self._client.drop_request(req_id) + +class AsyncLLMEngineLlumnix(AsyncLLMEngineLlumnixMixin, AsyncLLMEngine): + def __init__(self, + instance_id: str, + output_queue_type: QueueType, + migration_config: MigrationConfig, + placement_group: Optional[PlacementGroup], + node_id: Optional[str], + *args, **kwargs, + ) -> None: + AsyncLLMEngine.__init__(self, *args, **kwargs) + AsyncLLMEngineLlumnixMixin.__init__(self, instance_id, output_queue_type, migration_config, placement_group, node_id) + +class PrefillAsyncLLMEngineLlumnix(AsyncLLMEngineLlumnixMixin, PrefillAsyncLLMEngine): + def __init__(self, + instance_id: str, + output_queue_type: QueueType, + migration_config: MigrationConfig, + placement_group: Optional[PlacementGroup], + node_id: Optional[str], + *args, **kwargs, + ) -> None: + PrefillAsyncLLMEngine.__init__(self, *args, **kwargs) + AsyncLLMEngineLlumnixMixin.__init__(self, instance_id, output_queue_type, migration_config, placement_group, node_id) + +class DecodeAsyncLLMEngineLlumnix(AsyncLLMEngineLlumnixMixin, DecodeAsyncLLMEngine): + def __init__(self, + instance_id: str, + output_queue_type: QueueType, + migration_config: MigrationConfig, + placement_group: Optional[PlacementGroup], + node_id: Optional[str], + *args, **kwargs, + ) -> None: + DecodeAsyncLLMEngine.__init__(self, *args, **kwargs) + AsyncLLMEngineLlumnixMixin.__init__(self, instance_id, output_queue_type, migration_config, placement_group, node_id) + +class BackendBladeLLM(BackendInterface): + def __init__( + self, + instance_id: str, + output_queue_type: QueueType, + migration_config: MigrationConfig, + engine_args: ServingArgs, + placement_group: PlacementGroup = None, + node_id: str = None, + ) -> None: + self.instance_id = instance_id + self.engine_args = engine_args + engine_cls = self._get_engine_cls() + self.engine = engine_cls(instance_id, output_queue_type, migration_config, placement_group, node_id, engine_args) + + self._loop = asyncio.new_event_loop() + self._engine_ready = threading.Event() + self._thread = threading.Thread(target=self._start_loop, args=(self._loop,), daemon=True, name="async_engine") + self._thread.start() + self._engine_ready.wait() + + @property + def _stop_event(self): + # pylint: disable=protected-access + return self.engine._stop_event + + @property + def state(self): + return self.engine.state + + def _get_engine_cls(self): + engine_cls = None + if not self.engine_args.enable_disagg: + engine_cls = AsyncLLMEngineLlumnix + else: + if self.engine_args.disagg_options.inst_role == InstanceRole.PREFILL: + engine_cls = PrefillAsyncLLMEngineLlumnix + else: + engine_cls = DecodeAsyncLLMEngineLlumnix + return engine_cls + + def _start_loop(self, loop): + asyncio.set_event_loop(loop) + self.engine.start(loop) + self._engine_ready.set() + loop.run_forever() + + def add_request(self, request_id: str, server_info: ServerInfo, expected_steps: int, *args, **kwargs) -> None: + assert "server_request" in kwargs and kwargs["server_request"] + server_request = ServerRequest(**json.loads(kwargs["server_request"])) + asyncio.run_coroutine_threadsafe(self.engine.add_request(server_info, server_request), self._loop) + + def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: + if isinstance(request_id, str): + request_id = (request_id,) + request_ids = set(request_id) + for req_id in request_ids: + self.engine.drop_request(int(req_id)) + + def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_stage_num_blocks: int) -> List[int]: + pass + + def get_running_queue(self) -> Deque[LlumnixRequest]: + pass + + def get_waiting_queue(self) -> Deque[LlumnixRequest]: + pass + + def remove_running_request(self, request_id: str) -> bool: + pass + + def remove_waiting_request(self, request_id: str) -> bool: + pass + + def add_migrating_out_request_last_stage(self, backend_request: LlumnixRequest) -> None: + pass + + def remove_migrating_out_request_last_stage(self, backend_request: LlumnixRequest) -> None: + pass + + def pop_migrating_out_requests_last_stage(self) -> List[LlumnixRequest]: + pass + + def pre_alloc(self, + request_id: str, + request_status: RequestStatus, + request_arrival_time: float, + block_num: int) -> List[int]: + pass + + def add_running_request(self, backend_request: LlumnixRequest) -> None: + pass + + def add_waiting_request(self, backend_request: LlumnixRequest) -> None: + pass + + def free_dst_pre_alloc_cache(self, request_id: str = None) -> None: + pass + + def free_src_request(self, backend_request: LlumnixRequest) -> None: + pass + + async def send_blocks(self, dst_ray_actor: ray.actor.ActorHandle, src_blocks: List[int], dst_blocks: List[int]): + pass + + def commit_dst_request(self, backend_request: LlumnixRequest) -> None: + pass + + def get_all_request_ids(self) -> List[str]: + pass diff --git a/llumnix/backends/bladellm/metrics.py b/llumnix/backends/bladellm/metrics.py new file mode 100644 index 00000000..205e1ab8 --- /dev/null +++ b/llumnix/backends/bladellm/metrics.py @@ -0,0 +1,56 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from blade_llm.service.block_space_manager import BlockSpaceManager +from blade_llm.service.schedulers.paged_scheduler import PagedScheduler + +from llumnix.backends.bladellm.llm_engine import AsyncLLMEngineLlumnixMixin +from llumnix.metrics.base_metrics import LlumnixMetrics +from llumnix.metrics.dumper import LoggerDumper + +class BladeLLMMetrics(LlumnixMetrics): + def _init_dumper(self,): + self.dumper = LoggerDumper() + + def block_manager_init_metrics(self, block_manager: BlockSpaceManager): + self.num_total_gpu_blocks.observe(block_manager.num_total_gpu_blocks) + self.num_watermark_blocks.observe(block_manager.reserved_blocks) + + def engine_init_metrics(self, engine: AsyncLLMEngineLlumnixMixin): + self.instance_id.observe(engine.instance_id) + + def scheduler_step_metrics(self, scheduler: PagedScheduler): + block_manager: BlockSpaceManager = scheduler.block_manager + self.num_used_gpu_blocks.observe(block_manager.get_blocks_usage()*block_manager.num_total_gpu_blocks) + self.num_running_requests.observe(len(scheduler.running)) + self.num_waiting_requests.observe(len(scheduler.waiting)) + + num_blocks_all_waiting_requests = 0 + for gen_group_state in scheduler.waiting: + num_blocks_all_waiting_requests += sum([page_req.required_blocks for page_req in gen_group_state.paged_reqs]) + self.num_blocks_all_waiting_requests.observe(num_blocks_all_waiting_requests) + + self.dump() + + def engine_step_metrics(self, scheduler: PagedScheduler): + block_manager: BlockSpaceManager = scheduler.block_manager + self.num_used_gpu_blocks.observe(block_manager.get_blocks_usage()*block_manager.num_total_gpu_blocks) + self.num_running_requests.observe(len(scheduler.running)) + self.num_waiting_requests.observe(len(scheduler.waiting)) + + num_blocks_all_waiting_requests = 0 + for gen_group_state in scheduler.waiting: + num_blocks_all_waiting_requests += sum([page_req.required_blocks for page_req in gen_group_state.paged_reqs]) + self.num_blocks_all_waiting_requests.observe(num_blocks_all_waiting_requests) + + self.dump() diff --git a/llumnix/backends/bladellm/scheduler.py b/llumnix/backends/bladellm/scheduler.py new file mode 100644 index 00000000..4535dec3 --- /dev/null +++ b/llumnix/backends/bladellm/scheduler.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from blade_llm.service.schedulers import PagedScheduler +from blade_llm.service.scheduler_types import SchedulerStepOutput +from blade_llm.service.args import ServingArgs + +from llumnix.backends.bladellm.metrics import BladeLLMMetrics + +class SchedulerLlumnixMixin: + def __init__(self): + self.llumnix_metrics = BladeLLMMetrics() + +class PagedSchedulerLlumnix(PagedScheduler, SchedulerLlumnixMixin): + def __init__(self, serving_args: ServingArgs, *args, **kwargs) -> None: + PagedScheduler.__init__(self, serving_args, *args, **kwargs) + SchedulerLlumnixMixin.__init__(self) + self.llumnix_metrics.block_manager_init_metrics(self.block_manager) + + def step(self) -> SchedulerStepOutput: + step_out = super().step() + self.llumnix_metrics.scheduler_step_metrics(self) + return step_out diff --git a/llumnix/backends/utils.py b/llumnix/backends/utils.py index 16e4da4d..876542cc 100644 --- a/llumnix/backends/utils.py +++ b/llumnix/backends/utils.py @@ -11,13 +11,51 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict, List +import asyncio +import time import ray from ray.util.placement_group import PlacementGroup +from loguru import logger from llumnix.backends.backend_interface import BackendInterface, BackendType from llumnix.queue.queue_type import QueueType +from llumnix.queue.queue_client_base import QueueClientBase +from llumnix.queue.utils import init_request_output_queue_client +from llumnix.server_info import ServerInfo + +class AsyncPutQueueActor: + def __init__(self, instance_id, request_output_queue_type: QueueType): + self.instance_id = instance_id + self.request_output_queue_type = request_output_queue_type + self.request_output_queue_client: QueueClientBase = init_request_output_queue_client(request_output_queue_type) + self.engine_actor_handle = None + + async def put_nowait_to_servers(self, + server_request_outputs: Dict[str, List], + server_info_dict: Dict[str, ServerInfo]) -> None: + if self.engine_actor_handle is None: + self.engine_actor_handle = ray.get_actor("instance_{}".format(self.instance_id), namespace="llumnix") + tasks = [] + for server_id, req_outputs in server_request_outputs.items(): + server_info = server_info_dict[server_id] + for req_output in req_outputs: + if hasattr(req_output, 'request_timestamps'): + req_output.request_timestamps.engine_actor_put_queue_timestamp = time.time() + tasks.append(asyncio.create_task(self.request_output_queue_client.put_nowait(req_outputs, server_info))) + rets = await asyncio.gather(*tasks, return_exceptions=True) + for idx, ret in enumerate(rets): + if isinstance(ret, Exception): + server_id = list(server_request_outputs.keys())[idx] + server_info = server_info_dict[server_id] + logger.info("server {} is dead".format(server_id)) + if self.request_output_queue_type == QueueType.ZMQ: + logger.info("request output queue ip: {}, port: {}".format(server_info.request_output_queue_ip, + server_info.request_output_queue_port)) + req_outputs = list(server_request_outputs.values())[idx] + request_ids = [req_output.request_id for req_output in req_outputs] + self.engine_actor_handle.abort_request.remote(request_ids) def init_backend_engine(instance_id: str, request_output_queue_type: QueueType, backend_type: BackendType, *args, **kwargs) -> BackendInterface: @@ -29,6 +67,10 @@ def init_backend_engine(instance_id: str, request_output_queue_type: QueueType, # pylint: disable=import-outside-toplevel from llumnix.backends.vllm.simulator import BackendSimVLLM backend_engine = BackendSimVLLM(instance_id, request_output_queue_type, *args, **kwargs) + elif backend_type == BackendType.BLADELLM: + # pylint: disable=import-outside-toplevel + from llumnix.backends.bladellm.llm_engine import BackendBladeLLM + backend_engine = BackendBladeLLM(instance_id, request_output_queue_type, *args, **kwargs) else: raise ValueError(f'Unsupported backend: {backend_type}') return backend_engine diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index 13a2f6e9..1af547cc 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -13,7 +13,7 @@ import time import traceback -from typing import Any, List, Optional, Dict, Union, Iterable, Tuple, Deque +from typing import Any, List, Optional, Union, Iterable, Tuple, Deque from collections import defaultdict import threading import asyncio @@ -38,60 +38,21 @@ from llumnix.backends.profiling import LatencyMemData from llumnix.server_info import ServerInfo from llumnix.internal_config import MigrationConfig -from llumnix.queue.queue_client_base import QueueClientBase -from llumnix.queue.utils import init_request_output_queue_client, QueueType +from llumnix.queue.utils import QueueType +from llumnix.backends.utils import AsyncPutQueueActor logger = init_logger(__name__) NO_OUTPUTS_STEP_INTERVAL = 0.01 - -class AsyncPutQueueActor: - def __init__(self, instance_id, request_output_queue_type: QueueType): - self.instance_id = instance_id - self.request_output_queue_type = request_output_queue_type - self.request_output_queue_client: QueueClientBase = init_request_output_queue_client(request_output_queue_type) - self.engine_actor_handle = None - - async def put_nowait_to_servers(self, - server_request_outputs: Dict[str, List[RequestOutput]], - server_info_dict: Dict[str, ServerInfo]) -> None: - try: - if self.engine_actor_handle is None: - self.engine_actor_handle = ray.get_actor("instance_{}".format(self.instance_id), namespace="llumnix") - tasks = [] - for server_id, req_outputs in server_request_outputs.items(): - server_info = server_info_dict[server_id] - for req_output in req_outputs: - if hasattr(req_output, 'request_timestamps'): - req_output.request_timestamps.engine_actor_put_queue_timestamp = time.time() - tasks.append(asyncio.create_task(self.request_output_queue_client.put_nowait(req_outputs, server_info))) - rets = await asyncio.gather(*tasks, return_exceptions=True) - for idx, ret in enumerate(rets): - if isinstance(ret, Exception): - server_id = list(server_request_outputs.keys())[idx] - server_info = server_info_dict[server_id] - logger.info("server {} is dead".format(server_id)) - if self.request_output_queue_type == QueueType.ZMQ: - logger.info("request output queue ip: {}, port: {}".format(server_info.request_output_queue_ip, - server_info.request_output_queue_port)) - req_outputs = list(server_request_outputs.values())[idx] - request_ids = [req_output.request_id for req_output in req_outputs] - self.engine_actor_handle.abort_request.remote(request_ids) - # pylint: disable=W0703 - except Exception as e: - logger.error("Error in engine loop: {}".format(e)) - logger.error("exception traceback: {}".format(traceback.format_exc())) - - class LLMEngineLlumnix(_AsyncLLMEngine): def __init__(self, instance_id: str, request_output_queue_type: QueueType, placement_group: Optional[PlacementGroup], node_id: Optional[str], - *arg, **kwargs) -> None: - super().__init__(*arg, **kwargs) + *args, **kwargs) -> None: + super().__init__(*args, **kwargs) self.instance_id = instance_id self.step_counter = Counter() self.instance_info = None diff --git a/llumnix/backends/vllm/migration_backend.py b/llumnix/backends/vllm/migration_backend.py index a6f2c375..f21c2bab 100644 --- a/llumnix/backends/vllm/migration_backend.py +++ b/llumnix/backends/vllm/migration_backend.py @@ -75,7 +75,7 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, self.migration_stream = torch.cuda.Stream() def init_backend(self, group_name, world_size, rank) -> bool: - logger.info("create rpc migration backend successfully.") + logger.info("create rayrpc migration backend successfully.") return True def destory_backend(self) -> None: @@ -85,7 +85,7 @@ def destory_backend(self) -> None: def warmup(self) -> bool: self.actor.exec_method.remote(self.is_driver_worker, "do_send", [0]) - logger.info("rpc migration backend warmup successfully.") + logger.info("rayrpc migration backend warmup successfully.") return True # The src actor will pack the kv-cache data layer by layer. Specifically, NumPy is used for the transfer @@ -285,7 +285,7 @@ def get_migration_backend(migration_config: MigrationConfig, cache_engine: Cache target_migration_backend = None backend = migration_config.migration_backend - assert backend in ['nccl', 'gloo', 'rpc'], "Unsupported migration backend: {} for llumnix".format(backend) + assert backend in ['nccl', 'gloo', 'rayrpc'], "Unsupported migration backend: {} for llumnix".format(backend) if backend in ['nccl', 'gloo']: target_migration_backend = RayColMigrationBackend(migration_config, cache_engine, local_rank, scheduling_strategy, diff --git a/llumnix/config/default.py b/llumnix/config/default.py index 67de97ee..aa76ba30 100644 --- a/llumnix/config/default.py +++ b/llumnix/config/default.py @@ -105,6 +105,12 @@ # Communication backend of migration _C.MANAGER.MIGRATION_BACKEND = "gloo" +# Transfer type for migration backend kvTransfer +_C.MANAGER.MIGRATION_BACKEND_TRANSFER_TYPE = "rdma" +# Address of grpc server for migration backend +_C.MANAGER.MIGRATION_BACKEND_SERVER_ADDRESS = "127.0.0.1:50051" +# URL of naming server for kvtransfer migration backend +_C.MANAGER.MIGRATION_BACKEND_KVTRANSFER_NAMING_URL = "file:/tmp/llumnix/naming/" # Timeout(s) for initializing migration backend _C.MANAGER.MIGRATION_BACKEND_INIT_TIMEOUT = 10.0 # Number of cache blocks in migration diff --git a/llumnix/entrypoints/bladellm/__init__.py b/llumnix/entrypoints/bladellm/__init__.py new file mode 100644 index 00000000..4638bd9c --- /dev/null +++ b/llumnix/entrypoints/bladellm/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/llumnix/entrypoints/bladellm/api_server.py b/llumnix/entrypoints/bladellm/api_server.py new file mode 100644 index 00000000..ad9d5327 --- /dev/null +++ b/llumnix/entrypoints/bladellm/api_server.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from blade_llm.service.args import ServingArgs + +from llumnix.config import get_llumnix_config, LlumnixConfig +from llumnix.backends.backend_interface import BackendType +from llumnix.arg_utils import LlumnixEntrypointsArgs, EngineManagerArgs, LlumnixArgumentParser +from llumnix.entrypoints.setup import setup_ray_cluster, setup_llumnix, is_gpu_available +from llumnix.entrypoints.bladellm.client import LlumnixClientBladeLLM +from llumnix.entrypoints.setup import LlumnixEntrypointsContext +from llumnix.entrypoints.bladellm.utils import get_args + +def setup_llumnix_api_server(bladellm_args: ServingArgs, loop: asyncio.AbstractEventLoop): + # generate llumnix_parser for checking parameters with choices + llumnix_parser: LlumnixArgumentParser = LlumnixArgumentParser() + llumnix_parser = LlumnixEntrypointsArgs.add_cli_args(llumnix_parser) + llumnix_parser = EngineManagerArgs.add_cli_args(llumnix_parser) + llumnix_config: LlumnixConfig = get_llumnix_config(bladellm_args.llumnix_config) + _, engine_manager_args, engine_args = get_args(llumnix_config, llumnix_parser, bladellm_args) + + setup_ray_cluster(llumnix_config) + + llm_client = None + # if gpu is not available, it means that this node is head pod x any llumnix components + if is_gpu_available(): + world_size = engine_args.tensor_parallel_size * engine_args.pipeline_parallel_size + instance_ids = None + if engine_args.enable_disagg: + instance_ids = [engine_args.disagg_options.inst_id] + + llumnix_context: LlumnixEntrypointsContext = \ + setup_llumnix(engine_manager_args, engine_args, llumnix_config, BackendType.BLADELLM, + world_size, instance_ids=instance_ids) + llm_client = LlumnixClientBladeLLM(bladellm_args, llumnix_context, loop) + + return llm_client diff --git a/llumnix/entrypoints/bladellm/client.py b/llumnix/entrypoints/bladellm/client.py new file mode 100644 index 00000000..1d0d0366 --- /dev/null +++ b/llumnix/entrypoints/bladellm/client.py @@ -0,0 +1,145 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import time +import asyncio +import copy +import random +from typing import Dict + +import ray + +from blade_llm.service.communications.engine_client import MultiProcessingLLMClient +from blade_llm.service.communications.protocol import Stats +from blade_llm.service.communications.response import LLMResponse +from blade_llm.service.args import ServingArgs +from blade_llm.protocol import ServerRequest, GenerateStreamResponse +from blade_llm.service.communications.response import error_resp + +from llumnix.server_info import RequestTimestamps +from llumnix.entrypoints.setup import LlumnixEntrypointsContext +from llumnix.logger import init_logger + +logger = init_logger(__name__) + +WAIT_MANAGER_INTERVAL = 5 + +class LlumnixClientBladeLLM(MultiProcessingLLMClient): + def __init__(self, args: ServingArgs, llumnix_context: LlumnixEntrypointsContext, loop: asyncio.AbstractEventLoop): + super().__init__(args, -1) + self.entrypoint_id2llumnix_id = {} + self.llumnix_id2entrypoint_id = {} + self.llumnix_context = llumnix_context + self.request_streams: Dict[str, asyncio.Queue] = {} + loop.create_task(self.background_process_outputs()) + + async def background_process_outputs(self): + while True: + request_outputs = await self.llumnix_context.request_output_queue.get() + if request_outputs is None: + continue + for (request_id, request_output) in request_outputs: + request_output = GenerateStreamResponse(**json.loads(request_output)) + # Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished. + if request_id not in self.request_streams: + continue + await self.request_streams[request_id].put(request_output) + if request_output.is_finished: + logger.info("Client Recv: {}".format(request_output)) + del self.entrypoint_id2llumnix_id[self.llumnix_id2entrypoint_id[request_id]] + del self.llumnix_id2entrypoint_id[request_id] + del self.request_streams[request_id] + + async def _add_request(self, request: ServerRequest) -> LLMResponse: + if request.sampling_params.n > 1 or request.sampling_params.use_beam_search: + return error_resp(request.id, err_code=400, err_msg="Unsupported feature: multiple sequence decoding in Llumnix.") + + llumnix_id = random.randint(0, 2147483647) # 1<<31-1 + self.llumnix_id2entrypoint_id[str(llumnix_id)] = request.id + self.entrypoint_id2llumnix_id[request.id] = llumnix_id + request.id = llumnix_id + resp_stream = await self._manager_generate(request.model_dump_json(), str(llumnix_id)) + return resp_stream + + async def _manager_generate(self, request, request_id: str) -> LLMResponse: + logger.debug("Client Add request: {}:{}".format(request_id, request)) + + results_queue = asyncio.Queue() + self.request_streams[request_id] = results_queue + + # This request's outputs will be put to the request_output_queue of this api server no matter which instance it's running in. + # If manager is unavailable, request will be directly added to the llumlet held by api server. + try: + server_info_copy = copy.deepcopy(self.llumnix_context.server_info) + if self.llumnix_context.log_request_timestamps: + # Hack request timestamps in server_info for latency breakdown. + server_info_copy.request_timestamps = RequestTimestamps() + server_info_copy.request_timestamps.api_server_manager_generate_timestamp = time.time() + # await to catch exception + await self.llumnix_context.engine_manager.generate.remote(str(request_id), server_info_copy, server_request=request) + self.llumnix_context.manager_available = True + # pylint: disable=broad-except + except Exception as e: + logger.error("Error in manager generate: {}".format(e)) + # Do not re-generate the request to avoid duplicate requests. + if self.llumnix_context.manager_available: + self.llumnix_context.manager_available = False + return LLMResponse(request_id, resp_queue=results_queue) + try: + if self.llumnix_context.instance_num_requests: + instance_id = min(self.llumnix_context.instance_num_requests, key=self.llumnix_context.instance_num_requests.get) + self.llumnix_context.instance_num_requests[instance_id] += 1 + # TODO[xinyi]: set expected step here + await self.llumnix_context.instances[instance_id].generate.remote(request_id, server_info_copy, -1, request) + logger.info("Manager is unavailable, directly pass request {} to instance {}".format(request_id, instance_id)) + else: + logger.info("Manager is unavailable, but there is no instance behind this api server, " + "sleep {}s, waiting for manager restarts".format(WAIT_MANAGER_INTERVAL)) + await asyncio.sleep(WAIT_MANAGER_INTERVAL) + return await asyncio.create_task(self._manager_generate(request, request_id)) + except (ray.exceptions.RayActorError, KeyError): + if instance_id in self.llumnix_context.instances: + logger.info("[manager_generate] instance {} is dead".format(instance_id)) + del self.llumnix_context.instances[instance_id] + del self.llumnix_context.instance_num_requests[instance_id] + return await asyncio.create_task(self._manager_generate(request, request_id)) + return LLMResponse(request_id, resp_queue=results_queue) + + async def drop_request(self, req_id: int): + llumnix_id = self.entrypoint_id2llumnix_id.get(req_id, None) + if llumnix_id: + try: + logger.info("abort request: {}.".format(req_id)) + await self.llumnix_context.engine_manager.abort.remote(str(req_id)) + self.entrypoint_id2llumnix_id.pop(req_id, None) + except ray.exceptions.RayActorError: + logger.info("Manager is unavailable") + + def connect(self): + pass + + def close(self): + pass + + async def get_stats(self) -> Stats: + pass + + async def get_metrics(self) -> str: + pass + + async def start_profiler(self) -> None: + pass + + async def stop_profiler(self) -> None: + pass diff --git a/llumnix/entrypoints/bladellm/utils.py b/llumnix/entrypoints/bladellm/utils.py new file mode 100644 index 00000000..3b9f8d14 --- /dev/null +++ b/llumnix/entrypoints/bladellm/utils.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from loguru import logger + +from blade_llm.service.args import ServingArgs +from llumnix.arg_utils import LlumnixEntrypointsArgs, EngineManagerArgs + +def detect_unsupported_feature(engine_args: ServingArgs) -> None: + unsupported_feature = None + if engine_args.enable_lora: + unsupported_feature = "multi-lora serving" + elif not engine_args.disable_prompt_cache: + unsupported_feature = "automatic prompt caching" + elif engine_args.use_sps: + unsupported_feature = "speculative decoding" + elif engine_args.enable_remote_worker: + unsupported_feature = "enable_remote_worker" + + if unsupported_feature: + raise ValueError(f'Llumnix does not support "{unsupported_feature}" for bladeLLM currently.') + +def check_engine_args(engine_args: ServingArgs, engine_manager_args: EngineManagerArgs) -> None: + migration_config = engine_manager_args.create_migration_config() + if (engine_args.tensor_parallel_size > 1 or engine_args.tensor_parallel_size > 1) and \ + migration_config.migration_backend == 'nccl': + logger.info("Llumnix does not support TP or PP enabled model when the migration backend is nccl, \ + change migration backend to gloo.") + engine_manager_args.migration_backend = 'gloo' + detect_unsupported_feature(engine_args) + +def get_args(llumnix_cfg, llumnix_parser, engine_args): + llumnix_entrypoints_args = LlumnixEntrypointsArgs.from_llumnix_config(llumnix_cfg) + LlumnixEntrypointsArgs.check_args(llumnix_entrypoints_args, llumnix_parser) + engine_manager_args = EngineManagerArgs.from_llumnix_config(llumnix_cfg) + EngineManagerArgs.check_args(engine_manager_args, llumnix_parser) + check_engine_args(engine_args, engine_manager_args) + + logger.info("llumnix_entrypoints_args: {}", llumnix_entrypoints_args) + logger.info("engine_manager_args: {}", engine_manager_args) + logger.info("engine_args: {}", engine_args) + + return llumnix_entrypoints_args, engine_manager_args, engine_args diff --git a/llumnix/entrypoints/setup.py b/llumnix/entrypoints/setup.py index 35280e0f..23660186 100644 --- a/llumnix/entrypoints/setup.py +++ b/llumnix/entrypoints/setup.py @@ -152,14 +152,14 @@ def init_manager(engine_manager_args: EngineManagerArgs) -> LLMEngineManager: logger.info("Get existing LLMEngineManager") return engine_manager -def init_llumlets(engine_manager_args: EngineManagerArgs, engine_args, node_id: str, - request_output_queue_type: QueueType) -> Tuple[List[str], List[Llumlet]]: - engine_config = engine_args.create_engine_config() - parallel_config = engine_config.parallel_config +def init_llumlets(engine_manager_args: EngineManagerArgs, engine_args, node_id: str, request_output_queue_type: QueueType, + backend_type: BackendType, world_size: int, *args, **kwargs) -> Tuple[List[str], List[Llumlet]]: instance_ids: List[str] = [] llumlets: List[Llumlet] = [] - instance_ids = [random_uuid() for _ in range(engine_manager_args.initial_instances)] + if 'instance_ids' in kwargs: + instance_ids = kwargs['instance_ids'] if kwargs['instance_ids'] else instance_ids + kwargs.pop('instance_ids') migration_configs = engine_manager_args.create_migration_config() for idx in range(engine_manager_args.initial_instances): instance_id = instance_ids[idx] @@ -170,12 +170,15 @@ def init_llumlets(engine_manager_args: EngineManagerArgs, engine_args, node_id: False, node_id, instance_id, - BackendType.VLLM, - parallel_config.world_size, + backend_type, + world_size, migration_configs, engine_args, + *args, + **kwargs ) else: + assert backend_type == backend_type.VLLM, f'unimplemented backend SIM_{backend_type}' llumlet = Llumlet.from_args( request_output_queue_type, engine_manager_args.disable_fixed_node_init_instance, @@ -183,10 +186,11 @@ def init_llumlets(engine_manager_args: EngineManagerArgs, engine_args, node_id: node_id, instance_id, BackendType.SIM_VLLM, - parallel_config.world_size, + world_size, migration_configs, engine_manager_args.profiling_result_file_path, - engine_args, + *args, + **kwargs, ) llumlets.append(llumlet) return instance_ids, llumlets @@ -196,13 +200,16 @@ def init_llumnix_components(engine_manager_args: EngineManagerArgs, node_id: str, request_output_queue_type: QueueType, ip: str, - request_output_queue_port: str): + request_output_queue_port: str, + *args, + **kwargs + ): engine_manager = init_manager(engine_manager_args) if engine_manager_args.disable_init_instance_by_manager: - instance_ids, llumlets = init_llumlets(engine_manager_args, engine_args, node_id, request_output_queue_type) + instance_ids, llumlets = init_llumlets(engine_manager_args, engine_args, node_id, request_output_queue_type, *args, **kwargs) else: instance_ids, llumlets = retry_manager_method_sync( - engine_manager.init_llumlets.remote, 'init_llumlets', engine_args, node_id, request_output_queue_type) + engine_manager.init_llumlets.remote, 'init_llumlets', engine_args, node_id, request_output_queue_type, *args, **kwargs) available_instance_ids = [] dead_instance_ids = [] @@ -227,7 +234,7 @@ def init_llumnix_components(engine_manager_args: EngineManagerArgs, return engine_manager, available_instance_ids, available_llumlets, request_output_queue -def setup_llumnix(engine_manager_args, engine_args, cfg): +def setup_llumnix(engine_manager_args, engine_args, cfg, *args, **kwargs): ip = get_ip_address() node_id = ray.get_runtime_context().get_node_id() engine_manager, instance_ids, llumlets, request_output_queue = \ @@ -236,7 +243,9 @@ def setup_llumnix(engine_manager_args, engine_args, cfg): node_id, cfg.SERVER.REQUEST_OUTPUT_QUEUE_TYPE, ip, - cfg.SERVER.REQUEST_OUTPUT_QUEUE_PORT) + cfg.SERVER.REQUEST_OUTPUT_QUEUE_PORT, + *args, + **kwargs) server_id = random_uuid() server_info = ServerInfo(server_id, cfg.SERVER.REQUEST_OUTPUT_QUEUE_TYPE, diff --git a/llumnix/entrypoints/vllm/api_server.py b/llumnix/entrypoints/vllm/api_server.py index 4d6fa730..46cbf842 100644 --- a/llumnix/entrypoints/vllm/api_server.py +++ b/llumnix/entrypoints/vllm/api_server.py @@ -34,6 +34,7 @@ from llumnix.logger import init_logger from llumnix.utils import random_uuid from llumnix.config import get_llumnix_config, LlumnixConfig +from llumnix.backends.backend_interface import BackendType # Code file with __main__ should set the logger name to inherit the llumnix logger configuration. logger = init_logger("llumnix.entrypoints.vllm.api_server") @@ -188,8 +189,11 @@ async def is_ready(): # if gpu is not available, it means that this node is head pod without any llumnix components if is_gpu_available(): - llumnix_entrypoints_context = setup_llumnix(engine_manager_args, engine_args, cfg) + engine_config = engine_args.create_engine_config() + parallel_config = engine_config.parallel_config + llumnix_entrypoints_context = setup_llumnix(engine_manager_args, engine_args, cfg, BackendType.VLLM, parallel_config.world_size) llumnix_client = LlumnixClientVLLM(llumnix_entrypoints_context) + # Start the api server after all the components of llumnix are ready. logger.info("Start Api Server on '{}:{}'".format(cfg.SERVER.HOST, cfg.SERVER.PORT)) uvicorn.run(app, diff --git a/llumnix/global_scheduler/dispatch_scheduler.py b/llumnix/global_scheduler/dispatch_scheduler.py index 0f5ae030..df3857b4 100644 --- a/llumnix/global_scheduler/dispatch_scheduler.py +++ b/llumnix/global_scheduler/dispatch_scheduler.py @@ -59,10 +59,13 @@ def update_instance_infos(self, def add_instance(self, instance_id: str) -> None: self.instance_id_set.add(instance_id) self.num_instances = len(self.instance_id_set) - if self.num_dispatch_instances <= 0 or (self.num_dispatch_instances > 0 and - len(self.available_dispatch_instance_set) < self.num_dispatch_instances): - self.available_dispatch_instance_set.add(instance_id) - self.instance_num_requests[instance_id] = 0 + + # TODO(KuilongCui): a hacky method is being used to avoid the only-decode type engine dispatched + if "decode" not in instance_id: + if self.num_dispatch_instances <= 0 or (self.num_dispatch_instances > 0 and + len(self.available_dispatch_instance_set) < self.num_dispatch_instances): + self.available_dispatch_instance_set.add(instance_id) + self.instance_num_requests[instance_id] = 0 def remove_instance(self, instance_id: str) -> None: self.instance_id_set.remove(instance_id) diff --git a/llumnix/global_scheduler/migration_scheduler.py b/llumnix/global_scheduler/migration_scheduler.py index 9c448ebf..61516dd7 100644 --- a/llumnix/global_scheduler/migration_scheduler.py +++ b/llumnix/global_scheduler/migration_scheduler.py @@ -33,8 +33,8 @@ def __init__(self, # to prevent instances of migration backends that have not been initialized from participating in migration. migration_backend_init_filter = CustomFilter() migration_backend_init_filter.set_filter_condtition( - src_filter=lambda _: migration_backend == 'rpc', - dst_filter=lambda _: migration_backend == 'rpc') + src_filter=lambda _: migration_backend not in ['gloo', 'nccl'], + dst_filter=lambda _: migration_backend not in ['gloo', 'nccl']) self.migration_filter.register_filter("migration_backend_init_filter", migration_backend_init_filter) diff --git a/llumnix/instance_info.py b/llumnix/instance_info.py index 1848362b..95f7dd5f 100644 --- a/llumnix/instance_info.py +++ b/llumnix/instance_info.py @@ -36,7 +36,8 @@ def __init__(self, num_blocks_all_waiting_requests: int = 0, inference_type: RequestInferenceType = RequestInferenceType.PREFILL, instance_type: str = "", - num_batched_tokens: int = 0) -> None: + num_batched_tokens: int = 0, + instance_id: str = "",) -> None: self.num_total_gpu_blocks = num_total_gpu_blocks self.num_watermark_blocks = num_watermark_blocks self.num_used_gpu_blocks = num_used_gpu_blocks @@ -67,7 +68,7 @@ def __init__(self, self.finished_request_ids = None # For record statistics, assigned in backend engine. - self.instance_id = None + self.instance_id = instance_id self.step_id = None self.timestamp = None self.profiling_data = () diff --git a/llumnix/internal_config.py b/llumnix/internal_config.py index 9584f983..b21d45d7 100644 --- a/llumnix/internal_config.py +++ b/llumnix/internal_config.py @@ -20,14 +20,21 @@ def __init__( migration_num_layers: int, last_stage_max_blocks: int, max_stages: int, - migration_backend_init_timeout: float) -> None: + migration_backend_init_timeout: float, + migration_backend_transfer_type: str = "", + migration_backend_server_address: str = "", + migration_backend_kvtransfer_naming_url: str = "", + ) -> None: self.request_migration_policy = request_migration_policy self.migration_backend = migration_backend + self.migration_backend_transfer_type = migration_backend_transfer_type self.migration_num_layers = migration_num_layers self.migration_buffer_blocks = migration_buffer_blocks self.last_stage_max_blocks = last_stage_max_blocks self.max_stages = max_stages self.migration_backend_init_timeout = migration_backend_init_timeout + self.migration_backend_server_address = migration_backend_server_address + self.migration_backend_kvtransfer_naming_url = migration_backend_kvtransfer_naming_url class GlobalSchedulerConfig: def __init__( diff --git a/llumnix/llm_engine_manager.py b/llumnix/llm_engine_manager.py index 93c20f37..931d33a8 100644 --- a/llumnix/llm_engine_manager.py +++ b/llumnix/llm_engine_manager.py @@ -101,18 +101,14 @@ def __init__(self, self.instance_last_logged_empty = {} # When manager starts, it automatically connects to all existing instances. - self._connect_to_instances() - - async def generate( - self, - request_id: str, - server_info: ServerInfo, - *args, - **kwargs,) -> None: + asyncio.run_coroutine_threadsafe(self._connect_to_instances(), asyncio.get_event_loop()) + + async def generate(self, request_id: str, server_info: ServerInfo, *args, **kwargs,) -> None: while self.num_instances == 0: logger.info("No instance available temporarily, sleep {}s, " "and retry generate request {} again....".format(RETRIES_INTERVALS, request_id)) await asyncio.sleep(RETRIES_INTERVALS) + instance_id, request_expected_steps = self.global_scheduler.dispatch() try: if hasattr(server_info, 'request_timestamps'): @@ -178,6 +174,9 @@ def update_instance_info_done_callback(instance_id: str, fut): self.global_scheduler.update_instance_infos([ret]) else: dead_instance_ids.append(instance_id) + logger.info("[_update_instance_info_loop] dead instances: {}.".format(ret)) + logger.info("[_update_instance_info_loop] dead instances: {}.".format(self.instances)) + while True: try: await asyncio.sleep(interval) @@ -191,7 +190,6 @@ def update_instance_info_done_callback(instance_id: str, fut): tasks.append(task) await asyncio.gather(*tasks, return_exceptions=True) if len(dead_instance_ids) > 0: - logger.info("[_update_instance_info_loop] dead instances: {}.".format(dead_instance_ids)) self.scale_down(dead_instance_ids) self.num_instance_info_updates += 1 # Push migrate when the instance_info have updated a certain number of times. @@ -211,6 +209,7 @@ async def _clear_request_instance_loop(self, interval: float): while True: await asyncio.sleep(interval) self.request_instance = {} + async def _push_migrations(self) -> None: # Push migrate when the instance_info have updated a certain number of times. if self.enable_pd_disagg: @@ -366,7 +365,7 @@ def scale_up(self, instance_id: Union[str, Iterable[str]], llumlet_actor_handles # a coroutine is already handling the changes in the number of instances in the cluster and it will account for the changes # caused by this scale-up (see rebuild_migrate_backend for details). Therefore, we simply return in this case. Specifically, # for RPC, the Ray actor handle is used for the migration cache, so there is no need to rebuild the group. - if self.enable_migration and self.engine_manager_args.migration_backend != 'rpc' \ + if self.enable_migration and self.engine_manager_args.migration_backend in ['gloo', 'nccl'] \ and indeed_update and no_pending_instance: asyncio.create_task(self.rebuild_migrate_backend()) @@ -391,7 +390,7 @@ def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migrate_bac self.global_scheduler.scale_down(instance_ids) self.num_instances = len(self.instances) - if self.enable_migration and self.engine_manager_args.migration_backend != 'rpc': + if self.enable_migration and self.engine_manager_args.migration_backend in ['gloo', 'nccl']: if len(self.instances) == 0: self.pending_rebuild_migration_instances = 0 @@ -402,7 +401,7 @@ def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migrate_bac return self.num_instances - def _connect_to_instances(self): + async def _connect_to_instances(self): actor_names_dict = ray.util.list_named_actors(True) instance_actor_names = [actor_name_dict['name'] for actor_name_dict in actor_names_dict if actor_name_dict['name'] != MANAGER_ACTOR_NAME] instance_actor_handles = [ray.get_actor(actor_name, namespace='llumnix') for actor_name in instance_actor_names] @@ -412,7 +411,7 @@ def _connect_to_instances(self): instance_id = instance_actor_name[len('instance_'):] if instance_id not in self.instances: try: - ray.get(instance_actor_handle.is_ready.remote()) + await instance_actor_handle.is_ready.remote() # pylint: disable=W0703 except Exception as e: logger.info("connect to instance {} abort, which may be not ready or alive, err: {}".format(instance_id, e)) @@ -454,13 +453,14 @@ def from_args(cls, profiling_database=profiling_database) return engine_manager - # TODO(s5u13b): Significant duplication with llumlet_utils.init_llumlets. Consider reducing duplicate codes. - def init_llumlets(self, engine_args, node_id: str, request_output_queue_type: QueueType) -> Tuple[List[str], List[Llumlet]]: + # TODO(s5u13b): Fix the logger when enabling init instance by manager. + def init_llumlets(self, engine_args, node_id: str, request_output_queue_type: QueueType, + backend_type: BackendType, world_size: int, *args, **kwargs) -> Tuple[List[str], List[Llumlet]]: engine_manager_args = self.engine_manager_args - engine_config = engine_args.create_engine_config() - parallel_config = engine_config.parallel_config instance_ids: List[str] = [] llumlets: List[Llumlet] = [] + if 'instance_ids' in kwargs and kwargs['instance_ids'][0]: + instance_ids = kwargs['instance_ids'] for _ in range(engine_manager_args.initial_instances): instance_id = random_uuid() if not engine_manager_args.profiling_result_file_path: @@ -470,12 +470,15 @@ def init_llumlets(self, engine_args, node_id: str, request_output_queue_type: Qu True, node_id, instance_id, - BackendType.VLLM, - parallel_config.world_size, + backend_type, + world_size, engine_manager_args.create_migration_config(), engine_args, + *args, + **kwargs ) else: + assert backend_type == backend_type.VLLM, f'unimplemented backend SIM_{backend_type}' llumlet = Llumlet.from_args( request_output_queue_type, engine_manager_args.disable_fixed_node_init_instance, @@ -483,10 +486,11 @@ def init_llumlets(self, engine_args, node_id: str, request_output_queue_type: Qu node_id, instance_id, BackendType.SIM_VLLM, - parallel_config.world_size, + world_size, engine_manager_args.create_migration_config(), engine_manager_args.profiling_result_file_path, - engine_args, + *args, + **kwargs ) instance_ids.append(instance_id) llumlets.append(llumlet) diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index 63dd23e7..56ab4435 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -15,6 +15,7 @@ import traceback from typing import List, Union, Iterable import time + import ray from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy @@ -76,49 +77,62 @@ def from_args(cls, migration_config: MigrationConfig, *args, **kwargs): - lifetime = "detached" if detached else None - assert backend_type in [backend_type.VLLM, backend_type.SIM_VLLM], f'unimplemented backend {backend_type}' - actor_name = f"instance_{instance_id}" - if backend_type == backend_type.VLLM: - if disable_fixed_node_init_instance: - # TODO(s5u13b): Support placement_group lifetime management when the migration backend is gloo. - placement_group = initialize_placement_group(world_size, detached=detached) - kwargs["placement_group"] = placement_group - engine_class = ray.remote(num_cpus=1, - name=actor_name, - namespace='llumnix', - max_concurrency=4, - lifetime=lifetime)(cls).options( - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=placement_group, - placement_group_bundle_index=0, + try: + lifetime = "detached" if detached else None + assert backend_type in [backend_type.VLLM, backend_type.SIM_VLLM, backend_type.BLADELLM], \ + f'unimplemented backend {backend_type}' + num_gpu = 0 + if backend_type == backend_type.BLADELLM: + num_gpu = world_size + actor_name = f"instance_{instance_id}" + if backend_type in [backend_type.VLLM, backend_type.BLADELLM]: + if disable_fixed_node_init_instance: + # TODO(s5u13b): Support placement_group lifetime management when the migration backend is gloo. + placement_group = initialize_placement_group(world_size, detached=detached) + kwargs["placement_group"] = placement_group + engine_class = ray.remote(num_cpus=1, + num_gpus=num_gpu, + name=actor_name, + namespace='llumnix', + max_concurrency=4, + lifetime=lifetime)(cls).options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_bundle_index=0, + ) ) - ) - else: + else: + kwargs["node_id"] = node_id + engine_class = ray.remote(num_cpus=1, + num_gpus=num_gpu, + name=actor_name, + namespace='llumnix', + max_concurrency=4, + lifetime=lifetime)(cls).options( + scheduling_strategy=NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), + soft=False, + ) + ) + else: # backend_type == backend_type.SIM_VLLM: kwargs["node_id"] = node_id engine_class = ray.remote(num_cpus=1, - name=actor_name, - namespace='llumnix', - max_concurrency=4, - lifetime=lifetime)(cls).options( + num_gpu=num_gpu, + name=actor_name, + namespace='llumnix', + max_concurrency=4, + lifetime=lifetime)(cls).options( scheduling_strategy=NodeAffinitySchedulingStrategy( node_id=node_id, soft=False, ) ) - else: # backend_type == backend_type.SIM_VLLM: - kwargs["node_id"] = node_id - engine_class = ray.remote(num_cpus=1, - name=actor_name, - namespace='llumnix', - max_concurrency=4, - lifetime=lifetime)(cls).options( - scheduling_strategy=NodeAffinitySchedulingStrategy( - node_id=node_id, - soft=False, - ) - ) - llumlet = engine_class.remote(instance_id, request_output_queue_type, backend_type, migration_config, *args, **kwargs) + llumlet = engine_class.remote(instance_id, request_output_queue_type, backend_type, migration_config, *args, **kwargs) + # pylint: disable=broad-except + except Exception as e: + logger.error("Failed to initialize llumlet: {}".format(e)) + logger.error("exception traceback: {}".format(traceback.format_exc())) + return llumlet async def _check_engine_state_loop(self): @@ -128,6 +142,7 @@ async def _check_engine_state_loop(self): logger.warning("llumlet ({}) detected backend engine crashed. Stopping...".format(self.instance_id)) # pylint: disable=protected-access self.backend_engine._stop_event.set() + await asyncio.sleep(0) self_actor = ray.get_actor(self.actor_name) ray.kill(self_actor) @@ -198,14 +213,7 @@ def is_ready(self) -> bool: def get_all_request_ids(self) -> List[str]: return self.backend_engine.get_all_request_ids() - def generate( - self, - request_id: str, - server_info: ServerInfo, - expected_steps: int, - *args, - **kwargs, - ) -> None: + def generate(self, request_id: str, server_info: ServerInfo, expected_steps: int, *args, **kwargs) -> None: # This should not be used for logging, as it is monotonic time. if hasattr(server_info, 'request_timestamps'): server_info.request_timestamps.llumlet_generate_timestamp = time.time() diff --git a/llumnix/metrics/__init__.py b/llumnix/metrics/__init__.py new file mode 100644 index 00000000..4638bd9c --- /dev/null +++ b/llumnix/metrics/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/llumnix/metrics/base_metrics.py b/llumnix/metrics/base_metrics.py new file mode 100644 index 00000000..ad7d1799 --- /dev/null +++ b/llumnix/metrics/base_metrics.py @@ -0,0 +1,59 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from llumnix.metrics.variable import _REGISTRY, Status +from llumnix.metrics.dumper import Dumper, DummyDumper +from llumnix.instance_info import InstanceInfo + +class LlumnixMetrics(ABC): + def __init__(self): + self.instance_id = Status("instance_id") + + # used for dispatch + self.num_total_gpu_blocks = Status("num_total_gpu_blocks") + self.num_watermark_blocks = Status("num_watermark_blocks") + self.num_used_gpu_blocks = Status("num_used_gpu_blocks") + self.num_blocks_all_waiting_requests = Status("num_blocks_all_waiting_requests") + self.num_running_requests = Status("num_running_requests") + self.num_waiting_requests = Status("num_waiting_requests") + + self.dumper: Dumper = None + self._init_dumper() + + def dump(self): + self.dumper.dump(_REGISTRY.describe_all()) + + def to_instance_info(self) -> InstanceInfo: + return InstanceInfo(**(_REGISTRY.describe_all())) + + def _init_dumper(self,): + self.dumper = DummyDumper() + + @abstractmethod + def block_manager_init_metrics(self, block_manager): + ... + + @abstractmethod + def engine_init_metrics(self, engine): + ... + + @abstractmethod + def scheduler_step_metrics(self, scheduler): + ... + + @abstractmethod + def engine_step_metrics(self, scheduler): + ... + \ No newline at end of file diff --git a/llumnix/metrics/dumper.py b/llumnix/metrics/dumper.py new file mode 100644 index 00000000..0334198a --- /dev/null +++ b/llumnix/metrics/dumper.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any, Dict + +from loguru import logger + +class Dumper(ABC): + @abstractmethod + def dump(self, metrics: Dict[str, Any]) -> None: + ... + +class LoggerDumper(Dumper): + def dump(self, metrics: Dict[str, Any]) -> None: + logger.info("Metrics: {}", metrics) + +class DummyDumper(Dumper): + def dump(self, metrics: Dict[str, Any]) -> None: + pass diff --git a/llumnix/metrics/variable.py b/llumnix/metrics/variable.py new file mode 100644 index 00000000..4ea191d2 --- /dev/null +++ b/llumnix/metrics/variable.py @@ -0,0 +1,72 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional + +class Registery: + def __init__(self): + self._metrics: Dict[str, Variable] = {} + + def get(self, name: str) -> Optional['Variable']: + return self._metrics.get(name, None) + + def register(self, name: str, metric: 'Variable'): + if name in self._metrics: + raise RuntimeError(f"Metric name already registered: {name}") + self._metrics[name] = metric + + def describe_all(self) -> Dict[str, Any]: + ret = {} + for _, metric in self._metrics.items(): + ret.update(metric.describe()) + return ret + + def clear(self): + self._metrics.clear() + + def remove(self, key) -> None: + del self._metrics[key] + +_REGISTRY = Registery() + +class Variable(ABC): + def __init__(self, name: str): + self._name: str = name + _REGISTRY.register(name, self) + + @abstractmethod + def collect(self) -> Any: + ... + + @abstractmethod + def observe(self, value: Any) -> None: + ... + + def describe(self): + return {self.name : self.collect()} + + @property + def name(self) -> str: + return self._name + +class Status(Variable): + def __init__(self, name: str, initial_value: Any = None): + super().__init__(name) + self._value: Any = initial_value + + def collect(self) -> Any: + return self._value + + def observe(self, value: Any) -> None: + self._value = value diff --git a/llumnix/queue/utils.py b/llumnix/queue/utils.py index 35472c76..c39fa91c 100644 --- a/llumnix/queue/utils.py +++ b/llumnix/queue/utils.py @@ -19,7 +19,9 @@ from llumnix.queue.ray_queue_client import RayQueueClient from llumnix.queue.zmq_utils import get_open_zmq_ipc_path from llumnix.queue.queue_type import QueueType +from llumnix.logger import init_logger +logger = init_logger(__name__) def init_request_output_queue_server(zmq_ip: str, zmq_port: int, queue_type: QueueType) -> QueueServerBase: output_queue_server: QueueServerBase = None diff --git a/requirements/requirements_bladellm.txt b/requirements/requirements_bladellm.txt new file mode 100644 index 00000000..3c66e6c7 --- /dev/null +++ b/requirements/requirements_bladellm.txt @@ -0,0 +1,7 @@ +ray >= 2.9.0 +pyarrow # Required for Ray data. +aiohttp +pandas +matplotlib +pyyaml +yacs diff --git a/requirements.txt b/requirements/requirements_vllm.txt similarity index 96% rename from requirements.txt rename to requirements/requirements_vllm.txt index 25cf07ec..f9fbe6a6 100644 --- a/requirements.txt +++ b/requirements/requirements_vllm.txt @@ -10,3 +10,4 @@ pyyaml yacs numpy < 1.24.0 # for gloo migration backend's compatibility with numpy.float pyzmq +loguru diff --git a/setup.py b/setup.py index 8aef38f6..1444d513 100644 --- a/setup.py +++ b/setup.py @@ -18,11 +18,11 @@ ROOT_DIR = os.path.dirname(__file__) def get_path(*filepath) -> str: - return os.path.join(ROOT_DIR, *filepath) + return os.path.join(ROOT_DIR, 'requirements', *filepath) -def get_requirements() -> List[str]: +def get_requirements(engine: str) -> List[str]: """Get Python package dependencies from requirements.txt.""" - with open(get_path("requirements.txt")) as f: + with open(get_path(f"requirements_{engine}.txt")) as f: requirements = f.read().strip().split("\n") return requirements @@ -42,7 +42,10 @@ def readme(): url='https://github.com/AlibabaPAI/llumnix', license="Apache 2.0", packages=find_packages(), - install_requires=get_requirements(), + extras_require={ + 'vllm': get_requirements('vllm'), + 'bladellm': get_requirements('bladellm'), + }, platforms=["all"], classifiers=[ 'Programming Language :: Python', diff --git a/tests/conftest.py b/tests/conftest.py index 65a1ae9b..42b8b2ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,15 +20,10 @@ from llumnix.utils import random_uuid - def pytest_sessionstart(session): - subprocess.run(["ray", "stop"], check=False, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) subprocess.run(["ray", "start", "--head", "--disable-usage-stats", "--port=6379"], check=False, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) -def pytest_sessionfinish(session, exitstatus): - subprocess.run(["ray", "stop"], check=False, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - @pytest.fixture def setup_ray_env(): ray.init(namespace="llumnix", ignore_reinit_error=True) diff --git a/tests/e2e_test/test_e2e.py b/tests/e2e_test/test_e2e.py index 24dd896e..87b03417 100644 --- a/tests/e2e_test/test_e2e.py +++ b/tests/e2e_test/test_e2e.py @@ -62,7 +62,7 @@ def run_vllm(model, max_model_len, sampling_params): @pytest.mark.asyncio @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="at least 1 gpus required for e2e test") @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) -@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo']) +@pytest.mark.parametrize("migration_backend", ['rayrpc', 'gloo']) @pytest.mark.parametrize("launch_mode", ['eief', 'eidf', 'dief', 'didf']) async def test_e2e(cleanup_ray_env, shutdown_llumnix_service, model, migration_backend, launch_mode): if migration_backend == 'gloo' and launch_mode != 'eief': diff --git a/tests/e2e_test/test_migration.py b/tests/e2e_test/test_migration.py index 700b8da3..ba0793da 100644 --- a/tests/e2e_test/test_migration.py +++ b/tests/e2e_test/test_migration.py @@ -90,11 +90,11 @@ def get_instance_num_blocks(): @pytest.mark.asyncio @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="at least 2 gpus required for migration bench") @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) -@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo']) +@pytest.mark.parametrize("migration_backend", ['rayrpc', 'gloo']) @pytest.mark.parametrize("migrated_request_status", ['running', 'waiting']) async def test_migration_benchmark(cleanup_ray_env, shutdown_llumnix_service, model, migration_backend, migrated_request_status): - if migrated_request_status == 'waiting' and migration_backend != 'rpc': - pytest.skip("When the migrated request status is waiting, only test the rpc migration backend.") + if migrated_request_status == 'waiting' and migration_backend != 'rayrpc': + pytest.skip("When the migrated request status is waiting, only test the rayrpc migration backend.") request_migration_policy = 'SR' if migrated_request_status == 'running' else 'FCW' ip = "127.0.0.1" diff --git a/tests/unit_test/backends/vllm/test_migration.py b/tests/unit_test/backends/vllm/test_migration.py index d247e160..d73b130e 100644 --- a/tests/unit_test/backends/vllm/test_migration.py +++ b/tests/unit_test/backends/vllm/test_migration.py @@ -83,7 +83,7 @@ async def step_async_try_schedule(): self.backend_engine.engine.step_async = step_async_try_schedule -@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo', 'nccl']) +@pytest.mark.parametrize("migration_backend", ['rayrpc', 'gloo', 'nccl']) @pytest.mark.parametrize("migration_request_status", ['waiting', 'running']) @pytest.mark.asyncio async def test_migration_correctness(setup_ray_env, migration_backend, migration_request_status): @@ -203,7 +203,7 @@ async def test_correctness(prompt): await test_correctness(prompt) que.cleanup() -@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo', 'nccl']) +@pytest.mark.parametrize("migration_backend", ['rayrpc', 'gloo', 'nccl']) @pytest.mark.asyncio async def test_pd_diaggregation_correctness(setup_ray_env, migration_backend): engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) diff --git a/tests/unit_test/backends/vllm/test_migration_backend.py b/tests/unit_test/backends/vllm/test_migration_backend.py index c6e23e10..5b92fb9c 100644 --- a/tests/unit_test/backends/vllm/test_migration_backend.py +++ b/tests/unit_test/backends/vllm/test_migration_backend.py @@ -37,7 +37,7 @@ def get_gpu_cache(self): return self.gpu_cache @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPU to run the test.") -@pytest.mark.parametrize("backend", ['rpc', 'gloo', 'nccl']) +@pytest.mark.parametrize("backend", ['rayrpc', 'gloo', 'nccl']) def test_migrate_cache(setup_ray_env, backend): engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config() migraiton_config = EngineManagerArgs(migration_buffer_blocks=3, migration_num_layers=5).create_migration_config() diff --git a/tests/unit_test/backends/vllm/test_worker.py b/tests/unit_test/backends/vllm/test_worker.py index a42b9f28..440bf6e9 100644 --- a/tests/unit_test/backends/vllm/test_worker.py +++ b/tests/unit_test/backends/vllm/test_worker.py @@ -56,7 +56,7 @@ def create_worker(rank: int, local_rank: int, engine_config: EngineConfig, return worker -@pytest.mark.parametrize("backend", ['rpc', 'gloo', 'nccl']) +@pytest.mark.parametrize("backend", ['rayrpc', 'gloo', 'nccl']) def test_reserve_memory_for_migration(setup_ray_env, backend): engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config() migration_config = EngineManagerArgs(migration_buffer_blocks=1).create_migration_config() @@ -77,7 +77,7 @@ def test_reserve_memory_for_migration(setup_ray_env, backend): assert migration_cache_size == occupy_memory @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPU to run the test.") -@pytest.mark.parametrize("backend", ['rpc', 'gloo', 'nccl']) +@pytest.mark.parametrize("backend", ['rayrpc', 'gloo', 'nccl']) def test_rebuild_migration_backend(setup_ray_env, backend): engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config() migration_config = EngineManagerArgs(migration_buffer_blocks=1).create_migration_config() diff --git a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py index 114ce551..11408902 100644 --- a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py +++ b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py @@ -21,7 +21,7 @@ def init_dispatch_scheduler(policy='load'): instance_load_calculator = InstanceLoadCalculator('remaining_steps', True) - dispatch_scheduler = DispatchScheduler(policy, instance_load_calculator, 1) + dispatch_scheduler = DispatchScheduler(policy, instance_load_calculator, 2) return dispatch_scheduler @pytest.fixture diff --git a/tests/unit_test/global_scheduler/test_global_scheduler.py b/tests/unit_test/global_scheduler/test_global_scheduler.py index 18c83f85..7079c96f 100644 --- a/tests/unit_test/global_scheduler/test_global_scheduler.py +++ b/tests/unit_test/global_scheduler/test_global_scheduler.py @@ -25,7 +25,7 @@ def init_global_scheduler(): global_scheduler_config = GlobalSchedulerConfig(0, 'remaining_steps', 'load', math.inf, 'defrag_constrained', 3.0, True, 'avg_load', - 10, 60, False, 'rpc') + 10, 60, False, 'rayrpc') global_scheduler = GlobalScheduler(global_scheduler_config) return global_scheduler diff --git a/tests/unit_test/global_scheduler/test_llm_engine_manager.py b/tests/unit_test/global_scheduler/test_llm_engine_manager.py index 5a09b283..595512c1 100644 --- a/tests/unit_test/global_scheduler/test_llm_engine_manager.py +++ b/tests/unit_test/global_scheduler/test_llm_engine_manager.py @@ -27,6 +27,7 @@ from llumnix.queue.queue_type import QueueType from llumnix.global_scheduler.scaling_scheduler import InstanceType from llumnix.backends.vllm.simulator import BackendSimVLLM +from llumnix.backends.backend_interface import BackendType from llumnix.backends.profiling import LatencyMemData # pylint: disable=unused-import @@ -104,7 +105,7 @@ def _get_lantecy_mem(self, *args, **kwargs): def init_manager(): try: - engine_manager_args = EngineManagerArgs(migration_backend="rpc", enable_migration=True) + engine_manager_args = EngineManagerArgs(migration_backend="rayrpc", enable_migration=True) engine_manager_args.log_instance_info = False engine_manager = LLMEngineManager.from_args(engine_manager_args, None) except ValueError: @@ -153,7 +154,7 @@ def test_init_llumlet(setup_ray_env, llumlet): def test_init_llumlets(setup_ray_env, engine_manager): engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) node_id = ray.get_runtime_context().get_node_id() - instance_ids, llumlets = ray.get(engine_manager.init_llumlets.remote(engine_args, node_id, QueueType("rayqueue"))) + instance_ids, llumlets = ray.get(engine_manager.init_llumlets.remote(engine_args, node_id, QueueType("rayqueue"), BackendType.VLLM, 1)) num_instances = ray.get(engine_manager.scale_up.remote(instance_ids, llumlets)) engine_manager_args = EngineManagerArgs() assert num_instances == engine_manager_args.initial_instances @@ -165,7 +166,7 @@ def test_init_llumlets_sim(setup_ray_env, engine_manager): llumnix.backends.vllm.simulator.BackendSimVLLM = MockBackendSim engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) node_id = ray.get_runtime_context().get_node_id() - instance_ids, llumlets = ray.get(engine_manager.init_llumlets.remote(engine_args, node_id, QueueType("rayqueue"))) + instance_ids, llumlets = ray.get(engine_manager.init_llumlets.remote(engine_args, node_id, QueueType("rayqueue"), BackendType.VLLM, 1)) num_instances = ray.get(engine_manager.scale_up.remote(instance_ids, llumlets)) engine_manager_args = EngineManagerArgs() assert num_instances == engine_manager_args.initial_instances @@ -219,15 +220,14 @@ def test_generate_and_abort(setup_ray_env, engine_manager, llumlet): def test_get_request_instance(setup_ray_env): _, llumlets = init_llumlets(2) llumlet, llumlet_1 = llumlets[0], llumlets[1] + engine_manager = init_manager() request_id = random_uuid() request_id_1 = random_uuid() - ray.get(llumlet.generate.remote(request_id, None, math.inf, None, None)) - ray.get(llumlet_1.generate.remote(request_id_1, None, math.inf, None, None)) + ray.get(engine_manager.generate.remote(request_id, None, math.inf, None, None)) + ray.get(engine_manager.generate.remote(request_id_1, None, math.inf, None, None)) num_requests = ray.get(llumlet.get_num_requests.remote()) num_requests_1 = ray.get(llumlet_1.get_num_requests.remote()) - assert num_requests == 1 - assert num_requests_1 == 1 - engine_manager = init_manager() + assert num_requests + num_requests_1 == 2 ray.get(engine_manager.abort.remote(request_id)) ray.get(engine_manager.abort.remote(request_id_1)) num_requests = ray.get(llumlet.get_num_requests.remote()) diff --git a/tests/unit_test/global_scheduler/test_migration_scheduler.py b/tests/unit_test/global_scheduler/test_migration_scheduler.py index 89b813c3..3ed1655a 100644 --- a/tests/unit_test/global_scheduler/test_migration_scheduler.py +++ b/tests/unit_test/global_scheduler/test_migration_scheduler.py @@ -27,7 +27,7 @@ def init_migration_scheduler(policy='balanced'): instance_load_calculator = InstanceLoadCalculator('remaining_steps', True) - migration_scheduler = MigrationScheduler(policy, MIGRATE_OUT_LOAD_THRESHOLD, instance_load_calculator, 'rpc') + migration_scheduler = MigrationScheduler(policy, MIGRATE_OUT_LOAD_THRESHOLD, instance_load_calculator, 'rayrpc') return migration_scheduler @pytest.fixture diff --git a/tests/unit_test/llumlet/test_engine_step_exception.py b/tests/unit_test/llumlet/test_engine_step_exception.py index 736520d2..50c621c0 100644 --- a/tests/unit_test/llumlet/test_engine_step_exception.py +++ b/tests/unit_test/llumlet/test_engine_step_exception.py @@ -51,7 +51,7 @@ async def raise_error_step(): @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need at least 1 GPU to run the test.") def test_engine_step_exception(setup_ray_env): engine_args = EngineArgs(model="facebook/opt-125m", max_model_len=8, worker_use_ray=True) - migration_config = MigrationConfig("SR", "rpc", 16, 1, 4, 5, 20) + migration_config = MigrationConfig("SR", "rayrpc", 16, 1, 4, 5, 20) node_id = ray.get_runtime_context().get_node_id() scheduling_strategy = NodeAffinitySchedulingStrategy(node_id=node_id, soft=False) diff --git a/tools/run_test.sh b/tools/run_test.sh index 42350908..7872a186 100755 --- a/tools/run_test.sh +++ b/tools/run_test.sh @@ -8,4 +8,4 @@ pgrep -f benchmark_serving.py | { while read pid; do kill -9 "$pid"; done; } nvidia-docker run --rm -t --net host --ipc host -v ${PWD}:/workspace -v /mnt:/mnt -w /workspace \ registry.cn-beijing.aliyuncs.com/llumnix/llumnix-dev:20240909_action_678a439 \ - sh -c "pip install -e . > /dev/null && make $test_mode" + sh -c "make install > /dev/null && make $test_mode"