diff --git a/llumnix/__init__.py b/llumnix/__init__.py index 4ea77baf..bb0cee80 100644 --- a/llumnix/__init__.py +++ b/llumnix/__init__.py @@ -11,28 +11,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -import vllm -from vllm import * +# import vllm +# from vllm import * -from llumnix.server_info import ServerInfo -from llumnix.entrypoints.llumnix_utils import (launch_ray_cluster, connect_to_ray_cluster, - init_manager, init_llumlets) -from llumnix.arg_utils import EngineManagerArgs -from llumnix.llm_engine_manager import LLMEngineManager -from llumnix.llumlet.llumlet import Llumlet +# from llumnix.server_info import ServerInfo +# from llumnix.entrypoints.llumnix_utils import (launch_ray_cluster, connect_to_ray_cluster, +# init_manager, init_llumlets) +# from llumnix.arg_utils import EngineManagerArgs +# from llumnix.llm_engine_manager import LLMEngineManager +# from llumnix.llumlet.llumlet import Llumlet -from .version import __version__ +# from .version import __version__ -__all__ = [ - "__version__", - "ServerInfo", - "launch_ray_cluster", - "connect_to_ray_cluster", - "init_manager", - "init_llumlets", - "EngineManagerArgs", - "LLMEngineManager", - "Llumlet" -] +# __all__ = [ +# "__version__", +# "ServerInfo", +# "launch_ray_cluster", +# "connect_to_ray_cluster", +# "init_manager", +# "init_llumlets", +# "EngineManagerArgs", +# "LLMEngineManager", +# "Llumlet" +# ] -__all__.extend(getattr(vllm, "__all__", [])) +# __all__.extend(getattr(vllm, "__all__", [])) diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index 571b5d4f..15352b3a 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -16,8 +16,11 @@ import argparse from typing import Tuple +from llumnix.common.config import get_cfg from llumnix.config import GlobalSchedulerConfig, MigrationConfig +from llumnix.logger import init_logger +logger = init_logger(__name__) @dataclass class EngineManagerArgs: @@ -59,9 +62,15 @@ class EngineManagerArgs: last_stage_max_blocks: int = 16 max_stages: int = 3 + config_file: str = None def create_global_scheduler_configs( self, ) -> Tuple[GlobalSchedulerConfig]: + + config_data = get_cfg() + config_data.merge_from_file(self.config_file) + + # Create the GlobalScheduler Configuration. global_scheduler_config = GlobalSchedulerConfig(self.initial_instances, self.load_metric, self.dispatch_policy, @@ -70,7 +79,9 @@ def create_global_scheduler_configs( self.enable_defrag, self.scaling_policy, self.scale_up_threshold, - self.scale_down_threshold) + self.scale_down_threshold, + config_data.PDD_CONFIG.ENABLE_PREFILL_DISAGGREATION, + config_data.PDD_CONFIG.PREFILL_INSTANCE_NUM) return global_scheduler_config def create_migration_config(self) -> MigrationConfig: @@ -229,5 +240,8 @@ def add_cli_args( type=int, default=EngineManagerArgs.max_stages, help='drop migration if the number of stages > max_stages') - + parser.add_argument("--config-file", + type=str, + default=EngineManagerArgs.config_file, + help="path to the configuration file") return parser diff --git a/llumnix/backends/backend_interface.py b/llumnix/backends/backend_interface.py index 441fab2d..f3bd91e2 100644 --- a/llumnix/backends/backend_interface.py +++ b/llumnix/backends/backend_interface.py @@ -32,7 +32,7 @@ def is_sim_backend(status: "BackendType") -> bool: class BackendInterface(ABC): # Methods for inference @abstractmethod - def add_request(self, request_id: str, server_info: ServerInfo, + def add_request(self, request_id: str, server_info: ServerInfo, request_expected_steps: int, *args, **kwargs) -> None: """Adds a new inference request to the backend's processing queue. @@ -42,6 +42,9 @@ def add_request(self, request_id: str, server_info: ServerInfo, Args: request_id: Request ID. server_info: The information of the api server where the request come. + request_expected_steps: The expected number of steps for the request to run.The number of steps + represents the sum of the times 'engine.step()' has been called by the + backend instances for the request. *args: Positional arguments that represent request-specific data. **kwargs: Keyword arguments that contain metadata of the backend request (request_id, arrival_time, etc.). @@ -267,6 +270,21 @@ def commit_dst_request(self, backend_request: LlumnixRequest) -> None: of the request. """ raise NotImplementedError + + @abstractmethod + def update_strict_pre_migration(self, new_strict_pre_migration: bool) -> None: + """Update the status of whether to force migration in the backend engine. + + This method updates the status of whether to force migration in the backend engine. This action is performed only when the + corresponding status in the llumlet is changed. + `pre_migration` represents whether the backend instance enables migration. By default, `pre_migration` is set to True, indicating that + the instance enables migration when `request.output_len` >= `request.request_expected_steps`. If `pre_migration` is set + to False, migration will not occur, and requests on the instance that reach the `request_expected_steps` will continue with inference. + + Args: + new_strict_pre_migration: New migration status provided for backend engine. + """ + raise NotImplementedError @abstractmethod def get_all_request_ids(self) -> List[str]: diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index b7c1ab48..966d1296 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -147,11 +147,12 @@ def update_instance_info(self, instance_info: InstanceInfo) -> None: instance_info.num_blocks_last_running_request = self.instance_info.num_blocks_last_running_request self.instance_info = instance_info - def add_request(self, request_id: str, server_info: ServerInfo, *args, **kwargs): + def add_request(self, request_id: str, server_info: ServerInfo, request_expected_steps: int, *args, **kwargs): super().add_request(request_id, *args, **kwargs) + logger.info("add_request") seq_group = self.scheduler.waiting[-1] - self.scheduler.waiting[-1] = SequenceGroupLlumnix(request_id, server_info, [seq_group.get_seqs()[0]], seq_group.sampling_params, - seq_group.metrics.arrival_time, seq_group.lora_request, seq_group.multi_modal_data) + self.scheduler.waiting[-1] = SequenceGroupLlumnix(request_id, server_info, request_expected_steps, [seq_group.get_seqs()[0]], seq_group.sampling_params, + seq_group.metrics.arrival_time, seq_group.lora_request, seq_group.multi_modal_data) self.scheduler.scheduler_lock.release() def _put_request_output_to_server(self, request_outputs, server_infos: List[ServerInfo]) -> None: @@ -181,13 +182,14 @@ def __init__( placement_group: "PlacementGroup" = None, node_id: str = None ) -> None: + self.strict_pre_migration = True self.engine: LLMEngineLlumnix = LLMEngineLlumnix.from_engine_args(engine_args=engine_args, migration_config=migration_config, instance_id=instance_id, placement_group=placement_group, node_id=node_id) # multi-instance args - self.engine.scheduler = SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config) + self.engine.scheduler = SchedulerLlumnix(self.strict_pre_migration, self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config) self.engine.scheduler.add_update_instance_info_callback(self.engine.update_instance_info) self.engine.output_processor.scheduler = self.engine.scheduler self.instance_id = instance_id @@ -212,10 +214,12 @@ def execute_worker_method(self, method, *args, **kwargs): def add_request(self, request_id: str, server_info: ServerInfo, + request_expected_steps: int, *args, **kwargs) -> None: # Store the server information of each request to put the request outputs back to the corresponding api server correctly. - self.engine.add_request(request_id, server_info, *args, **kwargs) + self.engine.add_request(request_id, server_info, request_expected_steps, *args, **kwargs) + def commit_dst_request(self, backend_request: SequenceGroupLlumnix) -> None: seq = backend_request.get_seqs()[0] @@ -223,8 +227,8 @@ def commit_dst_request(self, backend_request: SequenceGroupLlumnix) -> None: logger.info("add seq {} to block table".format(seq.seq_id)) pre_alloc_blocks = self.engine.scheduler.pre_alloc_cache_dict.pop(backend_request.request_id) self.engine.scheduler.block_manager.add_block_table(pre_alloc_blocks, seq.seq_id) - self.add_running_request(backend_request) backend_request.reset_migration_args() + self.add_running_request(backend_request) def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None: ray.get(dst_ray_actor.execute_engine_method.remote("_run_workers", @@ -248,7 +252,10 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: def get_running_queue(self ) -> List[SequenceGroupLlumnix]: return self.engine.scheduler.get_running_queue() - + def update_strict_pre_migration(self, new_migration_state: bool): + if self.strict_pre_migration != new_migration_state: + self.strict_pre_migration = new_migration_state + self.engine.scheduler.update_strict_pre_migration(new_migration_state) def get_request_incremental_blocks(self, *args, **kwargs) -> List[int]: return self.engine.scheduler.get_request_incremental_blocks(*args, **kwargs) @@ -281,6 +288,5 @@ def free_dst_pre_alloc_cache(self, *args, **kwargs) -> None: def free_src_request(self, backend_request: SequenceGroup) -> None: return self.engine.scheduler.free_src_request(backend_request) - def get_all_request_ids(self) -> List[str]: return self.engine.scheduler.get_all_request_ids() diff --git a/llumnix/backends/vllm/scheduler.py b/llumnix/backends/vllm/scheduler.py index aed61fec..19fdfe2d 100644 --- a/llumnix/backends/vllm/scheduler.py +++ b/llumnix/backends/vllm/scheduler.py @@ -14,6 +14,7 @@ from asyncio.log import logger import time import threading +import copy from typing import Dict, List, Optional, Tuple from vllm.core.block_manager_v1 import BlockSpaceManagerV1, BlockTable @@ -45,7 +46,7 @@ def add_block_table(self, block_table: BlockTable, seq_id: int) -> None: self.block_tables[seq_id] = block_table.copy() class SchedulerLlumnix(Scheduler): - def __init__(self, *args, **kwargs) -> None: + def __init__(self, strict_pre_migration, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.block_manager = BlockManagerLlumnix( block_size=self.cache_config.block_size, @@ -56,6 +57,7 @@ def __init__(self, *args, **kwargs) -> None: self.pre_alloc_cache_dict: Dict[str, BlockTable] = {} self.scheduler_lock = threading.Lock() self.migrating_out_request_last_stage: List[LlumnixRequest] = [] + self.strict_pre_migration = strict_pre_migration def add_update_instance_info_callback(self, update_instance_info_callback): self.update_instance_info_callback = update_instance_info_callback @@ -205,6 +207,25 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: self.update_instance_info_callback(self._get_instance_info()) return seq_group_metadata_list, scheduler_outputs + def _schedule_running(self, *args, **kwargs): + args_list = list(args) + args_list[0] = copy.deepcopy(self.running) + remove_running = [] + if self.strict_pre_migration: + for seq_group in list(args_list[0]): + if seq_group.expected_steps > 0 and seq_group.output_len >= seq_group.expected_steps: + args_list[0].remove(seq_group) + remove_running.append(seq_group) + new_args = tuple(args_list) + remaining_running, running_scheduled = super()._schedule_running(*new_args, **kwargs) + for seq_group in remove_running: + remaining_running.append(seq_group) + return remaining_running, running_scheduled + + @scheduler_lock + def update_strict_pre_migration(self, new_migration_state: bool) -> None: + self.strict_pre_migration = new_migration_state + def add_seq_group(self, *args, **kwargs): # The scheduler lock is mannually released in the end of LLMEngineLlumnix.add_request function. # pylint: disable=R1732 diff --git a/llumnix/backends/vllm/sequence.py b/llumnix/backends/vllm/sequence.py index 1c0c4456..4c61b261 100644 --- a/llumnix/backends/vllm/sequence.py +++ b/llumnix/backends/vllm/sequence.py @@ -17,9 +17,9 @@ class SequenceGroupLlumnix(SequenceGroup, LlumnixRequest): - def __init__(self, request_id, server_info, *args, **kwargs) -> None: + def __init__(self, request_id, server_info, request_expected_steps: int, *args, **kwargs) -> None: SequenceGroup.__init__(self, request_id, *args, **kwargs) - LlumnixRequest.__init__(self, request_id, server_info) + LlumnixRequest.__init__(self, request_id, server_info, request_expected_steps) @property def prompt_len(self) -> int: diff --git a/llumnix/common/config.py b/llumnix/common/config.py new file mode 100644 index 00000000..7149f5aa --- /dev/null +++ b/llumnix/common/config.py @@ -0,0 +1,211 @@ +# encoding: utf-8 +""" +@author: l1aoxingyu +@contact: sherlockliao01@gmail.com +""" + +import copy +from typing import Any +from ast import literal_eval +import yaml + +from llumnix.logger import init_logger + +logger = init_logger(__name__) + +class Config(dict): + """ + Config represents an internal node in the configuration tree. It's a simple + dict-like container that allows for attribute-based access to keys. + """ + def __init__(self, init_dict=None, key_list=None): + """ + Args: + init_dict (dict): the possibly-nested dictionary to initailize the Config. + key_list (list[str]): a list of names which index this Config from the root. + Currently only used for logging purposes. + new_allowed (bool): whether adding new key is allowed when merging with + other configs. + """ + # Recursively convert nested dictionaries in init_dict into Configs + init_dict = {} if init_dict is None else init_dict + key_list = [] if key_list is None else key_list + init_dict = self._create_config_tree_from_dict(init_dict, key_list) + super(Config, self).__init__(init_dict) + + @classmethod + def _create_config_tree_from_dict(cls, dic, key_list): + """ + Create a configuration tree using the given dict. + Any dict-like objects inside dict will be treated as a new Config. + + Args: + dic (dict): + key_list (list[str]): a list of names which index this Config from the root. + Currently only used for logging purposes. + """ + dic = copy.deepcopy(dic) + for k, v in dic.items(): + if isinstance(v, dict): + # Convert dict to Config + dic[k] = cls(v, key_list=key_list + [k]) + return dic + + @classmethod + def _decode_cfg_value(cls, value): + """ + Decodes a raw config value (e.g., from a yaml config files or command + line argument) into a Python object. + + If the value is a dict, it will be interpreted as a new Config. + If the value is a str, it will be evaluated as literals. + Otherwise it is returned as-is. + """ + # Configs parsed from raw yaml will contain dictionary keys that need to be + # converted to Config objects + if isinstance(value, dict): + return cls(value) + # All remaining processing is only applied to strings + if not isinstance(value, str): + return value + # Try to interpret `value` as a: + # string, number, tuple, list, dict, boolean, or None + try: + value = literal_eval(value) + # The following two excepts allow v to pass through when it represents a + # string. + # + # Longer explanation: + # The type of v is always a string (before calling literal_eval), but + # sometimes it *represents* a string and other times a data structure, like + # a list. In the case that v represents a string, what we got back from the + # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is + # ok with '"foo"', but will raise a ValueError if given 'foo'. In other + # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval + # will raise a SyntaxError. + except ValueError: + pass + except SyntaxError: + pass + return value + + + @staticmethod + def load_yaml_with_base(filename: str, allow_unsafe: bool = False): + """ + With "allow_unsafe=True", it supports pyyaml tags that evaluate + expressions in config. See examples in + https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types + Note that this may lead to arbitrary code execution: you must not + load a config file from untrusted sources before manually inspecting + the content of the file. + Args: + filename (str): the file name of the current config. Will be used to + find the base config file. + allow_unsafe (bool): whether to allow loading the config file with + `yaml.unsafe_load`. + Returns: + (dict): the loaded yaml + """ + with open(filename, "r", encoding='utf-8') as f: + try: + cfg = yaml.safe_load(f) + except yaml.constructor.ConstructorError: + if not allow_unsafe: + raise + logger.warning( + "Loading config {} with yaml.unsafe_load. Your machine may " + "be at risk if the file contains malicious content.".format( + filename + ) + ) + f.close() + with open(filename, "r") as f: + cfg = yaml.unsafe_load(f) + return cfg + + # def __str__(self): + # def _indent(s_, num_spaces): + # s = s_.split("\n") + # if len(s) == 1: + # return s_ + # first = s.pop(0) + # s = [(num_spaces * " ") + line for line in s] + # s = "\n".join(s) + # s = first + "\n" + s + # return s + + # r = "" + # s = [] + # for k, v in sorted(self.items()): + # seperator = "\n" if isinstance(v, Config) else " " + # attr_str = "{}:{}{}".format(str(k), seperator, str(v)) + # attr_str = _indent(attr_str, 2) + # s.append(attr_str) + # r += "\n".join(s) + # return r + + # def __repr__(self): + # return "{}({})".format(self.__class__.__name__, super(Config, self).__repr__()) + + + + def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = False): + """ + Merge configs from a given yaml file. + Args: + cfg_filename: the file name of the yaml config. + allow_unsafe: whether to allow loading the config file with + `yaml.unsafe_load`. + """ + loaded_cfg = self.load_yaml_with_base( + cfg_filename, allow_unsafe=allow_unsafe + ) + _merge_a_into_b(loaded_cfg, self, self, []) + + def clone(self): + """Recursively copy this Config.""" + return copy.deepcopy(self) + + def __getattr__(self, name): + if name in self: + return self[name] + else: + raise AttributeError(name) + + def __setattr__(self, name, value): + self[name] = value + + +def get_cfg() -> Config: + """ + Get a copy of the default config. + Returns: + a Config instance. + """ + from llumnix.common.defaults import _C + + return _C.clone() + +def _merge_a_into_b(a, b, root, key_list): + """Merge config dictionary a into config dictionary b, clobbering the + options in b whenever they are also specified in a. + """ + if a is None: + return + for k, v_ in a.items(): + full_key = ".".join(key_list + [k]) + + v = copy.deepcopy(v_) + v = b._decode_cfg_value(v) + if k in b: + # Recursively merge dicts + if isinstance(v, Config): + try: + _merge_a_into_b(v, b[k], root, key_list + [k]) + except BaseException: + raise + else: + b[k] = v + else: + raise KeyError("Non-existent config key: {}".format(full_key)) diff --git a/llumnix/common/defaults.py b/llumnix/common/defaults.py new file mode 100644 index 00000000..1fa0e8c7 --- /dev/null +++ b/llumnix/common/defaults.py @@ -0,0 +1,15 @@ +from llumnix.common.config import Config as Cfg + +# ----------------------------------------------------------------------------- +# Config definition +# ----------------------------------------------------------------------------- + +_C = Cfg() + +# ----------------------------------------------------------------------------- +# Prefill Decoding Disaggregation Config +# ----------------------------------------------------------------------------- +_C.PDD_CONFIG = Cfg() +_C.PDD_CONFIG.ENABLE_PREFILL_DISAGGREATION = False +_C.PDD_CONFIG.PREFILL_INSTANCE_NUM = -1 +_C.PDD_CONFIG.PREFILL_INSTANCE_TYPE = None \ No newline at end of file diff --git a/llumnix/config.py b/llumnix/config.py index af624ff2..8c292df6 100644 --- a/llumnix/config.py +++ b/llumnix/config.py @@ -40,7 +40,9 @@ def __init__( enable_defrag: bool, scaling_policy: str, scale_up_threshold: float, - scale_down_threshold: float) -> None: + scale_down_threshold: float, + enable_pd_disaggregation: bool, + available_dispatch_instance_num: int) -> None: self.initial_instances = initial_instances self.load_metric = load_metric @@ -53,3 +55,6 @@ def __init__( self.scaling_policy = scaling_policy self.scale_up_threshold = scale_up_threshold*(-1) self.scale_down_threshold = scale_down_threshold*(-1) + + self.enable_pd_disaggregation = enable_pd_disaggregation + self.available_dispatch_instance_num = available_dispatch_instance_num \ No newline at end of file diff --git a/llumnix/entrypoints/llumnix_utils.py b/llumnix/entrypoints/llumnix_utils.py index 10573268..01d3564a 100644 --- a/llumnix/entrypoints/llumnix_utils.py +++ b/llumnix/entrypoints/llumnix_utils.py @@ -141,7 +141,7 @@ def init_llumlets(engine_manager_args: EngineManagerArgs, instance_ids = [random_uuid() for _ in range(engine_manager_args.initial_instances)] migration_configs = engine_manager_args.create_migration_config() - + print("??",engine_manager_args.initial_instances) for idx in range(engine_manager_args.initial_instances): instance_id = instance_ids[idx] if not engine_manager_args.profiling_result_file_path: @@ -200,6 +200,10 @@ def init_llumnix_components(engine_manager_args: EngineManagerArgs, ray.get(task) available_instance_ids.append(instance_ids[idx]) available_llumlets.append(llumlets[idx]) + except Exception as e: + import traceback + logger.error("unexpected exception occurs: {}".format(e)) + logger.error("exception traceback: {}".format(traceback.format_exc())) except ray.exceptions.RayActorError: dead_instance_ids.append(instance_ids[idx]) diff --git a/llumnix/global_scheduler/dispatch_scheduler.py b/llumnix/global_scheduler/dispatch_scheduler.py index 3f49da17..e6650504 100644 --- a/llumnix/global_scheduler/dispatch_scheduler.py +++ b/llumnix/global_scheduler/dispatch_scheduler.py @@ -24,11 +24,14 @@ class DispatchScheduler: def __init__(self, dispatch_policy: str, - instance_load_calculator: InstanceLoadCalculator) -> None: + instance_load_calculator: InstanceLoadCalculator, + available_dispatch_instance_num: int) -> None: self.dispatch_policy = DispatchPolicyFactory.get_policy(dispatch_policy) self.instance_load_calculator = instance_load_calculator self.num_instances = 0 self.instance_id_set: Set[str] = set() + self.available_dispatch_instance_set: Set[str] = set() + self.available_dispatch_instance_num = available_dispatch_instance_num # instance info args self.instance_info: Dict[str, InstanceInfo] = {} self.sorted_instance_infos: List[InstanceInfo] = None @@ -56,22 +59,26 @@ def update_instance_infos(self, def add_instance(self, instance_id: str) -> None: self.instance_id_set.add(instance_id) self.num_instances = len(self.instance_id_set) - self.instance_num_requests[instance_id] = 0 + if self.available_dispatch_instance_num == -1 or (self.available_dispatch_instance_num > 0 and len(self.available_dispatch_instance_set) < self.available_dispatch_instance_num): + self.available_dispatch_instance_set.add(instance_id) + self.instance_num_requests[instance_id] = 0 def remove_instance(self, instance_id: str) -> None: self.instance_id_set.remove(instance_id) self.num_instances = len(self.instance_id_set) - del self.instance_num_requests[instance_id] + if instance_id in self.instance_num_requests: + del self.instance_num_requests[instance_id] def _sort_instance_infos(self, descending: bool = True) -> None: instance_infos: List[InstanceInfo] = list(self.instance_info.values()) + available_instance_infos = [info for info in instance_infos if info.instance_id in self.available_dispatch_instance_set] if isinstance(self.dispatch_policy, Queue): key_attr = 'num_waiting_requests' else: key_attr = 'instance_load_dispatch_scale' self.sorted_instance_infos = sorted( - instance_infos, + available_instance_infos, key=lambda instance_info: getattr(instance_info, key_attr), reverse=descending ) diff --git a/llumnix/global_scheduler/global_scheduler.py b/llumnix/global_scheduler/global_scheduler.py index 9f872d4a..68dfb63c 100644 --- a/llumnix/global_scheduler/global_scheduler.py +++ b/llumnix/global_scheduler/global_scheduler.py @@ -30,16 +30,19 @@ def __init__(self, # instance load and instance info args self.load_metric = global_scheduler_config.load_metric self.enable_defrag = global_scheduler_config.enable_defrag + self.enable_pd_disaggregation = global_scheduler_config.enable_pd_disaggregation self.instance_load_calculator = InstanceLoadCalculator(load_metric=self.load_metric, enable_defrag=self.enable_defrag) # dispatch args self.dispatch_policy = global_scheduler_config.dispatch_policy self.dispatch_scheduler = DispatchScheduler(global_scheduler_config.dispatch_policy, - self.instance_load_calculator) + self.instance_load_calculator, + global_scheduler_config.available_dispatch_instance_num) # migrate args self.migration_scheduler = MigrationScheduler(global_scheduler_config.pair_migration_policy, global_scheduler_config.migrate_out_load_threshold, - self.instance_load_calculator) + self.instance_load_calculator, + global_scheduler_config.available_dispatch_instance_num) # auto-scaling args self.scaling_scheduler = ScalingScheduler(global_scheduler_config.scale_up_threshold, global_scheduler_config.scale_down_threshold, @@ -61,11 +64,12 @@ def update_instance_infos(self, instance_infos: List[InstanceInfo]) -> None: def dispatch(self) -> str: self.dispatch_scheduler.update_instance_infos(self.instance_info) instance_id = self.dispatch_scheduler.dispatch() - return instance_id + request_expected_steps = 1 if self.enable_pd_disaggregation else -1 + return instance_id, request_expected_steps - def pair_migration(self) -> List[Tuple[str, str]]: + def pair_migration(self, migration_target:str) -> List[Tuple[str, str]]: self.migration_scheduler.update_instance_infos(self.instance_info) - migrate_instance_pairs = self.migration_scheduler.pair_migration() + migrate_instance_pairs = self.migration_scheduler.pair_migration(migration_target) return migrate_instance_pairs def check_scale(self) -> Tuple[str, str]: diff --git a/llumnix/global_scheduler/migration_scheduler.py b/llumnix/global_scheduler/migration_scheduler.py index a87f833b..c5ebb5f8 100644 --- a/llumnix/global_scheduler/migration_scheduler.py +++ b/llumnix/global_scheduler/migration_scheduler.py @@ -13,6 +13,7 @@ from typing import Dict, List, Tuple, Set from abc import ABC, abstractmethod +from enum import Enum import copy import numpy as np @@ -21,12 +22,27 @@ logger = init_logger(__name__) +class PairMigrationConstraints(str, Enum): + """Target of Migration.""" + NO_CONSTRAINTS = "NO_CONSTRAINTS" + + # Enable the prefill-decoding disaggregration. + DECODING_2_DECODING = "DECODING_2_DECODING" + PREFILL_2_DECODING = "PREFILL_2_DECODING" + +class InstanceType(str, Enum): + NO_CONSTRAINTS = "NO_CONSTRAINTS" + + # Specific to Prefill-Decoding disaggregation. + PREFILL = "prefill" + DECODE = "decode" class MigrationScheduler: def __init__(self, pair_migration_policy: str, migrate_out_load_threshold: float, - instance_load_calculator: InstanceLoadCalculator) -> None: + instance_load_calculator: InstanceLoadCalculator, + constraint_prefill_instance_num: int) -> None: self.migrate_out_load_threshold = migrate_out_load_threshold self.instance_load_calculator = instance_load_calculator self.enable_defrag = instance_load_calculator.enable_defrag @@ -43,14 +59,38 @@ def __init__(self, self.num_instances = 0 self.instance_id_set: Set[str] = set() + self.instance_id_type_set: Dict[InstanceType, Set[str]] = {instance_type: set() for instance_type in InstanceType} + self.constraint_prefill_instance_num = constraint_prefill_instance_num # instance info args self.instance_info: Dict[str, InstanceInfo] = None - self.sorted_instance_infos: List[InstanceInfo] = None + self.sorted_instance_infos: Dict[str, List[InstanceInfo]] = {instance_type: set() for instance_type in InstanceType} - def pair_migration(self) -> List[Tuple[str, str]]: + def pair_migration(self, migration_target:str) -> List[Tuple[str, str]]: self._sort_instance_infos(descending=False) - return self.pair_migration_policy.pair_migration(self.sorted_instance_infos) - + sorted_src_instance_infos, sorted_dst_instance_infos, strict_pre_migration = self._get_migration_settings(migration_target) + return self.pair_migration_policy.pair_migration(sorted_src_instance_infos, sorted_dst_instance_infos, strict_pre_migration) + + def _get_migration_settings(self, migration_target:str) -> Dict[str, InstanceInfo]: + strict_pre_migration = True + if migration_target == PairMigrationConstraints.NO_CONSTRAINTS: + # migrate out instances + sorted_src_instance_infos = [i for i in reversed(self.sorted_instance_infos[InstanceType.NO_CONSTRAINTS]) + if i.num_killed_requests > 0 or i.instance_load_migrate > self.migrate_out_load_threshold] + # migrate in instances + sorted_dst_instance_infos = [i for i in self.sorted_instance_infos[InstanceType.NO_CONSTRAINTS] + if i.num_killed_requests == 0 and i.instance_load_migrate < self.migrate_out_load_threshold] + elif migration_target == PairMigrationConstraints.PREFILL_2_DECODING: + sorted_src_instance_infos = [i for i in reversed(self.sorted_instance_infos[InstanceType.PREFILL])] + sorted_dst_instance_infos = [i for i in self.sorted_instance_infos[InstanceType.DECODE] + if i.num_killed_requests == 0] + # TODO[xinyi]: Considering decoding instances load, try to decode on the prefill instance(set strict_pre_migration as False). + elif migration_target == PairMigrationConstraints.DECODING_2_DECODING: + sorted_src_instance_infos = [i for i in reversed(self.sorted_instance_infos[InstanceType.DECODE]) + if i.num_killed_requests > 0 or i.instance_load_migrate > self.migrate_out_load_threshold] + sorted_dst_instance_infos = [i for i in self.sorted_instance_infos[InstanceType.DECODE] + if i.num_killed_requests == 0 and i.instance_load_migrate < self.migrate_out_load_threshold] + return sorted_src_instance_infos, sorted_dst_instance_infos, strict_pre_migration + def update_instance_infos(self, instance_info: Dict[str, InstanceInfo]) -> None: self.instance_info = instance_info @@ -58,6 +98,13 @@ def update_instance_infos(self, def add_instance(self, instance_id: str) -> None: self.instance_id_set.add(instance_id) self.num_instances = len(self.instance_id_set) + if self.constraint_prefill_instance_num > 0: + if len(self.instance_id_type_set[InstanceType.PREFILL]) < self.constraint_prefill_instance_num: + self.instance_id_type_set[InstanceType.PREFILL].add(instance_id) + else: + self.instance_id_type_set[InstanceType.DECODE].add(instance_id) + else: + self.instance_id_type_set[InstanceType.NO_CONSTRAINTS].add(instance_id) def remove_instance(self, instance_id: str) -> None: self.instance_id_set.remove(instance_id) @@ -66,12 +113,15 @@ def remove_instance(self, instance_id: str) -> None: def _sort_instance_infos(self, descending: bool = True) -> None: instance_infos: List[InstanceInfo] = list(self.instance_info.values()) + filter_instance_infos = {instance_type: set() for instance_type in InstanceType} key_attr = 'instance_load_migrate' - self.sorted_instance_infos = sorted( - instance_infos, - key=lambda instance_info: getattr(instance_info, key_attr), - reverse=descending - ) + for inst_type in InstanceType: + filter_instance_infos[inst_type] = [info for info in instance_infos if info.instance_id in self.instance_id_type_set[inst_type]] + self.sorted_instance_infos[inst_type] = sorted( + filter_instance_infos[inst_type], + key=lambda instance_info: getattr(instance_info, key_attr), + reverse=descending + ) class PairMigrationPolicy(ABC): def __init__(self, @@ -82,31 +132,29 @@ def __init__(self, @abstractmethod def pair_migration(self, - sorted_instance_infos: List[InstanceInfo] - ) -> List[Tuple[str, str]]: + sorted_src_instance_infos: List[InstanceInfo], + sorted_dst_instance_infos: List[InstanceInfo], + strict_pre_migration: bool, + ) -> List[Tuple[str, str]]: raise NotImplementedError class Balanced(PairMigrationPolicy): def pair_migration(self, - sorted_instance_infos: List[InstanceInfo] + sorted_src_instance_infos: List[InstanceInfo], + sorted_dst_instance_infos: List[InstanceInfo], + strict_pre_migration: bool, ) -> List[Tuple[str, str]]: - # migrate in instances - migrate_in_instance_infos = [i for i in sorted_instance_infos - if i.num_killed_requests == 0 and i.instance_load_migrate < self.migrate_out_load_threshold] - # migrate out instances - migrate_out_instance_infos = [i for i in reversed(sorted_instance_infos) - if i.num_killed_requests > 0 or i.instance_load_migrate > self.migrate_out_load_threshold] migrate_instance_pairs = [] - for i in range(min(len(migrate_in_instance_infos), len(migrate_out_instance_infos))): - load_diff_before_mig = migrate_out_instance_infos[i].instance_load_migrate - migrate_in_instance_infos[i].instance_load_migrate - left_load_after_mig = self._compute_instance_load_after_migrate(migrate_in_instance_infos[i], is_migrate_in=True) - right_load_after_mig = self._compute_instance_load_after_migrate(migrate_out_instance_infos[i], is_migrate_in=False) + for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))): + load_diff_before_mig = sorted_src_instance_infos[i].instance_load_migrate - sorted_dst_instance_infos[i].instance_load_migrate + left_load_after_mig = self._compute_instance_load_after_migrate(sorted_src_instance_infos[i], is_migrate_in=False) + right_load_after_mig = self._compute_instance_load_after_migrate(sorted_dst_instance_infos[i], is_migrate_in=True) # Add some constrains to reduce unnecessary migrations - if left_load_after_mig > self.migrate_out_load_threshold: + if right_load_after_mig > self.migrate_out_load_threshold: continue - load_diff_after_mig = right_load_after_mig - left_load_after_mig - if (0 < load_diff_after_mig < load_diff_before_mig) or (migrate_in_instance_infos[i].instance_load_migrate == -np.inf): - migrate_instance_pairs.append((migrate_out_instance_infos[i].instance_id, migrate_in_instance_infos[i].instance_id)) + load_diff_after_mig = left_load_after_mig - right_load_after_mig + if (0 < load_diff_after_mig < load_diff_before_mig) or (sorted_dst_instance_infos[i].instance_load_migrate == -np.inf): + migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id, sorted_dst_instance_infos[i].instance_id, strict_pre_migration)) return migrate_instance_pairs def _compute_instance_load_after_migrate(self, instance_info: InstanceInfo, is_migrate_in: bool) -> float: @@ -122,41 +170,20 @@ def _compute_instance_load_after_migrate(self, instance_info: InstanceInfo, is_m class DefragConstrained(PairMigrationPolicy): def pair_migration(self, - sorted_instance_infos: List[InstanceInfo] - ) -> List[Tuple[str, str]]: - # migrate in instances - migrate_in_instance_infos = [i for i in sorted_instance_infos - if i.num_killed_requests == 0 and i.instance_load_migrate < self.migrate_out_load_threshold] - # migrate out instances - migrate_out_instance_infos = [i for i in reversed(sorted_instance_infos) - if i.num_killed_requests > 0 or i.instance_load_migrate > self.migrate_out_load_threshold] - migrate_instance_pairs = [] - for i in range(min(len(migrate_in_instance_infos), len(migrate_out_instance_infos))): - # without any constrain in order to make defragmentation migrate happens as soon as possible - migrate_instance_pairs.append((migrate_out_instance_infos[i].instance_id, migrate_in_instance_infos[i].instance_id)) - return migrate_instance_pairs - -class DefragRelaxed(PairMigrationPolicy): - def pair_migration(self, - sorted_instance_infos: List[InstanceInfo] + sorted_src_instance_infos: List[InstanceInfo], + sorted_dst_instance_infos: List[InstanceInfo], + strict_pre_migration: bool, ) -> List[Tuple[str, str]]: - # migrate in instances - migrate_in_instance_infos = [i for i in sorted_instance_infos - if i.num_killed_requests == 0 and i.instance_load_migrate < self.migrate_out_load_threshold] - # migrate out instances - migrate_out_instance_infos = list(reversed(sorted_instance_infos)) migrate_instance_pairs = [] - for i in range(min(len(migrate_in_instance_infos), len(migrate_out_instance_infos))): - if migrate_out_instance_infos[i].num_killed_requests != 0 \ - or migrate_out_instance_infos[i].instance_load_migrate > migrate_in_instance_infos[i].instance_load_migrate: - migrate_instance_pairs.append((migrate_out_instance_infos[i].instance_id, migrate_in_instance_infos[i].instance_id)) + for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))): + # without any constrain in order to make prefill migrate happens as soon as possible + migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id, sorted_dst_instance_infos[i].instance_id, strict_pre_migration)) return migrate_instance_pairs class PairMigrationPolicyFactory: _POLICY_REGISTRY = { 'balanced': Balanced, 'defrag_constrained': DefragConstrained, - 'defrag_relaxed': DefragRelaxed, } @classmethod diff --git a/llumnix/llm_engine_manager.py b/llumnix/llm_engine_manager.py index 489fd124..8b95de54 100644 --- a/llumnix/llm_engine_manager.py +++ b/llumnix/llm_engine_manager.py @@ -23,6 +23,7 @@ from llumnix.llumlet.llumlet import Llumlet from llumnix.logger import init_logger from llumnix.global_scheduler.global_scheduler import GlobalScheduler +from llumnix.global_scheduler.migration_scheduler import PairMigrationConstraints from llumnix.instance_info import InstanceInfo from llumnix.config import GlobalSchedulerConfig from llumnix.arg_utils import EngineManagerArgs @@ -61,6 +62,8 @@ def __init__(self, self.max_instances = engine_manager_args.max_instances self.min_instances = engine_manager_args.min_instances + self.enable_pd_disaggregation = global_scheduler_config.enable_pd_disaggregation + logger.info("LLMEngineManager starts") logger.info("enable_migration: {}".format(self.enable_migration)) logger.info("num_instances: {}".format(self.num_instances)) @@ -112,9 +115,10 @@ async def generate( logger.info("No instance available temporarily, sleep {}s, " "and retry generate request {} again....".format(RETRIES_INTERVALS, request_id)) await asyncio.sleep(RETRIES_INTERVALS) - instance_id = self.global_scheduler.dispatch() + try: - await self.instances[instance_id].generate.remote(request_id, server_info, *args, **kwargs) + instance_id, request_expected_steps = self.global_scheduler.dispatch() + await self.instances[instance_id].generate.remote(request_id, server_info, request_expected_steps, *args, **kwargs) if self.log_requests: logger.info("received request {}.".format(request_id)) logger.info("dispath to instance {}".format(instance_id)) @@ -122,6 +126,10 @@ async def generate( except (ray.exceptions.RayActorError, KeyError): logger.info("[generate] instance {} is dead, regenerate request {}".format(instance_id, request_id)) self.scale_down(instance_id) + except Exception as e: + import traceback + logger.error("unexpected exception occurs: {}".format(e)) + logger.error("exception traceback: {}".format(traceback.format_exc())) async def abort(self, request_id: Union[str, Iterable[str]]) -> None: if isinstance(request_id, str): @@ -192,7 +200,7 @@ async def _update_instance_info_loop(self, interval: float) -> None: # Push migrate when the instance_info have updated a certain number of times. if self.enable_migration and self.num_instance_info_updates != 0 \ and self.num_instance_info_updates % self.pair_migration_frequency == 0: - asyncio.create_task(self._migrate()) + asyncio.create_task(self._migrate_control()) if self.log_instance_info: self._log_instance_infos_to_csv(instance_info_list) # pylint: disable=W0703 @@ -230,24 +238,35 @@ async def _post_migrate(self, rets: List[str], call_migrate_instance_pairs: List if migrate_out_request_ids: migrate_out_request_id = migrate_out_request_ids[0] self.request_instance[migrate_out_request_id] = call_migrate_instance_pairs[i][1] - logger.info("{}->{} migrate done, migrate request {}".format( - call_migrate_instance_pairs[i][0], call_migrate_instance_pairs[i][1], migrate_out_request_ids)) - - async def _migrate(self) -> None: - migrate_instance_pairs = self.global_scheduler.pair_migration() + logger.info("{}->{} migrate done, migrate request {}".format( + call_migrate_instance_pairs[i][0], call_migrate_instance_pairs[i][1], migrate_out_request_ids)) + + async def _migrate_control(self) -> None: + # Push migrate when the instance_info have updated a certain number of times. + if self.enable_pd_disaggregation: + asyncio.create_task(self._migrate(PairMigrationConstraints.PREFILL_2_DECODING, -1)) + asyncio.create_task(self._migrate(PairMigrationConstraints.DECODING_2_DECODING, 1)) + else: + asyncio.create_task(self._migrate(PairMigrationConstraints.NO_CONSTRAINTS, 1)) + + async def _migrate(self, migration_target:str, migrate_in_num_requests:int) -> None: + migrate_instance_pairs = self.global_scheduler.pair_migration(migration_target) + # if len(migrate_instance_pairs)>0: + # logger.info("[_migrate] migrate_instance_pairs {} {}".format(migration_target, migrate_instance_pairs)) try: migration_tasks = [] call_migrate_instance_pairs: List[Tuple[str, str]] = [] for _, migrate_instance_pair in enumerate(migrate_instance_pairs): - migrate_out_instance_id, migrate_in_instance_id = migrate_instance_pair + migrate_out_instance_id, migrate_in_instance_id, strict_pre_migration = migrate_instance_pair + # logger.info("[_migrate] migrate_instance_pairs {} {} {} {} {}".format(migration_target, migrate_out_instance_id, migrate_in_instance_id, self.instance_migrating[migrate_out_instance_id], self.instance_migrating[migrate_in_instance_id])) if self.instance_migrating[migrate_out_instance_id] or self.instance_migrating[migrate_in_instance_id]: continue + # logger.info("[_migrate] migrate_instance_pairs {} {} {} ".format(migration_target, migrate_out_instance_id, migrate_in_instance_id)) self.instance_migrating[migrate_out_instance_id] = True self.instance_migrating[migrate_in_instance_id] = True migrate_in_instance_name = "instance_{}".format(migrate_in_instance_id) - logger.info("{}->{} begin migrate out".format(migrate_out_instance_id, migrate_in_instance_id)) call_migrate_instance_pairs.append(migrate_instance_pair) - task = self.instances[migrate_out_instance_id].migrate_out.remote(migrate_in_instance_name) + task = self.instances[migrate_out_instance_id].migrate_out.remote(migrate_in_instance_name, migrate_in_num_requests, strict_pre_migration) migration_tasks.append(task) # TODO(s5u13b): Migration failover could be implemented in Llumlet rather than manager. rets = await asyncio.gather(*migration_tasks, return_exceptions=True) diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index 712c145e..bef8dad1 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -38,6 +38,8 @@ def __init__(self, **kwargs) -> None: self.instance_id = instance_id self.actor_name = f"instance_{instance_id}" + self.strict_pre_migration = True + self.backend_engine: BackendInterface = init_backend_engine(self.instance_id, backend_type, migration_config, @@ -47,7 +49,7 @@ def __init__(self, migration_config.last_stage_max_blocks, migration_config.max_stages) self.migration_scheduler = LocalMigrationScheduler(migration_config.request_migration_policy, - self.backend_engine) + self.backend_engine, self.strict_pre_migration) self.log_requests = True @classmethod @@ -98,32 +100,45 @@ def from_args(cls, llumlet = engine_class.remote(instance_id, backend_type, migration_config, *args, **kwargs) return llumlet - def migrate_out(self, dst_instance_name: str) -> List[str]: + def migrate_out(self, dst_instance_name: str, num_requests: int, strict_pre_migration: bool) -> List[str]: try: - t0 = time.time() + self.update_strict_pre_migration(strict_pre_migration) migrate_in_ray_actor = ray.get_actor(dst_instance_name, namespace='llumnix') dst_instance_id = dst_instance_name[len("instance_"):] - logger.info("{}->{} begin migrate out".format(self.instance_id, dst_instance_id)) - migrate_out_request = self.migration_scheduler.get_migrate_out_request() migrated_request_list = [] - if migrate_out_request is None: - return migrated_request_list - status = self.migration_coordinator.migrate_out_multistage(migrate_in_ray_actor, migrate_out_request) - if status == MigrationStatus.FINISHED_DONE: - ray.get(migrate_in_ray_actor.execute_engine_method.remote("commit_dst_request", migrate_out_request)) - self.backend_engine.free_src_request(migrate_out_request) - migrated_request_list.append(migrate_out_request.request_id) - migrate_out_request.stage_timestamps.append(time.time()) - self.backend_engine.remove_migrating_out_request_last_stage(migrate_out_request) - else: - ray.get(migrate_in_ray_actor.execute_migration_method.remote("free_dst_pre_alloc_cache", migrate_out_request.request_id)) - t1 = time.time() - logger.info("{}->{} migrate done, migrate request {}, status:{}, len:{} blocks, cost:{} ms" \ - .format(self.instance_id, dst_instance_id, migrated_request_list, status, \ - sum(migrate_out_request.stage_num_blocks_list), (t1 - t0)*1000)) + continue_migrate = True + while continue_migrate: + t0 = time.time() + migrate_out_request = self.migration_scheduler.get_migrate_out_request() + if migrate_out_request is not None: + logger.info("migrate_out {}".format(migrate_out_request.request_id)) + if migrate_out_request is None or not self.strict_pre_migration: + return migrated_request_list + logger.info("{}->{} begin migrate out {} {}".format(self.instance_id, dst_instance_id, migrate_out_request.request_id, migrate_out_request.expected_steps)) + status = self.migration_coordinator.migrate_out_multistage(migrate_in_ray_actor, migrate_out_request) + if status == MigrationStatus.FINISHED_DONE: + migrate_out_request.expected_steps = -1 + ray.get(migrate_in_ray_actor.execute_engine_method.remote("commit_dst_request", migrate_out_request)) + self.backend_engine.free_src_request(migrate_out_request) + migrated_request_list.append(migrate_out_request.request_id) + migrate_out_request.stage_timestamps.append(time.time()) + self.backend_engine.remove_migrating_out_request_last_stage(migrate_out_request) + else: + ray.get(migrate_in_ray_actor.execute_migration_method.remote("free_dst_pre_alloc_cache", migrate_out_request.request_id)) + continue_migrate = False + if num_requests == 1: + continue_migrate = False + t1 = time.time() + logger.info("{}->{} migrate done, migrate request {}, status:{}, len:{} blocks, cost:{} ms" \ + .format(self.instance_id, dst_instance_id, migrated_request_list, status, \ + sum(migrate_out_request.stage_num_blocks_list), (t1 - t0)*1000)) except ray.exceptions.RayActorError: logger.info("[migrate_out] instance {} is dead".format(dst_instance_name[len("instance_"):])) raise + except Exception as e: + import traceback + logger.error("unexpected exception occurs: {}".format(e)) + logger.error("exception traceback: {}".format(traceback.format_exc())) return migrated_request_list def get_instance_info(self) -> InstanceInfo: @@ -134,16 +149,23 @@ def is_ready(self) -> bool: def get_all_request_ids(self) -> List[str]: return self.backend_engine.get_all_request_ids() + + def update_strict_pre_migration(self, new_strict_pre_migration:str) -> None: + if self.strict_pre_migration != new_strict_pre_migration: + self.strict_pre_migration = new_strict_pre_migration + self.backend_engine.update_strict_pre_migration(new_strict_pre_migration) + self.migration_scheduler.strict_pre_migration = new_strict_pre_migration def generate( self, request_id: str, server_info: ServerInfo, + expected_steps: int, *args, **kwargs, ) -> None: # This should not be used for logging, as it is monotonic time. - self.backend_engine.add_request(request_id, server_info, *args, **kwargs) + self.backend_engine.add_request(request_id, server_info, expected_steps, *args, **kwargs) def abort(self, request_id: Union[str, Iterable[str]]) -> None: if isinstance(request_id, str): diff --git a/llumnix/llumlet/local_migration_scheduler.py b/llumnix/llumlet/local_migration_scheduler.py index c30ea20a..8675a738 100644 --- a/llumnix/llumlet/local_migration_scheduler.py +++ b/llumnix/llumlet/local_migration_scheduler.py @@ -18,20 +18,34 @@ from llumnix.backends.backend_interface import BackendInterface class LocalMigrationScheduler: - def __init__(self, request_migration_policy: str, backend_engine: BackendInterface) -> None: + def __init__(self, request_migration_policy: str, backend_engine: BackendInterface, strict_pre_migration: bool) -> None: self.request_migration_policy = request_migration_policy self.backend_engine = backend_engine - + self.strict_pre_migration = strict_pre_migration def get_migrate_out_request(self, min_request_len=0, max_request_len=np.inf) -> Optional[LlumnixRequest]: - migrate_out_request: LlumnixRequest = None - if self.request_migration_policy == 'LCFS': - migrate_out_request = self.get_last_running_request(min_request_len, max_request_len) - elif self.request_migration_policy == 'LJF': - migrate_out_request = self.get_longest_running_request(min_request_len, max_request_len) - elif self.request_migration_policy == 'SJF': - migrate_out_request = self.get_shortest_running_request(min_request_len, max_request_len) + # Requests meet the strict pre-migration always have higher prioirity than other migration policy. + migrate_out_request = self.get_ready_migration_request(min_request_len, max_request_len) + if migrate_out_request is None: + if self.request_migration_policy == 'LCFS': + migrate_out_request = self.get_last_running_request(min_request_len, max_request_len) + elif self.request_migration_policy == 'LJF': + migrate_out_request = self.get_longest_running_request(min_request_len, max_request_len) + elif self.request_migration_policy == 'SJF': + migrate_out_request = self.get_shortest_running_request(min_request_len, max_request_len) return migrate_out_request - + + # TODO(xinyi): Currently, the function is only used for Prefill-decoding disaggregation, + # and only selects request that migrates from the prefill instance to the decoding instance. + def get_ready_migration_request(self, min_request_len, max_request_len): + if self.strict_pre_migration: + running: List[LlumnixRequest] = self.backend_engine.get_running_queue() + for request in reversed(running): + if request.expected_steps > 0 and request.output_len >= request.expected_steps \ + and request.inference_type == RequestInferenceType.DECODE \ + and min_request_len <= request.request_len <= max_request_len: + return request + return None + def get_last_running_request(self, min_request_len, max_request_len): running: List[LlumnixRequest] = self.backend_engine.get_running_queue() for request in reversed(running): diff --git a/llumnix/llumlet/migration_coordinator.py b/llumnix/llumlet/migration_coordinator.py index c06fe4c5..4ebab3a2 100644 --- a/llumnix/llumlet/migration_coordinator.py +++ b/llumnix/llumlet/migration_coordinator.py @@ -53,7 +53,7 @@ def migrate_out_onestage(self, migrate_in_ray_actor: "ray.actor.ActorHandle", m # live migration, transfer all blocks except last one(currently updating) migration_status = MigrationStatus.RUNNING is_last_stage = (len(incremental_blocks) <= self.last_stage_max_blocks) - if not is_last_stage: + if not is_last_stage and migrate_out_request.blocking_migration: src_blocks = incremental_blocks[:-1] stage_block_num = len(incremental_blocks) - 1 dst_blocks = ray.get(migrate_in_ray_actor.execute_migration_method \ @@ -70,7 +70,7 @@ def migrate_out_onestage(self, migrate_in_ray_actor: "ray.actor.ActorHandle", m if len(dst_blocks) != len(src_blocks): # migrate-in instance failed to prev alloc - if is_last_stage: + if is_last_stage or not migrate_out_request.blocking_migration: self.backend_engine.add_running_request(migrate_out_request) self.backend_engine.remove_migrating_out_request_last_stage(migrate_out_request) migration_status = MigrationStatus.FINISHED_ABORTED @@ -80,7 +80,7 @@ def migrate_out_onestage(self, migrate_in_ray_actor: "ray.actor.ActorHandle", m migrate_out_request.stage_num_blocks_list.append(stage_block_num) # TODO(ZeldaHuang): send_blocks in migrate_in_pre_alloc/migrate_in_last_stage self.backend_engine.send_blocks(migrate_in_ray_actor, src_blocks, dst_blocks) - if not is_last_stage and migrate_out_request.should_abort_migration(): + if not is_last_stage and migrate_out_request.blocking_migration and migrate_out_request.should_abort_migration(): # migrate-out request abort by scheduler during send/recv migration_status = MigrationStatus.FINISHED_ABORTED @@ -99,7 +99,7 @@ def migrate_out_multistage(self, migrate_in_ray_actor: "ray.actor.ActorHandle", return status # exceed max stages return MigrationStatus.FINISHED_ABORTED - + def migrate_in_pre_alloc(self, request_id: str, block_num: int) -> List[int]: """prev alloc blocks to migrate in request """ diff --git a/llumnix/llumlet/request.py b/llumnix/llumlet/request.py index e96513f1..e9a5928e 100644 --- a/llumnix/llumlet/request.py +++ b/llumnix/llumlet/request.py @@ -20,10 +20,13 @@ class RequestInferenceType(str, Enum): DECODE = "decode" class LlumnixRequest: - def __init__(self, request_id: int, server_info: ServerInfo) -> None: + def __init__(self, request_id: int, server_info: ServerInfo, request_expected_steps: int) -> None: self.request_id = request_id self.server_info = server_info + # strict pre-migration args + self.expected_steps = request_expected_steps + # migration args self.last_preemption_time = None self.stage_timestamps = [] @@ -33,6 +36,8 @@ def reset_migration_args(self): self.last_preemption_time = None self.stage_timestamps = [] self.stage_num_blocks_list = [] + # By default, there is no limit on the number of steps expected for the request." + self.expected_steps = -1 @property def inference_type(self) -> RequestInferenceType: @@ -49,6 +54,10 @@ def prompt_len(self) -> int: @property def output_len(self) -> int: raise NotImplementedError + + @property + def blocking_migration(self) -> bool: + return self.expected_steps < 0 or (self.expected_steps > 0 and self.expected_steps < self.output_len) def should_abort_migration(self) -> bool: return self.output_len == 0 \ diff --git a/tests/backends/vllm/test_migration.py b/tests/backends/vllm/test_migration.py index c0b83994..8e639253 100644 --- a/tests/backends/vllm/test_migration.py +++ b/tests/backends/vllm/test_migration.py @@ -18,7 +18,7 @@ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy from vllm import EngineArgs, SamplingParams -from vllm.utils import random_uuid +from llumnix.utils import random_uuid from llumnix.backends.vllm.llm_engine import BackendVLLM from llumnix.llumlet.llumlet import Llumlet diff --git a/tests/global_scheduler/test_dispatch_scheduler.py b/tests/global_scheduler/test_dispatch_scheduler.py index bcc58a06..d4d1bf69 100644 --- a/tests/global_scheduler/test_dispatch_scheduler.py +++ b/tests/global_scheduler/test_dispatch_scheduler.py @@ -17,10 +17,11 @@ from llumnix.instance_info import InstanceLoadCalculator, InstanceInfo from llumnix.global_scheduler.dispatch_scheduler import DispatchScheduler +INSTANCE_NUM = 4 def init_dispatch_scheduler(policy='load'): instance_load_calculator = InstanceLoadCalculator('remaining_steps', True) - dispatch_scheduler = DispatchScheduler(policy, instance_load_calculator) + dispatch_scheduler = DispatchScheduler(policy, instance_load_calculator, random.randint(-1,4)) return dispatch_scheduler @pytest.fixture @@ -39,51 +40,60 @@ def test_add_instance_and_remove_instance(dispatch_scheduler): assert dispatch_scheduler.num_instances == 0 def test_dispatch_balanced(): - dispatch_scheduler = init_dispatch_scheduler('balanced') num_tests = 100 for _ in range(num_tests): + dispatch_scheduler = init_dispatch_scheduler('balanced') instance_num_requests = {} - for instance_id in ['instance_1', 'instance_2', 'instance_3', 'instance_4']: - instance_num_requests[instance_id] = random.randint(1, 10) + for instance_id in [f'instance_{i}' for i in range(1, INSTANCE_NUM + 1)]: + if dispatch_scheduler.available_dispatch_instance_num <= 0 or (dispatch_scheduler.available_dispatch_instance_num > 0 and len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.available_dispatch_instance_num): + dispatch_scheduler.available_dispatch_instance_set.add(instance_id) + instance_num_requests[instance_id] = random.randint(1, 10) dispatch_scheduler.instance_num_requests = instance_num_requests min_instance_id = next(key for key, value in sorted(instance_num_requests.items(), key=lambda item: item[1])) instance_id = dispatch_scheduler.dispatch() assert min_instance_id == instance_id def test_dispatch_load(): - dispatch_scheduler = init_dispatch_scheduler('load') num_tests = 100 for _ in range(num_tests): + dispatch_scheduler = init_dispatch_scheduler('load') instance_num_requests = {} instance_info_dict = {} - for instance_id in ['instance_1', 'instance_2', 'instance_3', 'instance_4']: - instance_num_requests[instance_id] = 0 + for instance_id in [f'instance_{i}' for i in range(1, INSTANCE_NUM + 1)]: instance_info = InstanceInfo() instance_info.instance_id = instance_id instance_info.instance_load_dispatch_scale = random.random() instance_info_dict[instance_id] = instance_info + if dispatch_scheduler.available_dispatch_instance_num <= 0 or (dispatch_scheduler.available_dispatch_instance_num > 0 and len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.available_dispatch_instance_num): + dispatch_scheduler.available_dispatch_instance_set.add(instance_id) + instance_num_requests[instance_id] = 0 dispatch_scheduler.instance_num_requests = instance_num_requests dispatch_scheduler.instance_info = instance_info_dict - min_instance_id = next(key for key, value in sorted(instance_info_dict.items(), + available_instance_dict = {key: instance_info_dict[key] for key in instance_info_dict if key in dispatch_scheduler.available_dispatch_instance_set} + min_instance_id = next(key for key, value in sorted(available_instance_dict.items(), key=lambda item: item[1].instance_load_dispatch_scale)) instance_id = dispatch_scheduler.dispatch() assert min_instance_id == instance_id def test_dispatch_queue(): - dispatch_scheduler = init_dispatch_scheduler('queue') num_tests = 100 for _ in range(num_tests): + dispatch_scheduler = init_dispatch_scheduler('queue') instance_num_requests = {} instance_info_dict = {} - for instance_id in ['instance_1', 'instance_2', 'instance_3', 'instance_4']: - instance_num_requests[instance_id] = 0 + for instance_id in [f'instance_{i}' for i in range(1, INSTANCE_NUM + 1)]: instance_info = InstanceInfo() instance_info.instance_id = instance_id instance_info.num_waiting_requests = random.randint(1, 10) instance_info_dict[instance_id] = instance_info + if dispatch_scheduler.available_dispatch_instance_num <= 0 or (dispatch_scheduler.available_dispatch_instance_num > 0 and len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.available_dispatch_instance_num): + dispatch_scheduler.available_dispatch_instance_set.add(instance_id) + print(instance_id) + instance_num_requests[instance_id] = 0 dispatch_scheduler.instance_num_requests = instance_num_requests dispatch_scheduler.instance_info = instance_info_dict - min_instance_id = next(key for key, value in sorted(instance_info_dict.items(), + available_instance_dict = {key: instance_info_dict[key] for key in instance_info_dict if key in dispatch_scheduler.available_dispatch_instance_set} + min_instance_id = next(key for key, value in sorted(available_instance_dict.items(), key=lambda item: item[1].num_waiting_requests)) instance_id = dispatch_scheduler.dispatch() assert instance_info_dict[min_instance_id].num_waiting_requests == instance_info_dict[instance_id].num_waiting_requests diff --git a/tests/global_scheduler/test_global_scheduler.py b/tests/global_scheduler/test_global_scheduler.py index 7ba94145..05279a16 100644 --- a/tests/global_scheduler/test_global_scheduler.py +++ b/tests/global_scheduler/test_global_scheduler.py @@ -13,17 +13,16 @@ import pytest -from vllm.utils import random_uuid - from llumnix.config import GlobalSchedulerConfig from llumnix.global_scheduler.global_scheduler import GlobalScheduler from llumnix.instance_info import InstanceInfo +from llumnix.utils import random_uuid from tests.global_scheduler.test_llm_engine_manager import get_instance_info_migrate_in, get_instance_info_migrate_out def init_global_scheduler(): - global_scheduler_config = GlobalSchedulerConfig(0, 'remaining_steps', 'load', 'defrag_constrained', 3.0, True, 'avg_load', 10, 60) + global_scheduler_config = GlobalSchedulerConfig(0, 'remaining_steps', 'load', 'defrag_constrained', 3.0, True, 'avg_load', 10, 60, False, -1) global_scheduler = GlobalScheduler(global_scheduler_config) return global_scheduler @@ -70,8 +69,9 @@ def test_dispatch(global_scheduler): instance_ids = [instance_info.instance_id for instance_info in instance_infos] global_scheduler.scale_up(instance_ids) global_scheduler.update_instance_infos(instance_infos) - instance_id = global_scheduler.dispatch() + instance_id, request_expected_steps = global_scheduler.dispatch() assert instance_id in instance_ids + assert request_expected_steps in [-1, 1] def test_pair_migration(global_scheduler): instance_id = random_uuid() @@ -82,6 +82,6 @@ def test_pair_migration(global_scheduler): instance_infos = [instance_info_migrate_in, instance_info_migrate_out] global_scheduler.scale_up(instance_ids) global_scheduler.update_instance_infos(instance_infos) - migrate_instace_pairs = global_scheduler.pair_migration() + migrate_instace_pairs = global_scheduler.pair_migration("NO_CONSTRAINTS") assert migrate_instace_pairs[0][0] == instance_id_1 assert migrate_instace_pairs[0][1] == instance_id diff --git a/tests/global_scheduler/test_llm_engine_manager.py b/tests/global_scheduler/test_llm_engine_manager.py index 8c88459a..f0d1057a 100644 --- a/tests/global_scheduler/test_llm_engine_manager.py +++ b/tests/global_scheduler/test_llm_engine_manager.py @@ -17,9 +17,7 @@ import pytest import numpy as np -from vllm.utils import random_uuid -from vllm import EngineArgs - +from llumnix.utils import random_uuid from llumnix.arg_utils import EngineManagerArgs from llumnix.llm_engine_manager import LLMEngineManager, MANAGER_ACTOR_NAME from llumnix.instance_info import InstanceInfo diff --git a/tests/global_scheduler/test_migration_scheduler.py b/tests/global_scheduler/test_migration_scheduler.py index 0e69c81f..b493b6c3 100644 --- a/tests/global_scheduler/test_migration_scheduler.py +++ b/tests/global_scheduler/test_migration_scheduler.py @@ -16,15 +16,14 @@ import numpy as np from llumnix.instance_info import InstanceLoadCalculator, InstanceInfo -from llumnix.global_scheduler.migration_scheduler import MigrationScheduler - +from llumnix.global_scheduler.migration_scheduler import MigrationScheduler, PairMigrationConstraints, InstanceType MIGRATE_OUT_LOAD_THRESHOLD = 3.0 - +INSTANCE_NUM = 4 def init_migration_scheduler(policy='balanced'): instance_load_calculator = InstanceLoadCalculator('remaining_steps', True) - migration_scheduler = MigrationScheduler(policy, MIGRATE_OUT_LOAD_THRESHOLD, instance_load_calculator) + migration_scheduler = MigrationScheduler(policy, MIGRATE_OUT_LOAD_THRESHOLD, instance_load_calculator, random.randint(-1, INSTANCE_NUM)) return migration_scheduler @pytest.fixture @@ -35,36 +34,61 @@ def migration_scheduler(): def test_add_instance_and_remove_instance(migration_scheduler): migration_scheduler.add_instance('instance_1') assert migration_scheduler.num_instances == 1 + if migration_scheduler.constraint_prefill_instance_num <= 0: + assert len(migration_scheduler.instance_id_type_set[InstanceType.NO_CONSTRAINTS]) == 1 + else: + assert len(migration_scheduler.instance_id_type_set[InstanceType.PREFILL]) == 1 migration_scheduler.add_instance('instance_2') assert migration_scheduler.num_instances == 2 + if migration_scheduler.constraint_prefill_instance_num <= 0: + assert len(migration_scheduler.instance_id_type_set[InstanceType.NO_CONSTRAINTS]) == 2 + else: + assert len(migration_scheduler.instance_id_type_set[InstanceType.PREFILL]) == min(2, migration_scheduler.constraint_prefill_instance_num) + assert len(migration_scheduler.instance_id_type_set[InstanceType.DECODE]) == max(2 - migration_scheduler.constraint_prefill_instance_num, 0) migration_scheduler.remove_instance('instance_1') assert migration_scheduler.num_instances == 1 migration_scheduler.remove_instance('instance_2') assert migration_scheduler.num_instances == 0 -@pytest.mark.parametrize("policy", ['balanced', 'defrag_constrained', 'defrag_relaxed']) -def test_pair_migration(policy): - migration_scheduler = init_migration_scheduler(policy) +@pytest.mark.parametrize("policy, migration_target", [ + ('balanced', 'NO_CONSTRAINTS'), + ('defrag_constrained', 'NO_CONSTRAINTS'), + ('balanced', 'DECODING_2_DECODING'), + ('defrag_constrained', 'DECODING_2_DECODING'), + ('balanced', 'PREFILL_2_DECODING'), + ('defrag_constrained', 'PREFILL_2_DECODING')]) +def test_pair_migration(policy, migration_target): num_tests = 1000 for _ in range(num_tests): + migration_scheduler = init_migration_scheduler(policy) instance_info_dict = {} - for instance_id in ['instance_1', 'instance_2', 'instance_3', 'instance_4']: + for instance_id in [f'instance_{i}' for i in range(1, INSTANCE_NUM + 1)]: instance_info = InstanceInfo() instance_info.instance_id = instance_id instance_info.instance_load_migrate = MIGRATE_OUT_LOAD_THRESHOLD + random.uniform(-1, 1) instance_info.num_killed_requests = random.randint(0, 1) instance_info.num_blocks_last_running_request = random.randint(0, 1) * np.inf instance_info_dict[instance_id] = instance_info + if migration_scheduler.constraint_prefill_instance_num > 0: + if len(migration_scheduler.instance_id_type_set[InstanceType.PREFILL]) < migration_scheduler.constraint_prefill_instance_num: + migration_scheduler.instance_id_type_set[InstanceType.PREFILL].add(instance_id) + else: + migration_scheduler.instance_id_type_set[InstanceType.DECODE].add(instance_id) + else: + migration_scheduler.instance_id_type_set[InstanceType.NO_CONSTRAINTS].add(instance_id) migration_scheduler.instance_info = instance_info_dict - migrate_instance_pairs = migration_scheduler.pair_migration() - for migrate_out_instance, migrate_in_instance in migrate_instance_pairs: + migrate_instance_pairs = migration_scheduler.pair_migration(migration_target) + for migrate_out_instance, migrate_in_instance, strict_pre_migration in migrate_instance_pairs: assert migrate_out_instance != migrate_in_instance - if policy != 'defrag_relaxed': + assert strict_pre_migration == True + if migration_target != PairMigrationConstraints.PREFILL_2_DECODING: assert instance_info_dict[migrate_out_instance].num_killed_requests > 0 \ or instance_info_dict[migrate_out_instance].instance_load_migrate > MIGRATE_OUT_LOAD_THRESHOLD - assert instance_info_dict[migrate_in_instance].num_killed_requests == 0 \ - and instance_info_dict[migrate_in_instance].instance_load_migrate < MIGRATE_OUT_LOAD_THRESHOLD + assert instance_info_dict[migrate_in_instance].num_killed_requests == 0 \ + and instance_info_dict[migrate_in_instance].instance_load_migrate < MIGRATE_OUT_LOAD_THRESHOLD + if instance_info_dict[migrate_out_instance].num_killed_requests == 0: + assert instance_info_dict[migrate_out_instance].instance_load_migrate > instance_info_dict[migrate_in_instance].instance_load_migrate + else: + assert instance_info_dict[migrate_in_instance].num_killed_requests == 0 if policy == 'balanced': - assert instance_info_dict[migrate_out_instance].num_blocks_last_running_request == 0 - if instance_info_dict[migrate_out_instance].num_killed_requests == 0: - assert instance_info_dict[migrate_out_instance].instance_load_migrate > instance_info_dict[migrate_in_instance].instance_load_migrate + assert instance_info_dict[migrate_out_instance].num_blocks_last_running_request == 0 \ No newline at end of file diff --git a/tests/llumlet/test_local_migration_scheduler.py b/tests/llumlet/test_local_migration_scheduler.py index cd05c247..c1b08fbb 100644 --- a/tests/llumlet/test_local_migration_scheduler.py +++ b/tests/llumlet/test_local_migration_scheduler.py @@ -2,8 +2,8 @@ from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType class MockRequest(LlumnixRequest): - def __init__(self, request_id, length) -> None: - super().__init__(request_id=request_id, server_info=None) + def __init__(self, request_id, length, request_expected_steps) -> None: + super().__init__(request_id=request_id, server_info=None, request_expected_steps=request_expected_steps) self.length = length self.status = RequestInferenceType.DECODE @@ -27,19 +27,19 @@ class MockeEngine(): def __init__(self) -> None: self.running = [] - def add_request(self, request_id, length) -> None: - self.running.append(MockRequest(request_id, length)) + def add_request(self, request_id, length, request_expected_steps) -> None: + self.running.append(MockRequest(request_id, length, request_expected_steps)) def get_running_queue(self): return self.running def test_scheduler_policy(): engine = MockeEngine() - scheduler = LocalMigrationScheduler("", engine) + scheduler = LocalMigrationScheduler("", engine, strict_pre_migration=True) - engine.add_request(request_id="0", length=1) - engine.add_request(request_id="1", length=3) - engine.add_request(request_id="2", length=2) + engine.add_request(request_id="0", length=1, request_expected_steps=-1) + engine.add_request(request_id="1", length=3, request_expected_steps=-1) + engine.add_request(request_id="2", length=2, request_expected_steps=-1) scheduler.request_migration_policy = "LCFS" assert scheduler.get_migrate_out_request().request_id == "2" @@ -48,8 +48,14 @@ def test_scheduler_policy(): scheduler.request_migration_policy = "SJF" assert scheduler.get_migrate_out_request().request_id == "0" + engine.add_request(request_id="3", length=2, request_expected_steps=1) + scheduler.get_migrate_out_request().request_id == "3" + engine.add_request(request_id="4", length=3, request_expected_steps=-1) + scheduler.request_migration_policy = "LCFS" + assert scheduler.get_migrate_out_request().request_id == "3" + def test_scheduler_should_abort_migration(): - req_0 = MockRequest(request_id="0", length=1) + req_0 = MockRequest(request_id="0", length=1, request_expected_steps=-1) req_0.stage_timestamps = [1] assert req_0.should_abort_migration() is False req_0.status = RequestInferenceType.PREFILL @@ -57,3 +63,11 @@ def test_scheduler_should_abort_migration(): req_0.status = RequestInferenceType.DECODE req_0.last_preemption_time = 2 assert req_0.should_abort_migration() is True + +def test_blocking_migration(): + req_0 = MockRequest(request_id="0", length=1, request_expected_steps=-1) + assert req_0.blocking_migration is True + req_1 = MockRequest(request_id="1", length=2, request_expected_steps=1) + assert req_1.blocking_migration is True + req_2 = MockRequest(request_id="2", length=1, request_expected_steps=1) + assert req_2.blocking_migration is False \ No newline at end of file