Skip to content

Commit

Permalink
fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui committed Nov 15, 2024
1 parent be79dee commit ed4b569
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
14 changes: 7 additions & 7 deletions llumnix/global_scheduler/migration_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def filter_src_condition(self, filter_config, pair_migration_type) -> Callable[[
def filter_dst_condition(self, filter_config, pair_migration_type) -> Callable[[InstanceInfo], bool]:
raise NotImplementedError

class MigrationFilter(ABC):
class MigrationInstanceFilter(ABC):
def __init__(self, filter_config: MigrationFilterConfig) -> None:
self.filter_config = filter_config
self.registered_filters: Dict[str, MigrationFilterPolicy] = {}
Expand All @@ -57,9 +57,9 @@ def filter_instances(self, instance_infos: List[InstanceInfo],
dst_filter_conditions = [filter.filter_dst_condition() for filter in self.registered_filters.values()]

if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS:
policy_filter = MigrationFilterPolicyFactory.get_policy("loadconstraint")
policy_filter = MigrationFilterPolicyFactory.get_policy("load")
elif pair_migration_type in [PairMigrationConstraints.PREFILL_2_DECODING, PairMigrationConstraints.DECODING_2_DECODING]:
policy_filter = MigrationFilterPolicyFactory.get_policy('pdd')
policy_filter = MigrationFilterPolicyFactory.get_policy('prefill_decode')
else:
raise ValueError(f"Unsupported pair migration type: {pair_migration_type}")
src_filter_conditions.append(policy_filter.filter_src_condition(self.filter_config, pair_migration_type))
Expand Down Expand Up @@ -93,7 +93,7 @@ def filter_src_condition(self, filter_config: MigrationFilterConfig,
instance_type_filter = lambda instance_info: instance_info.instance_type == src_type

if pair_migration_type == PairMigrationConstraints.DECODING_2_DECODING:
inner_policy = MigrationFilterPolicyFactory.get_policy('loadconstraint')
inner_policy = MigrationFilterPolicyFactory.get_policy('load')
policy_filter = inner_policy.filter_src_condition(filter_config, pair_migration_type)
else:
policy_filter = lambda instance_info: True
Expand All @@ -106,7 +106,7 @@ def filter_dst_condition(self, filter_config: MigrationFilterConfig,
instance_type_filter = lambda instance_info: instance_info.instance_type == dst_type

if pair_migration_type == PairMigrationConstraints.DECODING_2_DECODING:
inner_policy = MigrationFilterPolicyFactory.get_policy('loadconstraint')
inner_policy = MigrationFilterPolicyFactory.get_policy('load')
policy_filter = inner_policy.filter_dst_condition(filter_config, pair_migration_type)
else:
policy_filter = lambda instance_info: instance_info.num_killed_requests == 0
Expand Down Expand Up @@ -136,8 +136,8 @@ def filter_dst_condition(self, filter_config: MigrationFilterConfig,

class MigrationFilterPolicyFactory:
_POLICY_REGISTRY = {
'loadconstraint': LoadConstrainedFilter,
'pdd': PddFilter,
'load': LoadConstrainedFilter,
'prefill_decode': PddFilter,
'custom': CustomFilter,
}

Expand Down
4 changes: 2 additions & 2 deletions llumnix/global_scheduler/migration_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from llumnix.logger import init_logger
from llumnix.instance_info import InstanceInfo, InstanceLoadCalculator
from llumnix.global_scheduler.migration_filter import MigrationFilter, MigrationFilterConfig
from llumnix.global_scheduler.migration_filter import MigrationInstanceFilter, MigrationFilterConfig
from llumnix.global_scheduler.migration_policy import PairMigrationConstraints, PairMigrationPolicyFactory

logger = init_logger(__name__)
Expand All @@ -26,7 +26,7 @@ def __init__(self,
migrate_out_load_threshold: float,
instance_load_calculator: InstanceLoadCalculator) -> None:
self.filter_config = MigrationFilterConfig(migrate_out_load_threshold=migrate_out_load_threshold)
self.migration_filter = MigrationFilter(self.filter_config)
self.migration_filter = MigrationInstanceFilter(self.filter_config)

self.instance_load_calculator = instance_load_calculator
self.enable_defrag = instance_load_calculator.enable_defrag
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_test/global_scheduler/test_migration_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from llumnix.instance_info import InstanceLoadCalculator, InstanceInfo
from llumnix.global_scheduler.migration_scheduler import MigrationScheduler
from llumnix.global_scheduler.scaling_scheduler import InstanceType
from llumnix.global_scheduler.migration_filter import MigrationFilter, MigrationFilterConfig
from llumnix.global_scheduler.migration_filter import MigrationInstanceFilter, MigrationFilterConfig
from llumnix.global_scheduler.migration_policy import PairMigrationConstraints

MIGRATE_OUT_LOAD_THRESHOLD = 3.0
Expand Down Expand Up @@ -48,7 +48,7 @@ def test_add_instance_and_remove_instance(migration_scheduler):
@pytest.mark.parametrize("pair_migration_type", ['NO_CONSTRAINTS', 'DECODING_2_DECODING', 'PREFILL_2_DECODING'])
def test_migration_filter(pair_migration_type):
num_tests = 1000
migration_filter = MigrationFilter(MigrationFilterConfig(MIGRATE_OUT_LOAD_THRESHOLD))
migration_filter = MigrationInstanceFilter(MigrationFilterConfig(MIGRATE_OUT_LOAD_THRESHOLD))

for _ in range(num_tests):
instance_infos = []
Expand Down

0 comments on commit ed4b569

Please sign in to comment.