From 598f08bd4e888b8c5dc91999939a57085b77c487 Mon Sep 17 00:00:00 2001 From: "zhizhi.zxy" Date: Wed, 28 Aug 2024 11:14:01 +0800 Subject: [PATCH] fix --- benchmark/benchmark_serving.py | 2 +- llumnix/arg_utils.py | 31 ++- llumnix/backends/backend_interface.py | 36 +-- llumnix/backends/vllm/llm_engine.py | 30 +-- llumnix/backends/vllm/scheduler.py | 59 +++-- llumnix/common/config.py | 211 ++++++++++++++++++ llumnix/common/defaults.py | 15 ++ llumnix/global_scheduler/global_scheduler.py | 11 +- .../global_scheduler/migration_scheduler.py | 61 +++-- llumnix/llm_engine_manager.py | 41 ++-- llumnix/llumlet/llumlet.py | 23 +- llumnix/llumlet/local_migration_scheduler.py | 19 +- llumnix/llumlet/migrating_request.py | 5 +- llumnix/llumlet/migration_coordinator.py | 2 +- 14 files changed, 383 insertions(+), 163 deletions(-) create mode 100644 llumnix/common/config.py create mode 100644 llumnix/common/defaults.py diff --git a/benchmark/benchmark_serving.py b/benchmark/benchmark_serving.py index a2d0f67d..e585023e 100644 --- a/benchmark/benchmark_serving.py +++ b/benchmark/benchmark_serving.py @@ -380,7 +380,7 @@ def __init__(self): self._all_decode_latencies = [] def measure(self, f): - async def measured(*args, **kwargs): + async def measured(*args, **gs): start = time.time() prompt, output = await f(*args, **kwargs) # Do not record latency if request failed. diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index a2115396..fb970f5b 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -15,9 +15,12 @@ from dataclasses import dataclass import argparse from typing import Tuple -import json + +from llumnix.common.config import get_cfg from llumnix.config import GlobalSchedulerConfig, MigrationConfig +from llumnix.logger import init_logger +logger = init_logger(__name__) @dataclass class EngineManagerArgs: @@ -58,21 +61,15 @@ class EngineManagerArgs: last_stage_max_blocks: int = 4 max_stages: int = 3 - pdd_config: str = None + config_file: str = None def create_engine_manager_configs( self, ) -> Tuple[GlobalSchedulerConfig]: - # Create the Configuration for prefill decoding disaggregation. - prefill_instance_num = -1 - # TODO[xinyi]: Bind the prefill instance to several fixed IP addresses, - # expanding into heterogeneous scenarios. - # prefill_instance_ip = None - if self.pdd_config: - pdd_config = json.load(open(self.pdd_config, 'r', encoding='utf-8')) - # TODO[xinyi]: hardcode the key fields in the pdd_config. - prefill_instance_num = pdd_config["prefill_instance_num"] - # prefill_instance_ip = pdd_config["prefill_instance_ip"] + + if self.config_file: + config_data = get_cfg() + config_data.merge_from_file(self.config_file) # Create the GlobalScheduler Configuration. global_scheduler_config = GlobalSchedulerConfig(self.initial_instances, @@ -84,8 +81,8 @@ def create_engine_manager_configs( self.scaling_policy, self.scale_up_threshold, self.scale_down_threshold, - self.pdd_config != None, - prefill_instance_num) + config_data.PDD_CONFIG.ENABLE_PREFILL_DISAGGREATION, + config_data.PDD_CONFIG.PREFILL_INSTANCE_NUM) return global_scheduler_config def create_migration_configs( @@ -228,8 +225,8 @@ def add_cli_args( type=int, default=EngineManagerArgs.max_stages, help='drop migration if the number of stages > max_stages') - parser.add_argument("--pdd-config", + parser.add_argument("--config-file", type=str, - default=EngineManagerArgs.pdd_config, - help="configuration for prefill decoding disaggregation") + default=EngineManagerArgs.config_file, + help="path to the configuration file") return parser diff --git a/llumnix/backends/backend_interface.py b/llumnix/backends/backend_interface.py index 2c13e796..129ca1d9 100644 --- a/llumnix/backends/backend_interface.py +++ b/llumnix/backends/backend_interface.py @@ -36,7 +36,7 @@ class BackendInferenceType(str, Enum): class BackendInterface(ABC): # Methods for inference @abstractmethod - def add_request(self, request_id: str, server_info: ServerInfo, req_expected_step: int, + def add_request(self, request_id: str, server_info: ServerInfo, instance_expected_steps: int, *args, **kwargs) -> None: """Adds a new inference request to the backend's processing queue. @@ -46,7 +46,9 @@ def add_request(self, request_id: str, server_info: ServerInfo, req_expected_ste Args: request_id: Request ID. server_info: The information of the api server where the request come. - req_expected_step: The expected number of steps for the request. + instance_expected_steps: The expected number of steps for the request to run on the instance. + The number of steps represents the times 'engine.step()' has been + called by the backend instance for the request. *args: Positional arguments that represent request-specific data. **kwargs: Keyword arguments that contain metadata of the backend request (request_id, arrival_time, etc.). @@ -195,7 +197,7 @@ def should_abort_migration(self, backend_request: Any, last_stage_time: int) -> raise NotImplementedError @abstractmethod - def add_running_request(self, backend_request: Any, req_expected_step: Optional[int]) -> None: + def add_running_request(self, backend_request: Any) -> None: """ Adds a backend request to the running queue for processing. @@ -207,7 +209,6 @@ def add_running_request(self, backend_request: Any, req_expected_step: Optional[ backend_request: An object representing the backend request. The type of this object is dependent on the backend implementation and the details of the request. - req_expected_step: The expected number of steps for the request. """ raise NotImplementedError @@ -277,7 +278,7 @@ def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[i raise NotImplementedError @abstractmethod - def commit_dst_request(self, backend_request: Any, server_info: ServerInfo, req_expected_step: int) -> None: + def commit_dst_request(self, backend_request: Any, server_info: ServerInfo, instance_expected_steps: int) -> None: """Commits the migrating request to the destination instance. This method finalizes the migration process by transferring all necessary metadata and resource @@ -289,7 +290,8 @@ def commit_dst_request(self, backend_request: Any, server_info: ServerInfo, req_ object is dependent on the backend implementation and the details of the request. server_info: The information of the api server where the request come. - req_expected_step: The expected number of steps for the request. + instance_expected_steps: The expected number of steps for the request to run on the instance. In the current version, it is implemented for prefill-decoding + disaggregation. """ raise NotImplementedError @@ -335,13 +337,12 @@ def get_shortest_running_request(self) -> Optional[MigratingRequest]: raise NotImplementedError @abstractmethod - def get_pre_migration_request(self) -> Optional[MigratingRequest]: - """Retrieves the request which meets the migration conditions from the running queue. + def get_ready_migration_request(self) -> Optional[MigratingRequest]: + """Retrieves the request which is ready for migration from the running queue. - This method iterates over the running queue in reverse order and returns the last request - that has moved past the prefilling stage and met the migration conditions. In the current - version, a request is considered to meet the migration confitions if its number of steps - exceeds expected_steps and backend.pre_migration is True. + This method iterates over the running queue in reverse order and returns the last request that has + met the migration conditions. A request is considered to meet the migration conditions if its number + of steps exceeds instance_expected_steps and backend.strict_pre_migration is True. Returns: An instance of MigratingRequest representing the last request in the running queue @@ -351,14 +352,17 @@ def get_pre_migration_request(self) -> Optional[MigratingRequest]: raise NotImplementedError @abstractmethod - def update_pre_migration(self, new_pre_migration: bool) -> None: + def update_strict_pre_migration(self, new_strict_pre_migration: bool) -> None: """Update the status of whether to force migration in the backend engine. - This method updates the status of whether to force migration in the backend engine. This - action is performed only when the corresponding status in the llumlet is changed. + This method updates the status of whether to force migration in the backend engine. This action is performed only when the + corresponding status in the llumlet is changed. + `pre_migration` represents whether the backend instance enables migration. By default, `pre_migration` is set to True, indicating that + the instance enables migration when `request.instance_completed_steps` >= `request.instance_expected_steps`. If `pre_migration` is set + to False, migration will not occur, and requests on the instance that reach the `instance_expected_steps` will continue with inference. Args: - new_pre_migration: New migration status provided for backend engine. + new_strict_pre_migration: New migration status provided for backend engine. """ raise NotImplementedError diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index 5457751a..aadd8046 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -110,7 +110,7 @@ def _process_model_outputs( new_scheduled_seq_groups.append(scheduled_seq_group) new_seq_group_metadata_list.append(seq_group_meta) new_output.append(seq_group_output) - self.scheduler.request_info[seq_group.request_id].completed_step += 1 + self.scheduler.request_infos[seq_group.request_id].instance_completed_steps += 1 scheduled_seq_groups = new_scheduled_seq_groups output[0].outputs = new_output seq_group_metadata_list = new_seq_group_metadata_list @@ -175,8 +175,8 @@ def free_request_states(self, request_id: Union[str, Iterable[str]]) -> None: del self.request_server_info[req_id] if req_id in self.scheduler.last_preemption_time_dict: del self.scheduler.last_preemption_time_dict[req_id] - if req_id in self.scheduler.request_info: - del self.scheduler.request_info[req_id] + if req_id in self.scheduler.request_infos: + del self.scheduler.request_infos[req_id] class BackendVLLM(BackendInterface): def __init__( @@ -196,7 +196,7 @@ def __init__( self.engine.scheduler = SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config) self.engine.output_processor.scheduler = self.engine.scheduler self.instance_id = instance_id - self.pre_migration = True + self.strict_pre_migration = True self.worker_handle_list = self.engine.model_executor.workers.copy() if len(self.worker_handle_list) + 1 == self.engine.parallel_config.world_size: self.worker_handle_list.insert(0, ray.get_actor(f"instance_{self.instance_id}", namespace="llumnix")) @@ -242,7 +242,7 @@ def restart_workers(self) -> None: def add_request(self, request_id: str, server_info: ServerInfo, - req_expected_step: int, + instance_expected_steps: int, *args, **kwargs) -> None: # When manager is unavailable, api server might dispatch the request that has already been dispatched. @@ -250,17 +250,17 @@ def add_request(self, return # Store the server information of each request to put the request outputs back to the corresponding api server correctly. self.engine.request_server_info[request_id] = server_info - with self.engine.scheduler.scheduler_lock: - self.engine.scheduler.request_info[request_id] = RequestInfo(expected_step=req_expected_step, completed_step=0) + self.engine.scheduler.update_request_infos(request_id, instance_expected_steps=instance_expected_steps, instance_completed_steps=0) self.engine.add_request(request_id, *args, **kwargs) - def commit_dst_request(self, backend_request: SequenceGroup, server_info: ServerInfo, req_expected_step: int) -> None: + def commit_dst_request(self, backend_request: SequenceGroup, server_info: ServerInfo, instance_expected_steps: int) -> None: seq = backend_request.get_seqs()[0] seq.seq_id = next(self.engine.seq_counter) logger.info("add seq {} to block table".format(seq.seq_id)) pre_alloc_blocks = self.engine.scheduler.pre_alloc_cache_dict.pop(backend_request.request_id) self.engine.scheduler.block_manager.add_block_table(pre_alloc_blocks, seq.seq_id) - self.add_running_request(backend_request, req_expected_step) + self.add_running_request(backend_request) + self.engine.scheduler.update_request_infos(backend_request.request_id, instance_expected_steps=instance_expected_steps) self.engine.request_server_info[backend_request.request_id] = server_info def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None: @@ -287,10 +287,10 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: def free_request_states(self, request_id: Union[str, Iterable[str]]) -> None: return self.engine.free_request_states(request_id) - def update_pre_migration(self, new_migration_state: bool): - if self.pre_migration != new_migration_state: - self.pre_migration = new_migration_state - self.engine.scheduler.update_pre_migration(new_migration_state) + def update_strict_pre_migration(self, new_migration_state: bool): + if self.strict_pre_migration != new_migration_state: + self.strict_pre_migration = new_migration_state + self.engine.scheduler.update_strict_pre_migration(new_migration_state) def get_request_incremental_blocks(self, *args, **kwargs) -> List[int]: return self.engine.scheduler.get_request_incremental_blocks(*args, **kwargs) @@ -334,8 +334,8 @@ def get_longest_running_request(self) -> Optional[MigratingRequest]: def get_shortest_running_request(self) -> Optional[MigratingRequest]: return self.engine.scheduler.get_shortest_running_request() - def get_pre_migration_request(self) -> Optional[MigratingRequest]: - return self.engine.scheduler.get_pre_migration_request() + def get_ready_migration_request(self) -> Optional[MigratingRequest]: + return self.engine.scheduler.get_ready_migration_request() def get_request_server_info(self, request_id: str) -> ServerInfo: return self.engine.request_server_info[request_id] diff --git a/llumnix/backends/vllm/scheduler.py b/llumnix/backends/vllm/scheduler.py index 2c8ae5ae..7ed708b3 100644 --- a/llumnix/backends/vllm/scheduler.py +++ b/llumnix/backends/vllm/scheduler.py @@ -31,9 +31,13 @@ logger = init_logger(__name__) class RequestInfo(): - def __init__(self, completed_step: int, expected_step: int) -> None: - self.expected_step = expected_step - self.completed_step = completed_step + def __init__(self, instance_completed_steps: int, instance_expected_steps: int) -> None: + # The number of steps executed on backend instance for the request. Using vLLM as the backend, + # each step generates one token, with the first step representing the prefill phase of the requets. + self.instance_completed_steps = instance_completed_steps + + # The expected number of steps for the request to run on the instance. + self.instance_expected_steps = instance_expected_steps # TODO(ZeldaHuang): adapt prefix cache and sliding window, now use v1 manager class BlockManagerLlumnix(BlockSpaceManagerV1): @@ -67,8 +71,8 @@ def __init__(self, *args, **kwargs) -> None: self.prefilling_seq_groups = [] self.scheduler_lock = threading.Lock() self.migrating_out_request_last_stage = [] - self.pre_migration = True - self.request_info: Dict[str, RequestInfo] = {} + self.strict_pre_migration = True + self.request_infos: Dict[str, RequestInfo] = {} def _preempt( self, @@ -89,15 +93,14 @@ def _get_num_killed_requests(self) -> int: # TODO(xinyi): Currently, the function is only used for Prefill-decoding disaggregation, # and only selects request that migrates from the prefill instance to the decoding instance. @scheduler_lock - def get_pre_migration_request(self) -> Optional[MigratingRequest]: - pre_migration_request = None - if self.pre_migration: + def get_ready_migration_request(self) -> Optional[MigratingRequest]: + if self.strict_pre_migration: for seq_group in reversed(self.running): - # logger.info("get_pre_migration_request {} {}".format(self.request_info[seq_group.request_id].expected_step,self.request_info[seq_group.request_id].completed_step)) - if self.request_info[seq_group.request_id].expected_step > 0 and \ - self.request_info[seq_group.request_id].completed_step >= self.request_info[seq_group.request_id].expected_step: - return MigratingRequest(seq_group.request_id, seq_group, expected_step=-1, blocking_migration=False) - return pre_migration_request + # logger.info("get_ready_migration_request {} {}".format(self.request_infos[seq_group.request_id].instance_expected_steps,self.request_info[seq_group.request_id].instance_completed_steps)) + if self.request_infos[seq_group.request_id].instance_expected_steps > 0 and \ + self.request_infos[seq_group.request_id].instance_completed_steps >= self.request_infos[seq_group.request_id].instance_expected_steps: + return MigratingRequest(seq_group.request_id, seq_group, instance_expected_steps=-1, blocking_migration=False) + return None @scheduler_lock def get_last_running_request(self) -> Optional[MigratingRequest]: @@ -177,15 +180,10 @@ def should_abort_migration(self, backend_request: SequenceGroup, last_stage_time return False @scheduler_lock - def add_running_request(self, backend_request: SequenceGroup, req_expected_step: Optional[int] = None) -> None: + def add_running_request(self, backend_request: SequenceGroup) -> None: seq = backend_request.get_seqs()[0] seq.status = SequenceStatus.RUNNING self.running.append(backend_request) - if req_expected_step: - if backend_request.request_id not in self.request_info: - self.request_info[backend_request.request_id] = RequestInfo(expected_step=req_expected_step, completed_step=0) - else: - self.request_info[backend_request.request_id].expected_step = req_expected_step @scheduler_lock def is_request_running(self, backend_request: SequenceGroup) -> bool: @@ -269,10 +267,10 @@ def _schedule_running(self, *args, **kwargs): args_list = list(args) args_list[0] = copy.deepcopy(self.running) remove_running = [] - if self.pre_migration: + if self.strict_pre_migration: for seq_group in list(args_list[0]): - if self.request_info[seq_group.request_id].expected_step > 0 and \ - self.request_info[seq_group.request_id].completed_step >= self.request_info[seq_group.request_id].expected_step: + if self.request_infos[seq_group.request_id].instance_expected_steps > 0 and \ + self.request_infos[seq_group.request_id].instance_completed_steps >= self.request_infos[seq_group.request_id].instance_expected_steps: args_list[0].remove(seq_group) remove_running.append(seq_group) new_args = tuple(args_list) @@ -280,10 +278,23 @@ def _schedule_running(self, *args, **kwargs): for seq_group in remove_running: remaining_running.append(seq_group) return remaining_running, running_scheduled + + @scheduler_lock + def update_request_infos(self, request_id: str, + instance_expected_steps: Optional[int] = None, + instance_completed_steps: Optional[int] = None) -> None: + if request_id not in self.request_infos: + self.request_infos[request_id] = RequestInfo(instance_expected_steps, instance_completed_steps) + else: + if instance_expected_steps is not None: + self.request_infos[request_id].instance_expected_steps = instance_expected_steps + if instance_completed_steps is not None: + self.request_infos[request_id].instance_completed_steps = instance_completed_steps + @scheduler_lock - def update_pre_migration(self, new_migration_state: bool) -> None: - self.pre_migration = new_migration_state + def update_strict_pre_migration(self, new_migration_state: bool) -> None: + self.strict_pre_migration = new_migration_state @scheduler_lock def add_seq_group(self, *args, **kwargs): diff --git a/llumnix/common/config.py b/llumnix/common/config.py new file mode 100644 index 00000000..7149f5aa --- /dev/null +++ b/llumnix/common/config.py @@ -0,0 +1,211 @@ +# encoding: utf-8 +""" +@author: l1aoxingyu +@contact: sherlockliao01@gmail.com +""" + +import copy +from typing import Any +from ast import literal_eval +import yaml + +from llumnix.logger import init_logger + +logger = init_logger(__name__) + +class Config(dict): + """ + Config represents an internal node in the configuration tree. It's a simple + dict-like container that allows for attribute-based access to keys. + """ + def __init__(self, init_dict=None, key_list=None): + """ + Args: + init_dict (dict): the possibly-nested dictionary to initailize the Config. + key_list (list[str]): a list of names which index this Config from the root. + Currently only used for logging purposes. + new_allowed (bool): whether adding new key is allowed when merging with + other configs. + """ + # Recursively convert nested dictionaries in init_dict into Configs + init_dict = {} if init_dict is None else init_dict + key_list = [] if key_list is None else key_list + init_dict = self._create_config_tree_from_dict(init_dict, key_list) + super(Config, self).__init__(init_dict) + + @classmethod + def _create_config_tree_from_dict(cls, dic, key_list): + """ + Create a configuration tree using the given dict. + Any dict-like objects inside dict will be treated as a new Config. + + Args: + dic (dict): + key_list (list[str]): a list of names which index this Config from the root. + Currently only used for logging purposes. + """ + dic = copy.deepcopy(dic) + for k, v in dic.items(): + if isinstance(v, dict): + # Convert dict to Config + dic[k] = cls(v, key_list=key_list + [k]) + return dic + + @classmethod + def _decode_cfg_value(cls, value): + """ + Decodes a raw config value (e.g., from a yaml config files or command + line argument) into a Python object. + + If the value is a dict, it will be interpreted as a new Config. + If the value is a str, it will be evaluated as literals. + Otherwise it is returned as-is. + """ + # Configs parsed from raw yaml will contain dictionary keys that need to be + # converted to Config objects + if isinstance(value, dict): + return cls(value) + # All remaining processing is only applied to strings + if not isinstance(value, str): + return value + # Try to interpret `value` as a: + # string, number, tuple, list, dict, boolean, or None + try: + value = literal_eval(value) + # The following two excepts allow v to pass through when it represents a + # string. + # + # Longer explanation: + # The type of v is always a string (before calling literal_eval), but + # sometimes it *represents* a string and other times a data structure, like + # a list. In the case that v represents a string, what we got back from the + # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is + # ok with '"foo"', but will raise a ValueError if given 'foo'. In other + # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval + # will raise a SyntaxError. + except ValueError: + pass + except SyntaxError: + pass + return value + + + @staticmethod + def load_yaml_with_base(filename: str, allow_unsafe: bool = False): + """ + With "allow_unsafe=True", it supports pyyaml tags that evaluate + expressions in config. See examples in + https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types + Note that this may lead to arbitrary code execution: you must not + load a config file from untrusted sources before manually inspecting + the content of the file. + Args: + filename (str): the file name of the current config. Will be used to + find the base config file. + allow_unsafe (bool): whether to allow loading the config file with + `yaml.unsafe_load`. + Returns: + (dict): the loaded yaml + """ + with open(filename, "r", encoding='utf-8') as f: + try: + cfg = yaml.safe_load(f) + except yaml.constructor.ConstructorError: + if not allow_unsafe: + raise + logger.warning( + "Loading config {} with yaml.unsafe_load. Your machine may " + "be at risk if the file contains malicious content.".format( + filename + ) + ) + f.close() + with open(filename, "r") as f: + cfg = yaml.unsafe_load(f) + return cfg + + # def __str__(self): + # def _indent(s_, num_spaces): + # s = s_.split("\n") + # if len(s) == 1: + # return s_ + # first = s.pop(0) + # s = [(num_spaces * " ") + line for line in s] + # s = "\n".join(s) + # s = first + "\n" + s + # return s + + # r = "" + # s = [] + # for k, v in sorted(self.items()): + # seperator = "\n" if isinstance(v, Config) else " " + # attr_str = "{}:{}{}".format(str(k), seperator, str(v)) + # attr_str = _indent(attr_str, 2) + # s.append(attr_str) + # r += "\n".join(s) + # return r + + # def __repr__(self): + # return "{}({})".format(self.__class__.__name__, super(Config, self).__repr__()) + + + + def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = False): + """ + Merge configs from a given yaml file. + Args: + cfg_filename: the file name of the yaml config. + allow_unsafe: whether to allow loading the config file with + `yaml.unsafe_load`. + """ + loaded_cfg = self.load_yaml_with_base( + cfg_filename, allow_unsafe=allow_unsafe + ) + _merge_a_into_b(loaded_cfg, self, self, []) + + def clone(self): + """Recursively copy this Config.""" + return copy.deepcopy(self) + + def __getattr__(self, name): + if name in self: + return self[name] + else: + raise AttributeError(name) + + def __setattr__(self, name, value): + self[name] = value + + +def get_cfg() -> Config: + """ + Get a copy of the default config. + Returns: + a Config instance. + """ + from llumnix.common.defaults import _C + + return _C.clone() + +def _merge_a_into_b(a, b, root, key_list): + """Merge config dictionary a into config dictionary b, clobbering the + options in b whenever they are also specified in a. + """ + if a is None: + return + for k, v_ in a.items(): + full_key = ".".join(key_list + [k]) + + v = copy.deepcopy(v_) + v = b._decode_cfg_value(v) + if k in b: + # Recursively merge dicts + if isinstance(v, Config): + try: + _merge_a_into_b(v, b[k], root, key_list + [k]) + except BaseException: + raise + else: + b[k] = v + else: + raise KeyError("Non-existent config key: {}".format(full_key)) diff --git a/llumnix/common/defaults.py b/llumnix/common/defaults.py new file mode 100644 index 00000000..1fa0e8c7 --- /dev/null +++ b/llumnix/common/defaults.py @@ -0,0 +1,15 @@ +from llumnix.common.config import Config as Cfg + +# ----------------------------------------------------------------------------- +# Config definition +# ----------------------------------------------------------------------------- + +_C = Cfg() + +# ----------------------------------------------------------------------------- +# Prefill Decoding Disaggregation Config +# ----------------------------------------------------------------------------- +_C.PDD_CONFIG = Cfg() +_C.PDD_CONFIG.ENABLE_PREFILL_DISAGGREATION = False +_C.PDD_CONFIG.PREFILL_INSTANCE_NUM = -1 +_C.PDD_CONFIG.PREFILL_INSTANCE_TYPE = None \ No newline at end of file diff --git a/llumnix/global_scheduler/global_scheduler.py b/llumnix/global_scheduler/global_scheduler.py index f629b9ee..31dccae9 100644 --- a/llumnix/global_scheduler/global_scheduler.py +++ b/llumnix/global_scheduler/global_scheduler.py @@ -64,15 +64,12 @@ def update_instance_infos(self, instance_infos: List[InstanceInfo]) -> None: def dispatch(self) -> str: self.dispatch_scheduler.update_instance_infos(self.instance_info) instance_id = self.dispatch_scheduler.dispatch() - if self.enable_pd_disaggregation: - req_expected_step = 1 - else: - req_expected_step = -1 - return instance_id, req_expected_step + instance_expected_steps = 1 if self.enable_pd_disaggregation else -1 + return instance_id, instance_expected_steps - def pair_migration(self, migrate_target:str) -> List[Tuple[str, str]]: + def pair_migration(self, migration_target:str) -> List[Tuple[str, str]]: self.migration_scheduler.update_instance_infos(self.instance_info) - migrate_instance_pairs = self.migration_scheduler.pair_migration(migrate_target) + migrate_instance_pairs = self.migration_scheduler.pair_migration(migration_target) return migrate_instance_pairs def check_scale(self) -> Tuple[str, str]: diff --git a/llumnix/global_scheduler/migration_scheduler.py b/llumnix/global_scheduler/migration_scheduler.py index e5f71a42..c586109a 100644 --- a/llumnix/global_scheduler/migration_scheduler.py +++ b/llumnix/global_scheduler/migration_scheduler.py @@ -22,9 +22,9 @@ logger = init_logger(__name__) -class MigrationTarget(enum.Enum): +class PairMigrationConstraints(enum.Enum): """Target of Migration.""" - GENERAL = enum.auto() + NO_CONSTRAINTS = enum.auto() # Enable the prefill-decoding disaggregration. DECODING_2_DECODING = enum.auto() @@ -61,37 +61,31 @@ def __init__(self, self.sorted_prefill_instance_infos: List[InstanceInfo] = None self.sorted_decoding_instance_infos: List[InstanceInfo] = None - def pair_migration(self, migrate_target:str) -> List[Tuple[str, str]]: + def pair_migration(self, migration_target:str) -> List[Tuple[str, str]]: self._sort_instance_infos(descending=False) - sorted_src_instance_infos, sorted_dst_instance_infos, pre_migration = self._get_migration_pattern(migrate_target) - return self.pair_migration_policy.pair_migration(sorted_src_instance_infos, sorted_dst_instance_infos, pre_migration) - def _get_migration_pattern(self, migrate_target:str) -> Dict[str, InstanceInfo]: - pre_migration = True - if migrate_target == MigrationTarget.GENERAL: + sorted_src_instance_infos, sorted_dst_instance_infos, strict_pre_migration = self._get_migration_settings(migration_target) + return self.pair_migration_policy.pair_migration(sorted_src_instance_infos, sorted_dst_instance_infos, strict_pre_migration) + + def _get_migration_settings(self, migration_target:str) -> Dict[str, InstanceInfo]: + strict_pre_migration = True + if migration_target == PairMigrationConstraints.NO_CONSTRAINTS: # migrate in instances sorted_src_instance_infos = [i for i in reversed(self.sorted_prefill_instance_infos) if i.num_killed_requests > 0 or i.instance_load_migrate > self.migrate_out_load_threshold] # migrate out instances sorted_dst_instance_infos = [i for i in self.sorted_prefill_instance_infos if i.num_killed_requests == 0 and i.instance_load_migrate < self.migrate_out_load_threshold] - elif migrate_target == MigrationTarget.PREFILL_2_DECODING: + elif migration_target == PairMigrationConstraints.PREFILL_2_DECODING: sorted_src_instance_infos = [i for i in reversed(self.sorted_prefill_instance_infos)] sorted_dst_instance_infos = [i for i in self.sorted_decoding_instance_infos - if i.num_killed_requests == 0] # and i.instance_load_migrate < self.migrate_out_load_threshold - # TODO[xinyi]: For PDD, add more constaints considering decoding instances load. - # if len(sorted_dst_instance_infos) == 0: - # pre_migration = False - # else: - # idx = -1 - # while len(sorted_src_instance_infos) > len(sorted_dst_instance_infos): - # idx = (idx+1) % len(sorted_src_instance_infos) - # sorted_dst_instance_infos.insert(0, sorted_dst_instance_infos[idx]) - elif migrate_target == MigrationTarget.DECODING_2_DECODING: + if i.num_killed_requests == 0] + # TODO[xinyi]: Considering decoding instances load, try to decode on the prefill instance(set strict_pre_migration as False). + elif migration_target == PairMigrationConstraints.DECODING_2_DECODING: sorted_src_instance_infos = [i for i in reversed(self.sorted_decoding_instance_infos) if i.num_killed_requests > 0 or i.instance_load_migrate > self.migrate_out_load_threshold] sorted_dst_instance_infos = [i for i in self.sorted_decoding_instance_infos if i.num_killed_requests == 0 and i.instance_load_migrate < self.migrate_out_load_threshold] - return sorted_src_instance_infos, sorted_dst_instance_infos, pre_migration + return sorted_src_instance_infos, sorted_dst_instance_infos, strict_pre_migration def update_instance_infos(self, instance_info: Dict[str, InstanceInfo]) -> None: @@ -137,7 +131,7 @@ def __init__(self, def pair_migration(self, sorted_src_instance_infos: List[InstanceInfo], sorted_dst_instance_infos: List[InstanceInfo], - pre_migration: bool, + strict_pre_migration: bool, ) -> List[Tuple[str, str]]: raise NotImplementedError @@ -145,19 +139,19 @@ class Balanced(PairMigrationPolicy): def pair_migration(self, sorted_src_instance_infos: List[InstanceInfo], sorted_dst_instance_infos: List[InstanceInfo], - pre_migration: bool, + strict_pre_migration: bool, ) -> List[Tuple[str, str]]: migrate_instance_pairs = [] for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))): load_diff_before_mig = sorted_src_instance_infos[i].instance_load_migrate - sorted_dst_instance_infos[i].instance_load_migrate - left_load_after_mig = self._compute_instance_load_after_migrate(sorted_dst_instance_infos[i], is_migrate_in=True) - right_load_after_mig = self._compute_instance_load_after_migrate(sorted_src_instance_infos[i], is_migrate_in=False) + left_load_after_mig = self._compute_instance_load_after_migrate(sorted_src_instance_infos[i], is_migrate_in=False) + right_load_after_mig = self._compute_instance_load_after_migrate(sorted_dst_instance_infos[i], is_migrate_in=True) # Add some constrains to reduce unnecessary migrations - if left_load_after_mig > self.migrate_out_load_threshold: + if right_load_after_mig > self.migrate_out_load_threshold: continue - load_diff_after_mig = right_load_after_mig - left_load_after_mig + load_diff_after_mig = left_load_after_mig - right_load_after_mig if (0 < load_diff_after_mig < load_diff_before_mig) or (sorted_dst_instance_infos[i].instance_load_migrate == -np.inf): - migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id, sorted_dst_instance_infos[i].instance_id, pre_migration)) + migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id, sorted_dst_instance_infos[i].instance_id, strict_pre_migration)) return migrate_instance_pairs def _compute_instance_load_after_migrate(self, instance_info: InstanceInfo, is_migrate_in: bool) -> float: @@ -175,17 +169,12 @@ class DefragConstrained(PairMigrationPolicy): def pair_migration(self, sorted_src_instance_infos: List[InstanceInfo], sorted_dst_instance_infos: List[InstanceInfo], - pre_migration: bool, + strict_pre_migration: bool, ) -> List[Tuple[str, str]]: migrate_instance_pairs = [] - if not pre_migration and len(sorted_dst_instance_infos) == 0: - # No suitable migrating in instances, and the migrating out instance is allowed to continue. - for i in range(len(sorted_src_instance_infos)): - migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id, "invalid_instance", pre_migration)) - else: - for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))): - # without any constrain in order to make prefill migrate happens as soon as possible - migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id, sorted_dst_instance_infos[i].instance_id, pre_migration)) + for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))): + # without any constrain in order to make prefill migrate happens as soon as possible + migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id, sorted_dst_instance_infos[i].instance_id, strict_pre_migration)) return migrate_instance_pairs class PairMigrationPolicyFactory: diff --git a/llumnix/llm_engine_manager.py b/llumnix/llm_engine_manager.py index 8e125471..52cd37b6 100644 --- a/llumnix/llm_engine_manager.py +++ b/llumnix/llm_engine_manager.py @@ -23,7 +23,7 @@ from llumnix.llumlet.llumlet import Llumlet from llumnix.logger import init_logger from llumnix.global_scheduler.global_scheduler import GlobalScheduler -from llumnix.global_scheduler.migration_scheduler import MigrationTarget +from llumnix.global_scheduler.migration_scheduler import PairMigrationConstraints from llumnix.instance_info import InstanceInfo from llumnix.config import GlobalSchedulerConfig from llumnix.arg_utils import EngineManagerArgs @@ -63,7 +63,7 @@ def __init__(self, self.max_instances = engine_manager_args.max_instances self.min_instances = engine_manager_args.min_instances - self.enable_pd_disaggregation = (engine_manager_args.pdd_config != None) + self.enable_pd_disaggregation = global_scheduler_config.enable_pd_disaggregation logger.info("LLMEngineManager starts") logger.info("enable_migration: {}".format(self.enable_migration)) @@ -115,9 +115,9 @@ async def generate( logger.info("No instance available temporarily, sleep {}s, " "and retry generate request {} again....".format(RETRIES_INTERVALS, request_id)) await asyncio.sleep(RETRIES_INTERVALS) - instance_id, req_expected_step = self.global_scheduler.dispatch() + instance_id, instance_expected_steps = self.global_scheduler.dispatch() try: - await self.instances[instance_id].generate.remote(request_id, server_info, req_expected_step, *args, **kwargs) + await self.instances[instance_id].generate.remote(request_id, server_info, instance_expected_steps, *args, **kwargs) if self.log_requests: logger.info("received request {}.".format(request_id)) logger.info("dispath to instance {}".format(instance_id)) @@ -232,36 +232,33 @@ async def _post_migrate(self, rets: List[str], call_migrate_instance_pairs: List self.request_instance[migrate_out_request_id] = call_migrate_instance_pairs[i][1] logger.info("{}->{} migrate done, migrate request {}".format( call_migrate_instance_pairs[i][0], call_migrate_instance_pairs[i][1], migrate_out_request_ids)) + async def _migrate_control(self) -> None: - try: - # Push migrate when the instance_info have updated a certain number of times. - if self.enable_pd_disaggregation: - asyncio.create_task(self._migrate(MigrationTarget.PREFILL_2_DECODING, -1)) - asyncio.create_task(self._migrate(MigrationTarget.DECODING_2_DECODING, 1)) - else: - asyncio.create_task(self._migrate(MigrationTarget.GENERAL, 1)) - # pylint: disable=W0703 - except Exception as e: - logger.error("unexpected exception occurs: {}".format(e)) - logger.error("exception traceback: {}".format(traceback.format_exc())) - - async def _migrate(self, migrate_target:str, migrate_in_num_requests:int) -> None: - migrate_instance_pairs = self.global_scheduler.pair_migration(migrate_target) + # Push migrate when the instance_info have updated a certain number of times. + if self.enable_pd_disaggregation: + asyncio.create_task(self._migrate(PairMigrationConstraints.PREFILL_2_DECODING, -1)) + asyncio.create_task(self._migrate(PairMigrationConstraints.DECODING_2_DECODING, 1)) + else: + asyncio.create_task(self._migrate(PairMigrationConstraints.NO_CONSTRAINTS, 1)) + + async def _migrate(self, migration_target:str, migrate_in_num_requests:int) -> None: + migrate_instance_pairs = self.global_scheduler.pair_migration(migration_target) # if len(migrate_instance_pairs)>0: - # logger.info("[_migrate] migrate_instance_pairs {} {}".format(migrate_target, migrate_instance_pairs)) + # logger.info("[_migrate] migrate_instance_pairs {} {}".format(migration_target, migrate_instance_pairs)) try: migration_tasks = [] call_migrate_instance_pairs: List[Tuple[str, str]] = [] for _, migrate_instance_pair in enumerate(migrate_instance_pairs): - migrate_out_instance_id, migrate_in_instance_id, pre_migration = migrate_instance_pair - # logger.info("[_migrate] migrate_instance_pairs {} {} {} {} {}".format(migrate_target, migrate_out_instance_id, migrate_in_instance_id, self.instance_migrating[migrate_out_instance_id], self.instance_migrating[migrate_in_instance_id])) + migrate_out_instance_id, migrate_in_instance_id, strict_pre_migration = migrate_instance_pair + # logger.info("[_migrate] migrate_instance_pairs {} {} {} {} {}".format(migration_target, migrate_out_instance_id, migrate_in_instance_id, self.instance_migrating[migrate_out_instance_id], self.instance_migrating[migrate_in_instance_id])) if self.instance_migrating[migrate_out_instance_id] or self.instance_migrating[migrate_in_instance_id]: continue + # logger.info("[_migrate] migrate_instance_pairs {} {} {} ".format(migration_target, migrate_out_instance_id, migrate_in_instance_id)) self.instance_migrating[migrate_out_instance_id] = True self.instance_migrating[migrate_in_instance_id] = True migrate_in_instance_name = "instance_{}".format(migrate_in_instance_id) call_migrate_instance_pairs.append(migrate_instance_pair) - task = self.instances[migrate_out_instance_id].migrate_out.remote(migrate_in_instance_name, migrate_in_num_requests, pre_migration) + task = self.instances[migrate_out_instance_id].migrate_out.remote(migrate_in_instance_name, migrate_in_num_requests, strict_pre_migration) migration_tasks.append(task) # TODO(yiwang): It's not necessary for manager to await for each migration. # TODO(yiwang): Migration failover could be implemented in Llumlet rather than manager. diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index d3b7a54a..7a325742 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -49,7 +49,8 @@ def __init__(self, self.migration_scheduler = LocalMigrationScheduler(migration_config.request_migration_policy, self.backend_engine) self.log_requests = True - self.pre_migration = True + + self.strict_pre_migration = True @classmethod def from_args(cls, @@ -100,19 +101,19 @@ def from_args(cls, llumlet = engine_class.remote(instance_id, backend_type, migration_config, *args, **kwargs) return llumlet - def migrate_out(self, dst_instance_name: str, num_requests: int, pre_migration: bool) -> List[str]: + def migrate_out(self, dst_instance_name: str, num_requests: int, strict_pre_migration: bool) -> List[str]: try: - self.update_pre_migration(pre_migration) + self.update_strict_pre_migration(strict_pre_migration) migrate_in_ray_actor = ray.get_actor(dst_instance_name, namespace='llumnix') dst_instance_id = dst_instance_name[len("instance_"):] migrated_request_list = [] continue_migrate = True while continue_migrate: t0 = time.time() - migrate_out_request = self.migration_scheduler.get_migrate_out_request(num_requests) + migrate_out_request = self.migration_scheduler.get_migrate_out_request() if migrate_out_request is not None: logger.info("migrate_out {}".format(migrate_out_request.request_id)) - if migrate_out_request is None or not self.pre_migration: + if migrate_out_request is None or not self.strict_pre_migration: return migrated_request_list logger.info("{}->{} begin migrate out".format(self.instance_id, dst_instance_id, migrate_out_request.request_id)) status = self.migration_coordinator.migrate_out_multistage(migrate_in_ray_actor, migrate_out_request) @@ -145,21 +146,21 @@ def is_ready(self) -> bool: def get_all_request_ids(self) -> List[str]: return self.backend_engine.get_all_request_ids() - def update_pre_migration(self, new_pre_migration:str) -> None: - if self.pre_migration != new_pre_migration: - self.pre_migration = new_pre_migration - self.backend_engine.update_pre_migration(new_pre_migration) + def update_strict_pre_migration(self, new_strict_pre_migration:str) -> None: + if self.strict_pre_migration != new_strict_pre_migration: + self.strict_pre_migration = new_strict_pre_migration + self.backend_engine.update_strict_pre_migration(new_strict_pre_migration) def generate( self, request_id: str, server_info: ServerInfo, - req_expected_step: int, + expected_steps: int, *args, **kwargs, ) -> None: # This should not be used for logging, as it is monotonic time. - self.backend_engine.add_request(request_id, server_info, req_expected_step, *args, **kwargs) + self.backend_engine.add_request(request_id, server_info, expected_steps, *args, **kwargs) def abort(self, request_id: Union[str, Iterable[str]]) -> None: if isinstance(request_id, str): diff --git a/llumnix/llumlet/local_migration_scheduler.py b/llumnix/llumlet/local_migration_scheduler.py index 37b649d9..b432c1dc 100644 --- a/llumnix/llumlet/local_migration_scheduler.py +++ b/llumnix/llumlet/local_migration_scheduler.py @@ -20,17 +20,14 @@ def __init__(self, request_migration_policy: str, backend_engine: BackendInterfa self.request_migration_policy = request_migration_policy self.backend_engine = backend_engine - def get_migrate_out_request(self, num_requests) -> Optional[MigratingRequest]: - # TODO(s5u13b): remove the if-else codes - migrate_out_request: MigratingRequest = None - if num_requests == -1: - migrate_out_request = self.backend_engine.get_pre_migration_request() - else: + def get_migrate_out_request(self) -> Optional[MigratingRequest]: + # Requests meet the strict pre-migration always have higher prioirity than other migration policy. + migrate_out_request = self.backend_engine.get_ready_migration_request() + if migrate_out_request is not None: if self.migrate_policy == 'LCFS': - migrate_out_request = self.backend_engine.get_last_running_request() - elif self.migrate_policy in ['SJF', 'LJF']: - if self.migrate_policy == 'LJF': + migrate_out_request = self.backend_engine.get_last_running_request() + elif self.migrate_policy == 'LJF': migrate_out_request = self.backend_engine.get_longest_running_request() - elif self.migrate_policy == 'SJF': - migrate_out_request = self.backend_engine.get_shortest_running_request() + elif self.migrate_policy == 'SJF': + migrate_out_request = self.backend_engine.get_shortest_running_request() return migrate_out_request diff --git a/llumnix/llumlet/migrating_request.py b/llumnix/llumlet/migrating_request.py index 4c649064..8afb267b 100644 --- a/llumnix/llumlet/migrating_request.py +++ b/llumnix/llumlet/migrating_request.py @@ -18,7 +18,7 @@ def __init__( self, request_id: int, backend_request: Any, - expected_step: Optional[int] = -1, + instance_expected_steps: Optional[int] = -1, blocking_migration: Optional[bool] = True, ) -> None: self.request_id = request_id @@ -26,5 +26,6 @@ def __init__( self.stage_timestamps = [] self.stage_num_blocks_list = [] self.server_info = None - self.expected_step = expected_step + self.instance_expected_steps = instance_expected_steps + # whether to migrate in multiple stages self.blocking_migration = blocking_migration \ No newline at end of file diff --git a/llumnix/llumlet/migration_coordinator.py b/llumnix/llumlet/migration_coordinator.py index ae2b0953..68cd987c 100644 --- a/llumnix/llumlet/migration_coordinator.py +++ b/llumnix/llumlet/migration_coordinator.py @@ -107,7 +107,7 @@ def migrate_in_last_stage(self, request: MigratingRequest, block_num: int) -> Li pre_alloc_blocks = self.migrate_in_pre_alloc(request.request_id, block_num) if len(pre_alloc_blocks) == block_num: # Pass the server information of the request to dst instance. - self.backend_engine.commit_dst_request(request.backend_request, request.server_info, request.expected_step) + self.backend_engine.commit_dst_request(request.backend_request, request.server_info, request.instance_expected_steps) return pre_alloc_blocks def migrate_in_pre_alloc(self, request_id: str, block_num: int) -> List[int]: