Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui committed Nov 5, 2024
1 parent 7b45ae6 commit a2f7f94
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 55 deletions.
6 changes: 6 additions & 0 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
import argparse
from typing import Tuple

from llumnix.logger import init_logger
from llumnix.internal_config import GlobalSchedulerConfig, MigrationConfig
from llumnix.config import LlumnixConfig, get_llumnix_config
from llumnix.config.default import _C

logger = init_logger("llumnix.entrypoints.vllm.api_server")

class LlumnixArgumentParser(argparse.ArgumentParser):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -197,6 +199,10 @@ def check_args(cls, args: 'EngineManagerArgs', parser: argparse.ArgumentParser):
if hasattr(action, 'choices') and action.choices is not None and hasattr(args, action.dest):
assert getattr(args, action.dest) in action.choices, f"{action.dest} should be one of {action.choices}."

if args.migration_backend == 'nccl':
logger.warning("NCCL migration backend is deprecated, use gloo instead.")
args.migration_backend = 'gloo'

assert args.migration_backend != 'gloo' or (args.migration_backend == 'gloo' \
and not args.disable_init_instance_by_manager and not args.disable_fixed_node_init_instance), \
("When using gloo as migration backend, "
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e_test/test_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def parse_manager_log_file(log_file):
@pytest.mark.asyncio
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="at least 2 gpus required for migration bench")
@pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B'])
@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo', 'nccl'])
@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo'])
async def test_migration_benchmark(model, migration_backend):
base_port = 37037
instance_output_logs = []
Expand Down
172 changes: 118 additions & 54 deletions tests/unit_test/backends/vllm/test_migration_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,46 @@
from tests.conftest import setup_ray_env
from .test_worker import create_worker

def get_ready_workers(num_worker, num_gpu_blocks, engine_config, migraiton_config):
workers = []
worker_ids = []

for _ in range(num_worker):
worker_id = random_uuid()
worker = create_worker(rank=0, local_rank=0, engine_config=engine_config,
worker_module_name="tests.unit_test.backends.vllm.test_migration_backend",
worker_class_name="MockMigrationWorker")

ray.get(worker.execute_method.remote('init_device'))
ray.get(worker.execute_method.remote('initialize_cache', num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0))
ray.get(worker.execute_method.remote(
'init_migration',
instance_id=worker_id,
migration_config=migraiton_config,
src_worker_handle_list=[worker],
node_id=ray.get_runtime_context().get_node_id()))

workers.append(worker)
worker_ids.append(worker_id)

instance_rank = {}
for idx, worker_id in enumerate(worker_ids):
instance_rank[worker_id] = idx
group_name = random_uuid()

init_group_tasks =[]
for worker in workers:
init_group_tasks.append(worker.execute_method.remote('rebuild_migration_backend',
instance_rank=instance_rank, group_name=group_name))
assert all(ray.get(init_group_tasks))

warmup_tasks = []
for worker in workers:
warmup_tasks.append(worker.execute_method.remote('warmup'))
assert all(ray.get(warmup_tasks))

return workers, worker_ids

class MockMigrationWorker(MigrationWorker):
def set_gpu_cache(self, data):
for layer_idx in range(self.cache_engine.num_layers):
Expand All @@ -38,71 +78,95 @@ def get_gpu_cache(self):

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPU to run the test.")
@pytest.mark.parametrize("backend", ['rpc', 'gloo', 'nccl'])
def test_migrate_cache(setup_ray_env, backend):
def test_one_to_many_migrate_cache(setup_ray_env, backend):
engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config()
migraiton_config = EngineManagerArgs(migration_cache_blocks=3, migration_num_layers=5).create_migration_config()
migraiton_config.migration_backend = backend

worker0 = create_worker(rank=0, local_rank=0, engine_config=engine_config,
worker_module_name="tests.unit_test.backends.vllm.test_migration_backend",
worker_class_name="MockMigrationWorker")
worker1 = create_worker(rank=0, local_rank=0, engine_config=engine_config,
worker_module_name="tests.unit_test.backends.vllm.test_migration_backend",
worker_class_name="MockMigrationWorker")

ray.get(worker0.execute_method.remote('init_device'))
ray.get(worker1.execute_method.remote('init_device'))

num_gpu_blocks = 8
ray.get(worker0.execute_method.remote('initialize_cache', num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0))
ray.get(worker1.execute_method.remote('initialize_cache', num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0))

worker0_id = random_uuid()
ray.get(worker0.execute_method.remote(
'init_migration',
instance_id=worker0_id,
migration_config=migraiton_config,
src_worker_handle_list=[worker0],
node_id=ray.get_runtime_context().get_node_id()))

worker1_id = random_uuid()
ray.get(worker1.execute_method.remote(
'init_migration',
instance_id=worker1_id,
migration_config=migraiton_config,
src_worker_handle_list=[worker1],
node_id=ray.get_runtime_context().get_node_id()))

instance_rank = {worker0_id: 0, worker1_id: 1}
group_name = random_uuid()
assert all(ray.get([worker0.execute_method.remote('rebuild_migration_backend',
instance_rank=instance_rank, group_name=group_name),
worker1.execute_method.remote('rebuild_migration_backend',
instance_rank=instance_rank, group_name=group_name)]))
assert all(ray.get([worker0.execute_method.remote('warmup'),
worker1.execute_method.remote('warmup')]))
num_worker = 3 if backend != 'nccl' else 2
num_gpu_blocks = 300
workers, _ = get_ready_workers(num_worker, num_gpu_blocks, engine_config, migraiton_config)

