From be79deeda063c126217072a151365fecfc83435f Mon Sep 17 00:00:00 2001 From: root Date: Tue, 5 Nov 2024 05:57:28 +0000 Subject: [PATCH 1/3] [misc] other --- llumnix/global_scheduler/global_scheduler.py | 3 +- llumnix/global_scheduler/migration_filter.py | 146 ++++++++++++++++ llumnix/global_scheduler/migration_policy.py | 113 ++++++++++++ .../global_scheduler/migration_scheduler.py | 165 +----------------- llumnix/internal_config.py | 2 + llumnix/llm_engine_manager.py | 5 +- .../test_migration_scheduler.py | 88 +++++----- 7 files changed, 321 insertions(+), 201 deletions(-) create mode 100644 llumnix/global_scheduler/migration_filter.py create mode 100644 llumnix/global_scheduler/migration_policy.py diff --git a/llumnix/global_scheduler/global_scheduler.py b/llumnix/global_scheduler/global_scheduler.py index 201b57de..ec1568bb 100644 --- a/llumnix/global_scheduler/global_scheduler.py +++ b/llumnix/global_scheduler/global_scheduler.py @@ -18,7 +18,8 @@ from llumnix.internal_config import GlobalSchedulerConfig from llumnix.instance_info import InstanceLoadCalculator, InstanceInfo from llumnix.global_scheduler.dispatch_scheduler import DispatchScheduler -from llumnix.global_scheduler.migration_scheduler import MigrationScheduler, PairMigrationConstraints +from llumnix.global_scheduler.migration_scheduler import MigrationScheduler +from llumnix.global_scheduler.migration_policy import PairMigrationConstraints from llumnix.global_scheduler.scaling_scheduler import ScalingScheduler logger = init_logger(__name__) diff --git a/llumnix/global_scheduler/migration_filter.py b/llumnix/global_scheduler/migration_filter.py new file mode 100644 index 00000000..2f833a0b --- /dev/null +++ b/llumnix/global_scheduler/migration_filter.py @@ -0,0 +1,146 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, List, Optional +from abc import ABC, abstractmethod + +from llumnix.logger import init_logger +from llumnix.instance_info import InstanceInfo +from llumnix.global_scheduler.scaling_scheduler import InstanceType +from llumnix.global_scheduler.migration_policy import PairMigrationConstraints + +logger = init_logger(__name__) + +class MigrationFilterConfig: + def __init__(self, migrate_out_load_threshold): + self.migrate_out_load_threshold: float = migrate_out_load_threshold + +# TODO(KuilongCui): A filter might contain other filters; leave this for the future. +class MigrationFilterPolicy(ABC): + @abstractmethod + def filter_src_condition(self, filter_config, pair_migration_type) -> Callable[[InstanceInfo], bool]: + raise NotImplementedError + + @abstractmethod + def filter_dst_condition(self, filter_config, pair_migration_type) -> Callable[[InstanceInfo], bool]: + raise NotImplementedError + +class MigrationFilter(ABC): + def __init__(self, filter_config: MigrationFilterConfig) -> None: + self.filter_config = filter_config + self.registered_filters: Dict[str, MigrationFilterPolicy] = {} + + def register_filter(self, filter_name: str, migration_filter: MigrationFilterPolicy) -> bool: + if filter_name in self.registered_filters: + logger.warning("migration filter {} has been registered.".format(filter_name)) + return False + + self.registered_filters[filter_name] = migration_filter + return True + + def unregister_filter(self, filter_name: str) -> None: + self.registered_filters.pop(filter_name, None) + + def filter_instances(self, instance_infos: List[InstanceInfo], + pair_migration_type: PairMigrationConstraints) -> Dict[str, InstanceInfo]: + src_filter_conditions = [filter.filter_src_condition() for filter in self.registered_filters.values()] + dst_filter_conditions = [filter.filter_dst_condition() for filter in self.registered_filters.values()] + + if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: + policy_filter = MigrationFilterPolicyFactory.get_policy("loadconstraint") + elif pair_migration_type in [PairMigrationConstraints.PREFILL_2_DECODING, PairMigrationConstraints.DECODING_2_DECODING]: + policy_filter = MigrationFilterPolicyFactory.get_policy('pdd') + else: + raise ValueError(f"Unsupported pair migration type: {pair_migration_type}") + src_filter_conditions.append(policy_filter.filter_src_condition(self.filter_config, pair_migration_type)) + dst_filter_conditions.append(policy_filter.filter_dst_condition(self.filter_config, pair_migration_type)) + + filtered_src_instance_infos = [info for info in instance_infos if all(cond(info) for cond in src_filter_conditions)] + filtered_dst_instance_infos = [info for info in instance_infos if all(cond(info) for cond in dst_filter_conditions)] + + return filtered_src_instance_infos, filtered_dst_instance_infos + +class LoadConstrainedFilter(MigrationFilterPolicy): + def filter_src_condition(self, filter_config: MigrationFilterConfig, + pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: + return lambda instance_info: instance_info.num_killed_requests > 0 \ + or instance_info.instance_load_migrate > filter_config.migrate_out_load_threshold + + def filter_dst_condition(self, filter_config: MigrationFilterConfig, + pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: + return lambda instance_info: instance_info.num_killed_requests == 0 \ + and instance_info.instance_load_migrate < filter_config.migrate_out_load_threshold + +class PddFilter(MigrationFilterPolicy): + INSTANCE_FILTER_RULES = { + PairMigrationConstraints.DECODING_2_DECODING: (InstanceType.DECODE, InstanceType.DECODE), + PairMigrationConstraints.PREFILL_2_DECODING: (InstanceType.PREFILL, InstanceType.DECODE), + } + + def filter_src_condition(self, filter_config: MigrationFilterConfig, + pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: + src_type, _ = self.INSTANCE_FILTER_RULES[pair_migration_type] + instance_type_filter = lambda instance_info: instance_info.instance_type == src_type + + if pair_migration_type == PairMigrationConstraints.DECODING_2_DECODING: + inner_policy = MigrationFilterPolicyFactory.get_policy('loadconstraint') + policy_filter = inner_policy.filter_src_condition(filter_config, pair_migration_type) + else: + policy_filter = lambda instance_info: True + + return lambda instance_info: instance_type_filter(instance_info) and policy_filter(instance_info) + + def filter_dst_condition(self, filter_config: MigrationFilterConfig, + pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: + _, dst_type = self.INSTANCE_FILTER_RULES[pair_migration_type] + instance_type_filter = lambda instance_info: instance_info.instance_type == dst_type + + if pair_migration_type == PairMigrationConstraints.DECODING_2_DECODING: + inner_policy = MigrationFilterPolicyFactory.get_policy('loadconstraint') + policy_filter = inner_policy.filter_dst_condition(filter_config, pair_migration_type) + else: + policy_filter = lambda instance_info: instance_info.num_killed_requests == 0 + + return lambda instance_info: instance_type_filter(instance_info) and policy_filter(instance_info) + +class CustomFilter(MigrationFilterPolicy): + def __init__(self): + super().__init__() + self.src_filter = lambda _: True + self.dst_filter = lambda _: True + + def set_filter_condtition(self, src_filter: Optional[Callable[[InstanceInfo], bool]] = None, + dst_filter: Optional[Callable[[InstanceInfo], bool]] = None) -> None: + if src_filter: + self.src_filter = src_filter + if dst_filter: + self.dst_filter = dst_filter + + def filter_src_condition(self, filter_config: MigrationFilterConfig, + pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: + return self.src_filter + + def filter_dst_condition(self, filter_config: MigrationFilterConfig, + pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: + return self.dst_filter + +class MigrationFilterPolicyFactory: + _POLICY_REGISTRY = { + 'loadconstraint': LoadConstrainedFilter, + 'pdd': PddFilter, + 'custom': CustomFilter, + } + + @classmethod + def get_policy(cls, policy_name: PairMigrationConstraints, **kwargs) -> MigrationFilterPolicy: + return cls._POLICY_REGISTRY[policy_name](**kwargs) diff --git a/llumnix/global_scheduler/migration_policy.py b/llumnix/global_scheduler/migration_policy.py new file mode 100644 index 00000000..c917cce7 --- /dev/null +++ b/llumnix/global_scheduler/migration_policy.py @@ -0,0 +1,113 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple +from abc import ABC, abstractmethod +from enum import Enum +import copy +import numpy as np + +from llumnix.logger import init_logger +from llumnix.instance_info import InstanceInfo, InstanceLoadCalculator + +logger = init_logger(__name__) + +class PairMigrationConstraints(str, Enum): + """Target of Migration.""" + NO_CONSTRAINTS = "NO_CONSTRAINTS" + + # Enable the prefill-decoding disaggregration. + DECODING_2_DECODING = "DECODING_2_DECODING" + PREFILL_2_DECODING = "PREFILL_2_DECODING" + +class PairMigrationPolicy(ABC): + def __init__(self, + migrate_out_load_threshold: float, + instance_load_calculator: InstanceLoadCalculator) -> None: + self.migrate_out_load_threshold = migrate_out_load_threshold + self.instance_load_calculator = instance_load_calculator + + @abstractmethod + def pair_migration(self, + src_instance_infos: List[InstanceInfo], + dst_instance_infos: List[InstanceInfo], + ) -> List[Tuple[str, str]]: + raise NotImplementedError + + def sort_instance_infos(self, instance_infos: List[InstanceInfo], descending: bool = True) -> None: + key_attr = 'instance_load_migrate' + sorted_instance_infos = sorted( + instance_infos, + key=lambda instance_info: getattr(instance_info, key_attr), + reverse=descending + ) + return sorted_instance_infos + +class Balanced(PairMigrationPolicy): + def pair_migration(self, + src_instance_infos: List[InstanceInfo], + dst_instance_infos: List[InstanceInfo], + ) -> List[Tuple[str, str]]: + sorted_src_instance_infos = self.sort_instance_infos(src_instance_infos, descending=True) + sorted_dst_instance_infos = self.sort_instance_infos(dst_instance_infos, descending=False) + migrate_instance_pairs = [] + for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))): + load_diff_before_mig = sorted_src_instance_infos[i].instance_load_migrate - sorted_dst_instance_infos[i].instance_load_migrate + + left_load_after_mig = self._compute_instance_load_after_migrate(sorted_src_instance_infos[i], is_migrate_in=False) + right_load_after_mig = self._compute_instance_load_after_migrate(sorted_dst_instance_infos[i], is_migrate_in=True) + + # Add some constrains to reduce unnecessary migrations + if right_load_after_mig > self.migrate_out_load_threshold: + continue + load_diff_after_mig = left_load_after_mig - right_load_after_mig + if (0 < load_diff_after_mig < load_diff_before_mig) or (sorted_dst_instance_infos[i].instance_load_migrate == -np.inf): + migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id, + sorted_dst_instance_infos[i].instance_id)) + return migrate_instance_pairs + + def _compute_instance_load_after_migrate(self, instance_info: InstanceInfo, is_migrate_in: bool) -> float: + instance_info_after_migrate = copy.deepcopy(instance_info) + num_blocks_last_running_request = instance_info_after_migrate.num_blocks_last_running_request + + if is_migrate_in: + instance_info_after_migrate.num_running_requests += 1 + instance_info_after_migrate.num_free_gpu_blocks -= num_blocks_last_running_request + else: + instance_info_after_migrate.num_running_requests -= 1 + instance_info_after_migrate.num_free_gpu_blocks += num_blocks_last_running_request + + return self.instance_load_calculator.compute_instance_load(instance_info_after_migrate, action='migrate') + +class DefragConstrained(PairMigrationPolicy): + def pair_migration(self, + src_instance_infos: List[InstanceInfo], + dst_instance_infos: List[InstanceInfo], + ) -> List[Tuple[str, str]]: + sorted_src_instance_infos = self.sort_instance_infos(src_instance_infos, descending=True) + sorted_dst_instance_infos = self.sort_instance_infos(dst_instance_infos, descending=False) + migrate_instance_pairs = [] + for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))): + # without any constrain in order to make prefill migrate happens as soon as possible + migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id, sorted_dst_instance_infos[i].instance_id)) + return migrate_instance_pairs + +class PairMigrationPolicyFactory: + _POLICY_REGISTRY = { + 'balanced': Balanced, + 'defrag_constrained': DefragConstrained, + } + + @classmethod + def get_policy(cls, policy_name: str, **kwargs) -> PairMigrationPolicy: + return cls._POLICY_REGISTRY[policy_name](**kwargs) diff --git a/llumnix/global_scheduler/migration_scheduler.py b/llumnix/global_scheduler/migration_scheduler.py index 3445b210..565eb1dc 100644 --- a/llumnix/global_scheduler/migration_scheduler.py +++ b/llumnix/global_scheduler/migration_scheduler.py @@ -12,31 +12,22 @@ # limitations under the License. from typing import Dict, List, Tuple, Set -from abc import ABC, abstractmethod -from enum import Enum -import copy -import numpy as np from llumnix.logger import init_logger from llumnix.instance_info import InstanceInfo, InstanceLoadCalculator -from llumnix.global_scheduler.scaling_scheduler import InstanceType +from llumnix.global_scheduler.migration_filter import MigrationFilter, MigrationFilterConfig +from llumnix.global_scheduler.migration_policy import PairMigrationConstraints, PairMigrationPolicyFactory logger = init_logger(__name__) -class PairMigrationConstraints(str, Enum): - """Target of Migration.""" - NO_CONSTRAINTS = "NO_CONSTRAINTS" - - # Enable the prefill-decoding disaggregration. - DECODING_2_DECODING = "DECODING_2_DECODING" - PREFILL_2_DECODING = "PREFILL_2_DECODING" - class MigrationScheduler: def __init__(self, pair_migration_policy: str, migrate_out_load_threshold: float, instance_load_calculator: InstanceLoadCalculator) -> None: - self.migrate_out_load_threshold = migrate_out_load_threshold + self.filter_config = MigrationFilterConfig(migrate_out_load_threshold=migrate_out_load_threshold) + self.migration_filter = MigrationFilter(self.filter_config) + self.instance_load_calculator = instance_load_calculator self.enable_defrag = instance_load_calculator.enable_defrag if not self.enable_defrag: @@ -57,14 +48,9 @@ def __init__(self, self.sorted_instance_infos: List[InstanceInfo] = None def pair_migration(self, pair_migration_type: PairMigrationConstraints) -> List[Tuple[str, str]]: - self._sort_instance_infos(descending=False) - sorted_src_instance_infos, sorted_dst_instance_infos = self._get_migration_instance_infos(pair_migration_type) - return self.pair_migration_policy.pair_migration(sorted_src_instance_infos, sorted_dst_instance_infos) - - def _get_migration_instance_infos(self, pair_migration_type: PairMigrationConstraints) -> Dict[str, InstanceInfo]: - filter_instance_infos_policy = FilteringInstanceInfosPolicyFactory.get_policy(pair_migration_type, - migrate_out_load_threshold=self.migrate_out_load_threshold) - return filter_instance_infos_policy.filter_instances(self.sorted_instance_infos,pair_migration_type) + src_instance_infos, dst_instance_infos = self.migration_filter.filter_instances( + self.instance_info.values(), pair_migration_type) + return self.pair_migration_policy.pair_migration(src_instance_infos, dst_instance_infos) def update_instance_infos(self, instance_info: Dict[str, InstanceInfo]) -> None: @@ -77,138 +63,3 @@ def add_instance(self, instance_id: str) -> None: def remove_instance(self, instance_id: str) -> None: self.instance_id_set.remove(instance_id) self.num_instances = len(self.instance_id_set) - - def _sort_instance_infos(self, - descending: bool = True) -> None: - instance_infos: List[InstanceInfo] = list(self.instance_info.values()) - key_attr = 'instance_load_migrate' - self.sorted_instance_infos = sorted( - instance_infos, - key=lambda instance_info: getattr(instance_info, key_attr), - reverse=descending - ) - -class FilteringInstanceInfosPolicy(ABC): - def __init__(self, - migrate_out_load_threshold: float) -> None: - self.migrate_out_load_threshold = migrate_out_load_threshold - self.filter_instances_rules = { - PairMigrationConstraints.NO_CONSTRAINTS: (InstanceType.NO_CONSTRAINTS, InstanceType.NO_CONSTRAINTS), - PairMigrationConstraints.DECODING_2_DECODING: (InstanceType.DECODE, InstanceType.DECODE), - PairMigrationConstraints.PREFILL_2_DECODING: (InstanceType.PREFILL, InstanceType.DECODE), - } - - def filter_instances(self, sorted_instance_infos: List[InstanceInfo], - pair_migration_type: PairMigrationConstraints = None) -> Dict[str, InstanceInfo]: - src_type, dst_type = self.filter_instances_rules[pair_migration_type] - filtered_src_instance_infos = [info for info in sorted_instance_infos if info.instance_type == src_type] - filtered_dst_instance_infos = [info for info in sorted_instance_infos if info.instance_type == dst_type] - src_instance_infos = self.filter_src_instances(filtered_src_instance_infos) - dst_instance_infos = self.filter_dst_instances(filtered_dst_instance_infos) - return src_instance_infos, dst_instance_infos - - @abstractmethod - def filter_src_instances(self, filtered_instance_infos) -> Dict[str, InstanceInfo]: - raise NotImplementedError - - @abstractmethod - def filter_dst_instances(self, filtered_instance_infos) -> Dict[str, InstanceInfo]: - raise NotImplementedError - -class FilterConstrained(FilteringInstanceInfosPolicy): - def filter_src_instances(self, filtered_instance_infos: List[InstanceInfo]) -> Dict[str, InstanceInfo]: - src_instance_infos = [i for i in reversed(filtered_instance_infos) - if i.num_killed_requests > 0 or i.instance_load_migrate > self.migrate_out_load_threshold] - return src_instance_infos - - def filter_dst_instances(self, filtered_instance_infos: List[InstanceInfo]) -> Dict[str, InstanceInfo]: - dst_instance_infos = [i for i in filtered_instance_infos - if i.num_killed_requests == 0 and i.instance_load_migrate < self.migrate_out_load_threshold] - return dst_instance_infos - -class FilterRelaxed(FilteringInstanceInfosPolicy): - # The policy is currently used to select the decoding instances to migrate requests from the prefill instances. - def filter_src_instances(self, filtered_instance_infos: List[InstanceInfo]) -> Dict[str, InstanceInfo]: - src_instance_infos = list(reversed(filtered_instance_infos)) - return src_instance_infos - - def filter_dst_instances(self, filtered_instance_infos: List[InstanceInfo]) -> Dict[str, InstanceInfo]: - dst_instance_infos = [i for i in filtered_instance_infos - if i.num_killed_requests == 0] - return dst_instance_infos - -class FilteringInstanceInfosPolicyFactory: - _POLICY_REGISTRY = { - PairMigrationConstraints.NO_CONSTRAINTS: FilterConstrained, - PairMigrationConstraints.DECODING_2_DECODING: FilterConstrained, - PairMigrationConstraints.PREFILL_2_DECODING: FilterRelaxed, - } - - @classmethod - def get_policy(cls, policy_name: PairMigrationConstraints, **kwargs) -> FilteringInstanceInfosPolicy: - return cls._POLICY_REGISTRY[policy_name](**kwargs) - -class PairMigrationPolicy(ABC): - def __init__(self, - migrate_out_load_threshold: float, - instance_load_calculator: InstanceLoadCalculator) -> None: - self.migrate_out_load_threshold = migrate_out_load_threshold - self.instance_load_calculator = instance_load_calculator - - @abstractmethod - def pair_migration(self, - sorted_src_instance_infos: List[InstanceInfo], - sorted_dst_instance_infos: List[InstanceInfo], - ) -> List[Tuple[str, str]]: - raise NotImplementedError - -class Balanced(PairMigrationPolicy): - def pair_migration(self, - sorted_src_instance_infos: List[InstanceInfo], - sorted_dst_instance_infos: List[InstanceInfo], - ) -> List[Tuple[str, str]]: - migrate_instance_pairs = [] - for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))): - load_diff_before_mig = sorted_src_instance_infos[i].instance_load_migrate - sorted_dst_instance_infos[i].instance_load_migrate - left_load_after_mig = self._compute_instance_load_after_migrate(sorted_src_instance_infos[i], is_migrate_in=False) - right_load_after_mig = self._compute_instance_load_after_migrate(sorted_dst_instance_infos[i], is_migrate_in=True) - # Add some constrains to reduce unnecessary migrations - if right_load_after_mig > self.migrate_out_load_threshold: - continue - load_diff_after_mig = left_load_after_mig - right_load_after_mig - if (0 < load_diff_after_mig < load_diff_before_mig) or (sorted_dst_instance_infos[i].instance_load_migrate == -np.inf): - migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id, - sorted_dst_instance_infos[i].instance_id)) - return migrate_instance_pairs - - def _compute_instance_load_after_migrate(self, instance_info: InstanceInfo, is_migrate_in: bool) -> float: - instance_info_after_migrate = copy.deepcopy(instance_info) - num_blocks_last_running_request = instance_info_after_migrate.num_blocks_last_running_request - if is_migrate_in: - instance_info_after_migrate.num_running_requests += 1 - instance_info_after_migrate.num_free_gpu_blocks -= num_blocks_last_running_request - else: - instance_info_after_migrate.num_running_requests -= 1 - instance_info_after_migrate.num_free_gpu_blocks += num_blocks_last_running_request - return self.instance_load_calculator.compute_instance_load(instance_info_after_migrate, action='migrate') - -class DefragConstrained(PairMigrationPolicy): - def pair_migration(self, - sorted_src_instance_infos: List[InstanceInfo], - sorted_dst_instance_infos: List[InstanceInfo], - ) -> List[Tuple[str, str]]: - migrate_instance_pairs = [] - for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))): - # without any constrain in order to make prefill migrate happens as soon as possible - migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id, sorted_dst_instance_infos[i].instance_id)) - return migrate_instance_pairs - -class PairMigrationPolicyFactory: - _POLICY_REGISTRY = { - 'balanced': Balanced, - 'defrag_constrained': DefragConstrained, - } - - @classmethod - def get_policy(cls, policy_name: str, **kwargs) -> PairMigrationPolicy: - return cls._POLICY_REGISTRY[policy_name](**kwargs) diff --git a/llumnix/internal_config.py b/llumnix/internal_config.py index 08f5283f..4412c13b 100644 --- a/llumnix/internal_config.py +++ b/llumnix/internal_config.py @@ -51,6 +51,8 @@ def __init__( self.dispatch_policy = dispatch_policy self.pair_migration_policy = pair_migration_policy + # TODO(KuilongCui): Use a better way to set the threshold, as having both positive and negative + # values can cause confusion. self.migrate_out_load_threshold = migrate_out_threshold*(-1) self.enable_defrag = enable_defrag diff --git a/llumnix/llm_engine_manager.py b/llumnix/llm_engine_manager.py index 66739632..5d8c48a5 100644 --- a/llumnix/llm_engine_manager.py +++ b/llumnix/llm_engine_manager.py @@ -363,7 +363,8 @@ def scale_up(self, instance_id: Union[str, Iterable[str]], llumlet_actor_handles # a coroutine is already handling the changes in the number of instances in the cluster and it will account for the changes # caused by this scale-up (see rebuild_migrate_backend for details). Therefore, we simply return in this case. Specifically, # for RPC, the Ray actor handle is used for the migration cache, so there is no need to rebuild the group. - if self.engine_manager_args.migration_backend != 'rpc' and indeed_update and no_pending_instance: + if self.enable_migration and self.engine_manager_args.migration_backend != 'rpc' \ + and indeed_update and no_pending_instance: asyncio.create_task(self.rebuild_migrate_backend()) return self.num_instances @@ -386,7 +387,7 @@ def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migrate_bac self.global_scheduler.scale_down(instance_ids) self.num_instances = len(self.instances) - if self.engine_manager_args.migration_backend != 'rpc': + if self.enable_migration and self.engine_manager_args.migration_backend != 'rpc': if len(self.instances) == 0: self.pending_rebuild_migration_instances = 0 diff --git a/tests/unit_test/global_scheduler/test_migration_scheduler.py b/tests/unit_test/global_scheduler/test_migration_scheduler.py index 8fd32105..c584207e 100644 --- a/tests/unit_test/global_scheduler/test_migration_scheduler.py +++ b/tests/unit_test/global_scheduler/test_migration_scheduler.py @@ -17,11 +17,13 @@ import numpy as np from llumnix.instance_info import InstanceLoadCalculator, InstanceInfo -from llumnix.global_scheduler.migration_scheduler import MigrationScheduler, PairMigrationConstraints +from llumnix.global_scheduler.migration_scheduler import MigrationScheduler from llumnix.global_scheduler.scaling_scheduler import InstanceType +from llumnix.global_scheduler.migration_filter import MigrationFilter, MigrationFilterConfig +from llumnix.global_scheduler.migration_policy import PairMigrationConstraints MIGRATE_OUT_LOAD_THRESHOLD = 3.0 -INSTANCE_NUM = 4 +INSTANCE_NUM = 16 def init_migration_scheduler(policy='balanced'): instance_load_calculator = InstanceLoadCalculator('remaining_steps', True) @@ -43,57 +45,66 @@ def test_add_instance_and_remove_instance(migration_scheduler): migration_scheduler.remove_instance('instance_2') assert migration_scheduler.num_instances == 0 -@pytest.mark.parametrize("pair_migration_type", ['NO_CONSTRAINTS','DECODING_2_DECODING','PREFILL_2_DECODING']) -def test_get_migration_instance_infos(pair_migration_type): +@pytest.mark.parametrize("pair_migration_type", ['NO_CONSTRAINTS', 'DECODING_2_DECODING', 'PREFILL_2_DECODING']) +def test_migration_filter(pair_migration_type): num_tests = 1000 + migration_filter = MigrationFilter(MigrationFilterConfig(MIGRATE_OUT_LOAD_THRESHOLD)) + for _ in range(num_tests): - instance_info_dict = {} - for instance_id in [f'instance_{i}' for i in range(1, INSTANCE_NUM + 1)]: + instance_infos = [] + + total_prefill_instance_num = 0 + + for instance_id in range(1, INSTANCE_NUM + 1): instance_info = InstanceInfo() instance_info.instance_id = instance_id instance_info.instance_load_migrate = MIGRATE_OUT_LOAD_THRESHOLD + random.uniform(-1, 1) instance_info.num_killed_requests = random.randint(0, 1) + if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: constraint_prefill_instance_num = math.inf else: constraint_prefill_instance_num = random.randint(1, INSTANCE_NUM) - migration_scheduler = init_migration_scheduler() + if constraint_prefill_instance_num == math.inf: instance_info.instance_type = InstanceType.NO_CONSTRAINTS else: - if len([info for info in instance_info_dict.values() - if info.instance_type == InstanceType.PREFILL]) < constraint_prefill_instance_num: + if total_prefill_instance_num < constraint_prefill_instance_num: instance_info.instance_type = InstanceType.PREFILL + total_prefill_instance_num += 1 else: instance_info.instance_type = InstanceType.DECODE - instance_info_dict[instance_id] = instance_info - migration_scheduler.instance_info = instance_info_dict - migration_scheduler._sort_instance_infos(descending=False) - sorted_src_instance_infos, sorted_dst_instance_infos = migration_scheduler._get_migration_instance_infos(pair_migration_type) - for instance in sorted_src_instance_infos: - if pair_migration_type != PairMigrationConstraints.PREFILL_2_DECODING: - assert instance.num_killed_requests > 0 \ - or instance.instance_load_migrate > MIGRATE_OUT_LOAD_THRESHOLD - if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: - assert instance.instance_type == InstanceType.NO_CONSTRAINTS - elif migration_scheduler == PairMigrationConstraints.DECODING_2_DECODING: - assert instance.instance_type == InstanceType.DECODE - else: - assert instance.instance_type == InstanceType.PREFILL - for instance in sorted_dst_instance_infos: - if pair_migration_type != PairMigrationConstraints.PREFILL_2_DECODING: - assert instance.num_killed_requests == 0 and instance.instance_load_migrate < MIGRATE_OUT_LOAD_THRESHOLD - if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: - assert instance.instance_type == InstanceType.NO_CONSTRAINTS - elif migration_scheduler == PairMigrationConstraints.DECODING_2_DECODING: + + instance_infos.append(instance_info) + + src_instance_infos, dst_instance_infos = migration_filter.filter_instances(instance_infos, pair_migration_type) + + for instance in src_instance_infos: + if pair_migration_type != PairMigrationConstraints.PREFILL_2_DECODING: + assert instance.num_killed_requests > 0 \ + or instance.instance_load_migrate > MIGRATE_OUT_LOAD_THRESHOLD + if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: + assert instance.instance_type == InstanceType.NO_CONSTRAINTS + elif pair_migration_type == PairMigrationConstraints.DECODING_2_DECODING: + assert instance.instance_type == InstanceType.DECODE + else: + assert instance.instance_type == InstanceType.PREFILL + + for instance in dst_instance_infos: + if pair_migration_type != PairMigrationConstraints.PREFILL_2_DECODING: + assert instance.num_killed_requests == 0 and instance.instance_load_migrate < MIGRATE_OUT_LOAD_THRESHOLD + if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: + assert instance.instance_type == InstanceType.NO_CONSTRAINTS + elif pair_migration_type == PairMigrationConstraints.DECODING_2_DECODING: + assert instance.instance_type == InstanceType.DECODE + else: assert instance.instance_type == InstanceType.DECODE - else: - assert instance.instance_type == InstanceType.DECODE - assert instance.num_killed_requests == 0 + assert instance.num_killed_requests == 0 -@pytest.mark.parametrize("policy", ['balanced','defrag_constrained']) +@pytest.mark.parametrize("policy", ['balanced', 'defrag_constrained']) def test_pair_migration(policy): num_tests = 1000 + for _ in range(num_tests): migration_scheduler = init_migration_scheduler(policy) instance_info_dict = {} @@ -106,14 +117,9 @@ def test_pair_migration(policy): instance_info.instance_type = InstanceType.NO_CONSTRAINTS instance_info_dict[instance_id] = instance_info migration_scheduler.instance_info = instance_info_dict - migration_scheduler._sort_instance_infos(descending=False) - sorted_src_instance_infos = [i for i in reversed(migration_scheduler.sorted_instance_infos) - if i.instance_type == InstanceType.NO_CONSTRAINTS - and (i.num_killed_requests > 0 or i.instance_load_migrate > migration_scheduler.migrate_out_load_threshold)] - sorted_dst_instance_infos = [i for i in migration_scheduler.sorted_instance_infos - if i.instance_type == InstanceType.NO_CONSTRAINTS - and (i.num_killed_requests == 0 and i.instance_load_migrate < migration_scheduler.migrate_out_load_threshold)] - migrate_instance_pairs = migration_scheduler.pair_migration_policy.pair_migration(sorted_src_instance_infos, sorted_dst_instance_infos) + + migrate_instance_pairs = migration_scheduler.pair_migration(PairMigrationConstraints.NO_CONSTRAINTS) + for migrate_out_instance, migrate_in_instance in migrate_instance_pairs: assert migrate_out_instance != migrate_in_instance if policy == 'balanced': From ed4b5699246ebda9a175414346653079a40b2cd8 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 15 Nov 2024 05:37:23 +0000 Subject: [PATCH 2/3] fix comment --- llumnix/global_scheduler/migration_filter.py | 14 +++++++------- llumnix/global_scheduler/migration_scheduler.py | 4 ++-- .../global_scheduler/test_migration_scheduler.py | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/llumnix/global_scheduler/migration_filter.py b/llumnix/global_scheduler/migration_filter.py index 2f833a0b..f7861679 100644 --- a/llumnix/global_scheduler/migration_filter.py +++ b/llumnix/global_scheduler/migration_filter.py @@ -35,7 +35,7 @@ def filter_src_condition(self, filter_config, pair_migration_type) -> Callable[[ def filter_dst_condition(self, filter_config, pair_migration_type) -> Callable[[InstanceInfo], bool]: raise NotImplementedError -class MigrationFilter(ABC): +class MigrationInstanceFilter(ABC): def __init__(self, filter_config: MigrationFilterConfig) -> None: self.filter_config = filter_config self.registered_filters: Dict[str, MigrationFilterPolicy] = {} @@ -57,9 +57,9 @@ def filter_instances(self, instance_infos: List[InstanceInfo], dst_filter_conditions = [filter.filter_dst_condition() for filter in self.registered_filters.values()] if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: - policy_filter = MigrationFilterPolicyFactory.get_policy("loadconstraint") + policy_filter = MigrationFilterPolicyFactory.get_policy("load") elif pair_migration_type in [PairMigrationConstraints.PREFILL_2_DECODING, PairMigrationConstraints.DECODING_2_DECODING]: - policy_filter = MigrationFilterPolicyFactory.get_policy('pdd') + policy_filter = MigrationFilterPolicyFactory.get_policy('prefill_decode') else: raise ValueError(f"Unsupported pair migration type: {pair_migration_type}") src_filter_conditions.append(policy_filter.filter_src_condition(self.filter_config, pair_migration_type)) @@ -93,7 +93,7 @@ def filter_src_condition(self, filter_config: MigrationFilterConfig, instance_type_filter = lambda instance_info: instance_info.instance_type == src_type if pair_migration_type == PairMigrationConstraints.DECODING_2_DECODING: - inner_policy = MigrationFilterPolicyFactory.get_policy('loadconstraint') + inner_policy = MigrationFilterPolicyFactory.get_policy('load') policy_filter = inner_policy.filter_src_condition(filter_config, pair_migration_type) else: policy_filter = lambda instance_info: True @@ -106,7 +106,7 @@ def filter_dst_condition(self, filter_config: MigrationFilterConfig, instance_type_filter = lambda instance_info: instance_info.instance_type == dst_type if pair_migration_type == PairMigrationConstraints.DECODING_2_DECODING: - inner_policy = MigrationFilterPolicyFactory.get_policy('loadconstraint') + inner_policy = MigrationFilterPolicyFactory.get_policy('load') policy_filter = inner_policy.filter_dst_condition(filter_config, pair_migration_type) else: policy_filter = lambda instance_info: instance_info.num_killed_requests == 0 @@ -136,8 +136,8 @@ def filter_dst_condition(self, filter_config: MigrationFilterConfig, class MigrationFilterPolicyFactory: _POLICY_REGISTRY = { - 'loadconstraint': LoadConstrainedFilter, - 'pdd': PddFilter, + 'load': LoadConstrainedFilter, + 'prefill_decode': PddFilter, 'custom': CustomFilter, } diff --git a/llumnix/global_scheduler/migration_scheduler.py b/llumnix/global_scheduler/migration_scheduler.py index 565eb1dc..ad538f06 100644 --- a/llumnix/global_scheduler/migration_scheduler.py +++ b/llumnix/global_scheduler/migration_scheduler.py @@ -15,7 +15,7 @@ from llumnix.logger import init_logger from llumnix.instance_info import InstanceInfo, InstanceLoadCalculator -from llumnix.global_scheduler.migration_filter import MigrationFilter, MigrationFilterConfig +from llumnix.global_scheduler.migration_filter import MigrationInstanceFilter, MigrationFilterConfig from llumnix.global_scheduler.migration_policy import PairMigrationConstraints, PairMigrationPolicyFactory logger = init_logger(__name__) @@ -26,7 +26,7 @@ def __init__(self, migrate_out_load_threshold: float, instance_load_calculator: InstanceLoadCalculator) -> None: self.filter_config = MigrationFilterConfig(migrate_out_load_threshold=migrate_out_load_threshold) - self.migration_filter = MigrationFilter(self.filter_config) + self.migration_filter = MigrationInstanceFilter(self.filter_config) self.instance_load_calculator = instance_load_calculator self.enable_defrag = instance_load_calculator.enable_defrag diff --git a/tests/unit_test/global_scheduler/test_migration_scheduler.py b/tests/unit_test/global_scheduler/test_migration_scheduler.py index c584207e..fa25e1f8 100644 --- a/tests/unit_test/global_scheduler/test_migration_scheduler.py +++ b/tests/unit_test/global_scheduler/test_migration_scheduler.py @@ -19,7 +19,7 @@ from llumnix.instance_info import InstanceLoadCalculator, InstanceInfo from llumnix.global_scheduler.migration_scheduler import MigrationScheduler from llumnix.global_scheduler.scaling_scheduler import InstanceType -from llumnix.global_scheduler.migration_filter import MigrationFilter, MigrationFilterConfig +from llumnix.global_scheduler.migration_filter import MigrationInstanceFilter, MigrationFilterConfig from llumnix.global_scheduler.migration_policy import PairMigrationConstraints MIGRATE_OUT_LOAD_THRESHOLD = 3.0 @@ -48,7 +48,7 @@ def test_add_instance_and_remove_instance(migration_scheduler): @pytest.mark.parametrize("pair_migration_type", ['NO_CONSTRAINTS', 'DECODING_2_DECODING', 'PREFILL_2_DECODING']) def test_migration_filter(pair_migration_type): num_tests = 1000 - migration_filter = MigrationFilter(MigrationFilterConfig(MIGRATE_OUT_LOAD_THRESHOLD)) + migration_filter = MigrationInstanceFilter(MigrationFilterConfig(MIGRATE_OUT_LOAD_THRESHOLD)) for _ in range(num_tests): instance_infos = [] From 96d931d4aefb31ed4c85e2fbd9a189bad3bcb8ca Mon Sep 17 00:00:00 2001 From: root Date: Fri, 15 Nov 2024 05:53:26 +0000 Subject: [PATCH 3/3] fix --- llumnix/global_scheduler/migration_filter.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llumnix/global_scheduler/migration_filter.py b/llumnix/global_scheduler/migration_filter.py index f7861679..ea82e55b 100644 --- a/llumnix/global_scheduler/migration_filter.py +++ b/llumnix/global_scheduler/migration_filter.py @@ -51,6 +51,9 @@ def register_filter(self, filter_name: str, migration_filter: MigrationFilterPol def unregister_filter(self, filter_name: str) -> None: self.registered_filters.pop(filter_name, None) + def get_filter(self, filter_name: str) -> Optional[MigrationFilterPolicy]: + return self.registered_filters.get(filter_name, None) + def filter_instances(self, instance_infos: List[InstanceInfo], pair_migration_type: PairMigrationConstraints) -> Dict[str, InstanceInfo]: src_filter_conditions = [filter.filter_src_condition() for filter in self.registered_filters.values()]