Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui committed Nov 6, 2024
1 parent 7e62a67 commit e7145f7
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 60 deletions.
6 changes: 3 additions & 3 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,9 @@ def check_args(cls, args: 'EngineManagerArgs', parser: argparse.ArgumentParser):
if hasattr(action, 'choices') and action.choices is not None and hasattr(args, action.dest):
assert getattr(args, action.dest) in action.choices, f"{action.dest} should be one of {action.choices}."

if args.migration_backend == 'nccl':
logger.warning("NCCL migration backend is deprecated, use gloo instead.")
args.migration_backend = 'gloo'
if args.migration_backend == 'nccl' and args.migration_internal_cache_size != 1:
logger.warning("The NCCL migration backend does not support concurrency. Set migration_internal_cache_size to 1.")
args.migration_internal_cache_size = 1

assert args.migration_backend != 'gloo' or (args.migration_backend == 'gloo' \
and not args.disable_init_instance_by_manager and not args.disable_fixed_node_init_instance), \
Expand Down
7 changes: 4 additions & 3 deletions llumnix/backends/vllm/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import time
from typing import Dict, List
import math
import ray
import torch

from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy
Expand Down Expand Up @@ -112,8 +111,10 @@ def migrate_cache(self, src_worker_handle_list, src_blocks: List[int], dst_block
start_time = time.time()
try:
self.migration_backend.migrate_cache(src_worker_handle, src_blocks, dst_blocks)
except ray.exceptions.RayActorError:
logger.info("[migrate_cache] self.rank: {}, src_worker_handle {} is dead".format(self.rank, src_worker_handle))
# pylint: disable=broad-except
except Exception as e:
logger.info("[migrate_cache] self.rank: {}, src_worker_handle {}, meet err : {}"
.format(self.rank, src_worker_handle, e))
end_time = time.time()

total_kv_cache_size = len(src_blocks) * CacheEngine.get_cache_block_size(
Expand Down
4 changes: 2 additions & 2 deletions llumnix/global_scheduler/global_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def dispatch(self) -> str:
request_expected_steps = 1 if self.enable_pd_disagg else math.inf
return instance_id, request_expected_steps

def pair_migration(self, pair_migration_type: PairMigrationConstraints, inflight_migrating: Dict[str, int]) -> List[Tuple[str, str]]:
def pair_migration(self, pair_migration_type: PairMigrationConstraints) -> List[Tuple[str, str]]:
self.migration_scheduler.update_instance_infos(self.instance_info)
migrate_instance_pairs = self.migration_scheduler.pair_migration(pair_migration_type, inflight_migrating)
migrate_instance_pairs = self.migration_scheduler.pair_migration(pair_migration_type)
return migrate_instance_pairs

def check_scale(self) -> Tuple[str, str]:
Expand Down
33 changes: 8 additions & 25 deletions llumnix/global_scheduler/migration_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def __init__(self,
self.instance_info: Dict[str, InstanceInfo] = None
self.sorted_instance_infos: List[InstanceInfo] = None

def pair_migration(self, pair_migration_type: PairMigrationConstraints, inflight_migrating: Dict[str, int]) -> List[Tuple[str, str]]:
def pair_migration(self, pair_migration_type: PairMigrationConstraints) -> List[Tuple[str, str]]:
self._sort_instance_infos(descending=False)
sorted_src_instance_infos, sorted_dst_instance_infos = self._get_migration_instance_infos(pair_migration_type)

migrate_instance_pairs = self.pair_migration_policy.pair_migration(
sorted_src_instance_infos, sorted_dst_instance_infos, inflight_migrating)
sorted_src_instance_infos, sorted_dst_instance_infos)

return migrate_instance_pairs

Expand Down Expand Up @@ -162,34 +162,21 @@ def __init__(self,
@abstractmethod
def pair_migration(self,
sorted_src_instance_infos: List[InstanceInfo],
sorted_dst_instance_infos: List[InstanceInfo],
inflight_migration: Dict[str, int],
sorted_dst_instance_infos: List[InstanceInfo]
) -> List[Tuple[str, str]]:
raise NotImplementedError

class Balanced(PairMigrationPolicy):
def pair_migration(self,
sorted_src_instance_infos: List[InstanceInfo],
sorted_dst_instance_infos: List[InstanceInfo],
inflight_migration: Dict[str, int],
sorted_dst_instance_infos: List[InstanceInfo]
) -> List[Tuple[str, str]]:
migrate_instance_pairs = []
for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))):
load_diff_before_mig = sorted_src_instance_infos[i].instance_load_migrate - sorted_dst_instance_infos[i].instance_load_migrate

