Skip to content

Commit

Permalink
[CI] Add test for migration backend and worker (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui authored Aug 29, 2024
1 parent be0d326 commit db660ac
Show file tree
Hide file tree
Showing 12 changed files with 293 additions and 42 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ persistent=yes

# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
load-plugins=pylint_pytest

disable=consider-using-f-string,logging-format-interpolation

Expand Down
10 changes: 5 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@ install: cupy

.PHONY: lint
lint: check_pylint_installed
@pylint --rcfile=.pylintrc -s n ./llumnix

@pylint -s n --disable=all \
--enable=trailing-whitespace,unused-variable,wrong-import-order,missing-final-newline,line-too-long,\
unused-import,singleton-comparison,unnecessary-comprehension ./tests
@pylint --rcfile=.pylintrc -s n ./llumnix ./tests

.PHONY: test
test:
Expand Down Expand Up @@ -60,4 +56,8 @@ check_pylint_installed:
echo "pylint is not installed. Installing pylint $(PYLINT_VERSION)..."; \
python3 -m pip install pylint==$(PYLINT_VERSION); }

@python3 -c "import pylint_pytest" >/dev/null 2>&1 || { \
echo "pylint-pytest is not installed. Installing pylint-pytest ..."; \
python3 -m pip install pylint-pytest; }

###################################### pylint end #######################################
8 changes: 6 additions & 2 deletions llumnix/llm_engine_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,12 @@ async def run_task(alive_instances: List[str], task_name: str, *args, **kwargs):
self.scale_down(dead_instances, rebuild_migrate_backend=False)

if self.engine_manager_args.migration_backend == 'gloo':
# clear gloo migrate backend intermediate state
ray.kill(ray.get_actor("gloo_queue", "llumnix"))
try:
# clear gloo migrate backend intermediate state
ray.kill(ray.get_actor("gloo_queue", "llumnix"))
except ValueError:
# gloo_queue may not have been created yet; just ignore this error.
pass

return dead_instances

Expand Down
4 changes: 1 addition & 3 deletions tests/backends/vllm/test_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from unittest.mock import MagicMock

from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput,
Expand Down Expand Up @@ -88,4 +86,4 @@ def test_llm_engine_from_engine_args():

latency_data = LatencyMemData({},{},{})
llm_engine = MockEngine.from_engine_args(engine_args, instance_id="0", migration_config=None, latency_mem=latency_data)
assert llm_engine.executor_class == SimGPUExecutor
assert llm_engine.executor_class == SimGPUExecutor
3 changes: 1 addition & 2 deletions tests/backends/vllm/test_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import pytest
import torch
import time
import ray
from ray.util.queue import Queue as RayQueue
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
Expand Down Expand Up @@ -141,4 +140,4 @@ def test_clear_migration_states():
_, seq_group = create_dummy_prompt("0",7,block_size)
llumlet.backend_engine.add_migrating_out_request_last_stage(seq_group)
llumlet.clear_migration_states(is_migrate_in=False)
assert llumlet.backend_engine.get_last_running_request() is not None
assert llumlet.backend_engine.get_last_running_request() is not None
110 changes: 110 additions & 0 deletions tests/backends/vllm/test_migration_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import pytest
import torch
import ray

from vllm.engine.arg_utils import EngineArgs

from llumnix.backends.vllm.worker import MigrationWorker
from llumnix.arg_utils import EngineManagerArgs
from llumnix.utils import random_uuid

from tests.backends.vllm.test_worker import create_worker

class MockMigrationWorker(MigrationWorker):
def set_gpu_cache(self, data):
for layer_idx in range(self.cache_engine.num_layers):
self.gpu_cache[layer_idx].copy_(data[layer_idx])
torch.cuda.synchronize()

def get_gpu_cache(self):
torch.cuda.synchronize()
return self.gpu_cache

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPU to run the test.")
@pytest.mark.parametrize("backend", ['rpc', 'gloo', 'nccl'])
def test_migrate_cache(backend):
ray.init(namespace="llumnix", ignore_reinit_error=True)

engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config()
migraiton_config = EngineManagerArgs(migration_cache_blocks=3, migration_num_layers=5).create_migration_config()
migraiton_config.migration_backend = backend

worker0 = create_worker(rank=0, local_rank=0, engine_config=engine_config,
worker_module_name="tests.backends.vllm.test_migration_backend",
worker_class_name="MockMigrationWorker")
worker1 = create_worker(rank=0, local_rank=0, engine_config=engine_config,
worker_module_name="tests.backends.vllm.test_migration_backend",
worker_class_name="MockMigrationWorker")

ray.get(worker0.execute_method.remote('init_device'))
ray.get(worker1.execute_method.remote('init_device'))

num_gpu_blocks = 8
ray.get(worker0.execute_method.remote('initialize_cache', num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0))
ray.get(worker1.execute_method.remote('initialize_cache', num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0))

worker0_id = random_uuid()
ray.get(worker0.execute_method.remote(
'init_migration',
instance_id=worker0_id,
migration_config=migraiton_config,
src_worker_handle_list=[worker0],
node_id=ray.get_runtime_context().get_node_id()))

worker1_id = random_uuid()
ray.get(worker1.execute_method.remote(
'init_migration',
instance_id=worker1_id,
migration_config=migraiton_config,
src_worker_handle_list=[worker1],
node_id=ray.get_runtime_context().get_node_id()))

instance_rank = {worker0_id: 0, worker1_id: 1}
group_name = random_uuid()
assert all(ray.get([worker0.execute_method.remote('rebuild_migration_backend',
instance_rank=instance_rank, group_name=group_name),
worker1.execute_method.remote('rebuild_migration_backend',
instance_rank=instance_rank, group_name=group_name)]))
assert all(ray.get([worker0.execute_method.remote('warmup'),
worker1.execute_method.remote('warmup')]))

num_layers = engine_config.model_config.get_num_layers(engine_config.parallel_config)
head_size = engine_config.model_config.get_head_size()
num_heads = engine_config.model_config.get_num_kv_heads(engine_config.parallel_config)
block_size = engine_config.cache_config.block_size

dummy_data = torch.randn(size=(num_layers, 2, num_gpu_blocks, block_size*num_heads*head_size))
ray.get(worker0.execute_method.remote('set_gpu_cache', data=dummy_data))
worker0_data = ray.get(worker0.execute_method.remote('get_gpu_cache'))

dst_blocks = list(range(num_gpu_blocks))
random.shuffle(dst_blocks)
src_to_dst = dict(enumerate(dst_blocks))
ray.get(worker1.execute_method.remote(
'migrate_cache',
src_worker_handle_list=[worker0],
src_blocks=list(src_to_dst.keys()),
dst_blocks=list(src_to_dst.values())))

worker1_data = ray.get(worker1.execute_method.remote('get_gpu_cache'))

for layer_idx in range(num_layers):
for src_idx, dst_idx in src_to_dst.items():
assert torch.allclose(worker0_data[layer_idx][0][src_idx], worker1_data[layer_idx][0][dst_idx])
assert torch.allclose(worker0_data[layer_idx][1][src_idx], worker1_data[layer_idx][1][dst_idx])

ray.shutdown()
44 changes: 22 additions & 22 deletions tests/backends/vllm/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,19 @@ def test_scheduler_policy():

# all seq_group in waiting queue
migrating_request = scheduler.get_last_running_request()
assert migrating_request == None
assert migrating_request is None
migrating_request = scheduler.get_shortest_running_request()
assert migrating_request == None
assert migrating_request is None
migrating_request = scheduler.get_longest_running_request()
assert migrating_request == None
assert migrating_request is None
# all seq_group in prefilling stage
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
_, out = schedule_and_update_computed_tokens(scheduler)
migrating_request = scheduler.get_last_running_request()
assert migrating_request == None
assert migrating_request is None
migrating_request = scheduler.get_shortest_running_request()
assert migrating_request == None
assert migrating_request is None
migrating_request = scheduler.get_longest_running_request()
assert migrating_request == None
assert migrating_request is None
append_new_token(out, 1)
schedule_and_update_computed_tokens(scheduler)
# all in running queue
Expand All @@ -109,11 +109,11 @@ def test_scheduler_num_killed_request():
_, seq_group = create_dummy_prompt(str(idx), prompt_length=8, block_size=block_size)
scheduler.add_seq_group(seq_group)
# remain 0 blocks
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
_, out = schedule_and_update_computed_tokens(scheduler)
append_new_token(out, 1)
assert scheduler._get_num_killed_requests() == 0
# preempt 2 requests
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
_, out = schedule_and_update_computed_tokens(scheduler)
assert scheduler._get_num_killed_requests() == 2

