Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyi-ECNU committed Oct 8, 2024
1 parent c52edbe commit 3e1623b
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 34 deletions.
7 changes: 6 additions & 1 deletion llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,14 @@ class EngineManagerArgs:
last_stage_max_blocks: int = None
max_stages: int = None

enable_pd_disagg: bool = False
enable_pd_disagg: bool = None

def __post_init__(self):
# Check if all fields default to None
for field_info in dataclasses.fields(self):
if field_info.default is not None:
raise ValueError(f"The default value of '{field_info.name}' should be None")

for attr in dataclasses.fields(self):
if getattr(self, attr.name) is None:
setattr(self, attr.name, getattr(_C.MANAGER, attr.name.upper()))
Expand Down
2 changes: 2 additions & 0 deletions llumnix/global_scheduler/dispatch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def remove_instance(self, instance_id: str) -> None:
self.num_instances = len(self.instance_id_set)
if instance_id in self.instance_num_requests:
del self.instance_num_requests[instance_id]
if instance_id in self.available_dispatch_instance_set:
self.available_dispatch_instance_set.remove(instance_id)

def _sort_instance_infos(self,
descending: bool = True) -> None:
Expand Down
2 changes: 1 addition & 1 deletion llumnix/global_scheduler/global_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def dispatch(self) -> str:
request_expected_steps = 1 if self.enable_pd_disagg else math.inf
return instance_id, request_expected_steps

