Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] refactor migration scheduler #66

Merged
merged 3 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llumnix/global_scheduler/global_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from llumnix.internal_config import GlobalSchedulerConfig
from llumnix.instance_info import InstanceLoadCalculator, InstanceInfo
from llumnix.global_scheduler.dispatch_scheduler import DispatchScheduler
from llumnix.global_scheduler.migration_scheduler import MigrationScheduler, PairMigrationConstraints
from llumnix.global_scheduler.migration_scheduler import MigrationScheduler
from llumnix.global_scheduler.migration_policy import PairMigrationConstraints
from llumnix.global_scheduler.scaling_scheduler import ScalingScheduler

logger = init_logger(__name__)
Expand Down
149 changes: 149 additions & 0 deletions llumnix/global_scheduler/migration_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Dict, List, Optional
from abc import ABC, abstractmethod

from llumnix.logger import init_logger
from llumnix.instance_info import InstanceInfo
from llumnix.global_scheduler.scaling_scheduler import InstanceType
from llumnix.global_scheduler.migration_policy import PairMigrationConstraints

logger = init_logger(__name__)

class MigrationFilterConfig:
def __init__(self, migrate_out_load_threshold):
self.migrate_out_load_threshold: float = migrate_out_load_threshold

# TODO(KuilongCui): A filter might contain other filters; leave this for the future.
class MigrationFilterPolicy(ABC):
@abstractmethod
def filter_src_condition(self, filter_config, pair_migration_type) -> Callable[[InstanceInfo], bool]:
raise NotImplementedError

@abstractmethod
def filter_dst_condition(self, filter_config, pair_migration_type) -> Callable[[InstanceInfo], bool]:
raise NotImplementedError

class MigrationInstanceFilter(ABC):
def __init__(self, filter_config: MigrationFilterConfig) -> None:
self.filter_config = filter_config
self.registered_filters: Dict[str, MigrationFilterPolicy] = {}

def register_filter(self, filter_name: str, migration_filter: MigrationFilterPolicy) -> bool:
if filter_name in self.registered_filters:
s5u13b marked this conversation as resolved.
Show resolved Hide resolved
logger.warning("migration filter {} has been registered.".format(filter_name))
return False

self.registered_filters[filter_name] = migration_filter
return True

def unregister_filter(self, filter_name: str) -> None:
self.registered_filters.pop(filter_name, None)

def get_filter(self, filter_name: str) -> Optional[MigrationFilterPolicy]:
return self.registered_filters.get(filter_name, None)

def filter_instances(self, instance_infos: List[InstanceInfo],
pair_migration_type: PairMigrationConstraints) -> Dict[str, InstanceInfo]:
src_filter_conditions = [filter.filter_src_condition() for filter in self.registered_filters.values()]
dst_filter_conditions = [filter.filter_dst_condition() for filter in self.registered_filters.values()]

if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS:
policy_filter = MigrationFilterPolicyFactory.get_policy("load")
elif pair_migration_type in [PairMigrationConstraints.PREFILL_2_DECODING, PairMigrationConstraints.DECODING_2_DECODING]:
policy_filter = MigrationFilterPolicyFactory.get_policy('prefill_decode')
else:
raise ValueError(f"Unsupported pair migration type: {pair_migration_type}")
src_filter_conditions.append(policy_filter.filter_src_condition(self.filter_config, pair_migration_type))
dst_filter_conditions.append(policy_filter.filter_dst_condition(self.filter_config, pair_migration_type))

filtered_src_instance_infos = [info for info in instance_infos if all(cond(info) for cond in src_filter_conditions)]
filtered_dst_instance_infos = [info for info in instance_infos if all(cond(info) for cond in dst_filter_conditions)]

return filtered_src_instance_infos, filtered_dst_instance_infos

class LoadConstrainedFilter(MigrationFilterPolicy):
def filter_src_condition(self, filter_config: MigrationFilterConfig,
pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]:
return lambda instance_info: instance_info.num_killed_requests > 0 \
or instance_info.instance_load_migrate > filter_config.migrate_out_load_threshold

def filter_dst_condition(self, filter_config: MigrationFilterConfig,
pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]:
return lambda instance_info: instance_info.num_killed_requests == 0 \
and instance_info.instance_load_migrate < filter_config.migrate_out_load_threshold

class PddFilter(MigrationFilterPolicy):
INSTANCE_FILTER_RULES = {
PairMigrationConstraints.DECODING_2_DECODING: (InstanceType.DECODE, InstanceType.DECODE),
PairMigrationConstraints.PREFILL_2_DECODING: (InstanceType.PREFILL, InstanceType.DECODE),
}

def filter_src_condition(self, filter_config: MigrationFilterConfig,
pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]:
src_type, _ = self.INSTANCE_FILTER_RULES[pair_migration_type]
instance_type_filter = lambda instance_info: instance_info.instance_type == src_type

if pair_migration_type == PairMigrationConstraints.DECODING_2_DECODING:
inner_policy = MigrationFilterPolicyFactory.get_policy('load')
policy_filter = inner_policy.filter_src_condition(filter_config, pair_migration_type)
else:
policy_filter = lambda instance_info: True

return lambda instance_info: instance_type_filter(instance_info) and policy_filter(instance_info)

def filter_dst_condition(self, filter_config: MigrationFilterConfig,
pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]:
_, dst_type = self.INSTANCE_FILTER_RULES[pair_migration_type]
instance_type_filter = lambda instance_info: instance_info.instance_type == dst_type

if pair_migration_type == PairMigrationConstraints.DECODING_2_DECODING:
inner_policy = MigrationFilterPolicyFactory.get_policy('load')
policy_filter = inner_policy.filter_dst_condition(filter_config, pair_migration_type)
else:
policy_filter = lambda instance_info: instance_info.num_killed_requests == 0

