Skip to content

Commit

Permalink
[Core] Support one-to-many and many-to-one migration (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui authored Nov 11, 2024
1 parent 2135d8b commit 844c836
Show file tree
Hide file tree
Showing 30 changed files with 461 additions and 257 deletions.
14 changes: 7 additions & 7 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,34 @@ install:

.PHONY: lint
lint: check_pylint_installed check_pytest_installed
@pylint --rcfile=.pylintrc -s n --jobs=32 ./llumnix
@pylint --rcfile=.pylintrc -s n --jobs=128 ./llumnix

@pylint --rcfile=.pylintrc \
--disable=protected-access,super-init-not-called,unused-argument,redefined-outer-name,invalid-name \
-s n --jobs=32 ./tests
-s n --jobs=128 ./tests

.PHONY: test
test: check_pytest_installed
@pytest -x -v --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings
@pytest -v -x --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings
@python examlpes/offline_inference.py
@pytest -v tests/e2e_test/test_e2e.py
@pytest -v -x tests/e2e_test/test_e2e.py
@pytest -v -x ./tests/e2e_test/test_migration.py

.PHONY: unit_test
unit_test: check_pytest_installed
@pytest -x -v --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings
@pytest -v -x --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings

.PHONY: offline_test
offline_test:
@python examlpes/offline_inference.py

.PHONY: e2e_test
e2e_test:
@pytest -v tests/e2e_test/test_e2e.py
@pytest -v -x tests/e2e_test/test_e2e.py

.PHONY: bench_test
bench_test:
@pytest -v ./tests/e2e_test/test_bench.py
@pytest -v -x ./tests/e2e_test/test_bench.py

.PHONY: migration_test
migration_test:
Expand Down
5 changes: 2 additions & 3 deletions configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ SERVER:
HOST: '127.0.0.1'
PORT: 1234
QUEUE_TYPE: "rayqueue"

RAY:
RAY_CLUSTER_PORT: 6379
LAUNCH_RAY_CLUSTER: True

Expand All @@ -21,6 +19,7 @@ MANAGER:
REQUEST_MIGRATION_POLICY: 'SJF'

MIGRATION_BACKEND: 'gloo'
MIGRATION_CACHE_BLOCKS: 512
MIGRATION_BUFFER_BLOCKS: 512
MIGRATION_INTERNAL_BUFFER_NUM: 2

ENABLE_SCALING: False
13 changes: 9 additions & 4 deletions docs/Arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
[--profiling-result-file-path PROFILING_RESULT_FILE_PATH]
[--gpu-type GPU_TYPE]
[--polling-interval POLLING_INTERVAL]
[--migration-backend {gloo,nccl,rpc}]
[--migration-cache-blocks MIGRATION_CACHE_BLOCKS]
[--migration-backend {gloo,rpc}]
[--migration-buffer-blocks MIGRATION_BUFFER_BLOCKS]
[--migration-backend-init-timeout MIGRATION_BACKEND_INIT_TIMEOUT]
[--migration-num-layers MIGRATION_NUM_LAYERS]
[--last-stage-max-blocks LAST_STAGE_MAX_BLOCKS]
[--max-stages MAX_STAGES]
[--enable-pd-disagg]
[--num-dispatch-instances NUM_DISPATCH_INSTANCES]
[--migration-internal-buffer-num MIGRATION_INTERNAL_BUFFER_NUM]
[--log-request-timestamps]
```
Expand Down Expand Up @@ -147,8 +148,8 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
- Possible choices: gloo, rpc
- Default: "rpc"

`--migration-cache-blocks`
- Number of cache blocks in migration.
`--migration-buffer-blocks`
- Number of cache blocks in each migration buffer.
- Default: 512

`--migration-backend-init-timeout`
Expand All @@ -167,6 +168,10 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
- Drop migration if the number of stages > max_stages.
- Default: 3

`--migration-internal-buffer-num`
- Number of the buffer in migration backend for sending and receiving
- Default: 2

`--log-request-timestamps`
- Enable logging request timestamps.

Expand Down
21 changes: 14 additions & 7 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from llumnix.config import LlumnixConfig, get_llumnix_config
from llumnix.config.default import _C


class LlumnixArgumentParser(argparse.ArgumentParser):
def __init__(self, *args, **kwargs):
self.cur_namespace = "llumnix"
Expand Down Expand Up @@ -134,10 +133,11 @@ class EngineManagerArgs:

migration_backend_init_timeout: float = None
migration_backend: str = None
migration_cache_blocks: int = None
migration_buffer_blocks: int = None
migration_num_layers: int = None
last_stage_max_blocks: int = None
max_stages: int = None
migration_internal_buffer_num: int = None

enable_pd_disagg: bool = None

Expand Down Expand Up @@ -172,11 +172,12 @@ def create_global_scheduler_configs(
def create_migration_config(self) -> MigrationConfig:
migration_config = MigrationConfig(self.request_migration_policy,
self.migration_backend,
self.migration_cache_blocks,
self.migration_buffer_blocks,
self.migration_num_layers,
self.last_stage_max_blocks,
self.max_stages,
self.migration_backend_init_timeout)
self.migration_backend_init_timeout,
self.migration_internal_buffer_num)
return migration_config

@classmethod
Expand All @@ -195,6 +196,9 @@ def check_args(cls, args: 'EngineManagerArgs', parser: argparse.ArgumentParser):
if hasattr(action, 'choices') and action.choices is not None and hasattr(args, action.dest):
assert getattr(args, action.dest) in action.choices, f"{action.dest} should be one of {action.choices}."

assert args.migration_backend != 'nccl', 'NCCL has been temporarily deprecated due to its incompatibility with \
concurrent migrations in Llumnix.'

assert args.migration_backend != 'gloo' or (args.migration_backend == 'gloo' \
and not args.disable_init_instance_by_manager and not args.disable_fixed_node_init_instance), \
("When using gloo as migration backend, "
Expand Down Expand Up @@ -288,20 +292,23 @@ def add_cli_args(

parser.add_argument('--migration-backend',
type=str,
choices=['gloo','nccl','rpc'],
choices=['gloo', 'nccl', 'rpc'],
help='communication backend of migration')
parser.add_argument('--migration-backend-init-timeout',
type=float,
help='timeout(s) for initializing migration backend')
parser.add_argument('--migration-cache-blocks',
parser.add_argument('--migration-buffer-blocks',
type=int,
help='number of cache blocks in migration')
help='number of cache blocks in each migration buffer')
parser.add_argument('--migration-num-layers',
type=int,
help='number of kv-cache layers to transfer in each round during migration')
parser.add_argument('--last-stage-max-blocks',
type=int,
help='if the number pf remain blocks < last_stage_max_blocks, do last stage migration')
parser.add_argument('--migration-internal-buffer-num',
type=int,
help='number of the buffer in migration backend for sending and receiving')
parser.add_argument('--max-stages',
type=int,
help='drop migration if the number of stages > max_stages')
Expand Down
23 changes: 23 additions & 0 deletions llumnix/backends/migration_backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

from abc import ABC, abstractmethod
from typing import List
import queue

import torch

class MigrationBackendBase(ABC):
@abstractmethod
Expand All @@ -39,3 +41,24 @@ def do_send(self, dst_handle, blocks: List[int]):
@abstractmethod
def do_recv(self, src_handle, blocks: List[int]):
raise NotImplementedError

class BufferMigrationBackend(MigrationBackendBase):
def __init__(self, num_buffer, buffer_shape, buffer_dtype, buffer_device, pin_memory, *args, **kwargs):
super().__init__(*args, **kwargs)

self.num_buffer = num_buffer

self.dummy_buffer = [
torch.empty(size=buffer_shape, dtype=buffer_dtype, device=buffer_device, pin_memory=pin_memory)
for _ in range(self.num_buffer)
]

self.avaiable_buffer_queue = queue.Queue()
for i in range(self.num_buffer):
self.avaiable_buffer_queue.put_nowait(i)

def get_available_cache(self):
return self.avaiable_buffer_queue.get()

def put_back_cache(self, buffer_id):
self.avaiable_buffer_queue.put_nowait(buffer_id)
8 changes: 4 additions & 4 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,10 @@ def commit_dst_request(self, backend_request: SequenceGroupLlumnix) -> None:

async def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None:
await dst_ray_actor.execute_engine_method.remote("_run_workers",
"migrate_cache",
dst_blocks=dst_blocks,
src_blocks=src_blocks,
src_worker_handle_list=self.worker_handle_list)
"migrate_cache",
dst_blocks=dst_blocks,
src_blocks=src_blocks,
src_worker_handle_list=self.worker_handle_list)

def _run_workers(self, *args, **kwargs):
# pylint: disable=protected-access
Expand Down
Loading

0 comments on commit 844c836

Please sign in to comment.