diff --git a/Makefile b/Makefile index 6bc87a9b..8f75c380 100644 --- a/Makefile +++ b/Makefile @@ -21,22 +21,23 @@ install: .PHONY: lint lint: check_pylint_installed check_pytest_installed - @pylint --rcfile=.pylintrc -s n --jobs=32 ./llumnix + @pylint --rcfile=.pylintrc -s n --jobs=128 ./llumnix @pylint --rcfile=.pylintrc \ --disable=protected-access,super-init-not-called,unused-argument,redefined-outer-name,invalid-name \ - -s n --jobs=32 ./tests + -s n --jobs=128 ./tests .PHONY: test test: check_pytest_installed - @pytest -x -v --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings + @pytest -v --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings @python examlpes/offline_inference.py - @pytest -v tests/e2e_test/test_e2e.py - @pytest -v -x ./tests/e2e_test/test_migration.py + @pytest -v ./tests/e2e_test/test_e2e.py + @pytest -v ./tests/e2e_test/test_bench.py + @pytest -v ./tests/e2e_test/test_migration.py .PHONY: unit_test unit_test: check_pytest_installed - @pytest -x -v --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings + @pytest -v --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings .PHONY: offline_test offline_test: @@ -44,7 +45,7 @@ offline_test: .PHONY: e2e_test e2e_test: - @pytest -v tests/e2e_test/test_e2e.py + @pytest -v ./tests/e2e_test/test_e2e.py .PHONY: bench_test bench_test: @@ -52,7 +53,7 @@ bench_test: .PHONY: migration_test migration_test: - @pytest -v -x ./tests/e2e_test/test_migration.py + @pytest -v ./tests/e2e_test/test_migration.py #################### pygloo install for gloo migration backend begin #################### diff --git a/configs/base.yml b/configs/base.yml index afce7127..b9ee7077 100644 --- a/configs/base.yml +++ b/configs/base.yml @@ -2,8 +2,6 @@ SERVER: HOST: '127.0.0.1' PORT: 1234 QUEUE_TYPE: "rayqueue" - -RAY: RAY_CLUSTER_PORT: 6379 LAUNCH_RAY_CLUSTER: True @@ -18,9 +16,10 @@ MANAGER: ENABLE_MIGRATION: True ENABLE_DEFRAG: True - REQUEST_MIGRATION_POLICY: 'SJF' + REQUEST_MIGRATION_POLICY: 'SR' MIGRATION_BACKEND: 'gloo' - MIGRATION_CACHE_BLOCKS: 512 + MIGRATION_BUFFER_BLOCKS: 512 + MIGRATION_INTERNAL_BUFFER_NUM: 2 ENABLE_SCALING: False diff --git a/docs/Arguments.md b/docs/Arguments.md index 56474d82..32a21ed8 100644 --- a/docs/Arguments.md +++ b/docs/Arguments.md @@ -12,12 +12,12 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] [--initial-instances INITIAL_INSTANCES] [--load-metric {remaining_steps,usage_ratio}] [--polling-interval POLLING_INTERVAL] - [--dispatch-policy {balanced,load,queue}] + [--dispatch-policy {balanced,load,queue,rr}] [--enable-migration] [--pair-migration-frequency PAIR_MIGRATION_FREQUENCY] [--pair-migration-policy {balanced,defrag_constrained,defrag_relaxed}] [--migrate-out-threshold MIGRATE_OUT_THRESHOLD] - [--request-migration-policy {LCFS,SJF,LJF}] + [--request-migration-policy {LCR,SR,LR,FCW,FCWSR}] [--enable-defrag ENABLE_DEFRAG] [--enable-scaling] [--min-instances MIN_INSTANCES] @@ -32,12 +32,15 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] [--profiling-result-file-path PROFILING_RESULT_FILE_PATH] [--gpu-type GPU_TYPE] [--polling-interval POLLING_INTERVAL] - [--migration-backend {gloo,nccl,rpc}] - [--migration-cache-blocks MIGRATION_CACHE_BLOCKS] + [--migration-backend {gloo,rpc}] + [--migration-buffer-blocks MIGRATION_BUFFER_BLOCKS] [--migration-backend-init-timeout MIGRATION_BACKEND_INIT_TIMEOUT] [--migration-num-layers MIGRATION_NUM_LAYERS] [--last-stage-max-blocks LAST_STAGE_MAX_BLOCKS] [--max-stages MAX_STAGES] + [--enable-pd-disagg] + [--num-dispatch-instances NUM_DISPATCH_INSTANCES] + [--migration-internal-buffer-num MIGRATION_INTERNAL_BUFFER_NUM] [--log-request-timestamps] ``` @@ -66,7 +69,7 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] `--dispatch-policy` - Request dispatch policy. -- Possible choices: balanced, load, queue +- Possible choices: balanced, load, queue, rr - Default: "load" `--enable-migration` @@ -87,8 +90,8 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] `--request-migration-policy` - Request migration policy. -- Possible choices: LCFS, SJF, LJF -- Default: "SJF" +- Possible choices: LCR, SR, LR, FCW, FCWSR +- Default: "SR" `--enable-defrag` - Enable defragmentation through migration based on virtual usage. @@ -145,8 +148,8 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] - Possible choices: gloo, rpc - Default: "rpc" -`--migration-cache-blocks` -- Number of cache blocks in migration. +`--migration-buffer-blocks` +- Number of cache blocks in each migration buffer. - Default: 512 `--migration-backend-init-timeout` @@ -165,9 +168,19 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] - Drop migration if the number of stages > max_stages. - Default: 3 +`--migration-internal-buffer-num` +- Number of the buffer in migration backend for sending and receiving +- Default: 2 + `--log-request-timestamps` - Enable logging request timestamps. +`--enable-pd-disagg` +- Enable prefill decoding disaggregation. + +`--num-dispatch-instances` +- Number of available instances for dispatch. + # Unsupported vLLM feature options `--device` diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index 70a643cf..5394ea24 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -134,10 +134,11 @@ class EngineManagerArgs: migration_backend_init_timeout: float = None migration_backend: str = None - migration_cache_blocks: int = None + migration_buffer_blocks: int = None migration_num_layers: int = None last_stage_max_blocks: int = None max_stages: int = None + migration_internal_buffer_num: int = None enable_pd_disagg: bool = None @@ -172,11 +173,12 @@ def create_global_scheduler_configs( def create_migration_config(self) -> MigrationConfig: migration_config = MigrationConfig(self.request_migration_policy, self.migration_backend, - self.migration_cache_blocks, + self.migration_buffer_blocks, self.migration_num_layers, self.last_stage_max_blocks, self.max_stages, - self.migration_backend_init_timeout) + self.migration_backend_init_timeout, + self.migration_internal_buffer_num) return migration_config @classmethod @@ -195,6 +197,9 @@ def check_args(cls, args: 'EngineManagerArgs', parser: argparse.ArgumentParser): if hasattr(action, 'choices') and action.choices is not None and hasattr(args, action.dest): assert getattr(args, action.dest) in action.choices, f"{action.dest} should be one of {action.choices}." + assert args.migration_backend != 'nccl', 'NCCL has been temporarily deprecated due to its incompatibility with \ + concurrent migrations in Llumnix.' + assert args.migration_backend != 'gloo' or (args.migration_backend == 'gloo' \ and not args.disable_init_instance_by_manager and not args.disable_fixed_node_init_instance), \ ("When using gloo as migration backend, " @@ -223,8 +228,13 @@ def add_cli_args( parser.add_argument('--dispatch-policy', type=str, - choices=['balanced', 'load', 'queue', 'flood'], - help='request dispatch policy') + choices=['balanced', 'load', 'queue', 'flood', 'rr'], + help='The request dispatch policy.\n\n' + '* "balanced" dispatch request to the instance with minimum requests dispatched.\n' + '* "load" dispatch request to the instance with lowest instance load.\n' + '* "queue" dispatch request to the instance with minimum waiting request queue length.\n' + '* "flood" dispatch request to the instance with maximum requests dispatched.\n' + '* "rr" dispatch requests with round-robin policy.\n') parser.add_argument('--num-available-dispatch-instances', type=int, help='number of available instances for dispatching') @@ -238,14 +248,25 @@ def add_cli_args( parser.add_argument('--pair-migration-policy', type=str, choices=['balanced', 'defrag_constrained', 'defrag_relaxed'], - help='pair migration policy') + help='The pair migration policy.\n\n' + '* "balanced" pair migration to make the instance load of instance more balanced.\n' + '* "defrag_constrained" pair migration without balanced constraint to ' + 'achieve defragmentation thoroughly (with instance constraints).\n' + '* "defrag_relaxed" pair migration to without balanced constraint ' + 'to achieve defragmentation thoroughly (without instance constraints).\n') parser.add_argument('--migrate-out-threshold', type=float, help='migrate out instance load threshold') parser.add_argument('--request-migration-policy', type=str, - choices=['LCFS', 'SJF', 'LJF'], - help='request migration policy') + default=None, + choices=['LCR', 'SR', 'LR', 'FCW', 'FCWSR'], + help='The request migration policy.\n\n' + '* "LCR" migrate the running request last come.\n' + '* "SR" migrate the running request shortest.\n' + '* "LR" migrate the running request longest.\n' + '* "FCW" migrate the waiting request first come.\n' + '* "FCWSR" migrate the waiting request first come and running request shortest.\n') parser.add_argument('--enable-defrag', type=bool, help='enable defragmentation through migration based on virtual usage') @@ -288,24 +309,30 @@ def add_cli_args( parser.add_argument('--migration-backend', type=str, - choices=['gloo','nccl','rpc'], + choices=['gloo', 'nccl', 'rpc'], help='communication backend of migration') parser.add_argument('--migration-backend-init-timeout', type=float, help='timeout(s) for initializing migration backend') - parser.add_argument('--migration-cache-blocks', + parser.add_argument('--migration-buffer-blocks', type=int, - help='number of cache blocks in migration') + help='number of cache blocks in each migration buffer') parser.add_argument('--migration-num-layers', type=int, help='number of kv-cache layers to transfer in each round during migration') parser.add_argument('--last-stage-max-blocks', type=int, help='if the number pf remain blocks < last_stage_max_blocks, do last stage migration') + parser.add_argument('--migration-internal-buffer-num', + type=int, + help='number of the buffer in migration backend for sending and receiving') parser.add_argument('--max-stages', type=int, help='drop migration if the number of stages > max_stages') parser.add_argument('--enable-pd-disagg', - type=bool, + action='store_true', help='enable prefill decoding disaggregation') + parser.add_argument('--num-dispatch-instances', + type=int, + help='number of available instances for dispatch') return parser diff --git a/llumnix/backends/backend_interface.py b/llumnix/backends/backend_interface.py index 16a8ac1f..28e1e802 100644 --- a/llumnix/backends/backend_interface.py +++ b/llumnix/backends/backend_interface.py @@ -13,9 +13,9 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Iterable, List, Union +from typing import Iterable, List, Union, Deque -from llumnix.llumlet.request import LlumnixRequest +from llumnix.llumlet.request import LlumnixRequest, RequestStatus from llumnix.server_info import ServerInfo class EngineState(str, Enum): @@ -99,14 +99,21 @@ def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_st raise NotImplementedError @abstractmethod - def get_running_queue(self) -> List[LlumnixRequest]: + def get_running_queue(self) -> Deque[LlumnixRequest]: """ Return backend's running queue. """ raise NotImplementedError @abstractmethod - def remove_running_request(self, request_id: str) -> None: + def get_waiting_queue(self) -> Deque[LlumnixRequest]: + """ + Return backend's waiting queue. + """ + raise NotImplementedError + + @abstractmethod + def remove_running_request(self, request_id: str) -> bool: """ Removes a request from the backend's running queue. @@ -117,6 +124,26 @@ def remove_running_request(self, request_id: str) -> None: Args: request_id: A string identifier for the request that is to be removed from the running queue. This ID uniquely identifies the request within the backend system. + + Returns: + True if the request was successfully removed from the running queue, False otherwise. + """ + raise NotImplementedError + + @abstractmethod + def remove_waiting_request(self, request_id: str) -> bool: + """ + Removes a request from the backend's waiting queue. + + This method is responsible for safely halting and removing an active request from the waiting + queue of the backend engine. This action is performed in waiting request migration. + + Args: + request_id: A string identifier for the request that is to be removed from the waiting + queue. This ID uniquely identifies the request within the backend system. + + Returns: + True if the request was successfully removed from the waiting queue, False otherwise. """ raise NotImplementedError @@ -164,17 +191,25 @@ def pop_migrating_out_requests_last_stage(self) -> List[LlumnixRequest]: raise NotImplementedError @abstractmethod - def pre_alloc(self, request_id: str, block_num: int) -> List[int]: + def pre_alloc(self, + request_id: str, + request_status: RequestStatus, + request_arrival_time: float, + block_num: int) -> List[int]: """Pre-allocates cache blocks for a migrating request. This method selects a specified number of free cache blocks to be reserved for an incoming migration request identified by the given request ID. It updates the pre-allocation cache dictionary with the allocated blocks, which ensures that these blocks are not used by - another process until the migration is finished. + another process until the migration is finished. For the waiting request, it only reserves + free cache blocks when the request is the earliest arrival one among the requests of dst instance's + waiting queue. Args: request_id: The unique identifier of the migration request for which cache blocks are to be pre-allocated. + request_status: The status (waiting/running) of the request. + request_arrival_time: The arrival time of the request. block_num: The number of cache blocks that need to be pre-allocated for the request. Returns: @@ -187,9 +222,8 @@ def add_running_request(self, backend_request: LlumnixRequest) -> None: """ Adds a backend request to the running queue for processing. - This method enqueues a backend request into engine running queue, marking it for - active processing. It is used when a suspend migrating request should be added back - to running queue. + This method enqueues a backend request into engine running queue. + It is used when a suspend migrating request should be added back to running queue. Args: backend_request: An object representing the backend request. The type of this @@ -199,19 +233,17 @@ def add_running_request(self, backend_request: LlumnixRequest) -> None: raise NotImplementedError @abstractmethod - def is_request_running(self, backend_request: LlumnixRequest) -> bool: - """Checks if a given backend request is currently in the running queue. + def add_waiting_request(self, backend_request: LlumnixRequest) -> None: + """ + Adds a backend request to the waiting queue for processing. - This method determines whether a backend request is present and actively being processed - in the running queue. + This method enqueues a backend request into engine waiting queue. + It is used when a suspend migrating request should be added back to waiting queue. Args: backend_request: An object representing the backend request. The type of this object is dependent on the backend implementation and the details of the request. - - Returns: - True if the backend request is currently in the running queue; False otherwise. """ raise NotImplementedError diff --git a/llumnix/backends/migration_backend_interface.py b/llumnix/backends/migration_backend_interface.py index 808ba8c8..9fd231cc 100644 --- a/llumnix/backends/migration_backend_interface.py +++ b/llumnix/backends/migration_backend_interface.py @@ -13,7 +13,9 @@ from abc import ABC, abstractmethod from typing import List +import queue +import torch class MigrationBackendBase(ABC): @abstractmethod @@ -39,3 +41,24 @@ def do_send(self, dst_handle, blocks: List[int]): @abstractmethod def do_recv(self, src_handle, blocks: List[int]): raise NotImplementedError + +class BufferMigrationBackend(MigrationBackendBase): + def __init__(self, num_buffer, buffer_shape, buffer_dtype, buffer_device, pin_memory, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.num_buffer = num_buffer + + self.dummy_buffer = [ + torch.empty(size=buffer_shape, dtype=buffer_dtype, device=buffer_device, pin_memory=pin_memory) + for _ in range(self.num_buffer) + ] + + self.avaiable_buffer_queue = queue.Queue() + for i in range(self.num_buffer): + self.avaiable_buffer_queue.put_nowait(i) + + def get_available_cache(self): + return self.avaiable_buffer_queue.get() + + def put_back_cache(self, buffer_id): + self.avaiable_buffer_queue.put_nowait(buffer_id) diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index bf583366..59b41fa7 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -13,7 +13,7 @@ import time import traceback -from typing import Any, List, Optional, Dict, Union, Iterable, Tuple +from typing import Any, List, Optional, Dict, Union, Iterable, Tuple, Deque from collections import defaultdict import threading import asyncio @@ -34,7 +34,7 @@ from llumnix.instance_info import InstanceInfo from llumnix.backends.backend_interface import BackendInterface, EngineState from llumnix.backends.vllm.scheduler import SchedulerLlumnix -from llumnix.backends.vllm.sequence import SequenceGroupLlumnix +from llumnix.backends.vllm.sequence import SequenceGroupLlumnix, RequestStatus from llumnix.backends.profiling import LatencyMemData from llumnix.server_info import ServerInfo from llumnix.internal_config import MigrationConfig @@ -199,7 +199,7 @@ def _process_model_outputs( # TODO(ZeldaHuang): Use LlumnixRequestOutput to store llumnix output args. return request_outputs, server_infos - async def step_async(self) -> None: + async def step_async(self) -> Tuple[List[RequestOutput], List[ServerInfo]]: step_begin_time = time.time() request_outputs, server_infos = await super().step_async() for request_output in request_outputs: @@ -295,9 +295,11 @@ def __init__( self.worker_handle_list = self.engine.model_executor.workers.copy() if len(self.worker_handle_list) + 1 == self.engine.parallel_config.world_size: self.worker_handle_list.insert(0, ray.get_actor(f"instance_{self.instance_id}", namespace="llumnix")) - self._run_workers("init_migration", instance_id=instance_id, migration_config=migration_config,\ - src_worker_handle_list=self.worker_handle_list, - placement_group=placement_group, node_id=node_id) + self._run_workers("init_migration", instance_id=instance_id, + migration_config=migration_config, + src_worker_handle_list=self.worker_handle_list, + placement_group=placement_group, + node_id=node_id) self.state = EngineState.INIT logger.info("engine ({}) current state {}".format(self.instance_id, self.state)) @@ -350,15 +352,22 @@ def commit_dst_request(self, backend_request: SequenceGroupLlumnix) -> None: logger.info("add seq {} to block table".format(seq.seq_id)) pre_alloc_blocks = self.engine.scheduler.pre_alloc_cache_dict.pop(backend_request.request_id) self.engine.scheduler.block_manager.add_block_table(pre_alloc_blocks, seq.seq_id) - backend_request.reset_migration_args() - self.add_running_request(backend_request) + backend_request.reset_migration_args_dst() + assert backend_request.status in [RequestStatus.WAITING_MIGRATING, RequestStatus.RUNNING_MIGRATING], \ + "The status of request migrated to dst instance should be \ + RequestStatus.WAITING_MIGRATING or RequestStatus.RUNNING_MIGRATING" + if backend_request.status == RequestStatus.WAITING_MIGRATING: + self.add_waiting_request(backend_request) + else: # RUNNING_MIGRATING: + backend_request.reset_status() + self.add_running_request(backend_request) async def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None: await dst_ray_actor.execute_engine_method.remote("_run_workers", - "migrate_cache", - dst_blocks=dst_blocks, - src_blocks=src_blocks, - src_worker_handle_list=self.worker_handle_list) + "migrate_cache", + dst_blocks=dst_blocks, + src_blocks=src_blocks, + src_worker_handle_list=self.worker_handle_list) def _run_workers(self, *args, **kwargs): # pylint: disable=protected-access @@ -373,15 +382,21 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: request_ids = set(request_id) return self.engine.abort_request(request_ids) - def get_running_queue(self) -> List[SequenceGroupLlumnix]: + def get_running_queue(self) -> Deque[SequenceGroupLlumnix]: return self.engine.scheduler.get_running_queue() + def get_waiting_queue(self) -> Deque[SequenceGroupLlumnix]: + return self.engine.scheduler.get_waiting_queue() + def get_request_incremental_blocks(self, *args, **kwargs) -> List[int]: return self.engine.scheduler.get_request_incremental_blocks(*args, **kwargs) - def remove_running_request(self, *args, **kwargs) -> None: + def remove_running_request(self, *args, **kwargs) -> bool: return self.engine.scheduler.remove_running_request(*args, **kwargs) + def remove_waiting_request(self, *args, **kwargs) -> bool: + return self.engine.scheduler.remove_waiting_request(*args, **kwargs) + def add_migrating_out_request_last_stage(self, *args, **kwargs) -> None: return self.engine.scheduler.add_migrating_out_request_last_stage(*args, **kwargs) @@ -400,8 +415,8 @@ def should_abort_migration(self, *args, **kwargs) -> bool: def add_running_request(self, *args, **kwargs) -> None: return self.engine.scheduler.add_running_request(*args, **kwargs) - def is_request_running(self, *args, **kwargs) -> bool: - return self.engine.scheduler.is_request_running(*args, **kwargs) + def add_waiting_request(self, *args, **kwargs) -> None: + return self.engine.scheduler.add_waiting_request(*args, **kwargs) def free_dst_pre_alloc_cache(self, *args, **kwargs) -> None: return self.engine.scheduler.free_dst_pre_alloc_cache(*args, **kwargs) diff --git a/llumnix/backends/vllm/migration_backend.py b/llumnix/backends/vllm/migration_backend.py index 947d3e7e..950c1b31 100644 --- a/llumnix/backends/vllm/migration_backend.py +++ b/llumnix/backends/vllm/migration_backend.py @@ -15,11 +15,15 @@ import torch from func_timeout import func_set_timeout, FunctionTimedOut +import cupy +from cupy.cuda import nccl import ray import ray.util.collective as col +from ray.util.collective.collective_group import nccl_util + from vllm.worker.cache_engine import CacheEngine from llumnix.internal_config import MigrationConfig -from llumnix.backends.migration_backend_interface import MigrationBackendBase +from llumnix.backends.migration_backend_interface import MigrationBackendBase, BufferMigrationBackend from llumnix.logger import init_logger logger = init_logger(__name__) @@ -40,17 +44,16 @@ def exec_method(self, is_driver_worker, handle, *args, **kwargs): NUMPY_SUPPORTED_DTYPES = [torch.float32, torch.float16] -class RayRpcMigrationBackend(MigrationBackendBase): +class RayRpcMigrationBackend(BufferMigrationBackend): def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, worker_rank, worker_handle_list, \ scheduling_strategy, is_driver_worker, gpu_cache) -> None: - super().__init__() - self.migration_config = migration_config self.cache_engine = cache_engine self.worker_rank = worker_rank self.worker_handle_list = worker_handle_list self.actor = ProxyActor.options(scheduling_strategy=scheduling_strategy).remote() + self.migration_stream = torch.cuda.Stream() self.rpc_dtype = self.cache_engine.dtype if self.cache_engine.dtype in NUMPY_SUPPORTED_DTYPES: @@ -62,17 +65,13 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, self.is_driver_worker = is_driver_worker self.gpu_cache = gpu_cache self.cache_device = "cpu" - self.num_migration_cache_blocks = self.migration_config.migration_cache_blocks + self.num_migration_buffer_blocks = self.migration_config.migration_buffer_blocks self.num_layers = self.cache_engine.num_layers self.migration_cache_size = self.cache_engine.block_size * self.cache_engine.num_heads * self.cache_engine.head_size + buffer_shape = (self.num_migration_buffer_blocks, self.num_layers, 2, self.migration_cache_size) - self.dummy_cache = torch.empty( - size=(self.num_migration_cache_blocks, self.num_layers, 2, self.migration_cache_size), - dtype=self.cache_engine.dtype, - device=self.cache_device, - pin_memory=True - ) - self.migration_stream = torch.cuda.Stream() + super().__init__(migration_config.migration_internal_buffer_num, buffer_shape, self.cache_engine.dtype, + self.cache_device, pin_memory=True) def init_backend(self, group_name, world_size, rank) -> bool: logger.info("create rpc migration backend successfully.") @@ -94,30 +93,38 @@ def warmup(self) -> bool: def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]) -> None: tot_blocks = len(src_blocks) rpc_numpy_cache = None - for start_idx in range(0, tot_blocks, self.num_migration_cache_blocks): - offset = min(self.num_migration_cache_blocks, tot_blocks - start_idx) + for start_idx in range(0, tot_blocks, self.num_migration_buffer_blocks): + offset = min(self.num_migration_buffer_blocks, tot_blocks - start_idx) send_blocks = src_blocks[start_idx:start_idx+offset] ray_obj = self.actor.exec_method.remote(self.is_driver_worker, src_handle, "do_send", None, send_blocks) if rpc_numpy_cache is not None: self.do_recv(rpc_numpy_cache, recv_blocks) - rpc_numpy_cache = ray.get(ray_obj) + rpc_numpy_cache_ref = ray.get(ray_obj) + rpc_numpy_cache = ray.get(rpc_numpy_cache_ref) recv_blocks = dst_blocks[start_idx:start_idx+offset] self.do_recv(rpc_numpy_cache, recv_blocks) def do_send(self, dst_handle, blocks: List[int]): num_blocks = len(blocks) - send_cache = self.dummy_cache[:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size) + dummy_cache_idx = self.get_available_cache() + send_cache = self.dummy_buffer[dummy_cache_idx][:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size) src_to_dst = {block_num: idx for idx, block_num in enumerate(blocks)} with torch.cuda.stream(self.migration_stream): for layer_idx in range(self.num_layers): self.cache_engine.attn_backend.swap_blocks(self.gpu_cache[layer_idx], send_cache[layer_idx], src_to_dst) torch.cuda.Stream.synchronize(self.migration_stream) - return send_cache.to(self.rpc_dtype).numpy() + # Here, we use ray.put to store data and finally return the object reference so that we can release the internal buffer. + # This might seem like an anti-pattern, but it's okay since the kv-cache transferred is in the MB range and won't utilize + # Ray's optimization for returning small objects (<100KB). + data = ray.put(send_cache.to(self.rpc_dtype).numpy()) + self.put_back_cache(dummy_cache_idx) + return data def do_recv(self, src_handle, blocks: List[int]): num_blocks = len(blocks) src_to_dst = dict(enumerate(blocks)) - recv_cache = self.dummy_cache[:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size) + dummy_cache_idx = self.get_available_cache() + recv_cache = self.dummy_buffer[dummy_cache_idx][:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size) # use pin memory dummy_cache to speed up data transfer recv_cache.copy_(torch.from_numpy(src_handle)) @@ -125,6 +132,7 @@ def do_recv(self, src_handle, blocks: List[int]): for layer_idx in range(self.num_layers): self.cache_engine.attn_backend.swap_blocks(recv_cache[layer_idx], self.gpu_cache[layer_idx], src_to_dst) torch.cuda.Stream.synchronize(self.migration_stream) + self.put_back_cache(dummy_cache_idx) def try_import_gloo(): try: @@ -139,19 +147,14 @@ def try_import_gloo(): except ImportError as e: raise ImportError("Gloo is not installed. Please install it first.") from e -class RayColMigrationBackend(MigrationBackendBase): +class RayColMigrationBackend(BufferMigrationBackend): def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, local_rank, scheduling_strategy, is_driver_worker, gpu_cache) -> None: - super().__init__() - - # pylint: disable=C0415 - import cupy - self.migration_config = migration_config self.cache_engine = cache_engine self.backend = migration_config.migration_backend self.migration_num_layers = min(migration_config.migration_num_layers, self.cache_engine.num_layers) - self.num_migration_cache_blocks = migration_config.migration_cache_blocks + self.num_migration_buffer_blocks = migration_config.migration_buffer_blocks self.backend = migration_config.migration_backend self.global_world_size = -1 @@ -162,6 +165,7 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, self.actor = ProxyActor.options(scheduling_strategy=scheduling_strategy).remote() self.is_driver_worker = is_driver_worker self.gpu_cache = gpu_cache + self.migration_stream = cupy.cuda.Stream() self.migration_cache_size = self.cache_engine.block_size * self.cache_engine.num_heads * self.cache_engine.head_size @@ -169,17 +173,13 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, try_import_gloo() self.cache_device = "cpu" else: + nccl_util.TORCH_NCCL_DTYPE_MAP[torch.bfloat16] = nccl.NCCL_FLOAT16 self.cache_device = torch.device(f"cuda:{self.local_rank}") pin_memory = (self.backend == 'gloo') - self.dummy_cache = torch.empty( - size=(self.num_migration_cache_blocks, self.migration_num_layers, 2, self.migration_cache_size), - dtype=self.cache_engine.dtype, - device=self.cache_device, - pin_memory=pin_memory - ) - - self.migration_stream = cupy.cuda.Stream() + buffer_shape = (self.num_migration_buffer_blocks, self.migration_num_layers, 2, self.migration_cache_size) + super().__init__(migration_config.migration_internal_buffer_num, buffer_shape, self.cache_engine.dtype, + self.cache_device, pin_memory=pin_memory) def init_backend(self, group_name, world_size, rank) -> bool: @func_set_timeout(self.migration_config.migration_backend_init_timeout) @@ -224,7 +224,7 @@ def destory_backend(self) -> None: def warmup(self) -> bool: if self.global_world_size > 1: try: - col.allreduce(self.dummy_cache[0], self.group_name) + col.allreduce(self.dummy_buffer[0][0], self.group_name) # pylint: disable=W0703 except Exception as e: logger.info("warmup migration backend failed (group_name: {}, world_size: {}, rank: {}, backbend: {}), err: {}." @@ -241,8 +241,8 @@ def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int] tot_blocks = len(src_blocks) src_rank = ray.get(self.actor.exec_method.remote(self.is_driver_worker, src_handle, "get_global_rank")) - for start_idx in range(0, tot_blocks, self.num_migration_cache_blocks): - offset = min(self.num_migration_cache_blocks, tot_blocks - start_idx) + for start_idx in range(0, tot_blocks, self.num_migration_buffer_blocks): + offset = min(self.num_migration_buffer_blocks, tot_blocks - start_idx) send_blocks = src_blocks[start_idx:start_idx+offset] recv_blocks = dst_blocks[start_idx:start_idx+offset] self.actor.exec_method.remote(self.is_driver_worker, src_handle, "do_send", self.global_rank, send_blocks) @@ -250,7 +250,8 @@ def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int] def do_send(self, dst_handle, blocks: List[int]): num_blocks = len(blocks) - send_cache = self.dummy_cache[:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size) + dummy_cache_idx = self.get_available_cache() + send_cache = self.dummy_buffer[dummy_cache_idx][:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size) src_to_dst = {block_num: idx for idx, block_num in enumerate(blocks)} with self.migration_stream: @@ -261,11 +262,13 @@ def do_send(self, dst_handle, blocks: List[int]): # TODO(KuilongCui): check the error code if peer is dead col.send(send_cache, dst_handle, self.group_name) self.migration_stream.synchronize() + self.put_back_cache(dummy_cache_idx) def do_recv(self, src_handle, blocks: List[int]): num_blocks = len(blocks) src_to_dst = dict(enumerate(blocks)) - recv_cache = self.dummy_cache[:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size) + dummy_cache_idx = self.get_available_cache() + recv_cache = self.dummy_buffer[dummy_cache_idx][:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size) with self.migration_stream: for layer_idx in range(self.cache_engine.num_layers): @@ -274,21 +277,24 @@ def do_recv(self, src_handle, blocks: List[int]): col.recv(recv_cache, src_handle, self.group_name) self.cache_engine.attn_backend.swap_blocks(recv_cache[cache_idx], self.gpu_cache[layer_idx], src_to_dst) self.migration_stream.synchronize() + self.put_back_cache(dummy_cache_idx) def get_migration_backend(migration_config: MigrationConfig, cache_engine: CacheEngine, worker_handle_list, scheduling_strategy, is_driver_worker, gpu_cache, worker_rank, local_rank) -> MigrationBackendBase: - if cache_engine.num_gpu_blocks < migration_config.migration_cache_blocks: - logger.warning("migration_cache_blocks({}) is larger than num_gpu_blocks({}), reducing it to num_gpu_blocks." - .format(migration_config.migration_cache_blocks, cache_engine.num_gpu_blocks)) - migration_config.migration_cache_blocks = cache_engine.num_gpu_blocks + if cache_engine.num_gpu_blocks < migration_config.migration_buffer_blocks: + logger.warning("migration_buffer_blocks({}) is larger than num_gpu_blocks({}), reducing it to num_gpu_blocks." + .format(migration_config.migration_buffer_blocks, cache_engine.num_gpu_blocks)) + migration_config.migration_buffer_blocks = cache_engine.num_gpu_blocks - target_col = None + target_migration_backend = None backend = migration_config.migration_backend + assert backend in ['nccl', 'gloo', 'rpc'], "Unsupported migration backend: {} for llumnix".format(backend) + if backend in ['nccl', 'gloo']: - target_col = RayColMigrationBackend(migration_config, cache_engine, local_rank, scheduling_strategy, + target_migration_backend = RayColMigrationBackend(migration_config, cache_engine, local_rank, scheduling_strategy, is_driver_worker, gpu_cache) else: - target_col = RayRpcMigrationBackend(migration_config, cache_engine, worker_rank, worker_handle_list, + target_migration_backend = RayRpcMigrationBackend(migration_config, cache_engine, worker_rank, worker_handle_list, scheduling_strategy, is_driver_worker, gpu_cache) - return target_col + return target_migration_backend diff --git a/llumnix/backends/vllm/scheduler.py b/llumnix/backends/vllm/scheduler.py index a14db0b3..4c6403ae 100644 --- a/llumnix/backends/vllm/scheduler.py +++ b/llumnix/backends/vllm/scheduler.py @@ -13,19 +13,24 @@ from asyncio.log import logger import time -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Deque from collections import deque from vllm.core.block_manager_v1 import BlockSpaceManagerV1, BlockTable from vllm.core.scheduler import (Scheduler, PreemptionMode, SequenceStatus, SequenceGroupMetadata, SchedulerOutputs) +from vllm.core.policy import PolicyFactory +from vllm.sequence import SequenceGroup +from vllm.core.interfaces import AllocStatus from llumnix.instance_info import InstanceInfo from llumnix.logger import init_logger -from llumnix.llumlet.request import RequestInferenceType +from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType, RequestStatus from llumnix.backends.vllm.sequence import SequenceGroupLlumnix + logger = init_logger(__name__) + # TODO(ZeldaHuang): adapt prefix cache and sliding window, now use v1 manager class BlockManagerLlumnix(BlockSpaceManagerV1): def get_free_blocks(self, num_required_blocks: int) -> BlockTable: @@ -76,9 +81,12 @@ def _get_num_killed_requests(self) -> int: cnt += 1 return cnt - def get_running_queue(self): + def get_running_queue(self) -> Deque[SequenceGroupLlumnix]: return self.running + def get_waiting_queue(self) -> Deque[SequenceGroupLlumnix]: + return self.waiting + def get_all_request_ids(self) -> List[str]: request_ids : List[str] = [] for state_queue in [self.waiting, self.running, self.swapped]: @@ -86,18 +94,26 @@ def get_all_request_ids(self) -> List[str]: request_ids.append(seq_group.request_id) return request_ids - def get_request_incremental_blocks(self, backend_request: SequenceGroupLlumnix, pre_stage_num_blocks: int) -> List[int]: + def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_stage_num_blocks: int) -> List[int]: seq = backend_request.get_seqs()[0] blocks = self.block_manager.get_block_table(seq) return blocks[pre_stage_num_blocks:] - def remove_running_request(self, request_id: str) -> None: + def remove_running_request(self, request_id: str) -> bool: for seq_group in self.running: if seq_group.request_id == request_id: - seq = seq_group.get_seqs()[0] self.running.remove(seq_group) - seq.status = SequenceStatus.WAITING - break + seq_group.set_status(RequestStatus.RUNNING_MIGRATING) + return True + return False + + def remove_waiting_request(self, request_id: str) -> bool: + for seq_group in self.waiting: + if seq_group.request_id == request_id: + self.waiting.remove(seq_group) + seq_group.set_status(RequestStatus.WAITING_MIGRATING) + return True + return False def add_migrating_out_request_last_stage(self, backend_request: SequenceGroupLlumnix) -> None: self.migrating_out_request_last_stage.append(backend_request) @@ -110,7 +126,17 @@ def pop_migrating_out_requests_last_stage(self) -> List[SequenceGroupLlumnix]: self.migrating_out_request_last_stage.clear() return migrating_out_request_last_stage - def pre_alloc(self, request_id: str, block_num: int) -> List[int]: + def pre_alloc(self, + request_id: str, + request_status: RequestStatus, + request_arrival_time: float, + block_num: int) -> List[int]: + # Only migrate waiting request when the waiting request is the earliest arrival one + # among the requests of dst instance's waiting queue. + if request_status == RequestStatus.WAITING_MIGRATING: + if (self.waiting and request_arrival_time > self.waiting[0].arrival_time) \ + or block_num * self.cache_config.block_size > self.prompt_limit: + return [] blocks = self.block_manager.get_free_blocks(block_num) pre_blocks = self.pre_alloc_cache_dict.get(request_id, []) pre_blocks.extend(blocks) @@ -118,13 +144,37 @@ def pre_alloc(self, request_id: str, block_num: int) -> List[int]: blocks = [block.block_number for block in blocks] return blocks - def add_running_request(self, backend_request: SequenceGroupLlumnix) -> None: - seq = backend_request.get_seqs()[0] - seq.status = SequenceStatus.RUNNING + def add_running_request(self, backend_request: LlumnixRequest) -> None: + self._set_status(backend_request, status_to=SequenceStatus.RUNNING) self.running.append(backend_request) - def is_request_running(self, backend_request: SequenceGroupLlumnix) -> bool: - return backend_request in self.running + def add_waiting_request(self, backend_request: LlumnixRequest) -> None: + self._set_status(backend_request, status_to=SequenceStatus.WAITING) + # pylint: disable=E0203 + self.waiting.append(backend_request) + fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs") + self.waiting = fcfs_policy.sort_by_priority(time.time(), self.waiting) + + def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + if seq_group.status == RequestStatus.WAITING_MIGRATING: + return AllocStatus.OK + return super().can_allocate(seq_group) + + def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: + # Change seq status to running, but request status is still waiting_migrating. + if seq_group.status == RequestStatus.WAITING_MIGRATING: + # For the waiting request migrated in, blocks have already been allocated when pre alloc. + self._set_status(seq_group, status_to=SequenceStatus.RUNNING) + seq_group.reset_status() + else: + super()._allocate_and_set_running(seq_group) + + def _set_status(self, + seq_group: SequenceGroup, + status_to: SequenceStatus, + status_from: SequenceStatus = None): + for seq in seq_group.get_seqs(status=status_from): + seq.status = status_to def free_dst_pre_alloc_cache(self, request_id: str = None) -> None: if request_id: @@ -132,6 +182,7 @@ def free_dst_pre_alloc_cache(self, request_id: str = None) -> None: # pylint: disable=protected-access self.block_manager._free_block_table(blocks) else: + # TODO(s5u13b): Only effective with one-to-one migration restriction. # Clear all pre-allocated cache of dst instance when src instance encounters exception. request_ids = list(self.pre_alloc_cache_dict.keys()) for req_id in request_ids: @@ -141,7 +192,7 @@ def free_dst_pre_alloc_cache(self, request_id: str = None) -> None: def free_src_request(self, backend_request: SequenceGroupLlumnix) -> None: seq = backend_request.get_seqs()[0] - logger.info("free seq {}".format(seq.seq_id)) + logger.info("free request: {}, free seq: {}".format(backend_request.request_id, seq.seq_id)) self.free_seq(seq) def _get_instance_info(self, scheduled_seq_groups: List[SequenceGroupLlumnix]) -> InstanceInfo: @@ -184,15 +235,18 @@ def _get_instance_info(self, scheduled_seq_groups: List[SequenceGroupLlumnix]) - if scheduled_seq_groups: instance_info.inference_type = scheduled_seq_groups[-1].inference_type # TODO(ZeldaHuang) adapt chunked-prefill - instance_info.num_batched_tokens = sum([seq_group.request_len for seq_group in scheduled_seq_groups])\ - if instance_info.inference_type == RequestInferenceType.PREFILL else len(instance_info.running_seq_lens) - instance_info.finished_request_ids = [seq_group.request_id for seq_group in self.running if seq_group.is_finished()] + instance_info.num_batched_tokens = sum([seq_group.request_len for seq_group in scheduled_seq_groups]) \ + if instance_info.inference_type == RequestInferenceType.PREFILL \ + else len(instance_info.running_seq_lens) + instance_info.finished_request_ids = [seq_group.request_id for seq_group in self.running if seq_group.finished] return instance_info def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_group_metadata_list, scheduler_outputs = super().schedule() self.update_instance_info_callback(self._get_instance_info([scheduled_seq_group.seq_group \ for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups])) + for seq_group in self.waiting: + seq_group.try_schedule_times += 1 return seq_group_metadata_list, scheduler_outputs def _schedule_running(self, running_queue: deque, *args, **kwargs): diff --git a/llumnix/backends/vllm/sequence.py b/llumnix/backends/vllm/sequence.py index 3c41a5c6..5964f96d 100644 --- a/llumnix/backends/vllm/sequence.py +++ b/llumnix/backends/vllm/sequence.py @@ -11,9 +11,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from vllm.sequence import SequenceGroup +from vllm.sequence import SequenceGroup, SequenceStatus -from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType +from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType, RequestStatus class SequenceGroupLlumnix(SequenceGroup, LlumnixRequest): @@ -41,3 +41,29 @@ def inference_type(self) -> RequestInferenceType: if self.is_prefill(): return RequestInferenceType.PREFILL return RequestInferenceType.DECODE + + @property + def finished(self) -> bool: + return self.get_seqs()[0].is_finished() + + @property + def arrival_time(self) -> float: + return self.metrics.arrival_time + + @property + def status(self) -> RequestStatus: + if self._status: + return self._status + status = self.get_seqs()[0].status + if status == SequenceStatus.RUNNING: + request_status = RequestStatus.RUNNING + elif status == SequenceStatus.WAITING: + request_status = RequestStatus.WAITING + else: + request_status = RequestStatus.FINISHED + return request_status + + @property + def prefill_num_blocks(self) -> int: + # Get the prefill len of the waiting request. + return len(self.get_seqs()[0].logical_token_blocks) diff --git a/llumnix/backends/vllm/simulator.py b/llumnix/backends/vllm/simulator.py index b5ccb45b..94367d75 100644 --- a/llumnix/backends/vllm/simulator.py +++ b/llumnix/backends/vllm/simulator.py @@ -36,6 +36,7 @@ def __init__( migration_config: MigrationConfig, profiling_result_file_path: str, engine_args: EngineArgs, + node_id: str = None, ) -> None: # multi-instance args latency_mem = self._get_lantecy_mem(profiling_result_file_path, engine_args) @@ -43,7 +44,8 @@ def __init__( output_queue_type=output_queue_type, migration_config=migration_config, instance_id=instance_id, - latency_mem=latency_mem) + latency_mem=latency_mem, + node_id=node_id) self.engine.scheduler = SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config) self.engine.scheduler.add_update_instance_info_callback(self.engine.update_instance_info) self.engine.output_processor.scheduler = self.engine.scheduler diff --git a/llumnix/backends/vllm/utils.py b/llumnix/backends/vllm/utils.py index 8aafc9f1..7e49720a 100644 --- a/llumnix/backends/vllm/utils.py +++ b/llumnix/backends/vllm/utils.py @@ -48,8 +48,7 @@ def check_engine_args(engine_args: AsyncEngineArgs, engine_manager_args: EngineM engine_config = engine_args.create_engine_config() parallel_config = engine_config.parallel_config if parallel_config.world_size > 1 and migration_config.migration_backend == 'nccl': - # TODO(s5u13b): fix logger - print("Llumnix does not support TP or PP enabled model when the migration backend is nccl, change migration backend to gloo.") + logger.info("Llumnix does not support TP or PP enabled model when the migration backend is nccl, change migration backend to gloo.") engine_manager_args.migration_backend = 'gloo' detect_unsupported_feature(engine_args) diff --git a/llumnix/backends/vllm/worker.py b/llumnix/backends/vllm/worker.py index 92bf1f1b..e38c3423 100644 --- a/llumnix/backends/vllm/worker.py +++ b/llumnix/backends/vllm/worker.py @@ -14,7 +14,6 @@ import time from typing import Dict, List import math -import ray import torch from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy @@ -50,10 +49,11 @@ def get_global_rank(self): def reserve_memory_for_migration(self, migration_config: MigrationConfig, model_config: ModelConfig, cache_config: CacheConfig, parallel_config: ParallelConfig) -> int: - migrate_cache_blocks_size = migration_config.migration_cache_blocks + migrate_cache_blocks_size = migration_config.migration_buffer_blocks migrate_num_layers = migration_config.migration_num_layers - dummy_cache_size = migrate_num_layers * migrate_cache_blocks_size * CacheEngine.get_cache_block_size( - cache_config, model_config, parallel_config) // model_config.get_num_layers(parallel_config) + dummy_cache_size = migration_config.migration_internal_buffer_num * migrate_num_layers * migrate_cache_blocks_size \ + * CacheEngine.get_cache_block_size(cache_config, model_config, parallel_config) \ + // model_config.get_num_layers(parallel_config) # For nccl migration backend, reserve gpu memory for dummy cache in migration backend. For other backends, # CPU memory is used for the dummy cache, which is almost unlimited, so no special action is needed. @@ -118,7 +118,7 @@ def migrate_cache(self, src_worker_handle_list, src_blocks: List[int], dst_block total_kv_cache_size = len(src_blocks) * CacheEngine.get_cache_block_size( self.cache_config, self.model_config, self.parallel_config) speed = total_kv_cache_size/_GB/(end_time - start_time) - logger.info("[migration_cache] blocks_num: {}, total_kv_cache_size: {}, time: {}s, speed: {}GB/s." + logger.info("[migrate_cache] blocks_num: {}, total_kv_cache_size: {}, time: {}s, speed: {}GB/s." .format(len(src_blocks), convert_bytes(total_kv_cache_size), end_time-start_time, speed)) def do_recv(self, *args, **kwargs): @@ -150,7 +150,3 @@ def shutdown(self) -> None: del self.migration_backend torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() - - def restart(self) -> None: - self.init_model() - self.init_cache_engine(self.cache_config) diff --git a/llumnix/config/default.py b/llumnix/config/default.py index 17849463..358d9e1b 100644 --- a/llumnix/config/default.py +++ b/llumnix/config/default.py @@ -80,7 +80,7 @@ _C.MANAGER.LOAD_METRIC = 'remaining_steps' # Request dispatch policy _C.MANAGER.DISPATCH_POLICY = 'load' -# Number of available dispatch instances. -1 indicates that all instances can be used for dispatching +# Number of available dispatch instances. math.inf indicates that all instances can be used for dispatching _C.MANAGER.NUM_DISPATCH_INSTANCES = math.inf # ----------------------------------------------------------------------------- @@ -95,7 +95,7 @@ # Migrate out instance load threshold _C.MANAGER.MIGRATE_OUT_THRESHOLD = 3.0 # Request migration policy -_C.MANAGER.REQUEST_MIGRATION_POLICY = 'SJF' +_C.MANAGER.REQUEST_MIGRATION_POLICY = 'SR' # Enable defragmentation through migration based on virtual usage _C.MANAGER.ENABLE_DEFRAG = False # Drop migration if the number of stages > max_stages @@ -108,9 +108,11 @@ # Timeout(s) for initializing migration backend _C.MANAGER.MIGRATION_BACKEND_INIT_TIMEOUT = 10.0 # Number of cache blocks in migration -_C.MANAGER.MIGRATION_CACHE_BLOCKS = 512 +_C.MANAGER.MIGRATION_BUFFER_BLOCKS = 512 # Number of kv-cache layers to transfer in each round during migration _C.MANAGER.MIGRATION_NUM_LAYERS = 1 +# Number of internal cache size in migration backend for sending and receiving +_C.MANAGER.MIGRATION_INTERNAL_BUFFER_NUM = 2 # ----------------------------------------------------------------------------- # SCALING CONFIGURATION diff --git a/llumnix/global_scheduler/dispatch_scheduler.py b/llumnix/global_scheduler/dispatch_scheduler.py index 175bdbde..0f5ae030 100644 --- a/llumnix/global_scheduler/dispatch_scheduler.py +++ b/llumnix/global_scheduler/dispatch_scheduler.py @@ -71,6 +71,10 @@ def remove_instance(self, instance_id: str) -> None: del self.instance_num_requests[instance_id] if instance_id in self.available_dispatch_instance_set: self.available_dispatch_instance_set.remove(instance_id) + # TODO(KuilongCui): Check it when there is no decode instance. + if self.num_instances >= self.num_dispatch_instances: + free_instance_id = next(iter(self.instance_id_set - self.available_dispatch_instance_set)) + self.available_dispatch_instance_set.add(free_instance_id) def _sort_instance_infos(self, descending: bool = True) -> None: @@ -133,12 +137,26 @@ def dispatch(self, logger.info("dispatch to {}, queue size: {}".format(instance_id, sorted_instance_infos[0].num_waiting_requests)) return instance_id +class RoundRobin(DispatchPolicy): + prev_instance_idx: int = -1 + + def dispatch(self, + instance_num_requests: Dict[str, int], + sorted_instance_infos: List[InstanceInfo]) -> str: + all_instance_ids = sorted(instance_num_requests.keys()) + cur_instance_idx = (self.prev_instance_idx + 1) % len(all_instance_ids) + + target_instance_id = all_instance_ids[cur_instance_idx] + self.prev_instance_idx = cur_instance_idx + return target_instance_id + class DispatchPolicyFactory: _POLICY_REGISTRY = { 'flood': Flood, 'balanced': Balanced, 'load': Load, 'queue': Queue, + 'rr': RoundRobin, } @classmethod diff --git a/llumnix/global_scheduler/global_scheduler.py b/llumnix/global_scheduler/global_scheduler.py index 79d6e88e..ec1568bb 100644 --- a/llumnix/global_scheduler/global_scheduler.py +++ b/llumnix/global_scheduler/global_scheduler.py @@ -18,7 +18,8 @@ from llumnix.internal_config import GlobalSchedulerConfig from llumnix.instance_info import InstanceLoadCalculator, InstanceInfo from llumnix.global_scheduler.dispatch_scheduler import DispatchScheduler -from llumnix.global_scheduler.migration_scheduler import MigrationScheduler, PairMigrationConstraints +from llumnix.global_scheduler.migration_scheduler import MigrationScheduler +from llumnix.global_scheduler.migration_policy import PairMigrationConstraints from llumnix.global_scheduler.scaling_scheduler import ScalingScheduler logger = init_logger(__name__) @@ -48,6 +49,7 @@ def __init__(self, global_scheduler_config.scale_down_threshold, global_scheduler_config.scaling_policy, self.instance_load_calculator, + self.enable_pd_disagg, global_scheduler_config.num_dispatch_instances) self.num_instances = 0 diff --git a/llumnix/global_scheduler/migration_filter.py b/llumnix/global_scheduler/migration_filter.py new file mode 100644 index 00000000..ea82e55b --- /dev/null +++ b/llumnix/global_scheduler/migration_filter.py @@ -0,0 +1,149 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, List, Optional +from abc import ABC, abstractmethod + +from llumnix.logger import init_logger +from llumnix.instance_info import InstanceInfo +from llumnix.global_scheduler.scaling_scheduler import InstanceType +from llumnix.global_scheduler.migration_policy import PairMigrationConstraints + +logger = init_logger(__name__) + +class MigrationFilterConfig: + def __init__(self, migrate_out_load_threshold): + self.migrate_out_load_threshold: float = migrate_out_load_threshold + +# TODO(KuilongCui): A filter might contain other filters; leave this for the future. +class MigrationFilterPolicy(ABC): + @abstractmethod + def filter_src_condition(self, filter_config, pair_migration_type) -> Callable[[InstanceInfo], bool]: + raise NotImplementedError + + @abstractmethod + def filter_dst_condition(self, filter_config, pair_migration_type) -> Callable[[InstanceInfo], bool]: + raise NotImplementedError + +class MigrationInstanceFilter(ABC): + def __init__(self, filter_config: MigrationFilterConfig) -> None: + self.filter_config = filter_config + self.registered_filters: Dict[str, MigrationFilterPolicy] = {} + + def register_filter(self, filter_name: str, migration_filter: MigrationFilterPolicy) -> bool: + if filter_name in self.registered_filters: + logger.warning("migration filter {} has been registered.".format(filter_name)) + return False + + self.registered_filters[filter_name] = migration_filter + return True + + def unregister_filter(self, filter_name: str) -> None: + self.registered_filters.pop(filter_name, None) + + def get_filter(self, filter_name: str) -> Optional[MigrationFilterPolicy]: + return self.registered_filters.get(filter_name, None) + + def filter_instances(self, instance_infos: List[InstanceInfo], + pair_migration_type: PairMigrationConstraints) -> Dict[str, InstanceInfo]: + src_filter_conditions = [filter.filter_src_condition() for filter in self.registered_filters.values()] + dst_filter_conditions = [filter.filter_dst_condition() for filter in self.registered_filters.values()] + + if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: + policy_filter = MigrationFilterPolicyFactory.get_policy("load") + elif pair_migration_type in [PairMigrationConstraints.PREFILL_2_DECODING, PairMigrationConstraints.DECODING_2_DECODING]: + policy_filter = MigrationFilterPolicyFactory.get_policy('prefill_decode') + else: + raise ValueError(f"Unsupported pair migration type: {pair_migration_type}") + src_filter_conditions.append(policy_filter.filter_src_condition(self.filter_config, pair_migration_type)) + dst_filter_conditions.append(policy_filter.filter_dst_condition(self.filter_config, pair_migration_type)) + + filtered_src_instance_infos = [info for info in instance_infos if all(cond(info) for cond in src_filter_conditions)] + filtered_dst_instance_infos = [info for info in instance_infos if all(cond(info) for cond in dst_filter_conditions)] + + return filtered_src_instance_infos, filtered_dst_instance_infos + +class LoadConstrainedFilter(MigrationFilterPolicy): + def filter_src_condition(self, filter_config: MigrationFilterConfig, + pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: + return lambda instance_info: instance_info.num_killed_requests > 0 \ + or instance_info.instance_load_migrate > filter_config.migrate_out_load_threshold + + def filter_dst_condition(self, filter_config: MigrationFilterConfig, + pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: + return lambda instance_info: instance_info.num_killed_requests == 0 \ + and instance_info.instance_load_migrate < filter_config.migrate_out_load_threshold + +class PddFilter(MigrationFilterPolicy): + INSTANCE_FILTER_RULES = { + PairMigrationConstraints.DECODING_2_DECODING: (InstanceType.DECODE, InstanceType.DECODE), + PairMigrationConstraints.PREFILL_2_DECODING: (InstanceType.PREFILL, InstanceType.DECODE), + } + + def filter_src_condition(self, filter_config: MigrationFilterConfig, + pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: + src_type, _ = self.INSTANCE_FILTER_RULES[pair_migration_type] + instance_type_filter = lambda instance_info: instance_info.instance_type == src_type + + if pair_migration_type == PairMigrationConstraints.DECODING_2_DECODING: + inner_policy = MigrationFilterPolicyFactory.get_policy('load') + policy_filter = inner_policy.filter_src_condition(filter_config, pair_migration_type) + else: + policy_filter = lambda instance_info: True + + return lambda instance_info: instance_type_filter(instance_info) and policy_filter(instance_info) + + def filter_dst_condition(self, filter_config: MigrationFilterConfig, + pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: + _, dst_type = self.INSTANCE_FILTER_RULES[pair_migration_type] + instance_type_filter = lambda instance_info: instance_info.instance_type == dst_type + + if pair_migration_type == PairMigrationConstraints.DECODING_2_DECODING: + inner_policy = MigrationFilterPolicyFactory.get_policy('load') + policy_filter = inner_policy.filter_dst_condition(filter_config, pair_migration_type) + else: + policy_filter = lambda instance_info: instance_info.num_killed_requests == 0 + + return lambda instance_info: instance_type_filter(instance_info) and policy_filter(instance_info) + +class CustomFilter(MigrationFilterPolicy): + def __init__(self): + super().__init__() + self.src_filter = lambda _: True + self.dst_filter = lambda _: True + + def set_filter_condtition(self, src_filter: Optional[Callable[[InstanceInfo], bool]] = None, + dst_filter: Optional[Callable[[InstanceInfo], bool]] = None) -> None: + if src_filter: + self.src_filter = src_filter + if dst_filter: + self.dst_filter = dst_filter + + def filter_src_condition(self, filter_config: MigrationFilterConfig, + pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: + return self.src_filter + + def filter_dst_condition(self, filter_config: MigrationFilterConfig, + pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: + return self.dst_filter + +class MigrationFilterPolicyFactory: + _POLICY_REGISTRY = { + 'load': LoadConstrainedFilter, + 'prefill_decode': PddFilter, + 'custom': CustomFilter, + } + + @classmethod + def get_policy(cls, policy_name: PairMigrationConstraints, **kwargs) -> MigrationFilterPolicy: + return cls._POLICY_REGISTRY[policy_name](**kwargs) diff --git a/llumnix/global_scheduler/migration_policy.py b/llumnix/global_scheduler/migration_policy.py new file mode 100644 index 00000000..c917cce7 --- /dev/null +++ b/llumnix/global_scheduler/migration_policy.py @@ -0,0 +1,113 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple +from abc import ABC, abstractmethod +from enum import Enum +import copy +import numpy as np + +from llumnix.logger import init_logger +from llumnix.instance_info import InstanceInfo, InstanceLoadCalculator + +logger = init_logger(__name__) + +class PairMigrationConstraints(str, Enum): + """Target of Migration.""" + NO_CONSTRAINTS = "NO_CONSTRAINTS" + + # Enable the prefill-decoding disaggregration. + DECODING_2_DECODING = "DECODING_2_DECODING" + PREFILL_2_DECODING = "PREFILL_2_DECODING" + +class PairMigrationPolicy(ABC): + def __init__(self, + migrate_out_load_threshold: float, + instance_load_calculator: InstanceLoadCalculator) -> None: + self.migrate_out_load_threshold = migrate_out_load_threshold + self.instance_load_calculator = instance_load_calculator + + @abstractmethod + def pair_migration(self, + src_instance_infos: List[InstanceInfo], + dst_instance_infos: List[InstanceInfo], + ) -> List[Tuple[str, str]]: + raise NotImplementedError + + def sort_instance_infos(self, instance_infos: List[InstanceInfo], descending: bool = True) -> None: + key_attr = 'instance_load_migrate' + sorted_instance_infos = sorted( + instance_infos, + key=lambda instance_info: getattr(instance_info, key_attr), + reverse=descending + ) + return sorted_instance_infos + +class Balanced(PairMigrationPolicy): + def pair_migration(self, + src_instance_infos: List[InstanceInfo], + dst_instance_infos: List[InstanceInfo], + ) -> List[Tuple[str, str]]: + sorted_src_instance_infos = self.sort_instance_infos(src_instance_infos, descending=True) + sorted_dst_instance_infos = self.sort_instance_infos(dst_instance_infos, descending=False) + migrate_instance_pairs = [] + for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))): + load_diff_before_mig = sorted_src_instance_infos[i].instance_load_migrate - sorted_dst_instance_infos[i].instance_load_migrate + + left_load_after_mig = self._compute_instance_load_after_migrate(sorted_src_instance_infos[i], is_migrate_in=False) + right_load_after_mig = self._compute_instance_load_after_migrate(sorted_dst_instance_infos[i], is_migrate_in=True) + + # Add some constrains to reduce unnecessary migrations + if right_load_after_mig > self.migrate_out_load_threshold: + continue + load_diff_after_mig = left_load_after_mig - right_load_after_mig + if (0 < load_diff_after_mig < load_diff_before_mig) or (sorted_dst_instance_infos[i].instance_load_migrate == -np.inf): + migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id, + sorted_dst_instance_infos[i].instance_id)) + return migrate_instance_pairs + + def _compute_instance_load_after_migrate(self, instance_info: InstanceInfo, is_migrate_in: bool) -> float: + instance_info_after_migrate = copy.deepcopy(instance_info) + num_blocks_last_running_request = instance_info_after_migrate.num_blocks_last_running_request + + if is_migrate_in: + instance_info_after_migrate.num_running_requests += 1 + instance_info_after_migrate.num_free_gpu_blocks -= num_blocks_last_running_request + else: + instance_info_after_migrate.num_running_requests -= 1 + instance_info_after_migrate.num_free_gpu_blocks += num_blocks_last_running_request + + return self.instance_load_calculator.compute_instance_load(instance_info_after_migrate, action='migrate') + +class DefragConstrained(PairMigrationPolicy): + def pair_migration(self, + src_instance_infos: List[InstanceInfo], + dst_instance_infos: List[InstanceInfo], + ) -> List[Tuple[str, str]]: + sorted_src_instance_infos = self.sort_instance_infos(src_instance_infos, descending=True) + sorted_dst_instance_infos = self.sort_instance_infos(dst_instance_infos, descending=False) + migrate_instance_pairs = [] + for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))): + # without any constrain in order to make prefill migrate happens as soon as possible + migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id, sorted_dst_instance_infos[i].instance_id)) + return migrate_instance_pairs + +class PairMigrationPolicyFactory: + _POLICY_REGISTRY = { + 'balanced': Balanced, + 'defrag_constrained': DefragConstrained, + } + + @classmethod + def get_policy(cls, policy_name: str, **kwargs) -> PairMigrationPolicy: + return cls._POLICY_REGISTRY[policy_name](**kwargs) diff --git a/llumnix/global_scheduler/migration_scheduler.py b/llumnix/global_scheduler/migration_scheduler.py index 3445b210..ad538f06 100644 --- a/llumnix/global_scheduler/migration_scheduler.py +++ b/llumnix/global_scheduler/migration_scheduler.py @@ -12,31 +12,22 @@ # limitations under the License. from typing import Dict, List, Tuple, Set -from abc import ABC, abstractmethod -from enum import Enum -import copy -import numpy as np from llumnix.logger import init_logger from llumnix.instance_info import InstanceInfo, InstanceLoadCalculator -from llumnix.global_scheduler.scaling_scheduler import InstanceType +from llumnix.global_scheduler.migration_filter import MigrationInstanceFilter, MigrationFilterConfig +from llumnix.global_scheduler.migration_policy import PairMigrationConstraints, PairMigrationPolicyFactory logger = init_logger(__name__) -class PairMigrationConstraints(str, Enum): - """Target of Migration.""" - NO_CONSTRAINTS = "NO_CONSTRAINTS" - - # Enable the prefill-decoding disaggregration. - DECODING_2_DECODING = "DECODING_2_DECODING" - PREFILL_2_DECODING = "PREFILL_2_DECODING" - class MigrationScheduler: def __init__(self, pair_migration_policy: str, migrate_out_load_threshold: float, instance_load_calculator: InstanceLoadCalculator) -> None: - self.migrate_out_load_threshold = migrate_out_load_threshold + self.filter_config = MigrationFilterConfig(migrate_out_load_threshold=migrate_out_load_threshold) + self.migration_filter = MigrationInstanceFilter(self.filter_config) + self.instance_load_calculator = instance_load_calculator self.enable_defrag = instance_load_calculator.enable_defrag if not self.enable_defrag: @@ -57,14 +48,9 @@ def __init__(self, self.sorted_instance_infos: List[InstanceInfo] = None def pair_migration(self, pair_migration_type: PairMigrationConstraints) -> List[Tuple[str, str]]: - self._sort_instance_infos(descending=False) - sorted_src_instance_infos, sorted_dst_instance_infos = self._get_migration_instance_infos(pair_migration_type) - return self.pair_migration_policy.pair_migration(sorted_src_instance_infos, sorted_dst_instance_infos) - - def _get_migration_instance_infos(self, pair_migration_type: PairMigrationConstraints) -> Dict[str, InstanceInfo]: - filter_instance_infos_policy = FilteringInstanceInfosPolicyFactory.get_policy(pair_migration_type, - migrate_out_load_threshold=self.migrate_out_load_threshold) - return filter_instance_infos_policy.filter_instances(self.sorted_instance_infos,pair_migration_type) + src_instance_infos, dst_instance_infos = self.migration_filter.filter_instances( + self.instance_info.values(), pair_migration_type) + return self.pair_migration_policy.pair_migration(src_instance_infos, dst_instance_infos) def update_instance_infos(self, instance_info: Dict[str, InstanceInfo]) -> None: @@ -77,138 +63,3 @@ def add_instance(self, instance_id: str) -> None: def remove_instance(self, instance_id: str) -> None: self.instance_id_set.remove(instance_id) self.num_instances = len(self.instance_id_set) - - def _sort_instance_infos(self, - descending: bool = True) -> None: - instance_infos: List[InstanceInfo] = list(self.instance_info.values()) - key_attr = 'instance_load_migrate' - self.sorted_instance_infos = sorted( - instance_infos, - key=lambda instance_info: getattr(instance_info, key_attr), - reverse=descending - ) - -class FilteringInstanceInfosPolicy(ABC): - def __init__(self, - migrate_out_load_threshold: float) -> None: - self.migrate_out_load_threshold = migrate_out_load_threshold - self.filter_instances_rules = { - PairMigrationConstraints.NO_CONSTRAINTS: (InstanceType.NO_CONSTRAINTS, InstanceType.NO_CONSTRAINTS), - PairMigrationConstraints.DECODING_2_DECODING: (InstanceType.DECODE, InstanceType.DECODE), - PairMigrationConstraints.PREFILL_2_DECODING: (InstanceType.PREFILL, InstanceType.DECODE), - } - - def filter_instances(self, sorted_instance_infos: List[InstanceInfo], - pair_migration_type: PairMigrationConstraints = None) -> Dict[str, InstanceInfo]: - src_type, dst_type = self.filter_instances_rules[pair_migration_type] - filtered_src_instance_infos = [info for info in sorted_instance_infos if info.instance_type == src_type] - filtered_dst_instance_infos = [info for info in sorted_instance_infos if info.instance_type == dst_type] - src_instance_infos = self.filter_src_instances(filtered_src_instance_infos) - dst_instance_infos = self.filter_dst_instances(filtered_dst_instance_infos) - return src_instance_infos, dst_instance_infos - - @abstractmethod - def filter_src_instances(self, filtered_instance_infos) -> Dict[str, InstanceInfo]: - raise NotImplementedError - - @abstractmethod - def filter_dst_instances(self, filtered_instance_infos) -> Dict[str, InstanceInfo]: - raise NotImplementedError - -class FilterConstrained(FilteringInstanceInfosPolicy): - def filter_src_instances(self, filtered_instance_infos: List[InstanceInfo]) -> Dict[str, InstanceInfo]: - src_instance_infos = [i for i in reversed(filtered_instance_infos) - if i.num_killed_requests > 0 or i.instance_load_migrate > self.migrate_out_load_threshold] - return src_instance_infos - - def filter_dst_instances(self, filtered_instance_infos: List[InstanceInfo]) -> Dict[str, InstanceInfo]: - dst_instance_infos = [i for i in filtered_instance_infos - if i.num_killed_requests == 0 and i.instance_load_migrate < self.migrate_out_load_threshold] - return dst_instance_infos - -class FilterRelaxed(FilteringInstanceInfosPolicy): - # The policy is currently used to select the decoding instances to migrate requests from the prefill instances. - def filter_src_instances(self, filtered_instance_infos: List[InstanceInfo]) -> Dict[str, InstanceInfo]: - src_instance_infos = list(reversed(filtered_instance_infos)) - return src_instance_infos - - def filter_dst_instances(self, filtered_instance_infos: List[InstanceInfo]) -> Dict[str, InstanceInfo]: - dst_instance_infos = [i for i in filtered_instance_infos - if i.num_killed_requests == 0] - return dst_instance_infos - -class FilteringInstanceInfosPolicyFactory: - _POLICY_REGISTRY = { - PairMigrationConstraints.NO_CONSTRAINTS: FilterConstrained, - PairMigrationConstraints.DECODING_2_DECODING: FilterConstrained, - PairMigrationConstraints.PREFILL_2_DECODING: FilterRelaxed, - } - - @classmethod - def get_policy(cls, policy_name: PairMigrationConstraints, **kwargs) -> FilteringInstanceInfosPolicy: - return cls._POLICY_REGISTRY[policy_name](**kwargs) - -class PairMigrationPolicy(ABC): - def __init__(self, - migrate_out_load_threshold: float, - instance_load_calculator: InstanceLoadCalculator) -> None: - self.migrate_out_load_threshold = migrate_out_load_threshold - self.instance_load_calculator = instance_load_calculator - - @abstractmethod - def pair_migration(self, - sorted_src_instance_infos: List[InstanceInfo], - sorted_dst_instance_infos: List[InstanceInfo], - ) -> List[Tuple[str, str]]: - raise NotImplementedError - -class Balanced(PairMigrationPolicy): - def pair_migration(self, - sorted_src_instance_infos: List[InstanceInfo], - sorted_dst_instance_infos: List[InstanceInfo], - ) -> List[Tuple[str, str]]: - migrate_instance_pairs = [] - for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))): - load_diff_before_mig = sorted_src_instance_infos[i].instance_load_migrate - sorted_dst_instance_infos[i].instance_load_migrate - left_load_after_mig = self._compute_instance_load_after_migrate(sorted_src_instance_infos[i], is_migrate_in=False) - right_load_after_mig = self._compute_instance_load_after_migrate(sorted_dst_instance_infos[i], is_migrate_in=True) - # Add some constrains to reduce unnecessary migrations - if right_load_after_mig > self.migrate_out_load_threshold: - continue - load_diff_after_mig = left_load_after_mig - right_load_after_mig - if (0 < load_diff_after_mig < load_diff_before_mig) or (sorted_dst_instance_infos[i].instance_load_migrate == -np.inf): - migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id, - sorted_dst_instance_infos[i].instance_id)) - return migrate_instance_pairs - - def _compute_instance_load_after_migrate(self, instance_info: InstanceInfo, is_migrate_in: bool) -> float: - instance_info_after_migrate = copy.deepcopy(instance_info) - num_blocks_last_running_request = instance_info_after_migrate.num_blocks_last_running_request - if is_migrate_in: - instance_info_after_migrate.num_running_requests += 1 - instance_info_after_migrate.num_free_gpu_blocks -= num_blocks_last_running_request - else: - instance_info_after_migrate.num_running_requests -= 1 - instance_info_after_migrate.num_free_gpu_blocks += num_blocks_last_running_request - return self.instance_load_calculator.compute_instance_load(instance_info_after_migrate, action='migrate') - -class DefragConstrained(PairMigrationPolicy): - def pair_migration(self, - sorted_src_instance_infos: List[InstanceInfo], - sorted_dst_instance_infos: List[InstanceInfo], - ) -> List[Tuple[str, str]]: - migrate_instance_pairs = [] - for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))): - # without any constrain in order to make prefill migrate happens as soon as possible - migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id, sorted_dst_instance_infos[i].instance_id)) - return migrate_instance_pairs - -class PairMigrationPolicyFactory: - _POLICY_REGISTRY = { - 'balanced': Balanced, - 'defrag_constrained': DefragConstrained, - } - - @classmethod - def get_policy(cls, policy_name: str, **kwargs) -> PairMigrationPolicy: - return cls._POLICY_REGISTRY[policy_name](**kwargs) diff --git a/llumnix/global_scheduler/scaling_scheduler.py b/llumnix/global_scheduler/scaling_scheduler.py index edcc9627..7607d88a 100644 --- a/llumnix/global_scheduler/scaling_scheduler.py +++ b/llumnix/global_scheduler/scaling_scheduler.py @@ -14,7 +14,6 @@ from typing import Dict, List, Tuple, Set from abc import ABC, abstractmethod from enum import Enum -import math import numpy as np from llumnix.logger import init_logger @@ -36,6 +35,7 @@ def __init__(self, scale_down_threshold: float, scaling_policy: str, instance_load_calculator: InstanceLoadCalculator, + enable_pd_disagg: bool, maximum_prefill_instance_num: int) -> None: self.scale_up_threshold = scale_up_threshold self.scale_down_threshold = scale_down_threshold @@ -46,6 +46,7 @@ def __init__(self, self.num_instances = 0 self.instance_id_set: Set[str] = set() self.maximum_prefill_instance_num = maximum_prefill_instance_num + self.enable_pd_disagg = enable_pd_disagg # instance info args self.instance_info: Dict[str, InstanceInfo] = None self.sorted_instance_infos: List[InstanceInfo] = None @@ -78,7 +79,7 @@ def add_instance(self, instance_id: str) -> None: self.instance_id_set.add(instance_id) self.num_instances = len(self.instance_id_set) instance_type = None - if self.maximum_prefill_instance_num == math.inf: + if not self.enable_pd_disagg: instance_type = InstanceType.NO_CONSTRAINTS else: if len(self.instance_type_id_set[InstanceType.PREFILL]) < self.maximum_prefill_instance_num: diff --git a/llumnix/internal_config.py b/llumnix/internal_config.py index 410d38e0..4412c13b 100644 --- a/llumnix/internal_config.py +++ b/llumnix/internal_config.py @@ -16,18 +16,20 @@ def __init__( self, request_migration_policy: str, migration_backend: str, - migration_cache_blocks: int, + migration_buffer_blocks: int, migration_num_layers: int, last_stage_max_blocks: int, max_stages: int, - migration_backend_init_timeout: float) -> None: + migration_backend_init_timeout: float, + migration_internal_buffer_num: int) -> None: self.request_migration_policy = request_migration_policy self.migration_backend = migration_backend self.migration_num_layers = migration_num_layers - self.migration_cache_blocks = migration_cache_blocks + self.migration_buffer_blocks = migration_buffer_blocks self.last_stage_max_blocks = last_stage_max_blocks self.max_stages = max_stages self.migration_backend_init_timeout = migration_backend_init_timeout + self.migration_internal_buffer_num = migration_internal_buffer_num class GlobalSchedulerConfig: def __init__( @@ -49,6 +51,8 @@ def __init__( self.dispatch_policy = dispatch_policy self.pair_migration_policy = pair_migration_policy + # TODO(KuilongCui): Use a better way to set the threshold, as having both positive and negative + # values can cause confusion. self.migrate_out_load_threshold = migrate_out_threshold*(-1) self.enable_defrag = enable_defrag diff --git a/llumnix/llm_engine_manager.py b/llumnix/llm_engine_manager.py index 7b47728b..5d8c48a5 100644 --- a/llumnix/llm_engine_manager.py +++ b/llumnix/llm_engine_manager.py @@ -15,7 +15,6 @@ import time import csv import os -import math from typing import Dict, List, Tuple, Union, Iterable from collections import defaultdict import traceback @@ -42,8 +41,6 @@ RETRIES_INTERVALS = 5.0 # TODO(s5u13b): Fix the logger when manager failover. - - class LLMEngineManager: def __init__(self, engine_manager_args: EngineManagerArgs, @@ -71,10 +68,7 @@ def __init__(self, logger.info("num_instances: {}".format(self.num_instances)) logger.info("max_instances: {}, min_instances: {}".format(self.max_instances, self.min_instances)) - # TODO(s5u13b): refactor auto-scaling - self.instances: Dict[str, Llumlet] = {} - self.instance_migrating: Dict[str, bool] = {} self.pending_rebuild_migration_instances = 0 self.global_scheduler = GlobalScheduler(global_scheduler_config) @@ -92,8 +86,9 @@ def __init__(self, # migrate states self.num_instance_info_updates = 0 - self.migrating = False + self.num_migrating = 0 + # TODO(s5u13b): refactor auto-scaling # auto-scaling states self.scale_up_time = -1 self.scale_down_time = -1 @@ -184,26 +179,31 @@ def update_instance_info_done_callback(instance_id: str, fut): self.global_scheduler.update_instance_infos([ret]) else: dead_instance_ids.append(instance_id) + while True: try: await asyncio.sleep(interval) tasks = [] instance_infos = [] dead_instance_ids = [] + for instance_id, instance in self.instances.items(): # Use asyncio.gather to wrap ray remote call to add done callback. task = asyncio.gather(instance.get_instance_info.remote(), return_exceptions=True) task.add_done_callback(partial(update_instance_info_done_callback, instance_id)) tasks.append(task) await asyncio.gather(*tasks, return_exceptions=True) + if len(dead_instance_ids) > 0: logger.info("[_update_instance_info_loop] dead instances: {}.".format(dead_instance_ids)) self.scale_down(dead_instance_ids) self.num_instance_info_updates += 1 + # Push migrate when the instance_info have updated a certain number of times. if self.enable_migration and self.num_instance_info_updates != 0 \ and self.num_instance_info_updates % self.pair_migration_frequency == 0: asyncio.create_task(self._push_migrations()) + if self.log_instance_info: self._log_instance_infos_to_csv(instance_infos) # pylint: disable=W0703 @@ -217,28 +217,27 @@ async def _clear_request_instance_loop(self, interval: float): while True: await asyncio.sleep(interval) self.request_instance = {} + async def _push_migrations(self) -> None: # Push migrate when the instance_info have updated a certain number of times. if self.enable_pd_disagg: - asyncio.create_task(self._migrate(PairMigrationConstraints.PREFILL_2_DECODING, math.inf)) - asyncio.create_task(self._migrate(PairMigrationConstraints.DECODING_2_DECODING, 1)) + asyncio.create_task(self._migrate(PairMigrationConstraints.PREFILL_2_DECODING)) + asyncio.create_task(self._migrate(PairMigrationConstraints.DECODING_2_DECODING)) else: - asyncio.create_task(self._migrate(PairMigrationConstraints.NO_CONSTRAINTS, 1)) + asyncio.create_task(self._migrate(PairMigrationConstraints.NO_CONSTRAINTS)) - async def _migrate(self, pair_migration_type: PairMigrationConstraints, migrate_in_num_requests: int) -> None: + async def _migrate(self, pair_migration_type: PairMigrationConstraints) -> None: async def migrate_done_callback(ret, migrate_instance_pair: Tuple[str, str]) -> None: - if migrate_instance_pair[0] in self.instance_migrating: - self.instance_migrating[migrate_instance_pair[0]] = False - if migrate_instance_pair[1] in self.instance_migrating: - self.instance_migrating[migrate_instance_pair[1]] = False - if isinstance(ret, (ray.exceptions.RayActorError, KeyError)): + self.num_migrating -= 1 + # TODO(s5u13b): Add more exception types for failover. + if isinstance(ret, (ray.exceptions.RayActorError, ray.exceptions.RayTaskError, KeyError)): has_error_pair = await self._check_instance_error(migrate_instance_pair) for i, has_error in enumerate(has_error_pair): # Instance without error should clear migration states. if not has_error: try: await self.instances[migrate_instance_pair[i]].clear_migration_states.remote(is_migrate_in=bool(i)) - except (ray.exceptions.RayActorError, KeyError): + except (ray.exceptions.RayActorError, ray.exceptions.RayTaskError, KeyError): has_error = True for i, has_error in enumerate(has_error_pair): if has_error: @@ -252,22 +251,23 @@ async def migrate_done_callback(ret, migrate_instance_pair: Tuple[str, str]) -> self.request_instance[migrate_out_request_id] = migrate_instance_pair[1] logger.info("{}->{} migrate done, migrate request {}".format( migrate_instance_pair[0], migrate_instance_pair[1], migrate_out_request_ids)) + def migrate_done_callback_wrapper(migrate_instance_pair: Tuple[str, str], fut) -> None: ret = fut.result() loop = asyncio.get_event_loop() loop.create_task(migrate_done_callback(ret, migrate_instance_pair)) - migrate_instance_pairs = self.global_scheduler.pair_migration(pair_migration_type) + try: + migrate_instance_pairs = self.global_scheduler.pair_migration(pair_migration_type) + migration_tasks = [] for _, migrate_instance_pair in enumerate(migrate_instance_pairs): + self.num_migrating += 1 migrate_out_instance_id, migrate_in_instance_id = migrate_instance_pair - if self.instance_migrating[migrate_out_instance_id] or self.instance_migrating[migrate_in_instance_id]: - continue - self.instance_migrating[migrate_out_instance_id] = True - self.instance_migrating[migrate_in_instance_id] = True + migrate_in_instance_name = "instance_{}".format(migrate_in_instance_id) # Use asyncio.gather to wrap ray remote call to add done callback. - task = asyncio.gather(self.instances[migrate_out_instance_id].migrate_out.remote(migrate_in_instance_name, migrate_in_num_requests), + task = asyncio.gather(self.instances[migrate_out_instance_id].migrate_out.remote(migrate_in_instance_name), return_exceptions=True) task.add_done_callback(partial(migrate_done_callback_wrapper, migrate_instance_pair)) migration_tasks.append(task) @@ -280,7 +280,7 @@ def migrate_done_callback_wrapper(migrate_instance_pair: Tuple[str, str], fut) - async def rebuild_migrate_backend(self) -> None: # Wait for all instances to finish migration - while any(self.instance_migrating.values()): + while self.num_migrating > 0: await asyncio.sleep(0.1) # During rebuilding migration backend, disable migrate @@ -353,7 +353,6 @@ def scale_up(self, instance_id: Union[str, Iterable[str]], llumlet_actor_handles if ins_id not in self.instances: indeed_update = True self.instances[ins_id] = llumlet_actor_handles[idx] - self.instance_migrating[ins_id] = False if self.log_instance_info: self.instance_last_logged_empty[ins_id] = False self.pending_rebuild_migration_instances += 1 @@ -364,7 +363,8 @@ def scale_up(self, instance_id: Union[str, Iterable[str]], llumlet_actor_handles # a coroutine is already handling the changes in the number of instances in the cluster and it will account for the changes # caused by this scale-up (see rebuild_migrate_backend for details). Therefore, we simply return in this case. Specifically, # for RPC, the Ray actor handle is used for the migration cache, so there is no need to rebuild the group. - if self.engine_manager_args.migration_backend != 'rpc' and indeed_update and no_pending_instance: + if self.enable_migration and self.engine_manager_args.migration_backend != 'rpc' \ + and indeed_update and no_pending_instance: asyncio.create_task(self.rebuild_migrate_backend()) return self.num_instances @@ -381,14 +381,13 @@ def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migrate_bac if ins_id in self.instances: indeed_update = True del self.instances[ins_id] - del self.instance_migrating[ins_id] if self.log_instance_info: del self.instance_last_logged_empty[ins_id] self.pending_rebuild_migration_instances += 1 self.global_scheduler.scale_down(instance_ids) self.num_instances = len(self.instances) - if self.engine_manager_args.migration_backend != 'rpc': + if self.enable_migration and self.engine_manager_args.migration_backend != 'rpc': if len(self.instances) == 0: self.pending_rebuild_migration_instances = 0 diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index 5aa3e4c2..3af73ac5 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -27,6 +27,7 @@ from llumnix.server_info import ServerInfo from llumnix.internal_config import MigrationConfig from llumnix.queue.queue_type import QueueType +from llumnix.llumlet.request import LlumnixRequest, RequestStatus logger = init_logger(__name__) @@ -55,7 +56,7 @@ def __init__(self, self.backend_engine) self.log_requests = True - self.check_state_thread = asyncio.create_task(self.check_state()) + asyncio.create_task(self._check_state_loop()) # pylint: disable=broad-except except Exception as e: logger.error("Failed to initialize llumlet: {}".format(e)) @@ -118,7 +119,7 @@ def from_args(cls, llumlet = engine_class.remote(instance_id, output_queue_type, backend_type, migration_config, *args, **kwargs) return llumlet - async def check_state(self): + async def _check_state_loop(self): while True: await asyncio.sleep(1) if self.backend_engine.state == EngineState.CRASHED: @@ -128,39 +129,55 @@ async def check_state(self): self_actor = ray.get_actor(self.actor_name) ray.kill(self_actor) - async def migrate_out(self, dst_instance_name: str, num_requests: int) -> List[str]: + async def migrate_out(self, dst_instance_name: str) -> List[str]: + migrate_out_requests = self.migration_scheduler.get_migrate_out_requests() + if len(migrate_out_requests) == 0: + return [] + migrated_request_list = [] + for migrate_out_request in migrate_out_requests: + migrated_request = await self._migrate_out_one_request(migrate_out_request, dst_instance_name) + migrated_request_list.extend(migrated_request) + if len(migrated_request) == 0 and migrate_out_request.eom: + break + return migrated_request_list + + async def _migrate_out_one_request(self, migrate_out_request: LlumnixRequest, dst_instance_name: str) -> List[LlumnixRequest]: try: + t0 = time.time() migrate_in_ray_actor = ray.get_actor(dst_instance_name, namespace='llumnix') dst_instance_id = dst_instance_name[len("instance_"):] - migrated_request_list = [] - continue_migrate = True - while continue_migrate and len(migrated_request_list) < num_requests: - t0 = time.time() - migrate_out_request = self.migration_scheduler.get_migrate_out_request() - if migrate_out_request is not None: - logger.info("migrate_out {}".format(migrate_out_request.request_id)) - if migrate_out_request is None: - return migrated_request_list - logger.info("{}->{} begin migrate out {}".format(self.instance_id, dst_instance_id, migrate_out_request.request_id)) - status = await self.migration_coordinator.migrate_out_multistage(migrate_in_ray_actor, migrate_out_request) - if status == MigrationStatus.FINISHED_DONE: - await migrate_in_ray_actor.execute_engine_method.remote("commit_dst_request", migrate_out_request) - self.backend_engine.free_src_request(migrate_out_request) - migrated_request_list.append(migrate_out_request.request_id) - migrate_out_request.stage_timestamps.append(time.time()) - self.backend_engine.remove_migrating_out_request_last_stage(migrate_out_request) - else: - migrate_out_request.reset_migration_args() + logger.info("{}->{} begin migrate out".format(self.instance_id, dst_instance_id)) + migrated_request = [] + if migrate_out_request.status == RequestStatus.RUNNING: + status = await self.migration_coordinator.migrate_out_running_request(migrate_in_ray_actor, migrate_out_request) + elif migrate_out_request.status == RequestStatus.WAITING: + status = await self.migration_coordinator.migrate_out_waiting_request(migrate_in_ray_actor, migrate_out_request) + else: + return migrated_request + if status == MigrationStatus.FINISHED: + await migrate_in_ray_actor.execute_engine_method.remote("commit_dst_request", migrate_out_request) + self.backend_engine.free_src_request(migrate_out_request) + self.backend_engine.remove_migrating_out_request_last_stage(migrate_out_request) + migrated_request.append(migrate_out_request.request_id) + else: # ABORTED_SRC or ABORTED_DST + migrate_out_request.reset_migration_args_src() + migrate_out_request.reset_status() + # If dst aborts itself, dst proactively frees the pre allocated cache in migrate_in_pre_alloc. + if status == MigrationStatus.ABORTED_SRC: await migrate_in_ray_actor.execute_migration_method.remote("free_dst_pre_alloc_cache", migrate_out_request.request_id) - continue_migrate = False - t1 = time.time() - logger.info("{}->{} migrate done, migrate request {}, status:{}, len:{} blocks, cost:{} ms" \ - .format(self.instance_id, dst_instance_id, migrated_request_list, status, \ - sum(migrate_out_request.stage_num_blocks_list), (t1 - t0)*1000)) + t1 = time.time() + logger.info("{}->{} migrate done, migrate request {}, migration status: {}, len: {} blocks, cost: {} ms" \ + .format(self.instance_id, dst_instance_id, migrated_request, status, \ + sum(migrate_out_request.stage_num_blocks_list), (t1 - t0)*1000)) except ray.exceptions.RayActorError: logger.info("[migrate_out] instance {} is dead".format(dst_instance_name[len("instance_"):])) raise - return migrated_request_list + # pylint: disable=W0703 + except Exception as e: + logger.error("unexpected exception occurs: {}".format(e)) + logger.error("exception traceback: {}".format(traceback.format_exc())) + raise + return migrated_request def get_instance_info(self) -> InstanceInfo: return self.backend_engine.engine.instance_info @@ -202,7 +219,13 @@ def clear_migration_states(self, is_migrate_in: bool) -> None: migrating_out_requests_last_stage = self.backend_engine.pop_migrating_out_requests_last_stage() for backend_request in migrating_out_requests_last_stage: logger.info("clear_migration_states: add request {} back to engine".format(backend_request.request_id)) - self.backend_engine.add_running_request(backend_request) + assert backend_request.status in [RequestStatus.WAITING_MIGRATING, RequestStatus.RUNNING_MIGRATING], \ + "The status of request in migrating_out_requests_last_stage should be \ + RequestStatus.WAITING_MIGRATING or RequestStatus.RUNNING_MIGRATING" + if backend_request.status == RequestStatus.RUNNING_MIGRATING: + self.backend_engine.add_running_request(backend_request) + else: # WAITING_MIGRATING + self.backend_engine.add_waiting_request(backend_request) def execute_migration_method(self, method, *args, **kwargs): executor = getattr(self.migration_coordinator, method) diff --git a/llumnix/llumlet/local_migration_scheduler.py b/llumnix/llumlet/local_migration_scheduler.py index e630d982..4f30f850 100644 --- a/llumnix/llumlet/local_migration_scheduler.py +++ b/llumnix/llumlet/local_migration_scheduler.py @@ -11,63 +11,92 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Deque, List import numpy as np -from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType +from llumnix.llumlet.request import LlumnixRequest, RequestStatus, RequestInferenceType from llumnix.backends.backend_interface import BackendInterface + class LocalMigrationScheduler: def __init__(self, request_migration_policy: str, backend_engine: BackendInterface) -> None: self.request_migration_policy = request_migration_policy self.backend_engine = backend_engine - def get_migrate_out_request(self, min_request_len=0, max_request_len=np.inf) -> Optional[LlumnixRequest]: - # Requests meet the strict pre-migration always have higher prioirity than other migration policy. - migrate_out_request = self.get_ready_migration_request(min_request_len, max_request_len) - if migrate_out_request is None: - if self.request_migration_policy == 'LCFS': - migrate_out_request = self.get_last_running_request(min_request_len, max_request_len) - elif self.request_migration_policy == 'LJF': - migrate_out_request = self.get_longest_running_request(min_request_len, max_request_len) - elif self.request_migration_policy == 'SJF': - migrate_out_request = self.get_shortest_running_request(min_request_len, max_request_len) - return migrate_out_request + def get_migrate_out_requests(self, min_request_len=0, max_request_len=np.inf) -> List[LlumnixRequest]: + # Requests meet the strict pre-migration always have higher prioirity than other migration policy. + migrate_out_requests: List[LlumnixRequest] = self.get_required_migration_request() + if len(migrate_out_requests) == 0: + if self.request_migration_policy == 'LCR': + migrate_out_requests = self._get_last_running_request(min_request_len, max_request_len) + elif self.request_migration_policy == 'LR': + migrate_out_requests = self._get_longest_running_request(min_request_len, max_request_len) + elif self.request_migration_policy == 'SR': + migrate_out_requests = self._get_shortest_running_request(min_request_len, max_request_len) + elif self.request_migration_policy == 'FCW': + migrate_out_requests = self._get_first_waiting_request(min_request_len, max_request_len) + elif self.request_migration_policy == 'FCWSR': + migrate_out_requests = self._get_first_waiting_and_shortest_running_requests(min_request_len, max_request_len) + return migrate_out_requests # The function is used to retrieve requests on the backend that have already met the expected_steps. - # TODO(xinyi): Currently, the function is only used for Prefill-decoding disaggregation, + # (xinyi): Currently, the function is only used for Prefill-decoding disaggregation, # and only selects request that migrates from the prefill instance to the decoding instance. - def get_ready_migration_request(self, min_request_len, max_request_len): + def get_required_migration_request(self): running: List[LlumnixRequest] = self.backend_engine.get_running_queue() + required_migration_requests = [] for request in reversed(running): - if request.output_len >= request.expected_steps \ + if request.status == RequestStatus.RUNNING \ and request.inference_type == RequestInferenceType.DECODE \ - and min_request_len <= request.request_len <= max_request_len: - return request - return None + and request.output_len >= request.expected_steps: + required_migration_requests.append(request) + return required_migration_requests - def get_last_running_request(self, min_request_len, max_request_len): - running: List[LlumnixRequest] = self.backend_engine.get_running_queue() - for request in reversed(running): - if request.inference_type == RequestInferenceType.DECODE \ - and min_request_len <= request.request_len <= max_request_len: - return request - return None + def _filter_running_queue(self, running, min_request_len, max_request_len): + filtered_running = [ + request for request in running \ + if request.status == RequestStatus.RUNNING \ + and request.inference_type == RequestInferenceType.DECODE \ + and min_request_len < request.request_len < max_request_len \ + ] + return filtered_running - def get_longest_running_request(self, min_request_len, max_request_len): - running: List[LlumnixRequest] = self.backend_engine.get_running_queue() - condition = lambda request : request.inference_type == RequestInferenceType.DECODE \ - and min_request_len <= request.request_len <= max_request_len + def _filter_waiting_queue(self, waiting, min_request_len, max_request_len): + filtered_waiting = [ + request for request in waiting \ + if request.status == RequestStatus.WAITING \ + and request.try_schedule_times >= 1 \ + and min_request_len < request.request_len < max_request_len \ + ] + return filtered_waiting - longest_seq_group = max((request for request in running if condition(request)), \ - key=lambda request: request.request_len, default=None) - return longest_seq_group + def _get_last_running_request(self, min_request_len, max_request_len): + running: Deque[LlumnixRequest] = self.backend_engine.get_running_queue() + filtered_running = self._filter_running_queue(running, min_request_len, max_request_len) + return [filtered_running[-1]] if filtered_running else [] - def get_shortest_running_request(self, min_request_len, max_request_len): - running: List[LlumnixRequest] = self.backend_engine.get_running_queue() - condition = lambda request : request.inference_type == RequestInferenceType.DECODE \ - and min_request_len <= request.request_len <= max_request_len + def _get_longest_running_request(self, min_request_len, max_request_len) -> List[LlumnixRequest]: + running: Deque[LlumnixRequest] = self.backend_engine.get_running_queue() + filtered_running = self._filter_running_queue(running, min_request_len, max_request_len) + longest_seq_group = max((request for request in filtered_running), \ + key=lambda request: request.request_len, default=None) + return [longest_seq_group] if longest_seq_group is not None else [] + + def _get_shortest_running_request(self, min_request_len, max_request_len) -> List[LlumnixRequest]: + running: Deque[LlumnixRequest] = self.backend_engine.get_running_queue() + filtered_running = self._filter_running_queue(running, min_request_len, max_request_len) + shortest_seq_group = min((request for request in filtered_running), \ + key=lambda request: request.request_len, default=None) + return [shortest_seq_group] if shortest_seq_group is not None else [] + + def _get_first_waiting_request(self, min_request_len, max_request_len) -> List[LlumnixRequest]: + waiting: Deque[LlumnixRequest] = self.backend_engine.get_waiting_queue() + filtered_waiting = self._filter_waiting_queue(waiting, min_request_len, max_request_len) + return [waiting[0]] if filtered_waiting else [] - shortest_seq_group = min((request for request in running if condition(request)), \ - key=lambda request: request.request_len, default=None) - return shortest_seq_group + def _get_first_waiting_and_shortest_running_requests(self, min_request_len, max_request_len) -> List[LlumnixRequest]: + waiting_requests = self._get_first_waiting_request(min_request_len, max_request_len) + running_requests = self._get_shortest_running_request(min_request_len, max_request_len) + if waiting_requests: + waiting_requests[0].eom = True + return waiting_requests + running_requests diff --git a/llumnix/llumlet/migration_coordinator.py b/llumnix/llumlet/migration_coordinator.py index 03b20cb2..224c41c3 100644 --- a/llumnix/llumlet/migration_coordinator.py +++ b/llumnix/llumlet/migration_coordinator.py @@ -19,7 +19,7 @@ import ray from llumnix.logger import init_logger -from llumnix.llumlet.request import LlumnixRequest +from llumnix.llumlet.request import LlumnixRequest, RequestStatus from llumnix.backends.backend_interface import BackendInterface logger = init_logger(__name__) @@ -27,18 +27,16 @@ class MigrationStatus(enum.Enum): """Status of Migration.""" RUNNING = enum.auto() - # aborted by src instance - ABORTED_SRC = enum.auto() - # aborted by dst instance ABORTED_DST = enum.auto() - FINISHED_DONE = enum.auto() + ABORTED_SRC = enum.auto() + FINISHED = enum.auto() @staticmethod def is_finished(status: "MigrationStatus") -> bool: return status in [ - MigrationStatus.ABORTED_SRC, MigrationStatus.ABORTED_DST, - MigrationStatus.FINISHED_DONE + MigrationStatus.ABORTED_SRC, + MigrationStatus.FINISHED ] class MigrationCoordinator: @@ -50,36 +48,88 @@ def __init__(self, self.max_stages = max_stages self.backend_engine = backend_engine - async def migrate_out_onestage(self, migrate_in_ray_actor: "ray.actor.ActorHandle", migrate_out_request: LlumnixRequest, ) -> "MigrationStatus": - """one-stage live migration until last stage + async def migrate_out_running_request(self, + migrate_in_ray_actor: "ray.actor.ActorHandle", + migrate_out_request: LlumnixRequest) -> "MigrationStatus": + return await self._migrate_out_multistage(migrate_in_ray_actor, migrate_out_request) + + async def migrate_out_waiting_request(self, + migrate_in_ray_actor: "ray.actor.ActorHandle", + migrate_out_request: LlumnixRequest) -> "MigrationStatus": + """one-stage migration for a waiting request + """ + found = self.backend_engine.remove_waiting_request(migrate_out_request.request_id) + if not found: + return MigrationStatus.ABORTED_SRC + self.backend_engine.add_migrating_out_request_last_stage(migrate_out_request) + dst_blocks = await migrate_in_ray_actor.execute_migration_method \ + .remote("migrate_in_pre_alloc", migrate_out_request.request_id, + migrate_out_request.status, + migrate_out_request.arrival_time, + migrate_out_request.prefill_num_blocks) + if len(dst_blocks) != migrate_out_request.prefill_num_blocks: + self.backend_engine.add_waiting_request(migrate_out_request) + self.backend_engine.remove_migrating_out_request_last_stage(migrate_out_request) + return MigrationStatus.ABORTED_DST + + return MigrationStatus.FINISHED + + async def _migrate_out_multistage(self, + migrate_in_ray_actor: "ray.actor.ActorHandle", + migrate_out_request: LlumnixRequest) -> "MigrationStatus": + """Migrate out requests to a specified instance, return migrated request id. + Args: + migrate_in_ray_actor: instance actor name, used to get ray actor handle + """ + stage_count = 0 + while stage_count < self.max_stages: + stage_count += 1 + status = await self._migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) + if MigrationStatus.is_finished(status): + return status + # exceed max stages + return MigrationStatus.ABORTED_SRC + + async def _migrate_out_onestage(self, + migrate_in_ray_actor: "ray.actor.ActorHandle", + migrate_out_request: LlumnixRequest) -> "MigrationStatus": + """one-stage live migration until last stage for a running request """ pre_stage_num_blocks = sum(migrate_out_request.stage_num_blocks_list) incremental_blocks = self.backend_engine.get_request_incremental_blocks(migrate_out_request, pre_stage_num_blocks) # live migration, transfer all blocks except last one(currently updating) - migration_status = MigrationStatus.RUNNING is_last_stage = (len(incremental_blocks) <= self.last_stage_max_blocks) or migrate_out_request.blocking_migration if not is_last_stage: + migration_status = MigrationStatus.RUNNING src_blocks = incremental_blocks[:-1] stage_block_num = len(incremental_blocks) - 1 dst_blocks = await migrate_in_ray_actor.execute_migration_method \ - .remote("migrate_in_pre_alloc", migrate_out_request.request_id, stage_block_num) + .remote("migrate_in_pre_alloc", migrate_out_request.request_id, + migrate_out_request.status, + migrate_out_request.arrival_time, + stage_block_num) else: # last stage migration, stop inference, transfer all blocks - migration_status = MigrationStatus.FINISHED_DONE - self.backend_engine.remove_running_request(migrate_out_request.request_id) + migration_status = MigrationStatus.FINISHED + found = self.backend_engine.remove_running_request(migrate_out_request.request_id) + if not found: + return MigrationStatus.ABORTED_SRC self.backend_engine.add_migrating_out_request_last_stage(migrate_out_request) - stage_block_num = len(incremental_blocks) src_blocks = incremental_blocks[:] + stage_block_num = len(incremental_blocks) dst_blocks = await migrate_in_ray_actor.execute_migration_method \ - .remote("migrate_in_pre_alloc", migrate_out_request.request_id, stage_block_num) + .remote("migrate_in_pre_alloc", migrate_out_request.request_id, + migrate_out_request.status, + migrate_out_request.arrival_time, + stage_block_num) if len(dst_blocks) != len(src_blocks): - # migrate-in instance failed to prev alloc + # migrate-in instance failed to pre alloc if is_last_stage: self.backend_engine.add_running_request(migrate_out_request) self.backend_engine.remove_migrating_out_request_last_stage(migrate_out_request) - migration_status = MigrationStatus.ABORTED_DST - return migration_status + return MigrationStatus.ABORTED_DST + # do stage send/recv migrate_out_request.stage_timestamps.append(time.time()) migrate_out_request.stage_num_blocks_list.append(stage_block_num) @@ -87,28 +137,21 @@ async def migrate_out_onestage(self, migrate_in_ray_actor: "ray.actor.ActorHandl await self.backend_engine.send_blocks(migrate_in_ray_actor, src_blocks, dst_blocks) if not is_last_stage and migrate_out_request.should_abort_migration(): # migrate-out request abort by scheduler during send/recv - migration_status = MigrationStatus.ABORTED_SRC + return MigrationStatus.ABORTED_SRC return migration_status - async def migrate_out_multistage(self, migrate_in_ray_actor: "ray.actor.ActorHandle", migrate_out_request: LlumnixRequest) -> "MigrationStatus": - """Migrate out requests to a specified instance, return migrated request id. - Args: - dst_instance_name:instance actor name, used to get ray actor handle - """ - state_count = 0 - while state_count < self.max_stages: - state_count += 1 - status = await self.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) - if MigrationStatus.is_finished(status): - return status - # exceed max stages - return MigrationStatus.ABORTED_SRC - - def migrate_in_pre_alloc(self, request_id: str, block_num: int) -> List[int]: + def migrate_in_pre_alloc(self, + request_id: str, + request_status: RequestStatus, + request_arrival_time: float, + block_num: int) -> List[int]: """prev alloc blocks to migrate in request """ - pre_alloc_blocks = self.backend_engine.pre_alloc(request_id ,block_num) + pre_alloc_blocks = self.backend_engine.pre_alloc(request_id, + request_status, + request_arrival_time, + block_num) if len(pre_alloc_blocks) != block_num: # failed to alloc, abort request self.free_dst_pre_alloc_cache(request_id) diff --git a/llumnix/llumlet/request.py b/llumnix/llumlet/request.py index 2319f52f..d92e6564 100644 --- a/llumnix/llumlet/request.py +++ b/llumnix/llumlet/request.py @@ -20,6 +20,13 @@ class RequestInferenceType(str, Enum): PREFILL = "prefill" DECODE = "decode" +class RequestStatus(str, Enum): + RUNNING = "running" + WAITING = "waiting" + FINISHED = "finished" + RUNNING_MIGRATING = "running_migrating" + WAITING_MIGRATING = "waiting_migrating" + class LlumnixRequest: def __init__(self, request_id: int, server_info: ServerInfo, expected_steps: int) -> None: self.request_id = request_id @@ -32,16 +39,31 @@ def __init__(self, request_id: int, server_info: ServerInfo, expected_steps: int self.last_preemption_time = None self.stage_timestamps = [] self.stage_num_blocks_list = [] + self.try_schedule_times = 0 + self._status = None + + # end-of-migration, for multiple requests migration + self.eom = False + + def reset_migration_args_dst(self): + # By default, there is no limit on the number of steps expected for the request. + self.expected_steps = math.inf - def reset_migration_args(self): self.last_preemption_time = None self.stage_timestamps = [] self.stage_num_blocks_list = [] - # By default, there is no limit on the number of steps expected for the request. - self.expected_steps = math.inf + self.try_schedule_times = 0 - def is_finished(self) -> bool: - raise NotImplementedError + def reset_migration_args_src(self): + self.last_preemption_time = None + self.stage_timestamps = [] + self.stage_num_blocks_list = [] + + def reset_status(self): + self._status = None + + def set_status(self, status: RequestStatus): + self._status = status @property def inference_type(self) -> RequestInferenceType: @@ -59,6 +81,22 @@ def prompt_len(self) -> int: def output_len(self) -> int: raise NotImplementedError + @property + def finished(self) -> bool: + raise NotImplementedError + + @property + def arrival_time(self) -> float: + raise NotImplementedError + + @property + def status(self) -> RequestStatus: + raise NotImplementedError + + @property + def prefill_num_blocks(self) -> int: + raise NotImplementedError + # Whether the migration of request is completed within one stage. For requests that have already reached # the expected steps, blocking_migration is True. @property @@ -66,7 +104,5 @@ def blocking_migration(self) -> bool: return self.output_len >= self.expected_steps def should_abort_migration(self) -> bool: - return self.output_len == 0 \ - or (self.last_preemption_time and self.last_preemption_time > self.stage_timestamps[-1]) \ - or self.inference_type == RequestInferenceType.PREFILL \ - or self.is_finished() + return self.finished \ + or (self.last_preemption_time is not None and self.last_preemption_time > self.stage_timestamps[-1]) diff --git a/tests/conftest.py b/tests/conftest.py index 2749ba00..ba3b467c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,20 +12,16 @@ # limitations under the License. import subprocess -from time import sleep import ray import pytest def pytest_sessionstart(session): - subprocess.run(["ray", "stop", "--force"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - sleep(3) + subprocess.run(["ray", "stop"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) subprocess.run(["ray", "start", "--head", "--disable-usage-stats", "--port=6379"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - sleep(3) def pytest_sessionfinish(session, exitstatus): subprocess.run(["ray", "stop"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - sleep(3) @pytest.fixture def setup_ray_env(): diff --git a/tests/e2e_test/test_bench.py b/tests/e2e_test/test_bench.py index b6d70d8f..5eba27d1 100644 --- a/tests/e2e_test/test_bench.py +++ b/tests/e2e_test/test_bench.py @@ -20,7 +20,8 @@ import numpy as np from .test_e2e import generate_launch_command, clear_ray_state -from .utils import to_markdown_table +# pylint: disable=unused-import +from .utils import to_markdown_table, setup_ray_env def launch_llumnix_service(command): subprocess.run(command, shell=True, check=True) @@ -90,7 +91,7 @@ def get_markdown_data(key: str, head_name: str): @pytest.mark.asyncio @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="at least 1 gpus required for simple benchmark") @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) -async def test_simple_benchmark(model): +async def test_simple_benchmark(setup_ray_env, model): device_count = torch.cuda.device_count() base_port = 37037 for i in range(device_count): @@ -107,7 +108,7 @@ async def run_bench_command(command): tasks = [] for i in range(device_count): - bench_command = generate_bench_command(ip_ports=f"127.0.0.1:{base_port+i}", model=model, num_prompts=500, + bench_command = generate_bench_command(ip_ports=f"127.0.0.1:{base_port+i}", model=model, num_prompts=300, dataset_type="sharegpt", dataset_path="/mnt/dataset/sharegpt_gpt4/sharegpt_gpt4.jsonl" , qps=2, diff --git a/tests/e2e_test/test_e2e.py b/tests/e2e_test/test_e2e.py index 42f92512..a3bf1977 100644 --- a/tests/e2e_test/test_e2e.py +++ b/tests/e2e_test/test_e2e.py @@ -19,7 +19,8 @@ import torch from vllm import LLM, SamplingParams - +# pylint: disable=unused-import +from .utils import setup_ray_env def parse_launch_mode(launch_mode: str): # 'eief' means that enable init instance by manager and enable fixed node init instance, and so on. @@ -40,11 +41,12 @@ def parse_launch_mode(launch_mode: str): def generate_launch_command(result_filename: str = "", launch_ray_cluster: bool = True, HEAD_NODE_IP: str = "127.0.0.1", ip: str = "127.0.0.1", port: int = 37000, instances_num = 1, dispatch_policy: str = "load", migration_backend = "gloo", model = "facebook/opt-125m", max_model_len: int = 2048, - launch_mode: str = 'eief', log_instance_info: bool = False): + launch_mode: str = 'eief', log_instance_info: bool = False, + request_migration_policy: str = 'SR'): disable_init_instance_by_manager, disable_fixed_node_init_instance = parse_launch_mode(launch_mode) command = ( f"RAY_DEDUP_LOGS=0 HEAD_NODE_IP={HEAD_NODE_IP} HEAD_NODE=1 " - f"nohup python -m llumnix.entrypoints.vllm.api_server " + f"nohup python -u -m llumnix.entrypoints.vllm.api_server " f"--host {ip} " f"--port {port} " f"{'--disable-init-instance-by-manager ' if disable_init_instance_by_manager else ''}" @@ -59,9 +61,10 @@ def generate_launch_command(result_filename: str = "", launch_ray_cluster: bool f"--max-model-len {max_model_len} " f"--dispatch-policy {dispatch_policy} " f"--trust-remote-code " - f"--request-migration-policy LCFS " + f"--request-migration-policy {request_migration_policy} " f"--migration-backend {migration_backend} " - f"--migration-cache-blocks 32 " + f"--migration-buffer-blocks 32 " + f"--migration-internal-buffer-num 2 " f"--tensor-parallel-size 1 " f"--request-output-queue-port {1234+port} " f"{'--launch-ray-cluster ' if launch_ray_cluster else ''}" @@ -98,7 +101,7 @@ def clear_ray_state(): continue ray.shutdown() -async def get_llumnix_responce(prompt, sampling_params, ip_ports): +async def get_llumnix_response(prompt, sampling_params, ip_ports): timeout = aiohttp.ClientTimeout(total=60) request = { @@ -119,6 +122,8 @@ async def get_llumnix_responce(prompt, sampling_params, ip_ports): "The future of AI is", ] +vllm_output = {} + @ray.remote(num_gpus=1) def run_vllm(model, max_model_len, sampling_params): vllm_output = {} @@ -133,9 +138,9 @@ def run_vllm(model, max_model_len, sampling_params): @pytest.mark.asyncio @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="at least 1 gpus required for e2e test") @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) -@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo', 'nccl']) +@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo']) @pytest.mark.parametrize("launch_mode", ['eief', 'eidf', 'dief', 'didf']) -async def test_e2e(model, migration_backend, launch_mode): +async def test_e2e(setup_ray_env, model, migration_backend, launch_mode): if migration_backend == 'gloo' and launch_mode != 'eief': pytest.skip("When the migration backend is gloo, the launch mode of llumnix can only be eief") max_model_len = 370 @@ -155,15 +160,18 @@ async def test_e2e(model, migration_backend, launch_mode): llumnix_output = {} for prompt in prompts: - response = await asyncio.wait_for(get_llumnix_responce(prompt, sampling_params, f"127.0.0.1:{base_port}"), + response = await asyncio.wait_for(get_llumnix_response(prompt, sampling_params, f"127.0.0.1:{base_port}"), timeout=60*5) llumnix_output[prompt] = response['text'][0] shutdown_llumnix_service() - vllm_output = ray.get(run_vllm.remote(model, max_model_len, sampling_params)) - clear_ray_state() + global vllm_output + if len(vllm_output) == 0: + vllm_output = ray.get(run_vllm.remote(model, max_model_len, sampling_params)) + + clear_ray_state() # compare for prompt in prompts: assert llumnix_output[prompt] == vllm_output[prompt] diff --git a/tests/e2e_test/test_migration.py b/tests/e2e_test/test_migration.py index 7fe167bb..ced1e0be 100644 --- a/tests/e2e_test/test_migration.py +++ b/tests/e2e_test/test_migration.py @@ -21,7 +21,8 @@ from .test_e2e import generate_launch_command from .test_bench import generate_bench_command, clear_ray_state, shutdown_llumnix_service -from .utils import to_markdown_table +# pylint: disable=unused-import +from .utils import to_markdown_table, setup_ray_env size_pattern = re.compile(r'total_kv_cache_size:\s*([\d.]+)\s*(B|KB|MB|GB|KB|TB)') speed_pattern = re.compile(r'speed:\s*([\d.]+)GB/s') @@ -41,18 +42,18 @@ def parse_instance_log_file(log_files): speed = float(speed_match.group(1)) speed_dict[total_kv_cache_size].append(speed) - averger_speed = {} + average_speed = {} for transfer_size, speeds in speed_dict.items(): if len(speeds) <= 2: continue speeds.sort() trimmed_speeds = speeds[1:-1] - averger_speed[transfer_size] = sum(trimmed_speeds) / len(trimmed_speeds) + average_speed[transfer_size] = sum(trimmed_speeds) / len(trimmed_speeds) - assert len(averger_speed) > 0, "Migration should have occurred, but it was not detected. " + assert len(average_speed) > 0, "Migration should have occurred, but it was not detected. " - return averger_speed + return average_speed def parse_manager_log_file(log_file): df = pd.read_csv(log_file) @@ -65,8 +66,14 @@ def parse_manager_log_file(log_file): @pytest.mark.asyncio @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="at least 2 gpus required for migration bench") @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) -@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo', 'nccl']) -async def test_migration_benchmark(model, migration_backend): +@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo']) +@pytest.mark.parametrize("migrated_request_status", ['running', 'waiting']) +async def test_migration_benchmark(model, migration_backend, migrated_request_status): + if migrated_request_status == 'waiting' and migration_backend != 'rpc': + pytest.skip("When the migrated request status is waiting, only test the rpc migration backend.") + + request_migration_policy = 'SR' if migrated_request_status == 'running' else 'FCW' + base_port = 37037 instance_output_logs = [] @@ -76,37 +83,45 @@ async def test_migration_benchmark(model, migration_backend): instance_output_logs.append("instance_"+output_log) launch_command = generate_launch_command(result_filename=output_log, launch_ray_cluster=False, port=base_port+i, model=model, dispatch_policy="flood", migration_backend=migration_backend, - log_instance_info=True) + log_instance_info=True, + request_migration_policy=request_migration_policy) subprocess.run(launch_command, shell=True, check=True) - await asyncio.sleep(60) + await asyncio.sleep(5) + await asyncio.sleep(30) async def run_bench_command(command): process = await asyncio.create_subprocess_shell(command) await process.wait() assert process.returncode == 0 + tasks = [] for i in range(device_count//2): bench_command = generate_bench_command(ip_ports=f"127.0.0.1:{base_port+i}", model=model, num_prompts=300, dataset_type="sharegpt", dataset_path="/mnt/dataset/sharegpt_gpt4/sharegpt_gpt4.jsonl" , - qps=10) - await asyncio.wait_for(run_bench_command(bench_command), timeout=60*30) - await asyncio.sleep(30) + qps=10, + results_filename=f"{base_port+i}.out") + tasks.append(asyncio.create_task(run_bench_command(bench_command))) - parse_manager_log_file("manager_instance.csv") + _, pending = await asyncio.wait(tasks, timeout=60*30) - averger_speed = parse_instance_log_file(instance_output_logs) + await asyncio.sleep(10) - sorted_keys = sorted(averger_speed.keys(), key=lambda x: float(x.split()[0])) + if len(pending) > 0: + raise RuntimeError("migration task Timeout") - data = [ - ['migration_size'] + sorted_keys, - [f'{migration_backend}_speed(GB/s)'] + [f"{averger_speed[key]:.2f}" for key in sorted_keys] - ] + parse_manager_log_file("manager_instance.csv") - with open("performance.txt", "a", encoding="utf-8") as f: - f.write(to_markdown_table(data)) + if migrated_request_status == 'running': + average_speed = parse_instance_log_file(instance_output_logs) + sorted_keys = sorted(average_speed.keys(), key=lambda x: float(x.split()[0])) + data = [ + ['migration_size'] + sorted_keys, + [f'{migration_backend}_speed(GB/s)'] + [f"{average_speed[key]:.2f}" for key in sorted_keys] + ] + with open("performance.txt", "a", encoding="utf-8") as f: + f.write(to_markdown_table(data)) shutdown_llumnix_service() clear_ray_state() - await asyncio.sleep(3) + await asyncio.sleep(10) diff --git a/tests/e2e_test/utils.py b/tests/e2e_test/utils.py index 62d9bff8..1c38dcc8 100644 --- a/tests/e2e_test/utils.py +++ b/tests/e2e_test/utils.py @@ -11,6 +11,10 @@ # See the License for the specific language governing permissions and # limitations under the License. + +import subprocess +import pytest + def to_markdown_table(data): headers = data[0] rows = data[1:] @@ -27,3 +31,11 @@ def to_markdown_table(data): table = f"{header_row}\n{separator_row}\n" + "\n".join(data_rows) + "\n\n" return table + +@pytest.fixture +def setup_ray_env(): + subprocess.run(["ray", "stop"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + subprocess.run(["ray", "start", "--head", "--disable-usage-stats", "--port=6379"], check=True, + stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + yield + subprocess.run(["ray", "stop"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) diff --git a/tests/unit_test/backends/vllm/test_migration.py b/tests/unit_test/backends/vllm/test_migration.py index 2a8ad19e..b74950c2 100644 --- a/tests/unit_test/backends/vllm/test_migration.py +++ b/tests/unit_test/backends/vllm/test_migration.py @@ -11,20 +11,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List import asyncio import math +from unittest.mock import MagicMock import pytest import ray +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy from vllm import EngineArgs, SamplingParams from vllm.utils import random_uuid +from vllm.sequence import SequenceStatus from llumnix.backends.vllm.llm_engine import BackendVLLM from llumnix.llumlet.llumlet import Llumlet from llumnix.backends.utils import BackendType from llumnix.internal_config import MigrationConfig -from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType +from llumnix.llumlet.request import RequestInferenceType, RequestStatus from llumnix.queue.queue_type import QueueType from tests.unit_test.queue.utils import request_output_queue_server @@ -51,22 +53,60 @@ def __init__(self): self.instance_id = "0" self.backend_engine = MockBackendVLLM() +@ray.remote(num_cpus=1, max_concurrency=4) +class MockLlumletDoNotSchedule(Llumlet): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # stop the schedule in engine step loop + self.backend_engine.engine.scheduler.schedule = MagicMock() + + # For some reason, if MockScheduelrOutputs is defined outside, the constructor would raise error. + class MockScheduelrOutputs: + def __init__(self): + self.scheduled_seq_groups = [] + self.ignored_seq_groups = [] + self.num_batched_tokens = 0 + + def is_empty(self) -> bool: + return not self.scheduled_seq_groups + + scheduler_outputs = MockScheduelrOutputs() + self.backend_engine.engine.scheduler.schedule.return_value = ([], scheduler_outputs) + + self.step_async = self.backend_engine.engine.step_async + + async def step_async_try_schedule(): + request_outputs, server_infos = await self.step_async() + for seq_group in self.backend_engine.engine.scheduler.waiting: + seq_group.try_schedule_times += 1 + return request_outputs, server_infos + + self.backend_engine.engine.step_async = step_async_try_schedule + +# TODO(s5u13b): Test migrate waiting request. @pytest.mark.parametrize("migration_backend", ['rpc', 'gloo', 'nccl']) +@pytest.mark.parametrize("migration_request_status", ['waiting', 'running']) @pytest.mark.asyncio -async def test_migration_correctness(setup_ray_env, migration_backend): +async def test_migration_correctness(setup_ray_env, migration_backend, migration_request_status): engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) - id_rank_map = {"0":0, "1":1} - migration_config = MigrationConfig("LCFS", migration_backend, 16, 1, 4, 5, 20) + id_rank_map = {"0": 0, "1": 1, "2": 2} + if migration_request_status == 'running': + request_migration_policy = "SR" + elif migration_request_status == 'waiting': + request_migration_policy = "FCW" + migration_config = MigrationConfig(request_migration_policy, migration_backend, 16, 1, 4, 5, 20, 2) output_queue_type = QueueType.RAYQUEUE que, server_info = request_output_queue_server(output_queue_type) asyncio.create_task(que.run_server_loop()) + node_id = ray.get_runtime_context().get_node_id() + scheduling_strategy = NodeAffinitySchedulingStrategy(node_id=node_id, soft=False) llumlet_0: Llumlet = Llumlet.from_args( output_queue_type, False, - True, - ray.get_runtime_context().get_node_id(), + False, + node_id, "0", BackendType.VLLM, 1, @@ -76,24 +116,39 @@ async def test_migration_correctness(setup_ray_env, migration_backend): llumlet_1: Llumlet = Llumlet.from_args( output_queue_type, False, - True, - ray.get_runtime_context().get_node_id(), + False, + node_id, "1", BackendType.VLLM, 1, migration_config, engine_args) + llumlet_2: Llumlet = MockLlumletDoNotSchedule.options( + name='instance_2', + namespace='llumnix', + scheduling_strategy=scheduling_strategy).remote( + instance_id="2", + output_queue_type=output_queue_type, + backend_type=BackendType.VLLM, + migration_config=migration_config, + engine_args=engine_args, + node_id=node_id + ) + while True: - res = ray.get([llumlet_0.is_ready.remote(),llumlet_1.is_ready.remote()]) + res = ray.get([llumlet_0.is_ready.remote(), llumlet_1.is_ready.remote(), llumlet_2.is_ready.remote()]) if all(res): break ray.get([llumlet_0.execute_engine_method.remote("_run_workers", "rebuild_migration_backend", id_rank_map, "llumnix"), - llumlet_1.execute_engine_method.remote("_run_workers", "rebuild_migration_backend", id_rank_map, "llumnix")]) + llumlet_1.execute_engine_method.remote("_run_workers", "rebuild_migration_backend", id_rank_map, "llumnix"), + llumlet_2.execute_engine_method.remote("_run_workers", "rebuild_migration_backend", id_rank_map, "llumnix")]) # empty instance migrate out - res = ray.get(llumlet_0.migrate_out.remote("instance_1", num_requests=math.inf)) + res = ray.get(llumlet_0.migrate_out.remote("instance_1")) + assert not res + res = ray.get(llumlet_2.migrate_out.remote("instance_1")) assert not res # running without migration @@ -110,16 +165,28 @@ async def test_correctness(prompt): origin_output = request_output.outputs[0] finished = request_output.finished - request_id1 = random_uuid() - ray.get(llumlet_0.generate.remote(request_id1, server_info, math.inf, prompt, sampling_params)) - # wait prefill done - while True: - running_queue: List[LlumnixRequest] = ray.get(llumlet_0.execute_engine_method.remote("get_running_queue")) - if len(running_queue) > 0 and running_queue[0].inference_type == RequestInferenceType.DECODE: - break - # migrate request - res = ray.get(llumlet_0.migrate_out.remote("instance_1", num_requests=math.inf)) - assert len(res) == 1 + if migration_request_status == 'running': + request_id1 = random_uuid() + ray.get(llumlet_0.generate.remote(request_id1, server_info, math.inf, prompt, sampling_params)) + # wait prefill done + while True: + running_queue = ray.get(llumlet_0.execute_engine_method.remote("get_running_queue")) + if len(running_queue) > 0 and running_queue[0].inference_type == RequestInferenceType.DECODE: + break + # migrate request + res = ray.get(llumlet_0.migrate_out.remote("instance_1")) + assert len(res) == 1 + elif migration_request_status == 'waiting': + request_id1 = random_uuid() + ray.get(llumlet_2.generate.remote(request_id1, server_info, math.inf, prompt, sampling_params)) + # wait try schedule done + while True: + waiting_queue = ray.get(llumlet_2.execute_engine_method.remote("get_waiting_queue")) + if len(waiting_queue) > 0 and waiting_queue[0].try_schedule_times >= 1: + break + # migrate request + res = ray.get(llumlet_2.migrate_out.remote("instance_1")) + assert len(res) == 1 request_output_queue = que output = None @@ -127,10 +194,6 @@ async def test_correctness(prompt): while not finished: request_outputs = await request_output_queue.get() for request_output in request_outputs: - origin_output = request_output.outputs[0] - finished = request_output.finished - if request_output.request_id != request_id1: - continue output = request_output.outputs[0] finished = request_output.finished @@ -144,9 +207,9 @@ async def test_correctness(prompt): @pytest.mark.parametrize("migration_backend", ['rpc', 'gloo', 'nccl']) @pytest.mark.asyncio async def test_pd_diaggregation_correctness(setup_ray_env, migration_backend): - engine_args = EngineArgs(model="facebook/opt-125m",worker_use_ray=True) - id_rank_map = {"0":0,"1":1} - migration_config = MigrationConfig("LCFS", migration_backend, 16, 1, 4, 5, 20) + engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) + id_rank_map = {"0":0, "1":1} + migration_config = MigrationConfig("SR", migration_backend, 16, 1, 4, 5, 20, 2) output_queue_type = QueueType.RAYQUEUE que, server_info = request_output_queue_server(output_queue_type) @@ -174,14 +237,15 @@ async def test_pd_diaggregation_correctness(setup_ray_env, migration_backend): migration_config, engine_args, ) + while True: res = ray.get([llumlet_0.is_ready.remote(),llumlet_1.is_ready.remote()]) if all(res): break ray.get([llumlet_0.execute_engine_method.remote("_run_workers","rebuild_migration_backend", id_rank_map, "llumnix"), - llumlet_1.execute_engine_method.remote("_run_workers","rebuild_migration_backend", id_rank_map, "llumnix")]) + llumlet_1.execute_engine_method.remote("_run_workers","rebuild_migration_backend", id_rank_map, "llumnix")]) # empty instance migrate out - res = ray.get(llumlet_0.migrate_out.remote("instance_1", num_requests=math.inf)) + res = ray.get(llumlet_0.migrate_out.remote("instance_1")) assert not res # running without migration @@ -204,7 +268,7 @@ async def test_correctness(prompt): ray.get(llumlet_0.generate.remote(request_id1, server_info, request_expected_steps_id1, prompt, sampling_params)) # migrate request for decoding while True: - res = ray.get(llumlet_0.migrate_out.remote("instance_1", num_requests = math.inf)) + res = ray.get(llumlet_0.migrate_out.remote("instance_1")) if len(res) == 1: break request_output_queue = que @@ -213,28 +277,32 @@ async def test_correctness(prompt): while not finished: request_outputs = await request_output_queue.get() for request_output in request_outputs: - origin_output = request_output.outputs[0] + output = request_output.outputs[0] finished = request_output.finished - if request_output.request_id != request_id1: - continue - output = request_output.outputs[0] - finished = request_output.finished assert output.text == origin_output.text assert output.cumulative_logprob == origin_output.cumulative_logprob + for prompt in TEST_PROMPTS: await test_correctness(prompt) + que.cleanup() def test_clear_migration_states(): llumlet = MockLlumlet() - llumlet.backend_engine.pre_alloc("0", 1) + llumlet.backend_engine.pre_alloc("0", RequestStatus.RUNNING, 0.0, 1) num_gpu_blocks = 8 block_size = 4 llumlet.clear_migration_states(is_migrate_in=True) - assert len(llumlet.backend_engine.pre_alloc("0", num_gpu_blocks)) == num_gpu_blocks - _, seq_group = create_dummy_prompt("0",7,block_size) + assert len(llumlet.backend_engine.pre_alloc("0", RequestStatus.RUNNING, 0.0, num_gpu_blocks)) == num_gpu_blocks + _, seq_group = create_dummy_prompt("0",7,block_size,SequenceStatus.RUNNING) + seq_group.set_status(RequestStatus.RUNNING_MIGRATING) + llumlet.backend_engine.add_migrating_out_request_last_stage(seq_group) + llumlet.clear_migration_states(is_migrate_in=False) + assert len(llumlet.backend_engine.get_running_queue()) == 1 + _, seq_group = create_dummy_prompt("0",7,block_size,SequenceStatus.WAITING) + seq_group.set_status(RequestStatus.WAITING_MIGRATING) llumlet.backend_engine.add_migrating_out_request_last_stage(seq_group) llumlet.clear_migration_states(is_migrate_in=False) - assert len(llumlet.backend_engine.get_running_queue()) > 0 + assert len(llumlet.backend_engine.get_waiting_queue()) == 1 diff --git a/tests/unit_test/backends/vllm/test_migration_backend.py b/tests/unit_test/backends/vllm/test_migration_backend.py index 2bb008ee..12ec324c 100644 --- a/tests/unit_test/backends/vllm/test_migration_backend.py +++ b/tests/unit_test/backends/vllm/test_migration_backend.py @@ -26,6 +26,44 @@ from tests.conftest import setup_ray_env from .test_worker import create_worker +def get_ready_workers(num_worker, num_gpu_blocks, engine_config, migraiton_config): + workers = [] + worker_ids = [] + + for _ in range(num_worker): + worker_id = random_uuid() + worker = create_worker(rank=0, local_rank=0, engine_config=engine_config, + worker_module_name="tests.unit_test.backends.vllm.test_migration_backend", + worker_class_name="MockMigrationWorker") + ray.get(worker.execute_method.remote('initialize_cache', num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0)) + ray.get(worker.execute_method.remote( + 'init_migration', + instance_id=worker_id, + migration_config=migraiton_config, + src_worker_handle_list=[worker], + node_id=ray.get_runtime_context().get_node_id())) + + workers.append(worker) + worker_ids.append(worker_id) + + instance_rank = {} + for idx, worker_id in enumerate(worker_ids): + instance_rank[worker_id] = idx + group_name = random_uuid() + + init_group_tasks =[] + for worker in workers: + init_group_tasks.append(worker.execute_method.remote('rebuild_migration_backend', + instance_rank=instance_rank, group_name=group_name)) + assert all(ray.get(init_group_tasks)) + + warmup_tasks = [] + for worker in workers: + warmup_tasks.append(worker.execute_method.remote('warmup')) + assert all(ray.get(warmup_tasks)) + + return workers, worker_ids + class MockMigrationWorker(MigrationWorker): def set_gpu_cache(self, data): for layer_idx in range(self.cache_engine.num_layers): @@ -34,75 +72,120 @@ def set_gpu_cache(self, data): def get_gpu_cache(self): torch.cuda.synchronize() - return self.gpu_cache + gpu_data = [] + for layer_idx in range(self.cache_engine.num_layers): + gpu_data.append(self.gpu_cache[layer_idx].clone().cpu()) + return gpu_data -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPU to run the test.") -@pytest.mark.parametrize("backend", ['rpc', 'gloo', 'nccl']) -def test_migrate_cache(setup_ray_env, backend): +@pytest.mark.skipif(torch.cuda.device_count() < 3, reason="Need at least 3 GPU to run the test.") +@pytest.mark.parametrize("backend", ['rpc', 'gloo']) +def test_one_to_many_migrate_cache(setup_ray_env, backend): engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config() - migraiton_config = EngineManagerArgs(migration_cache_blocks=3, migration_num_layers=5).create_migration_config() + migration_internal_buffer_num = 2 + migraiton_config = EngineManagerArgs(migration_buffer_blocks=3, migration_num_layers=5, + migration_internal_buffer_num=migration_internal_buffer_num).create_migration_config() migraiton_config.migration_backend = backend - worker0 = create_worker(rank=0, local_rank=0, engine_config=engine_config, - worker_module_name="tests.unit_test.backends.vllm.test_migration_backend", - worker_class_name="MockMigrationWorker") - worker1 = create_worker(rank=0, local_rank=0, engine_config=engine_config, - worker_module_name="tests.unit_test.backends.vllm.test_migration_backend", - worker_class_name="MockMigrationWorker") - - ray.get(worker0.execute_method.remote('init_device')) - ray.get(worker1.execute_method.remote('init_device')) - - num_gpu_blocks = 8 - ray.get(worker0.execute_method.remote('initialize_cache', num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0)) - ray.get(worker1.execute_method.remote('initialize_cache', num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0)) - - worker0_id = random_uuid() - ray.get(worker0.execute_method.remote( - 'init_migration', - instance_id=worker0_id, - migration_config=migraiton_config, - src_worker_handle_list=[worker0], - node_id=ray.get_runtime_context().get_node_id())) - - worker1_id = random_uuid() - ray.get(worker1.execute_method.remote( - 'init_migration', - instance_id=worker1_id, - migration_config=migraiton_config, - src_worker_handle_list=[worker1], - node_id=ray.get_runtime_context().get_node_id())) - - instance_rank = {worker0_id: 0, worker1_id: 1} - group_name = random_uuid() - assert all(ray.get([worker0.execute_method.remote('rebuild_migration_backend', - instance_rank=instance_rank, group_name=group_name), - worker1.execute_method.remote('rebuild_migration_backend', - instance_rank=instance_rank, group_name=group_name)])) - assert all(ray.get([worker0.execute_method.remote('warmup'), - worker1.execute_method.remote('warmup')])) + num_worker = 3 + num_gpu_blocks = 6000 + workers, _ = get_ready_workers(num_worker, num_gpu_blocks, engine_config, migraiton_config) num_layers = engine_config.model_config.get_num_layers(engine_config.parallel_config) head_size = engine_config.model_config.get_head_size() num_heads = engine_config.model_config.get_num_kv_heads(engine_config.parallel_config) block_size = engine_config.cache_config.block_size + dummy_data = torch.randn(size=(num_layers, 2, num_gpu_blocks, block_size*num_heads*head_size)) + ray.get(workers[0].execute_method.remote('set_gpu_cache', data=dummy_data)) + worker0_data = ray.get(workers[0].execute_method.remote('get_gpu_cache')) + + dst_blocks = list(range(num_gpu_blocks)) + random.shuffle(dst_blocks) + + single_worker_num_blocks = len(dst_blocks)//(num_worker-1) + migration_tasks = [] + worker_idx = 1 + per_step_blocks = 500 + for offset in range(0, len(dst_blocks), single_worker_num_blocks): + src_to_dst = dict(enumerate(dst_blocks[offset:offset+single_worker_num_blocks])) + src_blocks = list(src_to_dst.keys()) + dst_blocks = list(src_to_dst.values()) + for idx in range(0, len(src_blocks), per_step_blocks): + cur_src_blocks = src_blocks[idx:idx+per_step_blocks] + cur_dst_blocks = dst_blocks[idx:idx+per_step_blocks] + migration_tasks.append(workers[0].execute_method.remote( + 'migrate_cache', + src_worker_handle_list=[workers[worker_idx]], + src_blocks=cur_src_blocks, + dst_blocks=cur_dst_blocks) + ) + worker_idx += 1 + ray.get(migration_tasks) + + worker_idx = 1 + for offset in range(0, len(dst_blocks), single_worker_num_blocks): + src_to_dst = dict(enumerate(dst_blocks[offset:offset+single_worker_num_blocks])) + dst_worker_data = ray.get(workers[worker_idx].execute_method.remote('get_gpu_cache')) + for layer_idx in range(num_layers): + for src_idx, dst_idx in src_to_dst.items(): + assert torch.allclose(worker0_data[layer_idx][0][src_idx], dst_worker_data[layer_idx][0][dst_idx]) + assert torch.allclose(worker0_data[layer_idx][1][src_idx], dst_worker_data[layer_idx][1][dst_idx]) + worker_idx += 1 + +@pytest.mark.skipif(torch.cuda.device_count() < 3, reason="Need at least 3 GPU to run the test.") +@pytest.mark.parametrize("backend", ['rpc', 'gloo']) +def test_many_to_one_migrate_cache(setup_ray_env, backend): + engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config() + migration_internal_buffer_num = 2 + migraiton_config = EngineManagerArgs(migration_buffer_blocks=3, migration_num_layers=5, + migration_internal_buffer_num=migration_internal_buffer_num).create_migration_config() + migraiton_config.migration_backend = backend + num_worker = 3 + num_gpu_blocks = 6000 + workers, _ = get_ready_workers(num_worker, num_gpu_blocks, engine_config, migraiton_config) + + num_layers = engine_config.model_config.get_num_layers(engine_config.parallel_config) + head_size = engine_config.model_config.get_head_size() + num_heads = engine_config.model_config.get_num_kv_heads(engine_config.parallel_config) + block_size = engine_config.cache_config.block_size dummy_data = torch.randn(size=(num_layers, 2, num_gpu_blocks, block_size*num_heads*head_size)) - ray.get(worker0.execute_method.remote('set_gpu_cache', data=dummy_data)) - worker0_data = ray.get(worker0.execute_method.remote('get_gpu_cache')) + + worker_datas = [0] + for idx in range(1, num_worker): + ray.get(workers[idx].execute_method.remote('set_gpu_cache', data=dummy_data)) + worker_datas.append(ray.get(workers[idx].execute_method.remote('get_gpu_cache'))) dst_blocks = list(range(num_gpu_blocks)) random.shuffle(dst_blocks) - src_to_dst = dict(enumerate(dst_blocks)) - ray.get(worker1.execute_method.remote( - 'migrate_cache', - src_worker_handle_list=[worker0], - src_blocks=list(src_to_dst.keys()), - dst_blocks=list(src_to_dst.values()))) - - worker1_data = ray.get(worker1.execute_method.remote('get_gpu_cache')) - - for layer_idx in range(num_layers): - for src_idx, dst_idx in src_to_dst.items(): - assert torch.allclose(worker0_data[layer_idx][0][src_idx], worker1_data[layer_idx][0][dst_idx]) - assert torch.allclose(worker0_data[layer_idx][1][src_idx], worker1_data[layer_idx][1][dst_idx]) + + single_worker_num_blocks = len(dst_blocks)//(num_worker-1) + migration_tasks = [] + worker_idx = 1 + per_step_blocks = 500 + for offset in range(0, len(dst_blocks), single_worker_num_blocks): + src_to_dst = dict(enumerate(dst_blocks[offset:offset+single_worker_num_blocks])) + src_blocks = list(src_to_dst.keys()) + dst_blocks = list(src_to_dst.values()) + for idx in range(0, len(src_blocks), per_step_blocks): + cur_src_blocks = src_blocks[idx:idx+per_step_blocks] + cur_dst_blocks = dst_blocks[idx:idx+per_step_blocks] + migration_tasks.append(workers[0].execute_method.remote( + 'migrate_cache', + src_worker_handle_list=[workers[worker_idx]], + src_blocks=cur_src_blocks, + dst_blocks=cur_dst_blocks) + ) + worker_idx += 1 + ray.get(migration_tasks) + + dst_worker_data = ray.get(workers[0].execute_method.remote('get_gpu_cache')) + + worker_idx = 1 + for offset in range(0, len(dst_blocks), single_worker_num_blocks): + src_to_dst = dict(enumerate(dst_blocks[offset:offset+single_worker_num_blocks])) + + for layer_idx in range(num_layers): + for src_idx, dst_idx in src_to_dst.items(): + assert torch.allclose(worker_datas[worker_idx][layer_idx][0][src_idx], dst_worker_data[layer_idx][0][dst_idx]) + assert torch.allclose(worker_datas[worker_idx][layer_idx][1][src_idx], dst_worker_data[layer_idx][1][dst_idx]) + worker_idx += 1 diff --git a/tests/unit_test/backends/vllm/test_scheduler.py b/tests/unit_test/backends/vllm/test_scheduler.py index 1c1af7ac..c8a03981 100644 --- a/tests/unit_test/backends/vllm/test_scheduler.py +++ b/tests/unit_test/backends/vllm/test_scheduler.py @@ -12,13 +12,14 @@ # limitations under the License. import math +import time from vllm.sequence import Sequence from vllm.sequence import Logprob from vllm.core.policy import PolicyFactory from llumnix.backends.vllm.scheduler import BlockManagerLlumnix -from llumnix.llumlet.request import RequestInferenceType +from llumnix.llumlet.request import RequestInferenceType, RequestStatus from .utils import create_dummy_prompt, initialize_scheduler, create_token_budget @@ -129,6 +130,25 @@ def test_scheduler_running_request(): scheduler.add_running_request(seq_group) assert scheduler.get_num_unfinished_seq_groups() == 4 +def test_scheduler_waiting_request(): + scheduler = initialize_scheduler() + num_seq_group = 4 + block_size = 4 + _, seq_group_0 = create_dummy_prompt("0", prompt_length=0, block_size=block_size) + for idx in range(1, num_seq_group + 1): + _, seq_group = create_dummy_prompt(str(idx), prompt_length=idx, block_size=block_size) + scheduler.add_seq_group(seq_group) + assert scheduler.get_num_unfinished_seq_groups() == 4 + scheduler.remove_waiting_request("1") + assert scheduler.get_num_unfinished_seq_groups() == 3 + _, seq_group = create_dummy_prompt("6", prompt_length=idx, block_size=block_size) + scheduler.add_waiting_request(seq_group) + assert scheduler.get_num_unfinished_seq_groups() == 4 + # Test if sort the waiting queue by arrival time in add_waiting_request. + scheduler.add_waiting_request(seq_group_0) + waiting_queue = scheduler.get_waiting_queue() + assert waiting_queue[0] == seq_group_0 + def test_scheduler_migrating_out_request_last_stage(): scheduler = initialize_scheduler() block_size = 4 @@ -142,13 +162,13 @@ def test_scheduler_migrating_out_request_last_stage(): def test_scheduler_pre_alloc(): # total 8 blocks scheduler = initialize_scheduler() - blocks = scheduler.pre_alloc("1", 2) + blocks = scheduler.pre_alloc("1", RequestStatus.RUNNING, 0.0, 2) assert len(blocks) == 2 assert len(scheduler.pre_alloc_cache_dict["1"]) == 2 - blocks = scheduler.pre_alloc("1", 4) + blocks = scheduler.pre_alloc("1", RequestStatus.RUNNING, 0.0, 4) assert len(blocks) == 4 assert len(scheduler.pre_alloc_cache_dict["1"]) == 6 - blocks = scheduler.pre_alloc("2,", 4) + blocks = scheduler.pre_alloc("2", RequestStatus.RUNNING, 0.0, 4) assert len(blocks) == 0 def test_schedule_running(): @@ -176,3 +196,37 @@ def test_schedule_running(): assert len(running_scheduled.decode_seq_groups) == 1 assert len(running_scheduled.prefill_seq_groups) == 0 assert len(remainig_running) == 1 + + # test pre alloc waiting condition + # total 8 blocks + scheduler = initialize_scheduler() + before_arrival = time.time() + _, seq_group = create_dummy_prompt("1", prompt_length=1, block_size=2, expected_steps=math.inf) + after_arrival = time.time() + blocks = scheduler.pre_alloc("2", RequestStatus.WAITING_MIGRATING, after_arrival, 2) + assert len(blocks) == 2 + scheduler.add_waiting_request(seq_group) + blocks = scheduler.pre_alloc("3", RequestStatus.WAITING_MIGRATING, after_arrival, 2) + assert len(blocks) == 0 + blocks = scheduler.pre_alloc("4", RequestStatus.WAITING_MIGRATING, before_arrival, 2) + assert len(blocks) == 2 + +def test_try_schedule_times(): + # total 8 blocks + scheduler = initialize_scheduler() + _, seq_group_1 = create_dummy_prompt("1", prompt_length=8, block_size=1) + _, seq_group_2 = create_dummy_prompt("2", prompt_length=8, block_size=1) + scheduler.add_seq_group(seq_group_1) + scheduler.add_seq_group(seq_group_2) + waiting_queue = scheduler.get_waiting_queue() + assert len(waiting_queue) == 2 + assert seq_group_1.try_schedule_times == 0 + assert seq_group_2.try_schedule_times == 0 + scheduler.schedule() + # seq_group_2 cannot be scheduled due to lack of blocks + assert seq_group_1.try_schedule_times == 0 + assert seq_group_2.try_schedule_times == 1 + scheduler.schedule() + # seq_group_1 is preempted to waiting queue + assert seq_group_1.try_schedule_times == 1 + assert seq_group_2.try_schedule_times == 2 diff --git a/tests/unit_test/backends/vllm/test_simulator.py b/tests/unit_test/backends/vllm/test_simulator.py index 7fb94baa..7a2632cf 100644 --- a/tests/unit_test/backends/vllm/test_simulator.py +++ b/tests/unit_test/backends/vllm/test_simulator.py @@ -71,7 +71,7 @@ async def test_backend(setup_ray_env): # TODO(ZeldaHuang): add tests for BackendSimVLLM methods # (currently BackendSimVLLM is just a wrapper of BackendVLLM) engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) - migration_config = MigrationConfig("LCFS", "gloo", 16, 1, 4, 5, 20) + migration_config = MigrationConfig("SR", "gloo", 16, 1, 4, 5, 20, 2) output_queue_type = QueueType.RAYQUEUE que, server_info = request_output_queue_server(output_queue_type) diff --git a/tests/unit_test/backends/vllm/test_worker.py b/tests/unit_test/backends/vllm/test_worker.py index dc014005..09df9ea0 100644 --- a/tests/unit_test/backends/vllm/test_worker.py +++ b/tests/unit_test/backends/vllm/test_worker.py @@ -39,7 +39,7 @@ def create_worker(rank: int, local_rank: int, engine_config: EngineConfig, trust_remote_code=True ) - worker.init_worker.remote( + ray.get(worker.init_worker.remote( model_config=engine_config.model_config, parallel_config=engine_config.parallel_config, scheduler_config=engine_config.scheduler_config, @@ -52,25 +52,25 @@ def create_worker(rank: int, local_rank: int, engine_config: EngineConfig, lora_config=engine_config.lora_config, vision_language_config=engine_config.vision_language_config, is_driver_worker = False - ) - + )) + ray.get(worker.execute_method.remote('init_device')) return worker @pytest.mark.parametrize("backend", ['rpc', 'gloo', 'nccl']) def test_reserve_memory_for_migration(setup_ray_env, backend): engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config() - migraiton_config = EngineManagerArgs(migration_cache_blocks=1).create_migration_config() - migraiton_config.migration_backend = backend + migration_config = EngineManagerArgs(migration_buffer_blocks=1).create_migration_config() + migration_config.migration_backend = backend worker = create_worker(rank=0, local_rank=0, engine_config=engine_config) - ray.get(worker.execute_method.remote('init_device')) block_size = CacheEngine.get_cache_block_size(engine_config.cache_config, engine_config.model_config, engine_config.parallel_config) num_layers = engine_config.model_config.get_num_layers(engine_config.parallel_config) - occupy_memory = migraiton_config.migration_cache_blocks * block_size * migraiton_config.migration_num_layers // num_layers + occupy_memory = migration_config.migration_internal_buffer_num * migration_config.migration_buffer_blocks \ + * block_size * migration_config.migration_num_layers // num_layers migration_cache_size = ray.get(worker.execute_method.remote('reserve_memory_for_migration', - migration_config=migraiton_config, + migration_config=migration_config, model_config=engine_config.model_config, cache_config=engine_config.cache_config, parallel_config=engine_config.parallel_config)) @@ -80,17 +80,16 @@ def test_reserve_memory_for_migration(setup_ray_env, backend): @pytest.mark.parametrize("backend", ['rpc', 'gloo', 'nccl']) def test_rebuild_migration_backend(setup_ray_env, backend): engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config() - migraiton_config = EngineManagerArgs(migration_cache_blocks=1).create_migration_config() - migraiton_config.migration_backend = backend + migration_config = EngineManagerArgs(migration_buffer_blocks=1).create_migration_config() + migration_config.migration_backend = backend worker0 = create_worker(rank=0, local_rank=0, engine_config=engine_config) worker0_id = random_uuid() - ray.get(worker0.execute_method.remote('init_device')) ray.get(worker0.execute_method.remote('initialize_cache', num_gpu_blocks=8, num_cpu_blocks=0)) ray.get(worker0.execute_method.remote( 'init_migration', instance_id=worker0_id, - migration_config=migraiton_config, + migration_config=migration_config, src_worker_handle_list=[worker0], node_id=ray.get_runtime_context().get_node_id())) instance_rank = {worker0_id: 0} @@ -100,12 +99,11 @@ def test_rebuild_migration_backend(setup_ray_env, backend): worker1 = create_worker(rank=0, local_rank=0, engine_config=engine_config) worker1_id = random_uuid() - ray.get(worker1.execute_method.remote('init_device')) ray.get(worker1.execute_method.remote('initialize_cache', num_gpu_blocks=8, num_cpu_blocks=0)) ray.get(worker1.execute_method.remote( 'init_migration', instance_id=worker1_id, - migration_config=migraiton_config, + migration_config=migration_config, src_worker_handle_list=[worker1], node_id=ray.get_runtime_context().get_node_id())) diff --git a/tests/unit_test/backends/vllm/utils.py b/tests/unit_test/backends/vllm/utils.py index bc8d1f09..887bdd93 100644 --- a/tests/unit_test/backends/vllm/utils.py +++ b/tests/unit_test/backends/vllm/utils.py @@ -18,7 +18,7 @@ from vllm import SamplingParams from vllm.lora.request import LoRARequest -from vllm.sequence import Logprob, Sequence +from vllm.sequence import Logprob, Sequence, SequenceStatus from vllm.config import SchedulerConfig, CacheConfig from vllm.core.scheduler import SchedulingBudget @@ -45,6 +45,7 @@ def create_dummy_prompt( request_id: str, prompt_length: int, block_size: Optional[int] = None, + status: SequenceStatus = SequenceStatus.WAITING, lora_request: Optional[LoRARequest] = None, use_beam_search: bool = False, best_of: int = 1, @@ -63,6 +64,7 @@ def create_dummy_prompt( request_id, server_info, expected_steps, [prompt], SamplingParams(use_beam_search=use_beam_search, best_of=best_of), time.time(), lora_request) + seq_group.get_seqs()[0].status = status return prompt, seq_group diff --git a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py index 8cee3a69..114ce551 100644 --- a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py +++ b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py @@ -21,7 +21,7 @@ def init_dispatch_scheduler(policy='load'): instance_load_calculator = InstanceLoadCalculator('remaining_steps', True) - dispatch_scheduler = DispatchScheduler(policy, instance_load_calculator, random.randint(1,4)) + dispatch_scheduler = DispatchScheduler(policy, instance_load_calculator, 1) return dispatch_scheduler @pytest.fixture @@ -29,7 +29,9 @@ def dispatch_scheduler(): dispatch_scheduler = init_dispatch_scheduler() yield dispatch_scheduler -def test_add_instance_and_remove_instance(dispatch_scheduler): +@pytest.mark.parametrize("num_dispatch_instances", [1, 2, 3]) +def test_add_instance_and_remove_instance(dispatch_scheduler, num_dispatch_instances): + dispatch_scheduler.num_dispatch_instances = num_dispatch_instances dispatch_scheduler.add_instance('instance_1') assert dispatch_scheduler.num_instances == 1 assert len(dispatch_scheduler.available_dispatch_instance_set) == 1 @@ -99,7 +101,7 @@ def test_dispatch_queue(): instance_info.instance_id = instance_id instance_info.num_waiting_requests = random.randint(1, 10) instance_info_dict[instance_id] = instance_info - if len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.num_dispatch_instances: + if len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.num_dispatch_instances: dispatch_scheduler.available_dispatch_instance_set.add(instance_id) instance_num_requests[instance_id] = 0 dispatch_scheduler.instance_num_requests = instance_num_requests @@ -110,3 +112,36 @@ def test_dispatch_queue(): key=lambda item: item[1].num_waiting_requests)) instance_id = dispatch_scheduler.dispatch() assert instance_info_dict[min_instance_id].num_waiting_requests == instance_info_dict[instance_id].num_waiting_requests + +def test_dispatch_rr(): + instance_num = 7 + instance_load_calculator = InstanceLoadCalculator('remaining_steps', True) + dispatch_scheduler = DispatchScheduler('rr', instance_load_calculator, 3) + instance_num_requests = {} + instance_info_dict = {} + + for instance_id in [f'instance_{i}' for i in range(instance_num)]: + instance_info = InstanceInfo() + instance_info.instance_id = instance_id + instance_info.num_waiting_requests = random.randint(1, 10) + instance_info_dict[instance_id] = instance_info + if len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.num_dispatch_instances: + dispatch_scheduler.available_dispatch_instance_set.add(instance_id) + instance_num_requests[instance_id] = 0 + dispatch_scheduler.instance_num_requests = instance_num_requests + dispatch_scheduler.instance_info = instance_info_dict + + num_request = 2 * instance_num + 2 + for idx in range(0, num_request): + instance_id = dispatch_scheduler.dispatch() + target_instance_id = idx%dispatch_scheduler.num_dispatch_instances + assert instance_id == f'instance_{target_instance_id}' + + for idx in range(instance_num): + if idx < dispatch_scheduler.num_dispatch_instances: + dispatch_scheduler.instance_num_requests[f'instance_{idx}'] = \ + num_request // dispatch_scheduler.num_dispatch_instances + (1 \ + if num_request % dispatch_scheduler.num_dispatch_instances > \ + idx % dispatch_scheduler.num_dispatch_instances else 0) + else: + dispatch_scheduler.instance_num_requests[f'instance_{idx}'] = 0 diff --git a/tests/unit_test/global_scheduler/test_llm_engine_manager.py b/tests/unit_test/global_scheduler/test_llm_engine_manager.py index b744ced6..5f81baf6 100644 --- a/tests/unit_test/global_scheduler/test_llm_engine_manager.py +++ b/tests/unit_test/global_scheduler/test_llm_engine_manager.py @@ -26,6 +26,8 @@ from llumnix.server_info import ServerInfo from llumnix.queue.queue_type import QueueType from llumnix.global_scheduler.scaling_scheduler import InstanceType +from llumnix.backends.vllm.simulator import BackendSimVLLM +from llumnix.backends.profiling import LatencyMemData # pylint: disable=unused-import from tests.conftest import setup_ray_env @@ -40,6 +42,7 @@ def __init__(self, instance_id): self.request_id_set = set() self.instance_info = None self.num_migrate_out = 0 + self.num_migrate_in = 0 def get_instance_id(self) -> str: return self.instance_id @@ -75,12 +78,29 @@ def abort(self, request_id): self.num_requests = len(self.request_id_set) return self.num_requests - def migrate_out(self, src_instance_name, dst_instance_name): + def migrate_out(self, dst_instance_name): self.num_migrate_out += 1 + migrate_in_ray_actor = ray.get_actor(dst_instance_name, namespace='llumnix') + ray.get(migrate_in_ray_actor.migrate_in.remote(self.actor_name)) + time.sleep(0.1) + return self.num_migrate_out + + def migrate_in(self, src_instance_name): + self.num_migrate_in += 1 + return self.num_migrate_in def get_num_migrate_out(self): return self.num_migrate_out + def get_num_migrate_in(self): + return self.num_migrate_in + +class MockBackendSim(BackendSimVLLM): + def _get_lantecy_mem(self, *args, **kwargs): + latency_mem = LatencyMemData({}, {}, {}) + latency_mem.prefill_model_params = (0,0) + latency_mem.decode_model_params = (0,0,0) + return latency_mem def init_manager(): try: @@ -138,6 +158,18 @@ def test_init_llumlets(setup_ray_env, engine_manager): engine_manager_args = EngineManagerArgs() assert num_instances == engine_manager_args.initial_instances +def test_init_llumlets_sim(setup_ray_env, engine_manager): + engine_manager.profiling_result_file_path="//" + # pylint: disable=import-outside-toplevel + import llumnix.backends.vllm.simulator + llumnix.backends.vllm.simulator.BackendSimVLLM = MockBackendSim + engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) + node_id = ray.get_runtime_context().get_node_id() + instance_ids, llumlets = ray.get(engine_manager.init_llumlets.remote(engine_args, node_id, QueueType("rayqueue"))) + num_instances = ray.get(engine_manager.scale_up.remote(instance_ids, llumlets)) + engine_manager_args = EngineManagerArgs() + assert num_instances == engine_manager_args.initial_instances + def test_scale_up_and_down(setup_ray_env, engine_manager): initial_instances = 4 instance_ids, llumlets = init_llumlets(initial_instances) @@ -222,20 +254,37 @@ def get_instance_info_migrate_out(instance_id): return instance_info def test_update_instance_info_loop_and_migrate(setup_ray_env, engine_manager): - instance_ids, llumlets = init_llumlets(2) - instance_id, instance_id_1 = instance_ids[0], instance_ids[1] - llumlet, llumlet_1 = llumlets[0], llumlets[1] - request_id = random_uuid() - request_id_1 = random_uuid() - ray.get(llumlet.generate.remote(request_id, None, math.inf, None, None)) - ray.get(llumlet_1.generate.remote(request_id_1, None, math.inf, None, None)) - instance_info_migrate_out = get_instance_info_migrate_out(instance_id) - instance_info_migrate_in = get_instance_info_migrate_in(instance_id_1) - ray.get(llumlet.set_instance_info.remote(instance_info_migrate_out)) - ray.get(llumlet_1.set_instance_info.remote(instance_info_migrate_in)) - num_migrate_out = ray.get(llumlet.get_num_migrate_out.remote()) - assert num_migrate_out == 0 + num_llumlets = 5 + instance_ids, llumlets = init_llumlets(num_llumlets) + + for i in range(num_llumlets): + for _ in range(2*(i+1)): + ray.get(llumlets[i].generate.remote(random_uuid(), None, math.inf, None, None)) + + instance_info = InstanceInfo() + instance_info.instance_type = InstanceType.NO_CONSTRAINTS + + for i in range(num_llumlets): + instance_info.instance_id = instance_ids[i] + instance_info.num_available_gpu_blocks = 40 - i * 10 + instance_info.num_running_requests = i + instance_info.num_blocks_first_waiting_request = i + ray.get(llumlets[i].set_instance_info.remote(instance_info)) + + for i in range(num_llumlets): + num_migrate_out = ray.get(llumlets[i].get_num_migrate_out.remote()) + assert num_migrate_out == 0 + ray.get(engine_manager.scale_up.remote(instance_ids, llumlets)) - time.sleep(0.5) - num_migrate_out = ray.get(llumlet.get_num_migrate_out.remote()) - assert num_migrate_out != 0 + time.sleep(2) + + for i in range(num_llumlets): + num_migrate_out = ray.get(llumlets[i].get_num_migrate_out.remote()) + num_migrate_in = ray.get(llumlets[i].get_num_migrate_in.remote()) + + if i == 0: + assert num_migrate_in > 1 and num_migrate_out == 0 + elif i == num_llumlets - 1: + assert num_migrate_in == 0 and num_migrate_out > 1 + else: + assert num_migrate_in == 0 and num_migrate_out == 0 diff --git a/tests/unit_test/global_scheduler/test_migration_scheduler.py b/tests/unit_test/global_scheduler/test_migration_scheduler.py index 8fd32105..fa25e1f8 100644 --- a/tests/unit_test/global_scheduler/test_migration_scheduler.py +++ b/tests/unit_test/global_scheduler/test_migration_scheduler.py @@ -17,11 +17,13 @@ import numpy as np from llumnix.instance_info import InstanceLoadCalculator, InstanceInfo -from llumnix.global_scheduler.migration_scheduler import MigrationScheduler, PairMigrationConstraints +from llumnix.global_scheduler.migration_scheduler import MigrationScheduler from llumnix.global_scheduler.scaling_scheduler import InstanceType +from llumnix.global_scheduler.migration_filter import MigrationInstanceFilter, MigrationFilterConfig +from llumnix.global_scheduler.migration_policy import PairMigrationConstraints MIGRATE_OUT_LOAD_THRESHOLD = 3.0 -INSTANCE_NUM = 4 +INSTANCE_NUM = 16 def init_migration_scheduler(policy='balanced'): instance_load_calculator = InstanceLoadCalculator('remaining_steps', True) @@ -43,57 +45,66 @@ def test_add_instance_and_remove_instance(migration_scheduler): migration_scheduler.remove_instance('instance_2') assert migration_scheduler.num_instances == 0 -@pytest.mark.parametrize("pair_migration_type", ['NO_CONSTRAINTS','DECODING_2_DECODING','PREFILL_2_DECODING']) -def test_get_migration_instance_infos(pair_migration_type): +@pytest.mark.parametrize("pair_migration_type", ['NO_CONSTRAINTS', 'DECODING_2_DECODING', 'PREFILL_2_DECODING']) +def test_migration_filter(pair_migration_type): num_tests = 1000 + migration_filter = MigrationInstanceFilter(MigrationFilterConfig(MIGRATE_OUT_LOAD_THRESHOLD)) + for _ in range(num_tests): - instance_info_dict = {} - for instance_id in [f'instance_{i}' for i in range(1, INSTANCE_NUM + 1)]: + instance_infos = [] + + total_prefill_instance_num = 0 + + for instance_id in range(1, INSTANCE_NUM + 1): instance_info = InstanceInfo() instance_info.instance_id = instance_id instance_info.instance_load_migrate = MIGRATE_OUT_LOAD_THRESHOLD + random.uniform(-1, 1) instance_info.num_killed_requests = random.randint(0, 1) + if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: constraint_prefill_instance_num = math.inf else: constraint_prefill_instance_num = random.randint(1, INSTANCE_NUM) - migration_scheduler = init_migration_scheduler() + if constraint_prefill_instance_num == math.inf: instance_info.instance_type = InstanceType.NO_CONSTRAINTS else: - if len([info for info in instance_info_dict.values() - if info.instance_type == InstanceType.PREFILL]) < constraint_prefill_instance_num: + if total_prefill_instance_num < constraint_prefill_instance_num: instance_info.instance_type = InstanceType.PREFILL + total_prefill_instance_num += 1 else: instance_info.instance_type = InstanceType.DECODE - instance_info_dict[instance_id] = instance_info - migration_scheduler.instance_info = instance_info_dict - migration_scheduler._sort_instance_infos(descending=False) - sorted_src_instance_infos, sorted_dst_instance_infos = migration_scheduler._get_migration_instance_infos(pair_migration_type) - for instance in sorted_src_instance_infos: - if pair_migration_type != PairMigrationConstraints.PREFILL_2_DECODING: - assert instance.num_killed_requests > 0 \ - or instance.instance_load_migrate > MIGRATE_OUT_LOAD_THRESHOLD - if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: - assert instance.instance_type == InstanceType.NO_CONSTRAINTS - elif migration_scheduler == PairMigrationConstraints.DECODING_2_DECODING: - assert instance.instance_type == InstanceType.DECODE - else: - assert instance.instance_type == InstanceType.PREFILL - for instance in sorted_dst_instance_infos: - if pair_migration_type != PairMigrationConstraints.PREFILL_2_DECODING: - assert instance.num_killed_requests == 0 and instance.instance_load_migrate < MIGRATE_OUT_LOAD_THRESHOLD - if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: - assert instance.instance_type == InstanceType.NO_CONSTRAINTS - elif migration_scheduler == PairMigrationConstraints.DECODING_2_DECODING: + + instance_infos.append(instance_info) + + src_instance_infos, dst_instance_infos = migration_filter.filter_instances(instance_infos, pair_migration_type) + + for instance in src_instance_infos: + if pair_migration_type != PairMigrationConstraints.PREFILL_2_DECODING: + assert instance.num_killed_requests > 0 \ + or instance.instance_load_migrate > MIGRATE_OUT_LOAD_THRESHOLD + if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: + assert instance.instance_type == InstanceType.NO_CONSTRAINTS + elif pair_migration_type == PairMigrationConstraints.DECODING_2_DECODING: + assert instance.instance_type == InstanceType.DECODE + else: + assert instance.instance_type == InstanceType.PREFILL + + for instance in dst_instance_infos: + if pair_migration_type != PairMigrationConstraints.PREFILL_2_DECODING: + assert instance.num_killed_requests == 0 and instance.instance_load_migrate < MIGRATE_OUT_LOAD_THRESHOLD + if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: + assert instance.instance_type == InstanceType.NO_CONSTRAINTS + elif pair_migration_type == PairMigrationConstraints.DECODING_2_DECODING: + assert instance.instance_type == InstanceType.DECODE + else: assert instance.instance_type == InstanceType.DECODE - else: - assert instance.instance_type == InstanceType.DECODE - assert instance.num_killed_requests == 0 + assert instance.num_killed_requests == 0 -@pytest.mark.parametrize("policy", ['balanced','defrag_constrained']) +@pytest.mark.parametrize("policy", ['balanced', 'defrag_constrained']) def test_pair_migration(policy): num_tests = 1000 + for _ in range(num_tests): migration_scheduler = init_migration_scheduler(policy) instance_info_dict = {} @@ -106,14 +117,9 @@ def test_pair_migration(policy): instance_info.instance_type = InstanceType.NO_CONSTRAINTS instance_info_dict[instance_id] = instance_info migration_scheduler.instance_info = instance_info_dict - migration_scheduler._sort_instance_infos(descending=False) - sorted_src_instance_infos = [i for i in reversed(migration_scheduler.sorted_instance_infos) - if i.instance_type == InstanceType.NO_CONSTRAINTS - and (i.num_killed_requests > 0 or i.instance_load_migrate > migration_scheduler.migrate_out_load_threshold)] - sorted_dst_instance_infos = [i for i in migration_scheduler.sorted_instance_infos - if i.instance_type == InstanceType.NO_CONSTRAINTS - and (i.num_killed_requests == 0 and i.instance_load_migrate < migration_scheduler.migrate_out_load_threshold)] - migrate_instance_pairs = migration_scheduler.pair_migration_policy.pair_migration(sorted_src_instance_infos, sorted_dst_instance_infos) + + migrate_instance_pairs = migration_scheduler.pair_migration(PairMigrationConstraints.NO_CONSTRAINTS) + for migrate_out_instance, migrate_in_instance in migrate_instance_pairs: assert migrate_out_instance != migrate_in_instance if policy == 'balanced': diff --git a/tests/unit_test/llumlet/test_engine_step_exception.py b/tests/unit_test/llumlet/test_engine_step_exception.py index 56b58322..86176ab2 100644 --- a/tests/unit_test/llumlet/test_engine_step_exception.py +++ b/tests/unit_test/llumlet/test_engine_step_exception.py @@ -11,7 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import time import ray import torch @@ -30,28 +29,17 @@ @ray.remote(num_cpus=1, max_concurrency=4) class MockLlumlet(Llumlet): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.origin_step = self.backend_engine.engine.step_async - - def set_error_step(self, broken: bool): - self.backend_engine._stop_event.set() - + def set_error_step(self): async def raise_error_step(): await self.origin_step() raise ValueError("Mock engine step error") - if broken: - self.backend_engine.engine.step_async = raise_error_step - else: - self.backend_engine.engine.step_async = self.origin_step - - asyncio.create_task(self.backend_engine._start_engine_step_loop()) + self.backend_engine.engine.step_async = raise_error_step @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need at least 1 GPU to run the test.") def test_engine_step_exception(setup_ray_env): - engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) - migration_config = MigrationConfig("LCFS", "rpc", 16, 1, 4, 5, 20) + engine_args = EngineArgs(model="facebook/opt-125m", max_model_len=8, worker_use_ray=True) + migration_config = MigrationConfig("SR", "rpc", 16, 1, 4, 5, 20, 2) node_id = ray.get_runtime_context().get_node_id() scheduling_strategy = NodeAffinitySchedulingStrategy(node_id=node_id, soft=False) @@ -76,7 +64,7 @@ def test_engine_step_exception(setup_ray_env): cur_free_memory, _ = torch.cuda.mem_get_info() assert cur_free_memory < origin_free_memory - ray.get(llumlet.set_error_step.remote(True)) + ray.get(llumlet.set_error_step.remote()) time.sleep(3) all_actors = ray.util.list_named_actors(True) diff --git a/tests/unit_test/llumlet/test_local_migration_scheduler.py b/tests/unit_test/llumlet/test_local_migration_scheduler.py index d585300d..ecca2b71 100644 --- a/tests/unit_test/llumlet/test_local_migration_scheduler.py +++ b/tests/unit_test/llumlet/test_local_migration_scheduler.py @@ -13,20 +13,25 @@ import math from llumnix.llumlet.local_migration_scheduler import LocalMigrationScheduler -from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType +from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType, RequestStatus class MockRequest(LlumnixRequest): - def __init__(self, request_id, length, expected_steps) -> None: + def __init__(self, request_id, length, expected_steps, status=RequestStatus.RUNNING) -> None: super().__init__(request_id=request_id, server_info=None, expected_steps=expected_steps) self.length = length - self.status = RequestInferenceType.DECODE + self._status = status + self._inference_type = RequestInferenceType.DECODE + self._finished = False + self.try_schedule_times = 0 + self.eom = False - def is_finished(self) -> bool: - return False + @property + def finished(self) -> bool: + return self._finished @property def inference_type(self) -> RequestInferenceType: - return self.status + return self._inference_type @property def request_len(self) -> int: @@ -40,16 +45,37 @@ def prompt_len(self) -> int: def output_len(self) -> int: return self.length + @property + def arrival_time(self) -> float: + pass + + @property + def status(self) -> RequestStatus: + return self._status + + @property + def prefill_num_blocks(self) -> int: + pass + class MockeEngine(): def __init__(self) -> None: self.running = [] + self.waiting = [] def add_request(self, request_id, length, expected_steps) -> None: self.running.append(MockRequest(request_id, length, expected_steps)) + def add_request_waiting(self, request_id, length, expected_steps) -> None: + request = MockRequest(request_id, length, expected_steps, status=RequestStatus.WAITING) + request.try_schedule_times += 1 + self.waiting.append(request) + def get_running_queue(self): return self.running + def get_waiting_queue(self): + return self.waiting + def test_scheduler_policy(): engine = MockeEngine() scheduler = LocalMigrationScheduler("", engine) @@ -57,33 +83,40 @@ def test_scheduler_policy(): engine.add_request(request_id="0", length=1, expected_steps=math.inf) engine.add_request(request_id="1", length=3, expected_steps=math.inf) engine.add_request(request_id="2", length=2, expected_steps=math.inf) - - scheduler.request_migration_policy = "LCFS" - assert scheduler.get_migrate_out_request().request_id == "2" - scheduler.request_migration_policy = "LJF" - assert scheduler.get_migrate_out_request().request_id == "1" - scheduler.request_migration_policy = "SJF" - assert scheduler.get_migrate_out_request().request_id == "0" - - engine.add_request(request_id="3", length=2, expected_steps=1) - request = scheduler.get_migrate_out_request() - assert request.request_id == "3" + engine.add_request_waiting(request_id="3", length=2, expected_steps=math.inf) + engine.add_request_waiting(request_id="4", length=2, expected_steps=math.inf) + + scheduler.request_migration_policy = "LCR" + assert scheduler.get_migrate_out_requests()[0].request_id == "2" + scheduler.request_migration_policy = "LR" + assert scheduler.get_migrate_out_requests()[0].request_id == "1" + scheduler.request_migration_policy = "SR" + assert scheduler.get_migrate_out_requests()[0].request_id == "0" + scheduler.request_migration_policy = "FCW" + assert scheduler.get_migrate_out_requests()[0].request_id == "3" + scheduler.request_migration_policy = "FCWSR" + assert scheduler.get_migrate_out_requests()[0].request_id == "3" + assert scheduler.get_migrate_out_requests()[1].request_id == "0" + + engine.add_request(request_id="5", length=2, expected_steps=1) + request = scheduler.get_migrate_out_requests()[0] + assert request.request_id == "5" assert request.output_len >= request.expected_steps and request.inference_type == RequestInferenceType.DECODE - engine.add_request(request_id="4", length=3, expected_steps=math.inf) - scheduler.request_migration_policy = "LCFS" - request = scheduler.get_migrate_out_request() - assert request.request_id == "3" + engine.add_request(request_id="6", length=3, expected_steps=math.inf) + scheduler.request_migration_policy = "LCR" + request = scheduler.get_migrate_out_requests()[0] + assert request.request_id == "5" assert request.output_len >= request.expected_steps and request.inference_type == RequestInferenceType.DECODE def test_scheduler_should_abort_migration(): req_0 = MockRequest(request_id="0", length=1, expected_steps=math.inf) req_0.stage_timestamps = [1] assert req_0.should_abort_migration() is False - req_0.status = RequestInferenceType.PREFILL - assert req_0.should_abort_migration() is True - req_0.status = RequestInferenceType.DECODE req_0.last_preemption_time = 2 assert req_0.should_abort_migration() is True + req_0.last_preemption_time = None + req_0._finished = True + assert req_0.should_abort_migration() is True def test_blocking_migration(): req_0 = MockRequest(request_id="0", length=1, expected_steps=math.inf) diff --git a/tests/unit_test/llumlet/test_migration_coordinator.py b/tests/unit_test/llumlet/test_migration_coordinator.py index 8a1a4d44..fcdf0638 100644 --- a/tests/unit_test/llumlet/test_migration_coordinator.py +++ b/tests/unit_test/llumlet/test_migration_coordinator.py @@ -38,7 +38,7 @@ async def test_migrate_out_onestage(setup_ray_env): migrate_out_request = MagicMock() # Create an instance of MigrationCoordinator - coordinator = MigrationCoordinator(backend_engine, 1, 3) + coordinator = MigrationCoordinator(backend_engine, last_stage_max_blocks=1, max_stages=3) # Mock method return values and test data src_blocks = [1, 2, 3] @@ -49,7 +49,7 @@ async def test_migrate_out_onestage(setup_ray_env): migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) # Test normal migration scenario - status = await coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) + status = await coordinator._migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) assert status == MigrationStatus.RUNNING # Test the last stage of migration @@ -59,20 +59,21 @@ async def test_migrate_out_onestage(setup_ray_env): migrate_out_request.should_abort_migration.return_value = False migrate_out_request.blocking_migration = False migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) - status = await coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) - assert status == MigrationStatus.FINISHED_DONE + status = await coordinator._migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) + assert status == MigrationStatus.FINISHED migrate_out_request = MagicMock() - # Test migration aborted scenario + # Test migration dst aborted scenario src_blocks = [1, 2, 3] dst_blocks = [] backend_engine.get_request_incremental_blocks.return_value = src_blocks migrate_out_request.should_abort_migration.return_value = False migrate_out_request.blocking_migration = False migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) - status = await coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) + status = await coordinator._migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) assert status == MigrationStatus.ABORTED_DST + # Test migration src aborted scenario migrate_out_request = MagicMock() src_blocks = [1, 2, 3] dst_blocks = [1, 2] @@ -80,23 +81,13 @@ async def test_migrate_out_onestage(setup_ray_env): migrate_out_request.should_abort_migration.return_value = True migrate_out_request.blocking_migration = False migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) - status = await coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) + status = await coordinator._migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) assert status == MigrationStatus.ABORTED_SRC - migrate_out_request = MagicMock() - src_blocks = [1, 2, 3] - dst_blocks = [1, 2] - backend_engine.get_request_incremental_blocks.return_value = src_blocks - migrate_out_request.should_abort_migration.return_value = False - migrate_out_request.blocking_migration = True - migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) - status = await coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) - assert status == MigrationStatus.ABORTED_DST - -# setup_ray_env should be passed after migrate_out_onestage -@patch.object(MigrationCoordinator, 'migrate_out_onestage') +# setup_ray_env should be passed after _migrate_out_onestage +@patch.object(MigrationCoordinator, '_migrate_out_onestage') @pytest.mark.asyncio -async def test_migrate_out_multistage(_, setup_ray_env): +async def test_migrate_out_running_request(_, setup_ray_env): # Create mock objects backend_engine = MagicMock(spec=BackendInterface) migrate_in_ray_actor = MagicMock() @@ -110,16 +101,41 @@ async def test_migrate_out_multistage(_, setup_ray_env): migrate_in_ray_actor.execute_engine_method.remote = MagicMock() migrate_in_ray_actor.execute_engine_method.remote.return_value = ray_remote_call.remote([1]) migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote([1]) - coordinator.migrate_out_onestage.side_effect = [MigrationStatus.FINISHED_DONE] - status = await coordinator.migrate_out_multistage(migrate_in_ray_actor, migrate_out_request) - assert coordinator.migrate_out_onestage.call_count == 1 - assert status == MigrationStatus.FINISHED_DONE + coordinator._migrate_out_onestage.side_effect = [MigrationStatus.FINISHED] + status = await coordinator.migrate_out_running_request(migrate_in_ray_actor, migrate_out_request) + assert coordinator._migrate_out_onestage.call_count == 1 + assert status == MigrationStatus.FINISHED max_stages = 3 - coordinator.migrate_out_onestage.side_effect = [MigrationStatus.RUNNING, - MigrationStatus.RUNNING, - MigrationStatus.RUNNING, - MigrationStatus.RUNNING] - status = await coordinator.migrate_out_multistage(migrate_in_ray_actor, migrate_out_request) - assert coordinator.migrate_out_onestage.call_count == max_stages + 1 + coordinator._migrate_out_onestage.side_effect = [MigrationStatus.RUNNING, + MigrationStatus.RUNNING, + MigrationStatus.RUNNING, + MigrationStatus.RUNNING] + status = await coordinator.migrate_out_running_request(migrate_in_ray_actor, migrate_out_request) + assert coordinator._migrate_out_onestage.call_count == max_stages + 1 assert status == MigrationStatus.ABORTED_SRC + +@pytest.mark.asyncio +async def test_migrate_out_waiting_request(): + # Create mock objects + backend_engine = MagicMock(spec=BackendInterface) + migrate_in_ray_actor = MagicMock() + migrate_out_request = MagicMock() + + # Create an instance of MigrationCoordinator + coordinator = MigrationCoordinator(backend_engine, last_stage_max_blocks=1, max_stages=3) + + # Test FINISHED + migrate_out_request.prefill_num_blocks = 3 + dst_blocks = [1, 2, 3] + migrate_in_ray_actor.execute_engine_method = MagicMock() + migrate_in_ray_actor.execute_engine_method.remote = MagicMock() + migrate_in_ray_actor.execute_engine_method.remote.return_value = ray_remote_call.remote(dst_blocks) + migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) + status = await coordinator.migrate_out_waiting_request(migrate_in_ray_actor, migrate_out_request) + assert status == MigrationStatus.FINISHED + + # Test FINISHED_ABORTED + migrate_out_request.prefill_num_blocks = 2 + status = await coordinator.migrate_out_waiting_request(migrate_in_ray_actor, migrate_out_request) + assert status == MigrationStatus.ABORTED_DST diff --git a/tests/unit_test/queue/test_zmq.py b/tests/unit_test/queue/test_zmq.py index d4303d37..6f62935e 100644 --- a/tests/unit_test/queue/test_zmq.py +++ b/tests/unit_test/queue/test_zmq.py @@ -106,8 +106,8 @@ async def benchmark_queue(qps, ip=None, port=None): signal.alarm(0) @pytest.mark.asyncio -@pytest.mark.parametrize("qps", [128.0, 256.0, 512.0, 1024.0]) -async def test_queue_zmq(setup_ray_env, qps): +async def test_queue_zmq(setup_ray_env): ip = '127.0.0.1' port = 1234 + qps = 1024.0 await benchmark_queue(qps, ip, port)