diff --git a/llumnix/backends/vllm/migration_backend.py b/llumnix/backends/vllm/migration_backend.py index 6e2ff539..1f65adeb 100644 --- a/llumnix/backends/vllm/migration_backend.py +++ b/llumnix/backends/vllm/migration_backend.py @@ -88,6 +88,9 @@ def warmup(self) -> bool: logger.info("rpc migration backend warmup successfully.") return True + # The src actor will pack the kv-cache data layer by layer. Specifically, NumPy is used for the transfer + # because, for a single node, Ray RPC can transfer NumPy arrays via shared memory. Then, the recv actor + # first copies the data to a pinned-memory dummy cache before transferring it to the GPU to accelerate data transfer. def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]) -> None: tot_blocks = len(src_blocks) rpc_numpy_cache = None @@ -207,10 +210,10 @@ def destory_backend(self) -> None: err_info = e if err_info is not None: - logger.info("destory migration backend successfully (group_name:{}, backbend: {}), meet_err: {}." + logger.info("destory migration backend successfully (group_name: {}, backbend: {}), error: {}." .format(self.group_name, self.backend, err_info)) else: - logger.info("destory migration backend successfully (group_name:{}, backbend: {})." + logger.info("destory migration backend successfully (group_name: {}, backbend: {})." .format(self.group_name, self.backend)) self.group_name = None @@ -221,14 +224,16 @@ def warmup(self) -> bool: col.allreduce(self.dummy_cache[0], self.group_name) # pylint: disable=W0703 except Exception as e: - logger.info("warmup migration backend failed (group_name:{}, world_size: {}, rank: {}, backbend: {}), err: {}." + logger.info("warmup migration backend failed (group_name: {}, world_size: {}, rank: {}, backbend: {}), err: {}." .format(self.group_name, self.global_world_size, self.global_rank, self.backend, e)) return False - logger.info("migration backend warmup successfully (group_name:{}, world_size: {}, rank: {}, backbend: {})." + logger.info("migration backend warmup successfully (group_name: {}, world_size: {}, rank: {}, backbend: {})." .format(self.group_name, self.global_world_size, self.global_rank, self.backend)) return True + # Ray.collective is used to construct the gloo and nccl backends. The do_send/do_recv functions will transmit + # data layer by layer. Take into consideration that col.send/recv are blocking operations. def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]) -> None: tot_blocks = len(src_blocks) src_rank = ray.get(self.actor.exec_method.remote(self.is_driver_worker, src_handle, "get_global_rank")) @@ -249,6 +254,7 @@ def do_send(self, dst_handle, blocks: List[int]): cache_idx = layer_idx % self.migration_num_layers self.cache_engine.attn_backend.swap_blocks(self.gpu_cache[layer_idx], send_cache[cache_idx], src_to_dst) if cache_idx + 1 == self.migration_num_layers or layer_idx + 1 == self.cache_engine.num_layers: + # TODO(KuilongCui): check the error code if peer is dead col.send(send_cache, dst_handle, self.group_name) torch.cuda.Stream.synchronize(self.migration_stream) @@ -276,10 +282,8 @@ def get_migration_backend(migration_config: MigrationConfig, cache_engine: Cache if backend in ['nccl', 'gloo']: target_col = RayColMigrationBackend(migration_config, cache_engine, local_rank, scheduling_strategy, is_driver_worker, gpu_cache) - elif backend == 'rpc': + else: target_col = RayRpcMigrationBackend(migration_config, cache_engine, worker_rank, worker_handle_list, scheduling_strategy, is_driver_worker, gpu_cache) - else: - raise ValueError(f"Unsupported backend {backend}") return target_col