Skip to content

Commit

Permalink
[Misc] Add e2e test for prefill-decoding migration (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyi-ECNU authored Nov 6, 2024
1 parent 188b08e commit 2135d8b
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 9 deletions.
8 changes: 8 additions & 0 deletions docs/Arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
[--migration-num-layers MIGRATION_NUM_LAYERS]
[--last-stage-max-blocks LAST_STAGE_MAX_BLOCKS]
[--max-stages MAX_STAGES]
[--enable-pd-disagg]
[--num-dispatch-instances NUM_DISPATCH_INSTANCES]
[--log-request-timestamps]
```
Expand Down Expand Up @@ -168,6 +170,12 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
`--log-request-timestamps`
- Enable logging request timestamps.

`--enable-pd-disagg`
- Enable prefill decoding disaggregation.

`--num-dispatch-instances`
- Number of available instances for dispatch.

# Unsupported vLLM feature options

`--device`
Expand Down
5 changes: 4 additions & 1 deletion llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ def add_cli_args(
type=int,
help='drop migration if the number of stages > max_stages')
parser.add_argument('--enable-pd-disagg',
type=bool,
action='store_true',
help='enable prefill decoding disaggregation')
parser.add_argument('--num-dispatch-instances',
type=int,
help='number of available instances for dispatch')
return parser
2 changes: 1 addition & 1 deletion llumnix/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
_C.MANAGER.LOAD_METRIC = 'remaining_steps'
# Request dispatch policy
_C.MANAGER.DISPATCH_POLICY = 'load'
# Number of available dispatch instances. -1 indicates that all instances can be used for dispatching
# Number of available dispatch instances. math.inf indicates that all instances can be used for dispatching
_C.MANAGER.NUM_DISPATCH_INSTANCES = math.inf

# -----------------------------------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions llumnix/global_scheduler/global_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self,
global_scheduler_config.scale_down_threshold,
global_scheduler_config.scaling_policy,
self.instance_load_calculator,
self.enable_pd_disagg,
global_scheduler_config.num_dispatch_instances)

self.num_instances = 0
Expand Down
5 changes: 3 additions & 2 deletions llumnix/global_scheduler/scaling_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from typing import Dict, List, Tuple, Set
from abc import ABC, abstractmethod
from enum import Enum
import math
import numpy as np

from llumnix.logger import init_logger
Expand All @@ -36,6 +35,7 @@ def __init__(self,
scale_down_threshold: float,
scaling_policy: str,
instance_load_calculator: InstanceLoadCalculator,
enable_pd_disagg: bool,
maximum_prefill_instance_num: int) -> None:
self.scale_up_threshold = scale_up_threshold
self.scale_down_threshold = scale_down_threshold
Expand All @@ -46,6 +46,7 @@ def __init__(self,
self.num_instances = 0
self.instance_id_set: Set[str] = set()
self.maximum_prefill_instance_num = maximum_prefill_instance_num
self.enable_pd_disagg = enable_pd_disagg
# instance info args
self.instance_info: Dict[str, InstanceInfo] = None
self.sorted_instance_infos: List[InstanceInfo] = None
Expand Down Expand Up @@ -78,7 +79,7 @@ def add_instance(self, instance_id: str) -> None:
self.instance_id_set.add(instance_id)
self.num_instances = len(self.instance_id_set)
instance_type = None
if self.maximum_prefill_instance_num == math.inf:
if not self.enable_pd_disagg:
instance_type = InstanceType.NO_CONSTRAINTS
else:
if len(self.instance_type_id_set[InstanceType.PREFILL]) < self.maximum_prefill_instance_num:
Expand Down
10 changes: 7 additions & 3 deletions tests/e2e_test/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import subprocess
import asyncio
import pytest
Expand Down Expand Up @@ -40,7 +41,8 @@ def parse_launch_mode(launch_mode: str):
def generate_launch_command(result_filename: str = "", launch_ray_cluster: bool = True, HEAD_NODE_IP: str = "127.0.0.1",
ip: str = "127.0.0.1", port: int = 37000, instances_num = 1, dispatch_policy: str = "load",
migration_backend = "gloo", model = "facebook/opt-125m", max_model_len: int = 2048,
launch_mode: str = 'eief', log_instance_info: bool = False):
launch_mode: str = 'eief', log_instance_info: bool = False, enable_pd_disagg: bool = False,
num_dispatch_instances: int = math.inf):
disable_init_instance_by_manager, disable_fixed_node_init_instance = parse_launch_mode(launch_mode)
command = (
f"RAY_DEDUP_LOGS=0 HEAD_NODE_IP={HEAD_NODE_IP} HEAD_NODE=1 "
Expand All @@ -64,6 +66,8 @@ def generate_launch_command(result_filename: str = "", launch_ray_cluster: bool
f"--migration-cache-blocks 32 "
f"--tensor-parallel-size 1 "
f"--request-output-queue-port {1234+port} "
f"{'--enable-pd-disagg ' if enable_pd_disagg else ''} "
f"{f'--num-dispatch-instances {num_dispatch_instances} ' if num_dispatch_instances != math.inf else ''} "
f"{'--launch-ray-cluster ' if launch_ray_cluster else ''}"
f"{'> instance_'+result_filename if len(result_filename)> 0 else ''} 2>&1 &"
)
Expand Down Expand Up @@ -98,7 +102,7 @@ def clear_ray_state():
continue
ray.shutdown()

async def get_llumnix_responce(prompt, sampling_params, ip_ports):
async def get_llumnix_response(prompt, sampling_params, ip_ports):
timeout = aiohttp.ClientTimeout(total=60)

request = {
Expand Down Expand Up @@ -155,7 +159,7 @@ async def test_e2e(model, migration_backend, launch_mode):

llumnix_output = {}
for prompt in prompts:
response = await asyncio.wait_for(get_llumnix_responce(prompt, sampling_params, f"127.0.0.1:{base_port}"),
response = await asyncio.wait_for(get_llumnix_response(prompt, sampling_params, f"127.0.0.1:{base_port}"),
timeout=60*5)
llumnix_output[prompt] = response['text'][0]

Expand Down
8 changes: 6 additions & 2 deletions tests/e2e_test/test_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import asyncio
from collections import defaultdict
import re
Expand Down Expand Up @@ -66,17 +67,20 @@ def parse_manager_log_file(log_file):
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="at least 2 gpus required for migration bench")
@pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B'])
@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo', 'nccl'])
async def test_migration_benchmark(model, migration_backend):
@pytest.mark.parametrize("enable_pd_disagg", [False, True])
async def test_migration_benchmark(model, migration_backend, enable_pd_disagg):
base_port = 37037
instance_output_logs = []

device_count = torch.cuda.device_count()
num_dispatch_instances = device_count//2 if enable_pd_disagg else math.inf
for i in range(device_count):
output_log = f"{base_port+i}.out"
instance_output_logs.append("instance_"+output_log)
launch_command = generate_launch_command(result_filename=output_log, launch_ray_cluster=False, port=base_port+i,
model=model, dispatch_policy="flood", migration_backend=migration_backend,
log_instance_info=True)
log_instance_info=True, enable_pd_disagg=enable_pd_disagg,
num_dispatch_instances=num_dispatch_instances)
subprocess.run(launch_command, shell=True, check=True)
await asyncio.sleep(60)

Expand Down

0 comments on commit 2135d8b

Please sign in to comment.