src_num_migration_request = 0
if sorted_src_instance_infos[i].instance_id in inflight_migration:
src_num_migration_request = inflight_migration[sorted_src_instance_infos[i].instance_id][0] - \
inflight_migration[sorted_src_instance_infos[i].instance_id][1]
left_load_after_mig = self._compute_instance_load_after_migrate(sorted_src_instance_infos[i],
is_migrate_in=False, inflight_migration=src_num_migration_request)

dst_num_migration_request = 0
if sorted_dst_instance_infos[i].instance_id in inflight_migration:
dst_num_migration_request = inflight_migration[sorted_dst_instance_infos[i].instance_id][0] - \
inflight_migration[sorted_dst_instance_infos[i].instance_id][1]
right_load_after_mig = self._compute_instance_load_after_migrate(sorted_dst_instance_infos[i],
is_migrate_in=True, inflight_migration=dst_num_migration_request)
left_load_after_mig = self._compute_instance_load_after_migrate(sorted_src_instance_infos[i], is_migrate_in=False)
right_load_after_mig = self._compute_instance_load_after_migrate(sorted_dst_instance_infos[i], is_migrate_in=True)

# Add some constrains to reduce unnecessary migrations
if right_load_after_mig > self.migrate_out_load_threshold:
Expand All @@ -200,7 +187,7 @@ def pair_migration(self,
sorted_dst_instance_infos[i].instance_id))
return migrate_instance_pairs

def _compute_instance_load_after_migrate(self, instance_info: InstanceInfo, is_migrate_in: bool, inflight_migration: int) -> float:
def _compute_instance_load_after_migrate(self, instance_info: InstanceInfo, is_migrate_in: bool) -> float:
instance_info_after_migrate = copy.deepcopy(instance_info)
num_blocks_last_running_request = instance_info_after_migrate.num_blocks_last_running_request

