Skip to content

Commit

Permalink
[Core] Add Gloo and NCCL migration backends with elastic support (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui authored Aug 23, 2024
1 parent 0b48bbc commit 6b7b099
Show file tree
Hide file tree
Showing 22 changed files with 765 additions and 253 deletions.
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[submodule "third_party/pygloo"]
path = third_party/pygloo
url = https://github.com/ZeldaHuang/pygloo
branch = llumnix
51 changes: 51 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

.PHONY: init
init:
@git submodule update --init --recursive

.PHONY: install
install:
pip install -e .

.PHONY: lint
lint: check_pylint_installed
pylint --rcfile=.pylintrc ./llumnix

.PHONY: test
test:
pytest -vs --ignore=third_party/ --disable-warnings

#################### pygloo install for gloo migration backend begin ####################

BAZEL_CMD = bazel
PYGLOO_DIR = third_party/pygloo

.PHONY: pygloo
pygloo: init
./tools/pygloo_install.sh

##################### pygloo install for gloo migration backend end #####################

##################################### pylint begin ######################################

PYLINT_VERSION = 2.12.2

.PHONY: check_pylint_installed
check_pylint_installed:
@command -v pylint >/dev/null 2>&1 || { \
echo "pylint is not installed. Installing pylint $(PYLINT_VERSION)..."; \
python3 -m pip install pylint==$(PYLINT_VERSION); }

###################################### pylint end #######################################
6 changes: 4 additions & 2 deletions benchmark/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,8 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument("--tokenizer", type=str, required=True,
help="Name or path of the tokenizer.")
parser.add_argument('--trust_remote_code',
action='store_true')
parser.add_argument('-v', '--verbose', action='store_true')
parser.add_argument('--backend', type=GenerationBackend,
choices=[e.name for e in GenerationBackend], default='vLLM')
Expand Down Expand Up @@ -701,7 +703,7 @@ def main():
assert args.random_prompt_count is not None

backend = GenerationBackend[args.backend]
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=args.trust_remote_code)
print(tokenizer)

if args.dataset_type:
Expand Down Expand Up @@ -798,4 +800,4 @@ def main():


if __name__ == '__main__':
main()
main()
27 changes: 19 additions & 8 deletions docs/Arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ Note: since Llumnix is still in alpha stage, the interface and arguments are *su

```
usage: -m llumnix.entrypoints.vllm.api_server [-h]
[--fixed-node-init-instance]
[--init-instance-by-manager]
[--disable-fixed-node-init-instance]
[--disable-init-instance-by-manager]
[--initial-instances INITIAL_INSTANCES]
[--load-metric {remaining_steps,usage_ratio}]
[--polling-interval POLLING_INTERVAL]
Expand All @@ -30,17 +30,20 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
[--log-filename LOG_FILENAME]
[--profiling-result-file-path PROFILING_RESULT_FILE_PATH]
[--gpu-type GPU_TYPE]
[--migration-backend {gloo,rpc}]
[--migration-cache_blocks MIGRATION_CACHE_BLOCKS]
[--polling-interval POLLING_INTERVAL]
[--migration-backend {gloo,nccl,rpc}]
[--migration-cache-blocks MIGRATION_CACHE_BLOCKS]
[--migration-backend-init-timeout MIGRATION_BACKEND_INIT_TIMEOUT]
[--migration-num-layers MIGRATION_NUM_LAYERS]
[--last-stage-max-blocks LAST_STAGE_MAX_BLOCKS]
[--max-stages MAX_STAGES]
```

`--fixed-node-init-instance`
- Fix the placement of instance to current node.
`--disable-fixed-node-init-instance`
- Disable fixing the instance's placement to the current node.

`--init-instance-by-manager`
- initialize instance by manager.
`--disable-init-instance-by-manager`
- Disable the initialization of instance by the manager.

`--initial-instances`
- Number of model instances created at initialization.
Expand Down Expand Up @@ -138,6 +141,14 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
- Number of cache blocks in migration.
- Default: 512

