diff --git a/Makefile b/Makefile index 6bc87a9b..b2cd80f3 100644 --- a/Makefile +++ b/Makefile @@ -21,22 +21,22 @@ install: .PHONY: lint lint: check_pylint_installed check_pytest_installed - @pylint --rcfile=.pylintrc -s n --jobs=32 ./llumnix + @pylint --rcfile=.pylintrc -s n --jobs=128 ./llumnix @pylint --rcfile=.pylintrc \ --disable=protected-access,super-init-not-called,unused-argument,redefined-outer-name,invalid-name \ - -s n --jobs=32 ./tests + -s n --jobs=128 ./tests .PHONY: test test: check_pytest_installed - @pytest -x -v --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings + @pytest -v -x --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings @python examlpes/offline_inference.py - @pytest -v tests/e2e_test/test_e2e.py + @pytest -v -x tests/e2e_test/test_e2e.py @pytest -v -x ./tests/e2e_test/test_migration.py .PHONY: unit_test unit_test: check_pytest_installed - @pytest -x -v --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings + @pytest -v -x --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings .PHONY: offline_test offline_test: @@ -44,11 +44,11 @@ offline_test: .PHONY: e2e_test e2e_test: - @pytest -v tests/e2e_test/test_e2e.py + @pytest -v -x tests/e2e_test/test_e2e.py .PHONY: bench_test bench_test: - @pytest -v ./tests/e2e_test/test_bench.py + @pytest -v -x ./tests/e2e_test/test_bench.py .PHONY: migration_test migration_test: diff --git a/configs/base.yml b/configs/base.yml index afce7127..70358339 100644 --- a/configs/base.yml +++ b/configs/base.yml @@ -2,8 +2,6 @@ SERVER: HOST: '127.0.0.1' PORT: 1234 QUEUE_TYPE: "rayqueue" - -RAY: RAY_CLUSTER_PORT: 6379 LAUNCH_RAY_CLUSTER: True @@ -21,6 +19,7 @@ MANAGER: REQUEST_MIGRATION_POLICY: 'SJF' MIGRATION_BACKEND: 'gloo' - MIGRATION_CACHE_BLOCKS: 512 + MIGRATION_BUFFER_BLOCKS: 512 + MIGRATION_INTERNAL_BUFFER_NUM: 2 ENABLE_SCALING: False diff --git a/docs/Arguments.md b/docs/Arguments.md index c8397bfa..a2584417 100644 --- a/docs/Arguments.md +++ b/docs/Arguments.md @@ -32,14 +32,15 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] [--profiling-result-file-path PROFILING_RESULT_FILE_PATH] [--gpu-type GPU_TYPE] [--polling-interval POLLING_INTERVAL] - [--migration-backend {gloo,nccl,rpc}] - [--migration-cache-blocks MIGRATION_CACHE_BLOCKS] + [--migration-backend {gloo,rpc}] + [--migration-buffer-blocks MIGRATION_BUFFER_BLOCKS] [--migration-backend-init-timeout MIGRATION_BACKEND_INIT_TIMEOUT] [--migration-num-layers MIGRATION_NUM_LAYERS] [--last-stage-max-blocks LAST_STAGE_MAX_BLOCKS] [--max-stages MAX_STAGES] [--enable-pd-disagg] [--num-dispatch-instances NUM_DISPATCH_INSTANCES] + [--migration-internal-buffer-num MIGRATION_INTERNAL_BUFFER_NUM] [--log-request-timestamps] ``` @@ -147,8 +148,8 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] - Possible choices: gloo, rpc - Default: "rpc" -`--migration-cache-blocks` -- Number of cache blocks in migration. +`--migration-buffer-blocks` +- Number of cache blocks in each migration buffer. - Default: 512 `--migration-backend-init-timeout` @@ -167,6 +168,10 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] - Drop migration if the number of stages > max_stages. - Default: 3 +`--migration-internal-buffer-num` +- Number of the buffer in migration backend for sending and receiving +- Default: 2 + `--log-request-timestamps` - Enable logging request timestamps. diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index dd80276d..37b3bbc6 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -22,7 +22,6 @@ from llumnix.config import LlumnixConfig, get_llumnix_config from llumnix.config.default import _C - class LlumnixArgumentParser(argparse.ArgumentParser): def __init__(self, *args, **kwargs): self.cur_namespace = "llumnix" @@ -134,10 +133,11 @@ class EngineManagerArgs: migration_backend_init_timeout: float = None migration_backend: str = None - migration_cache_blocks: int = None + migration_buffer_blocks: int = None migration_num_layers: int = None last_stage_max_blocks: int = None max_stages: int = None + migration_internal_buffer_num: int = None enable_pd_disagg: bool = None @@ -172,11 +172,12 @@ def create_global_scheduler_configs( def create_migration_config(self) -> MigrationConfig: migration_config = MigrationConfig(self.request_migration_policy, self.migration_backend, - self.migration_cache_blocks, + self.migration_buffer_blocks, self.migration_num_layers, self.last_stage_max_blocks, self.max_stages, - self.migration_backend_init_timeout) + self.migration_backend_init_timeout, + self.migration_internal_buffer_num) return migration_config @classmethod @@ -195,6 +196,9 @@ def check_args(cls, args: 'EngineManagerArgs', parser: argparse.ArgumentParser): if hasattr(action, 'choices') and action.choices is not None and hasattr(args, action.dest): assert getattr(args, action.dest) in action.choices, f"{action.dest} should be one of {action.choices}." + assert args.migration_backend != 'nccl', 'NCCL has been temporarily deprecated due to its incompatibility with \ + concurrent migrations in Llumnix.' + assert args.migration_backend != 'gloo' or (args.migration_backend == 'gloo' \ and not args.disable_init_instance_by_manager and not args.disable_fixed_node_init_instance), \ ("When using gloo as migration backend, " @@ -288,20 +292,23 @@ def add_cli_args( parser.add_argument('--migration-backend', type=str, - choices=['gloo','nccl','rpc'], + choices=['gloo', 'nccl', 'rpc'], help='communication backend of migration') parser.add_argument('--migration-backend-init-timeout', type=float, help='timeout(s) for initializing migration backend') - parser.add_argument('--migration-cache-blocks', + parser.add_argument('--migration-buffer-blocks', type=int, - help='number of cache blocks in migration') + help='number of cache blocks in each migration buffer') parser.add_argument('--migration-num-layers', type=int, help='number of kv-cache layers to transfer in each round during migration') parser.add_argument('--last-stage-max-blocks', type=int, help='if the number pf remain blocks < last_stage_max_blocks, do last stage migration') + parser.add_argument('--migration-internal-buffer-num', + type=int, + help='number of the buffer in migration backend for sending and receiving') parser.add_argument('--max-stages', type=int, help='drop migration if the number of stages > max_stages') diff --git a/llumnix/backends/migration_backend_interface.py b/llumnix/backends/migration_backend_interface.py index 808ba8c8..9fd231cc 100644 --- a/llumnix/backends/migration_backend_interface.py +++ b/llumnix/backends/migration_backend_interface.py @@ -13,7 +13,9 @@ from abc import ABC, abstractmethod from typing import List +import queue +import torch class MigrationBackendBase(ABC): @abstractmethod @@ -39,3 +41,24 @@ def do_send(self, dst_handle, blocks: List[int]): @abstractmethod def do_recv(self, src_handle, blocks: List[int]): raise NotImplementedError + +class BufferMigrationBackend(MigrationBackendBase): + def __init__(self, num_buffer, buffer_shape, buffer_dtype, buffer_device, pin_memory, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.num_buffer = num_buffer + + self.dummy_buffer = [ + torch.empty(size=buffer_shape, dtype=buffer_dtype, device=buffer_device, pin_memory=pin_memory) + for _ in range(self.num_buffer) + ] + + self.avaiable_buffer_queue = queue.Queue() + for i in range(self.num_buffer): + self.avaiable_buffer_queue.put_nowait(i) + + def get_available_cache(self): + return self.avaiable_buffer_queue.get() + + def put_back_cache(self, buffer_id): + self.avaiable_buffer_queue.put_nowait(buffer_id) diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index bf583366..4b2a076d 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -355,10 +355,10 @@ def commit_dst_request(self, backend_request: SequenceGroupLlumnix) -> None: async def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None: await dst_ray_actor.execute_engine_method.remote("_run_workers", - "migrate_cache", - dst_blocks=dst_blocks, - src_blocks=src_blocks, - src_worker_handle_list=self.worker_handle_list) + "migrate_cache", + dst_blocks=dst_blocks, + src_blocks=src_blocks, + src_worker_handle_list=self.worker_handle_list) def _run_workers(self, *args, **kwargs): # pylint: disable=protected-access diff --git a/llumnix/backends/vllm/migration_backend.py b/llumnix/backends/vllm/migration_backend.py index 947d3e7e..e69f3479 100644 --- a/llumnix/backends/vllm/migration_backend.py +++ b/llumnix/backends/vllm/migration_backend.py @@ -15,11 +15,15 @@ import torch from func_timeout import func_set_timeout, FunctionTimedOut +import cupy +from cupy.cuda import nccl import ray import ray.util.collective as col +from ray.util.collective.collective_group import nccl_util + from vllm.worker.cache_engine import CacheEngine from llumnix.internal_config import MigrationConfig -from llumnix.backends.migration_backend_interface import MigrationBackendBase +from llumnix.backends.migration_backend_interface import MigrationBackendBase, BufferMigrationBackend from llumnix.logger import init_logger logger = init_logger(__name__) @@ -40,17 +44,16 @@ def exec_method(self, is_driver_worker, handle, *args, **kwargs): NUMPY_SUPPORTED_DTYPES = [torch.float32, torch.float16] -class RayRpcMigrationBackend(MigrationBackendBase): +class RayRpcMigrationBackend(BufferMigrationBackend): def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, worker_rank, worker_handle_list, \ scheduling_strategy, is_driver_worker, gpu_cache) -> None: - super().__init__() - self.migration_config = migration_config self.cache_engine = cache_engine self.worker_rank = worker_rank self.worker_handle_list = worker_handle_list self.actor = ProxyActor.options(scheduling_strategy=scheduling_strategy).remote() + self.migration_stream = torch.cuda.Stream() self.rpc_dtype = self.cache_engine.dtype if self.cache_engine.dtype in NUMPY_SUPPORTED_DTYPES: @@ -62,17 +65,13 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, self.is_driver_worker = is_driver_worker self.gpu_cache = gpu_cache self.cache_device = "cpu" - self.num_migration_cache_blocks = self.migration_config.migration_cache_blocks + self.num_migration_buffer_blocks = self.migration_config.migration_buffer_blocks self.num_layers = self.cache_engine.num_layers self.migration_cache_size = self.cache_engine.block_size * self.cache_engine.num_heads * self.cache_engine.head_size + buffer_shape = (self.num_migration_buffer_blocks, self.num_layers, 2, self.migration_cache_size) - self.dummy_cache = torch.empty( - size=(self.num_migration_cache_blocks, self.num_layers, 2, self.migration_cache_size), - dtype=self.cache_engine.dtype, - device=self.cache_device, - pin_memory=True - ) - self.migration_stream = torch.cuda.Stream() + super().__init__(migration_config.migration_internal_buffer_num, buffer_shape, self.cache_engine.dtype, + self.cache_device, pin_memory=True) def init_backend(self, group_name, world_size, rank) -> bool: logger.info("create rpc migration backend successfully.") @@ -94,30 +93,38 @@ def warmup(self) -> bool: def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]) -> None: tot_blocks = len(src_blocks) rpc_numpy_cache = None - for start_idx in range(0, tot_blocks, self.num_migration_cache_blocks): - offset = min(self.num_migration_cache_blocks, tot_blocks - start_idx) + for start_idx in range(0, tot_blocks, self.num_migration_buffer_blocks): + offset = min(self.num_migration_buffer_blocks, tot_blocks - start_idx) send_blocks = src_blocks[start_idx:start_idx+offset] ray_obj = self.actor.exec_method.remote(self.is_driver_worker, src_handle, "do_send", None, send_blocks) if rpc_numpy_cache is not None: self.do_recv(rpc_numpy_cache, recv_blocks) - rpc_numpy_cache = ray.get(ray_obj) + rpc_numpy_cache_ref = ray.get(ray_obj) + rpc_numpy_cache = ray.get(rpc_numpy_cache_ref) recv_blocks = dst_blocks[start_idx:start_idx+offset] self.do_recv(rpc_numpy_cache, recv_blocks) def do_send(self, dst_handle, blocks: List[int]): num_blocks = len(blocks) - send_cache = self.dummy_cache[:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size) + dummy_cache_idx = self.get_available_cache() + send_cache = self.dummy_buffer[dummy_cache_idx][:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size) src_to_dst = {block_num: idx for idx, block_num in enumerate(blocks)} with torch.cuda.stream(self.migration_stream): for layer_idx in range(self.num_layers): self.cache_engine.attn_backend.swap_blocks(self.gpu_cache[layer_idx], send_cache[layer_idx], src_to_dst) torch.cuda.Stream.synchronize(self.migration_stream) - return send_cache.to(self.rpc_dtype).numpy() + # Here, we use ray.put to store data and finally return the object reference so that we can release the internal buffer. + # This might seem like an anti-pattern, but it's okay since the kv-cache transferred is in the MB range and won't utilize + # Ray's optimization for returning small objects (<100KB). + data = ray.put(send_cache.to(self.rpc_dtype).numpy()) + self.put_back_cache(dummy_cache_idx) + return data def do_recv(self, src_handle, blocks: List[int]): num_blocks = len(blocks) src_to_dst = dict(enumerate(blocks)) - recv_cache = self.dummy_cache[:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size) + dummy_cache_idx = self.get_available_cache() + recv_cache = self.dummy_buffer[dummy_cache_idx][:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size) # use pin memory dummy_cache to speed up data transfer recv_cache.copy_(torch.from_numpy(src_handle)) @@ -125,6 +132,7 @@ def do_recv(self, src_handle, blocks: List[int]): for layer_idx in range(self.num_layers): self.cache_engine.attn_backend.swap_blocks(recv_cache[layer_idx], self.gpu_cache[layer_idx], src_to_dst) torch.cuda.Stream.synchronize(self.migration_stream) + self.put_back_cache(dummy_cache_idx) def try_import_gloo(): try: @@ -139,19 +147,14 @@ def try_import_gloo(): except ImportError as e: raise ImportError("Gloo is not installed. Please install it first.") from e -class RayColMigrationBackend(MigrationBackendBase): +class RayColMigrationBackend(BufferMigrationBackend): def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, local_rank, scheduling_strategy, is_driver_worker, gpu_cache) -> None: - super().__init__() - - # pylint: disable=C0415 - import cupy - self.migration_config = migration_config self.cache_engine = cache_engine self.backend = migration_config.migration_backend self.migration_num_layers = min(migration_config.migration_num_layers, self.cache_engine.num_layers) - self.num_migration_cache_blocks = migration_config.migration_cache_blocks + self.num_migration_buffer_blocks = migration_config.migration_buffer_blocks self.backend = migration_config.migration_backend self.global_world_size = -1 @@ -162,6 +165,7 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, self.actor = ProxyActor.options(scheduling_strategy=scheduling_strategy).remote() self.is_driver_worker = is_driver_worker self.gpu_cache = gpu_cache + self.migration_stream = cupy.cuda.Stream() self.migration_cache_size = self.cache_engine.block_size * self.cache_engine.num_heads * self.cache_engine.head_size @@ -169,17 +173,13 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, try_import_gloo() self.cache_device = "cpu" else: + nccl_util.TORCH_NCCL_DTYPE_MAP[torch.bfloat16] = nccl.NCCL_FLOAT16 self.cache_device = torch.device(f"cuda:{self.local_rank}") pin_memory = (self.backend == 'gloo') - self.dummy_cache = torch.empty( - size=(self.num_migration_cache_blocks, self.migration_num_layers, 2, self.migration_cache_size), - dtype=self.cache_engine.dtype, - device=self.cache_device, - pin_memory=pin_memory - ) - - self.migration_stream = cupy.cuda.Stream() + buffer_shape = (self.num_migration_buffer_blocks, self.migration_num_layers, 2, self.migration_cache_size) + super().__init__(migration_config.migration_internal_buffer_num, buffer_shape, self.cache_engine.dtype, + self.cache_device, pin_memory=pin_memory) def init_backend(self, group_name, world_size, rank) -> bool: @func_set_timeout(self.migration_config.migration_backend_init_timeout) @@ -224,7 +224,7 @@ def destory_backend(self) -> None: def warmup(self) -> bool: if self.global_world_size > 1: try: - col.allreduce(self.dummy_cache[0], self.group_name) + col.allreduce(self.dummy_buffer[0][0], self.group_name) # pylint: disable=W0703 except Exception as e: logger.info("warmup migration backend failed (group_name: {}, world_size: {}, rank: {}, backbend: {}), err: {}." @@ -241,8 +241,8 @@ def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int] tot_blocks = len(src_blocks) src_rank = ray.get(self.actor.exec_method.remote(self.is_driver_worker, src_handle, "get_global_rank")) - for start_idx in range(0, tot_blocks, self.num_migration_cache_blocks): - offset = min(self.num_migration_cache_blocks, tot_blocks - start_idx) + for start_idx in range(0, tot_blocks, self.num_migration_buffer_blocks): + offset = min(self.num_migration_buffer_blocks, tot_blocks - start_idx) send_blocks = src_blocks[start_idx:start_idx+offset] recv_blocks = dst_blocks[start_idx:start_idx+offset] self.actor.exec_method.remote(self.is_driver_worker, src_handle, "do_send", self.global_rank, send_blocks) @@ -250,7 +250,8 @@ def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int] def do_send(self, dst_handle, blocks: List[int]): num_blocks = len(blocks) - send_cache = self.dummy_cache[:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size) + dummy_cache_idx = self.get_available_cache() + send_cache = self.dummy_buffer[dummy_cache_idx][:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size) src_to_dst = {block_num: idx for idx, block_num in enumerate(blocks)} with self.migration_stream: @@ -261,11 +262,13 @@ def do_send(self, dst_handle, blocks: List[int]): # TODO(KuilongCui): check the error code if peer is dead col.send(send_cache, dst_handle, self.group_name) self.migration_stream.synchronize() + self.put_back_cache(dummy_cache_idx) def do_recv(self, src_handle, blocks: List[int]): num_blocks = len(blocks) src_to_dst = dict(enumerate(blocks)) - recv_cache = self.dummy_cache[:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size) + dummy_cache_idx = self.get_available_cache() + recv_cache = self.dummy_buffer[dummy_cache_idx][:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size) with self.migration_stream: for layer_idx in range(self.cache_engine.num_layers): @@ -274,16 +277,19 @@ def do_recv(self, src_handle, blocks: List[int]): col.recv(recv_cache, src_handle, self.group_name) self.cache_engine.attn_backend.swap_blocks(recv_cache[cache_idx], self.gpu_cache[layer_idx], src_to_dst) self.migration_stream.synchronize() + self.put_back_cache(dummy_cache_idx) def get_migration_backend(migration_config: MigrationConfig, cache_engine: CacheEngine, worker_handle_list, scheduling_strategy, is_driver_worker, gpu_cache, worker_rank, local_rank) -> MigrationBackendBase: - if cache_engine.num_gpu_blocks < migration_config.migration_cache_blocks: - logger.warning("migration_cache_blocks({}) is larger than num_gpu_blocks({}), reducing it to num_gpu_blocks." - .format(migration_config.migration_cache_blocks, cache_engine.num_gpu_blocks)) - migration_config.migration_cache_blocks = cache_engine.num_gpu_blocks + if cache_engine.num_gpu_blocks < migration_config.migration_buffer_blocks: + logger.warning("migration_buffer_blocks({}) is larger than num_gpu_blocks({}), reducing it to num_gpu_blocks." + .format(migration_config.migration_buffer_blocks, cache_engine.num_gpu_blocks)) + migration_config.migration_buffer_blocks = cache_engine.num_gpu_blocks target_col = None backend = migration_config.migration_backend + assert backend in ['nccl', 'gloo', 'rpc'], "Unsupported backend: {} for VLLM".format(backend) + if backend in ['nccl', 'gloo']: target_col = RayColMigrationBackend(migration_config, cache_engine, local_rank, scheduling_strategy, is_driver_worker, gpu_cache) diff --git a/llumnix/backends/vllm/worker.py b/llumnix/backends/vllm/worker.py index 92bf1f1b..2b0cab33 100644 --- a/llumnix/backends/vllm/worker.py +++ b/llumnix/backends/vllm/worker.py @@ -14,7 +14,6 @@ import time from typing import Dict, List import math -import ray import torch from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy @@ -50,10 +49,11 @@ def get_global_rank(self): def reserve_memory_for_migration(self, migration_config: MigrationConfig, model_config: ModelConfig, cache_config: CacheConfig, parallel_config: ParallelConfig) -> int: - migrate_cache_blocks_size = migration_config.migration_cache_blocks + migrate_cache_blocks_size = migration_config.migration_buffer_blocks migrate_num_layers = migration_config.migration_num_layers - dummy_cache_size = migrate_num_layers * migrate_cache_blocks_size * CacheEngine.get_cache_block_size( - cache_config, model_config, parallel_config) // model_config.get_num_layers(parallel_config) + dummy_cache_size = migration_config.migration_internal_buffer_num * migrate_num_layers * migrate_cache_blocks_size \ + * CacheEngine.get_cache_block_size(cache_config, model_config, parallel_config) \ + // model_config.get_num_layers(parallel_config) # For nccl migration backend, reserve gpu memory for dummy cache in migration backend. For other backends, # CPU memory is used for the dummy cache, which is almost unlimited, so no special action is needed. @@ -111,14 +111,16 @@ def migrate_cache(self, src_worker_handle_list, src_blocks: List[int], dst_block start_time = time.time() try: self.migration_backend.migrate_cache(src_worker_handle, src_blocks, dst_blocks) - except ray.exceptions.RayActorError: - logger.info("[migrate_cache] self.rank: {}, src_worker_handle {} is dead".format(self.rank, src_worker_handle)) + # pylint: disable=broad-except + except Exception as e: + logger.info("[migrate_cache] self.rank: {}, src_worker_handle {}, meet error : {}" + .format(self.rank, src_worker_handle, e)) end_time = time.time() total_kv_cache_size = len(src_blocks) * CacheEngine.get_cache_block_size( self.cache_config, self.model_config, self.parallel_config) speed = total_kv_cache_size/_GB/(end_time - start_time) - logger.info("[migration_cache] blocks_num: {}, total_kv_cache_size: {}, time: {}s, speed: {}GB/s." + logger.info("[migrate_cache] blocks_num: {}, total_kv_cache_size: {}, time: {}s, speed: {}GB/s." .format(len(src_blocks), convert_bytes(total_kv_cache_size), end_time-start_time, speed)) def do_recv(self, *args, **kwargs): @@ -150,7 +152,3 @@ def shutdown(self) -> None: del self.migration_backend torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() - - def restart(self) -> None: - self.init_model() - self.init_cache_engine(self.cache_config) diff --git a/llumnix/config/default.py b/llumnix/config/default.py index fb94443b..2a6c7758 100644 --- a/llumnix/config/default.py +++ b/llumnix/config/default.py @@ -108,9 +108,11 @@ # Timeout(s) for initializing migration backend _C.MANAGER.MIGRATION_BACKEND_INIT_TIMEOUT = 10.0 # Number of cache blocks in migration -_C.MANAGER.MIGRATION_CACHE_BLOCKS = 512 +_C.MANAGER.MIGRATION_BUFFER_BLOCKS = 512 # Number of kv-cache layers to transfer in each round during migration _C.MANAGER.MIGRATION_NUM_LAYERS = 1 +# Number of internal cache size in migration backend for sending and receiving +_C.MANAGER.MIGRATION_INTERNAL_BUFFER_NUM = 2 # ----------------------------------------------------------------------------- # SCALING CONFIGURATION diff --git a/llumnix/global_scheduler/dispatch_scheduler.py b/llumnix/global_scheduler/dispatch_scheduler.py index 175bdbde..27458f26 100644 --- a/llumnix/global_scheduler/dispatch_scheduler.py +++ b/llumnix/global_scheduler/dispatch_scheduler.py @@ -72,6 +72,10 @@ def remove_instance(self, instance_id: str) -> None: if instance_id in self.available_dispatch_instance_set: self.available_dispatch_instance_set.remove(instance_id) + if self.num_instances >= self.num_dispatch_instances: + free_instance_id = next(iter(self.instance_id_set - self.available_dispatch_instance_set)) + self.available_dispatch_instance_set.add(free_instance_id) + def _sort_instance_infos(self, descending: bool = True) -> None: instance_infos: List[InstanceInfo] = list(self.instance_info.values()) diff --git a/llumnix/global_scheduler/migration_scheduler.py b/llumnix/global_scheduler/migration_scheduler.py index 3445b210..77fd9b25 100644 --- a/llumnix/global_scheduler/migration_scheduler.py +++ b/llumnix/global_scheduler/migration_scheduler.py @@ -170,8 +170,10 @@ def pair_migration(self, migrate_instance_pairs = [] for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))): load_diff_before_mig = sorted_src_instance_infos[i].instance_load_migrate - sorted_dst_instance_infos[i].instance_load_migrate + left_load_after_mig = self._compute_instance_load_after_migrate(sorted_src_instance_infos[i], is_migrate_in=False) right_load_after_mig = self._compute_instance_load_after_migrate(sorted_dst_instance_infos[i], is_migrate_in=True) + # Add some constrains to reduce unnecessary migrations if right_load_after_mig > self.migrate_out_load_threshold: continue @@ -184,12 +186,14 @@ def pair_migration(self, def _compute_instance_load_after_migrate(self, instance_info: InstanceInfo, is_migrate_in: bool) -> float: instance_info_after_migrate = copy.deepcopy(instance_info) num_blocks_last_running_request = instance_info_after_migrate.num_blocks_last_running_request + if is_migrate_in: instance_info_after_migrate.num_running_requests += 1 instance_info_after_migrate.num_free_gpu_blocks -= num_blocks_last_running_request else: instance_info_after_migrate.num_running_requests -= 1 instance_info_after_migrate.num_free_gpu_blocks += num_blocks_last_running_request + return self.instance_load_calculator.compute_instance_load(instance_info_after_migrate, action='migrate') class DefragConstrained(PairMigrationPolicy): diff --git a/llumnix/internal_config.py b/llumnix/internal_config.py index 410d38e0..08f5283f 100644 --- a/llumnix/internal_config.py +++ b/llumnix/internal_config.py @@ -16,18 +16,20 @@ def __init__( self, request_migration_policy: str, migration_backend: str, - migration_cache_blocks: int, + migration_buffer_blocks: int, migration_num_layers: int, last_stage_max_blocks: int, max_stages: int, - migration_backend_init_timeout: float) -> None: + migration_backend_init_timeout: float, + migration_internal_buffer_num: int) -> None: self.request_migration_policy = request_migration_policy self.migration_backend = migration_backend self.migration_num_layers = migration_num_layers - self.migration_cache_blocks = migration_cache_blocks + self.migration_buffer_blocks = migration_buffer_blocks self.last_stage_max_blocks = last_stage_max_blocks self.max_stages = max_stages self.migration_backend_init_timeout = migration_backend_init_timeout + self.migration_internal_buffer_num = migration_internal_buffer_num class GlobalSchedulerConfig: def __init__( diff --git a/llumnix/llm_engine_manager.py b/llumnix/llm_engine_manager.py index 7b47728b..d98f3a8e 100644 --- a/llumnix/llm_engine_manager.py +++ b/llumnix/llm_engine_manager.py @@ -42,8 +42,6 @@ RETRIES_INTERVALS = 5.0 # TODO(s5u13b): Fix the logger when manager failover. - - class LLMEngineManager: def __init__(self, engine_manager_args: EngineManagerArgs, @@ -71,10 +69,7 @@ def __init__(self, logger.info("num_instances: {}".format(self.num_instances)) logger.info("max_instances: {}, min_instances: {}".format(self.max_instances, self.min_instances)) - # TODO(s5u13b): refactor auto-scaling - self.instances: Dict[str, Llumlet] = {} - self.instance_migrating: Dict[str, bool] = {} self.pending_rebuild_migration_instances = 0 self.global_scheduler = GlobalScheduler(global_scheduler_config) @@ -92,8 +87,9 @@ def __init__(self, # migrate states self.num_instance_info_updates = 0 - self.migrating = False + self.num_migrating = 0 + # TODO(s5u13b): refactor auto-scaling # auto-scaling states self.scale_up_time = -1 self.scale_down_time = -1 @@ -184,26 +180,31 @@ def update_instance_info_done_callback(instance_id: str, fut): self.global_scheduler.update_instance_infos([ret]) else: dead_instance_ids.append(instance_id) + while True: try: await asyncio.sleep(interval) tasks = [] instance_infos = [] dead_instance_ids = [] + for instance_id, instance in self.instances.items(): # Use asyncio.gather to wrap ray remote call to add done callback. task = asyncio.gather(instance.get_instance_info.remote(), return_exceptions=True) task.add_done_callback(partial(update_instance_info_done_callback, instance_id)) tasks.append(task) await asyncio.gather(*tasks, return_exceptions=True) + if len(dead_instance_ids) > 0: logger.info("[_update_instance_info_loop] dead instances: {}.".format(dead_instance_ids)) self.scale_down(dead_instance_ids) self.num_instance_info_updates += 1 + # Push migrate when the instance_info have updated a certain number of times. if self.enable_migration and self.num_instance_info_updates != 0 \ and self.num_instance_info_updates % self.pair_migration_frequency == 0: asyncio.create_task(self._push_migrations()) + if self.log_instance_info: self._log_instance_infos_to_csv(instance_infos) # pylint: disable=W0703 @@ -217,6 +218,7 @@ async def _clear_request_instance_loop(self, interval: float): while True: await asyncio.sleep(interval) self.request_instance = {} + async def _push_migrations(self) -> None: # Push migrate when the instance_info have updated a certain number of times. if self.enable_pd_disagg: @@ -227,10 +229,7 @@ async def _push_migrations(self) -> None: async def _migrate(self, pair_migration_type: PairMigrationConstraints, migrate_in_num_requests: int) -> None: async def migrate_done_callback(ret, migrate_instance_pair: Tuple[str, str]) -> None: - if migrate_instance_pair[0] in self.instance_migrating: - self.instance_migrating[migrate_instance_pair[0]] = False - if migrate_instance_pair[1] in self.instance_migrating: - self.instance_migrating[migrate_instance_pair[1]] = False + self.num_migrating -= 1 if isinstance(ret, (ray.exceptions.RayActorError, KeyError)): has_error_pair = await self._check_instance_error(migrate_instance_pair) for i, has_error in enumerate(has_error_pair): @@ -252,19 +251,20 @@ async def migrate_done_callback(ret, migrate_instance_pair: Tuple[str, str]) -> self.request_instance[migrate_out_request_id] = migrate_instance_pair[1] logger.info("{}->{} migrate done, migrate request {}".format( migrate_instance_pair[0], migrate_instance_pair[1], migrate_out_request_ids)) + def migrate_done_callback_wrapper(migrate_instance_pair: Tuple[str, str], fut) -> None: ret = fut.result() loop = asyncio.get_event_loop() loop.create_task(migrate_done_callback(ret, migrate_instance_pair)) - migrate_instance_pairs = self.global_scheduler.pair_migration(pair_migration_type) + try: + migrate_instance_pairs = self.global_scheduler.pair_migration(pair_migration_type) + migration_tasks = [] for _, migrate_instance_pair in enumerate(migrate_instance_pairs): + self.num_migrating += 1 migrate_out_instance_id, migrate_in_instance_id = migrate_instance_pair - if self.instance_migrating[migrate_out_instance_id] or self.instance_migrating[migrate_in_instance_id]: - continue - self.instance_migrating[migrate_out_instance_id] = True - self.instance_migrating[migrate_in_instance_id] = True + migrate_in_instance_name = "instance_{}".format(migrate_in_instance_id) # Use asyncio.gather to wrap ray remote call to add done callback. task = asyncio.gather(self.instances[migrate_out_instance_id].migrate_out.remote(migrate_in_instance_name, migrate_in_num_requests), @@ -280,7 +280,7 @@ def migrate_done_callback_wrapper(migrate_instance_pair: Tuple[str, str], fut) - async def rebuild_migrate_backend(self) -> None: # Wait for all instances to finish migration - while any(self.instance_migrating.values()): + while self.num_migrating > 0: await asyncio.sleep(0.1) # During rebuilding migration backend, disable migrate @@ -353,7 +353,6 @@ def scale_up(self, instance_id: Union[str, Iterable[str]], llumlet_actor_handles if ins_id not in self.instances: indeed_update = True self.instances[ins_id] = llumlet_actor_handles[idx] - self.instance_migrating[ins_id] = False if self.log_instance_info: self.instance_last_logged_empty[ins_id] = False self.pending_rebuild_migration_instances += 1 @@ -381,7 +380,6 @@ def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migrate_bac if ins_id in self.instances: indeed_update = True del self.instances[ins_id] - del self.instance_migrating[ins_id] if self.log_instance_info: del self.instance_last_logged_empty[ins_id] self.pending_rebuild_migration_instances += 1 diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index 5aa3e4c2..42be2f2c 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -130,6 +130,7 @@ async def check_state(self): async def migrate_out(self, dst_instance_name: str, num_requests: int) -> List[str]: try: + migrate_out_request = None migrate_in_ray_actor = ray.get_actor(dst_instance_name, namespace='llumnix') dst_instance_id = dst_instance_name[len("instance_"):] migrated_request_list = [] @@ -137,12 +138,13 @@ async def migrate_out(self, dst_instance_name: str, num_requests: int) -> List[s while continue_migrate and len(migrated_request_list) < num_requests: t0 = time.time() migrate_out_request = self.migration_scheduler.get_migrate_out_request() - if migrate_out_request is not None: - logger.info("migrate_out {}".format(migrate_out_request.request_id)) if migrate_out_request is None: - return migrated_request_list + break + + migrate_out_request.migrating = True logger.info("{}->{} begin migrate out {}".format(self.instance_id, dst_instance_id, migrate_out_request.request_id)) status = await self.migration_coordinator.migrate_out_multistage(migrate_in_ray_actor, migrate_out_request) + if status == MigrationStatus.FINISHED_DONE: await migrate_in_ray_actor.execute_engine_method.remote("commit_dst_request", migrate_out_request) self.backend_engine.free_src_request(migrate_out_request) @@ -157,8 +159,13 @@ async def migrate_out(self, dst_instance_name: str, num_requests: int) -> List[s logger.info("{}->{} migrate done, migrate request {}, status:{}, len:{} blocks, cost:{} ms" \ .format(self.instance_id, dst_instance_id, migrated_request_list, status, \ sum(migrate_out_request.stage_num_blocks_list), (t1 - t0)*1000)) - except ray.exceptions.RayActorError: - logger.info("[migrate_out] instance {} is dead".format(dst_instance_name[len("instance_"):])) + # pylint: disable=broad-except + except Exception as e: + if migrate_out_request: + migrate_out_request.reset_migration_args() + + logger.info("[migrate_out] src instance {}, dst instance {}, meet error: {}" + .format(self.instance_id, dst_instance_name[len("instance_"):], e)) raise return migrated_request_list diff --git a/llumnix/llumlet/local_migration_scheduler.py b/llumnix/llumlet/local_migration_scheduler.py index e630d982..ad676cc1 100644 --- a/llumnix/llumlet/local_migration_scheduler.py +++ b/llumnix/llumlet/local_migration_scheduler.py @@ -39,35 +39,52 @@ def get_migrate_out_request(self, min_request_len=0, max_request_len=np.inf) -> # and only selects request that migrates from the prefill instance to the decoding instance. def get_ready_migration_request(self, min_request_len, max_request_len): running: List[LlumnixRequest] = self.backend_engine.get_running_queue() + target_request: LlumnixRequest = None for request in reversed(running): + if request.migrating: + continue + if request.output_len >= request.expected_steps \ and request.inference_type == RequestInferenceType.DECODE \ and min_request_len <= request.request_len <= max_request_len: - return request - return None + target_request = request + break + + return target_request def get_last_running_request(self, min_request_len, max_request_len): running: List[LlumnixRequest] = self.backend_engine.get_running_queue() + target_request: LlumnixRequest = None + for request in reversed(running): + if request.migrating: + continue + if request.inference_type == RequestInferenceType.DECODE \ and min_request_len <= request.request_len <= max_request_len: - return request - return None + target_request=request + break + + return target_request def get_longest_running_request(self, min_request_len, max_request_len): running: List[LlumnixRequest] = self.backend_engine.get_running_queue() condition = lambda request : request.inference_type == RequestInferenceType.DECODE \ - and min_request_len <= request.request_len <= max_request_len + and min_request_len <= request.request_len <= max_request_len \ + and (not request.migrating) longest_seq_group = max((request for request in running if condition(request)), \ key=lambda request: request.request_len, default=None) + return longest_seq_group def get_shortest_running_request(self, min_request_len, max_request_len): running: List[LlumnixRequest] = self.backend_engine.get_running_queue() condition = lambda request : request.inference_type == RequestInferenceType.DECODE \ - and min_request_len <= request.request_len <= max_request_len + and min_request_len <= request.request_len <= max_request_len \ + and (not request.migrating) shortest_seq_group = min((request for request in running if condition(request)), \ key=lambda request: request.request_len, default=None) + return shortest_seq_group diff --git a/llumnix/llumlet/request.py b/llumnix/llumlet/request.py index 2319f52f..c2aeda9e 100644 --- a/llumnix/llumlet/request.py +++ b/llumnix/llumlet/request.py @@ -32,6 +32,7 @@ def __init__(self, request_id: int, server_info: ServerInfo, expected_steps: int self.last_preemption_time = None self.stage_timestamps = [] self.stage_num_blocks_list = [] + self.migrating = False def reset_migration_args(self): self.last_preemption_time = None @@ -39,6 +40,7 @@ def reset_migration_args(self): self.stage_num_blocks_list = [] # By default, there is no limit on the number of steps expected for the request. self.expected_steps = math.inf + self.migrating = False def is_finished(self) -> bool: raise NotImplementedError diff --git a/tests/conftest.py b/tests/conftest.py index 2749ba00..ba3b467c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,20 +12,16 @@ # limitations under the License. import subprocess -from time import sleep import ray import pytest def pytest_sessionstart(session): - subprocess.run(["ray", "stop", "--force"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - sleep(3) + subprocess.run(["ray", "stop"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) subprocess.run(["ray", "start", "--head", "--disable-usage-stats", "--port=6379"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - sleep(3) def pytest_sessionfinish(session, exitstatus): subprocess.run(["ray", "stop"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - sleep(3) @pytest.fixture def setup_ray_env(): diff --git a/tests/e2e_test/test_bench.py b/tests/e2e_test/test_bench.py index b6d70d8f..eb93fb89 100644 --- a/tests/e2e_test/test_bench.py +++ b/tests/e2e_test/test_bench.py @@ -20,7 +20,8 @@ import numpy as np from .test_e2e import generate_launch_command, clear_ray_state -from .utils import to_markdown_table +# pylint: disable=unused-import +from .utils import to_markdown_table, clean_ray def launch_llumnix_service(command): subprocess.run(command, shell=True, check=True) @@ -90,7 +91,7 @@ def get_markdown_data(key: str, head_name: str): @pytest.mark.asyncio @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="at least 1 gpus required for simple benchmark") @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) -async def test_simple_benchmark(model): +async def test_simple_benchmark(clean_ray, model): device_count = torch.cuda.device_count() base_port = 37037 for i in range(device_count): @@ -107,7 +108,7 @@ async def run_bench_command(command): tasks = [] for i in range(device_count): - bench_command = generate_bench_command(ip_ports=f"127.0.0.1:{base_port+i}", model=model, num_prompts=500, + bench_command = generate_bench_command(ip_ports=f"127.0.0.1:{base_port+i}", model=model, num_prompts=300, dataset_type="sharegpt", dataset_path="/mnt/dataset/sharegpt_gpt4/sharegpt_gpt4.jsonl" , qps=2, diff --git a/tests/e2e_test/test_e2e.py b/tests/e2e_test/test_e2e.py index 741360f1..11b8617f 100644 --- a/tests/e2e_test/test_e2e.py +++ b/tests/e2e_test/test_e2e.py @@ -20,7 +20,8 @@ import torch from vllm import LLM, SamplingParams - +# pylint: disable=unused-import +from .utils import clean_ray def parse_launch_mode(launch_mode: str): # 'eief' means that enable init instance by manager and enable fixed node init instance, and so on. @@ -46,7 +47,7 @@ def generate_launch_command(result_filename: str = "", launch_ray_cluster: bool disable_init_instance_by_manager, disable_fixed_node_init_instance = parse_launch_mode(launch_mode) command = ( f"RAY_DEDUP_LOGS=0 HEAD_NODE_IP={HEAD_NODE_IP} HEAD_NODE=1 " - f"nohup python -m llumnix.entrypoints.vllm.api_server " + f"nohup python -u -m llumnix.entrypoints.vllm.api_server " f"--host {ip} " f"--port {port} " f"{'--disable-init-instance-by-manager ' if disable_init_instance_by_manager else ''}" @@ -63,7 +64,8 @@ def generate_launch_command(result_filename: str = "", launch_ray_cluster: bool f"--trust-remote-code " f"--request-migration-policy LCFS " f"--migration-backend {migration_backend} " - f"--migration-cache-blocks 32 " + f"--migration-buffer-blocks 32 " + f"--migration-internal-buffer-num 2 " f"--tensor-parallel-size 1 " f"--request-output-queue-port {1234+port} " f"{'--enable-pd-disagg ' if enable_pd_disagg else ''} " @@ -123,6 +125,8 @@ async def get_llumnix_response(prompt, sampling_params, ip_ports): "The future of AI is", ] +vllm_output = {} + @ray.remote(num_gpus=1) def run_vllm(model, max_model_len, sampling_params): vllm_output = {} @@ -137,9 +141,9 @@ def run_vllm(model, max_model_len, sampling_params): @pytest.mark.asyncio @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="at least 1 gpus required for e2e test") @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) -@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo', 'nccl']) +@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo']) @pytest.mark.parametrize("launch_mode", ['eief', 'eidf', 'dief', 'didf']) -async def test_e2e(model, migration_backend, launch_mode): +async def test_e2e(clean_ray, model, migration_backend, launch_mode): if migration_backend == 'gloo' and launch_mode != 'eief': pytest.skip("When the migration backend is gloo, the launch mode of llumnix can only be eief") max_model_len = 370 @@ -165,9 +169,12 @@ async def test_e2e(model, migration_backend, launch_mode): shutdown_llumnix_service() - vllm_output = ray.get(run_vllm.remote(model, max_model_len, sampling_params)) - clear_ray_state() + global vllm_output + if len(vllm_output) == 0: + vllm_output = ray.get(run_vllm.remote(model, max_model_len, sampling_params)) + + clear_ray_state() # compare for prompt in prompts: assert llumnix_output[prompt] == vllm_output[prompt] diff --git a/tests/e2e_test/test_migration.py b/tests/e2e_test/test_migration.py index ddf7fb51..b1f446f1 100644 --- a/tests/e2e_test/test_migration.py +++ b/tests/e2e_test/test_migration.py @@ -22,7 +22,8 @@ from .test_e2e import generate_launch_command from .test_bench import generate_bench_command, clear_ray_state, shutdown_llumnix_service -from .utils import to_markdown_table +# pylint: disable=unused-import +from .utils import to_markdown_table, clean_ray size_pattern = re.compile(r'total_kv_cache_size:\s*([\d.]+)\s*(B|KB|MB|GB|KB|TB)') speed_pattern = re.compile(r'speed:\s*([\d.]+)GB/s') @@ -66,21 +67,19 @@ def parse_manager_log_file(log_file): @pytest.mark.asyncio @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="at least 2 gpus required for migration bench") @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) -@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo', 'nccl']) -@pytest.mark.parametrize("enable_pd_disagg", [False, True]) -async def test_migration_benchmark(model, migration_backend, enable_pd_disagg): +@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo']) +async def test_migration_benchmark(clean_ray, model, migration_backend): base_port = 37037 instance_output_logs = [] device_count = torch.cuda.device_count() - num_dispatch_instances = device_count//2 if enable_pd_disagg else math.inf for i in range(device_count): output_log = f"{base_port+i}.out" instance_output_logs.append("instance_"+output_log) launch_command = generate_launch_command(result_filename=output_log, launch_ray_cluster=False, port=base_port+i, model=model, dispatch_policy="flood", migration_backend=migration_backend, - log_instance_info=True, enable_pd_disagg=enable_pd_disagg, - num_dispatch_instances=num_dispatch_instances) + log_instance_info=True, enable_pd_disagg=False, + num_dispatch_instances=math.inf) subprocess.run(launch_command, shell=True, check=True) await asyncio.sleep(60) @@ -89,13 +88,19 @@ async def run_bench_command(command): await process.wait() assert process.returncode == 0 + tasks = [] for i in range(device_count//2): bench_command = generate_bench_command(ip_ports=f"127.0.0.1:{base_port+i}", model=model, num_prompts=300, dataset_type="sharegpt", dataset_path="/mnt/dataset/sharegpt_gpt4/sharegpt_gpt4.jsonl" , - qps=10) - await asyncio.wait_for(run_bench_command(bench_command), timeout=60*30) - await asyncio.sleep(30) + qps=10, + results_filename=f"{base_port+i}.out") + tasks.append(asyncio.create_task(run_bench_command(bench_command))) + + _, pending = await asyncio.wait(tasks, timeout=60*30) + + if len(pending) > 0: + raise RuntimeError("migration task Timeout") parse_manager_log_file("manager_instance.csv") diff --git a/tests/e2e_test/utils.py b/tests/e2e_test/utils.py index 62d9bff8..492eb2fd 100644 --- a/tests/e2e_test/utils.py +++ b/tests/e2e_test/utils.py @@ -11,6 +11,10 @@ # See the License for the specific language governing permissions and # limitations under the License. + +import subprocess +import pytest + def to_markdown_table(data): headers = data[0] rows = data[1:] @@ -27,3 +31,11 @@ def to_markdown_table(data): table = f"{header_row}\n{separator_row}\n" + "\n".join(data_rows) + "\n\n" return table + +@pytest.fixture +def clean_ray(): + subprocess.run(["ray", "stop"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + subprocess.run(["ray", "start", "--head", "--disable-usage-stats", "--port=6379"], check=True, + stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + yield + subprocess.run(["ray", "stop"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) diff --git a/tests/unit_test/backends/vllm/test_migration.py b/tests/unit_test/backends/vllm/test_migration.py index 2a8ad19e..e5bf4567 100644 --- a/tests/unit_test/backends/vllm/test_migration.py +++ b/tests/unit_test/backends/vllm/test_migration.py @@ -56,7 +56,7 @@ def __init__(self): async def test_migration_correctness(setup_ray_env, migration_backend): engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) id_rank_map = {"0":0, "1":1} - migration_config = MigrationConfig("LCFS", migration_backend, 16, 1, 4, 5, 20) + migration_config = MigrationConfig("LCFS", migration_backend, 16, 1, 4, 5, 20, 2) output_queue_type = QueueType.RAYQUEUE que, server_info = request_output_queue_server(output_queue_type) @@ -144,9 +144,9 @@ async def test_correctness(prompt): @pytest.mark.parametrize("migration_backend", ['rpc', 'gloo', 'nccl']) @pytest.mark.asyncio async def test_pd_diaggregation_correctness(setup_ray_env, migration_backend): - engine_args = EngineArgs(model="facebook/opt-125m",worker_use_ray=True) - id_rank_map = {"0":0,"1":1} - migration_config = MigrationConfig("LCFS", migration_backend, 16, 1, 4, 5, 20) + engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) + id_rank_map = {"0":0, "1":1} + migration_config = MigrationConfig("LCFS", migration_backend, 16, 1, 4, 5, 20, 2) output_queue_type = QueueType.RAYQUEUE que, server_info = request_output_queue_server(output_queue_type) @@ -174,12 +174,15 @@ async def test_pd_diaggregation_correctness(setup_ray_env, migration_backend): migration_config, engine_args, ) + while True: res = ray.get([llumlet_0.is_ready.remote(),llumlet_1.is_ready.remote()]) if all(res): break - ray.get([llumlet_0.execute_engine_method.remote("_run_workers","rebuild_migration_backend", id_rank_map, "llumnix"), - llumlet_1.execute_engine_method.remote("_run_workers","rebuild_migration_backend", id_rank_map, "llumnix")]) + + ray.get([llumlet_0.execute_engine_method.remote("_run_workers", "rebuild_migration_backend", id_rank_map, "llumnix"), + llumlet_1.execute_engine_method.remote("_run_workers", "rebuild_migration_backend", id_rank_map, "llumnix")]) + # empty instance migrate out res = ray.get(llumlet_0.migrate_out.remote("instance_1", num_requests=math.inf)) assert not res @@ -222,8 +225,10 @@ async def test_correctness(prompt): assert output.text == origin_output.text assert output.cumulative_logprob == origin_output.cumulative_logprob + for prompt in TEST_PROMPTS: await test_correctness(prompt) + que.cleanup() def test_clear_migration_states(): diff --git a/tests/unit_test/backends/vllm/test_migration_backend.py b/tests/unit_test/backends/vllm/test_migration_backend.py index 2bb008ee..12ec324c 100644 --- a/tests/unit_test/backends/vllm/test_migration_backend.py +++ b/tests/unit_test/backends/vllm/test_migration_backend.py @@ -26,6 +26,44 @@ from tests.conftest import setup_ray_env from .test_worker import create_worker +def get_ready_workers(num_worker, num_gpu_blocks, engine_config, migraiton_config): + workers = [] + worker_ids = [] + + for _ in range(num_worker): + worker_id = random_uuid() + worker = create_worker(rank=0, local_rank=0, engine_config=engine_config, + worker_module_name="tests.unit_test.backends.vllm.test_migration_backend", + worker_class_name="MockMigrationWorker") + ray.get(worker.execute_method.remote('initialize_cache', num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0)) + ray.get(worker.execute_method.remote( + 'init_migration', + instance_id=worker_id, + migration_config=migraiton_config, + src_worker_handle_list=[worker], + node_id=ray.get_runtime_context().get_node_id())) + + workers.append(worker) + worker_ids.append(worker_id) + + instance_rank = {} + for idx, worker_id in enumerate(worker_ids): + instance_rank[worker_id] = idx + group_name = random_uuid() + + init_group_tasks =[] + for worker in workers: + init_group_tasks.append(worker.execute_method.remote('rebuild_migration_backend', + instance_rank=instance_rank, group_name=group_name)) + assert all(ray.get(init_group_tasks)) + + warmup_tasks = [] + for worker in workers: + warmup_tasks.append(worker.execute_method.remote('warmup')) + assert all(ray.get(warmup_tasks)) + + return workers, worker_ids + class MockMigrationWorker(MigrationWorker): def set_gpu_cache(self, data): for layer_idx in range(self.cache_engine.num_layers): @@ -34,75 +72,120 @@ def set_gpu_cache(self, data): def get_gpu_cache(self): torch.cuda.synchronize() - return self.gpu_cache + gpu_data = [] + for layer_idx in range(self.cache_engine.num_layers): + gpu_data.append(self.gpu_cache[layer_idx].clone().cpu()) + return gpu_data -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPU to run the test.") -@pytest.mark.parametrize("backend", ['rpc', 'gloo', 'nccl']) -def test_migrate_cache(setup_ray_env, backend): +@pytest.mark.skipif(torch.cuda.device_count() < 3, reason="Need at least 3 GPU to run the test.") +@pytest.mark.parametrize("backend", ['rpc', 'gloo']) +def test_one_to_many_migrate_cache(setup_ray_env, backend): engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config() - migraiton_config = EngineManagerArgs(migration_cache_blocks=3, migration_num_layers=5).create_migration_config() + migration_internal_buffer_num = 2 + migraiton_config = EngineManagerArgs(migration_buffer_blocks=3, migration_num_layers=5, + migration_internal_buffer_num=migration_internal_buffer_num).create_migration_config() migraiton_config.migration_backend = backend - worker0 = create_worker(rank=0, local_rank=0, engine_config=engine_config, - worker_module_name="tests.unit_test.backends.vllm.test_migration_backend", - worker_class_name="MockMigrationWorker") - worker1 = create_worker(rank=0, local_rank=0, engine_config=engine_config, - worker_module_name="tests.unit_test.backends.vllm.test_migration_backend", - worker_class_name="MockMigrationWorker") - - ray.get(worker0.execute_method.remote('init_device')) - ray.get(worker1.execute_method.remote('init_device')) - - num_gpu_blocks = 8 - ray.get(worker0.execute_method.remote('initialize_cache', num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0)) - ray.get(worker1.execute_method.remote('initialize_cache', num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0)) - - worker0_id = random_uuid() - ray.get(worker0.execute_method.remote( - 'init_migration', - instance_id=worker0_id, - migration_config=migraiton_config, - src_worker_handle_list=[worker0], - node_id=ray.get_runtime_context().get_node_id())) - - worker1_id = random_uuid() - ray.get(worker1.execute_method.remote( - 'init_migration', - instance_id=worker1_id, - migration_config=migraiton_config, - src_worker_handle_list=[worker1], - node_id=ray.get_runtime_context().get_node_id())) - - instance_rank = {worker0_id: 0, worker1_id: 1} - group_name = random_uuid() - assert all(ray.get([worker0.execute_method.remote('rebuild_migration_backend', - instance_rank=instance_rank, group_name=group_name), - worker1.execute_method.remote('rebuild_migration_backend', - instance_rank=instance_rank, group_name=group_name)])) - assert all(ray.get([worker0.execute_method.remote('warmup'), - worker1.execute_method.remote('warmup')])) + num_worker = 3 + num_gpu_blocks = 6000 + workers, _ = get_ready_workers(num_worker, num_gpu_blocks, engine_config, migraiton_config) num_layers = engine_config.model_config.get_num_layers(engine_config.parallel_config) head_size = engine_config.model_config.get_head_size() num_heads = engine_config.model_config.get_num_kv_heads(engine_config.parallel_config) block_size = engine_config.cache_config.block_size + dummy_data = torch.randn(size=(num_layers, 2, num_gpu_blocks, block_size*num_heads*head_size)) + ray.get(workers[0].execute_method.remote('set_gpu_cache', data=dummy_data)) + worker0_data = ray.get(workers[0].execute_method.remote('get_gpu_cache')) + + dst_blocks = list(range(num_gpu_blocks)) + random.shuffle(dst_blocks) + + single_worker_num_blocks = len(dst_blocks)//(num_worker-1) + migration_tasks = [] + worker_idx = 1 + per_step_blocks = 500 + for offset in range(0, len(dst_blocks), single_worker_num_blocks): + src_to_dst = dict(enumerate(dst_blocks[offset:offset+single_worker_num_blocks])) + src_blocks = list(src_to_dst.keys()) + dst_blocks = list(src_to_dst.values()) + for idx in range(0, len(src_blocks), per_step_blocks): + cur_src_blocks = src_blocks[idx:idx+per_step_blocks] + cur_dst_blocks = dst_blocks[idx:idx+per_step_blocks] + migration_tasks.append(workers[0].execute_method.remote( + 'migrate_cache', + src_worker_handle_list=[workers[worker_idx]], + src_blocks=cur_src_blocks, + dst_blocks=cur_dst_blocks) + ) + worker_idx += 1 + ray.get(migration_tasks) + + worker_idx = 1 + for offset in range(0, len(dst_blocks), single_worker_num_blocks): + src_to_dst = dict(enumerate(dst_blocks[offset:offset+single_worker_num_blocks])) + dst_worker_data = ray.get(workers[worker_idx].execute_method.remote('get_gpu_cache')) + for layer_idx in range(num_layers): + for src_idx, dst_idx in src_to_dst.items(): + assert torch.allclose(worker0_data[layer_idx][0][src_idx], dst_worker_data[layer_idx][0][dst_idx]) + assert torch.allclose(worker0_data[layer_idx][1][src_idx], dst_worker_data[layer_idx][1][dst_idx]) + worker_idx += 1 + +@pytest.mark.skipif(torch.cuda.device_count() < 3, reason="Need at least 3 GPU to run the test.") +@pytest.mark.parametrize("backend", ['rpc', 'gloo']) +def test_many_to_one_migrate_cache(setup_ray_env, backend): + engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config() + migration_internal_buffer_num = 2 + migraiton_config = EngineManagerArgs(migration_buffer_blocks=3, migration_num_layers=5, + migration_internal_buffer_num=migration_internal_buffer_num).create_migration_config() + migraiton_config.migration_backend = backend + num_worker = 3 + num_gpu_blocks = 6000 + workers, _ = get_ready_workers(num_worker, num_gpu_blocks, engine_config, migraiton_config) + + num_layers = engine_config.model_config.get_num_layers(engine_config.parallel_config) + head_size = engine_config.model_config.get_head_size() + num_heads = engine_config.model_config.get_num_kv_heads(engine_config.parallel_config) + block_size = engine_config.cache_config.block_size dummy_data = torch.randn(size=(num_layers, 2, num_gpu_blocks, block_size*num_heads*head_size)) - ray.get(worker0.execute_method.remote('set_gpu_cache', data=dummy_data)) - worker0_data = ray.get(worker0.execute_method.remote('get_gpu_cache')) + + worker_datas = [0] + for idx in range(1, num_worker): + ray.get(workers[idx].execute_method.remote('set_gpu_cache', data=dummy_data)) + worker_datas.append(ray.get(workers[idx].execute_method.remote('get_gpu_cache'))) dst_blocks = list(range(num_gpu_blocks)) random.shuffle(dst_blocks) - src_to_dst = dict(enumerate(dst_blocks)) - ray.get(worker1.execute_method.remote( - 'migrate_cache', - src_worker_handle_list=[worker0], - src_blocks=list(src_to_dst.keys()), - dst_blocks=list(src_to_dst.values()))) - - worker1_data = ray.get(worker1.execute_method.remote('get_gpu_cache')) - - for layer_idx in range(num_layers): - for src_idx, dst_idx in src_to_dst.items(): - assert torch.allclose(worker0_data[layer_idx][0][src_idx], worker1_data[layer_idx][0][dst_idx]) - assert torch.allclose(worker0_data[layer_idx][1][src_idx], worker1_data[layer_idx][1][dst_idx]) + + single_worker_num_blocks = len(dst_blocks)//(num_worker-1) + migration_tasks = [] + worker_idx = 1 + per_step_blocks = 500 + for offset in range(0, len(dst_blocks), single_worker_num_blocks): + src_to_dst = dict(enumerate(dst_blocks[offset:offset+single_worker_num_blocks])) + src_blocks = list(src_to_dst.keys()) + dst_blocks = list(src_to_dst.values()) + for idx in range(0, len(src_blocks), per_step_blocks): + cur_src_blocks = src_blocks[idx:idx+per_step_blocks] + cur_dst_blocks = dst_blocks[idx:idx+per_step_blocks] + migration_tasks.append(workers[0].execute_method.remote( + 'migrate_cache', + src_worker_handle_list=[workers[worker_idx]], + src_blocks=cur_src_blocks, + dst_blocks=cur_dst_blocks) + ) + worker_idx += 1 + ray.get(migration_tasks) + + dst_worker_data = ray.get(workers[0].execute_method.remote('get_gpu_cache')) + + worker_idx = 1 + for offset in range(0, len(dst_blocks), single_worker_num_blocks): + src_to_dst = dict(enumerate(dst_blocks[offset:offset+single_worker_num_blocks])) + + for layer_idx in range(num_layers): + for src_idx, dst_idx in src_to_dst.items(): + assert torch.allclose(worker_datas[worker_idx][layer_idx][0][src_idx], dst_worker_data[layer_idx][0][dst_idx]) + assert torch.allclose(worker_datas[worker_idx][layer_idx][1][src_idx], dst_worker_data[layer_idx][1][dst_idx]) + worker_idx += 1 diff --git a/tests/unit_test/backends/vllm/test_simulator.py b/tests/unit_test/backends/vllm/test_simulator.py index 7fb94baa..c0753b06 100644 --- a/tests/unit_test/backends/vllm/test_simulator.py +++ b/tests/unit_test/backends/vllm/test_simulator.py @@ -71,7 +71,7 @@ async def test_backend(setup_ray_env): # TODO(ZeldaHuang): add tests for BackendSimVLLM methods # (currently BackendSimVLLM is just a wrapper of BackendVLLM) engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) - migration_config = MigrationConfig("LCFS", "gloo", 16, 1, 4, 5, 20) + migration_config = MigrationConfig("LCFS", "gloo", 16, 1, 4, 5, 20, 2) output_queue_type = QueueType.RAYQUEUE que, server_info = request_output_queue_server(output_queue_type) diff --git a/tests/unit_test/backends/vllm/test_worker.py b/tests/unit_test/backends/vllm/test_worker.py index dc014005..09df9ea0 100644 --- a/tests/unit_test/backends/vllm/test_worker.py +++ b/tests/unit_test/backends/vllm/test_worker.py @@ -39,7 +39,7 @@ def create_worker(rank: int, local_rank: int, engine_config: EngineConfig, trust_remote_code=True ) - worker.init_worker.remote( + ray.get(worker.init_worker.remote( model_config=engine_config.model_config, parallel_config=engine_config.parallel_config, scheduler_config=engine_config.scheduler_config, @@ -52,25 +52,25 @@ def create_worker(rank: int, local_rank: int, engine_config: EngineConfig, lora_config=engine_config.lora_config, vision_language_config=engine_config.vision_language_config, is_driver_worker = False - ) - + )) + ray.get(worker.execute_method.remote('init_device')) return worker @pytest.mark.parametrize("backend", ['rpc', 'gloo', 'nccl']) def test_reserve_memory_for_migration(setup_ray_env, backend): engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config() - migraiton_config = EngineManagerArgs(migration_cache_blocks=1).create_migration_config() - migraiton_config.migration_backend = backend + migration_config = EngineManagerArgs(migration_buffer_blocks=1).create_migration_config() + migration_config.migration_backend = backend worker = create_worker(rank=0, local_rank=0, engine_config=engine_config) - ray.get(worker.execute_method.remote('init_device')) block_size = CacheEngine.get_cache_block_size(engine_config.cache_config, engine_config.model_config, engine_config.parallel_config) num_layers = engine_config.model_config.get_num_layers(engine_config.parallel_config) - occupy_memory = migraiton_config.migration_cache_blocks * block_size * migraiton_config.migration_num_layers // num_layers + occupy_memory = migration_config.migration_internal_buffer_num * migration_config.migration_buffer_blocks \ + * block_size * migration_config.migration_num_layers // num_layers migration_cache_size = ray.get(worker.execute_method.remote('reserve_memory_for_migration', - migration_config=migraiton_config, + migration_config=migration_config, model_config=engine_config.model_config, cache_config=engine_config.cache_config, parallel_config=engine_config.parallel_config)) @@ -80,17 +80,16 @@ def test_reserve_memory_for_migration(setup_ray_env, backend): @pytest.mark.parametrize("backend", ['rpc', 'gloo', 'nccl']) def test_rebuild_migration_backend(setup_ray_env, backend): engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config() - migraiton_config = EngineManagerArgs(migration_cache_blocks=1).create_migration_config() - migraiton_config.migration_backend = backend + migration_config = EngineManagerArgs(migration_buffer_blocks=1).create_migration_config() + migration_config.migration_backend = backend worker0 = create_worker(rank=0, local_rank=0, engine_config=engine_config) worker0_id = random_uuid() - ray.get(worker0.execute_method.remote('init_device')) ray.get(worker0.execute_method.remote('initialize_cache', num_gpu_blocks=8, num_cpu_blocks=0)) ray.get(worker0.execute_method.remote( 'init_migration', instance_id=worker0_id, - migration_config=migraiton_config, + migration_config=migration_config, src_worker_handle_list=[worker0], node_id=ray.get_runtime_context().get_node_id())) instance_rank = {worker0_id: 0} @@ -100,12 +99,11 @@ def test_rebuild_migration_backend(setup_ray_env, backend): worker1 = create_worker(rank=0, local_rank=0, engine_config=engine_config) worker1_id = random_uuid() - ray.get(worker1.execute_method.remote('init_device')) ray.get(worker1.execute_method.remote('initialize_cache', num_gpu_blocks=8, num_cpu_blocks=0)) ray.get(worker1.execute_method.remote( 'init_migration', instance_id=worker1_id, - migration_config=migraiton_config, + migration_config=migration_config, src_worker_handle_list=[worker1], node_id=ray.get_runtime_context().get_node_id())) diff --git a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py index 8cee3a69..28fc129e 100644 --- a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py +++ b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py @@ -21,7 +21,7 @@ def init_dispatch_scheduler(policy='load'): instance_load_calculator = InstanceLoadCalculator('remaining_steps', True) - dispatch_scheduler = DispatchScheduler(policy, instance_load_calculator, random.randint(1,4)) + dispatch_scheduler = DispatchScheduler(policy, instance_load_calculator, 1) return dispatch_scheduler @pytest.fixture @@ -29,7 +29,9 @@ def dispatch_scheduler(): dispatch_scheduler = init_dispatch_scheduler() yield dispatch_scheduler -def test_add_instance_and_remove_instance(dispatch_scheduler): +@pytest.mark.parametrize("num_dispatch_instances", [1, 2, 3]) +def test_add_instance_and_remove_instance(dispatch_scheduler, num_dispatch_instances): + dispatch_scheduler.num_dispatch_instances = num_dispatch_instances dispatch_scheduler.add_instance('instance_1') assert dispatch_scheduler.num_instances == 1 assert len(dispatch_scheduler.available_dispatch_instance_set) == 1 diff --git a/tests/unit_test/global_scheduler/test_llm_engine_manager.py b/tests/unit_test/global_scheduler/test_llm_engine_manager.py index 5c5fc644..024ad4bf 100644 --- a/tests/unit_test/global_scheduler/test_llm_engine_manager.py +++ b/tests/unit_test/global_scheduler/test_llm_engine_manager.py @@ -27,6 +27,7 @@ from llumnix.queue.queue_type import QueueType from llumnix.global_scheduler.scaling_scheduler import InstanceType from llumnix.backends.vllm.simulator import BackendSimVLLM +from llumnix.backends.profiling import LatencyMemData # pylint: disable=unused-import from tests.conftest import setup_ray_env @@ -41,6 +42,7 @@ def __init__(self, instance_id): self.request_id_set = set() self.instance_info = None self.num_migrate_out = 0 + self.num_migrate_in = 0 def get_instance_id(self) -> str: return self.instance_id @@ -76,14 +78,24 @@ def abort(self, request_id): self.num_requests = len(self.request_id_set) return self.num_requests - def migrate_out(self, src_instance_name, dst_instance_name): + def migrate_out(self, dst_instance_name, num_requests): self.num_migrate_out += 1 + migrate_in_ray_actor = ray.get_actor(dst_instance_name, namespace='llumnix') + ray.get(migrate_in_ray_actor.migrate_in.remote(self.actor_name, num_requests)) + time.sleep(0.1) + return self.num_migrate_out + + def migrate_in(self, src_instance_name, num_requests): + self.num_migrate_in += 1 + return self.num_migrate_in def get_num_migrate_out(self): return self.num_migrate_out -class MockBackendSim(BackendSimVLLM): + def get_num_migrate_in(self): + return self.num_migrate_in +class MockBackendSim(BackendSimVLLM): def _get_lantecy_mem(self, *args, **kwargs): latency_mem = LatencyMemData({}, {}, {}) latency_mem.prefill_model_params = (0,0) @@ -242,20 +254,37 @@ def get_instance_info_migrate_out(instance_id): return instance_info def test_update_instance_info_loop_and_migrate(setup_ray_env, engine_manager): - instance_ids, llumlets = init_llumlets(2) - instance_id, instance_id_1 = instance_ids[0], instance_ids[1] - llumlet, llumlet_1 = llumlets[0], llumlets[1] - request_id = random_uuid() - request_id_1 = random_uuid() - ray.get(llumlet.generate.remote(request_id, None, math.inf, None, None)) - ray.get(llumlet_1.generate.remote(request_id_1, None, math.inf, None, None)) - instance_info_migrate_out = get_instance_info_migrate_out(instance_id) - instance_info_migrate_in = get_instance_info_migrate_in(instance_id_1) - ray.get(llumlet.set_instance_info.remote(instance_info_migrate_out)) - ray.get(llumlet_1.set_instance_info.remote(instance_info_migrate_in)) - num_migrate_out = ray.get(llumlet.get_num_migrate_out.remote()) - assert num_migrate_out == 0 + num_llumlets = 5 + instance_ids, llumlets = init_llumlets(num_llumlets) + + for i in range(num_llumlets): + for _ in range(2*(i+1)): + ray.get(llumlets[i].generate.remote(random_uuid(), None, math.inf, None, None)) + + instance_info = InstanceInfo() + instance_info.instance_type = InstanceType.NO_CONSTRAINTS + + for i in range(num_llumlets): + instance_info.instance_id = instance_ids[i] + instance_info.num_available_gpu_blocks = 40 - i * 10 + instance_info.num_running_requests = i + instance_info.num_blocks_first_waiting_request = i + ray.get(llumlets[i].set_instance_info.remote(instance_info)) + + for i in range(num_llumlets): + num_migrate_out = ray.get(llumlets[i].get_num_migrate_out.remote()) + assert num_migrate_out == 0 + ray.get(engine_manager.scale_up.remote(instance_ids, llumlets)) - time.sleep(0.5) - num_migrate_out = ray.get(llumlet.get_num_migrate_out.remote()) - assert num_migrate_out != 0 + time.sleep(2) + + for i in range(num_llumlets): + num_migrate_out = ray.get(llumlets[i].get_num_migrate_out.remote()) + num_migrate_in = ray.get(llumlets[i].get_num_migrate_in.remote()) + + if i == 0: + assert num_migrate_in > 1 and num_migrate_out == 0 + elif i == num_llumlets - 1: + assert num_migrate_in == 0 and num_migrate_out > 1 + else: + assert num_migrate_in == 0 and num_migrate_out == 0 diff --git a/tests/unit_test/llumlet/test_engine_step_exception.py b/tests/unit_test/llumlet/test_engine_step_exception.py index 56b58322..c630a04f 100644 --- a/tests/unit_test/llumlet/test_engine_step_exception.py +++ b/tests/unit_test/llumlet/test_engine_step_exception.py @@ -11,7 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import time import ray import torch @@ -30,28 +29,17 @@ @ray.remote(num_cpus=1, max_concurrency=4) class MockLlumlet(Llumlet): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.origin_step = self.backend_engine.engine.step_async - - def set_error_step(self, broken: bool): - self.backend_engine._stop_event.set() - + def set_error_step(self): async def raise_error_step(): await self.origin_step() raise ValueError("Mock engine step error") - if broken: - self.backend_engine.engine.step_async = raise_error_step - else: - self.backend_engine.engine.step_async = self.origin_step - - asyncio.create_task(self.backend_engine._start_engine_step_loop()) + self.backend_engine.engine.step_async = raise_error_step @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need at least 1 GPU to run the test.") def test_engine_step_exception(setup_ray_env): - engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) - migration_config = MigrationConfig("LCFS", "rpc", 16, 1, 4, 5, 20) + engine_args = EngineArgs(model="facebook/opt-125m", max_model_len=8, worker_use_ray=True) + migration_config = MigrationConfig("LCFS", "rpc", 16, 1, 4, 5, 20, 2) node_id = ray.get_runtime_context().get_node_id() scheduling_strategy = NodeAffinitySchedulingStrategy(node_id=node_id, soft=False) @@ -76,7 +64,7 @@ def test_engine_step_exception(setup_ray_env): cur_free_memory, _ = torch.cuda.mem_get_info() assert cur_free_memory < origin_free_memory - ray.get(llumlet.set_error_step.remote(True)) + ray.get(llumlet.set_error_step.remote()) time.sleep(3) all_actors = ray.util.list_named_actors(True) diff --git a/tests/unit_test/llumlet/test_local_migration_scheduler.py b/tests/unit_test/llumlet/test_local_migration_scheduler.py index d585300d..c0c6f834 100644 --- a/tests/unit_test/llumlet/test_local_migration_scheduler.py +++ b/tests/unit_test/llumlet/test_local_migration_scheduler.py @@ -66,14 +66,18 @@ def test_scheduler_policy(): assert scheduler.get_migrate_out_request().request_id == "0" engine.add_request(request_id="3", length=2, expected_steps=1) - request = scheduler.get_migrate_out_request() - assert request.request_id == "3" - assert request.output_len >= request.expected_steps and request.inference_type == RequestInferenceType.DECODE engine.add_request(request_id="4", length=3, expected_steps=math.inf) + engine.add_request(request_id="5", length=4, expected_steps=math.inf) scheduler.request_migration_policy = "LCFS" request = scheduler.get_migrate_out_request() + request.migrating = True assert request.request_id == "3" assert request.output_len >= request.expected_steps and request.inference_type == RequestInferenceType.DECODE + request = scheduler.get_migrate_out_request() + request.migrating = True + assert request.request_id == "5" + request = scheduler.get_migrate_out_request() + assert request.request_id == "4" def test_scheduler_should_abort_migration(): req_0 = MockRequest(request_id="0", length=1, expected_steps=math.inf) diff --git a/tests/unit_test/queue/test_zmq.py b/tests/unit_test/queue/test_zmq.py index d4303d37..6f62935e 100644 --- a/tests/unit_test/queue/test_zmq.py +++ b/tests/unit_test/queue/test_zmq.py @@ -106,8 +106,8 @@ async def benchmark_queue(qps, ip=None, port=None): signal.alarm(0) @pytest.mark.asyncio -@pytest.mark.parametrize("qps", [128.0, 256.0, 512.0, 1024.0]) -async def test_queue_zmq(setup_ray_env, qps): +async def test_queue_zmq(setup_ray_env): ip = '127.0.0.1' port = 1234 + qps = 1024.0 await benchmark_queue(qps, ip, port)