From 0b48bbcf8c40caf172c3d7c4fb449ba7200dd27a Mon Sep 17 00:00:00 2001 From: Biao Sun Date: Fri, 23 Aug 2024 09:43:05 +0800 Subject: [PATCH] [CI] Add unittest for global_scheduler and entrypoints (#12) --- benchmark/benchmark_serving.py | 30 +-- docs/Arguments.md | 12 +- llumnix/arg_utils.py | 16 +- llumnix/backends/vllm/llm_engine.py | 8 +- llumnix/backends/vllm/scheduler.py | 46 ++-- llumnix/backends/vllm/worker.py | 4 +- llumnix/entrypoints/llumnix_utils.py | 17 +- llumnix/entrypoints/vllm/api_server.py | 22 +- .../global_scheduler/dispatch_scheduler.py | 50 ++-- llumnix/global_scheduler/global_scheduler.py | 22 +- .../global_scheduler/migration_scheduler.py | 70 +++--- llumnix/global_scheduler/scaling_scheduler.py | 50 ++-- llumnix/instance_info.py | 116 ++++----- llumnix/llm_engine_manager.py | 68 +++--- llumnix/llumlet/llumlet.py | 6 - requirements.txt | 1 + tests/entrypoints/test_llumnix_utils.py | 65 +++++ tests/entrypoints/vllm/api_server_manager.py | 79 ++++++ tests/entrypoints/vllm/test_api_server.py | 123 ++++++++++ .../test_dispatch_scheduler.py | 89 +++++++ .../global_scheduler/test_global_scheduler.py | 87 +++++++ .../test_llm_engine_manager.py | 231 ++++++++++++++++++ .../test_migration_scheduler.py | 70 ++++++ 23 files changed, 1026 insertions(+), 256 deletions(-) create mode 100644 tests/entrypoints/test_llumnix_utils.py create mode 100644 tests/entrypoints/vllm/api_server_manager.py create mode 100644 tests/entrypoints/vllm/test_api_server.py create mode 100644 tests/global_scheduler/test_dispatch_scheduler.py create mode 100644 tests/global_scheduler/test_global_scheduler.py create mode 100644 tests/global_scheduler/test_llm_engine_manager.py create mode 100644 tests/global_scheduler/test_migration_scheduler.py diff --git a/benchmark/benchmark_serving.py b/benchmark/benchmark_serving.py index d85f1ee3..a2d0f67d 100644 --- a/benchmark/benchmark_serving.py +++ b/benchmark/benchmark_serving.py @@ -31,8 +31,8 @@ from typing import List -num_finished_request = 0 -server_num_request = {} +num_finished_requests = 0 +server_num_requests = {} def get_wait_time(mean_time_between_requests: float, distribution: str, coefficient_variation: float = 0.0) -> float: @@ -76,11 +76,11 @@ async def query_model_vllm(prompt, verbose, ip_ports): prompt, prompt_len, expected_response_len = prompt # Round-Robin dispatch request to the given api servers. - global server_num_request - server_id = min(server_num_request, key=server_num_request.get) - server_num_request[server_id] += 1 + global server_num_requests + server_id = min(server_num_requests, key=server_num_requests.get) + server_num_requests[server_id] += 1 timeout = aiohttp.ClientTimeout(total=4*60*60) - global num_finished_request + global num_finished_requests async with aiohttp.ClientSession(timeout=timeout) as session: # TODO(yiwang): Remove hard codes of params. @@ -111,8 +111,8 @@ async def query_model_vllm(prompt, verbose, ip_ports): output['response_len'] = expected_response_len if verbose and 'generated_text' in output: print(json.dumps(output['generated_text'])) - num_finished_request += 1 - print("num_finised_request: {}".format(num_finished_request)) + num_finished_requests += 1 + print("num_finised_requests: {}".format(num_finished_requests)) return (prompt, output) except aiohttp.ClientError as e: print(f"Connect to {ip_ports[server_id]} failed with: {str(e)}") @@ -334,18 +334,18 @@ def plot_instance(log_filename_0): log_files.sort(key=os.path.getmtime, reverse=True) df_0 = pd.read_csv(log_files[0]).sort_values(by=["timestamp"]) timestamp_list_0 = df_0["timestamp"].to_numpy() - instance_num_list_0 = df_0["num_instance"].to_numpy() + num_instances_list_0 = df_0["num_instances"].to_numpy() time_0 = 0 sum_0 = 0 for idx, t in enumerate(timestamp_list_0): if t > time_0: time_0 += 1 - sum_0 += instance_num_list_0[idx] + sum_0 += num_instances_list_0[idx] print(f"{sum_0/time_0} gpu/s") avg_instance_num = np.round(sum_0/time_0, 2) fig, ax = plt.subplots() - ax.plot(timestamp_list_0, instance_num_list_0, color="red", label=f"instance_num(avg {avg_instance_num} /s)") + ax.plot(timestamp_list_0, num_instances_list_0, color="red", label=f"instance_num(avg {avg_instance_num} /s)") ax.legend(loc='upper left') fig_filename = os.path.splitext(log_filename_0)[0] + "_instance.png" index1 = fig_filename.rfind('/') @@ -437,10 +437,10 @@ async def benchmark( else: raise ValueError(f'unknown backend {backend}') - global server_num_request - num_server = len(ip_ports) - for server_id in range(num_server): - server_num_request[server_id] = 0 + global server_num_requests + num_servers = len(ip_ports) + for server_id in range(num_servers): + server_num_requests[server_id] = 0 m = MeasureLatency() diff --git a/docs/Arguments.md b/docs/Arguments.md index d947f9f4..2b08e33a 100644 --- a/docs/Arguments.md +++ b/docs/Arguments.md @@ -9,12 +9,12 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] [--fixed-node-init-instance] [--init-instance-by-manager] [--initial-instances INITIAL_INSTANCES] - [--load-metric {consumed_speed,used_ratio}] + [--load-metric {remaining_steps,usage_ratio}] [--polling-interval POLLING_INTERVAL] [--dispatch-policy {balanced,load,queue}] [--enable-migration] [--pair-migration-frequency PAIR_MIGRATION_FREQUENCY] - [--pair-migration-policy {balanced,prefill_constrained,prefill_relaxed}] + [--pair-migration-policy {balanced,defrag_constrained,defrag_relaxed}] [--migrate-out-threshold MIGRATE_OUT_THRESHOLD] [--request-migration-policy {LCFS,SJF,LJF}] [--enable-defrag ENABLE_DEFRAG] @@ -48,8 +48,8 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] `--load-metric` - Instance load metric. -- Possible choices: consumed_speed, used_ratio -- Default: "consumed_speed" +- Possible choices: remaining_steps, usage_ratio +- Default: "remaining_steps" `--polling-interval` - Time interval(s) to update instance info and pair migration. @@ -139,11 +139,11 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] - Default: 512 `--last-stage-max-blocks` -- If the remaining blocks num < last_stage_max_blocks, do last stage migration. +- If the number of remaining blocks < last_stage_max_blocks, do last stage migration. - Default: 4 `--max-stages` -- Drop migration if stage num > max_stages. +- Drop migration if the number of stages > max_stages. - Default: 3 # Unsupported vLLM feature options diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index 0817cd24..9e15413c 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -26,7 +26,7 @@ class EngineManagerArgs: initial_instances: int = 1 fixed_node_init_instance: bool = False - load_metric: str = 'consumed_speed' + load_metric: str = 'remaining_steps' polling_interval: float = 0.05 dispatch_policy: str = 'load' @@ -34,7 +34,7 @@ class EngineManagerArgs: enable_migration: bool = True enable_defrag: bool = True pair_migration_frequency: int = 1 - pair_migration_policy: str = 'prefill_constrained' + pair_migration_policy: str = 'defrag_constrained' migrate_out_threshold: float = 3.0 request_migration_policy: str = 'SJF' @@ -87,8 +87,8 @@ def from_cli_args(cls, args: argparse.Namespace) -> 'EngineManagerArgs': # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] # Set the attributes from the parsed arguments. - engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) - return engine_args + engine_manager_args = cls(**{attr: getattr(args, attr) for attr in attrs}) + return engine_manager_args @staticmethod def add_cli_args( @@ -107,7 +107,7 @@ def add_cli_args( parser.add_argument('--load-metric', type=str, default=EngineManagerArgs.load_metric, - choices=['consumed_speed', 'used_ratio'], + choices=['remaining_steps', 'usage_ratio'], help='instance load metric') parser.add_argument('--polling-interval', type=float, @@ -130,7 +130,7 @@ def add_cli_args( parser.add_argument('--pair-migration-policy', type=str, default=EngineManagerArgs.pair_migration_policy, - choices=['balanced', 'prefill_constrained', 'prefill_relaxed'], + choices=['balanced', 'defrag_constrained', 'defrag_relaxed'], help='pair migration policy') parser.add_argument('--migrate-out-threshold', type=float, @@ -207,10 +207,10 @@ def add_cli_args( parser.add_argument('--last-stage-max-blocks', type=int, default=EngineManagerArgs.last_stage_max_blocks, - help='if the remain blocks num < last_stage_max_blocks, do last stage migration') + help='if the number pf remain blocks < last_stage_max_blocks, do last stage migration') parser.add_argument('--max-stages', type=int, default=EngineManagerArgs.max_stages, - help='drop migration if stage num > max_stages') + help='drop migration if the number of stages > max_stages') return parser diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index a3a22c9e..340096a2 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -121,9 +121,9 @@ def step(self) -> None: instance_info: InstanceInfo = self.scheduler.get_instance_info() if self.scaling_down: - instance_info.num_running_request = 1 - instance_info.num_available_gpu_block = -self.cache_config.num_gpu_blocks - instance_info.num_available_gpu_block_waiting = -self.cache_config.num_gpu_blocks + instance_info.num_running_requests = 1 + instance_info.num_available_gpu_blocks = -self.cache_config.num_gpu_blocks + instance_info.num_available_gpu_blocks_waiting = -self.cache_config.num_gpu_blocks instance_info.instance_id = self.instance_id instance_info.step_id = next(self.step_counter) @@ -136,7 +136,7 @@ def step(self) -> None: blocks = self.scheduler.block_manager.get_block_table(seq) tot_blocks.extend(blocks) tot_blocks = set(tot_blocks) - instance_info.num_block_last_running_request = len(tot_blocks) + instance_info.num_blocks_last_running_request = len(tot_blocks) self.free_request_states(instance_info.finished_request_ids) diff --git a/llumnix/backends/vllm/scheduler.py b/llumnix/backends/vllm/scheduler.py index 7b591b35..2d8f3611 100644 --- a/llumnix/backends/vllm/scheduler.py +++ b/llumnix/backends/vllm/scheduler.py @@ -71,7 +71,7 @@ def _preempt( self.last_preemption_time_dict[seq_group.request_id] = time.time() return super()._preempt(seq_group, blocks_to_swap_out, preemption_mode) - def _get_num_killed_request(self) -> int: + def _get_num_killed_requests(self) -> int: cnt = len(self.swapped) for seq_group in self.waiting: if seq_group.request_id in self.last_preemption_time_dict: @@ -187,44 +187,44 @@ def free_src_request(self, backend_request: SequenceGroup) -> None: @scheduler_lock def get_instance_info(self) -> InstanceInfo: - num_total_gpu_block = self.cache_config.num_gpu_blocks - num_free_gpu_block = self.block_manager.get_num_free_gpu_blocks() - num_used_gpu_block = num_total_gpu_block - num_free_gpu_block - gpu_cache_usage = num_used_gpu_block / num_total_gpu_block + num_total_gpu_blocks = self.cache_config.num_gpu_blocks + num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks() + num_used_gpu_blocks = num_total_gpu_blocks - num_free_gpu_blocks + gpu_cache_usage = num_used_gpu_blocks / num_total_gpu_blocks if self.waiting: - num_block_waiting_requests = [] + num_blocks_waiting_requests = [] waiting_time_waiting_requests = [] for seq_group in self.waiting: - num_prompt_token = seq_group.get_seqs()[0].get_len() - num_block = num_prompt_token / self.cache_config.block_size + num_prompt_tokens = seq_group.get_seqs()[0].get_len() + num_blocks = num_prompt_tokens / self.cache_config.block_size waiting_time = time.time() - seq_group.metrics.arrival_time - num_block_waiting_requests.append(num_block) + num_blocks_waiting_requests.append(num_blocks) waiting_time_waiting_requests.append(waiting_time) - num_block_first_waiting_request = num_block_waiting_requests[0] + num_blocks_first_waiting_request = num_blocks_waiting_requests[0] waiting_time_first_waiting_request = waiting_time_waiting_requests[0] - num_block_all_waiting_request = sum(num_block_waiting_requests) + num_blocks_all_waiting_requests = sum(num_blocks_waiting_requests) else: - num_block_first_waiting_request = 0 + num_blocks_first_waiting_request = 0 waiting_time_first_waiting_request = 0 - num_block_all_waiting_request = 0 + num_blocks_all_waiting_requests = 0 instance_info = InstanceInfo( - num_total_gpu_block=num_total_gpu_block, - num_watermark_block=self.block_manager.watermark_blocks, - num_free_gpu_block=num_free_gpu_block, - num_used_gpu_block=num_used_gpu_block, + num_total_gpu_blocks=num_total_gpu_blocks, + num_watermark_blocks=self.block_manager.watermark_blocks, + num_free_gpu_blocks=num_free_gpu_blocks, + num_used_gpu_blocks=num_used_gpu_blocks, gpu_cache_usage=gpu_cache_usage, - num_running_request=len(self.running), - num_waiting_request=len(self.waiting), - num_killed_request=self._get_num_killed_request(), - num_block_first_waiting_request=num_block_first_waiting_request, + num_running_requests=len(self.running), + num_waiting_requests=len(self.waiting), + num_killed_requests=self._get_num_killed_requests(), + num_blocks_first_waiting_request=num_blocks_first_waiting_request, waiting_time_first_waiting_request=waiting_time_first_waiting_request, - num_block_all_waiting_request=num_block_all_waiting_request, + num_blocks_all_waiting_requests=num_blocks_all_waiting_requests, inference_type=BackendInferenceType.PREFILL if self.prefilling_seq_groups \ else BackendInferenceType.DECODE, ) for seq_group in self.running: instance_info.running_seq_lens.extend([seq.get_len() for seq in seq_group.get_seqs()]) - instance_info.num_seq = len(instance_info.running_seq_lens) + instance_info.num_seqs = len(instance_info.running_seq_lens) instance_info.num_batched_tokens = sum([seq_group.get_seqs()[0].get_len() for seq_group in self.prefilling_seq_groups])\ if self.prefilling_seq_groups else len(instance_info.running_seq_lens) instance_info.finished_request_ids = [seq_group.request_id for seq_group in self.running if seq_group.is_finished()] diff --git a/llumnix/backends/vllm/worker.py b/llumnix/backends/vllm/worker.py index a18a6d43..6a4c72d6 100644 --- a/llumnix/backends/vllm/worker.py +++ b/llumnix/backends/vllm/worker.py @@ -209,8 +209,8 @@ def restart(self) -> None: self.init_cache_engine(self.cache_config) # instance_id is changed from int to str, this function should be modified if used - # def init_migration_dist_ray(self, num_instance, instance_id): - # self.ray_world_size = num_instance * self.parallel_config.world_size + # def init_migration_dist_ray(self, num_instances, instance_id): + # self.ray_world_size = num_instances * self.parallel_config.world_size # self.ray_rank = self.rank + instance_id * self.parallel_config.world_size # logger.info(f"{self.ray_world_size, self.ray_rank}") # # col.init_collective_group(world_size=self.ray_world_size, rank=self.ray_rank , backend="gloo") diff --git a/llumnix/entrypoints/llumnix_utils.py b/llumnix/entrypoints/llumnix_utils.py index 3eb43de5..a55088eb 100644 --- a/llumnix/entrypoints/llumnix_utils.py +++ b/llumnix/entrypoints/llumnix_utils.py @@ -1,3 +1,16 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import subprocess import sys import os @@ -15,6 +28,7 @@ from llumnix.logger import init_logger from llumnix.arg_utils import EngineManagerArgs + logger = init_logger(__name__) # TODO(s5u13b): Set the values through tests. @@ -29,7 +43,7 @@ def get_ip_address(): ip_address = result.stdout.decode('utf-8').strip() return ip_address -def launch_ray_cluster(ray_cluster_port: int) -> None: +def launch_ray_cluster(ray_cluster_port: int) -> subprocess.CompletedProcess: head_node_ip = os.getenv('HEAD_NODE_IP') node_ip_address = get_ip_address() try: @@ -66,6 +80,7 @@ def launch_ray_cluster(ray_cluster_port: int) -> None: sys.exit(1) logger.info("'{}' succeeed with: \n{}".format(ray_start_command, result.stdout)) ray.init(address=f"{head_node_ip}:{ray_cluster_port}", ignore_reinit_error=True, namespace='llumnix') + return result def is_gpu_available() -> bool: try: diff --git a/llumnix/entrypoints/vllm/api_server.py b/llumnix/entrypoints/vllm/api_server.py index 7625297f..07ea7ccb 100644 --- a/llumnix/entrypoints/vllm/api_server.py +++ b/llumnix/entrypoints/vllm/api_server.py @@ -36,14 +36,14 @@ logger = init_logger(__name__) engine_manager = None instances = {} -instance_num_request: Dict[str, int] = {} +instance_num_requests: Dict[str, int] = {} # request_output_queue could be None if initialzed in lifespan. request_output_queue = None server_id = None TIMEOUT_KEEP_ALIVE = 5 # seconds. request_streams: Dict[str, AsyncStream] = {} log_requests = None -num_finished_request = 0 +num_finished_requests = 0 WAIT_MANAGER_INTERVAL = 5 @@ -82,9 +82,9 @@ async def manager_generate(prompt, sampling_params, request_id) -> AsyncStream: await engine_manager.generate.remote(request_id, server_info, prompt, sampling_params) except ray.exceptions.RayActorError: try: - if instance_num_request: - instance_id = min(instance_num_request, key=instance_num_request.get) - instance_num_request[instance_id] += 1 + if instance_num_requests: + instance_id = min(instance_num_requests, key=instance_num_requests.get) + instance_num_requests[instance_id] += 1 await instances[instance_id].generate.remote(request_id, server_info, prompt, sampling_params) print("Manager is unavailable, directly pass request {} to instance {}".format(request_id, instance_id)) else: @@ -96,7 +96,7 @@ async def manager_generate(prompt, sampling_params, request_id) -> AsyncStream: if instance_id in instances: print("[manager_generate] instance {} is dead".format(instance_id)) del instances[instance_id] - del instance_num_request[instance_id] + del instance_num_requests[instance_id] return await asyncio.create_task(manager_generate(prompt, sampling_params, request_id)) return results_generator @@ -185,12 +185,12 @@ async def generate_benchmark(request: Request) -> Response: start = now_time final_output = request_output - global num_finished_request + global num_finished_requests if log_requests: # TODO(s5u13b): Use logger. print(f"Finished request {request_id}.") - num_finished_request += 1 - print(f"num_finished_request {num_finished_request}.") + num_finished_requests += 1 + print(f"num_finished_requests {num_finished_requests}.") generation = final_output.outputs[0].text num_output_tokens = len(final_output.outputs[0].token_ids) @@ -218,7 +218,7 @@ async def is_ready(): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--port", type=int, default=8003) + parser.add_argument("--port", type=int, default=8000) parser.add_argument("--ssl-keyfile", type=str, default=None) parser.add_argument("--ssl-certfile", type=str, default=None) parser.add_argument('--disable-log-requests-server', @@ -249,7 +249,7 @@ async def is_ready(): engine_manager, instance_ids, llumlets, request_output_queue = init_llumnix_components(engine_manager_args, engine_args, node_id) for idx, ins_id in enumerate(instance_ids): instances[ins_id] = llumlets[idx] - instance_num_request[ins_id] = 0 + instance_num_requests[ins_id] = 0 log_requests = not args.disable_log_requests_server # Start the api server after all the components of llumnix are ready. print(f"Start Api Server on '{args.host}:{args.port}'") diff --git a/llumnix/global_scheduler/dispatch_scheduler.py b/llumnix/global_scheduler/dispatch_scheduler.py index 65e84479..bd727413 100644 --- a/llumnix/global_scheduler/dispatch_scheduler.py +++ b/llumnix/global_scheduler/dispatch_scheduler.py @@ -27,26 +27,26 @@ def __init__(self, instance_load_calculator: InstanceLoadCalculator) -> None: self.dispatch_policy = DispatchPolicyFactory.get_policy(dispatch_policy) self.instance_load_calculator = instance_load_calculator - self.num_instance = 0 + self.num_instances = 0 self.instance_id_set: Set[str] = set() # instance info args self.instance_info: Dict[str, InstanceInfo] = {} self.sorted_instance_infos: List[InstanceInfo] = None # statistics - self.num_request = 0 - self.instance_num_request: Dict[str, int] = {} + self.num_requests = 0 + self.instance_num_requests: Dict[str, int] = {} def dispatch(self) -> str: - self.num_request += 1 + self.num_requests += 1 if isinstance(self.dispatch_policy, (Load, Queue)): self._sort_instance_infos(descending=False) - dispatch_instance_id = self.dispatch_policy.dispatch(self.instance_num_request, - self.sorted_instance_infos) - self.instance_num_request[dispatch_instance_id] += 1 - if self.num_request % 100 == 0: - logger.info("self.num_request: {}".format(self.num_request)) - for instance_id, num_request in self.instance_num_request.items(): - logger.info("Instance {} num_dispatched_request: {}".format(instance_id, num_request)) + dispatch_instance_id = self.dispatch_policy.dispatch(self.instance_num_requests, + self.sorted_instance_infos) + self.instance_num_requests[dispatch_instance_id] += 1 + if self.num_requests % 100 == 0: + logger.info("self.num_requests: {}".format(self.num_requests)) + for instance_id, num_requests in self.instance_num_requests.items(): + logger.info("Instance {} num_dispatched_requests: {}".format(instance_id, num_requests)) return dispatch_instance_id def update_instance_infos(self, @@ -55,19 +55,19 @@ def update_instance_infos(self, def add_instance(self, instance_id: str) -> None: self.instance_id_set.add(instance_id) - self.num_instance = len(self.instance_id_set) - self.instance_num_request[instance_id] = 0 + self.num_instances = len(self.instance_id_set) + self.instance_num_requests[instance_id] = 0 def remove_instance(self, instance_id: str) -> None: self.instance_id_set.remove(instance_id) - self.num_instance = len(self.instance_id_set) - del self.instance_num_request[instance_id] + self.num_instances = len(self.instance_id_set) + del self.instance_num_requests[instance_id] def _sort_instance_infos(self, descending: bool = True) -> None: instance_infos: List[InstanceInfo] = list(self.instance_info.values()) if isinstance(self.dispatch_policy, Queue): - key_attr = 'num_waiting_request' + key_attr = 'num_waiting_requests' else: key_attr = 'instance_load_dispatch_scale' self.sorted_instance_infos = sorted( @@ -82,21 +82,21 @@ def __init__(self): @abstractmethod def dispatch(self, - instance_num_request: Dict[str, int], + instance_num_requests: Dict[str, int], sorted_instance_infos: List[InstanceInfo]) -> int: pass class Balanced(DispatchPolicy): def dispatch(self, - instance_num_request: Dict[str, int], + instance_num_requests: Dict[str, int], sorted_instance_infos: List[InstanceInfo]) -> str: - # dispatch request according to the number of request dispatched to instance by manager - instance_id = min(instance_num_request, key=instance_num_request.get) + # dispatch request according to the number of requests dispatched to instance by manager + instance_id = min(instance_num_requests, key=instance_num_requests.get) return instance_id class Load(DispatchPolicy): def dispatch(self, - instance_num_request: Dict[str, int], + instance_num_requests: Dict[str, int], sorted_instance_infos: List[InstanceInfo]) -> str: instance_id = sorted_instance_infos[0].instance_id logger.info("dispatch to {}, load: {}".format(instance_id, sorted_instance_infos[0].instance_load_dispatch_scale)) @@ -104,15 +104,15 @@ def dispatch(self, class Queue(DispatchPolicy): def dispatch(self, - instance_num_request: Dict[str, int], + instance_num_requests: Dict[str, int], sorted_instance_infos: List[InstanceInfo]) -> str: - min_queue_size = sorted_instance_infos[0].num_waiting_request + min_queue_size = sorted_instance_infos[0].num_waiting_requests instance_id_list = [] for instance_info in sorted_instance_infos: - if instance_info.num_waiting_request == min_queue_size: + if instance_info.num_waiting_requests == min_queue_size: instance_id_list.append(instance_info.instance_id) instance_id = random.choice(instance_id_list) - logger.info("dispatch to {}, queue size: {}".format(instance_id, sorted_instance_infos[0].num_waiting_request)) + logger.info("dispatch to {}, queue size: {}".format(instance_id, sorted_instance_infos[0].num_waiting_requests)) return instance_id class DispatchPolicyFactory: diff --git a/llumnix/global_scheduler/global_scheduler.py b/llumnix/global_scheduler/global_scheduler.py index d8626d2e..9f872d4a 100644 --- a/llumnix/global_scheduler/global_scheduler.py +++ b/llumnix/global_scheduler/global_scheduler.py @@ -46,13 +46,13 @@ def __init__(self, global_scheduler_config.scaling_policy, self.instance_load_calculator) - self.num_instance = 0 + self.num_instances = 0 self.instance_id_set: Set[str] = set() self.instance_info: Dict[str, InstanceInfo] = {} def update_instance_infos(self, instance_infos: List[InstanceInfo]) -> None: for instance_info in instance_infos: - if instance_info.instance_id in self.instance_info: + if instance_info.instance_id in self.instance_id_set: # Llumnix have different instance load compuatation methods for dispatch/migrate/scale. instance_info.instance_load_dispatch_scale = self.instance_load_calculator.compute_instance_load(instance_info, action='dispatch') instance_info.instance_load_migrate = self.instance_load_calculator.compute_instance_load(instance_info, action='migrate') @@ -73,40 +73,42 @@ def check_scale(self) -> Tuple[str, str]: scale_up_num, scale_down_num = self.scaling_scheduler.check_scale() return scale_up_num, scale_down_num - def scale_up(self, instance_id: Union[str, Iterable[str]]) -> None: + def scale_up(self, instance_id: Union[str, Iterable[str]]) -> int: if isinstance(instance_id, str): instance_id = [instance_id,] instance_ids = list(instance_id) for ins_id in instance_ids: - if ins_id not in self.instance_info: + if ins_id not in self.instance_id_set: logger.info("scale up instance: {}".format(ins_id)) new_intance_info = self._get_empty_instance_info() new_intance_info.instance_id = ins_id self.instance_info[ins_id] = new_intance_info self._add_instance(ins_id) - logger.info("self.num_instance: {}, self.instances: {}".format(self.num_instance, self.instance_id_set)) + logger.info("self.num_instances: {}, self.instances: {}".format(self.num_instances, self.instance_id_set)) + return self.num_instances - def scale_down(self, instance_id: Union[str, Iterable[str]]) -> None: + def scale_down(self, instance_id: Union[str, Iterable[str]]) -> int: if isinstance(instance_id, str): instance_id = [instance_id,] instance_ids = list(instance_id) for ins_id in instance_ids: - if ins_id in self.instance_info: + if ins_id in self.instance_id_set: logger.info("scale down instance: {}".format(ins_id)) del self.instance_info[ins_id] self._remove_instance(ins_id) - logger.info("self.num_instance: {}, self.instances: {}".format(self.num_instance, self.instance_id_set)) + logger.info("self.num_instances: {}, self.instances: {}".format(self.num_instances, self.instance_id_set)) + return self.num_instances def _add_instance(self, instance_id: str) -> None: self.instance_id_set.add(instance_id) - self.num_instance = len(self.instance_id_set) + self.num_instances = len(self.instance_id_set) for scheduler in (self.dispatch_scheduler, self.migration_scheduler, self.scaling_scheduler): scheduler.update_instance_infos(self.instance_info) scheduler.add_instance(instance_id) def _remove_instance(self, instance_id: str) -> None: self.instance_id_set.remove(instance_id) - self.num_instance = len(self.instance_id_set) + self.num_instances = len(self.instance_id_set) for scheduler in (self.dispatch_scheduler, self.migration_scheduler, self.scaling_scheduler): scheduler.update_instance_infos(self.instance_info) scheduler.remove_instance(instance_id) diff --git a/llumnix/global_scheduler/migration_scheduler.py b/llumnix/global_scheduler/migration_scheduler.py index 11405d48..a87f833b 100644 --- a/llumnix/global_scheduler/migration_scheduler.py +++ b/llumnix/global_scheduler/migration_scheduler.py @@ -41,7 +41,7 @@ def __init__(self, migrate_out_load_threshold=migrate_out_load_threshold, instance_load_calculator=instance_load_calculator) - self.num_instance = 0 + self.num_instances = 0 self.instance_id_set: Set[str] = set() # instance info args self.instance_info: Dict[str, InstanceInfo] = None @@ -57,11 +57,11 @@ def update_instance_infos(self, def add_instance(self, instance_id: str) -> None: self.instance_id_set.add(instance_id) - self.num_instance = len(self.instance_id_set) + self.num_instances = len(self.instance_id_set) def remove_instance(self, instance_id: str) -> None: self.instance_id_set.remove(instance_id) - self.num_instance = len(self.instance_id_set) + self.num_instances = len(self.instance_id_set) def _sort_instance_infos(self, descending: bool = True) -> None: @@ -91,70 +91,72 @@ def pair_migration(self, sorted_instance_infos: List[InstanceInfo] ) -> List[Tuple[str, str]]: # migrate in instances - left_instance_infos = [i for i in sorted_instance_infos - if i.num_killed_request == 0 and i.instance_load_migrate < self.migrate_out_load_threshold] + migrate_in_instance_infos = [i for i in sorted_instance_infos + if i.num_killed_requests == 0 and i.instance_load_migrate < self.migrate_out_load_threshold] # migrate out instances - right_instance_infos = [i for i in reversed(sorted_instance_infos) - if i.num_killed_request > 0 or i.instance_load_migrate > self.migrate_out_load_threshold] + migrate_out_instance_infos = [i for i in reversed(sorted_instance_infos) + if i.num_killed_requests > 0 or i.instance_load_migrate > self.migrate_out_load_threshold] migrate_instance_pairs = [] - for i in range(min(len(left_instance_infos), len(right_instance_infos))): - load_diff_before_mig = right_instance_infos[i].instance_load_migrate - left_instance_infos[i].instance_load_migrate - left_load_after_mig = self._compute_instance_load_after_migrate(left_instance_infos[i], is_migrate_in=True) - right_load_after_mig = self._compute_instance_load_after_migrate(right_instance_infos[i], is_migrate_in=False) + for i in range(min(len(migrate_in_instance_infos), len(migrate_out_instance_infos))): + load_diff_before_mig = migrate_out_instance_infos[i].instance_load_migrate - migrate_in_instance_infos[i].instance_load_migrate + left_load_after_mig = self._compute_instance_load_after_migrate(migrate_in_instance_infos[i], is_migrate_in=True) + right_load_after_mig = self._compute_instance_load_after_migrate(migrate_out_instance_infos[i], is_migrate_in=False) # Add some constrains to reduce unnecessary migrations if left_load_after_mig > self.migrate_out_load_threshold: continue load_diff_after_mig = right_load_after_mig - left_load_after_mig - if (0 < load_diff_after_mig < load_diff_before_mig) or (left_instance_infos[i].instance_load_migrate == -np.inf): - migrate_instance_pairs.append((right_instance_infos[i].instance_id, left_instance_infos[i].instance_id)) + if (0 < load_diff_after_mig < load_diff_before_mig) or (migrate_in_instance_infos[i].instance_load_migrate == -np.inf): + migrate_instance_pairs.append((migrate_out_instance_infos[i].instance_id, migrate_in_instance_infos[i].instance_id)) return migrate_instance_pairs def _compute_instance_load_after_migrate(self, instance_info: InstanceInfo, is_migrate_in: bool) -> float: instance_info_after_migrate = copy.deepcopy(instance_info) - num_block_last_running_request = instance_info_after_migrate.num_block_last_running_request + num_blocks_last_running_request = instance_info_after_migrate.num_blocks_last_running_request if is_migrate_in: - instance_info_after_migrate.num_running_request += 1 - instance_info_after_migrate.num_free_gpu_block -= num_block_last_running_request + instance_info_after_migrate.num_running_requests += 1 + instance_info_after_migrate.num_free_gpu_blocks -= num_blocks_last_running_request else: - instance_info_after_migrate.num_running_request -= 1 - instance_info_after_migrate.num_free_gpu_block += num_block_last_running_request + instance_info_after_migrate.num_running_requests -= 1 + instance_info_after_migrate.num_free_gpu_blocks += num_blocks_last_running_request return self.instance_load_calculator.compute_instance_load(instance_info_after_migrate, action='migrate') -class PrefillConstrained(PairMigrationPolicy): +class DefragConstrained(PairMigrationPolicy): def pair_migration(self, sorted_instance_infos: List[InstanceInfo] ) -> List[Tuple[str, str]]: # migrate in instances - left_instance_infos = [i for i in sorted_instance_infos - if i.num_killed_request == 0 and i.instance_load_migrate < self.migrate_out_load_threshold] + migrate_in_instance_infos = [i for i in sorted_instance_infos + if i.num_killed_requests == 0 and i.instance_load_migrate < self.migrate_out_load_threshold] # migrate out instances - right_instance_infos = [i for i in reversed(sorted_instance_infos) - if i.num_killed_request > 0 or i.instance_load_migrate > self.migrate_out_load_threshold] + migrate_out_instance_infos = [i for i in reversed(sorted_instance_infos) + if i.num_killed_requests > 0 or i.instance_load_migrate > self.migrate_out_load_threshold] migrate_instance_pairs = [] - for i in range(min(len(left_instance_infos), len(right_instance_infos))): - # without any constrain in order to make prefill migrate happens as soon as possible - migrate_instance_pairs.append((right_instance_infos[i].instance_id, left_instance_infos[i].instance_id)) + for i in range(min(len(migrate_in_instance_infos), len(migrate_out_instance_infos))): + # without any constrain in order to make defragmentation migrate happens as soon as possible + migrate_instance_pairs.append((migrate_out_instance_infos[i].instance_id, migrate_in_instance_infos[i].instance_id)) return migrate_instance_pairs -class PrefillRelaxed(PairMigrationPolicy): +class DefragRelaxed(PairMigrationPolicy): def pair_migration(self, sorted_instance_infos: List[InstanceInfo] ) -> List[Tuple[str, str]]: # migrate in instances - left_instance_infos = [i for i in sorted_instance_infos - if i.num_killed_request == 0 and i.instance_load_migrate < self.migrate_out_load_threshold] + migrate_in_instance_infos = [i for i in sorted_instance_infos + if i.num_killed_requests == 0 and i.instance_load_migrate < self.migrate_out_load_threshold] # migrate out instances - right_instance_infos = list(reversed(sorted_instance_infos)) + migrate_out_instance_infos = list(reversed(sorted_instance_infos)) migrate_instance_pairs = [] - for i in range(min(len(left_instance_infos), len(right_instance_infos))): - migrate_instance_pairs.append((right_instance_infos[i].instance_id, left_instance_infos[i].instance_id)) + for i in range(min(len(migrate_in_instance_infos), len(migrate_out_instance_infos))): + if migrate_out_instance_infos[i].num_killed_requests != 0 \ + or migrate_out_instance_infos[i].instance_load_migrate > migrate_in_instance_infos[i].instance_load_migrate: + migrate_instance_pairs.append((migrate_out_instance_infos[i].instance_id, migrate_in_instance_infos[i].instance_id)) return migrate_instance_pairs class PairMigrationPolicyFactory: _POLICY_REGISTRY = { 'balanced': Balanced, - 'prefill_constrained': PrefillConstrained, - 'prefill_relaxed': PrefillRelaxed, + 'defrag_constrained': DefragConstrained, + 'defrag_relaxed': DefragRelaxed, } @classmethod diff --git a/llumnix/global_scheduler/scaling_scheduler.py b/llumnix/global_scheduler/scaling_scheduler.py index 99913098..3a1c6c3f 100644 --- a/llumnix/global_scheduler/scaling_scheduler.py +++ b/llumnix/global_scheduler/scaling_scheduler.py @@ -33,7 +33,7 @@ def __init__(self, instance_load_calculator=instance_load_calculator) self.instance_load_calculator = instance_load_calculator - self.num_instance = 0 + self.num_instances = 0 self.instance_id_set: Set[str] = set() # instance info args self.instance_info: Dict[str, InstanceInfo] = None @@ -43,7 +43,7 @@ def check_scale(self) -> Tuple[str, str]: scale_up_num = 0 scale_down_num = 0 # if not all instances have returned instance_info, not scale - if len(self.instance_info.keys()) < self.num_instance: + if len(self.instance_info.keys()) < self.num_instances: return scale_up_num, scale_down_num now_instances = [self.instance_info[instance_id] for instance_id in self.instance_id_set] load_metric_up = self.scaling_policy.compute_load_metric_up(now_instances) @@ -62,21 +62,21 @@ def update_instance_infos(self, def add_instance(self, instance_id: str) -> None: self.instance_id_set.add(instance_id) - self.num_instance = len(self.instance_id_set) + self.num_instances = len(self.instance_id_set) def remove_instance(self, instance_id: str) -> None: self.instance_id_set.remove(instance_id) - self.num_instance = len(self.instance_id_set) + self.num_instances = len(self.instance_id_set) def get_empty_instance_info(self) -> InstanceInfo: dummy_intance_info = InstanceInfo() dummy_intance_info.instance_id = -1 dummy_intance_info.step_id = -1 # TODO(s5u13b): Should be changed for proactive auto-scaling. - dummy_intance_info.num_total_gpu_block = np.inf - dummy_intance_info.num_available_gpu_block = np.inf - dummy_intance_info.num_free_gpu_block = np.inf - dummy_intance_info.num_available_gpu_block_waiting = np.inf + dummy_intance_info.num_total_gpu_blocks = np.inf + dummy_intance_info.num_available_gpu_blocks = np.inf + dummy_intance_info.num_free_gpu_blocks = np.inf + dummy_intance_info.num_available_gpu_blocks_waiting = np.inf return dummy_intance_info class ScalePolicy(ABC): @@ -96,13 +96,13 @@ def compute_load_metric_avg(self, instance_infos: List[InstanceInfo]) -> float: tot_instance_info = InstanceInfo() tot_instance_info.instance_id = -1 tot_instance_info.step_id = -1 - tot_instance_info.num_running_request = sum([i.num_running_request for i in instance_infos]) - tot_instance_info.num_waiting_request = sum([i.num_waiting_request for i in instance_infos]) - tot_instance_info.num_free_gpu_block = sum([i.num_free_gpu_block for i in instance_infos]) - tot_instance_info.num_total_gpu_block = sum([i.num_total_gpu_block for i in instance_infos]) - tot_instance_info.num_watermark_block = sum([i.num_watermark_block for i in instance_infos]) - tot_instance_info.num_block_all_waiting_request = sum([i.num_block_all_waiting_request for i in instance_infos]) - tot_instance_info.num_available_gpu_block = tot_instance_info.num_free_gpu_block - tot_instance_info.num_watermark_block + tot_instance_info.num_running_requests = sum([i.num_running_requests for i in instance_infos]) + tot_instance_info.num_waiting_requests = sum([i.num_waiting_requests for i in instance_infos]) + tot_instance_info.num_free_gpu_blocks = sum([i.num_free_gpu_blocks for i in instance_infos]) + tot_instance_info.num_total_gpu_blocks = sum([i.num_total_gpu_blocks for i in instance_infos]) + tot_instance_info.num_watermark_blocks = sum([i.num_watermark_blocks for i in instance_infos]) + tot_instance_info.num_blocks_all_waiting_requests = sum([i.num_blocks_all_waiting_requests for i in instance_infos]) + tot_instance_info.num_available_gpu_blocks = tot_instance_info.num_free_gpu_blocks - tot_instance_info.num_watermark_blocks return self.instance_load_calculator.compute_instance_load(tot_instance_info, action="scale") class MaxLoad(ScalePolicy): @@ -123,23 +123,23 @@ def compute_load_metric_up(self, instance_infos: List[InstanceInfo]) -> float: return self.compute_load_metric_avg(instance_infos) def compute_load_metric_down(self, instance_infos: List[InstanceInfo]) -> float: - num_instance = len(instance_infos) + num_instances = len(instance_infos) tot_instance_info = InstanceInfo() tot_instance_info.instance_id = -1 tot_instance_info.step_id = -1 # the average load after scale down the last instance - tot_instance_info.num_running_request = sum([i.num_running_request for i in instance_infos]) - tot_instance_info.num_waiting_request = sum([i.num_waiting_request for i in instance_infos]) - tot_instance_info.num_free_gpu_block = sum([i.num_free_gpu_block-i.num_total_gpu_block - if i.instance_id + 1 == num_instance else i.num_free_gpu_block + tot_instance_info.num_running_requests = sum([i.num_running_requests for i in instance_infos]) + tot_instance_info.num_waiting_requests = sum([i.num_waiting_requests for i in instance_infos]) + tot_instance_info.num_free_gpu_blocks = sum([i.num_free_gpu_blocks - i.num_total_gpu_blocks + if i.instance_id + 1 == num_instances else i.num_free_gpu_blocks for i in instance_infos]) - tot_instance_info.num_free_gpu_block = max(0, tot_instance_info.num_free_gpu_block) - tot_instance_info.num_total_gpu_block = sum([0 if i.instance_id + 1 == num_instance else i.num_total_gpu_block + tot_instance_info.num_free_gpu_blocks = max(0, tot_instance_info.num_free_gpu_blocks) + tot_instance_info.num_total_gpu_blocks = sum([0 if i.instance_id + 1 == num_instances else i.num_total_gpu_blocks for i in instance_infos]) - tot_instance_info.num_watermark_block = sum([0 if i.instance_id + 1 == num_instance else i.num_watermark_block + tot_instance_info.num_watermark_blocks = sum([0 if i.instance_id + 1 == num_instances else i.num_watermark_blocks for i in instance_infos]) - tot_instance_info.num_block_all_waiting_request = sum([i.num_block_all_waiting_request for i in instance_infos]) - tot_instance_info.num_available_gpu_block = tot_instance_info.num_free_gpu_block - tot_instance_info.num_watermark_block + tot_instance_info.num_blocks_all_waiting_requests = sum([i.num_blocks_all_waiting_requests for i in instance_infos]) + tot_instance_info.num_available_gpu_blocks = tot_instance_info.num_free_gpu_blocks - tot_instance_info.num_watermark_blocks return self.instance_load_calculator.compute_instance_load(tot_instance_info, action='scale') class ScalePolicyFactory: diff --git a/llumnix/instance_info.py b/llumnix/instance_info.py index 47ba6352..c0d076e9 100644 --- a/llumnix/instance_info.py +++ b/llumnix/instance_info.py @@ -22,34 +22,34 @@ class InstanceInfo: def __init__(self, - num_total_gpu_block: int = 0, - num_watermark_block: int= 0, - num_used_gpu_block: int = 0, - num_free_gpu_block: int = 0, + num_total_gpu_blocks: int = 0, + num_watermark_blocks: int= 0, + num_used_gpu_blocks: int = 0, + num_free_gpu_blocks: int = 0, gpu_cache_usage: float = 0.0, - num_running_request: int = 0, - num_waiting_request: int = 0, - num_killed_request: int = 0, - num_block_first_waiting_request: int = 0, + num_running_requests: int = 0, + num_waiting_requests: int = 0, + num_killed_requests: int = 0, + num_blocks_first_waiting_request: int = 0, waiting_time_first_waiting_request: int = 0, - num_block_all_waiting_request: int = 0, + num_blocks_all_waiting_requests: int = 0, inference_type: str = "", num_batched_tokens: int = 0) -> None: - self.num_total_gpu_block = num_total_gpu_block - self.num_watermark_block = num_watermark_block - self.num_used_gpu_block = num_used_gpu_block - self.num_free_gpu_block = num_free_gpu_block - self.num_available_gpu_block = self.num_free_gpu_block - self.num_watermark_block + self.num_total_gpu_blocks = num_total_gpu_blocks + self.num_watermark_blocks = num_watermark_blocks + self.num_used_gpu_blocks = num_used_gpu_blocks + self.num_free_gpu_blocks = num_free_gpu_blocks + self.num_available_gpu_blocks = self.num_free_gpu_blocks - self.num_watermark_blocks self.gpu_cache_usage = gpu_cache_usage - self.num_running_request = num_running_request - self.num_waiting_request = num_waiting_request - self.num_killed_request = num_killed_request - self.num_block_first_waiting_request = num_block_first_waiting_request + self.num_running_requests = num_running_requests + self.num_waiting_requests = num_waiting_requests + self.num_killed_requests = num_killed_requests + self.num_blocks_first_waiting_request = num_blocks_first_waiting_request self.waiting_time_first_waiting_request = waiting_time_first_waiting_request - self.num_block_all_waiting_request = num_block_all_waiting_request - self.num_available_gpu_block_waiting = self.num_available_gpu_block - self.num_block_all_waiting_request + self.num_blocks_all_waiting_requests = num_blocks_all_waiting_requests + self.num_available_gpu_blocks_waiting = self.num_available_gpu_blocks - self.num_blocks_all_waiting_requests # For instance load computation before migration. - self.num_block_last_running_request = 0 + self.num_blocks_last_running_request = 0 # For global scheduling. self.instance_load_migrate = -np.inf @@ -59,31 +59,31 @@ def __init__(self, self.inference_type = inference_type self.num_batched_tokens = num_batched_tokens self.running_seq_lens = [] - self.num_seq = 0 + self.num_seqs = 0 self.max_tot_tokens = 0 self.finished_request_ids = None # For record statistics, assigned in backend engine. - self.instance_id: None - self.step_id: None - self.timestamp: None + self.instance_id = None + self.step_id = None + self.timestamp = None self.latency = 0.0 class InstanceLoadInfo: def __init__(self, instance_info: InstanceInfo) -> None: - self.num_total_gpu_block = instance_info.num_total_gpu_block - self.num_watermark_block = instance_info.num_watermark_block - self.num_used_gpu_block = instance_info.num_used_gpu_block - self.num_free_gpu_block = instance_info.num_free_gpu_block - self.num_available_gpu_block = instance_info.num_available_gpu_block + self.num_total_gpu_blocks = instance_info.num_total_gpu_blocks + self.num_watermark_blocks = instance_info.num_watermark_blocks + self.num_used_gpu_blocks = instance_info.num_used_gpu_blocks + self.num_free_gpu_blocks = instance_info.num_free_gpu_blocks + self.num_available_gpu_blocks = instance_info.num_available_gpu_blocks - self.num_waiting_request = instance_info.num_waiting_request - self.num_running_request = instance_info.num_running_request - self.num_killed_request = instance_info.num_killed_request + self.num_waiting_requests = instance_info.num_waiting_requests + self.num_running_requests = instance_info.num_running_requests + self.num_killed_requests = instance_info.num_killed_requests - self.num_block_first_waiting_request = instance_info.num_block_first_waiting_request + self.num_blocks_first_waiting_request = instance_info.num_blocks_first_waiting_request self.waiting_time_first_waiting_request = instance_info.waiting_time_first_waiting_request - self.num_block_all_waiting_request = instance_info.num_block_all_waiting_request + self.num_blocks_all_waiting_requests = instance_info.num_blocks_all_waiting_requests self.instance_id = instance_info.instance_id self.step_id = instance_info.step_id @@ -92,7 +92,7 @@ class InstanceLoadCalculator: def __init__(self, load_metric: str, enable_defrag: bool) -> None: - assert load_metric in ['consumed_speed', 'used_ratio'] + assert load_metric in ['remaining_steps', 'usage_ratio'] self.load_metric = load_metric self.enable_defrag = enable_defrag self.load_computation_strategies: Dict[str, LoadComputationStrategy] = { @@ -122,36 +122,36 @@ def compute_instance_load(self, i: InstanceLoadInfo) -> float: class MigrationLoadComputation(LoadComputationStrategy): def compute_instance_load(self, i: InstanceLoadInfo) -> float: - assert self.load_metric in ['used_ratio', 'consumed_speed'] + assert self.load_metric in ['usage_ratio', 'remaining_steps'] instance_load = -np.inf - if self.load_metric == 'used_ratio': - instance_load = i.num_used_gpu_block / i.num_total_gpu_block - elif self.load_metric == 'consumed_speed': + if self.load_metric == 'usage_ratio': + instance_load = (i.num_used_gpu_blocks + i.num_blocks_first_waiting_request) / i.num_total_gpu_blocks + elif self.load_metric == 'remaining_steps': if not self.enable_defrag: - num_request = i.num_running_request - num_available_gpu_block = i.num_available_gpu_block + num_requests = i.num_running_requests + num_available_gpu_blocks = i.num_available_gpu_blocks else: - num_request = i.num_running_request - if i.num_waiting_request != 0: - num_request += 1 - # num_request = i.num_running_request + i.num_waiting_request - num_available_gpu_block = i.num_available_gpu_block - i.num_block_first_waiting_request - # num_available_gpu_block = i.num_available_gpu_block - i.num_block_all_waiting_request - if num_request == 0: + num_requests = i.num_running_requests + if i.num_waiting_requests != 0: + num_requests += 1 + # num_requests = i.num_running_requests + i.num_waiting_requests + num_available_gpu_blocks = i.num_available_gpu_blocks - i.num_blocks_first_waiting_request + # num_available_gpu_blocks = i.num_available_gpu_blocks - i.num_blocks_all_waiting_requests + if num_requests == 0: return -np.inf - instance_load = (num_available_gpu_block / num_request)*(-1) + instance_load = (num_available_gpu_blocks / num_requests)*(-1) return instance_load class DispatchAndScalingLoadComputation(LoadComputationStrategy): def compute_instance_load(self, i: InstanceLoadInfo) -> float: - assert self.load_metric in ['used_ratio', 'consumed_speed'] + assert self.load_metric in ['usage_ratio', 'remaining_steps'] instance_load = -np.inf - if self.load_metric == 'used_ratio': - instance_load = (i.num_used_gpu_block + i.num_block_all_waiting_request) / i.num_total_gpu_block - elif self.load_metric == 'consumed_speed': - num_request = i.num_running_request + i.num_waiting_request - num_available_gpu_block = i.num_available_gpu_block - i.num_block_all_waiting_request - if num_request == 0: + if self.load_metric == 'usage_ratio': + instance_load = (i.num_used_gpu_blocks + i.num_blocks_all_waiting_requests) / i.num_total_gpu_blocks + elif self.load_metric == 'remaining_steps': + num_requests = i.num_running_requests + i.num_waiting_requests + num_available_gpu_blocks = i.num_available_gpu_blocks - i.num_blocks_all_waiting_requests + if num_requests == 0: return -np.inf - instance_load = (num_available_gpu_block / num_request)*(-1) + instance_load = (num_available_gpu_blocks / num_requests)*(-1) return instance_load diff --git a/llumnix/llm_engine_manager.py b/llumnix/llm_engine_manager.py index 94b54866..2a72e949 100644 --- a/llumnix/llm_engine_manager.py +++ b/llumnix/llm_engine_manager.py @@ -36,6 +36,7 @@ MANAGER_ACTOR_NAME = 'manager' CLEARING_INTERVAL = 3600 +RETRIES_INTERVALS = 5.0 # TODO(yiwang): add unit test for CI # TODO(yiwang): Fix the logger when manager failover. @@ -55,7 +56,7 @@ def __init__(self, self.log_requests = log_requests - self.num_instance = 0 + self.num_instances = 0 self.enable_migration = engine_manager_args.enable_migration self.enable_scaling = engine_manager_args.enable_scaling self.max_instances = engine_manager_args.max_instances @@ -63,7 +64,7 @@ def __init__(self, logger.info("LLMEngineManager starts") logger.info("enable_migration: {}".format(self.enable_migration)) - logger.info("num_instance: {}".format(self.num_instance)) + logger.info("num_instances: {}".format(self.num_instances)) logger.info("max_instances: {}, min_instances: {}".format(self.max_instances, self.min_instances)) # TODO(yiwang): refactor auto-scaling @@ -87,7 +88,7 @@ def __init__(self, asyncio.create_task(self._clear_request_instance_loop(self.clearing_interval)) # migrate states - self.num_instance_info_update = 0 + self.num_instance_info_updates = 0 self.migrating = False # auto-scaling states @@ -107,6 +108,10 @@ async def generate( server_info: ServerInfo, *args, **kwargs,) -> None: + while self.num_instances == 0: + logger.info("No instance available temporarily, sleep {}s, " + "and retry generate request {} again....".format(RETRIES_INTERVALS, request_id)) + await asyncio.sleep(RETRIES_INTERVALS) instance_id = self.global_scheduler.dispatch() try: await self.instances[instance_id].generate.remote(request_id, server_info, *args, **kwargs) @@ -117,7 +122,7 @@ async def generate( except (ray.exceptions.RayActorError, KeyError): logger.info("[generate] instance {} is dead, regenerate request {}".format(instance_id, request_id)) self.scale_down(instance_id) - if self.num_instance != 0: + if self.num_instances != 0: asyncio.create_task(self.generate(request_id, server_info, *args, **kwargs)) async def abort(self, request_id: Union[str, Iterable[str]]) -> None: @@ -142,7 +147,7 @@ async def abort(self, request_id: Union[str, Iterable[str]]) -> None: logger.info("[abort] instance {} is dead".format(instance_id)) self.scale_down(instance_id) - async def _get_request_instance(self): + async def _get_request_instance(self) -> None: logger.info("_get_request_instance:") tasks = [instance_actor_handle.get_all_request_ids.remote() for instance_actor_handle in self.instances.values()] instance_ids = list(self.instances.keys()) @@ -180,10 +185,10 @@ async def _update_instance_info_loop(self, interval: float) -> None: logger.info("[_update_instance_info_loop] instance {} is dead".format(instance_id)) self.scale_down(instance_id) self.global_scheduler.update_instance_infos(instance_info_list) - self.num_instance_info_update += 1 + self.num_instance_info_updates += 1 # Push migrate when the instance_info have updated a certain number of times. - if self.enable_migration and self.num_instance_info_update != 0 \ - and self.num_instance_info_update % self.pair_migration_frequency == 0: + if self.enable_migration and self.num_instance_info_updates != 0 \ + and self.num_instance_info_updates % self.pair_migration_frequency == 0: asyncio.create_task(self._migrate()) if self.log_instance_info: self._log_instance_infos_to_csv(instance_info_list) @@ -250,18 +255,24 @@ async def _migrate(self) -> None: logger.error("unexpected exception occurs: {}".format(e)) logger.error("exception traceback: {}".format(traceback.format_exc())) - def scale_up(self, instance_id: Union[str, Iterable[str]], llumlet_actor_handles: List["ray.actor.ActorHandle"]) -> None: + def scale_up(self, + instance_id: Union[str, Iterable[str]], + llumlet_actor_handle: Union["ray.actor.ActorHandle", List["ray.actor.ActorHandle"]]) -> int: if isinstance(instance_id, str): instance_id = [instance_id,] instance_ids = list(instance_id) + if not isinstance(llumlet_actor_handle, list): + llumlet_actor_handle = [llumlet_actor_handle,] + llumlet_actor_handles = list(llumlet_actor_handle) for idx, ins_id in enumerate(instance_ids): if ins_id not in self.instances: self.instances[ins_id] = llumlet_actor_handles[idx] self.instance_migrating[ins_id] = False self.global_scheduler.scale_up(instance_ids) - self.num_instance = len(self.instances) + self.num_instances = len(self.instances) + return self.num_instances - def scale_down(self, instance_id: Union[str, Iterable[str]]) -> None: + def scale_down(self, instance_id: Union[str, Iterable[str]]) -> int: if isinstance(instance_id, str): instance_id = [instance_id,] instance_ids = list(instance_id) @@ -270,7 +281,8 @@ def scale_down(self, instance_id: Union[str, Iterable[str]]) -> None: del self.instances[ins_id] del self.instance_migrating[ins_id] self.global_scheduler.scale_down(instance_ids) - self.num_instance = len(self.instances) + self.num_instances = len(self.instances) + return self.num_instances def _connect_to_instances(self): actor_names_dict = ray.util.list_named_actors(True) @@ -375,20 +387,20 @@ def _init_instance_info_csv(self, engine_manager_args: EngineManagerArgs) -> Non 'instance_id', 'step_id', 'gpu_cache_usage', - 'num_available_gpu_block', + 'num_available_gpu_blocks', 'instance_load', 'max_tot_tokens', - 'num_running_request', - 'num_waiting_request', - 'num_killed_request', + 'num_running_requests', + 'num_waiting_requests', + 'num_killed_requests', 'inference_type', 'bs', 'latency', 'seq_lens', - 'num_instance', - 'num_seq', - 'num_block_first_waiting_request', - 'num_block_all_waiting_request', + 'num_instances', + 'num_seqs', + 'num_blocks_first_waiting_request', + 'num_blocks_all_waiting_requests', 'waiting_time_first_waiting_request']) def _log_instance_infos_to_csv(self, instance_infos: List[InstanceInfo]) -> None: @@ -398,19 +410,19 @@ def _log_instance_infos_to_csv(self, instance_infos: List[InstanceInfo]) -> None instance_info.instance_id, instance_info.step_id, instance_info.gpu_cache_usage, - instance_info.num_available_gpu_block, + instance_info.num_available_gpu_blocks, instance_info.instance_load_migrate, instance_info.max_tot_tokens, - instance_info.num_running_request, - instance_info.num_waiting_request, - instance_info.num_killed_request, + instance_info.num_running_requests, + instance_info.num_waiting_requests, + instance_info.num_killed_requests, instance_info.inference_type, instance_info.num_batched_tokens, instance_info.latency, instance_info.running_seq_lens, - self.num_instance, - instance_info.num_seq, - instance_info.num_block_first_waiting_request, - instance_info.num_block_all_waiting_request, + self.num_instances, + instance_info.num_seqs, + instance_info.num_blocks_first_waiting_request, + instance_info.num_blocks_all_waiting_requests, instance_info.waiting_time_first_waiting_request]) self.instance_info_file.flush() diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index a4932378..dac7cf3a 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -130,12 +130,6 @@ def migrate_out(self, dst_instance_name: str) -> List[str]: def get_instance_info(self) -> InstanceInfo: return self.backend_engine.engine.instance_info - def get_actor_name(self) -> str: - return self.actor_name - - def get_instance_id(self) -> str: - return self.instance_id - def is_ready(self) -> bool: return True diff --git a/requirements.txt b/requirements.txt index 0db75cb6..ac5953fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ aiohttp scipy pandas matplotlib +pytest-asyncio diff --git a/tests/entrypoints/test_llumnix_utils.py b/tests/entrypoints/test_llumnix_utils.py new file mode 100644 index 00000000..7a36e146 --- /dev/null +++ b/tests/entrypoints/test_llumnix_utils.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pytest +import ray + +from llumnix.arg_utils import EngineManagerArgs +from llumnix.entrypoints.llumnix_utils import (get_ip_address, + launch_ray_cluster, + init_manager, + init_request_output_queue, + retry_manager_method_sync, + retry_manager_method_async) +from llumnix.llm_engine_manager import MANAGER_ACTOR_NAME + + +def test_launch_ray_cluster(): + ip_address = get_ip_address() + os.environ['HEAD_NODE'] = '1' + os.environ['HEAD_NODE_IP'] = ip_address + result = launch_ray_cluster(30050) + assert result.returncode == 0 + +def test_init_manager(): + engine_manager_args = EngineManagerArgs() + engine_manager = init_manager(engine_manager_args) + assert engine_manager is not None + engine_manager_actor_handle = ray.get_actor(MANAGER_ACTOR_NAME, namespace='llumnix') + assert engine_manager_actor_handle is not None + assert engine_manager == engine_manager_actor_handle + ray.kill(engine_manager) + ray.shutdown() + +def test_init_request_output_queue(): + request_output_queue = init_request_output_queue() + assert request_output_queue is not None + ray.shutdown() + +def test_retry_manager_method_sync(): + engine_manager_args = EngineManagerArgs() + engine_manager = init_manager(engine_manager_args) + ret = retry_manager_method_sync(engine_manager.is_ready.remote, 'is_ready') + assert ret is True + ray.kill(engine_manager) + ray.shutdown() + +@pytest.mark.asyncio +async def test_retry_manager_method_async(): + engine_manager_args = EngineManagerArgs() + engine_manager = init_manager(engine_manager_args) + ret = await retry_manager_method_async(engine_manager.is_ready.remote, 'is_ready') + assert ret is True + ray.kill(engine_manager) + ray.shutdown() diff --git a/tests/entrypoints/vllm/api_server_manager.py b/tests/entrypoints/vllm/api_server_manager.py new file mode 100644 index 00000000..05e431c5 --- /dev/null +++ b/tests/entrypoints/vllm/api_server_manager.py @@ -0,0 +1,79 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import uvicorn +import ray +from fastapi.responses import JSONResponse, Response +from ray.util.queue import Queue as RayQueue + +from vllm.outputs import CompletionOutput, RequestOutput + +import llumnix.entrypoints.vllm.api_server +import llumnix.llm_engine_manager +from llumnix.arg_utils import EngineManagerArgs + + +app = llumnix.entrypoints.vllm.api_server.app +engine_manager = None +request_output_queue = RayQueue() +llumnix.entrypoints.vllm.api_server.request_output_queue = request_output_queue +MANAGER_ACTOR_NAME = llumnix.llm_engine_manager.MANAGER_ACTOR_NAME + + +@ray.remote(num_cpus=0) +class MockLLMEngineManager: + def __init__(self): + self._num_generates = 0 + self._num_aborts = 0 + + async def generate(self, request_id, server_info, *args, **kwargs): + self._num_generates += 1 + completion_output = CompletionOutput(0, "", [], 0.0, None) + request_output = RequestOutput(request_id, "", [], None, [completion_output], finished=True) + request_output_queue.put(request_output) + + async def abort(self, request_id): + self._num_aborts += 1 + + def testing_stats(self): + return {"num_aborted_requests": self._num_aborts} + + +def init_manager(): + engine_manager = MockLLMEngineManager.options(name=MANAGER_ACTOR_NAME, + namespace='llumnix').remote() + return engine_manager + +@app.get("/stats") +def stats() -> Response: + """Get the statistics of the engine.""" + return JSONResponse(ray.get(engine_manager.testing_stats.remote())) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser = EngineManagerArgs.add_cli_args(parser) + args = parser.parse_args() + + engine_manager = init_manager() + llumnix.entrypoints.vllm.api_server.engine_manager = engine_manager + + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="debug", + timeout_keep_alive=llumnix.entrypoints.vllm.api_server.TIMEOUT_KEEP_ALIVE) diff --git a/tests/entrypoints/vllm/test_api_server.py b/tests/entrypoints/vllm/test_api_server.py new file mode 100644 index 00000000..75624513 --- /dev/null +++ b/tests/entrypoints/vllm/test_api_server.py @@ -0,0 +1,123 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import sys +import time +from multiprocessing import Pool +from pathlib import Path +import pytest +import requests +import ray + + +def _query_server(prompt: str, max_tokens: int = 5, interface: str = 'generate') -> dict: + response = requests.post("http://localhost:8000/{}".format(interface), + json={ + "prompt": prompt, + "max_tokens": max_tokens, + "temperature": 0, + "ignore_eos": True + }) + response.raise_for_status() + return response.json() + +def _query_server_long(prompt: str) -> dict: + return _query_server(prompt, max_tokens=500) + +def _query_server_generate(prompt: str) -> dict: + return _query_server(prompt, interface='generate') + +def _query_server_generate_benchmark(prompt: str) -> dict: + return _query_server(prompt, interface='generate_benchmark') + +@pytest.fixture +def api_server(): + script_path = Path(__file__).parent.joinpath( + "api_server_manager.py").absolute() + commands = [ + sys.executable, + "-u", + str(script_path), + "--host", "127.0.0.1", + ] + uvicorn_process = subprocess.Popen(commands) + yield + uvicorn_process.terminate() + ray.shutdown() + time.sleep(1.0) + +@pytest.mark.parametrize("interface", ['generate', 'generate_benchmark']) +def test_api_server(api_server, interface: str): + """ + Run the API server and test it. + + We run both the server and requests in separate processes. + + We test that the server can handle incoming requests, including + multiple requests at the same time, and that it can handle requests + being cancelled without crashing. + """ + if interface == 'generate': + _query_server = _query_server_generate + elif interface == 'generate_benchmark': + _query_server = _query_server_generate_benchmark + + with Pool(32) as pool: + # Wait until the server is ready + prompts = ["warm up"] * 1 + result = None + while not result: + try: + for r in pool.map(_query_server, prompts): + result = r + break + except requests.exceptions.ConnectionError: + time.sleep(1) + + # Actual tests start here + # Try with 1 prompt + for result in pool.map(_query_server, prompts): + assert result + + num_aborted_requests = requests.get( + "http://localhost:8000/stats").json()["num_aborted_requests"] + assert num_aborted_requests == 0 + + # Try with 100 prompts + prompts = ["test prompt"] * 100 + for result in pool.map(_query_server, prompts): + assert result + + with Pool(32) as pool: + # Cancel requests + prompts = ["canceled requests"] * 100 + pool.map_async(_query_server_long, prompts) + time.sleep(0.01) + pool.terminate() + pool.join() + + # check cancellation stats + # give it some times to update the stats + time.sleep(1) + + num_aborted_requests = requests.get( + "http://localhost:8000/stats").json()["num_aborted_requests"] + assert num_aborted_requests > 0 + + # check that server still runs after cancellations + with Pool(32) as pool: + # Try with 100 prompts + prompts = ["test prompt after canceled"] * 100 + for result in pool.map(_query_server, prompts): + assert result diff --git a/tests/global_scheduler/test_dispatch_scheduler.py b/tests/global_scheduler/test_dispatch_scheduler.py new file mode 100644 index 00000000..bcc58a06 --- /dev/null +++ b/tests/global_scheduler/test_dispatch_scheduler.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import pytest + +from llumnix.instance_info import InstanceLoadCalculator, InstanceInfo +from llumnix.global_scheduler.dispatch_scheduler import DispatchScheduler + + +def init_dispatch_scheduler(policy='load'): + instance_load_calculator = InstanceLoadCalculator('remaining_steps', True) + dispatch_scheduler = DispatchScheduler(policy, instance_load_calculator) + return dispatch_scheduler + +@pytest.fixture +def dispatch_scheduler(): + dispatch_scheduler = init_dispatch_scheduler() + yield dispatch_scheduler + +def test_add_instance_and_remove_instance(dispatch_scheduler): + dispatch_scheduler.add_instance('instance_1') + assert dispatch_scheduler.num_instances == 1 + dispatch_scheduler.add_instance('instance_2') + assert dispatch_scheduler.num_instances == 2 + dispatch_scheduler.remove_instance('instance_1') + assert dispatch_scheduler.num_instances == 1 + dispatch_scheduler.remove_instance('instance_2') + assert dispatch_scheduler.num_instances == 0 + +def test_dispatch_balanced(): + dispatch_scheduler = init_dispatch_scheduler('balanced') + num_tests = 100 + for _ in range(num_tests): + instance_num_requests = {} + for instance_id in ['instance_1', 'instance_2', 'instance_3', 'instance_4']: + instance_num_requests[instance_id] = random.randint(1, 10) + dispatch_scheduler.instance_num_requests = instance_num_requests + min_instance_id = next(key for key, value in sorted(instance_num_requests.items(), key=lambda item: item[1])) + instance_id = dispatch_scheduler.dispatch() + assert min_instance_id == instance_id + +def test_dispatch_load(): + dispatch_scheduler = init_dispatch_scheduler('load') + num_tests = 100 + for _ in range(num_tests): + instance_num_requests = {} + instance_info_dict = {} + for instance_id in ['instance_1', 'instance_2', 'instance_3', 'instance_4']: + instance_num_requests[instance_id] = 0 + instance_info = InstanceInfo() + instance_info.instance_id = instance_id + instance_info.instance_load_dispatch_scale = random.random() + instance_info_dict[instance_id] = instance_info + dispatch_scheduler.instance_num_requests = instance_num_requests + dispatch_scheduler.instance_info = instance_info_dict + min_instance_id = next(key for key, value in sorted(instance_info_dict.items(), + key=lambda item: item[1].instance_load_dispatch_scale)) + instance_id = dispatch_scheduler.dispatch() + assert min_instance_id == instance_id + +def test_dispatch_queue(): + dispatch_scheduler = init_dispatch_scheduler('queue') + num_tests = 100 + for _ in range(num_tests): + instance_num_requests = {} + instance_info_dict = {} + for instance_id in ['instance_1', 'instance_2', 'instance_3', 'instance_4']: + instance_num_requests[instance_id] = 0 + instance_info = InstanceInfo() + instance_info.instance_id = instance_id + instance_info.num_waiting_requests = random.randint(1, 10) + instance_info_dict[instance_id] = instance_info + dispatch_scheduler.instance_num_requests = instance_num_requests + dispatch_scheduler.instance_info = instance_info_dict + min_instance_id = next(key for key, value in sorted(instance_info_dict.items(), + key=lambda item: item[1].num_waiting_requests)) + instance_id = dispatch_scheduler.dispatch() + assert instance_info_dict[min_instance_id].num_waiting_requests == instance_info_dict[instance_id].num_waiting_requests diff --git a/tests/global_scheduler/test_global_scheduler.py b/tests/global_scheduler/test_global_scheduler.py new file mode 100644 index 00000000..7ba94145 --- /dev/null +++ b/tests/global_scheduler/test_global_scheduler.py @@ -0,0 +1,87 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from vllm.utils import random_uuid + +from llumnix.config import GlobalSchedulerConfig +from llumnix.global_scheduler.global_scheduler import GlobalScheduler +from llumnix.instance_info import InstanceInfo + +from tests.global_scheduler.test_llm_engine_manager import get_instance_info_migrate_in, get_instance_info_migrate_out + + +def init_global_scheduler(): + global_scheduler_config = GlobalSchedulerConfig(0, 'remaining_steps', 'load', 'defrag_constrained', 3.0, True, 'avg_load', 10, 60) + global_scheduler = GlobalScheduler(global_scheduler_config) + return global_scheduler + +def init_instance_infos(initial_instances): + instance_infos = [] + for _ in range(initial_instances): + instance_id = random_uuid() + instance_info = InstanceInfo() + instance_info.instance_id = instance_id + instance_infos.append(instance_info) + return instance_infos + +@pytest.fixture +def global_scheduler(): + global_scheduler = init_global_scheduler() + yield global_scheduler + +def test_scale_up_and_scale_down(global_scheduler): + initial_instances = 4 + instance_infos = init_instance_infos(initial_instances) + instance_ids = [instance_info.instance_id for instance_info in instance_infos] + num_instances = global_scheduler.scale_up(instance_ids) + assert num_instances == initial_instances + instance_infos = init_instance_infos(initial_instances) + instance_ids_1 = [instance_info.instance_id for instance_info in instance_infos] + num_instances = global_scheduler.scale_down(instance_ids_1) + assert num_instances == initial_instances + num_instances = global_scheduler.scale_down(instance_ids) + assert num_instances == 0 + +def test_update_instance_infos(global_scheduler): + initial_instances = 4 + instance_infos = init_instance_infos(initial_instances) + global_scheduler.update_instance_infos(instance_infos) + assert len(global_scheduler.instance_info) == 0 + instance_ids = [instance_info.instance_id for instance_info in instance_infos] + global_scheduler.scale_up(instance_ids) + global_scheduler.update_instance_infos(instance_infos) + assert len(global_scheduler.instance_info) == initial_instances + +def test_dispatch(global_scheduler): + initial_instances = 4 + instance_infos = init_instance_infos(initial_instances) + instance_ids = [instance_info.instance_id for instance_info in instance_infos] + global_scheduler.scale_up(instance_ids) + global_scheduler.update_instance_infos(instance_infos) + instance_id = global_scheduler.dispatch() + assert instance_id in instance_ids + +def test_pair_migration(global_scheduler): + instance_id = random_uuid() + instance_id_1 = random_uuid() + instance_ids = [instance_id, instance_id_1] + instance_info_migrate_in = get_instance_info_migrate_in(instance_id) + instance_info_migrate_out = get_instance_info_migrate_out(instance_id_1) + instance_infos = [instance_info_migrate_in, instance_info_migrate_out] + global_scheduler.scale_up(instance_ids) + global_scheduler.update_instance_infos(instance_infos) + migrate_instace_pairs = global_scheduler.pair_migration() + assert migrate_instace_pairs[0][0] == instance_id_1 + assert migrate_instace_pairs[0][1] == instance_id diff --git a/tests/global_scheduler/test_llm_engine_manager.py b/tests/global_scheduler/test_llm_engine_manager.py new file mode 100644 index 00000000..c0edb419 --- /dev/null +++ b/tests/global_scheduler/test_llm_engine_manager.py @@ -0,0 +1,231 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ray +import pytest +import numpy as np +import time + +from vllm.utils import random_uuid + +from llumnix.arg_utils import EngineManagerArgs +from llumnix.llm_engine_manager import LLMEngineManager, MANAGER_ACTOR_NAME +from llumnix.instance_info import InstanceInfo + + +@ray.remote(num_cpus=1, max_concurrency=4) +class MockLlumlet: + def __init__(self, instance_id): + self.instance_id = instance_id + self.actor_name = f"instance_{instance_id}" + self.num_requests = 0 + self.request_id_set = set() + self.instance_info = None + self.num_migrate_out = 0 + + def get_instance_id(self) -> str: + return self.instance_id + + def set_instance_info(self, instance_info): + self.instance_info = instance_info + + def get_instance_info(self): + return self.instance_info + + def is_ready(self) -> bool: + return True + + def get_all_request_ids(self): + return list(self.request_id_set) + + def get_num_requests(self): + return self.num_requests + + def generate(self, request_id, server_info, *args, **kwargs): + self.request_id_set.add(request_id) + self.num_requests = len(self.request_id_set) + return self.num_requests + + def abort(self, request_id): + if isinstance(request_id, str): + request_id = (request_id,) + request_ids = set(request_id) + for request_id in request_ids: + if request_id in self.request_id_set: + self.request_id_set.remove(request_id) + self.num_requests = len(self.request_id_set) + return self.num_requests + + def migrate_out(self, dst_instance_name): + self.num_migrate_out += 1 + + def get_num_migrate_out(self): + return self.num_migrate_out + + +def init_manager(): + ray.init(ignore_reinit_error=True, namespace='llumnix') + try: + engine_manager_args = EngineManagerArgs() + engine_manager_args.log_instance_info = False + engine_manager = LLMEngineManager.from_args(engine_manager_args, None) + except ValueError: + engine_manager = ray.get_actor(MANAGER_ACTOR_NAME, namespace='llumnix') + ray.get(engine_manager.is_ready.remote()) + return engine_manager + +def init_llumlets(initial_instances): + instance_ids = [] + llumlets = [] + for _ in range(initial_instances): + instance_id = random_uuid() + instance_name = 'instance_{}'.format(instance_id) + llumlet = MockLlumlet.options(name=instance_name, + namespace='llumnix').remote(instance_id) + instance_ids.append(instance_id) + llumlets.append(llumlet) + ray.get([llumlet.is_ready.remote() for llumlet in llumlets]) + return instance_ids, llumlets + +@pytest.fixture +def engine_manager(): + engine_manager = init_manager() + ray.get(engine_manager.is_ready.remote()) + yield engine_manager + ray.kill(engine_manager) + ray.shutdown() + +@pytest.fixture +def llumlet(): + instance_id = random_uuid() + instance_name = 'instance_{}'.format(instance_id) + llumlet = MockLlumlet.options(name=instance_name, + namespace='llumnix').remote(instance_id) + ray.get(llumlet.is_ready.remote()) + return llumlet + +def test_init_manager(engine_manager): + assert engine_manager is not None + engine_manager_actor_handle = ray.get_actor(MANAGER_ACTOR_NAME, namespace='llumnix') + assert engine_manager_actor_handle is not None + assert engine_manager == engine_manager_actor_handle + +def test_init_llumlet(llumlet): + assert llumlet is not None + ray.get(llumlet.is_ready.remote()) + +# TODO(s5u13b): Add init_llumlets test. + +def test_scale_up_and_down(engine_manager): + initial_instances = 4 + instance_ids, llumlets = init_llumlets(initial_instances) + num_instances = ray.get(engine_manager.scale_up.remote(instance_ids, llumlets)) + assert num_instances == initial_instances + instance_ids_1, llumlets_1 = init_llumlets(initial_instances) + num_instances = ray.get(engine_manager.scale_down.remote(instance_ids_1)) + assert num_instances == initial_instances + num_instances = ray.get(engine_manager.scale_up.remote(instance_ids_1, llumlets_1)) + assert num_instances == initial_instances * 2 + num_instances = ray.get(engine_manager.scale_down.remote(instance_ids)) + assert num_instances == initial_instances + num_instances = ray.get(engine_manager.scale_down.remote(instance_ids_1)) + assert num_instances == 0 + +def test_connect_to_instances(): + initial_instances = 4 + instance_ids, llumlets = init_llumlets(initial_instances) + ray.get([llumlet.is_ready.remote() for llumlet in llumlets]) + engine_manager = init_manager() + instance_ids_1, llumlets_1 = init_llumlets(initial_instances) + num_instances = ray.get(engine_manager.scale_up.remote(instance_ids_1, llumlets_1)) + assert num_instances == initial_instances * 2 + num_instances = ray.get(engine_manager.scale_down.remote(instance_ids)) + assert num_instances == initial_instances + ray.kill(engine_manager) + ray.shutdown() + +def test_generate_and_abort(engine_manager, llumlet): + instance_id = ray.get(llumlet.get_instance_id.remote()) + ray.get(engine_manager.scale_up.remote(instance_id, llumlet)) + request_id = random_uuid() + num_requests = ray.get(llumlet.get_num_requests.remote()) + assert num_requests == 0 + ray.get(engine_manager.generate.remote(request_id, None, None, None)) + num_requests = ray.get(llumlet.get_num_requests.remote()) + assert num_requests == 1 + ray.get(engine_manager.abort.remote(request_id)) + num_requests = ray.get(llumlet.get_num_requests.remote()) + assert num_requests == 0 + request_id_1 = random_uuid() + request_id_2 = random_uuid() + request_ids = [request_id_1, request_id_2] + ray.get(engine_manager.abort.remote(request_ids)) + num_requests = ray.get(llumlet.get_num_requests.remote()) + assert num_requests == 0 + +def test_get_request_instance(): + instance_ids, llumlets = init_llumlets(2) + instance_id, instance_id_1 = instance_ids[0], instance_ids[1] + llumlet, llumlet_1 = llumlets[0], llumlets[1] + request_id = random_uuid() + request_id_1 = random_uuid() + ray.get(llumlet.generate.remote(request_id, None, None, None)) + ray.get(llumlet_1.generate.remote(request_id_1, None, None, None)) + num_requests = ray.get(llumlet.get_num_requests.remote()) + num_requests_1 = ray.get(llumlet_1.get_num_requests.remote()) + assert num_requests == 1 + assert num_requests_1 == 1 + engine_manager = init_manager() + ray.get(engine_manager.abort.remote(request_id)) + ray.get(engine_manager.abort.remote(request_id_1)) + num_requests = ray.get(llumlet.get_num_requests.remote()) + num_requests_1 = ray.get(llumlet_1.get_num_requests.remote()) + assert num_requests == 0 + assert num_requests_1 == 0 + ray.kill(engine_manager) + ray.shutdown() + +def get_instance_info_migrate_in(instance_id): + instance_info = InstanceInfo() + instance_info.instance_id = instance_id + instance_info.num_available_gpu_blocks = np.inf + instance_info.num_running_requests = 1 + instance_info.num_blocks_first_waiting_request = 0 + return instance_info + +def get_instance_info_migrate_out(instance_id): + instance_info = InstanceInfo() + instance_info.instance_id = instance_id + instance_info.num_available_gpu_blocks = 0 + instance_info.num_running_requests = 1 + instance_info.num_blocks_first_waiting_request = np.inf + return instance_info + +def test_update_instance_info_loop_and_migrate(engine_manager): + instance_ids, llumlets = init_llumlets(2) + instance_id, instance_id_1 = instance_ids[0], instance_ids[1] + llumlet, llumlet_1 = llumlets[0], llumlets[1] + request_id = random_uuid() + request_id_1 = random_uuid() + ray.get(llumlet.generate.remote(request_id, None, None, None)) + ray.get(llumlet_1.generate.remote(request_id_1, None, None, None)) + instance_info_migrate_out = get_instance_info_migrate_out(instance_id) + instance_info_migrate_in = get_instance_info_migrate_in(instance_id_1) + ray.get(llumlet.set_instance_info.remote(instance_info_migrate_out)) + ray.get(llumlet_1.set_instance_info.remote(instance_info_migrate_in)) + num_migrate_out = ray.get(llumlet.get_num_migrate_out.remote()) + assert num_migrate_out == 0 + ray.get(engine_manager.scale_up.remote(instance_ids, llumlets)) + time.sleep(0.2) + num_migrate_out = ray.get(llumlet.get_num_migrate_out.remote()) + assert num_migrate_out != 0 diff --git a/tests/global_scheduler/test_migration_scheduler.py b/tests/global_scheduler/test_migration_scheduler.py new file mode 100644 index 00000000..0e69c81f --- /dev/null +++ b/tests/global_scheduler/test_migration_scheduler.py @@ -0,0 +1,70 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import pytest +import numpy as np + +from llumnix.instance_info import InstanceLoadCalculator, InstanceInfo +from llumnix.global_scheduler.migration_scheduler import MigrationScheduler + + +MIGRATE_OUT_LOAD_THRESHOLD = 3.0 + + +def init_migration_scheduler(policy='balanced'): + instance_load_calculator = InstanceLoadCalculator('remaining_steps', True) + migration_scheduler = MigrationScheduler(policy, MIGRATE_OUT_LOAD_THRESHOLD, instance_load_calculator) + return migration_scheduler + +@pytest.fixture +def migration_scheduler(): + migration_scheduler = init_migration_scheduler() + yield migration_scheduler + +def test_add_instance_and_remove_instance(migration_scheduler): + migration_scheduler.add_instance('instance_1') + assert migration_scheduler.num_instances == 1 + migration_scheduler.add_instance('instance_2') + assert migration_scheduler.num_instances == 2 + migration_scheduler.remove_instance('instance_1') + assert migration_scheduler.num_instances == 1 + migration_scheduler.remove_instance('instance_2') + assert migration_scheduler.num_instances == 0 + +@pytest.mark.parametrize("policy", ['balanced', 'defrag_constrained', 'defrag_relaxed']) +def test_pair_migration(policy): + migration_scheduler = init_migration_scheduler(policy) + num_tests = 1000 + for _ in range(num_tests): + instance_info_dict = {} + for instance_id in ['instance_1', 'instance_2', 'instance_3', 'instance_4']: + instance_info = InstanceInfo() + instance_info.instance_id = instance_id + instance_info.instance_load_migrate = MIGRATE_OUT_LOAD_THRESHOLD + random.uniform(-1, 1) + instance_info.num_killed_requests = random.randint(0, 1) + instance_info.num_blocks_last_running_request = random.randint(0, 1) * np.inf + instance_info_dict[instance_id] = instance_info + migration_scheduler.instance_info = instance_info_dict + migrate_instance_pairs = migration_scheduler.pair_migration() + for migrate_out_instance, migrate_in_instance in migrate_instance_pairs: + assert migrate_out_instance != migrate_in_instance + if policy != 'defrag_relaxed': + assert instance_info_dict[migrate_out_instance].num_killed_requests > 0 \ + or instance_info_dict[migrate_out_instance].instance_load_migrate > MIGRATE_OUT_LOAD_THRESHOLD + assert instance_info_dict[migrate_in_instance].num_killed_requests == 0 \ + and instance_info_dict[migrate_in_instance].instance_load_migrate < MIGRATE_OUT_LOAD_THRESHOLD + if policy == 'balanced': + assert instance_info_dict[migrate_out_instance].num_blocks_last_running_request == 0 + if instance_info_dict[migrate_out_instance].num_killed_requests == 0: + assert instance_info_dict[migrate_out_instance].instance_load_migrate > instance_info_dict[migrate_in_instance].instance_load_migrate