From 2135d8b41ec21607647b749388fb3eea31e3b216 Mon Sep 17 00:00:00 2001 From: Xinyi Zhang <114055322+Xinyi-ECNU@users.noreply.github.com> Date: Wed, 6 Nov 2024 13:08:16 +0800 Subject: [PATCH] [Misc] Add e2e test for prefill-decoding migration (#65) --- docs/Arguments.md | 8 ++++++++ llumnix/arg_utils.py | 5 ++++- llumnix/config/default.py | 2 +- llumnix/global_scheduler/global_scheduler.py | 1 + llumnix/global_scheduler/scaling_scheduler.py | 5 +++-- tests/e2e_test/test_e2e.py | 10 +++++++--- tests/e2e_test/test_migration.py | 8 ++++++-- 7 files changed, 30 insertions(+), 9 deletions(-) diff --git a/docs/Arguments.md b/docs/Arguments.md index 56474d82..c8397bfa 100644 --- a/docs/Arguments.md +++ b/docs/Arguments.md @@ -38,6 +38,8 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] [--migration-num-layers MIGRATION_NUM_LAYERS] [--last-stage-max-blocks LAST_STAGE_MAX_BLOCKS] [--max-stages MAX_STAGES] + [--enable-pd-disagg] + [--num-dispatch-instances NUM_DISPATCH_INSTANCES] [--log-request-timestamps] ``` @@ -168,6 +170,12 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] `--log-request-timestamps` - Enable logging request timestamps. +`--enable-pd-disagg` +- Enable prefill decoding disaggregation. + +`--num-dispatch-instances` +- Number of available instances for dispatch. + # Unsupported vLLM feature options `--device` diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index 70a643cf..dd80276d 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -306,6 +306,9 @@ def add_cli_args( type=int, help='drop migration if the number of stages > max_stages') parser.add_argument('--enable-pd-disagg', - type=bool, + action='store_true', help='enable prefill decoding disaggregation') + parser.add_argument('--num-dispatch-instances', + type=int, + help='number of available instances for dispatch') return parser diff --git a/llumnix/config/default.py b/llumnix/config/default.py index 17849463..fb94443b 100644 --- a/llumnix/config/default.py +++ b/llumnix/config/default.py @@ -80,7 +80,7 @@ _C.MANAGER.LOAD_METRIC = 'remaining_steps' # Request dispatch policy _C.MANAGER.DISPATCH_POLICY = 'load' -# Number of available dispatch instances. -1 indicates that all instances can be used for dispatching +# Number of available dispatch instances. math.inf indicates that all instances can be used for dispatching _C.MANAGER.NUM_DISPATCH_INSTANCES = math.inf # ----------------------------------------------------------------------------- diff --git a/llumnix/global_scheduler/global_scheduler.py b/llumnix/global_scheduler/global_scheduler.py index 79d6e88e..201b57de 100644 --- a/llumnix/global_scheduler/global_scheduler.py +++ b/llumnix/global_scheduler/global_scheduler.py @@ -48,6 +48,7 @@ def __init__(self, global_scheduler_config.scale_down_threshold, global_scheduler_config.scaling_policy, self.instance_load_calculator, + self.enable_pd_disagg, global_scheduler_config.num_dispatch_instances) self.num_instances = 0 diff --git a/llumnix/global_scheduler/scaling_scheduler.py b/llumnix/global_scheduler/scaling_scheduler.py index edcc9627..7607d88a 100644 --- a/llumnix/global_scheduler/scaling_scheduler.py +++ b/llumnix/global_scheduler/scaling_scheduler.py @@ -14,7 +14,6 @@ from typing import Dict, List, Tuple, Set from abc import ABC, abstractmethod from enum import Enum -import math import numpy as np from llumnix.logger import init_logger @@ -36,6 +35,7 @@ def __init__(self, scale_down_threshold: float, scaling_policy: str, instance_load_calculator: InstanceLoadCalculator, + enable_pd_disagg: bool, maximum_prefill_instance_num: int) -> None: self.scale_up_threshold = scale_up_threshold self.scale_down_threshold = scale_down_threshold @@ -46,6 +46,7 @@ def __init__(self, self.num_instances = 0 self.instance_id_set: Set[str] = set() self.maximum_prefill_instance_num = maximum_prefill_instance_num + self.enable_pd_disagg = enable_pd_disagg # instance info args self.instance_info: Dict[str, InstanceInfo] = None self.sorted_instance_infos: List[InstanceInfo] = None @@ -78,7 +79,7 @@ def add_instance(self, instance_id: str) -> None: self.instance_id_set.add(instance_id) self.num_instances = len(self.instance_id_set) instance_type = None - if self.maximum_prefill_instance_num == math.inf: + if not self.enable_pd_disagg: instance_type = InstanceType.NO_CONSTRAINTS else: if len(self.instance_type_id_set[InstanceType.PREFILL]) < self.maximum_prefill_instance_num: diff --git a/tests/e2e_test/test_e2e.py b/tests/e2e_test/test_e2e.py index 42f92512..741360f1 100644 --- a/tests/e2e_test/test_e2e.py +++ b/tests/e2e_test/test_e2e.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import subprocess import asyncio import pytest @@ -40,7 +41,8 @@ def parse_launch_mode(launch_mode: str): def generate_launch_command(result_filename: str = "", launch_ray_cluster: bool = True, HEAD_NODE_IP: str = "127.0.0.1", ip: str = "127.0.0.1", port: int = 37000, instances_num = 1, dispatch_policy: str = "load", migration_backend = "gloo", model = "facebook/opt-125m", max_model_len: int = 2048, - launch_mode: str = 'eief', log_instance_info: bool = False): + launch_mode: str = 'eief', log_instance_info: bool = False, enable_pd_disagg: bool = False, + num_dispatch_instances: int = math.inf): disable_init_instance_by_manager, disable_fixed_node_init_instance = parse_launch_mode(launch_mode) command = ( f"RAY_DEDUP_LOGS=0 HEAD_NODE_IP={HEAD_NODE_IP} HEAD_NODE=1 " @@ -64,6 +66,8 @@ def generate_launch_command(result_filename: str = "", launch_ray_cluster: bool f"--migration-cache-blocks 32 " f"--tensor-parallel-size 1 " f"--request-output-queue-port {1234+port} " + f"{'--enable-pd-disagg ' if enable_pd_disagg else ''} " + f"{f'--num-dispatch-instances {num_dispatch_instances} ' if num_dispatch_instances != math.inf else ''} " f"{'--launch-ray-cluster ' if launch_ray_cluster else ''}" f"{'> instance_'+result_filename if len(result_filename)> 0 else ''} 2>&1 &" ) @@ -98,7 +102,7 @@ def clear_ray_state(): continue ray.shutdown() -async def get_llumnix_responce(prompt, sampling_params, ip_ports): +async def get_llumnix_response(prompt, sampling_params, ip_ports): timeout = aiohttp.ClientTimeout(total=60) request = { @@ -155,7 +159,7 @@ async def test_e2e(model, migration_backend, launch_mode): llumnix_output = {} for prompt in prompts: - response = await asyncio.wait_for(get_llumnix_responce(prompt, sampling_params, f"127.0.0.1:{base_port}"), + response = await asyncio.wait_for(get_llumnix_response(prompt, sampling_params, f"127.0.0.1:{base_port}"), timeout=60*5) llumnix_output[prompt] = response['text'][0] diff --git a/tests/e2e_test/test_migration.py b/tests/e2e_test/test_migration.py index 7fe167bb..ddf7fb51 100644 --- a/tests/e2e_test/test_migration.py +++ b/tests/e2e_test/test_migration.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import asyncio from collections import defaultdict import re @@ -66,17 +67,20 @@ def parse_manager_log_file(log_file): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="at least 2 gpus required for migration bench") @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) @pytest.mark.parametrize("migration_backend", ['rpc', 'gloo', 'nccl']) -async def test_migration_benchmark(model, migration_backend): +@pytest.mark.parametrize("enable_pd_disagg", [False, True]) +async def test_migration_benchmark(model, migration_backend, enable_pd_disagg): base_port = 37037 instance_output_logs = [] device_count = torch.cuda.device_count() + num_dispatch_instances = device_count//2 if enable_pd_disagg else math.inf for i in range(device_count): output_log = f"{base_port+i}.out" instance_output_logs.append("instance_"+output_log) launch_command = generate_launch_command(result_filename=output_log, launch_ray_cluster=False, port=base_port+i, model=model, dispatch_policy="flood", migration_backend=migration_backend, - log_instance_info=True) + log_instance_info=True, enable_pd_disagg=enable_pd_disagg, + num_dispatch_instances=num_dispatch_instances) subprocess.run(launch_command, shell=True, check=True) await asyncio.sleep(60)