Skip to content

Commit

Permalink
Merge branch 'main' into fix-failover
Browse files Browse the repository at this point in the history
  • Loading branch information
s5u13b committed Aug 29, 2024
2 parents 76599eb + db660ac commit 1ce9d7a
Show file tree
Hide file tree
Showing 15 changed files with 375 additions and 56 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ persistent=yes

# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
load-plugins=pylint_pytest

disable=consider-using-f-string,logging-format-interpolation

Expand Down
22 changes: 17 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ init:
@git submodule update --init --recursive

.PHONY: install
install:
pip install -e .
install: cupy
@pip install -e .

.PHONY: lint
lint: check_pylint_installed
pylint --rcfile=.pylintrc ./llumnix
@pylint --rcfile=.pylintrc -s n ./llumnix ./tests

.PHONY: test
test:
pytest -vs --ignore=third_party/ --disable-warnings
@pytest -q -x --ignore=third_party/ --disable-warnings

#################### pygloo install for gloo migration backend begin ####################

Expand All @@ -34,10 +34,18 @@ PYGLOO_DIR = third_party/pygloo

.PHONY: pygloo
pygloo: init
./tools/pygloo_install.sh
@./tools/pygloo_install.sh

##################### pygloo install for gloo migration backend end #####################

###################################### cupy begin #######################################

.PHONY: cupy
cupy:
@./tools/cupy_install.sh

####################################### cupy end ########################################

##################################### pylint begin ######################################

PYLINT_VERSION = 2.12.2
Expand All @@ -48,4 +56,8 @@ check_pylint_installed:
echo "pylint is not installed. Installing pylint $(PYLINT_VERSION)..."; \
python3 -m pip install pylint==$(PYLINT_VERSION); }

@python3 -c "import pylint_pytest" >/dev/null 2>&1 || { \
echo "pylint-pytest is not installed. Installing pylint-pytest ..."; \
python3 -m pip install pylint-pytest; }

###################################### pylint end #######################################
13 changes: 8 additions & 5 deletions llumnix/backends/vllm/migration_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from typing import List
import torch
import cupy
from func_timeout import func_set_timeout, FunctionTimedOut

import ray
Expand Down Expand Up @@ -176,7 +177,7 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine,
pin_memory=pin_memory
)

self.migration_stream = torch.cuda.Stream()
self.migration_stream = cupy.cuda.Stream()

def init_backend(self, group_name, world_size, rank) -> bool:
@func_set_timeout(self.migration_config.migration_backend_init_timeout)
Expand Down Expand Up @@ -249,26 +250,28 @@ def do_send(self, dst_handle, blocks: List[int]):
num_blocks = len(blocks)
send_cache = self.dummy_cache[:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size)
src_to_dst = {block_num: idx for idx, block_num in enumerate(blocks)}
with torch.cuda.stream(self.migration_stream):

with self.migration_stream:
for layer_idx in range(self.cache_engine.num_layers):
cache_idx = layer_idx % self.migration_num_layers
self.cache_engine.attn_backend.swap_blocks(self.gpu_cache[layer_idx], send_cache[cache_idx], src_to_dst)
if cache_idx + 1 == self.migration_num_layers or layer_idx + 1 == self.cache_engine.num_layers:
# TODO(KuilongCui): check the error code if peer is dead
col.send(send_cache, dst_handle, self.group_name)
torch.cuda.Stream.synchronize(self.migration_stream)
self.migration_stream.synchronize()

def do_recv(self, src_handle, blocks: List[int]):
num_blocks = len(blocks)
src_to_dst = dict(enumerate(blocks))
recv_cache = self.dummy_cache[:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size)
with torch.cuda.stream(self.migration_stream):

with self.migration_stream:
for layer_idx in range(self.cache_engine.num_layers):
cache_idx = layer_idx % self.migration_num_layers
if cache_idx == 0:
col.recv(recv_cache, src_handle, self.group_name)
self.cache_engine.attn_backend.swap_blocks(recv_cache[cache_idx], self.gpu_cache[layer_idx], src_to_dst)
torch.cuda.Stream.synchronize(self.migration_stream)
self.migration_stream.synchronize()

