diff --git a/llumnix/global_scheduler/migration_scheduler.py b/llumnix/global_scheduler/migration_scheduler.py index 77fd9b25..2811d466 100644 --- a/llumnix/global_scheduler/migration_scheduler.py +++ b/llumnix/global_scheduler/migration_scheduler.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple, Set +from typing import Callable, Dict, List, Optional, Tuple, Set from abc import ABC, abstractmethod from enum import Enum import copy @@ -31,12 +31,18 @@ class PairMigrationConstraints(str, Enum): DECODING_2_DECODING = "DECODING_2_DECODING" PREFILL_2_DECODING = "PREFILL_2_DECODING" +class MigrationFilterConfig: + def __init__(self, migrate_out_load_threshold): + self.migrate_out_load_threshold: float = migrate_out_load_threshold + class MigrationScheduler: def __init__(self, pair_migration_policy: str, migrate_out_load_threshold: float, instance_load_calculator: InstanceLoadCalculator) -> None: - self.migrate_out_load_threshold = migrate_out_load_threshold + self.filter_config = MigrationFilterConfig(migrate_out_load_threshold=migrate_out_load_threshold) + self.migration_filter = MigrationFilter(self.filter_config) + self.instance_load_calculator = instance_load_calculator self.enable_defrag = instance_load_calculator.enable_defrag if not self.enable_defrag: @@ -57,14 +63,9 @@ def __init__(self, self.sorted_instance_infos: List[InstanceInfo] = None def pair_migration(self, pair_migration_type: PairMigrationConstraints) -> List[Tuple[str, str]]: - self._sort_instance_infos(descending=False) - sorted_src_instance_infos, sorted_dst_instance_infos = self._get_migration_instance_infos(pair_migration_type) - return self.pair_migration_policy.pair_migration(sorted_src_instance_infos, sorted_dst_instance_infos) - - def _get_migration_instance_infos(self, pair_migration_type: PairMigrationConstraints) -> Dict[str, InstanceInfo]: - filter_instance_infos_policy = FilteringInstanceInfosPolicyFactory.get_policy(pair_migration_type, - migrate_out_load_threshold=self.migrate_out_load_threshold) - return filter_instance_infos_policy.filter_instances(self.sorted_instance_infos,pair_migration_type) + src_instance_infos, dst_instance_infos = self.migration_filter.filter_instances( + self.instance_info.values(), pair_migration_type) + return self.pair_migration_policy.pair_migration(src_instance_infos, dst_instance_infos) def update_instance_infos(self, instance_info: Dict[str, InstanceInfo]) -> None: @@ -78,74 +79,106 @@ def remove_instance(self, instance_id: str) -> None: self.instance_id_set.remove(instance_id) self.num_instances = len(self.instance_id_set) - def _sort_instance_infos(self, - descending: bool = True) -> None: - instance_infos: List[InstanceInfo] = list(self.instance_info.values()) - key_attr = 'instance_load_migrate' - self.sorted_instance_infos = sorted( - instance_infos, - key=lambda instance_info: getattr(instance_info, key_attr), - reverse=descending - ) - -class FilteringInstanceInfosPolicy(ABC): - def __init__(self, - migrate_out_load_threshold: float) -> None: - self.migrate_out_load_threshold = migrate_out_load_threshold - self.filter_instances_rules = { - PairMigrationConstraints.NO_CONSTRAINTS: (InstanceType.NO_CONSTRAINTS, InstanceType.NO_CONSTRAINTS), - PairMigrationConstraints.DECODING_2_DECODING: (InstanceType.DECODE, InstanceType.DECODE), - PairMigrationConstraints.PREFILL_2_DECODING: (InstanceType.PREFILL, InstanceType.DECODE), - } - - def filter_instances(self, sorted_instance_infos: List[InstanceInfo], - pair_migration_type: PairMigrationConstraints = None) -> Dict[str, InstanceInfo]: - src_type, dst_type = self.filter_instances_rules[pair_migration_type] - filtered_src_instance_infos = [info for info in sorted_instance_infos if info.instance_type == src_type] - filtered_dst_instance_infos = [info for info in sorted_instance_infos if info.instance_type == dst_type] - src_instance_infos = self.filter_src_instances(filtered_src_instance_infos) - dst_instance_infos = self.filter_dst_instances(filtered_dst_instance_infos) - return src_instance_infos, dst_instance_infos - +MIGRATION_FILTER_POLICY_MAPPING = { + PairMigrationConstraints.NO_CONSTRAINTS: 'constraint', + PairMigrationConstraints.DECODING_2_DECODING: 'constraint', + PairMigrationConstraints.PREFILL_2_DECODING: 'relax', +} + +class MigrationFilter(ABC): + def __init__(self, filter_config: MigrationFilterConfig) -> None: + self.filter_config = filter_config + self.addition_src_filter: Callable[[InstanceInfo], bool] = None + self.addition_dst_filter: Callable[[InstanceInfo], bool] = None + self.default_filter = MigrationFilterPolicyFactory.get_policy('instancetype') + + def set_addition_filter(self, src_filter: Optional[Callable[[InstanceInfo], bool]] = None, + dst_filter: Optional[Callable[[InstanceInfo], bool]] = None) -> None: + if src_filter: + self.addition_src_filter = src_filter + if dst_filter: + self.addition_dst_filter = dst_filter + + def remove_addition_filter(self) -> None: + self.addition_src_filter = None + self.addition_dst_filter = None + + def filter_instances(self, instance_infos: List[InstanceInfo], + pair_migration_type: PairMigrationConstraints) -> Dict[str, InstanceInfo]: + src_filter_conditions = [self.default_filter.filter_src_condition(self.filter_config, pair_migration_type)] + dst_filter_conditions = [self.default_filter.filter_dst_condition(self.filter_config, pair_migration_type)] + + if self.addition_src_filter: + src_filter_conditions.append(self.addition_src_filter) + if self.addition_dst_filter: + dst_filter_conditions.append(self.addition_dst_filter) + + policy_filter = MigrationFilterPolicyFactory.get_policy(MIGRATION_FILTER_POLICY_MAPPING[pair_migration_type]) + src_filter_conditions.append(policy_filter.filter_src_condition(self.filter_config, pair_migration_type)) + dst_filter_conditions.append(policy_filter.filter_dst_condition(self.filter_config, pair_migration_type)) + + filtered_src_instance_infos = [info for info in instance_infos if all(cond(info) for cond in src_filter_conditions)] + filtered_dst_instance_infos = [info for info in instance_infos if all(cond(info) for cond in dst_filter_conditions)] + + return filtered_src_instance_infos, filtered_dst_instance_infos + +class MigrationFilterPolicy(ABC): @abstractmethod - def filter_src_instances(self, filtered_instance_infos) -> Dict[str, InstanceInfo]: + def filter_src_condition(self, filter_config, pair_migration_type) -> Callable[[InstanceInfo], bool]: raise NotImplementedError @abstractmethod - def filter_dst_instances(self, filtered_instance_infos) -> Dict[str, InstanceInfo]: + def filter_dst_condition(self, filter_config, pair_migration_type) -> Callable[[InstanceInfo], bool]: raise NotImplementedError -class FilterConstrained(FilteringInstanceInfosPolicy): - def filter_src_instances(self, filtered_instance_infos: List[InstanceInfo]) -> Dict[str, InstanceInfo]: - src_instance_infos = [i for i in reversed(filtered_instance_infos) - if i.num_killed_requests > 0 or i.instance_load_migrate > self.migrate_out_load_threshold] - return src_instance_infos +class InstanceTypeFilter(MigrationFilterPolicy): + INSTANCE_FILTER_RULES = { + PairMigrationConstraints.NO_CONSTRAINTS: (InstanceType.NO_CONSTRAINTS, InstanceType.NO_CONSTRAINTS), + PairMigrationConstraints.DECODING_2_DECODING: (InstanceType.DECODE, InstanceType.DECODE), + PairMigrationConstraints.PREFILL_2_DECODING: (InstanceType.PREFILL, InstanceType.DECODE), + } + + def filter_src_condition(self, filter_config: MigrationFilterConfig, + pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: + src_type, _ = self.INSTANCE_FILTER_RULES[pair_migration_type] + return lambda instance_info: instance_info.instance_type == src_type + - def filter_dst_instances(self, filtered_instance_infos: List[InstanceInfo]) -> Dict[str, InstanceInfo]: - dst_instance_infos = [i for i in filtered_instance_infos - if i.num_killed_requests == 0 and i.instance_load_migrate < self.migrate_out_load_threshold] - return dst_instance_infos + def filter_dst_condition(self, filter_config: MigrationFilterConfig, + pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: + _, dst_type = self.INSTANCE_FILTER_RULES[pair_migration_type] + return lambda instance_info: instance_info.instance_type == dst_type -class FilterRelaxed(FilteringInstanceInfosPolicy): +class ConstrainedFilter(MigrationFilterPolicy): + def filter_src_condition(self, filter_config: MigrationFilterConfig, + pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: + return lambda instance_info: instance_info.num_killed_requests > 0 \ + or instance_info.instance_load_migrate > filter_config.migrate_out_load_threshold + + def filter_dst_condition(self, filter_config: MigrationFilterConfig, + pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: + return lambda instance_info: instance_info.num_killed_requests == 0 \ + and instance_info.instance_load_migrate < filter_config.migrate_out_load_threshold + +class RelaxedFilter(MigrationFilterPolicy): # The policy is currently used to select the decoding instances to migrate requests from the prefill instances. - def filter_src_instances(self, filtered_instance_infos: List[InstanceInfo]) -> Dict[str, InstanceInfo]: - src_instance_infos = list(reversed(filtered_instance_infos)) - return src_instance_infos + def filter_src_condition(self, filter_config: MigrationFilterConfig, + pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: + return lambda instance_info: True - def filter_dst_instances(self, filtered_instance_infos: List[InstanceInfo]) -> Dict[str, InstanceInfo]: - dst_instance_infos = [i for i in filtered_instance_infos - if i.num_killed_requests == 0] - return dst_instance_infos + def filter_dst_condition(self, filter_config: MigrationFilterConfig, + pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: + return lambda instance_info: instance_info.num_killed_requests == 0 -class FilteringInstanceInfosPolicyFactory: +class MigrationFilterPolicyFactory: _POLICY_REGISTRY = { - PairMigrationConstraints.NO_CONSTRAINTS: FilterConstrained, - PairMigrationConstraints.DECODING_2_DECODING: FilterConstrained, - PairMigrationConstraints.PREFILL_2_DECODING: FilterRelaxed, + 'instancetype': InstanceTypeFilter, + 'constraint': ConstrainedFilter, + 'relax': RelaxedFilter, } @classmethod - def get_policy(cls, policy_name: PairMigrationConstraints, **kwargs) -> FilteringInstanceInfosPolicy: + def get_policy(cls, policy_name: PairMigrationConstraints, **kwargs) -> MigrationFilterPolicy: return cls._POLICY_REGISTRY[policy_name](**kwargs) class PairMigrationPolicy(ABC): @@ -157,16 +190,28 @@ def __init__(self, @abstractmethod def pair_migration(self, - sorted_src_instance_infos: List[InstanceInfo], - sorted_dst_instance_infos: List[InstanceInfo], + src_instance_infos: List[InstanceInfo], + dst_instance_infos: List[InstanceInfo], ) -> List[Tuple[str, str]]: raise NotImplementedError + @classmethod + def sort_instance_infos(cls, instance_infos: List[InstanceInfo], descending: bool = True) -> None: + key_attr = 'instance_load_migrate' + sorted_instance_infos = sorted( + instance_infos, + key=lambda instance_info: getattr(instance_info, key_attr), + reverse=descending + ) + return sorted_instance_infos + class Balanced(PairMigrationPolicy): def pair_migration(self, - sorted_src_instance_infos: List[InstanceInfo], - sorted_dst_instance_infos: List[InstanceInfo], + src_instance_infos: List[InstanceInfo], + dst_instance_infos: List[InstanceInfo], ) -> List[Tuple[str, str]]: + sorted_src_instance_infos = self.sort_instance_infos(src_instance_infos, descending=True) + sorted_dst_instance_infos = self.sort_instance_infos(dst_instance_infos, descending=False) migrate_instance_pairs = [] for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))): load_diff_before_mig = sorted_src_instance_infos[i].instance_load_migrate - sorted_dst_instance_infos[i].instance_load_migrate @@ -198,9 +243,11 @@ def _compute_instance_load_after_migrate(self, instance_info: InstanceInfo, is_m class DefragConstrained(PairMigrationPolicy): def pair_migration(self, - sorted_src_instance_infos: List[InstanceInfo], - sorted_dst_instance_infos: List[InstanceInfo], + src_instance_infos: List[InstanceInfo], + dst_instance_infos: List[InstanceInfo], ) -> List[Tuple[str, str]]: + sorted_src_instance_infos = self.sort_instance_infos(src_instance_infos, descending=True) + sorted_dst_instance_infos = self.sort_instance_infos(dst_instance_infos, descending=False) migrate_instance_pairs = [] for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))): # without any constrain in order to make prefill migrate happens as soon as possible diff --git a/llumnix/internal_config.py b/llumnix/internal_config.py index 08f5283f..4412c13b 100644 --- a/llumnix/internal_config.py +++ b/llumnix/internal_config.py @@ -51,6 +51,8 @@ def __init__( self.dispatch_policy = dispatch_policy self.pair_migration_policy = pair_migration_policy + # TODO(KuilongCui): Use a better way to set the threshold, as having both positive and negative + # values can cause confusion. self.migrate_out_load_threshold = migrate_out_threshold*(-1) self.enable_defrag = enable_defrag diff --git a/llumnix/llm_engine_manager.py b/llumnix/llm_engine_manager.py index d98f3a8e..d89c4ee1 100644 --- a/llumnix/llm_engine_manager.py +++ b/llumnix/llm_engine_manager.py @@ -363,7 +363,8 @@ def scale_up(self, instance_id: Union[str, Iterable[str]], llumlet_actor_handles # a coroutine is already handling the changes in the number of instances in the cluster and it will account for the changes # caused by this scale-up (see rebuild_migrate_backend for details). Therefore, we simply return in this case. Specifically, # for RPC, the Ray actor handle is used for the migration cache, so there is no need to rebuild the group. - if self.engine_manager_args.migration_backend != 'rpc' and indeed_update and no_pending_instance: + if self.enable_migration and self.engine_manager_args.migration_backend != 'rpc' \ + and indeed_update and no_pending_instance: asyncio.create_task(self.rebuild_migrate_backend()) return self.num_instances @@ -386,7 +387,7 @@ def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migrate_bac self.global_scheduler.scale_down(instance_ids) self.num_instances = len(self.instances) - if self.engine_manager_args.migration_backend != 'rpc': + if self.enable_migration and self.engine_manager_args.migration_backend != 'rpc': if len(self.instances) == 0: self.pending_rebuild_migration_instances = 0 diff --git a/tests/unit_test/global_scheduler/test_migration_scheduler.py b/tests/unit_test/global_scheduler/test_migration_scheduler.py index 8fd32105..b6e7a772 100644 --- a/tests/unit_test/global_scheduler/test_migration_scheduler.py +++ b/tests/unit_test/global_scheduler/test_migration_scheduler.py @@ -17,11 +17,12 @@ import numpy as np from llumnix.instance_info import InstanceLoadCalculator, InstanceInfo -from llumnix.global_scheduler.migration_scheduler import MigrationScheduler, PairMigrationConstraints +from llumnix.global_scheduler.migration_scheduler import \ + (MigrationScheduler,PairMigrationConstraints, MigrationFilter, MigrationFilterConfig) from llumnix.global_scheduler.scaling_scheduler import InstanceType MIGRATE_OUT_LOAD_THRESHOLD = 3.0 -INSTANCE_NUM = 4 +INSTANCE_NUM = 16 def init_migration_scheduler(policy='balanced'): instance_load_calculator = InstanceLoadCalculator('remaining_steps', True) @@ -43,57 +44,66 @@ def test_add_instance_and_remove_instance(migration_scheduler): migration_scheduler.remove_instance('instance_2') assert migration_scheduler.num_instances == 0 -@pytest.mark.parametrize("pair_migration_type", ['NO_CONSTRAINTS','DECODING_2_DECODING','PREFILL_2_DECODING']) -def test_get_migration_instance_infos(pair_migration_type): +@pytest.mark.parametrize("pair_migration_type", ['NO_CONSTRAINTS', 'DECODING_2_DECODING', 'PREFILL_2_DECODING']) +def test_migration_filter(pair_migration_type): num_tests = 1000 + migration_filter = MigrationFilter(MigrationFilterConfig(MIGRATE_OUT_LOAD_THRESHOLD)) + for _ in range(num_tests): - instance_info_dict = {} - for instance_id in [f'instance_{i}' for i in range(1, INSTANCE_NUM + 1)]: + instance_infos = [] + + total_prefill_instance_num = 0 + + for instance_id in range(1, INSTANCE_NUM + 1): instance_info = InstanceInfo() instance_info.instance_id = instance_id instance_info.instance_load_migrate = MIGRATE_OUT_LOAD_THRESHOLD + random.uniform(-1, 1) instance_info.num_killed_requests = random.randint(0, 1) + if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: constraint_prefill_instance_num = math.inf else: constraint_prefill_instance_num = random.randint(1, INSTANCE_NUM) - migration_scheduler = init_migration_scheduler() + if constraint_prefill_instance_num == math.inf: instance_info.instance_type = InstanceType.NO_CONSTRAINTS else: - if len([info for info in instance_info_dict.values() - if info.instance_type == InstanceType.PREFILL]) < constraint_prefill_instance_num: + if total_prefill_instance_num < constraint_prefill_instance_num: instance_info.instance_type = InstanceType.PREFILL + total_prefill_instance_num += 1 else: instance_info.instance_type = InstanceType.DECODE - instance_info_dict[instance_id] = instance_info - migration_scheduler.instance_info = instance_info_dict - migration_scheduler._sort_instance_infos(descending=False) - sorted_src_instance_infos, sorted_dst_instance_infos = migration_scheduler._get_migration_instance_infos(pair_migration_type) - for instance in sorted_src_instance_infos: - if pair_migration_type != PairMigrationConstraints.PREFILL_2_DECODING: - assert instance.num_killed_requests > 0 \ - or instance.instance_load_migrate > MIGRATE_OUT_LOAD_THRESHOLD - if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: - assert instance.instance_type == InstanceType.NO_CONSTRAINTS - elif migration_scheduler == PairMigrationConstraints.DECODING_2_DECODING: - assert instance.instance_type == InstanceType.DECODE - else: - assert instance.instance_type == InstanceType.PREFILL - for instance in sorted_dst_instance_infos: - if pair_migration_type != PairMigrationConstraints.PREFILL_2_DECODING: - assert instance.num_killed_requests == 0 and instance.instance_load_migrate < MIGRATE_OUT_LOAD_THRESHOLD - if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: - assert instance.instance_type == InstanceType.NO_CONSTRAINTS - elif migration_scheduler == PairMigrationConstraints.DECODING_2_DECODING: + + instance_infos.append(instance_info) + + src_instance_infos, dst_instance_infos = migration_filter.filter_instances(instance_infos, pair_migration_type) + + for instance in src_instance_infos: + if pair_migration_type != PairMigrationConstraints.PREFILL_2_DECODING: + assert instance.num_killed_requests > 0 \ + or instance.instance_load_migrate > MIGRATE_OUT_LOAD_THRESHOLD + if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: + assert instance.instance_type == InstanceType.NO_CONSTRAINTS + elif pair_migration_type == PairMigrationConstraints.DECODING_2_DECODING: + assert instance.instance_type == InstanceType.DECODE + else: + assert instance.instance_type == InstanceType.PREFILL + + for instance in dst_instance_infos: + if pair_migration_type != PairMigrationConstraints.PREFILL_2_DECODING: + assert instance.num_killed_requests == 0 and instance.instance_load_migrate < MIGRATE_OUT_LOAD_THRESHOLD + if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: + assert instance.instance_type == InstanceType.NO_CONSTRAINTS + elif pair_migration_type == PairMigrationConstraints.DECODING_2_DECODING: + assert instance.instance_type == InstanceType.DECODE + else: assert instance.instance_type == InstanceType.DECODE - else: - assert instance.instance_type == InstanceType.DECODE - assert instance.num_killed_requests == 0 + assert instance.num_killed_requests == 0 -@pytest.mark.parametrize("policy", ['balanced','defrag_constrained']) +@pytest.mark.parametrize("policy", ['balanced', 'defrag_constrained']) def test_pair_migration(policy): num_tests = 1000 + for _ in range(num_tests): migration_scheduler = init_migration_scheduler(policy) instance_info_dict = {} @@ -106,14 +116,9 @@ def test_pair_migration(policy): instance_info.instance_type = InstanceType.NO_CONSTRAINTS instance_info_dict[instance_id] = instance_info migration_scheduler.instance_info = instance_info_dict - migration_scheduler._sort_instance_infos(descending=False) - sorted_src_instance_infos = [i for i in reversed(migration_scheduler.sorted_instance_infos) - if i.instance_type == InstanceType.NO_CONSTRAINTS - and (i.num_killed_requests > 0 or i.instance_load_migrate > migration_scheduler.migrate_out_load_threshold)] - sorted_dst_instance_infos = [i for i in migration_scheduler.sorted_instance_infos - if i.instance_type == InstanceType.NO_CONSTRAINTS - and (i.num_killed_requests == 0 and i.instance_load_migrate < migration_scheduler.migrate_out_load_threshold)] - migrate_instance_pairs = migration_scheduler.pair_migration_policy.pair_migration(sorted_src_instance_infos, sorted_dst_instance_infos) + + migrate_instance_pairs = migration_scheduler.pair_migration(PairMigrationConstraints.NO_CONSTRAINTS) + for migrate_out_instance, migrate_in_instance in migrate_instance_pairs: assert migrate_out_instance != migrate_in_instance if policy == 'balanced':