Expand All @@ -211,16 +198,12 @@ def _compute_instance_load_after_migrate(self, instance_info: InstanceInfo, is_m
instance_info_after_migrate.num_running_requests -= 1
instance_info_after_migrate.num_free_gpu_blocks += num_blocks_last_running_request

instance_info_after_migrate.num_running_requests -= inflight_migration
instance_info_after_migrate.num_free_gpu_blocks += inflight_migration * num_blocks_last_running_request

return self.instance_load_calculator.compute_instance_load(instance_info_after_migrate, action='migrate')

class DefragConstrained(PairMigrationPolicy):
def pair_migration(self,
sorted_src_instance_infos: List[InstanceInfo],
sorted_dst_instance_infos: List[InstanceInfo],
inflight_migration: Dict[str, int],
sorted_dst_instance_infos: List[InstanceInfo]
) -> List[Tuple[str, str]]:
migrate_instance_pairs = []
for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))):
Expand Down
25 changes: 2 additions & 23 deletions llumnix/llm_engine_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ def __init__(self,
logger.info("max_instances: {}, min_instances: {}".format(self.max_instances, self.min_instances))

self.instances: Dict[str, Llumlet] = {}
self.num_migrating_pair = 0
# instance_id -> [num_migration_out, num_migration_in]
self.inflight_migrating: Dict[str, List] = {}
self.pending_rebuild_migration_instances = 0
self.global_scheduler = GlobalScheduler(global_scheduler_config)

Expand Down Expand Up @@ -231,14 +228,6 @@ async def _push_migrations(self) -> None:

async def _migrate(self, pair_migration_type: PairMigrationConstraints, migrate_in_num_requests: int) -> None:
async def migrate_done_callback(ret, migrate_instance_pair: Tuple[str, str]) -> None:
if pair_migration_type != PairMigrationConstraints.PREFILL_2_DECODING:
migrate_out_instance_id, migrate_in_instance_id = migrate_instance_pair
self.num_migrating_pair -= 1
if migrate_out_instance_id in self.inflight_migrating:
self.inflight_migrating[migrate_out_instance_id][0] -= 1
if migrate_in_instance_id in self.inflight_migrating:
self.inflight_migrating[migrate_in_instance_id][1] -= 1

if isinstance(ret, (ray.exceptions.RayActorError, KeyError)):
has_error_pair = await self._check_instance_error(migrate_instance_pair)
for i, has_error in enumerate(has_error_pair):
Expand Down Expand Up @@ -266,18 +255,13 @@ def migrate_done_callback_wrapper(migrate_instance_pair: Tuple[str, str], fut) -
loop = asyncio.get_event_loop()
loop.create_task(migrate_done_callback(ret, migrate_instance_pair))

migrate_instance_pairs = self.global_scheduler.pair_migration(pair_migration_type, self.inflight_migrating)

try:
migrate_instance_pairs = self.global_scheduler.pair_migration(pair_migration_type)

migration_tasks = []
for _, migrate_instance_pair in enumerate(migrate_instance_pairs):
migrate_out_instance_id, migrate_in_instance_id = migrate_instance_pair

if pair_migration_type != PairMigrationConstraints.PREFILL_2_DECODING:
self.num_migrating_pair += 1
self.inflight_migrating[migrate_out_instance_id][0] += 1
self.inflight_migrating[migrate_in_instance_id][1] += 1

migrate_in_instance_name = "instance_{}".format(migrate_in_instance_id)
# Use asyncio.gather to wrap ray remote call to add done callback.
task = asyncio.gather(self.instances[migrate_out_instance_id].migrate_out.remote(migrate_in_instance_name, migrate_in_num_requests),
Expand Down Expand Up @@ -366,7 +350,6 @@ def scale_up(self, instance_id: Union[str, Iterable[str]], llumlet_actor_handles
if ins_id not in self.instances:
indeed_update = True
self.instances[ins_id] = llumlet_actor_handles[idx]
self.inflight_migrating[ins_id] = [0, 0]
if self.log_instance_info:
self.instance_last_logged_empty[ins_id] = False
self.pending_rebuild_migration_instances += 1
Expand Down Expand Up @@ -394,10 +377,6 @@ def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migrate_bac
if ins_id in self.instances:
indeed_update = True
del self.instances[ins_id]

if ins_id in self.inflight_migrating:
del self.inflight_migrating[ins_id]

if self.log_instance_info:
del self.instance_last_logged_empty[ins_id]
self.pending_rebuild_migration_instances += 1
Expand Down
8 changes: 7 additions & 1 deletion tests/e2e_test/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ async def get_llumnix_responce(prompt, sampling_params, ip_ports):
"The future of AI is",
]

vllm_output = {}

@ray.remote(num_gpus=1)
def run_vllm(model, max_model_len, sampling_params):
vllm_output = {}
Expand Down Expand Up @@ -161,7 +163,11 @@ async def test_e2e(model, migration_backend, launch_mode):

shutdown_llumnix_service()

vllm_output = ray.get(run_vllm.remote(model, max_model_len, sampling_params))
global vllm_output

if len(vllm_output) == 0:
vllm_output = ray.get(run_vllm.remote(model, max_model_len, sampling_params))

clear_ray_state()

# compare
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e_test/test_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def parse_manager_log_file(log_file):
@pytest.mark.asyncio
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="at least 2 gpus required for migration bench")
@pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B'])
@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo'])
@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo', 'nccl'])
async def test_migration_benchmark(model, migration_backend):
base_port = 37037
instance_output_logs = []
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_test/global_scheduler/test_global_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,6 @@ def test_pair_migration(global_scheduler):
instance_infos = [instance_info_migrate_in, instance_info_migrate_out]
global_scheduler.scale_up(instance_ids)
global_scheduler.update_instance_infos(instance_infos)
migrate_instace_pairs = global_scheduler.pair_migration("NO_CONSTRAINTS", {})
migrate_instace_pairs = global_scheduler.pair_migration("NO_CONSTRAINTS")
assert migrate_instace_pairs[0][0] == instance_id_1
assert migrate_instace_pairs[0][1] == instance_id
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_pair_migration(policy):
sorted_dst_instance_infos = [i for i in migration_scheduler.sorted_instance_infos
if i.instance_type == InstanceType.NO_CONSTRAINTS
and (i.num_killed_requests == 0 and i.instance_load_migrate < migration_scheduler.migrate_out_load_threshold)]
migrate_instance_pairs = migration_scheduler.pair_migration_policy.pair_migration(sorted_src_instance_infos, sorted_dst_instance_infos, {})
migrate_instance_pairs = migration_scheduler.pair_migration_policy.pair_migration(sorted_src_instance_infos, sorted_dst_instance_infos)
for migrate_out_instance, migrate_in_instance in migrate_instance_pairs:
assert migrate_out_instance != migrate_in_instance
if policy == 'balanced':
Expand Down

0 comments on commit e7145f7

Please sign in to comment.