def get_migration_backend(migration_config: MigrationConfig, cache_engine: CacheEngine, worker_handle_list, scheduling_strategy,
is_driver_worker, gpu_cache, worker_rank, local_rank) -> MigrationBackendBase:
Expand Down
19 changes: 10 additions & 9 deletions llumnix/backends/vllm/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,24 @@ def load_model(self):
torch.cuda.set_device(self.device)
return super().load_model()

# TODO(KuilongCui): global rank can be get from manager. There is no need to get for gloo and nccl
# for every migratie_cache
def get_global_rank(self):
return self.global_rank

def reserve_memory_for_migration(self, migration_config: MigrationConfig, model_config: ModelConfig,
cache_config: CacheConfig, parallel_config: ParallelConfig):
cache_config: CacheConfig, parallel_config: ParallelConfig) -> int:
# TODO(s5u13b): move this to arguments checker
if parallel_config.world_size > 1 and migration_config.migration_backend == 'nccl':
logger.warning("nccl backend is not supported for PP or TP enabled model, use gloo instead.")
migration_config.migration_backend = 'gloo'

# for nccl migration backend, reserve gpu memory for dummy cache in migration backend
if migration_config.migration_backend == "nccl" and parallel_config.world_size == 1:
migrate_cache_blocks_size = migration_config.migration_cache_blocks
migrate_num_layers = migration_config.migration_num_layers
dummy_cache_size = migrate_num_layers * migrate_cache_blocks_size * CacheEngine.get_cache_block_size(
cache_config, model_config, parallel_config) // model_config.get_num_layers(parallel_config)
migrate_cache_blocks_size = migration_config.migration_cache_blocks
migrate_num_layers = migration_config.migration_num_layers
dummy_cache_size = migrate_num_layers * migrate_cache_blocks_size * CacheEngine.get_cache_block_size(
cache_config, model_config, parallel_config) // model_config.get_num_layers(parallel_config)