num_layers = engine_config.model_config.get_num_layers(engine_config.parallel_config)
head_size = engine_config.model_config.get_head_size()
num_heads = engine_config.model_config.get_num_kv_heads(engine_config.parallel_config)
block_size = engine_config.cache_config.block_size
dummy_data = torch.randn(size=(num_layers, 2, num_gpu_blocks, block_size*num_heads*head_size))
ray.get(workers[0].execute_method.remote('set_gpu_cache', data=dummy_data))
worker0_data = ray.get(workers[0].execute_method.remote('get_gpu_cache'))

dst_blocks = list(range(num_gpu_blocks))
random.shuffle(dst_blocks)

single_worker_num_blocks = len(dst_blocks)//(num_worker-1)
migration_tasks = []
worker_idx = 1
for offset in range(0, len(dst_blocks), single_worker_num_blocks):
src_to_dst = dict(enumerate(dst_blocks[offset:offset+single_worker_num_blocks]))
migration_tasks.append(workers[worker_idx].execute_method.remote(
'migrate_cache',
src_worker_handle_list=[workers[0]],
src_blocks=list(src_to_dst.keys()),
dst_blocks=list(src_to_dst.values())))
worker_idx += 1
ray.get(migration_tasks)

worker_idx = 1
for offset in range(0, len(dst_blocks), single_worker_num_blocks):
src_to_dst = dict(enumerate(dst_blocks[offset:offset+single_worker_num_blocks]))
dst_worker_data = ray.get(workers[worker_idx].execute_method.remote('get_gpu_cache'))
for layer_idx in range(num_layers):
for src_idx, dst_idx in src_to_dst.items():
assert torch.allclose(worker0_data[layer_idx][0][src_idx], dst_worker_data[layer_idx][0][dst_idx])
assert torch.allclose(worker0_data[layer_idx][1][src_idx], dst_worker_data[layer_idx][1][dst_idx])
worker_idx += 1

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPU to run the test.")
@pytest.mark.parametrize("backend", ['rpc', 'gloo', 'nccl'])
def test_many_to_one_migrate_cache(setup_ray_env, backend):
engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config()
migraiton_config = EngineManagerArgs(migration_cache_blocks=3, migration_num_layers=5).create_migration_config()
migraiton_config.migration_backend = backend

num_worker = 3 if backend != 'nccl' else 2
num_gpu_blocks = 300
workers, _ = get_ready_workers(num_worker, num_gpu_blocks, engine_config, migraiton_config)

num_layers = engine_config.model_config.get_num_layers(engine_config.parallel_config)
head_size = engine_config.model_config.get_head_size()
num_heads = engine_config.model_config.get_num_kv_heads(engine_config.parallel_config)
block_size = engine_config.cache_config.block_size
dummy_data = torch.randn(size=(num_layers, 2, num_gpu_blocks, block_size*num_heads*head_size))
ray.get(worker0.execute_method.remote('set_gpu_cache', data=dummy_data))
worker0_data = ray.get(worker0.execute_method.remote('get_gpu_cache'))

worker_datas = [0]
for idx in range(1, num_worker):
ray.get(workers[idx].execute_method.remote('set_gpu_cache', data=dummy_data))
worker_datas.append(ray.get(workers[idx].execute_method.remote('get_gpu_cache')))

dst_blocks = list(range(num_gpu_blocks))
random.shuffle(dst_blocks)
src_to_dst = dict(enumerate(dst_blocks))
ray.get(worker1.execute_method.remote(
'migrate_cache',
src_worker_handle_list=[worker0],
src_blocks=list(src_to_dst.keys()),
dst_blocks=list(src_to_dst.values())))

worker1_data = ray.get(worker1.execute_method.remote('get_gpu_cache'))

for layer_idx in range(num_layers):
for src_idx, dst_idx in src_to_dst.items():
assert torch.allclose(worker0_data[layer_idx][0][src_idx], worker1_data[layer_idx][0][dst_idx])
assert torch.allclose(worker0_data[layer_idx][1][src_idx], worker1_data[layer_idx][1][dst_idx])

single_worker_num_blocks = len(dst_blocks)//(num_worker-1)
migration_tasks = []
worker_idx = 1
for offset in range(0, len(dst_blocks), single_worker_num_blocks):
src_to_dst = dict(enumerate(dst_blocks[offset:offset+single_worker_num_blocks]))
migration_tasks.append(workers[0].execute_method.remote(
'migrate_cache',
src_worker_handle_list=[workers[worker_idx]],
src_blocks=list(src_to_dst.keys()),
dst_blocks=list(src_to_dst.values())))
worker_idx += 1
ray.get(migration_tasks)

dst_worker_data = ray.get(workers[0].execute_method.remote('get_gpu_cache'))

worker_idx = 1
for offset in range(0, len(dst_blocks), single_worker_num_blocks):
src_to_dst = dict(enumerate(dst_blocks[offset:offset+single_worker_num_blocks]))

for layer_idx in range(num_layers):
for src_idx, dst_idx in src_to_dst.items():
assert torch.allclose(worker_datas[worker_idx][layer_idx][0][src_idx], dst_worker_data[layer_idx][0][dst_idx])
assert torch.allclose(worker_datas[worker_idx][layer_idx][1][src_idx], dst_worker_data[layer_idx][1][dst_idx])
worker_idx += 1

0 comments on commit a2f7f94

Please sign in to comment.