`--migration-backend-init-timeout`
- Timeout(s) for initializing migration backend.
- Default: 10.0

`--migration-num-layers`
- number of kv-cache layers to transfer in each round during migration
- Default: 1

`--last-stage-max-blocks`
- If the number of remaining blocks < last_stage_max_blocks, do last stage migration.
- Default: 4
Expand Down
8 changes: 6 additions & 2 deletions docs/Quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Requirements

Llumnix is currently built on top of vLLM (version 0.4.2). Therefore, the installation requirements are almost identical to those of vLLM. You can view the specific installation requirements for vLLM at the following link:
Llumnix requires python `3.8.1~3.10.0` and is currently built on top of vLLM (version 0.4.2). Therefore, the installation requirements are almost identical to those of vLLM. You can view the specific installation requirements for vLLM at the following link:

[vLLM Installation](https://docs.vllm.ai/en/v0.4.2/getting_started/installation.html)

Expand All @@ -12,9 +12,13 @@ You can build and install Llumnix from source:
```
git clone https://github.com/AlibabaPAI/llumnix.git
cd llumnix
pip install -e .
make install
```

If you want to use gloo as migration backend, please install [Bazel](https://github.com/bazelbuild/bazel) >= 5.1.0. Then, run `make pygloo` to install [pygloo](https://github.com/ZeldaHuang/pygloo).

Note: Using conda is not recommended, as it cannot properly handle pygloo's dependency on gcc libstdc++.so.6: version GLIBCXX_3.4.30.

We will provide official releases through pypi soon.

After installation, you can follow this guide to use Llumnix for multi-instance LLM serving quickly.
Expand Down
45 changes: 27 additions & 18 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
@dataclass
class EngineManagerArgs:
launch_ray_cluster: bool = True
init_instance_by_manager: bool = True
disable_init_instance_by_manager: bool = False
initial_instances: int = 1
fixed_node_init_instance: bool = False
disable_fixed_node_init_instance: bool = False

load_metric: str = 'remaining_steps'
polling_interval: float = 0.05
Expand Down Expand Up @@ -52,10 +52,11 @@ class EngineManagerArgs:
profiling_result_file_path: str = ""

gpu_type: str = "a10"

migration_backend_init_timeout: float = 10.0
migration_backend: str = "rpc"
migration_cache_blocks: int = 512
last_stage_max_blocks: int = 4
migration_cache_blocks: int = 32
migration_num_layers: int = 1
last_stage_max_blocks: int = 16
max_stages: int = 3

def create_engine_manager_configs(
Expand All @@ -72,14 +73,14 @@ def create_engine_manager_configs(
self.scale_down_threshold)
return global_scheduler_config

def create_migration_configs(
self,
) -> MigrationConfig:
def create_migration_config(self) -> MigrationConfig:
migration_config = MigrationConfig(self.request_migration_policy,
self.migration_backend,
self.migration_cache_blocks,
self.migration_num_layers,
self.last_stage_max_blocks,
self.max_stages)
self.max_stages,
self.migration_backend_init_timeout)
return migration_config

@classmethod
Expand All @@ -93,12 +94,12 @@ def from_cli_args(cls, args: argparse.Namespace) -> 'EngineManagerArgs':
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument('--fixed-node-init-instance',
parser.add_argument('--disable-fixed-node-init-instance',
action='store_true',
help='fix the placement of instance to current node')
parser.add_argument('--init-instance-by-manager',
help='disable fixing the placement of instance to current node')
parser.add_argument('--disable-init-instance-by-manager',
action='store_true',
help='initialize instance by manager')
help='disable the initialization of the instance by the manager')
parser.add_argument('--initial-instances',
type=int,
default=EngineManagerArgs.initial_instances,
Expand All @@ -117,7 +118,7 @@ def add_cli_args(
parser.add_argument('--dispatch-policy',
type=str,
default=EngineManagerArgs.dispatch_policy,
choices=['balanced', 'load', 'queue'],
choices=['balanced', 'load', 'queue', 'flood'],
help='request dispatch policy')

parser.add_argument('--enable-migration',
Expand Down Expand Up @@ -198,12 +199,20 @@ def add_cli_args(
parser.add_argument('--migration-backend',
type=str,
default=EngineManagerArgs.migration_backend,
choices=['gloo','rpc'],
help='communication backend of migration')
parser.add_argument('--migration-cache_blocks',
choices=['gloo','nccl','rpc'],
help='communication backend during migration')
parser.add_argument('--migration-backend-init-timeout',
type=float,
default=EngineManagerArgs.migration_backend_init_timeout,
help='timeout(s) for initializing migration backend')
parser.add_argument('--migration-cache-blocks',
type=int,
default=EngineManagerArgs.migration_cache_blocks,
help='number of cache blocks in migration')
help='cache blocks num during migration')
parser.add_argument('--migration-num-layers',
type=int,
default=EngineManagerArgs.migration_num_layers,
help='number of kv-cache layers to transfer in each round during migration')
parser.add_argument('--last-stage-max-blocks',
type=int,
default=EngineManagerArgs.last_stage_max_blocks,
Expand Down
41 changes: 41 additions & 0 deletions llumnix/backends/migration_backend_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import List


class MigrationBackendBase(ABC):
@abstractmethod
def init_backend(self, group_name, world_size, rank) -> bool:
raise NotImplementedError

@abstractmethod
def destory_backend(self) -> None:
raise NotImplementedError

@abstractmethod
def warmup(self) -> bool:
raise NotImplementedError

@abstractmethod
def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]) -> None:
raise NotImplementedError

@abstractmethod
def do_send(self, dst_handle, blocks: List[int]):
raise NotImplementedError

@abstractmethod
def do_recv(self, src_handle, blocks: List[int]):
raise NotImplementedError
7 changes: 2 additions & 5 deletions llumnix/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,9 @@ def init_backend_engine(instance_id: str, backend_type: BackendType, *args, **kw
raise ValueError(f'unimplemented backend {backend_type}')
return backend_engine

def initialize_cluster(
def initialize_placement_group(
world_size: int = 1,
ray_address: Optional[str] = None,
detached: bool = False,
detached: bool = False
) -> Tuple[str, Optional["PlacementGroup"]]:
"""Initialize the distributed cluster probably with Ray.
Expand All @@ -55,8 +54,6 @@ def initialize_cluster(
raise ImportError(
"Ray is not installed. Please install Ray to use distributed "
"serving.")
# Connect to a ray cluster.
ray.init(address=ray_address, ignore_reinit_error=True, namespace='llumnix')

lifetime = "detached" if detached else None
# Create placement group for worker processes
Expand Down
7 changes: 7 additions & 0 deletions llumnix/backends/vllm/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from vllm.sequence import Logprob, SequenceOutput, SequenceGroupOutput, SamplerOutput, ExecuteModelRequest
from vllm.config import _GB

from llumnix.config import MigrationConfig
from llumnix.logger import init_logger
from llumnix.backends.vllm.utils import get_cache_block_size
from llumnix.backends.profiling import LatencyMemData, SimCacheConfig, model_prefill, model_decode, _pad_to_alignment
Expand All @@ -35,6 +36,7 @@

class LlumnixRayGPUExecutor(RayGPUExecutor):
node_id: str = None
migration_config: MigrationConfig = None

def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
Expand Down Expand Up @@ -149,6 +151,11 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
self._run_workers("reserve_memory_for_migration",
migration_config=self.migration_config,
model_config=self.model_config,
cache_config=self.cache_config,
parallel_config=self.parallel_config)

def execute_model(self, *args, **kwargs):
t0 = time.time()
Expand Down
Loading

0 comments on commit 6b7b099

Please sign in to comment.