# For nccl migration backend, reserve gpu memory for dummy cache in migration backend. For other backends,
# CPU memory is used for the dummy cache, which is almost unlimited, so no special action is needed.
if migration_config.migration_backend == "nccl" and parallel_config.world_size == 1:
device = torch.device(f"cuda:{self.local_rank}")
_, total_memory = torch.cuda.mem_get_info(device)
migrate_ratio = math.ceil(dummy_cache_size / total_memory * 10000) / 10000
Expand All @@ -74,6 +73,8 @@ def reserve_memory_for_migration(self, migration_config: MigrationConfig, model_
logger.info("nccl migration backend take {:.4f} gpu memory, left gpu_memory_utilization {:.4f} for kv cache."
.format(migrate_ratio, cache_config.gpu_memory_utilization))

return dummy_cache_size

def init_migration(self, instance_id: str, migration_config: MigrationConfig, src_worker_handle_list,
placement_group=None, node_id=None) -> None:
if placement_group:
Expand Down
8 changes: 6 additions & 2 deletions llumnix/llm_engine_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,12 @@ async def run_task(alive_instances: List[str], task_name: str, *args, **kwargs):
self.scale_down(dead_instances, rebuild_migrate_backend=False)

if self.engine_manager_args.migration_backend == 'gloo':
# clear gloo migrate backend intermediate state
ray.kill(ray.get_actor("gloo_queue", "llumnix"))
try:
# clear gloo migrate backend intermediate state
ray.kill(ray.get_actor("gloo_queue", "llumnix"))
except ValueError:
# gloo_queue may not have been created yet; just ignore this error.
pass

return dead_instances

Expand Down
4 changes: 1 addition & 3 deletions tests/backends/vllm/test_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from unittest.mock import MagicMock

from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput,
Expand Down Expand Up @@ -88,4 +86,4 @@ def test_llm_engine_from_engine_args():

latency_data = LatencyMemData({},{},{})
llm_engine = MockEngine.from_engine_args(engine_args, instance_id="0", migration_config=None, latency_mem=latency_data)
assert llm_engine.executor_class == SimGPUExecutor
assert llm_engine.executor_class == SimGPUExecutor
3 changes: 1 addition & 2 deletions tests/backends/vllm/test_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import pytest
import torch
import time
import ray
from ray.util.queue import Queue as RayQueue
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
Expand Down Expand Up @@ -141,4 +140,4 @@ def test_clear_migration_states():
_, seq_group = create_dummy_prompt("0",7,block_size)
llumlet.backend_engine.add_migrating_out_request_last_stage(seq_group)
llumlet.clear_migration_states(is_migrate_in=False)
assert llumlet.backend_engine.get_last_running_request() is not None
assert llumlet.backend_engine.get_last_running_request() is not None
110 changes: 110 additions & 0 deletions tests/backends/vllm/test_migration_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import pytest
import torch
import ray

from vllm.engine.arg_utils import EngineArgs

from llumnix.backends.vllm.worker import MigrationWorker
from llumnix.arg_utils import EngineManagerArgs
from llumnix.utils import random_uuid

from tests.backends.vllm.test_worker import create_worker

class MockMigrationWorker(MigrationWorker):
def set_gpu_cache(self, data):
for layer_idx in range(self.cache_engine.num_layers):
self.gpu_cache[layer_idx].copy_(data[layer_idx])
torch.cuda.synchronize()

def get_gpu_cache(self):
torch.cuda.synchronize()
return self.gpu_cache

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPU to run the test.")
@pytest.mark.parametrize("backend", ['rpc', 'gloo', 'nccl'])
def test_migrate_cache(backend):
ray.init(namespace="llumnix", ignore_reinit_error=True)

engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config()
migraiton_config = EngineManagerArgs(migration_cache_blocks=3, migration_num_layers=5).create_migration_config()
migraiton_config.migration_backend = backend

worker0 = create_worker(rank=0, local_rank=0, engine_config=engine_config,
worker_module_name="tests.backends.vllm.test_migration_backend",
worker_class_name="MockMigrationWorker")
worker1 = create_worker(rank=0, local_rank=0, engine_config=engine_config,
worker_module_name="tests.backends.vllm.test_migration_backend",
worker_class_name="MockMigrationWorker")

ray.get(worker0.execute_method.remote('init_device'))
ray.get(worker1.execute_method.remote('init_device'))

num_gpu_blocks = 8
ray.get(worker0.execute_method.remote('initialize_cache', num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0))
ray.get(worker1.execute_method.remote('initialize_cache', num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0))

worker0_id = random_uuid()
ray.get(worker0.execute_method.remote(
'init_migration',
instance_id=worker0_id,
migration_config=migraiton_config,
src_worker_handle_list=[worker0],
node_id=ray.get_runtime_context().get_node_id()))

worker1_id = random_uuid()
ray.get(worker1.execute_method.remote(
'init_migration',
instance_id=worker1_id,
migration_config=migraiton_config,
src_worker_handle_list=[worker1],
node_id=ray.get_runtime_context().get_node_id()))

instance_rank = {worker0_id: 0, worker1_id: 1}
group_name = random_uuid()
assert all(ray.get([worker0.execute_method.remote('rebuild_migration_backend',
instance_rank=instance_rank, group_name=group_name),
worker1.execute_method.remote('rebuild_migration_backend',
instance_rank=instance_rank, group_name=group_name)]))
assert all(ray.get([worker0.execute_method.remote('warmup'),
worker1.execute_method.remote('warmup')]))

num_layers = engine_config.model_config.get_num_layers(engine_config.parallel_config)
head_size = engine_config.model_config.get_head_size()
num_heads = engine_config.model_config.get_num_kv_heads(engine_config.parallel_config)
block_size = engine_config.cache_config.block_size

dummy_data = torch.randn(size=(num_layers, 2, num_gpu_blocks, block_size*num_heads*head_size))
ray.get(worker0.execute_method.remote('set_gpu_cache', data=dummy_data))
worker0_data = ray.get(worker0.execute_method.remote('get_gpu_cache'))

dst_blocks = list(range(num_gpu_blocks))
random.shuffle(dst_blocks)
src_to_dst = dict(enumerate(dst_blocks))
ray.get(worker1.execute_method.remote(
'migrate_cache',
src_worker_handle_list=[worker0],
src_blocks=list(src_to_dst.keys()),
dst_blocks=list(src_to_dst.values())))

worker1_data = ray.get(worker1.execute_method.remote('get_gpu_cache'))

for layer_idx in range(num_layers):
for src_idx, dst_idx in src_to_dst.items():
assert torch.allclose(worker0_data[layer_idx][0][src_idx], worker1_data[layer_idx][0][dst_idx])
assert torch.allclose(worker0_data[layer_idx][1][src_idx], worker1_data[layer_idx][1][dst_idx])

ray.shutdown()
Loading

0 comments on commit 1ce9d7a

Please sign in to comment.