diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index 5394ea24..231080d4 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -167,7 +167,8 @@ def create_global_scheduler_configs( self.scaling_policy, self.scale_up_threshold, self.scale_down_threshold, - self.enable_pd_disagg) + self.enable_pd_disagg, + self.migration_backend,) return global_scheduler_config def create_migration_config(self) -> MigrationConfig: diff --git a/llumnix/global_scheduler/global_scheduler.py b/llumnix/global_scheduler/global_scheduler.py index ec1568bb..419dd41a 100644 --- a/llumnix/global_scheduler/global_scheduler.py +++ b/llumnix/global_scheduler/global_scheduler.py @@ -43,7 +43,8 @@ def __init__(self, # migrate args self.migration_scheduler = MigrationScheduler(global_scheduler_config.pair_migration_policy, global_scheduler_config.migrate_out_load_threshold, - self.instance_load_calculator) + self.instance_load_calculator, + global_scheduler_config.migration_backend) # auto-scaling args self.scaling_scheduler = ScalingScheduler(global_scheduler_config.scale_up_threshold, global_scheduler_config.scale_down_threshold, diff --git a/llumnix/global_scheduler/migration_scheduler.py b/llumnix/global_scheduler/migration_scheduler.py index ad538f06..74ff3f21 100644 --- a/llumnix/global_scheduler/migration_scheduler.py +++ b/llumnix/global_scheduler/migration_scheduler.py @@ -15,7 +15,7 @@ from llumnix.logger import init_logger from llumnix.instance_info import InstanceInfo, InstanceLoadCalculator -from llumnix.global_scheduler.migration_filter import MigrationInstanceFilter, MigrationFilterConfig +from llumnix.global_scheduler.migration_filter import MigrationInstanceFilter, MigrationFilterConfig, CustomFilter from llumnix.global_scheduler.migration_policy import PairMigrationConstraints, PairMigrationPolicyFactory logger = init_logger(__name__) @@ -24,9 +24,16 @@ class MigrationScheduler: def __init__(self, pair_migration_policy: str, migrate_out_load_threshold: float, - instance_load_calculator: InstanceLoadCalculator) -> None: + instance_load_calculator: InstanceLoadCalculator, + migration_backend: str,) -> None: self.filter_config = MigrationFilterConfig(migrate_out_load_threshold=migrate_out_load_threshold) self.migration_filter = MigrationInstanceFilter(self.filter_config) + migration_backend_init_filter = CustomFilter() + migration_backend_init_filter.set_filter_condtition( + src_filter=lambda _: migration_backend == 'rpc', + dst_filter=lambda _: migration_backend == 'rpc') + self.migration_filter.register_filter("migration_backend_init_filter", + migration_backend_init_filter) self.instance_load_calculator = instance_load_calculator self.enable_defrag = instance_load_calculator.enable_defrag diff --git a/llumnix/internal_config.py b/llumnix/internal_config.py index 4412c13b..54fbc5fc 100644 --- a/llumnix/internal_config.py +++ b/llumnix/internal_config.py @@ -44,7 +44,8 @@ def __init__( scaling_policy: str, scale_up_threshold: float, scale_down_threshold: float, - enable_pd_disagg: bool) -> None: + enable_pd_disagg: bool, + migration_backend: str,) -> None: self.initial_instances = initial_instances self.load_metric = load_metric @@ -62,3 +63,5 @@ def __init__( self.enable_pd_disagg = enable_pd_disagg self.num_dispatch_instances = num_dispatch_instances + + self.migration_backend = migration_backend diff --git a/llumnix/llm_engine_manager.py b/llumnix/llm_engine_manager.py index 5d8c48a5..820b9b72 100644 --- a/llumnix/llm_engine_manager.py +++ b/llumnix/llm_engine_manager.py @@ -25,6 +25,7 @@ from llumnix.logger import init_logger from llumnix.global_scheduler.global_scheduler import GlobalScheduler from llumnix.global_scheduler.migration_scheduler import PairMigrationConstraints +from llumnix.global_scheduler.migration_filter import CustomFilter from llumnix.instance_info import InstanceInfo from llumnix.internal_config import GlobalSchedulerConfig from llumnix.arg_utils import EngineManagerArgs @@ -335,6 +336,12 @@ async def run_task(alive_instances: List[str], task_name: str, *args, **kwargs): self.pending_rebuild_migration_instances = 0 group_name = None + migration_filter: CustomFilter = self.global_scheduler.migration_scheduler\ + .migration_filter.get_filter("migration_backend_init_filter") + migration_filter.set_filter_condtition( + src_filter=lambda instance_info: instance_info.instance_id in alive_instances, + dst_filter=lambda instance_info: instance_info.instance_id in alive_instances) + logger.info("rebuild {} migrate backend done, group_name: {}, alive instance ({}): {}" .format(self.engine_manager_args.migration_backend, group_name, len(alive_instances), alive_instances)) diff --git a/tests/unit_test/global_scheduler/test_global_scheduler.py b/tests/unit_test/global_scheduler/test_global_scheduler.py index adb1f1cc..18c83f85 100644 --- a/tests/unit_test/global_scheduler/test_global_scheduler.py +++ b/tests/unit_test/global_scheduler/test_global_scheduler.py @@ -25,7 +25,7 @@ def init_global_scheduler(): global_scheduler_config = GlobalSchedulerConfig(0, 'remaining_steps', 'load', math.inf, 'defrag_constrained', 3.0, True, 'avg_load', - 10, 60, False) + 10, 60, False, 'rpc') global_scheduler = GlobalScheduler(global_scheduler_config) return global_scheduler diff --git a/tests/unit_test/global_scheduler/test_migration_scheduler.py b/tests/unit_test/global_scheduler/test_migration_scheduler.py index fa25e1f8..89b813c3 100644 --- a/tests/unit_test/global_scheduler/test_migration_scheduler.py +++ b/tests/unit_test/global_scheduler/test_migration_scheduler.py @@ -27,7 +27,7 @@ def init_migration_scheduler(policy='balanced'): instance_load_calculator = InstanceLoadCalculator('remaining_steps', True) - migration_scheduler = MigrationScheduler(policy, MIGRATE_OUT_LOAD_THRESHOLD, instance_load_calculator) + migration_scheduler = MigrationScheduler(policy, MIGRATE_OUT_LOAD_THRESHOLD, instance_load_calculator, 'rpc') return migration_scheduler @pytest.fixture