From d2c399483b55900f01c0cb6a2d4b57e4ff31e189 Mon Sep 17 00:00:00 2001 From: xianyu Date: Thu, 25 Jul 2024 07:22:42 +0000 Subject: [PATCH] [Core] Support gloo and nccl bankend for kv cache transfer --- benchmark/benchmark_serving.py | 6 +- llumnix/arg_utils.py | 11 +- llumnix/backends/vllm/llm_engine.py | 12 +- llumnix/backends/vllm/migrate_backend.py | 278 ++++++++++++++++++ llumnix/backends/vllm/worker.py | 218 +++++--------- llumnix/config.py | 9 +- llumnix/entrypoints/llumnix_utils.py | 15 +- llumnix/entrypoints/vllm/api_server.py | 6 + .../global_scheduler/dispatch_scheduler.py | 8 + requirements.txt | 3 + 10 files changed, 396 insertions(+), 170 deletions(-) create mode 100644 llumnix/backends/vllm/migrate_backend.py diff --git a/benchmark/benchmark_serving.py b/benchmark/benchmark_serving.py index c915e1f5..f63d12b2 100644 --- a/benchmark/benchmark_serving.py +++ b/benchmark/benchmark_serving.py @@ -652,6 +652,8 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("--tokenizer", type=str, required=True, help="Name or path of the tokenizer.") + parser.add_argument('--trust_remote_code', + action='store_true') parser.add_argument('-v', '--verbose', action='store_true') parser.add_argument('--backend', type=GenerationBackend, choices=[e.name for e in GenerationBackend], default='vLLM') @@ -701,7 +703,7 @@ def main(): assert args.random_prompt_count is not None backend = GenerationBackend[args.backend] - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=args.trust_remote_code) print(tokenizer) if args.dataset_type: @@ -798,4 +800,4 @@ def main(): if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index f1602bd7..2cac5570 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -70,13 +70,16 @@ def create_engine_manager_configs( return global_scheduler_config def create_migration_configs( - self, + self, instance_rank_map, pp_or_tp_enabled, group_name ) -> MigrationConfig: migration_config = MigrationConfig(self.migrate_policy, self.migration_backend, self.migration_cache_blocks, self.last_stage_max_blocks, - self.max_stages) + self.max_stages, + instance_rank_map, + pp_or_tp_enabled, + group_name) return migration_config @classmethod @@ -107,7 +110,7 @@ def add_cli_args( parser.add_argument('--dispatch-policy', type=str, default=EngineManagerArgs.dispatch_policy, - choices=['balanced', 'load', 'queue'], + choices=['balanced', 'load', 'queue', 'flood'], help='dispatch policy') parser.add_argument('--enable-migrate', @@ -192,7 +195,7 @@ def add_cli_args( parser.add_argument('--migration-backend', type=str, default=EngineManagerArgs.migration_backend, - choices=['gloo','rpc'], + choices=['gloo','nccl','rpc'], help='communication backend during migration') parser.add_argument('--migration-cache_blocks', type=int, diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index a1970670..daf6fac8 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -174,12 +174,11 @@ def free_request_states(self, request_id: Union[str, Iterable[str]]) -> None: class BackendVLLM(BackendInterface): def __init__( self, - instance_id: int, + instance_id: str, migration_config: MigrationConfig, engine_args: EngineArgs, placement_group: "PlacementGroup" ) -> None: - assert migration_config.migration_backend == "rpc", "Gloo support will be released later." self.engine: LLMEngineLlumnix = LLMEngineLlumnix.from_engine_args(engine_args=engine_args, instance_id=instance_id, placement_group=placement_group) # multi-instance args @@ -189,7 +188,7 @@ def __init__( self.worker_handle_list = self.engine.model_executor.workers.copy() if len(self.worker_handle_list) + 1 == self.engine.parallel_config.world_size: self.worker_handle_list.insert(0, ray.get_actor(f"instance_{self.instance_id}", namespace="llumnix")) - self._run_workers("init_migration", num_migration_cache_blocks=migration_config.migration_cache_blocks,\ + self._run_workers("init_migration", instance_id=instance_id, migration_config=migration_config,\ src_worker_handle_list=self.worker_handle_list, placement_group=placement_group) self._thread = threading.Thread( @@ -201,9 +200,8 @@ def _start_engine_loop(self) -> None: while True: self.engine.step() - def send_cpu_cache(self, *args, **kwargs): - # driver worker migration interface - return self.engine.model_executor.driver_worker.execute_method("send_cpu_cache", *args, **kwargs) + def execute_wroker_method(self, method, *args, **kwargs): + return self.engine.model_executor.driver_worker.execute_method(method, *args, **kwargs) def stop_shutdown(self) -> None: self.engine.scaling_down = False @@ -250,7 +248,7 @@ def commit_dst_request(self, backend_request: SequenceGroup, server_info: Server def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None: ray.get(dst_ray_actor.execute_engine_method.remote("_run_workers", - "migrate_gpu_cache_ray_rpc", + "migrate_cache", dst_blocks=dst_blocks, src_blocks=src_blocks, src_worker_handle_list=self.worker_handle_list)) diff --git a/llumnix/backends/vllm/migrate_backend.py b/llumnix/backends/vllm/migrate_backend.py new file mode 100644 index 00000000..e0ae04c1 --- /dev/null +++ b/llumnix/backends/vllm/migrate_backend.py @@ -0,0 +1,278 @@ +from abc import ABC, abstractmethod +from typing import List +import torch +import numpy as np + +import ray +import ray.util.collective as col +from vllm.logger import init_logger +from vllm.worker.cache_engine import CacheEngine +from llumnix.config import MigrationConfig + +logger = init_logger(__name__) + +class MigrateBackendBase(ABC): + @abstractmethod + def init_col(self, name) -> None: + raise NotImplementedError + + @abstractmethod + def destory_col(self) -> None: + raise NotImplementedError + + @abstractmethod + def warmup(self) -> None: + raise NotImplementedError + + @abstractmethod + def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]) -> None: + raise NotImplementedError + + @abstractmethod + def do_send(self, dst_handle, blocks: List[int]): + raise NotImplementedError + + @abstractmethod + def do_recv(self, src_handle, blocks: List[int]): + raise NotImplementedError + +@ray.remote(num_cpus=0) +class ProxyActor: + def exec_method(self, is_driver_worker, handle, *args, **kwargs): + try: + if is_driver_worker: + ret = ray.get(handle.execute_engine_method.remote("execute_wroker_method", *args, **kwargs)) + else: + ret = ray.get(handle.execute_method.remote(*args, **kwargs)) + # pylint: disable=try-except-raise + except: + raise + + return ret + +class RPCMigrateBackend(MigrateBackendBase): + def __init__(self, migrate_config: MigrationConfig, cache_engine: CacheEngine, worker_rank, worker_handle_list, \ + scheduling_strategy, dtype, is_driver_worker, gpu_cache) -> None: + super().__init__() + + self.migrate_config = migrate_config + self.cache_engine = cache_engine + + self.worker_rank = worker_rank + self.worker_handle_list = worker_handle_list + self.actor = ProxyActor.options(scheduling_strategy=scheduling_strategy).remote() + + self.dtype = dtype + self.is_driver_worker = is_driver_worker + self.gpu_cache = gpu_cache + self.init_col(None) + + def init_col(self, name) -> None: + self.cache_device = "cpu" + + self.num_migration_cache_blocks = self.migrate_config.migration_cache_blocks + self.num_layers = self.cache_engine.num_layers + migration_cache_size = self.cache_engine.block_size * self.cache_engine.num_heads * self.cache_engine.head_size + + self.dummy_cache = torch.empty( + size=(2*self.num_migration_cache_blocks, self.num_layers, migration_cache_size), + dtype=self.cache_engine.dtype, + device=self.cache_device, + pin_memory=True + ) + + self.migration_stream = torch.cuda.Stream() + logger.info("create rpc migrate backend success.") + + def destory_col(self) -> None: + pass + + def warmup(self) -> None: + self_handle = self.worker_handle_list[self.worker_rank] + self.actor.exec_method.remote(self.is_driver_worker, self_handle, "do_send", None, + list(range(self.num_migration_cache_blocks))) + + def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]) -> None: + tot_blocks = len(src_blocks) + rpc_numpy_cache = None + for start_idx in range(0, tot_blocks, self.num_migration_cache_blocks): + offset = min(self.num_migration_cache_blocks, tot_blocks - start_idx) + send_blocks = src_blocks[start_idx:start_idx+offset] + ray_obj = self.actor.exec_method.remote(self.is_driver_worker, src_handle, "do_send", None, send_blocks) + if rpc_numpy_cache is not None: + self.do_recv(rpc_numpy_cache, recv_blocks) + rpc_numpy_cache = ray.get(ray_obj) + recv_blocks = dst_blocks[start_idx:start_idx+offset] + self.do_recv(rpc_numpy_cache, recv_blocks) + + def do_send(self, dst_handle, blocks: List[int]): + num_blocks = len(blocks) + data = self.dummy_cache[:2*num_blocks] + dummy_key_cpu = self.dummy_cache[:num_blocks] + dummy_value_cpu = self.dummy_cache[num_blocks:2*num_blocks] + with torch.cuda.stream(self.migration_stream): + for layer_idx in range(self.num_layers): + for idx, block_num in enumerate(blocks): + dummy_key_cpu[idx][layer_idx].copy_(self.gpu_cache[layer_idx][0][block_num], non_blocking=True) + dummy_value_cpu[idx][layer_idx].copy_(self.gpu_cache[layer_idx][1][block_num], non_blocking=True) + torch.cuda.Stream.synchronize(self.migration_stream) + return data.to(self.dtype).numpy() + + def do_recv(self, src_handle, blocks: List[int]): + num_blocks = len(blocks) + + # use pin memory dummy_cache to speed up data transfer + data = self.dummy_cache[:2*num_blocks].copy_(torch.from_numpy(src_handle)) + dummy_key = data[:num_blocks] + dummy_value = data[num_blocks:2*num_blocks] + + with torch.cuda.stream(self.migration_stream): + for layer_idx in range(self.num_layers): + for idx, block_num in enumerate(blocks): + self.gpu_cache[layer_idx][0][block_num].copy_(dummy_key[idx][layer_idx], non_blocking=True) + self.gpu_cache[layer_idx][1][block_num].copy_(dummy_value[idx][layer_idx], non_blocking=True) + torch.cuda.Stream.synchronize(self.migration_stream) + +class RayMigrateBackend(MigrateBackendBase): + def __init__(self, migrate_config: MigrationConfig, cache_engine: CacheEngine, ray_world_size, ray_rank, \ + local_rank, scheduling_strategy, dtype, is_driver_worker, gpu_cache) -> None: + super().__init__() + + self.num_migration_cache_blocks = migrate_config.migration_cache_blocks + self.migrate_config = migrate_config + self.cache_engine = cache_engine + self.backend = migrate_config.migration_backend + + self.ray_world_size = ray_world_size + self.ray_rank = ray_rank + self.group_name = migrate_config.group_name + + self.local_rank = local_rank + self.actor = ProxyActor.options(scheduling_strategy=scheduling_strategy).remote() + self.dtype = dtype + self.is_driver_worker = is_driver_worker + self.gpu_cache = gpu_cache + + self.init_col(migrate_config.group_name) + + def init_col(self, name) -> None: + self.group_name = name + + migration_cache_size = self.cache_engine.block_size * self.cache_engine.num_heads * self.cache_engine.head_size + + if self.backend == 'gloo': + self.cache_device = "cpu" + elif self.backend == 'nccl': + self.cache_device = torch.device(f"cuda:{self.local_rank}") + else: + raise ValueError("backend must be 'gloo' or 'nccl'") + + pin_memory = self.backend == 'gloo' + self.dummy_cache = torch.empty( + size=(2*self.num_migration_cache_blocks, self.cache_engine.num_layers, migration_cache_size), + dtype=self.cache_engine.dtype, + device=self.cache_device, + pin_memory=pin_memory + ) + + self.migration_stream = torch.cuda.Stream() + + col.init_collective_group(world_size=self.ray_world_size, rank=self.ray_rank, + backend=self.backend, group_name=self.group_name) + + logger.info("create ray collective group success (group_name:{}, backbend: {})." + .format(self.group_name, self.backend)) + + def warmup(self) -> None: + enable_warmup = self.ray_world_size > 1 + need_warmup = not (self.ray_world_size % 2 != 0 and self.ray_rank == self.ray_world_size - 1) + + if enable_warmup and need_warmup: + if self.ray_rank % 2 == 0: + self.do_send(self.ray_rank+1, list(range(self.migrate_config.migration_cache_blocks))) + else: + self.do_recv(self.ray_rank-1, list(range(self.migrate_config.migration_cache_blocks))) + + def destory_col(self) -> None: + col.destroy_collective_group(self.group_name) + + def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]) -> None: + tot_blocks = len(src_blocks) + src_rank = ray.get(self.actor.exec_method.remote(self.is_driver_worker, src_handle, "get_ray_rank")) + + for start_idx in range(0, tot_blocks, self.num_migration_cache_blocks): + offset = min(self.num_migration_cache_blocks, tot_blocks - start_idx) + send_blocks = src_blocks[start_idx:start_idx+offset] + recv_blocks = dst_blocks[start_idx:start_idx+offset] + self.actor.exec_method.remote(self.is_driver_worker, src_handle, "do_send", self.ray_rank, send_blocks) + self.do_recv(src_rank, recv_blocks) + + def do_send(self, dst_handle, blocks: List[int]): + num_blocks = len(blocks) + data = self.dummy_cache[:2*num_blocks] + dummy_key_cpu = data[:num_blocks] + dummy_value_cpu = data[num_blocks:2*num_blocks] + with torch.cuda.stream(self.migration_stream): + for layer_idx in range(self.cache_engine.num_layers): + for idx, block_num in enumerate(blocks): + dummy_key_cpu[idx][layer_idx].copy_(self.gpu_cache[layer_idx][0][block_num], non_blocking=True) + dummy_value_cpu[idx][layer_idx].copy_(self.gpu_cache[layer_idx][1][block_num], non_blocking=True) + torch.cuda.Stream.synchronize(self.migration_stream) + data = self._may_use_numpy_for_tranfer(data) + col.send(data, dst_handle, self.group_name) + + def do_recv(self, src_handle, blocks: List[int]): + num_blocks = len(blocks) + data = self.dummy_cache[:2*num_blocks] + data = self._may_use_numpy_for_tranfer(data) + # note that col.recv use ray.collective inner stream, not migration_stream + col.recv(data, src_handle, self.group_name) + data = to_tensor(data) + + dummy_key_cpu = data[:num_blocks] + dummy_value_cpu = data[num_blocks:2*num_blocks] + with torch.cuda.stream(self.migration_stream): + for layer_idx in range(self.cache_engine.num_layers): + for idx, block_num in enumerate(blocks): + self.gpu_cache[layer_idx][0][block_num].copy_(dummy_key_cpu[idx][layer_idx], non_blocking=True) + self.gpu_cache[layer_idx][1][block_num].copy_(dummy_value_cpu[idx][layer_idx], non_blocking=True) + torch.cuda.Stream.synchronize(self.migration_stream) + + def _may_use_numpy_for_tranfer(self, data): + ret = data + if self.backend == 'gloo': + ret = data.to(self.dtype).numpy() + return ret + +def to_tensor(data): + if isinstance(data, torch.Tensor): + return data + + if isinstance(data, np.ndarray): + return torch.from_numpy(data) + + raise TypeError("Input data must be either a numpy array or a PyTorch tensor") + +def get_migrate_collective(migrate_config: MigrationConfig, cache_engine: CacheEngine, worker_handle_list, scheduling_strategy, \ + dtype, is_driver_worker, gpu_cache, ray_world_size, ray_rank, worker_rank, local_rank) -> MigrateBackendBase: + if migrate_config.pp_or_tp_enabled and migrate_config.migration_backend == 'nccl': + logger.warning("NCCL backend is not supported for PP or TP enabled model, using gloo instead.") + migrate_config.migration_backend = 'gloo' + + if cache_engine.num_gpu_blocks < migrate_config.migration_cache_blocks: + logger.warning("migration_cache_blocks({}) is larger than num_gpu_blocks({}), reducing it to num_gpu_blocks." + .format(migrate_config.migration_cache_blocks, cache_engine.num_gpu_blocks)) + migrate_config.migration_cache_blocks = cache_engine.num_gpu_blocks + + target_col = None + backend = migrate_config.migration_backend + if backend in ['nccl', 'gloo']: + target_col = RayMigrateBackend(migrate_config, cache_engine, ray_world_size, ray_rank, \ + local_rank, scheduling_strategy, dtype, is_driver_worker, gpu_cache) + elif backend == 'rpc': + target_col = RPCMigrateBackend(migrate_config, cache_engine, worker_rank, worker_handle_list, scheduling_strategy, \ + dtype, is_driver_worker, gpu_cache) + else: + raise ValueError(f"Unsupported backend {backend}") + + return target_col diff --git a/llumnix/backends/vllm/worker.py b/llumnix/backends/vllm/worker.py index 3e693d22..816c7fac 100644 --- a/llumnix/backends/vllm/worker.py +++ b/llumnix/backends/vllm/worker.py @@ -12,50 +12,66 @@ # limitations under the License. from typing import List +import math +import os import ray import torch from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy from vllm.utils import is_pin_memory_available from vllm.worker.worker import Worker +from vllm.config import CacheConfig, ModelConfig, ParallelConfig from llumnix.logger import init_logger from llumnix.backends.vllm.utils import _sample_with_torch +from llumnix.backends.vllm.migrate_backend import get_migrate_collective +from llumnix.config import MigrationConfig + logger = init_logger(__name__) NUMPY_SUPPORT_DTYPES = [torch.float32, torch.float16] -@ray.remote(num_cpus=0) -class RecvActor: - def recv_cpu_cache(self, src_worker_handle, src_blocks, is_driver_worker): - """ - Args: - src_worker_handle: src worker actor handle - blocks: block to send - """ - try: - if is_driver_worker: - migration_cache = ray.get(src_worker_handle.execute_engine_method.remote("send_cpu_cache", src_blocks)) - else: - migration_cache = ray.get(src_worker_handle.execute_method.remote("send_cpu_cache", src_blocks)) - # pylint: disable=try-except-raise - except: - raise - return migration_cache - class MigrationWorker(Worker): def __init__(self, *args, **kwargs) -> None: # replace sampler # pylint: disable=import-outside-toplevel import vllm.model_executor.layers.sampler vllm.model_executor.layers.sampler._sample_with_torch = _sample_with_torch + + backend = os.environ.get("MIGRATE_BACKEND", "rpc") + migrate_size = int(os.environ.get("MIGRATE_CACHE_SIZE", 1)) + + parallel_config: ParallelConfig = kwargs["parallel_config"] + pp_or_tp_enabled = parallel_config.world_size > 1 + + if backend == "nccl" and (not pp_or_tp_enabled): + model_config: ModelConfig = kwargs["model_config"] + cache_config: CacheConfig = kwargs["cache_config"] + + num_layer = model_config.get_num_layers(parallel_config) + block_size = cache_config.block_size + num_head = model_config.get_num_kv_heads(parallel_config) + hidden_size = model_config.get_hidden_size() + total_size = migrate_size * num_layer * block_size * num_head * hidden_size + + device = torch.device(f"cuda:{kwargs['local_rank']}") + _, total_memory = torch.cuda.mem_get_info(device) + migrate_ratio = math.ceil(total_size / total_memory * 100) / 100 + cache_config.gpu_memory_utilization -= migrate_ratio + + logger.info("nccl collective take {} gpu memory, left gpu_memory_utilization {} for kv cache." \ + .format(migrate_ratio, cache_config.gpu_memory_utilization)) + super().__init__(*args, **kwargs) def load_model(self): torch.cuda.set_device(self.device) return super().load_model() - def init_migration(self, num_migration_cache_blocks: int, src_worker_handle_list, placement_group=None) -> None: + def get_ray_rank(self): + return self.ray_rank + + def init_migration(self, instance_id: str, migration_config: MigrationConfig, src_worker_handle_list, placement_group=None) -> None: if placement_group: scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=placement_group, @@ -66,135 +82,53 @@ def init_migration(self, num_migration_cache_blocks: int, src_worker_handle_list node_id=ray.get_runtime_context().get_node_id(), soft=False, ) - self.recv_actor = RecvActor.options(scheduling_strategy=scheduling_strategy).remote() - self.migration_stream = torch.cuda.Stream() - self.default_stream = torch.cuda.current_stream() - self.num_migration_cache_blocks = num_migration_cache_blocks - assert self.migration_stream != self.default_stream pin_memory = is_pin_memory_available() if not pin_memory: # Pinning memory in WSL is not supported. # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications logger.warning("Using 'pin_memory=False' as WSL is detected. " - "This may slow down the performance.") - migration_cache_size = self.cache_engine.block_size * self.cache_engine.num_heads * self.cache_engine.head_size + "This may slow down the performance.") + + num_migration_cache_blocks = migration_config.migration_cache_blocks + self.num_migration_cache_blocks = num_migration_cache_blocks + self.rpc_dtype = self.cache_engine.dtype if self.cache_engine.dtype in NUMPY_SUPPORT_DTYPES: self.rpc_dtype = self.cache_engine.dtype else: self.rpc_dtype = torch.float32 logger.warning("Detecting numpy unsupported dtype: {}. Using torch.float32.".format(self.cache_engine.dtype)) - # self.migration_cache = torch.zeros( - # size=(self.cache_engine.num_layers, 2, self.num_migration_cache_blocks, migration_cache_size), - # dtype=self.cache_engine.dtype, - # pin_memory=pin_memory, - # ) - self.migration_key_cache = torch.empty( - size=(self.num_migration_cache_blocks, self.cache_engine.num_layers, migration_cache_size), - dtype=self.cache_engine.dtype, - pin_memory=pin_memory, - ) - self.migration_value_cache = torch.empty( - size=(self.num_migration_cache_blocks, self.cache_engine.num_layers, migration_cache_size), - dtype=self.cache_engine.dtype, - pin_memory=pin_memory, - ) - # do dummy rpc - src_worker_handle = src_worker_handle_list[self.rank] - self.recv_actor.recv_cpu_cache.remote(src_worker_handle, [0], self.is_driver_worker) - - def send_cpu_cache(self, blocks: List[int]): - num_blocks = len(blocks) - dummy_key_cpu = self.migration_key_cache[:num_blocks] - dummy_value_cpu = self.migration_value_cache[:num_blocks] - with torch.cuda.stream(self.migration_stream): - for layer_idx in range(self.cache_engine.num_layers): - for idx, block_num in enumerate(blocks): - dummy_key_cpu[idx][layer_idx].copy_(self.gpu_cache[layer_idx][0][block_num]) - dummy_value_cpu[idx][layer_idx].copy_(self.gpu_cache[layer_idx][1][block_num]) - torch.cuda.Stream.synchronize(self.migration_stream) - return (dummy_key_cpu.to(self.rpc_dtype).numpy(), dummy_value_cpu.to(self.rpc_dtype).numpy()) - - def recv_cpu_cache(self, blocks: List[int], rpc_numpy_cache): - num_blocks = len(blocks) - dummy_key = self.migration_key_cache[:num_blocks] - dummy_value = self.migration_value_cache[:num_blocks] - k = rpc_numpy_cache[0] - v = rpc_numpy_cache[1] - dummy_key.copy_(torch.from_numpy(k)) - dummy_value.copy_(torch.from_numpy(v)) - with torch.cuda.stream(self.migration_stream): - for layer_idx in range(self.cache_engine.num_layers): - for idx, block_num in enumerate(blocks): - self.gpu_cache[layer_idx][0][block_num].copy_(dummy_key[idx][layer_idx]) - self.gpu_cache[layer_idx][1][block_num].copy_(dummy_value[idx][layer_idx]) - torch.cuda.Stream.synchronize(self.migration_stream) - - def send_cpu_cache_v2(self, blocks: List[int]): - src_to_dst = dict(enumerate(blocks)) - with torch.cuda.stream(self.migration_stream): - for layer_idx in range(self.cache_engine.num_layers): - self.cache_engine.attn_backend.swap_blocks(self.gpu_cache[layer_idx], self.migration_cache[layer_idx], src_to_dst) - torch.cuda.Stream.synchronize(self.migration_stream) - return self.migration_cache.to(self.rpc_dtype).numpy() - - def recv_cpu_cache_v2(self, blocks: List[int], rpc_numpy_cache): - with torch.cuda.stream(self.migration_stream): - self.migration_cache.copy_(torch.from_numpy(rpc_numpy_cache)) - src_to_dst = dict(enumerate(blocks)) - for layer_idx in range(self.cache_engine.num_layers): - self.cache_engine.attn_backend.swap_blocks(self.migration_cache[layer_idx], self.gpu_cache[layer_idx],src_to_dst) - torch.cuda.Stream.synchronize(self.migration_stream) - - - def migrate_gpu_cache_ray_rpc(self, src_worker_handle_list, src_blocks: List[int], dst_blocks: List[int]): + + self.instance_id = instance_id + num_instance = len(migration_config.instance_rank_map) + self.ray_world_size = num_instance * self.parallel_config.world_size + self.ray_rank = self.rank + migration_config.instance_rank_map[self.instance_id] * self.parallel_config.world_size + self.migrate_col = get_migrate_collective(migrate_config=migration_config, + cache_engine=self.cache_engine, + worker_handle_list=src_worker_handle_list, + scheduling_strategy=scheduling_strategy, + dtype=self.rpc_dtype, + is_driver_worker=self.is_driver_worker, + gpu_cache=self.gpu_cache, + ray_world_size=self.ray_world_size, + ray_rank=self.ray_rank, + worker_rank=self.rank, + local_rank=self.local_rank) + self.migrate_col.warmup() + + def migrate_cache(self, src_worker_handle_list, src_blocks: List[int], dst_blocks: List[int]): try: src_worker_handle = src_worker_handle_list[self.rank] - tot_blocks = len(src_blocks) - rpc_numpy_cache = None - for start_idx in range(0, tot_blocks, self.num_migration_cache_blocks): - # send/recv num_migration_cache_blocks per iter - offset = min(self.num_migration_cache_blocks, tot_blocks - start_idx) - send_blocks = src_blocks[start_idx:start_idx+offset] - ray_obj = self.recv_actor.recv_cpu_cache.remote(src_worker_handle, send_blocks, self.is_driver_worker) - if rpc_numpy_cache is not None: - self.recv_cpu_cache(recv_blocks, rpc_numpy_cache) - rpc_numpy_cache = ray.get(ray_obj) - recv_blocks = dst_blocks[start_idx:start_idx+offset] - self.recv_cpu_cache(recv_blocks, rpc_numpy_cache) + self.migrate_col.migrate_cache(src_worker_handle, src_blocks, dst_blocks) except ray.exceptions.RayActorError: - logger.info("[migrate_gpu_cache_ray_rpc] self.rank: {}, src_worker_handle {} is dead".format(self.rank, src_worker_handle)) - - # def send_gpu_cache_ray(self,rank_offset:int, blocks:List[int]): - # with torch.cuda.stream(self.migration_stream): - # dst_rank = self.ray_rank + rank_offset - # num_blocks = len(blocks) - # dummy_key_cpu = self.dummy_key_cpu[:num_blocks] - # dummy_value_cpu = self.dummy_value_cpu[:num_blocks] - # with torch.cuda.stream(self.migration_stream): - # for i in range(self.cache_engine.num_layers): - # for idx,block_num in enumerate(blocks): - # dummy_key_cpu[idx].copy_(self.gpu_cache[i][0][block_num]) - # dummy_value_cpu[idx].copy_(self.gpu_cache[i][1][block_num]) - # col.send(dummy_key_cpu, dst_rank) - # col.send(dummy_value_cpu, dst_rank) - - # torch.cuda.Stream.synchronize(self.migration_stream) - - # def recv_gpu_cache_ray(self,rank_offset:int, blocks): - # with torch.cuda.stream(self.migration_stream): - # src_rank = self.ray_rank + rank_offset - # num_blocks = len(blocks) - # dummy_key = self.dummy_key_cpu[:num_blocks] - # dummy_value = self.dummy_value_cpu[:num_blocks] - # for i in range(self.cache_engine.num_layers): - # col.recv(dummy_key, src_rank) - # col.recv(dummy_value, src_rank) - # for idx,block_num in enumerate(blocks): - # self.gpu_cache[i][0][block_num].copy_(dummy_key[idx]) - # self.gpu_cache[i][1][block_num].copy_(dummy_value[idx]) - # torch.cuda.Stream.synchronize(self.migration_stream) + logger.info("[migrate_cache] self.rank: {}, src_worker_handle {} is dead".format(self.rank, src_worker_handle)) + + def do_recv(self, src_handle, blocks: List[int]): + return self.migrate_col.do_recv(src_handle, blocks=blocks) + + def do_send(self, dst_handle, blocks: List[int]): + return self.migrate_col.do_send(dst_handle, blocks=blocks) def shutdown(self) -> None: torch.cuda.synchronize() @@ -207,21 +141,3 @@ def shutdown(self) -> None: def restart(self) -> None: self.init_model() self.init_cache_engine(self.cache_config) - - # instance_id is changed from int to str, this function should be modified if used - # def init_migration_dist_ray(self, num_instance, instance_id): - # self.ray_world_size = num_instance * self.parallel_config.world_size - # self.ray_rank = self.rank + instance_id * self.parallel_config.world_size - # logger.info(f"{self.ray_world_size, self.ray_rank}") - # # col.init_collective_group(world_size=self.ray_world_size, rank=self.ray_rank , backend="gloo") - # # rpc.init_rpc(f"worker_{self.ray_rank}", rank=self.ray_rank, world_size=self.ray_world_size) - - - # def run_migration_warmup(self): - # if self.ray_world_size > 1: - # if self.ray_rank % 2: - # self.recv_gpu_cache_ray(1 if self.ray_rank + 1 < self.ray_world_size else 1-self.ray_world_size,[0]) - # self.send_gpu_cache_ray(-1 if self.ray_rank > 0 else self.ray_world_size-1,[0]) - # else: - # self.send_gpu_cache_ray(-1 if self.ray_rank > 0 else self.ray_world_size-1,[0]) - # self.recv_gpu_cache_ray(1 if self.ray_rank + 1 < self.ray_world_size else 1-self.ray_world_size,[0]) diff --git a/llumnix/config.py b/llumnix/config.py index d70d9c9c..34249c78 100644 --- a/llumnix/config.py +++ b/llumnix/config.py @@ -18,12 +18,18 @@ def __init__( migration_backend: str, migration_cache_blocks: int, last_stage_max_blocks: int, - max_stages: int,) -> None: + max_stages: int, + instance_rank_map: dict, + pp_or_tp_enabled: bool, + group_name: str) -> None: self.migrate_policy = migrate_policy self.migration_backend = migration_backend self.migration_cache_blocks = migration_cache_blocks self.last_stage_max_blocks = last_stage_max_blocks self.max_stages = max_stages + self.instance_rank_map = instance_rank_map + self.pp_or_tp_enabled = pp_or_tp_enabled + self.group_name = group_name class GlobalSchedulerConfig: def __init__( @@ -49,3 +55,4 @@ def __init__( self.scale_policy = scale_policy self.scale_up_threshold = scale_up_threshold*(-1) self.scale_down_threshold = scale_down_threshold*(-1) + \ No newline at end of file diff --git a/llumnix/entrypoints/llumnix_utils.py b/llumnix/entrypoints/llumnix_utils.py index 437adbf7..8f10e570 100644 --- a/llumnix/entrypoints/llumnix_utils.py +++ b/llumnix/entrypoints/llumnix_utils.py @@ -122,15 +122,21 @@ def init_llumlets(engine_manager_args: EngineManagerArgs, parallel_config = engine_config.parallel_config instance_ids: List[str] = [] llumlets: List[Llumlet] = [] - for _ in range(engine_manager_args.initial_instances): - instance_id = random_uuid() + + instance_ids = [random_uuid() for _ in range(engine_manager_args.initial_instances)] + id_rank_map = {instance_id: index for index, instance_id in enumerate(instance_ids)} + pp_or_tp_enabled = parallel_config.world_size > 1 + migration_configs = engine_manager_args.create_migration_configs(id_rank_map, pp_or_tp_enabled, random_uuid()) + + for idx in range(engine_manager_args.initial_instances): + instance_id = instance_ids[idx] if not engine_manager_args.profiling_result_file_path: llumlet = Llumlet.from_args( engine_manager_args.fixed_node_init, instance_id, BackendType.VLLM, parallel_config.world_size, - engine_manager_args.create_migration_configs(), + migration_configs, engine_args, ) else: @@ -139,12 +145,11 @@ def init_llumlets(engine_manager_args: EngineManagerArgs, instance_id, BackendType.SIM_VLLM, parallel_config.world_size, - engine_manager_args.create_migration_configs(), + migration_configs, engine_manager_args.profiling_result_file_path, engine_manager_args.gpu_type, engine_args, ) - instance_ids.append(instance_id) llumlets.append(llumlet) return instance_ids, llumlets diff --git a/llumnix/entrypoints/vllm/api_server.py b/llumnix/entrypoints/vllm/api_server.py index 112043f7..11f7386d 100644 --- a/llumnix/entrypoints/vllm/api_server.py +++ b/llumnix/entrypoints/vllm/api_server.py @@ -17,6 +17,7 @@ import time import asyncio import json +import os from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse import uvicorn @@ -214,6 +215,9 @@ async def is_ready(): ready_status = await engine_manager.is_ready.remote() return ready_status +def set_runtime_env(manager_args: EngineManagerArgs): + os.environ['MIGRATE_CACHE_SIZE'] = str(manager_args.migration_cache_blocks) + os.environ['MIGRATE_BACKEND'] = manager_args.migration_backend if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -237,6 +241,8 @@ async def is_ready(): logger.info("engine_args: {}".format(engine_args)) + set_runtime_env(engine_manager_args) + if args.launch_ray_cluster: # Launch the ray cluster for multi-node serving. launch_ray_cluster(args.ray_cluster_port) diff --git a/llumnix/global_scheduler/dispatch_scheduler.py b/llumnix/global_scheduler/dispatch_scheduler.py index 65e84479..613535ce 100644 --- a/llumnix/global_scheduler/dispatch_scheduler.py +++ b/llumnix/global_scheduler/dispatch_scheduler.py @@ -86,6 +86,13 @@ def dispatch(self, sorted_instance_infos: List[InstanceInfo]) -> int: pass +class Flood(DispatchPolicy): + def dispatch(self, + instance_num_request: Dict[str, int], + sorted_instance_infos: List[InstanceInfo]) -> str: + instance_id = max(instance_num_request, key=instance_num_request.get) + return instance_id + class Balanced(DispatchPolicy): def dispatch(self, instance_num_request: Dict[str, int], @@ -117,6 +124,7 @@ def dispatch(self, class DispatchPolicyFactory: _POLICY_REGISTRY = { + 'flood': Flood, 'balanced': Balanced, 'load': Load, 'queue': Queue, diff --git a/requirements.txt b/requirements.txt index 0db75cb6..b3d35556 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,6 @@ aiohttp scipy pandas matplotlib +cupy-cuda12x # for ray.collective +numpy == 1.22.4 +# conda install -c conda-forge gcc==12.1.0 # for ray.collective