return lambda instance_info: instance_type_filter(instance_info) and policy_filter(instance_info)

class CustomFilter(MigrationFilterPolicy):
s5u13b marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self):
super().__init__()
self.src_filter = lambda _: True
self.dst_filter = lambda _: True

def set_filter_condtition(self, src_filter: Optional[Callable[[InstanceInfo], bool]] = None,
dst_filter: Optional[Callable[[InstanceInfo], bool]] = None) -> None:
if src_filter:
self.src_filter = src_filter
if dst_filter:
self.dst_filter = dst_filter

def filter_src_condition(self, filter_config: MigrationFilterConfig,
pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]:
return self.src_filter

def filter_dst_condition(self, filter_config: MigrationFilterConfig,
pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]:
return self.dst_filter

class MigrationFilterPolicyFactory:
_POLICY_REGISTRY = {
'load': LoadConstrainedFilter,
'prefill_decode': PddFilter,
'custom': CustomFilter,
}

@classmethod
def get_policy(cls, policy_name: PairMigrationConstraints, **kwargs) -> MigrationFilterPolicy:
return cls._POLICY_REGISTRY[policy_name](**kwargs)
113 changes: 113 additions & 0 deletions llumnix/global_scheduler/migration_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple
from abc import ABC, abstractmethod
from enum import Enum
import copy
import numpy as np

from llumnix.logger import init_logger
from llumnix.instance_info import InstanceInfo, InstanceLoadCalculator

logger = init_logger(__name__)

class PairMigrationConstraints(str, Enum):
"""Target of Migration."""
NO_CONSTRAINTS = "NO_CONSTRAINTS"

# Enable the prefill-decoding disaggregration.
DECODING_2_DECODING = "DECODING_2_DECODING"
PREFILL_2_DECODING = "PREFILL_2_DECODING"

class PairMigrationPolicy(ABC):
def __init__(self,
migrate_out_load_threshold: float,
instance_load_calculator: InstanceLoadCalculator) -> None:
self.migrate_out_load_threshold = migrate_out_load_threshold
self.instance_load_calculator = instance_load_calculator

@abstractmethod
def pair_migration(self,
src_instance_infos: List[InstanceInfo],
dst_instance_infos: List[InstanceInfo],
) -> List[Tuple[str, str]]:
raise NotImplementedError

def sort_instance_infos(self, instance_infos: List[InstanceInfo], descending: bool = True) -> None:
key_attr = 'instance_load_migrate'
sorted_instance_infos = sorted(
instance_infos,
key=lambda instance_info: getattr(instance_info, key_attr),
reverse=descending
)
return sorted_instance_infos

class Balanced(PairMigrationPolicy):
def pair_migration(self,
src_instance_infos: List[InstanceInfo],
dst_instance_infos: List[InstanceInfo],
) -> List[Tuple[str, str]]:
sorted_src_instance_infos = self.sort_instance_infos(src_instance_infos, descending=True)
s5u13b marked this conversation as resolved.
Show resolved Hide resolved
sorted_dst_instance_infos = self.sort_instance_infos(dst_instance_infos, descending=False)
migrate_instance_pairs = []
for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))):
load_diff_before_mig = sorted_src_instance_infos[i].instance_load_migrate - sorted_dst_instance_infos[i].instance_load_migrate

left_load_after_mig = self._compute_instance_load_after_migrate(sorted_src_instance_infos[i], is_migrate_in=False)
right_load_after_mig = self._compute_instance_load_after_migrate(sorted_dst_instance_infos[i], is_migrate_in=True)

# Add some constrains to reduce unnecessary migrations
if right_load_after_mig > self.migrate_out_load_threshold:
continue
load_diff_after_mig = left_load_after_mig - right_load_after_mig
if (0 < load_diff_after_mig < load_diff_before_mig) or (sorted_dst_instance_infos[i].instance_load_migrate == -np.inf):
migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id,
sorted_dst_instance_infos[i].instance_id))
return migrate_instance_pairs

def _compute_instance_load_after_migrate(self, instance_info: InstanceInfo, is_migrate_in: bool) -> float:
instance_info_after_migrate = copy.deepcopy(instance_info)
num_blocks_last_running_request = instance_info_after_migrate.num_blocks_last_running_request

if is_migrate_in:
instance_info_after_migrate.num_running_requests += 1
instance_info_after_migrate.num_free_gpu_blocks -= num_blocks_last_running_request
else:
instance_info_after_migrate.num_running_requests -= 1
instance_info_after_migrate.num_free_gpu_blocks += num_blocks_last_running_request

return self.instance_load_calculator.compute_instance_load(instance_info_after_migrate, action='migrate')

class DefragConstrained(PairMigrationPolicy):
def pair_migration(self,
src_instance_infos: List[InstanceInfo],
dst_instance_infos: List[InstanceInfo],
) -> List[Tuple[str, str]]:
sorted_src_instance_infos = self.sort_instance_infos(src_instance_infos, descending=True)
sorted_dst_instance_infos = self.sort_instance_infos(dst_instance_infos, descending=False)
migrate_instance_pairs = []
for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))):
# without any constrain in order to make prefill migrate happens as soon as possible
migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id, sorted_dst_instance_infos[i].instance_id))
return migrate_instance_pairs

class PairMigrationPolicyFactory:
_POLICY_REGISTRY = {
'balanced': Balanced,
'defrag_constrained': DefragConstrained,
}

@classmethod
def get_policy(cls, policy_name: str, **kwargs) -> PairMigrationPolicy:
return cls._POLICY_REGISTRY[policy_name](**kwargs)
Loading