def test_scheduler_running_request():
Expand All @@ -123,7 +123,7 @@ def test_scheduler_running_request():
for idx in range(1, num_seq_group + 1):
_, seq_group = create_dummy_prompt(str(idx), prompt_length=idx, block_size=block_size)
scheduler.add_seq_group(seq_group)
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
schedule_and_update_computed_tokens(scheduler)
assert scheduler.get_num_unfinished_seq_groups() == 4
scheduler.remove_running_request("1")
assert scheduler.get_num_unfinished_seq_groups() == 3
Expand Down Expand Up @@ -162,46 +162,46 @@ def test_scheduler_should_abort_migration():
_, seq_group_1 = create_dummy_prompt("1", prompt_length=17, block_size=block_size)
scheduler.add_seq_group(seq_group_1)
# remain 0 blocks
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
_, out = schedule_and_update_computed_tokens(scheduler)
append_new_token(out, 1)

assert scheduler._get_num_killed_requests() == 0
# assert scheduler.block_manager.get_num_free_gpu_blocks() == 0
# all in running queue
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
_, out = schedule_and_update_computed_tokens(scheduler)
append_new_token(out, 1)
assert scheduler._get_num_killed_requests() == 0
migrating_request = scheduler.get_last_running_request()
last_stage_time = time.time()
assert migrating_request.request_id == "1"
# preempt request 1
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
_, out = schedule_and_update_computed_tokens(scheduler)
append_new_token(out, 1)
assert scheduler.should_abort_migration(seq_group_1, last_stage_time) == True
assert scheduler.should_abort_migration(seq_group_0, last_stage_time) == False
assert scheduler.should_abort_migration(seq_group_1, last_stage_time)
assert not scheduler.should_abort_migration(seq_group_0, last_stage_time)
assert scheduler._get_num_killed_requests() == 1
scheduler.remove_running_request(seq_group_0)
scheduler.free_src_request(seq_group_0)
# free request 0, requset 1 prefill
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
_, out = schedule_and_update_computed_tokens(scheduler)
append_new_token(out, 1)
assert scheduler._get_num_killed_requests() == 0
assert scheduler.should_abort_migration(seq_group_1, last_stage_time) == True
assert scheduler.should_abort_migration(seq_group_1, last_stage_time)

def test_free_dst_pre_alloc_cache():
scheduler = initialize_scheduler()
blocks = scheduler.pre_alloc("1", 2)
blocks = scheduler.pre_alloc("1", 4)
scheduler.pre_alloc("1", 2)
scheduler.pre_alloc("1", 4)
assert len(scheduler.pre_alloc_cache_dict["1"]) == 6
scheduler.free_dst_pre_alloc_cache("1")
assert scheduler.pre_alloc_cache_dict.get("1",None) == None
assert scheduler.pre_alloc_cache_dict.get("1", None) is None
assert scheduler.block_manager.get_num_free_gpu_blocks() == 8

def test_get_request_incremental_blocks():
scheduler = initialize_scheduler()
block_size = 4
_, seq_group = create_dummy_prompt("0", prompt_length=16, block_size=block_size)
scheduler.add_seq_group(seq_group)
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
schedule_and_update_computed_tokens(scheduler)
incremental_blocks = scheduler.get_request_incremental_blocks(seq_group, 2)
assert len(incremental_blocks) == 2
assert len(incremental_blocks) == 2
Loading

0 comments on commit db660ac

Please sign in to comment.