Skip to content

Commit

Permalink
[Core] Support gloo and nccl bankend for kv cache transfer
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui committed Jul 31, 2024
1 parent 688cd8e commit d2c3994
Show file tree
Hide file tree
Showing 10 changed files with 396 additions and 170 deletions.
6 changes: 4 additions & 2 deletions benchmark/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,8 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument("--tokenizer", type=str, required=True,
help="Name or path of the tokenizer.")
parser.add_argument('--trust_remote_code',
action='store_true')
parser.add_argument('-v', '--verbose', action='store_true')
parser.add_argument('--backend', type=GenerationBackend,
choices=[e.name for e in GenerationBackend], default='vLLM')
Expand Down Expand Up @@ -701,7 +703,7 @@ def main():
assert args.random_prompt_count is not None

backend = GenerationBackend[args.backend]
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=args.trust_remote_code)
print(tokenizer)

if args.dataset_type:
Expand Down Expand Up @@ -798,4 +800,4 @@ def main():


if __name__ == '__main__':
main()
main()
11 changes: 7 additions & 4 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,16 @@ def create_engine_manager_configs(
return global_scheduler_config

def create_migration_configs(
self,
self, instance_rank_map, pp_or_tp_enabled, group_name
) -> MigrationConfig:
migration_config = MigrationConfig(self.migrate_policy,
self.migration_backend,
self.migration_cache_blocks,
self.last_stage_max_blocks,
self.max_stages)
self.max_stages,
instance_rank_map,
pp_or_tp_enabled,
group_name)
return migration_config

