Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui committed Nov 11, 2024
1 parent 55b9db2 commit 6cbd3da
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 110 deletions.
185 changes: 116 additions & 69 deletions llumnix/global_scheduler/migration_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Tuple, Set
from typing import Callable, Dict, List, Optional, Tuple, Set
from abc import ABC, abstractmethod
from enum import Enum
import copy
Expand All @@ -31,12 +31,18 @@ class PairMigrationConstraints(str, Enum):
DECODING_2_DECODING = "DECODING_2_DECODING"
PREFILL_2_DECODING = "PREFILL_2_DECODING"

class MigrationFilterConfig:
def __init__(self, migrate_out_load_threshold):
self.migrate_out_load_threshold: float = migrate_out_load_threshold

class MigrationScheduler:
def __init__(self,
pair_migration_policy: str,
migrate_out_load_threshold: float,
instance_load_calculator: InstanceLoadCalculator) -> None:
self.migrate_out_load_threshold = migrate_out_load_threshold
self.filter_config = MigrationFilterConfig(migrate_out_load_threshold=migrate_out_load_threshold)
self.migration_filter = MigrationFilter(self.filter_config)

self.instance_load_calculator = instance_load_calculator
self.enable_defrag = instance_load_calculator.enable_defrag
if not self.enable_defrag:
Expand All @@ -57,14 +63,9 @@ def __init__(self,
self.sorted_instance_infos: List[InstanceInfo] = None

def pair_migration(self, pair_migration_type: PairMigrationConstraints) -> List[Tuple[str, str]]:
self._sort_instance_infos(descending=False)
sorted_src_instance_infos, sorted_dst_instance_infos = self._get_migration_instance_infos(pair_migration_type)
return self.pair_migration_policy.pair_migration(sorted_src_instance_infos, sorted_dst_instance_infos)

def _get_migration_instance_infos(self, pair_migration_type: PairMigrationConstraints) -> Dict[str, InstanceInfo]:
filter_instance_infos_policy = FilteringInstanceInfosPolicyFactory.get_policy(pair_migration_type,
migrate_out_load_threshold=self.migrate_out_load_threshold)
return filter_instance_infos_policy.filter_instances(self.sorted_instance_infos,pair_migration_type)
src_instance_infos, dst_instance_infos = self.migration_filter.filter_instances(
self.instance_info.values(), pair_migration_type)
return self.pair_migration_policy.pair_migration(src_instance_infos, dst_instance_infos)

def update_instance_infos(self,
instance_info: Dict[str, InstanceInfo]) -> None:
Expand All @@ -78,74 +79,106 @@ def remove_instance(self, instance_id: str) -> None:
self.instance_id_set.remove(instance_id)
self.num_instances = len(self.instance_id_set)

def _sort_instance_infos(self,
descending: bool = True) -> None:
instance_infos: List[InstanceInfo] = list(self.instance_info.values())
key_attr = 'instance_load_migrate'
self.sorted_instance_infos = sorted(
instance_infos,
key=lambda instance_info: getattr(instance_info, key_attr),
reverse=descending
)

class FilteringInstanceInfosPolicy(ABC):
def __init__(self,
migrate_out_load_threshold: float) -> None:
self.migrate_out_load_threshold = migrate_out_load_threshold
self.filter_instances_rules = {
PairMigrationConstraints.NO_CONSTRAINTS: (InstanceType.NO_CONSTRAINTS, InstanceType.NO_CONSTRAINTS),
PairMigrationConstraints.DECODING_2_DECODING: (InstanceType.DECODE, InstanceType.DECODE),
PairMigrationConstraints.PREFILL_2_DECODING: (InstanceType.PREFILL, InstanceType.DECODE),
}

def filter_instances(self, sorted_instance_infos: List[InstanceInfo],
pair_migration_type: PairMigrationConstraints = None) -> Dict[str, InstanceInfo]:
src_type, dst_type = self.filter_instances_rules[pair_migration_type]
filtered_src_instance_infos = [info for info in sorted_instance_infos if info.instance_type == src_type]
filtered_dst_instance_infos = [info for info in sorted_instance_infos if info.instance_type == dst_type]
src_instance_infos = self.filter_src_instances(filtered_src_instance_infos)
dst_instance_infos = self.filter_dst_instances(filtered_dst_instance_infos)
return src_instance_infos, dst_instance_infos

MIGRATION_FILTER_POLICY_MAPPING = {
PairMigrationConstraints.NO_CONSTRAINTS: 'constraint',
PairMigrationConstraints.DECODING_2_DECODING: 'constraint',
PairMigrationConstraints.PREFILL_2_DECODING: 'relax',
}

class MigrationFilter(ABC):
def __init__(self, filter_config: MigrationFilterConfig) -> None:
self.filter_config = filter_config
self.addition_src_filter: Callable[[InstanceInfo], bool] = None
self.addition_dst_filter: Callable[[InstanceInfo], bool] = None
self.default_filter = MigrationFilterPolicyFactory.get_policy('instancetype')

def set_addition_filter(self, src_filter: Optional[Callable[[InstanceInfo], bool]] = None,
dst_filter: Optional[Callable[[InstanceInfo], bool]] = None) -> None:
if src_filter:
self.addition_src_filter = src_filter
if dst_filter:
self.addition_dst_filter = dst_filter

def remove_addition_filter(self) -> None:
self.addition_src_filter = []
self.addition_dst_filter = []

def filter_instances(self, instance_infos: List[InstanceInfo],
pair_migration_type: PairMigrationConstraints) -> Dict[str, InstanceInfo]:
src_filter_conditions = [self.default_filter.filter_src_condition(self.filter_config, pair_migration_type)]
dst_filter_conditions = [self.default_filter.filter_dst_condition(self.filter_config, pair_migration_type)]

if self.addition_src_filter:
src_filter_conditions.append(self.addition_src_filter)
if self.addition_dst_filter:
dst_filter_conditions.append(self.addition_dst_filter)

policy_filter = MigrationFilterPolicyFactory.get_policy(MIGRATION_FILTER_POLICY_MAPPING[pair_migration_type])
src_filter_conditions.append(policy_filter.filter_src_condition(self.filter_config, pair_migration_type))
dst_filter_conditions.append(policy_filter.filter_dst_condition(self.filter_config, pair_migration_type))

filtered_src_instance_infos = [info for info in instance_infos if all(cond(info) for cond in src_filter_conditions)]
filtered_dst_instance_infos = [info for info in instance_infos if all(cond(info) for cond in dst_filter_conditions)]

return filtered_src_instance_infos, filtered_dst_instance_infos

class MigrationFilterPolicy(ABC):
@abstractmethod
def filter_src_instances(self, filtered_instance_infos) -> Dict[str, InstanceInfo]:
def filter_src_condition(self, filter_config, pair_migration_type) -> Callable[[InstanceInfo], bool]:
raise NotImplementedError

@abstractmethod
def filter_dst_instances(self, filtered_instance_infos) -> Dict[str, InstanceInfo]:
def filter_dst_condition(self, filter_config, pair_migration_type) -> Callable[[InstanceInfo], bool]:
raise NotImplementedError

class FilterConstrained(FilteringInstanceInfosPolicy):
def filter_src_instances(self, filtered_instance_infos: List[InstanceInfo]) -> Dict[str, InstanceInfo]:
src_instance_infos = [i for i in reversed(filtered_instance_infos)
if i.num_killed_requests > 0 or i.instance_load_migrate > self.migrate_out_load_threshold]
return src_instance_infos
class InstanceTypeFilter(MigrationFilterPolicy):
INSTANCE_FILTER_RULES = {
PairMigrationConstraints.NO_CONSTRAINTS: (InstanceType.NO_CONSTRAINTS, InstanceType.NO_CONSTRAINTS),
PairMigrationConstraints.DECODING_2_DECODING: (InstanceType.DECODE, InstanceType.DECODE),
PairMigrationConstraints.PREFILL_2_DECODING: (InstanceType.PREFILL, InstanceType.DECODE),
}

def filter_src_condition(self, filter_config: MigrationFilterConfig,
pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]:
src_type, _ = self.INSTANCE_FILTER_RULES[pair_migration_type]
return lambda instance_info: instance_info.instance_type == src_type


def filter_dst_instances(self, filtered_instance_infos: List[InstanceInfo]) -> Dict[str, InstanceInfo]:
dst_instance_infos = [i for i in filtered_instance_infos
if i.num_killed_requests == 0 and i.instance_load_migrate < self.migrate_out_load_threshold]
return dst_instance_infos
def filter_dst_condition(self, filter_config: MigrationFilterConfig,
pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]:
_, dst_type = self.INSTANCE_FILTER_RULES[pair_migration_type]
return lambda instance_info: instance_info.instance_type == dst_type

class FilterRelaxed(FilteringInstanceInfosPolicy):
class ConstrainedFilter(MigrationFilterPolicy):
def filter_src_condition(self, filter_config: MigrationFilterConfig,
pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]:
return lambda instance_info: instance_info.num_killed_requests > 0 \
or instance_info.instance_load_migrate > filter_config.migrate_out_load_threshold

def filter_dst_condition(self, filter_config: MigrationFilterConfig,
pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]:
return lambda instance_info: instance_info.num_killed_requests == 0 \
and instance_info.instance_load_migrate < filter_config.migrate_out_load_threshold

class RelaxedFilter(MigrationFilterPolicy):
# The policy is currently used to select the decoding instances to migrate requests from the prefill instances.
def filter_src_instances(self, filtered_instance_infos: List[InstanceInfo]) -> Dict[str, InstanceInfo]:
src_instance_infos = list(reversed(filtered_instance_infos))
return src_instance_infos
def filter_src_condition(self, filter_config: MigrationFilterConfig,
pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]:
return lambda instance_info: True

def filter_dst_instances(self, filtered_instance_infos: List[InstanceInfo]) -> Dict[str, InstanceInfo]:
dst_instance_infos = [i for i in filtered_instance_infos
if i.num_killed_requests == 0]
return dst_instance_infos
def filter_dst_condition(self, filter_config: MigrationFilterConfig,
pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]:
return lambda instance_info: instance_info.num_killed_requests == 0

class FilteringInstanceInfosPolicyFactory:
class MigrationFilterPolicyFactory:
_POLICY_REGISTRY = {
PairMigrationConstraints.NO_CONSTRAINTS: FilterConstrained,
PairMigrationConstraints.DECODING_2_DECODING: FilterConstrained,
PairMigrationConstraints.PREFILL_2_DECODING: FilterRelaxed,
'instancetype': InstanceTypeFilter,
'constraint': ConstrainedFilter,
'relax': RelaxedFilter,
}

@classmethod
def get_policy(cls, policy_name: PairMigrationConstraints, **kwargs) -> FilteringInstanceInfosPolicy:
def get_policy(cls, policy_name: PairMigrationConstraints, **kwargs) -> MigrationFilterPolicy:
return cls._POLICY_REGISTRY[policy_name](**kwargs)

class PairMigrationPolicy(ABC):
Expand All @@ -157,16 +190,28 @@ def __init__(self,

@abstractmethod
def pair_migration(self,
sorted_src_instance_infos: List[InstanceInfo],
sorted_dst_instance_infos: List[InstanceInfo],
src_instance_infos: List[InstanceInfo],
dst_instance_infos: List[InstanceInfo],
) -> List[Tuple[str, str]]:
raise NotImplementedError

@classmethod
def sort_instance_infos(cls, instance_infos: List[InstanceInfo], descending: bool = True) -> None:
key_attr = 'instance_load_migrate'
sorted_instance_infos = sorted(
instance_infos,
key=lambda instance_info: getattr(instance_info, key_attr),
reverse=descending
)
return sorted_instance_infos

class Balanced(PairMigrationPolicy):
def pair_migration(self,
sorted_src_instance_infos: List[InstanceInfo],
sorted_dst_instance_infos: List[InstanceInfo],
src_instance_infos: List[InstanceInfo],
dst_instance_infos: List[InstanceInfo],
) -> List[Tuple[str, str]]:
sorted_src_instance_infos = self.sort_instance_infos(src_instance_infos, descending=True)
sorted_dst_instance_infos = self.sort_instance_infos(dst_instance_infos, descending=False)
migrate_instance_pairs = []
for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))):
load_diff_before_mig = sorted_src_instance_infos[i].instance_load_migrate - sorted_dst_instance_infos[i].instance_load_migrate
Expand Down Expand Up @@ -198,9 +243,11 @@ def _compute_instance_load_after_migrate(self, instance_info: InstanceInfo, is_m

class DefragConstrained(PairMigrationPolicy):
def pair_migration(self,
sorted_src_instance_infos: List[InstanceInfo],
sorted_dst_instance_infos: List[InstanceInfo],
src_instance_infos: List[InstanceInfo],
dst_instance_infos: List[InstanceInfo],
) -> List[Tuple[str, str]]:
sorted_src_instance_infos = self.sort_instance_infos(src_instance_infos, descending=True)
sorted_dst_instance_infos = self.sort_instance_infos(dst_instance_infos, descending=False)
migrate_instance_pairs = []
for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))):
# without any constrain in order to make prefill migrate happens as soon as possible
Expand Down
2 changes: 2 additions & 0 deletions llumnix/internal_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def __init__(
self.dispatch_policy = dispatch_policy

self.pair_migration_policy = pair_migration_policy
# TODO(KuilongCui): Use a better way to set the threshold, as having both positive and negative
# values can cause confusion.
self.migrate_out_load_threshold = migrate_out_threshold*(-1)
self.enable_defrag = enable_defrag

Expand Down
Loading

0 comments on commit 6cbd3da

Please sign in to comment.