def pair_migration(self, pair_migration_type:str) -> List[Tuple[str, str]]:
def pair_migration(self, pair_migration_type: str) -> List[Tuple[str, str]]:
self.migration_scheduler.update_instance_infos(self.instance_info)
migrate_instance_pairs = self.migration_scheduler.pair_migration(pair_migration_type)
return migrate_instance_pairs
Expand Down
4 changes: 2 additions & 2 deletions llumnix/global_scheduler/migration_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def __init__(self,
self.instance_info: Dict[str, InstanceInfo] = None
self.sorted_instance_infos: List[InstanceInfo] = None

def pair_migration(self, pair_migration_type:str) -> List[Tuple[str, str]]:
def pair_migration(self, pair_migration_type: str) -> List[Tuple[str, str]]:
self._sort_instance_infos(descending=False)
sorted_src_instance_infos, sorted_dst_instance_infos = self._get_migration_instance_infos(pair_migration_type)
return self.pair_migration_policy.pair_migration(sorted_src_instance_infos, sorted_dst_instance_infos)

def _get_migration_instance_infos(self, pair_migration_type:str) -> Dict[str, InstanceInfo]:
def _get_migration_instance_infos(self, pair_migration_type: str) -> Dict[str, InstanceInfo]:
filter_instance_infos_policy = FilteringInstanceInfosPolicyFactory.get_policy(pair_migration_type,
migrate_out_load_threshold=self.migrate_out_load_threshold)
return filter_instance_infos_policy.filter_instances(self.sorted_instance_infos,pair_migration_type)
Expand Down
4 changes: 4 additions & 0 deletions llumnix/global_scheduler/scaling_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def add_instance(self, instance_id: str) -> None:

def remove_instance(self, instance_id: str) -> None:
self.instance_id_set.remove(instance_id)
for instance_type in InstanceType:
if instance_id in self.instance_type_id_set[instance_type]:
self.instance_type_id_set[instance_type].remove(instance_id)
break
self.num_instances = len(self.instance_id_set)

def get_empty_instance_info(self) -> InstanceInfo:
Expand Down
2 changes: 1 addition & 1 deletion llumnix/llm_engine_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ async def _push_migrations(self) -> None:
else:
asyncio.create_task(self._migrate(PairMigrationConstraints.NO_CONSTRAINTS, 1))

async def _migrate(self, pair_migration_type:str, migrate_in_num_requests:int) -> None:
async def _migrate(self, pair_migration_type: str, migrate_in_num_requests: int) -> None:
async def migrate_done_callback(ret, migrate_instance_pair: Tuple[str, str]) -> None:
if migrate_instance_pair[0] in self.instance_migrating:
self.instance_migrating[migrate_instance_pair[0]] = False
Expand Down
141 changes: 114 additions & 27 deletions tests/unit_test/backends/vllm/test_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
# limitations under the License.

from typing import List
import asyncio
import math
import pytest
import ray
import torch

from vllm import EngineArgs, SamplingParams
from llumnix.utils import random_uuid
from vllm.utils import random_uuid

from llumnix.backends.vllm.llm_engine import BackendVLLM
from llumnix.llumlet.llumlet import Llumlet
Expand All @@ -25,7 +28,7 @@
from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType

# pylint: disable=unused-import
from tests.unit_test.rpc.test_queue import request_output_queue_server
from tests.unit_test.rpc.test_queue import init_request_output_queue, init_server_info
# pylint: disable=unused-import
from tests.conftest import setup_ray_env

Expand All @@ -49,15 +52,17 @@ def __init__(self):
self.instance_id = "0"
self.backend_engine = MockBackendVLLM()

# @pytest.mark.skipif(torch.cuda.device_count() < 2,
# reason="Need at least 2 GPUs to run the test.")
# FIXME(ZeldaHuang) this test is currently unstable
@pytest.mark.skip(reason="Regression Test")
def test_migration_correctness(setup_ray_env, request_output_queue_server, server_info):
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo', 'nccl'])
@pytest.mark.asyncio
async def test_migration_correctness(setup_ray_env, migration_backend):
engine_args = EngineArgs(model="facebook/opt-125m",worker_use_ray=True)
id_rank_map = {"0":0,"1":1}
migration_config = MigrationConfig("LCFS", "gloo",16,1,4,5,20)
que = request_output_queue_server
migration_config = MigrationConfig("LCFS", migration_backend, 16, 1, 4, 5, 20)
server_info = init_server_info()
que = init_request_output_queue(server_info)
asyncio.create_task(que.run_server_loop())

llumlet_0:Llumlet = Llumlet.from_args(
False,
Expand Down Expand Up @@ -86,49 +91,131 @@ def test_migration_correctness(setup_ray_env, request_output_queue_server, serve
ray.get([llumlet_0.execute_engine_method.remote("_run_workers","rebuild_migration_backend", id_rank_map, "llumnix"),
llumlet_1.execute_engine_method.remote("_run_workers","rebuild_migration_backend", id_rank_map, "llumnix")])
# empty instance migrate out
res = ray.get(llumlet_0.migrate_out.remote("instance_1"))
res = ray.get(llumlet_0.migrate_out.remote("instance_1", num_requests=-1))
assert not res

# running without migration
def test_correctness(prompt):
async def test_correctness(prompt):
sampling_params = SamplingParams(top_k=1, temperature=0, ignore_eos=True, max_tokens=100)
request_id0 = random_uuid()
llumlet_0.generate.remote(request_id0, server_info, prompt, sampling_params)
llumlet_0.generate.remote(request_id0, server_info, math.inf, prompt, sampling_params)
request_output_queue = que
origin_output = None
finished = False
while not finished:
qsize = request_output_queue.qsize()
request_outputs = request_output_queue.get_nowait_batch(qsize)
for request_output in request_outputs:
origin_output = request_output.outputs[0]
finished = request_output.finished
request_output = await request_output_queue.get()
origin_output = request_output.outputs[0]
finished = request_output.finished

request_id1 = random_uuid()
ray.get(llumlet_0.generate.remote(request_id1, server_info, prompt, sampling_params))
ray.get(llumlet_0.generate.remote(request_id1, server_info, math.inf, prompt, sampling_params))
# wait prefill done
while True:
running_queue: List[LlumnixRequest] = ray.get(llumlet_0.execute_engine_method.remote("get_running_queue"))
if len(running_queue) > 0 and running_queue[0].inference_type == RequestInferenceType.DECODE:
break
# migrate request
res = ray.get(llumlet_0.migrate_out.remote("instance_1"))
res = ray.get(llumlet_0.migrate_out.remote("instance_1", num_requests=-1))
assert len(res) == 1
request_output_queue = que
output = None
finished = False
while not finished:
qsize = request_output_queue.qsize()
request_outputs = request_output_queue.get_nowait_batch(qsize)
for request_output in request_outputs:
if request_output.request_id != request_id1:
continue
output = request_output.outputs[0]
finished = request_output.finished
request_output = await request_output_queue.get()
origin_output = request_output.outputs[0]
finished = request_output.finished
if request_output.request_id != request_id1:
continue
output = request_output.outputs[0]
finished = request_output.finished

assert output.text == origin_output.text
assert output.cumulative_logprob == origin_output.cumulative_logprob
for prompt in TEST_PROMPTS:
await test_correctness(prompt)
que.cleanup()

@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo', 'nccl'])
@pytest.mark.asyncio
async def test_pd_diaggregation_correctness(setup_ray_env, migration_backend):
engine_args = EngineArgs(model="facebook/opt-125m",worker_use_ray=True)
id_rank_map = {"0":0,"1":1}
migration_config = MigrationConfig("LCFS", migration_backend, 16, 1, 4, 5, 20)
server_info = init_server_info()
que = init_request_output_queue(server_info)
asyncio.create_task(que.run_server_loop())

llumlet_0:Llumlet = Llumlet.from_args(
False,
True,
ray.get_runtime_context().get_node_id(),
"0",
BackendType.VLLM,
1,
migration_config,
engine_args,)

llumlet_1:Llumlet = Llumlet.from_args(
False,
True,
ray.get_runtime_context().get_node_id(),
"1",
BackendType.VLLM,
1,
migration_config,
engine_args,
)
while True:
res = ray.get([llumlet_0.is_ready.remote(),llumlet_1.is_ready.remote()])
if all(res):
break
ray.get([llumlet_0.execute_engine_method.remote("_run_workers","rebuild_migration_backend", id_rank_map, "llumnix"),
llumlet_1.execute_engine_method.remote("_run_workers","rebuild_migration_backend", id_rank_map, "llumnix")])
# empty instance migrate out
res = ray.get(llumlet_0.migrate_out.remote("instance_1", num_requests=-1))
assert not res

# running without migration
async def test_correctness(prompt):
sampling_params = SamplingParams(top_k=1, temperature=0, ignore_eos=True, max_tokens=100)
request_id0 = random_uuid()
request_expected_steps_id0 = math.inf
llumlet_0.generate.remote(request_id0, server_info, request_expected_steps_id0, prompt, sampling_params)
request_output_queue = que
origin_output = None
finished = False
while not finished:
request_output = await request_output_queue.get()
origin_output = request_output.outputs[0]
finished = request_output.finished

request_id1 = random_uuid()
request_expected_steps_id1 = 1
ray.get(llumlet_0.generate.remote(request_id1, server_info, request_expected_steps_id1, prompt, sampling_params))
# migrate request for decoding
while True:
res = ray.get(llumlet_0.migrate_out.remote("instance_1", num_requests = -1))
if len(res) == 1:
break
request_output_queue = que
output = None
finished = False
while not finished:
request_output = await request_output_queue.get()
origin_output = request_output.outputs[0]
finished = request_output.finished
if request_output.request_id != request_id1:
continue
output = request_output.outputs[0]
finished = request_output.finished

assert output.text == origin_output.text
assert output.cumulative_logprob == origin_output.cumulative_logprob
for prompt in TEST_PROMPTS:
test_correctness(prompt)
await test_correctness(prompt)
que.cleanup()

def test_clear_migration_states():
llumlet = MockLlumlet()
Expand Down
16 changes: 14 additions & 2 deletions tests/unit_test/global_scheduler/test_dispatch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,27 @@ def test_add_instance_and_remove_instance(dispatch_scheduler):
dispatch_scheduler.add_instance('instance_1')
assert dispatch_scheduler.num_instances == 1
assert len(dispatch_scheduler.available_dispatch_instance_set) == 1
dispatch_scheduler.remove_instance('instance_1')
assert dispatch_scheduler.num_instances == 0
assert len(dispatch_scheduler.available_dispatch_instance_set) == 0

dispatch_scheduler.add_instance('instance_2')
assert dispatch_scheduler.num_instances == 1
assert len(dispatch_scheduler.available_dispatch_instance_set) == 1
dispatch_scheduler.add_instance('instance_3')
assert dispatch_scheduler.num_instances == 2
if dispatch_scheduler.num_dispatch_instances <= 0:
assert len(dispatch_scheduler.available_dispatch_instance_set) == 2
else:
assert len(dispatch_scheduler.available_dispatch_instance_set) == min(2, dispatch_scheduler.num_dispatch_instances)
dispatch_scheduler.remove_instance('instance_1')
assert dispatch_scheduler.num_instances == 1

dispatch_scheduler.remove_instance('instance_2')
assert dispatch_scheduler.num_instances == 1
if dispatch_scheduler.num_dispatch_instances <= 0:
assert len(dispatch_scheduler.available_dispatch_instance_set) == 1
else:
assert len(dispatch_scheduler.available_dispatch_instance_set) == min(1, dispatch_scheduler.num_dispatch_instances-1)
dispatch_scheduler.remove_instance('instance_3')
assert dispatch_scheduler.num_instances == 0

def test_dispatch_balanced():
Expand Down

0 comments on commit 3e1623b

Please sign in to comment.