@classmethod
Expand Down Expand Up @@ -107,7 +110,7 @@ def add_cli_args(
parser.add_argument('--dispatch-policy',
type=str,
default=EngineManagerArgs.dispatch_policy,
choices=['balanced', 'load', 'queue'],
choices=['balanced', 'load', 'queue', 'flood'],
help='dispatch policy')

parser.add_argument('--enable-migrate',
Expand Down Expand Up @@ -192,7 +195,7 @@ def add_cli_args(
parser.add_argument('--migration-backend',
type=str,
default=EngineManagerArgs.migration_backend,
choices=['gloo','rpc'],
choices=['gloo','nccl','rpc'],
help='communication backend during migration')
parser.add_argument('--migration-cache_blocks',
type=int,
Expand Down
12 changes: 5 additions & 7 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,11 @@ def free_request_states(self, request_id: Union[str, Iterable[str]]) -> None:
class BackendVLLM(BackendInterface):
def __init__(
self,
instance_id: int,
instance_id: str,
migration_config: MigrationConfig,
engine_args: EngineArgs,
placement_group: "PlacementGroup"
) -> None:
assert migration_config.migration_backend == "rpc", "Gloo support will be released later."
self.engine: LLMEngineLlumnix = LLMEngineLlumnix.from_engine_args(engine_args=engine_args, instance_id=instance_id,
placement_group=placement_group)
# multi-instance args
Expand All @@ -189,7 +188,7 @@ def __init__(
self.worker_handle_list = self.engine.model_executor.workers.copy()
if len(self.worker_handle_list) + 1 == self.engine.parallel_config.world_size:
self.worker_handle_list.insert(0, ray.get_actor(f"instance_{self.instance_id}", namespace="llumnix"))
self._run_workers("init_migration", num_migration_cache_blocks=migration_config.migration_cache_blocks,\
self._run_workers("init_migration", instance_id=instance_id, migration_config=migration_config,\
src_worker_handle_list=self.worker_handle_list,
placement_group=placement_group)
self._thread = threading.Thread(
Expand All @@ -201,9 +200,8 @@ def _start_engine_loop(self) -> None:
while True:
self.engine.step()

def send_cpu_cache(self, *args, **kwargs):
# driver worker migration interface
return self.engine.model_executor.driver_worker.execute_method("send_cpu_cache", *args, **kwargs)
def execute_wroker_method(self, method, *args, **kwargs):
return self.engine.model_executor.driver_worker.execute_method(method, *args, **kwargs)

def stop_shutdown(self) -> None:
self.engine.scaling_down = False
Expand Down Expand Up @@ -250,7 +248,7 @@ def commit_dst_request(self, backend_request: SequenceGroup, server_info: Server

def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None:
ray.get(dst_ray_actor.execute_engine_method.remote("_run_workers",
"migrate_gpu_cache_ray_rpc",
"migrate_cache",
dst_blocks=dst_blocks,
src_blocks=src_blocks,
src_worker_handle_list=self.worker_handle_list))
Expand Down
278 changes: 278 additions & 0 deletions llumnix/backends/vllm/migrate_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
from abc import ABC, abstractmethod
from typing import List
import torch
import numpy as np

import ray
import ray.util.collective as col
from vllm.logger import init_logger
from vllm.worker.cache_engine import CacheEngine
from llumnix.config import MigrationConfig

logger = init_logger(__name__)

class MigrateBackendBase(ABC):
@abstractmethod
def init_col(self, name) -> None:
raise NotImplementedError

@abstractmethod
def destory_col(self) -> None:
raise NotImplementedError

@abstractmethod
def warmup(self) -> None:
raise NotImplementedError

@abstractmethod
def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]) -> None:
raise NotImplementedError

@abstractmethod
def do_send(self, dst_handle, blocks: List[int]):
raise NotImplementedError

@abstractmethod
def do_recv(self, src_handle, blocks: List[int]):
raise NotImplementedError

@ray.remote(num_cpus=0)
class ProxyActor:
def exec_method(self, is_driver_worker, handle, *args, **kwargs):
try:
if is_driver_worker:
ret = ray.get(handle.execute_engine_method.remote("execute_wroker_method", *args, **kwargs))
else:
ret = ray.get(handle.execute_method.remote(*args, **kwargs))
# pylint: disable=try-except-raise
except:
raise

return ret

class RPCMigrateBackend(MigrateBackendBase):
def __init__(self, migrate_config: MigrationConfig, cache_engine: CacheEngine, worker_rank, worker_handle_list, \
scheduling_strategy, dtype, is_driver_worker, gpu_cache) -> None:
super().__init__()

self.migrate_config = migrate_config
self.cache_engine = cache_engine

self.worker_rank = worker_rank
self.worker_handle_list = worker_handle_list
self.actor = ProxyActor.options(scheduling_strategy=scheduling_strategy).remote()

self.dtype = dtype
self.is_driver_worker = is_driver_worker
self.gpu_cache = gpu_cache
self.init_col(None)

def init_col(self, name) -> None:
self.cache_device = "cpu"

self.num_migration_cache_blocks = self.migrate_config.migration_cache_blocks
self.num_layers = self.cache_engine.num_layers
migration_cache_size = self.cache_engine.block_size * self.cache_engine.num_heads * self.cache_engine.head_size

self.dummy_cache = torch.empty(
size=(2*self.num_migration_cache_blocks, self.num_layers, migration_cache_size),
dtype=self.cache_engine.dtype,
device=self.cache_device,
pin_memory=True
)

self.migration_stream = torch.cuda.Stream()
logger.info("create rpc migrate backend success.")

def destory_col(self) -> None:
pass

def warmup(self) -> None:
self_handle = self.worker_handle_list[self.worker_rank]
self.actor.exec_method.remote(self.is_driver_worker, self_handle, "do_send", None,
list(range(self.num_migration_cache_blocks)))

def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]) -> None:
tot_blocks = len(src_blocks)
rpc_numpy_cache = None
for start_idx in range(0, tot_blocks, self.num_migration_cache_blocks):
offset = min(self.num_migration_cache_blocks, tot_blocks - start_idx)
send_blocks = src_blocks[start_idx:start_idx+offset]
ray_obj = self.actor.exec_method.remote(self.is_driver_worker, src_handle, "do_send", None, send_blocks)
if rpc_numpy_cache is not None:
self.do_recv(rpc_numpy_cache, recv_blocks)
rpc_numpy_cache = ray.get(ray_obj)
recv_blocks = dst_blocks[start_idx:start_idx+offset]
self.do_recv(rpc_numpy_cache, recv_blocks)

def do_send(self, dst_handle, blocks: List[int]):
num_blocks = len(blocks)
data = self.dummy_cache[:2*num_blocks]
dummy_key_cpu = self.dummy_cache[:num_blocks]
dummy_value_cpu = self.dummy_cache[num_blocks:2*num_blocks]
with torch.cuda.stream(self.migration_stream):
for layer_idx in range(self.num_layers):
for idx, block_num in enumerate(blocks):
dummy_key_cpu[idx][layer_idx].copy_(self.gpu_cache[layer_idx][0][block_num], non_blocking=True)
dummy_value_cpu[idx][layer_idx].copy_(self.gpu_cache[layer_idx][1][block_num], non_blocking=True)
torch.cuda.Stream.synchronize(self.migration_stream)
return data.to(self.dtype).numpy()

def do_recv(self, src_handle, blocks: List[int]):
num_blocks = len(blocks)

# use pin memory dummy_cache to speed up data transfer
data = self.dummy_cache[:2*num_blocks].copy_(torch.from_numpy(src_handle))
dummy_key = data[:num_blocks]
dummy_value = data[num_blocks:2*num_blocks]

with torch.cuda.stream(self.migration_stream):
for layer_idx in range(self.num_layers):
for idx, block_num in enumerate(blocks):
self.gpu_cache[layer_idx][0][block_num].copy_(dummy_key[idx][layer_idx], non_blocking=True)
self.gpu_cache[layer_idx][1][block_num].copy_(dummy_value[idx][layer_idx], non_blocking=True)
torch.cuda.Stream.synchronize(self.migration_stream)

class RayMigrateBackend(MigrateBackendBase):
def __init__(self, migrate_config: MigrationConfig, cache_engine: CacheEngine, ray_world_size, ray_rank, \
local_rank, scheduling_strategy, dtype, is_driver_worker, gpu_cache) -> None:
super().__init__()

self.num_migration_cache_blocks = migrate_config.migration_cache_blocks
self.migrate_config = migrate_config
self.cache_engine = cache_engine
self.backend = migrate_config.migration_backend

self.ray_world_size = ray_world_size
self.ray_rank = ray_rank
self.group_name = migrate_config.group_name

self.local_rank = local_rank
self.actor = ProxyActor.options(scheduling_strategy=scheduling_strategy).remote()
self.dtype = dtype
self.is_driver_worker = is_driver_worker
self.gpu_cache = gpu_cache

self.init_col(migrate_config.group_name)

def init_col(self, name) -> None:
self.group_name = name

migration_cache_size = self.cache_engine.block_size * self.cache_engine.num_heads * self.cache_engine.head_size

if self.backend == 'gloo':
self.cache_device = "cpu"
elif self.backend == 'nccl':
self.cache_device = torch.device(f"cuda:{self.local_rank}")
else:
raise ValueError("backend must be 'gloo' or 'nccl'")

pin_memory = self.backend == 'gloo'
self.dummy_cache = torch.empty(
size=(2*self.num_migration_cache_blocks, self.cache_engine.num_layers, migration_cache_size),
dtype=self.cache_engine.dtype,
device=self.cache_device,
pin_memory=pin_memory
)

self.migration_stream = torch.cuda.Stream()

col.init_collective_group(world_size=self.ray_world_size, rank=self.ray_rank,
backend=self.backend, group_name=self.group_name)

logger.info("create ray collective group success (group_name:{}, backbend: {})."
.format(self.group_name, self.backend))

def warmup(self) -> None:
enable_warmup = self.ray_world_size > 1
need_warmup = not (self.ray_world_size % 2 != 0 and self.ray_rank == self.ray_world_size - 1)

if enable_warmup and need_warmup:
if self.ray_rank % 2 == 0:
self.do_send(self.ray_rank+1, list(range(self.migrate_config.migration_cache_blocks)))
else:
self.do_recv(self.ray_rank-1, list(range(self.migrate_config.migration_cache_blocks)))

def destory_col(self) -> None:
col.destroy_collective_group(self.group_name)

def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]) -> None:
tot_blocks = len(src_blocks)
src_rank = ray.get(self.actor.exec_method.remote(self.is_driver_worker, src_handle, "get_ray_rank"))

for start_idx in range(0, tot_blocks, self.num_migration_cache_blocks):
offset = min(self.num_migration_cache_blocks, tot_blocks - start_idx)
send_blocks = src_blocks[start_idx:start_idx+offset]
recv_blocks = dst_blocks[start_idx:start_idx+offset]
self.actor.exec_method.remote(self.is_driver_worker, src_handle, "do_send", self.ray_rank, send_blocks)
self.do_recv(src_rank, recv_blocks)

def do_send(self, dst_handle, blocks: List[int]):
num_blocks = len(blocks)
data = self.dummy_cache[:2*num_blocks]
dummy_key_cpu = data[:num_blocks]
dummy_value_cpu = data[num_blocks:2*num_blocks]
with torch.cuda.stream(self.migration_stream):
for layer_idx in range(self.cache_engine.num_layers):
for idx, block_num in enumerate(blocks):
dummy_key_cpu[idx][layer_idx].copy_(self.gpu_cache[layer_idx][0][block_num], non_blocking=True)
dummy_value_cpu[idx][layer_idx].copy_(self.gpu_cache[layer_idx][1][block_num], non_blocking=True)
torch.cuda.Stream.synchronize(self.migration_stream)
data = self._may_use_numpy_for_tranfer(data)
col.send(data, dst_handle, self.group_name)

def do_recv(self, src_handle, blocks: List[int]):
num_blocks = len(blocks)
data = self.dummy_cache[:2*num_blocks]
data = self._may_use_numpy_for_tranfer(data)
# note that col.recv use ray.collective inner stream, not migration_stream
col.recv(data, src_handle, self.group_name)
data = to_tensor(data)

dummy_key_cpu = data[:num_blocks]
dummy_value_cpu = data[num_blocks:2*num_blocks]
with torch.cuda.stream(self.migration_stream):
for layer_idx in range(self.cache_engine.num_layers):
for idx, block_num in enumerate(blocks):
self.gpu_cache[layer_idx][0][block_num].copy_(dummy_key_cpu[idx][layer_idx], non_blocking=True)
self.gpu_cache[layer_idx][1][block_num].copy_(dummy_value_cpu[idx][layer_idx], non_blocking=True)
torch.cuda.Stream.synchronize(self.migration_stream)

def _may_use_numpy_for_tranfer(self, data):
ret = data
if self.backend == 'gloo':
ret = data.to(self.dtype).numpy()
return ret

def to_tensor(data):
if isinstance(data, torch.Tensor):
return data

if isinstance(data, np.ndarray):
return torch.from_numpy(data)

raise TypeError("Input data must be either a numpy array or a PyTorch tensor")

def get_migrate_collective(migrate_config: MigrationConfig, cache_engine: CacheEngine, worker_handle_list, scheduling_strategy, \
dtype, is_driver_worker, gpu_cache, ray_world_size, ray_rank, worker_rank, local_rank) -> MigrateBackendBase:
if migrate_config.pp_or_tp_enabled and migrate_config.migration_backend == 'nccl':
logger.warning("NCCL backend is not supported for PP or TP enabled model, using gloo instead.")
migrate_config.migration_backend = 'gloo'

if cache_engine.num_gpu_blocks < migrate_config.migration_cache_blocks:
logger.warning("migration_cache_blocks({}) is larger than num_gpu_blocks({}), reducing it to num_gpu_blocks."
.format(migrate_config.migration_cache_blocks, cache_engine.num_gpu_blocks))
migrate_config.migration_cache_blocks = cache_engine.num_gpu_blocks

target_col = None
backend = migrate_config.migration_backend
if backend in ['nccl', 'gloo']:
target_col = RayMigrateBackend(migrate_config, cache_engine, ray_world_size, ray_rank, \
local_rank, scheduling_strategy, dtype, is_driver_worker, gpu_cache)
elif backend == 'rpc':
target_col = RPCMigrateBackend(migrate_config, cache_engine, worker_rank, worker_handle_list, scheduling_strategy, \
dtype, is_driver_worker, gpu_cache)
else:
raise ValueError(f"Unsupported backend {backend}")

return target_col
Loading

0 comments on commit d2c3994

Please sign in to comment.