diff --git a/Makefile b/Makefile index b2cd80f3..8f75c380 100644 --- a/Makefile +++ b/Makefile @@ -29,14 +29,15 @@ lint: check_pylint_installed check_pytest_installed .PHONY: test test: check_pytest_installed - @pytest -v -x --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings + @pytest -v --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings @python examlpes/offline_inference.py - @pytest -v -x tests/e2e_test/test_e2e.py - @pytest -v -x ./tests/e2e_test/test_migration.py + @pytest -v ./tests/e2e_test/test_e2e.py + @pytest -v ./tests/e2e_test/test_bench.py + @pytest -v ./tests/e2e_test/test_migration.py .PHONY: unit_test unit_test: check_pytest_installed - @pytest -v -x --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings + @pytest -v --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings .PHONY: offline_test offline_test: @@ -44,15 +45,15 @@ offline_test: .PHONY: e2e_test e2e_test: - @pytest -v -x tests/e2e_test/test_e2e.py + @pytest -v ./tests/e2e_test/test_e2e.py .PHONY: bench_test bench_test: - @pytest -v -x ./tests/e2e_test/test_bench.py + @pytest -v ./tests/e2e_test/test_bench.py .PHONY: migration_test migration_test: - @pytest -v -x ./tests/e2e_test/test_migration.py + @pytest -v ./tests/e2e_test/test_migration.py #################### pygloo install for gloo migration backend begin #################### diff --git a/configs/base.yml b/configs/base.yml index 70358339..b9ee7077 100644 --- a/configs/base.yml +++ b/configs/base.yml @@ -16,7 +16,7 @@ MANAGER: ENABLE_MIGRATION: True ENABLE_DEFRAG: True - REQUEST_MIGRATION_POLICY: 'SJF' + REQUEST_MIGRATION_POLICY: 'SR' MIGRATION_BACKEND: 'gloo' MIGRATION_BUFFER_BLOCKS: 512 diff --git a/docs/Arguments.md b/docs/Arguments.md index a2584417..09841a5b 100644 --- a/docs/Arguments.md +++ b/docs/Arguments.md @@ -17,7 +17,7 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] [--pair-migration-frequency PAIR_MIGRATION_FREQUENCY] [--pair-migration-policy {balanced,defrag_constrained,defrag_relaxed}] [--migrate-out-threshold MIGRATE_OUT_THRESHOLD] - [--request-migration-policy {LCFS,SJF,LJF}] + [--request-migration-policy {LCR,SR,LR,FCW,FCWSR}] [--enable-defrag ENABLE_DEFRAG] [--enable-scaling] [--min-instances MIN_INSTANCES] @@ -90,8 +90,8 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] `--request-migration-policy` - Request migration policy. -- Possible choices: LCFS, SJF, LJF -- Default: "SJF" +- Possible choices: LCR, SR, LR, FCW, FCWSR +- Default: "SR" `--enable-defrag` - Enable defragmentation through migration based on virtual usage. diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index 37b3bbc6..1c4c54b4 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -22,6 +22,7 @@ from llumnix.config import LlumnixConfig, get_llumnix_config from llumnix.config.default import _C + class LlumnixArgumentParser(argparse.ArgumentParser): def __init__(self, *args, **kwargs): self.cur_namespace = "llumnix" @@ -228,7 +229,11 @@ def add_cli_args( parser.add_argument('--dispatch-policy', type=str, choices=['balanced', 'load', 'queue', 'flood'], - help='request dispatch policy') + help='The request dispatch policy.\n\n' + '* "balanced" dispatch request to the instance with minimum requests dispatched.\n' + '* "load" dispatch request to the instance with lowest instance load.\n' + '* "queue" dispatch request to the instance with minimum waiting request queue length.\n' + '* "flood" dispatch request to the instance with maximum requests dispatched.\n') parser.add_argument('--num-available-dispatch-instances', type=int, help='number of available instances for dispatching') @@ -242,14 +247,25 @@ def add_cli_args( parser.add_argument('--pair-migration-policy', type=str, choices=['balanced', 'defrag_constrained', 'defrag_relaxed'], - help='pair migration policy') + help='The pair migration policy.\n\n' + '* "balanced" pair migration to make the instance load of instance more balanced.\n' + '* "defrag_constrained" pair migration without balanced constraint to ' + 'achieve defragmentation thoroughly (with instance constraints).\n' + '* "defrag_relaxed" pair migration to without balanced constraint ' + 'to achieve defragmentation thoroughly (without instance constraints).\n') parser.add_argument('--migrate-out-threshold', type=float, help='migrate out instance load threshold') parser.add_argument('--request-migration-policy', type=str, - choices=['LCFS', 'SJF', 'LJF'], - help='request migration policy') + default=None, + choices=['LCR', 'SR', 'LR', 'FCW', 'FCWSR'], + help='The request migration policy.\n\n' + '* "LCR" migrate the running request last come.\n' + '* "SR" migrate the running request shortest.\n' + '* "LR" migrate the running request longest.\n' + '* "FCW" migrate the waiting request first come.\n' + '* "FCWSR" migrate the waiting request first come and running request shortest.\n') parser.add_argument('--enable-defrag', type=bool, help='enable defragmentation through migration based on virtual usage') diff --git a/llumnix/backends/backend_interface.py b/llumnix/backends/backend_interface.py index 16a8ac1f..28e1e802 100644 --- a/llumnix/backends/backend_interface.py +++ b/llumnix/backends/backend_interface.py @@ -13,9 +13,9 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Iterable, List, Union +from typing import Iterable, List, Union, Deque -from llumnix.llumlet.request import LlumnixRequest +from llumnix.llumlet.request import LlumnixRequest, RequestStatus from llumnix.server_info import ServerInfo class EngineState(str, Enum): @@ -99,14 +99,21 @@ def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_st raise NotImplementedError @abstractmethod - def get_running_queue(self) -> List[LlumnixRequest]: + def get_running_queue(self) -> Deque[LlumnixRequest]: """ Return backend's running queue. """ raise NotImplementedError @abstractmethod - def remove_running_request(self, request_id: str) -> None: + def get_waiting_queue(self) -> Deque[LlumnixRequest]: + """ + Return backend's waiting queue. + """ + raise NotImplementedError + + @abstractmethod + def remove_running_request(self, request_id: str) -> bool: """ Removes a request from the backend's running queue. @@ -117,6 +124,26 @@ def remove_running_request(self, request_id: str) -> None: Args: request_id: A string identifier for the request that is to be removed from the running queue. This ID uniquely identifies the request within the backend system. + + Returns: + True if the request was successfully removed from the running queue, False otherwise. + """ + raise NotImplementedError + + @abstractmethod + def remove_waiting_request(self, request_id: str) -> bool: + """ + Removes a request from the backend's waiting queue. + + This method is responsible for safely halting and removing an active request from the waiting + queue of the backend engine. This action is performed in waiting request migration. + + Args: + request_id: A string identifier for the request that is to be removed from the waiting + queue. This ID uniquely identifies the request within the backend system. + + Returns: + True if the request was successfully removed from the waiting queue, False otherwise. """ raise NotImplementedError @@ -164,17 +191,25 @@ def pop_migrating_out_requests_last_stage(self) -> List[LlumnixRequest]: raise NotImplementedError @abstractmethod - def pre_alloc(self, request_id: str, block_num: int) -> List[int]: + def pre_alloc(self, + request_id: str, + request_status: RequestStatus, + request_arrival_time: float, + block_num: int) -> List[int]: """Pre-allocates cache blocks for a migrating request. This method selects a specified number of free cache blocks to be reserved for an incoming migration request identified by the given request ID. It updates the pre-allocation cache dictionary with the allocated blocks, which ensures that these blocks are not used by - another process until the migration is finished. + another process until the migration is finished. For the waiting request, it only reserves + free cache blocks when the request is the earliest arrival one among the requests of dst instance's + waiting queue. Args: request_id: The unique identifier of the migration request for which cache blocks are to be pre-allocated. + request_status: The status (waiting/running) of the request. + request_arrival_time: The arrival time of the request. block_num: The number of cache blocks that need to be pre-allocated for the request. Returns: @@ -187,9 +222,8 @@ def add_running_request(self, backend_request: LlumnixRequest) -> None: """ Adds a backend request to the running queue for processing. - This method enqueues a backend request into engine running queue, marking it for - active processing. It is used when a suspend migrating request should be added back - to running queue. + This method enqueues a backend request into engine running queue. + It is used when a suspend migrating request should be added back to running queue. Args: backend_request: An object representing the backend request. The type of this @@ -199,19 +233,17 @@ def add_running_request(self, backend_request: LlumnixRequest) -> None: raise NotImplementedError @abstractmethod - def is_request_running(self, backend_request: LlumnixRequest) -> bool: - """Checks if a given backend request is currently in the running queue. + def add_waiting_request(self, backend_request: LlumnixRequest) -> None: + """ + Adds a backend request to the waiting queue for processing. - This method determines whether a backend request is present and actively being processed - in the running queue. + This method enqueues a backend request into engine waiting queue. + It is used when a suspend migrating request should be added back to waiting queue. Args: backend_request: An object representing the backend request. The type of this object is dependent on the backend implementation and the details of the request. - - Returns: - True if the backend request is currently in the running queue; False otherwise. """ raise NotImplementedError diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index 4b2a076d..59b41fa7 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -13,7 +13,7 @@ import time import traceback -from typing import Any, List, Optional, Dict, Union, Iterable, Tuple +from typing import Any, List, Optional, Dict, Union, Iterable, Tuple, Deque from collections import defaultdict import threading import asyncio @@ -34,7 +34,7 @@ from llumnix.instance_info import InstanceInfo from llumnix.backends.backend_interface import BackendInterface, EngineState from llumnix.backends.vllm.scheduler import SchedulerLlumnix -from llumnix.backends.vllm.sequence import SequenceGroupLlumnix +from llumnix.backends.vllm.sequence import SequenceGroupLlumnix, RequestStatus from llumnix.backends.profiling import LatencyMemData from llumnix.server_info import ServerInfo from llumnix.internal_config import MigrationConfig @@ -199,7 +199,7 @@ def _process_model_outputs( # TODO(ZeldaHuang): Use LlumnixRequestOutput to store llumnix output args. return request_outputs, server_infos - async def step_async(self) -> None: + async def step_async(self) -> Tuple[List[RequestOutput], List[ServerInfo]]: step_begin_time = time.time() request_outputs, server_infos = await super().step_async() for request_output in request_outputs: @@ -295,9 +295,11 @@ def __init__( self.worker_handle_list = self.engine.model_executor.workers.copy() if len(self.worker_handle_list) + 1 == self.engine.parallel_config.world_size: self.worker_handle_list.insert(0, ray.get_actor(f"instance_{self.instance_id}", namespace="llumnix")) - self._run_workers("init_migration", instance_id=instance_id, migration_config=migration_config,\ - src_worker_handle_list=self.worker_handle_list, - placement_group=placement_group, node_id=node_id) + self._run_workers("init_migration", instance_id=instance_id, + migration_config=migration_config, + src_worker_handle_list=self.worker_handle_list, + placement_group=placement_group, + node_id=node_id) self.state = EngineState.INIT logger.info("engine ({}) current state {}".format(self.instance_id, self.state)) @@ -350,15 +352,22 @@ def commit_dst_request(self, backend_request: SequenceGroupLlumnix) -> None: logger.info("add seq {} to block table".format(seq.seq_id)) pre_alloc_blocks = self.engine.scheduler.pre_alloc_cache_dict.pop(backend_request.request_id) self.engine.scheduler.block_manager.add_block_table(pre_alloc_blocks, seq.seq_id) - backend_request.reset_migration_args() - self.add_running_request(backend_request) + backend_request.reset_migration_args_dst() + assert backend_request.status in [RequestStatus.WAITING_MIGRATING, RequestStatus.RUNNING_MIGRATING], \ + "The status of request migrated to dst instance should be \ + RequestStatus.WAITING_MIGRATING or RequestStatus.RUNNING_MIGRATING" + if backend_request.status == RequestStatus.WAITING_MIGRATING: + self.add_waiting_request(backend_request) + else: # RUNNING_MIGRATING: + backend_request.reset_status() + self.add_running_request(backend_request) async def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None: await dst_ray_actor.execute_engine_method.remote("_run_workers", - "migrate_cache", - dst_blocks=dst_blocks, - src_blocks=src_blocks, - src_worker_handle_list=self.worker_handle_list) + "migrate_cache", + dst_blocks=dst_blocks, + src_blocks=src_blocks, + src_worker_handle_list=self.worker_handle_list) def _run_workers(self, *args, **kwargs): # pylint: disable=protected-access @@ -373,15 +382,21 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: request_ids = set(request_id) return self.engine.abort_request(request_ids) - def get_running_queue(self) -> List[SequenceGroupLlumnix]: + def get_running_queue(self) -> Deque[SequenceGroupLlumnix]: return self.engine.scheduler.get_running_queue() + def get_waiting_queue(self) -> Deque[SequenceGroupLlumnix]: + return self.engine.scheduler.get_waiting_queue() + def get_request_incremental_blocks(self, *args, **kwargs) -> List[int]: return self.engine.scheduler.get_request_incremental_blocks(*args, **kwargs) - def remove_running_request(self, *args, **kwargs) -> None: + def remove_running_request(self, *args, **kwargs) -> bool: return self.engine.scheduler.remove_running_request(*args, **kwargs) + def remove_waiting_request(self, *args, **kwargs) -> bool: + return self.engine.scheduler.remove_waiting_request(*args, **kwargs) + def add_migrating_out_request_last_stage(self, *args, **kwargs) -> None: return self.engine.scheduler.add_migrating_out_request_last_stage(*args, **kwargs) @@ -400,8 +415,8 @@ def should_abort_migration(self, *args, **kwargs) -> bool: def add_running_request(self, *args, **kwargs) -> None: return self.engine.scheduler.add_running_request(*args, **kwargs) - def is_request_running(self, *args, **kwargs) -> bool: - return self.engine.scheduler.is_request_running(*args, **kwargs) + def add_waiting_request(self, *args, **kwargs) -> None: + return self.engine.scheduler.add_waiting_request(*args, **kwargs) def free_dst_pre_alloc_cache(self, *args, **kwargs) -> None: return self.engine.scheduler.free_dst_pre_alloc_cache(*args, **kwargs) diff --git a/llumnix/backends/vllm/migration_backend.py b/llumnix/backends/vllm/migration_backend.py index e69f3479..950c1b31 100644 --- a/llumnix/backends/vllm/migration_backend.py +++ b/llumnix/backends/vllm/migration_backend.py @@ -286,15 +286,15 @@ def get_migration_backend(migration_config: MigrationConfig, cache_engine: Cache .format(migration_config.migration_buffer_blocks, cache_engine.num_gpu_blocks)) migration_config.migration_buffer_blocks = cache_engine.num_gpu_blocks - target_col = None + target_migration_backend = None backend = migration_config.migration_backend - assert backend in ['nccl', 'gloo', 'rpc'], "Unsupported backend: {} for VLLM".format(backend) + assert backend in ['nccl', 'gloo', 'rpc'], "Unsupported migration backend: {} for llumnix".format(backend) if backend in ['nccl', 'gloo']: - target_col = RayColMigrationBackend(migration_config, cache_engine, local_rank, scheduling_strategy, + target_migration_backend = RayColMigrationBackend(migration_config, cache_engine, local_rank, scheduling_strategy, is_driver_worker, gpu_cache) else: - target_col = RayRpcMigrationBackend(migration_config, cache_engine, worker_rank, worker_handle_list, + target_migration_backend = RayRpcMigrationBackend(migration_config, cache_engine, worker_rank, worker_handle_list, scheduling_strategy, is_driver_worker, gpu_cache) - return target_col + return target_migration_backend diff --git a/llumnix/backends/vllm/scheduler.py b/llumnix/backends/vllm/scheduler.py index a14db0b3..4c6403ae 100644 --- a/llumnix/backends/vllm/scheduler.py +++ b/llumnix/backends/vllm/scheduler.py @@ -13,19 +13,24 @@ from asyncio.log import logger import time -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Deque from collections import deque from vllm.core.block_manager_v1 import BlockSpaceManagerV1, BlockTable from vllm.core.scheduler import (Scheduler, PreemptionMode, SequenceStatus, SequenceGroupMetadata, SchedulerOutputs) +from vllm.core.policy import PolicyFactory +from vllm.sequence import SequenceGroup +from vllm.core.interfaces import AllocStatus from llumnix.instance_info import InstanceInfo from llumnix.logger import init_logger -from llumnix.llumlet.request import RequestInferenceType +from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType, RequestStatus from llumnix.backends.vllm.sequence import SequenceGroupLlumnix + logger = init_logger(__name__) + # TODO(ZeldaHuang): adapt prefix cache and sliding window, now use v1 manager class BlockManagerLlumnix(BlockSpaceManagerV1): def get_free_blocks(self, num_required_blocks: int) -> BlockTable: @@ -76,9 +81,12 @@ def _get_num_killed_requests(self) -> int: cnt += 1 return cnt - def get_running_queue(self): + def get_running_queue(self) -> Deque[SequenceGroupLlumnix]: return self.running + def get_waiting_queue(self) -> Deque[SequenceGroupLlumnix]: + return self.waiting + def get_all_request_ids(self) -> List[str]: request_ids : List[str] = [] for state_queue in [self.waiting, self.running, self.swapped]: @@ -86,18 +94,26 @@ def get_all_request_ids(self) -> List[str]: request_ids.append(seq_group.request_id) return request_ids - def get_request_incremental_blocks(self, backend_request: SequenceGroupLlumnix, pre_stage_num_blocks: int) -> List[int]: + def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_stage_num_blocks: int) -> List[int]: seq = backend_request.get_seqs()[0] blocks = self.block_manager.get_block_table(seq) return blocks[pre_stage_num_blocks:] - def remove_running_request(self, request_id: str) -> None: + def remove_running_request(self, request_id: str) -> bool: for seq_group in self.running: if seq_group.request_id == request_id: - seq = seq_group.get_seqs()[0] self.running.remove(seq_group) - seq.status = SequenceStatus.WAITING - break + seq_group.set_status(RequestStatus.RUNNING_MIGRATING) + return True + return False + + def remove_waiting_request(self, request_id: str) -> bool: + for seq_group in self.waiting: + if seq_group.request_id == request_id: + self.waiting.remove(seq_group) + seq_group.set_status(RequestStatus.WAITING_MIGRATING) + return True + return False def add_migrating_out_request_last_stage(self, backend_request: SequenceGroupLlumnix) -> None: self.migrating_out_request_last_stage.append(backend_request) @@ -110,7 +126,17 @@ def pop_migrating_out_requests_last_stage(self) -> List[SequenceGroupLlumnix]: self.migrating_out_request_last_stage.clear() return migrating_out_request_last_stage - def pre_alloc(self, request_id: str, block_num: int) -> List[int]: + def pre_alloc(self, + request_id: str, + request_status: RequestStatus, + request_arrival_time: float, + block_num: int) -> List[int]: + # Only migrate waiting request when the waiting request is the earliest arrival one + # among the requests of dst instance's waiting queue. + if request_status == RequestStatus.WAITING_MIGRATING: + if (self.waiting and request_arrival_time > self.waiting[0].arrival_time) \ + or block_num * self.cache_config.block_size > self.prompt_limit: + return [] blocks = self.block_manager.get_free_blocks(block_num) pre_blocks = self.pre_alloc_cache_dict.get(request_id, []) pre_blocks.extend(blocks) @@ -118,13 +144,37 @@ def pre_alloc(self, request_id: str, block_num: int) -> List[int]: blocks = [block.block_number for block in blocks] return blocks - def add_running_request(self, backend_request: SequenceGroupLlumnix) -> None: - seq = backend_request.get_seqs()[0] - seq.status = SequenceStatus.RUNNING + def add_running_request(self, backend_request: LlumnixRequest) -> None: + self._set_status(backend_request, status_to=SequenceStatus.RUNNING) self.running.append(backend_request) - def is_request_running(self, backend_request: SequenceGroupLlumnix) -> bool: - return backend_request in self.running + def add_waiting_request(self, backend_request: LlumnixRequest) -> None: + self._set_status(backend_request, status_to=SequenceStatus.WAITING) + # pylint: disable=E0203 + self.waiting.append(backend_request) + fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs") + self.waiting = fcfs_policy.sort_by_priority(time.time(), self.waiting) + + def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + if seq_group.status == RequestStatus.WAITING_MIGRATING: + return AllocStatus.OK + return super().can_allocate(seq_group) + + def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: + # Change seq status to running, but request status is still waiting_migrating. + if seq_group.status == RequestStatus.WAITING_MIGRATING: + # For the waiting request migrated in, blocks have already been allocated when pre alloc. + self._set_status(seq_group, status_to=SequenceStatus.RUNNING) + seq_group.reset_status() + else: + super()._allocate_and_set_running(seq_group) + + def _set_status(self, + seq_group: SequenceGroup, + status_to: SequenceStatus, + status_from: SequenceStatus = None): + for seq in seq_group.get_seqs(status=status_from): + seq.status = status_to def free_dst_pre_alloc_cache(self, request_id: str = None) -> None: if request_id: @@ -132,6 +182,7 @@ def free_dst_pre_alloc_cache(self, request_id: str = None) -> None: # pylint: disable=protected-access self.block_manager._free_block_table(blocks) else: + # TODO(s5u13b): Only effective with one-to-one migration restriction. # Clear all pre-allocated cache of dst instance when src instance encounters exception. request_ids = list(self.pre_alloc_cache_dict.keys()) for req_id in request_ids: @@ -141,7 +192,7 @@ def free_dst_pre_alloc_cache(self, request_id: str = None) -> None: def free_src_request(self, backend_request: SequenceGroupLlumnix) -> None: seq = backend_request.get_seqs()[0] - logger.info("free seq {}".format(seq.seq_id)) + logger.info("free request: {}, free seq: {}".format(backend_request.request_id, seq.seq_id)) self.free_seq(seq) def _get_instance_info(self, scheduled_seq_groups: List[SequenceGroupLlumnix]) -> InstanceInfo: @@ -184,15 +235,18 @@ def _get_instance_info(self, scheduled_seq_groups: List[SequenceGroupLlumnix]) - if scheduled_seq_groups: instance_info.inference_type = scheduled_seq_groups[-1].inference_type # TODO(ZeldaHuang) adapt chunked-prefill - instance_info.num_batched_tokens = sum([seq_group.request_len for seq_group in scheduled_seq_groups])\ - if instance_info.inference_type == RequestInferenceType.PREFILL else len(instance_info.running_seq_lens) - instance_info.finished_request_ids = [seq_group.request_id for seq_group in self.running if seq_group.is_finished()] + instance_info.num_batched_tokens = sum([seq_group.request_len for seq_group in scheduled_seq_groups]) \ + if instance_info.inference_type == RequestInferenceType.PREFILL \ + else len(instance_info.running_seq_lens) + instance_info.finished_request_ids = [seq_group.request_id for seq_group in self.running if seq_group.finished] return instance_info def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_group_metadata_list, scheduler_outputs = super().schedule() self.update_instance_info_callback(self._get_instance_info([scheduled_seq_group.seq_group \ for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups])) + for seq_group in self.waiting: + seq_group.try_schedule_times += 1 return seq_group_metadata_list, scheduler_outputs def _schedule_running(self, running_queue: deque, *args, **kwargs): diff --git a/llumnix/backends/vllm/sequence.py b/llumnix/backends/vllm/sequence.py index 3c41a5c6..5964f96d 100644 --- a/llumnix/backends/vllm/sequence.py +++ b/llumnix/backends/vllm/sequence.py @@ -11,9 +11,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from vllm.sequence import SequenceGroup +from vllm.sequence import SequenceGroup, SequenceStatus -from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType +from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType, RequestStatus class SequenceGroupLlumnix(SequenceGroup, LlumnixRequest): @@ -41,3 +41,29 @@ def inference_type(self) -> RequestInferenceType: if self.is_prefill(): return RequestInferenceType.PREFILL return RequestInferenceType.DECODE + + @property + def finished(self) -> bool: + return self.get_seqs()[0].is_finished() + + @property + def arrival_time(self) -> float: + return self.metrics.arrival_time + + @property + def status(self) -> RequestStatus: + if self._status: + return self._status + status = self.get_seqs()[0].status + if status == SequenceStatus.RUNNING: + request_status = RequestStatus.RUNNING + elif status == SequenceStatus.WAITING: + request_status = RequestStatus.WAITING + else: + request_status = RequestStatus.FINISHED + return request_status + + @property + def prefill_num_blocks(self) -> int: + # Get the prefill len of the waiting request. + return len(self.get_seqs()[0].logical_token_blocks) diff --git a/llumnix/backends/vllm/utils.py b/llumnix/backends/vllm/utils.py index 8aafc9f1..7e49720a 100644 --- a/llumnix/backends/vllm/utils.py +++ b/llumnix/backends/vllm/utils.py @@ -48,8 +48,7 @@ def check_engine_args(engine_args: AsyncEngineArgs, engine_manager_args: EngineM engine_config = engine_args.create_engine_config() parallel_config = engine_config.parallel_config if parallel_config.world_size > 1 and migration_config.migration_backend == 'nccl': - # TODO(s5u13b): fix logger - print("Llumnix does not support TP or PP enabled model when the migration backend is nccl, change migration backend to gloo.") + logger.info("Llumnix does not support TP or PP enabled model when the migration backend is nccl, change migration backend to gloo.") engine_manager_args.migration_backend = 'gloo' detect_unsupported_feature(engine_args) diff --git a/llumnix/backends/vllm/worker.py b/llumnix/backends/vllm/worker.py index 2b0cab33..e38c3423 100644 --- a/llumnix/backends/vllm/worker.py +++ b/llumnix/backends/vllm/worker.py @@ -111,10 +111,8 @@ def migrate_cache(self, src_worker_handle_list, src_blocks: List[int], dst_block start_time = time.time() try: self.migration_backend.migrate_cache(src_worker_handle, src_blocks, dst_blocks) - # pylint: disable=broad-except - except Exception as e: - logger.info("[migrate_cache] self.rank: {}, src_worker_handle {}, meet error : {}" - .format(self.rank, src_worker_handle, e)) + except ray.exceptions.RayActorError: + logger.info("[migrate_cache] self.rank: {}, src_worker_handle {} is dead".format(self.rank, src_worker_handle)) end_time = time.time() total_kv_cache_size = len(src_blocks) * CacheEngine.get_cache_block_size( diff --git a/llumnix/config/default.py b/llumnix/config/default.py index 2a6c7758..358d9e1b 100644 --- a/llumnix/config/default.py +++ b/llumnix/config/default.py @@ -95,7 +95,7 @@ # Migrate out instance load threshold _C.MANAGER.MIGRATE_OUT_THRESHOLD = 3.0 # Request migration policy -_C.MANAGER.REQUEST_MIGRATION_POLICY = 'SJF' +_C.MANAGER.REQUEST_MIGRATION_POLICY = 'SR' # Enable defragmentation through migration based on virtual usage _C.MANAGER.ENABLE_DEFRAG = False # Drop migration if the number of stages > max_stages diff --git a/llumnix/global_scheduler/dispatch_scheduler.py b/llumnix/global_scheduler/dispatch_scheduler.py index 27458f26..51a0d36b 100644 --- a/llumnix/global_scheduler/dispatch_scheduler.py +++ b/llumnix/global_scheduler/dispatch_scheduler.py @@ -71,7 +71,7 @@ def remove_instance(self, instance_id: str) -> None: del self.instance_num_requests[instance_id] if instance_id in self.available_dispatch_instance_set: self.available_dispatch_instance_set.remove(instance_id) - + # TODO(KuilongCui): Check it when there is no decode instance. if self.num_instances >= self.num_dispatch_instances: free_instance_id = next(iter(self.instance_id_set - self.available_dispatch_instance_set)) self.available_dispatch_instance_set.add(free_instance_id) diff --git a/llumnix/global_scheduler/migration_scheduler.py b/llumnix/global_scheduler/migration_scheduler.py index 77fd9b25..3445b210 100644 --- a/llumnix/global_scheduler/migration_scheduler.py +++ b/llumnix/global_scheduler/migration_scheduler.py @@ -170,10 +170,8 @@ def pair_migration(self, migrate_instance_pairs = [] for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))): load_diff_before_mig = sorted_src_instance_infos[i].instance_load_migrate - sorted_dst_instance_infos[i].instance_load_migrate - left_load_after_mig = self._compute_instance_load_after_migrate(sorted_src_instance_infos[i], is_migrate_in=False) right_load_after_mig = self._compute_instance_load_after_migrate(sorted_dst_instance_infos[i], is_migrate_in=True) - # Add some constrains to reduce unnecessary migrations if right_load_after_mig > self.migrate_out_load_threshold: continue @@ -186,14 +184,12 @@ def pair_migration(self, def _compute_instance_load_after_migrate(self, instance_info: InstanceInfo, is_migrate_in: bool) -> float: instance_info_after_migrate = copy.deepcopy(instance_info) num_blocks_last_running_request = instance_info_after_migrate.num_blocks_last_running_request - if is_migrate_in: instance_info_after_migrate.num_running_requests += 1 instance_info_after_migrate.num_free_gpu_blocks -= num_blocks_last_running_request else: instance_info_after_migrate.num_running_requests -= 1 instance_info_after_migrate.num_free_gpu_blocks += num_blocks_last_running_request - return self.instance_load_calculator.compute_instance_load(instance_info_after_migrate, action='migrate') class DefragConstrained(PairMigrationPolicy): diff --git a/llumnix/llm_engine_manager.py b/llumnix/llm_engine_manager.py index d98f3a8e..66739632 100644 --- a/llumnix/llm_engine_manager.py +++ b/llumnix/llm_engine_manager.py @@ -15,7 +15,6 @@ import time import csv import os -import math from typing import Dict, List, Tuple, Union, Iterable from collections import defaultdict import traceback @@ -222,22 +221,23 @@ async def _clear_request_instance_loop(self, interval: float): async def _push_migrations(self) -> None: # Push migrate when the instance_info have updated a certain number of times. if self.enable_pd_disagg: - asyncio.create_task(self._migrate(PairMigrationConstraints.PREFILL_2_DECODING, math.inf)) - asyncio.create_task(self._migrate(PairMigrationConstraints.DECODING_2_DECODING, 1)) + asyncio.create_task(self._migrate(PairMigrationConstraints.PREFILL_2_DECODING)) + asyncio.create_task(self._migrate(PairMigrationConstraints.DECODING_2_DECODING)) else: - asyncio.create_task(self._migrate(PairMigrationConstraints.NO_CONSTRAINTS, 1)) + asyncio.create_task(self._migrate(PairMigrationConstraints.NO_CONSTRAINTS)) - async def _migrate(self, pair_migration_type: PairMigrationConstraints, migrate_in_num_requests: int) -> None: + async def _migrate(self, pair_migration_type: PairMigrationConstraints) -> None: async def migrate_done_callback(ret, migrate_instance_pair: Tuple[str, str]) -> None: self.num_migrating -= 1 - if isinstance(ret, (ray.exceptions.RayActorError, KeyError)): + # TODO(s5u13b): Add more exception types for failover. + if isinstance(ret, (ray.exceptions.RayActorError, ray.exceptions.RayTaskError, KeyError)): has_error_pair = await self._check_instance_error(migrate_instance_pair) for i, has_error in enumerate(has_error_pair): # Instance without error should clear migration states. if not has_error: try: await self.instances[migrate_instance_pair[i]].clear_migration_states.remote(is_migrate_in=bool(i)) - except (ray.exceptions.RayActorError, KeyError): + except (ray.exceptions.RayActorError, ray.exceptions.RayTaskError, KeyError): has_error = True for i, has_error in enumerate(has_error_pair): if has_error: @@ -267,7 +267,7 @@ def migrate_done_callback_wrapper(migrate_instance_pair: Tuple[str, str], fut) - migrate_in_instance_name = "instance_{}".format(migrate_in_instance_id) # Use asyncio.gather to wrap ray remote call to add done callback. - task = asyncio.gather(self.instances[migrate_out_instance_id].migrate_out.remote(migrate_in_instance_name, migrate_in_num_requests), + task = asyncio.gather(self.instances[migrate_out_instance_id].migrate_out.remote(migrate_in_instance_name), return_exceptions=True) task.add_done_callback(partial(migrate_done_callback_wrapper, migrate_instance_pair)) migration_tasks.append(task) diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index 42be2f2c..3af73ac5 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -27,6 +27,7 @@ from llumnix.server_info import ServerInfo from llumnix.internal_config import MigrationConfig from llumnix.queue.queue_type import QueueType +from llumnix.llumlet.request import LlumnixRequest, RequestStatus logger = init_logger(__name__) @@ -55,7 +56,7 @@ def __init__(self, self.backend_engine) self.log_requests = True - self.check_state_thread = asyncio.create_task(self.check_state()) + asyncio.create_task(self._check_state_loop()) # pylint: disable=broad-except except Exception as e: logger.error("Failed to initialize llumlet: {}".format(e)) @@ -118,7 +119,7 @@ def from_args(cls, llumlet = engine_class.remote(instance_id, output_queue_type, backend_type, migration_config, *args, **kwargs) return llumlet - async def check_state(self): + async def _check_state_loop(self): while True: await asyncio.sleep(1) if self.backend_engine.state == EngineState.CRASHED: @@ -128,46 +129,55 @@ async def check_state(self): self_actor = ray.get_actor(self.actor_name) ray.kill(self_actor) - async def migrate_out(self, dst_instance_name: str, num_requests: int) -> List[str]: + async def migrate_out(self, dst_instance_name: str) -> List[str]: + migrate_out_requests = self.migration_scheduler.get_migrate_out_requests() + if len(migrate_out_requests) == 0: + return [] + migrated_request_list = [] + for migrate_out_request in migrate_out_requests: + migrated_request = await self._migrate_out_one_request(migrate_out_request, dst_instance_name) + migrated_request_list.extend(migrated_request) + if len(migrated_request) == 0 and migrate_out_request.eom: + break + return migrated_request_list + + async def _migrate_out_one_request(self, migrate_out_request: LlumnixRequest, dst_instance_name: str) -> List[LlumnixRequest]: try: - migrate_out_request = None + t0 = time.time() migrate_in_ray_actor = ray.get_actor(dst_instance_name, namespace='llumnix') dst_instance_id = dst_instance_name[len("instance_"):] - migrated_request_list = [] - continue_migrate = True - while continue_migrate and len(migrated_request_list) < num_requests: - t0 = time.time() - migrate_out_request = self.migration_scheduler.get_migrate_out_request() - if migrate_out_request is None: - break - - migrate_out_request.migrating = True - logger.info("{}->{} begin migrate out {}".format(self.instance_id, dst_instance_id, migrate_out_request.request_id)) - status = await self.migration_coordinator.migrate_out_multistage(migrate_in_ray_actor, migrate_out_request) - - if status == MigrationStatus.FINISHED_DONE: - await migrate_in_ray_actor.execute_engine_method.remote("commit_dst_request", migrate_out_request) - self.backend_engine.free_src_request(migrate_out_request) - migrated_request_list.append(migrate_out_request.request_id) - migrate_out_request.stage_timestamps.append(time.time()) - self.backend_engine.remove_migrating_out_request_last_stage(migrate_out_request) - else: - migrate_out_request.reset_migration_args() + logger.info("{}->{} begin migrate out".format(self.instance_id, dst_instance_id)) + migrated_request = [] + if migrate_out_request.status == RequestStatus.RUNNING: + status = await self.migration_coordinator.migrate_out_running_request(migrate_in_ray_actor, migrate_out_request) + elif migrate_out_request.status == RequestStatus.WAITING: + status = await self.migration_coordinator.migrate_out_waiting_request(migrate_in_ray_actor, migrate_out_request) + else: + return migrated_request + if status == MigrationStatus.FINISHED: + await migrate_in_ray_actor.execute_engine_method.remote("commit_dst_request", migrate_out_request) + self.backend_engine.free_src_request(migrate_out_request) + self.backend_engine.remove_migrating_out_request_last_stage(migrate_out_request) + migrated_request.append(migrate_out_request.request_id) + else: # ABORTED_SRC or ABORTED_DST + migrate_out_request.reset_migration_args_src() + migrate_out_request.reset_status() + # If dst aborts itself, dst proactively frees the pre allocated cache in migrate_in_pre_alloc. + if status == MigrationStatus.ABORTED_SRC: await migrate_in_ray_actor.execute_migration_method.remote("free_dst_pre_alloc_cache", migrate_out_request.request_id) - continue_migrate = False - t1 = time.time() - logger.info("{}->{} migrate done, migrate request {}, status:{}, len:{} blocks, cost:{} ms" \ - .format(self.instance_id, dst_instance_id, migrated_request_list, status, \ - sum(migrate_out_request.stage_num_blocks_list), (t1 - t0)*1000)) - # pylint: disable=broad-except + t1 = time.time() + logger.info("{}->{} migrate done, migrate request {}, migration status: {}, len: {} blocks, cost: {} ms" \ + .format(self.instance_id, dst_instance_id, migrated_request, status, \ + sum(migrate_out_request.stage_num_blocks_list), (t1 - t0)*1000)) + except ray.exceptions.RayActorError: + logger.info("[migrate_out] instance {} is dead".format(dst_instance_name[len("instance_"):])) + raise + # pylint: disable=W0703 except Exception as e: - if migrate_out_request: - migrate_out_request.reset_migration_args() - - logger.info("[migrate_out] src instance {}, dst instance {}, meet error: {}" - .format(self.instance_id, dst_instance_name[len("instance_"):], e)) + logger.error("unexpected exception occurs: {}".format(e)) + logger.error("exception traceback: {}".format(traceback.format_exc())) raise - return migrated_request_list + return migrated_request def get_instance_info(self) -> InstanceInfo: return self.backend_engine.engine.instance_info @@ -209,7 +219,13 @@ def clear_migration_states(self, is_migrate_in: bool) -> None: migrating_out_requests_last_stage = self.backend_engine.pop_migrating_out_requests_last_stage() for backend_request in migrating_out_requests_last_stage: logger.info("clear_migration_states: add request {} back to engine".format(backend_request.request_id)) - self.backend_engine.add_running_request(backend_request) + assert backend_request.status in [RequestStatus.WAITING_MIGRATING, RequestStatus.RUNNING_MIGRATING], \ + "The status of request in migrating_out_requests_last_stage should be \ + RequestStatus.WAITING_MIGRATING or RequestStatus.RUNNING_MIGRATING" + if backend_request.status == RequestStatus.RUNNING_MIGRATING: + self.backend_engine.add_running_request(backend_request) + else: # WAITING_MIGRATING + self.backend_engine.add_waiting_request(backend_request) def execute_migration_method(self, method, *args, **kwargs): executor = getattr(self.migration_coordinator, method) diff --git a/llumnix/llumlet/local_migration_scheduler.py b/llumnix/llumlet/local_migration_scheduler.py index ad676cc1..4f30f850 100644 --- a/llumnix/llumlet/local_migration_scheduler.py +++ b/llumnix/llumlet/local_migration_scheduler.py @@ -11,80 +11,92 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Deque, List import numpy as np -from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType +from llumnix.llumlet.request import LlumnixRequest, RequestStatus, RequestInferenceType from llumnix.backends.backend_interface import BackendInterface + class LocalMigrationScheduler: def __init__(self, request_migration_policy: str, backend_engine: BackendInterface) -> None: self.request_migration_policy = request_migration_policy self.backend_engine = backend_engine - def get_migrate_out_request(self, min_request_len=0, max_request_len=np.inf) -> Optional[LlumnixRequest]: - # Requests meet the strict pre-migration always have higher prioirity than other migration policy. - migrate_out_request = self.get_ready_migration_request(min_request_len, max_request_len) - if migrate_out_request is None: - if self.request_migration_policy == 'LCFS': - migrate_out_request = self.get_last_running_request(min_request_len, max_request_len) - elif self.request_migration_policy == 'LJF': - migrate_out_request = self.get_longest_running_request(min_request_len, max_request_len) - elif self.request_migration_policy == 'SJF': - migrate_out_request = self.get_shortest_running_request(min_request_len, max_request_len) - return migrate_out_request + def get_migrate_out_requests(self, min_request_len=0, max_request_len=np.inf) -> List[LlumnixRequest]: + # Requests meet the strict pre-migration always have higher prioirity than other migration policy. + migrate_out_requests: List[LlumnixRequest] = self.get_required_migration_request() + if len(migrate_out_requests) == 0: + if self.request_migration_policy == 'LCR': + migrate_out_requests = self._get_last_running_request(min_request_len, max_request_len) + elif self.request_migration_policy == 'LR': + migrate_out_requests = self._get_longest_running_request(min_request_len, max_request_len) + elif self.request_migration_policy == 'SR': + migrate_out_requests = self._get_shortest_running_request(min_request_len, max_request_len) + elif self.request_migration_policy == 'FCW': + migrate_out_requests = self._get_first_waiting_request(min_request_len, max_request_len) + elif self.request_migration_policy == 'FCWSR': + migrate_out_requests = self._get_first_waiting_and_shortest_running_requests(min_request_len, max_request_len) + return migrate_out_requests # The function is used to retrieve requests on the backend that have already met the expected_steps. - # TODO(xinyi): Currently, the function is only used for Prefill-decoding disaggregation, + # (xinyi): Currently, the function is only used for Prefill-decoding disaggregation, # and only selects request that migrates from the prefill instance to the decoding instance. - def get_ready_migration_request(self, min_request_len, max_request_len): + def get_required_migration_request(self): running: List[LlumnixRequest] = self.backend_engine.get_running_queue() - target_request: LlumnixRequest = None + required_migration_requests = [] for request in reversed(running): - if request.migrating: - continue - - if request.output_len >= request.expected_steps \ + if request.status == RequestStatus.RUNNING \ and request.inference_type == RequestInferenceType.DECODE \ - and min_request_len <= request.request_len <= max_request_len: - target_request = request - break - - return target_request - - def get_last_running_request(self, min_request_len, max_request_len): - running: List[LlumnixRequest] = self.backend_engine.get_running_queue() - target_request: LlumnixRequest = None - - for request in reversed(running): - if request.migrating: - continue - - if request.inference_type == RequestInferenceType.DECODE \ - and min_request_len <= request.request_len <= max_request_len: - target_request=request - break - - return target_request - - def get_longest_running_request(self, min_request_len, max_request_len): - running: List[LlumnixRequest] = self.backend_engine.get_running_queue() - condition = lambda request : request.inference_type == RequestInferenceType.DECODE \ - and min_request_len <= request.request_len <= max_request_len \ - and (not request.migrating) - - longest_seq_group = max((request for request in running if condition(request)), \ - key=lambda request: request.request_len, default=None) - - return longest_seq_group - - def get_shortest_running_request(self, min_request_len, max_request_len): - running: List[LlumnixRequest] = self.backend_engine.get_running_queue() - condition = lambda request : request.inference_type == RequestInferenceType.DECODE \ - and min_request_len <= request.request_len <= max_request_len \ - and (not request.migrating) - - shortest_seq_group = min((request for request in running if condition(request)), \ - key=lambda request: request.request_len, default=None) - - return shortest_seq_group + and request.output_len >= request.expected_steps: + required_migration_requests.append(request) + return required_migration_requests + + def _filter_running_queue(self, running, min_request_len, max_request_len): + filtered_running = [ + request for request in running \ + if request.status == RequestStatus.RUNNING \ + and request.inference_type == RequestInferenceType.DECODE \ + and min_request_len < request.request_len < max_request_len \ + ] + return filtered_running + + def _filter_waiting_queue(self, waiting, min_request_len, max_request_len): + filtered_waiting = [ + request for request in waiting \ + if request.status == RequestStatus.WAITING \ + and request.try_schedule_times >= 1 \ + and min_request_len < request.request_len < max_request_len \ + ] + return filtered_waiting + + def _get_last_running_request(self, min_request_len, max_request_len): + running: Deque[LlumnixRequest] = self.backend_engine.get_running_queue() + filtered_running = self._filter_running_queue(running, min_request_len, max_request_len) + return [filtered_running[-1]] if filtered_running else [] + + def _get_longest_running_request(self, min_request_len, max_request_len) -> List[LlumnixRequest]: + running: Deque[LlumnixRequest] = self.backend_engine.get_running_queue() + filtered_running = self._filter_running_queue(running, min_request_len, max_request_len) + longest_seq_group = max((request for request in filtered_running), \ + key=lambda request: request.request_len, default=None) + return [longest_seq_group] if longest_seq_group is not None else [] + + def _get_shortest_running_request(self, min_request_len, max_request_len) -> List[LlumnixRequest]: + running: Deque[LlumnixRequest] = self.backend_engine.get_running_queue() + filtered_running = self._filter_running_queue(running, min_request_len, max_request_len) + shortest_seq_group = min((request for request in filtered_running), \ + key=lambda request: request.request_len, default=None) + return [shortest_seq_group] if shortest_seq_group is not None else [] + + def _get_first_waiting_request(self, min_request_len, max_request_len) -> List[LlumnixRequest]: + waiting: Deque[LlumnixRequest] = self.backend_engine.get_waiting_queue() + filtered_waiting = self._filter_waiting_queue(waiting, min_request_len, max_request_len) + return [waiting[0]] if filtered_waiting else [] + + def _get_first_waiting_and_shortest_running_requests(self, min_request_len, max_request_len) -> List[LlumnixRequest]: + waiting_requests = self._get_first_waiting_request(min_request_len, max_request_len) + running_requests = self._get_shortest_running_request(min_request_len, max_request_len) + if waiting_requests: + waiting_requests[0].eom = True + return waiting_requests + running_requests diff --git a/llumnix/llumlet/migration_coordinator.py b/llumnix/llumlet/migration_coordinator.py index 03b20cb2..224c41c3 100644 --- a/llumnix/llumlet/migration_coordinator.py +++ b/llumnix/llumlet/migration_coordinator.py @@ -19,7 +19,7 @@ import ray from llumnix.logger import init_logger -from llumnix.llumlet.request import LlumnixRequest +from llumnix.llumlet.request import LlumnixRequest, RequestStatus from llumnix.backends.backend_interface import BackendInterface logger = init_logger(__name__) @@ -27,18 +27,16 @@ class MigrationStatus(enum.Enum): """Status of Migration.""" RUNNING = enum.auto() - # aborted by src instance - ABORTED_SRC = enum.auto() - # aborted by dst instance ABORTED_DST = enum.auto() - FINISHED_DONE = enum.auto() + ABORTED_SRC = enum.auto() + FINISHED = enum.auto() @staticmethod def is_finished(status: "MigrationStatus") -> bool: return status in [ - MigrationStatus.ABORTED_SRC, MigrationStatus.ABORTED_DST, - MigrationStatus.FINISHED_DONE + MigrationStatus.ABORTED_SRC, + MigrationStatus.FINISHED ] class MigrationCoordinator: @@ -50,36 +48,88 @@ def __init__(self, self.max_stages = max_stages self.backend_engine = backend_engine - async def migrate_out_onestage(self, migrate_in_ray_actor: "ray.actor.ActorHandle", migrate_out_request: LlumnixRequest, ) -> "MigrationStatus": - """one-stage live migration until last stage + async def migrate_out_running_request(self, + migrate_in_ray_actor: "ray.actor.ActorHandle", + migrate_out_request: LlumnixRequest) -> "MigrationStatus": + return await self._migrate_out_multistage(migrate_in_ray_actor, migrate_out_request) + + async def migrate_out_waiting_request(self, + migrate_in_ray_actor: "ray.actor.ActorHandle", + migrate_out_request: LlumnixRequest) -> "MigrationStatus": + """one-stage migration for a waiting request + """ + found = self.backend_engine.remove_waiting_request(migrate_out_request.request_id) + if not found: + return MigrationStatus.ABORTED_SRC + self.backend_engine.add_migrating_out_request_last_stage(migrate_out_request) + dst_blocks = await migrate_in_ray_actor.execute_migration_method \ + .remote("migrate_in_pre_alloc", migrate_out_request.request_id, + migrate_out_request.status, + migrate_out_request.arrival_time, + migrate_out_request.prefill_num_blocks) + if len(dst_blocks) != migrate_out_request.prefill_num_blocks: + self.backend_engine.add_waiting_request(migrate_out_request) + self.backend_engine.remove_migrating_out_request_last_stage(migrate_out_request) + return MigrationStatus.ABORTED_DST + + return MigrationStatus.FINISHED + + async def _migrate_out_multistage(self, + migrate_in_ray_actor: "ray.actor.ActorHandle", + migrate_out_request: LlumnixRequest) -> "MigrationStatus": + """Migrate out requests to a specified instance, return migrated request id. + Args: + migrate_in_ray_actor: instance actor name, used to get ray actor handle + """ + stage_count = 0 + while stage_count < self.max_stages: + stage_count += 1 + status = await self._migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) + if MigrationStatus.is_finished(status): + return status + # exceed max stages + return MigrationStatus.ABORTED_SRC + + async def _migrate_out_onestage(self, + migrate_in_ray_actor: "ray.actor.ActorHandle", + migrate_out_request: LlumnixRequest) -> "MigrationStatus": + """one-stage live migration until last stage for a running request """ pre_stage_num_blocks = sum(migrate_out_request.stage_num_blocks_list) incremental_blocks = self.backend_engine.get_request_incremental_blocks(migrate_out_request, pre_stage_num_blocks) # live migration, transfer all blocks except last one(currently updating) - migration_status = MigrationStatus.RUNNING is_last_stage = (len(incremental_blocks) <= self.last_stage_max_blocks) or migrate_out_request.blocking_migration if not is_last_stage: + migration_status = MigrationStatus.RUNNING src_blocks = incremental_blocks[:-1] stage_block_num = len(incremental_blocks) - 1 dst_blocks = await migrate_in_ray_actor.execute_migration_method \ - .remote("migrate_in_pre_alloc", migrate_out_request.request_id, stage_block_num) + .remote("migrate_in_pre_alloc", migrate_out_request.request_id, + migrate_out_request.status, + migrate_out_request.arrival_time, + stage_block_num) else: # last stage migration, stop inference, transfer all blocks - migration_status = MigrationStatus.FINISHED_DONE - self.backend_engine.remove_running_request(migrate_out_request.request_id) + migration_status = MigrationStatus.FINISHED + found = self.backend_engine.remove_running_request(migrate_out_request.request_id) + if not found: + return MigrationStatus.ABORTED_SRC self.backend_engine.add_migrating_out_request_last_stage(migrate_out_request) - stage_block_num = len(incremental_blocks) src_blocks = incremental_blocks[:] + stage_block_num = len(incremental_blocks) dst_blocks = await migrate_in_ray_actor.execute_migration_method \ - .remote("migrate_in_pre_alloc", migrate_out_request.request_id, stage_block_num) + .remote("migrate_in_pre_alloc", migrate_out_request.request_id, + migrate_out_request.status, + migrate_out_request.arrival_time, + stage_block_num) if len(dst_blocks) != len(src_blocks): - # migrate-in instance failed to prev alloc + # migrate-in instance failed to pre alloc if is_last_stage: self.backend_engine.add_running_request(migrate_out_request) self.backend_engine.remove_migrating_out_request_last_stage(migrate_out_request) - migration_status = MigrationStatus.ABORTED_DST - return migration_status + return MigrationStatus.ABORTED_DST + # do stage send/recv migrate_out_request.stage_timestamps.append(time.time()) migrate_out_request.stage_num_blocks_list.append(stage_block_num) @@ -87,28 +137,21 @@ async def migrate_out_onestage(self, migrate_in_ray_actor: "ray.actor.ActorHandl await self.backend_engine.send_blocks(migrate_in_ray_actor, src_blocks, dst_blocks) if not is_last_stage and migrate_out_request.should_abort_migration(): # migrate-out request abort by scheduler during send/recv - migration_status = MigrationStatus.ABORTED_SRC + return MigrationStatus.ABORTED_SRC return migration_status - async def migrate_out_multistage(self, migrate_in_ray_actor: "ray.actor.ActorHandle", migrate_out_request: LlumnixRequest) -> "MigrationStatus": - """Migrate out requests to a specified instance, return migrated request id. - Args: - dst_instance_name:instance actor name, used to get ray actor handle - """ - state_count = 0 - while state_count < self.max_stages: - state_count += 1 - status = await self.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) - if MigrationStatus.is_finished(status): - return status - # exceed max stages - return MigrationStatus.ABORTED_SRC - - def migrate_in_pre_alloc(self, request_id: str, block_num: int) -> List[int]: + def migrate_in_pre_alloc(self, + request_id: str, + request_status: RequestStatus, + request_arrival_time: float, + block_num: int) -> List[int]: """prev alloc blocks to migrate in request """ - pre_alloc_blocks = self.backend_engine.pre_alloc(request_id ,block_num) + pre_alloc_blocks = self.backend_engine.pre_alloc(request_id, + request_status, + request_arrival_time, + block_num) if len(pre_alloc_blocks) != block_num: # failed to alloc, abort request self.free_dst_pre_alloc_cache(request_id) diff --git a/llumnix/llumlet/request.py b/llumnix/llumlet/request.py index c2aeda9e..d92e6564 100644 --- a/llumnix/llumlet/request.py +++ b/llumnix/llumlet/request.py @@ -20,6 +20,13 @@ class RequestInferenceType(str, Enum): PREFILL = "prefill" DECODE = "decode" +class RequestStatus(str, Enum): + RUNNING = "running" + WAITING = "waiting" + FINISHED = "finished" + RUNNING_MIGRATING = "running_migrating" + WAITING_MIGRATING = "waiting_migrating" + class LlumnixRequest: def __init__(self, request_id: int, server_info: ServerInfo, expected_steps: int) -> None: self.request_id = request_id @@ -32,18 +39,31 @@ def __init__(self, request_id: int, server_info: ServerInfo, expected_steps: int self.last_preemption_time = None self.stage_timestamps = [] self.stage_num_blocks_list = [] - self.migrating = False + self.try_schedule_times = 0 + self._status = None + + # end-of-migration, for multiple requests migration + self.eom = False + + def reset_migration_args_dst(self): + # By default, there is no limit on the number of steps expected for the request. + self.expected_steps = math.inf - def reset_migration_args(self): self.last_preemption_time = None self.stage_timestamps = [] self.stage_num_blocks_list = [] - # By default, there is no limit on the number of steps expected for the request. - self.expected_steps = math.inf - self.migrating = False + self.try_schedule_times = 0 - def is_finished(self) -> bool: - raise NotImplementedError + def reset_migration_args_src(self): + self.last_preemption_time = None + self.stage_timestamps = [] + self.stage_num_blocks_list = [] + + def reset_status(self): + self._status = None + + def set_status(self, status: RequestStatus): + self._status = status @property def inference_type(self) -> RequestInferenceType: @@ -61,6 +81,22 @@ def prompt_len(self) -> int: def output_len(self) -> int: raise NotImplementedError + @property + def finished(self) -> bool: + raise NotImplementedError + + @property + def arrival_time(self) -> float: + raise NotImplementedError + + @property + def status(self) -> RequestStatus: + raise NotImplementedError + + @property + def prefill_num_blocks(self) -> int: + raise NotImplementedError + # Whether the migration of request is completed within one stage. For requests that have already reached # the expected steps, blocking_migration is True. @property @@ -68,7 +104,5 @@ def blocking_migration(self) -> bool: return self.output_len >= self.expected_steps def should_abort_migration(self) -> bool: - return self.output_len == 0 \ - or (self.last_preemption_time and self.last_preemption_time > self.stage_timestamps[-1]) \ - or self.inference_type == RequestInferenceType.PREFILL \ - or self.is_finished() + return self.finished \ + or (self.last_preemption_time is not None and self.last_preemption_time > self.stage_timestamps[-1]) diff --git a/tests/e2e_test/test_bench.py b/tests/e2e_test/test_bench.py index eb93fb89..5eba27d1 100644 --- a/tests/e2e_test/test_bench.py +++ b/tests/e2e_test/test_bench.py @@ -21,7 +21,7 @@ from .test_e2e import generate_launch_command, clear_ray_state # pylint: disable=unused-import -from .utils import to_markdown_table, clean_ray +from .utils import to_markdown_table, setup_ray_env def launch_llumnix_service(command): subprocess.run(command, shell=True, check=True) @@ -91,7 +91,7 @@ def get_markdown_data(key: str, head_name: str): @pytest.mark.asyncio @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="at least 1 gpus required for simple benchmark") @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) -async def test_simple_benchmark(clean_ray, model): +async def test_simple_benchmark(setup_ray_env, model): device_count = torch.cuda.device_count() base_port = 37037 for i in range(device_count): diff --git a/tests/e2e_test/test_e2e.py b/tests/e2e_test/test_e2e.py index 11b8617f..a3bf1977 100644 --- a/tests/e2e_test/test_e2e.py +++ b/tests/e2e_test/test_e2e.py @@ -11,7 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math import subprocess import asyncio import pytest @@ -21,7 +20,7 @@ from vllm import LLM, SamplingParams # pylint: disable=unused-import -from .utils import clean_ray +from .utils import setup_ray_env def parse_launch_mode(launch_mode: str): # 'eief' means that enable init instance by manager and enable fixed node init instance, and so on. @@ -42,8 +41,8 @@ def parse_launch_mode(launch_mode: str): def generate_launch_command(result_filename: str = "", launch_ray_cluster: bool = True, HEAD_NODE_IP: str = "127.0.0.1", ip: str = "127.0.0.1", port: int = 37000, instances_num = 1, dispatch_policy: str = "load", migration_backend = "gloo", model = "facebook/opt-125m", max_model_len: int = 2048, - launch_mode: str = 'eief', log_instance_info: bool = False, enable_pd_disagg: bool = False, - num_dispatch_instances: int = math.inf): + launch_mode: str = 'eief', log_instance_info: bool = False, + request_migration_policy: str = 'SR'): disable_init_instance_by_manager, disable_fixed_node_init_instance = parse_launch_mode(launch_mode) command = ( f"RAY_DEDUP_LOGS=0 HEAD_NODE_IP={HEAD_NODE_IP} HEAD_NODE=1 " @@ -62,14 +61,12 @@ def generate_launch_command(result_filename: str = "", launch_ray_cluster: bool f"--max-model-len {max_model_len} " f"--dispatch-policy {dispatch_policy} " f"--trust-remote-code " - f"--request-migration-policy LCFS " + f"--request-migration-policy {request_migration_policy} " f"--migration-backend {migration_backend} " f"--migration-buffer-blocks 32 " f"--migration-internal-buffer-num 2 " f"--tensor-parallel-size 1 " f"--request-output-queue-port {1234+port} " - f"{'--enable-pd-disagg ' if enable_pd_disagg else ''} " - f"{f'--num-dispatch-instances {num_dispatch_instances} ' if num_dispatch_instances != math.inf else ''} " f"{'--launch-ray-cluster ' if launch_ray_cluster else ''}" f"{'> instance_'+result_filename if len(result_filename)> 0 else ''} 2>&1 &" ) @@ -143,7 +140,7 @@ def run_vllm(model, max_model_len, sampling_params): @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) @pytest.mark.parametrize("migration_backend", ['rpc', 'gloo']) @pytest.mark.parametrize("launch_mode", ['eief', 'eidf', 'dief', 'didf']) -async def test_e2e(clean_ray, model, migration_backend, launch_mode): +async def test_e2e(setup_ray_env, model, migration_backend, launch_mode): if migration_backend == 'gloo' and launch_mode != 'eief': pytest.skip("When the migration backend is gloo, the launch mode of llumnix can only be eief") max_model_len = 370 diff --git a/tests/e2e_test/test_migration.py b/tests/e2e_test/test_migration.py index b1f446f1..ced1e0be 100644 --- a/tests/e2e_test/test_migration.py +++ b/tests/e2e_test/test_migration.py @@ -11,7 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math import asyncio from collections import defaultdict import re @@ -23,7 +22,7 @@ from .test_e2e import generate_launch_command from .test_bench import generate_bench_command, clear_ray_state, shutdown_llumnix_service # pylint: disable=unused-import -from .utils import to_markdown_table, clean_ray +from .utils import to_markdown_table, setup_ray_env size_pattern = re.compile(r'total_kv_cache_size:\s*([\d.]+)\s*(B|KB|MB|GB|KB|TB)') speed_pattern = re.compile(r'speed:\s*([\d.]+)GB/s') @@ -43,18 +42,18 @@ def parse_instance_log_file(log_files): speed = float(speed_match.group(1)) speed_dict[total_kv_cache_size].append(speed) - averger_speed = {} + average_speed = {} for transfer_size, speeds in speed_dict.items(): if len(speeds) <= 2: continue speeds.sort() trimmed_speeds = speeds[1:-1] - averger_speed[transfer_size] = sum(trimmed_speeds) / len(trimmed_speeds) + average_speed[transfer_size] = sum(trimmed_speeds) / len(trimmed_speeds) - assert len(averger_speed) > 0, "Migration should have occurred, but it was not detected. " + assert len(average_speed) > 0, "Migration should have occurred, but it was not detected. " - return averger_speed + return average_speed def parse_manager_log_file(log_file): df = pd.read_csv(log_file) @@ -68,7 +67,13 @@ def parse_manager_log_file(log_file): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="at least 2 gpus required for migration bench") @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) @pytest.mark.parametrize("migration_backend", ['rpc', 'gloo']) -async def test_migration_benchmark(clean_ray, model, migration_backend): +@pytest.mark.parametrize("migrated_request_status", ['running', 'waiting']) +async def test_migration_benchmark(model, migration_backend, migrated_request_status): + if migrated_request_status == 'waiting' and migration_backend != 'rpc': + pytest.skip("When the migrated request status is waiting, only test the rpc migration backend.") + + request_migration_policy = 'SR' if migrated_request_status == 'running' else 'FCW' + base_port = 37037 instance_output_logs = [] @@ -78,10 +83,11 @@ async def test_migration_benchmark(clean_ray, model, migration_backend): instance_output_logs.append("instance_"+output_log) launch_command = generate_launch_command(result_filename=output_log, launch_ray_cluster=False, port=base_port+i, model=model, dispatch_policy="flood", migration_backend=migration_backend, - log_instance_info=True, enable_pd_disagg=False, - num_dispatch_instances=math.inf) + log_instance_info=True, + request_migration_policy=request_migration_policy) subprocess.run(launch_command, shell=True, check=True) - await asyncio.sleep(60) + await asyncio.sleep(5) + await asyncio.sleep(30) async def run_bench_command(command): process = await asyncio.create_subprocess_shell(command) @@ -99,23 +105,23 @@ async def run_bench_command(command): _, pending = await asyncio.wait(tasks, timeout=60*30) + await asyncio.sleep(10) + if len(pending) > 0: raise RuntimeError("migration task Timeout") parse_manager_log_file("manager_instance.csv") - averger_speed = parse_instance_log_file(instance_output_logs) - - sorted_keys = sorted(averger_speed.keys(), key=lambda x: float(x.split()[0])) - - data = [ - ['migration_size'] + sorted_keys, - [f'{migration_backend}_speed(GB/s)'] + [f"{averger_speed[key]:.2f}" for key in sorted_keys] - ] - - with open("performance.txt", "a", encoding="utf-8") as f: - f.write(to_markdown_table(data)) + if migrated_request_status == 'running': + average_speed = parse_instance_log_file(instance_output_logs) + sorted_keys = sorted(average_speed.keys(), key=lambda x: float(x.split()[0])) + data = [ + ['migration_size'] + sorted_keys, + [f'{migration_backend}_speed(GB/s)'] + [f"{average_speed[key]:.2f}" for key in sorted_keys] + ] + with open("performance.txt", "a", encoding="utf-8") as f: + f.write(to_markdown_table(data)) shutdown_llumnix_service() clear_ray_state() - await asyncio.sleep(3) + await asyncio.sleep(10) diff --git a/tests/e2e_test/utils.py b/tests/e2e_test/utils.py index 492eb2fd..1c38dcc8 100644 --- a/tests/e2e_test/utils.py +++ b/tests/e2e_test/utils.py @@ -33,7 +33,7 @@ def to_markdown_table(data): return table @pytest.fixture -def clean_ray(): +def setup_ray_env(): subprocess.run(["ray", "stop"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) subprocess.run(["ray", "start", "--head", "--disable-usage-stats", "--port=6379"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) diff --git a/tests/unit_test/backends/vllm/test_migration.py b/tests/unit_test/backends/vllm/test_migration.py index e5bf4567..b74950c2 100644 --- a/tests/unit_test/backends/vllm/test_migration.py +++ b/tests/unit_test/backends/vllm/test_migration.py @@ -11,20 +11,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List import asyncio import math +from unittest.mock import MagicMock import pytest import ray +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy from vllm import EngineArgs, SamplingParams from vllm.utils import random_uuid +from vllm.sequence import SequenceStatus from llumnix.backends.vllm.llm_engine import BackendVLLM from llumnix.llumlet.llumlet import Llumlet from llumnix.backends.utils import BackendType from llumnix.internal_config import MigrationConfig -from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType +from llumnix.llumlet.request import RequestInferenceType, RequestStatus from llumnix.queue.queue_type import QueueType from tests.unit_test.queue.utils import request_output_queue_server @@ -51,22 +53,60 @@ def __init__(self): self.instance_id = "0" self.backend_engine = MockBackendVLLM() +@ray.remote(num_cpus=1, max_concurrency=4) +class MockLlumletDoNotSchedule(Llumlet): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # stop the schedule in engine step loop + self.backend_engine.engine.scheduler.schedule = MagicMock() + + # For some reason, if MockScheduelrOutputs is defined outside, the constructor would raise error. + class MockScheduelrOutputs: + def __init__(self): + self.scheduled_seq_groups = [] + self.ignored_seq_groups = [] + self.num_batched_tokens = 0 + + def is_empty(self) -> bool: + return not self.scheduled_seq_groups + + scheduler_outputs = MockScheduelrOutputs() + self.backend_engine.engine.scheduler.schedule.return_value = ([], scheduler_outputs) + + self.step_async = self.backend_engine.engine.step_async + + async def step_async_try_schedule(): + request_outputs, server_infos = await self.step_async() + for seq_group in self.backend_engine.engine.scheduler.waiting: + seq_group.try_schedule_times += 1 + return request_outputs, server_infos + + self.backend_engine.engine.step_async = step_async_try_schedule + +# TODO(s5u13b): Test migrate waiting request. @pytest.mark.parametrize("migration_backend", ['rpc', 'gloo', 'nccl']) +@pytest.mark.parametrize("migration_request_status", ['waiting', 'running']) @pytest.mark.asyncio -async def test_migration_correctness(setup_ray_env, migration_backend): +async def test_migration_correctness(setup_ray_env, migration_backend, migration_request_status): engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) - id_rank_map = {"0":0, "1":1} - migration_config = MigrationConfig("LCFS", migration_backend, 16, 1, 4, 5, 20, 2) + id_rank_map = {"0": 0, "1": 1, "2": 2} + if migration_request_status == 'running': + request_migration_policy = "SR" + elif migration_request_status == 'waiting': + request_migration_policy = "FCW" + migration_config = MigrationConfig(request_migration_policy, migration_backend, 16, 1, 4, 5, 20, 2) output_queue_type = QueueType.RAYQUEUE que, server_info = request_output_queue_server(output_queue_type) asyncio.create_task(que.run_server_loop()) + node_id = ray.get_runtime_context().get_node_id() + scheduling_strategy = NodeAffinitySchedulingStrategy(node_id=node_id, soft=False) llumlet_0: Llumlet = Llumlet.from_args( output_queue_type, False, - True, - ray.get_runtime_context().get_node_id(), + False, + node_id, "0", BackendType.VLLM, 1, @@ -76,24 +116,39 @@ async def test_migration_correctness(setup_ray_env, migration_backend): llumlet_1: Llumlet = Llumlet.from_args( output_queue_type, False, - True, - ray.get_runtime_context().get_node_id(), + False, + node_id, "1", BackendType.VLLM, 1, migration_config, engine_args) + llumlet_2: Llumlet = MockLlumletDoNotSchedule.options( + name='instance_2', + namespace='llumnix', + scheduling_strategy=scheduling_strategy).remote( + instance_id="2", + output_queue_type=output_queue_type, + backend_type=BackendType.VLLM, + migration_config=migration_config, + engine_args=engine_args, + node_id=node_id + ) + while True: - res = ray.get([llumlet_0.is_ready.remote(),llumlet_1.is_ready.remote()]) + res = ray.get([llumlet_0.is_ready.remote(), llumlet_1.is_ready.remote(), llumlet_2.is_ready.remote()]) if all(res): break ray.get([llumlet_0.execute_engine_method.remote("_run_workers", "rebuild_migration_backend", id_rank_map, "llumnix"), - llumlet_1.execute_engine_method.remote("_run_workers", "rebuild_migration_backend", id_rank_map, "llumnix")]) + llumlet_1.execute_engine_method.remote("_run_workers", "rebuild_migration_backend", id_rank_map, "llumnix"), + llumlet_2.execute_engine_method.remote("_run_workers", "rebuild_migration_backend", id_rank_map, "llumnix")]) # empty instance migrate out - res = ray.get(llumlet_0.migrate_out.remote("instance_1", num_requests=math.inf)) + res = ray.get(llumlet_0.migrate_out.remote("instance_1")) + assert not res + res = ray.get(llumlet_2.migrate_out.remote("instance_1")) assert not res # running without migration @@ -110,16 +165,28 @@ async def test_correctness(prompt): origin_output = request_output.outputs[0] finished = request_output.finished - request_id1 = random_uuid() - ray.get(llumlet_0.generate.remote(request_id1, server_info, math.inf, prompt, sampling_params)) - # wait prefill done - while True: - running_queue: List[LlumnixRequest] = ray.get(llumlet_0.execute_engine_method.remote("get_running_queue")) - if len(running_queue) > 0 and running_queue[0].inference_type == RequestInferenceType.DECODE: - break - # migrate request - res = ray.get(llumlet_0.migrate_out.remote("instance_1", num_requests=math.inf)) - assert len(res) == 1 + if migration_request_status == 'running': + request_id1 = random_uuid() + ray.get(llumlet_0.generate.remote(request_id1, server_info, math.inf, prompt, sampling_params)) + # wait prefill done + while True: + running_queue = ray.get(llumlet_0.execute_engine_method.remote("get_running_queue")) + if len(running_queue) > 0 and running_queue[0].inference_type == RequestInferenceType.DECODE: + break + # migrate request + res = ray.get(llumlet_0.migrate_out.remote("instance_1")) + assert len(res) == 1 + elif migration_request_status == 'waiting': + request_id1 = random_uuid() + ray.get(llumlet_2.generate.remote(request_id1, server_info, math.inf, prompt, sampling_params)) + # wait try schedule done + while True: + waiting_queue = ray.get(llumlet_2.execute_engine_method.remote("get_waiting_queue")) + if len(waiting_queue) > 0 and waiting_queue[0].try_schedule_times >= 1: + break + # migrate request + res = ray.get(llumlet_2.migrate_out.remote("instance_1")) + assert len(res) == 1 request_output_queue = que output = None @@ -127,10 +194,6 @@ async def test_correctness(prompt): while not finished: request_outputs = await request_output_queue.get() for request_output in request_outputs: - origin_output = request_output.outputs[0] - finished = request_output.finished - if request_output.request_id != request_id1: - continue output = request_output.outputs[0] finished = request_output.finished @@ -146,7 +209,7 @@ async def test_correctness(prompt): async def test_pd_diaggregation_correctness(setup_ray_env, migration_backend): engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) id_rank_map = {"0":0, "1":1} - migration_config = MigrationConfig("LCFS", migration_backend, 16, 1, 4, 5, 20, 2) + migration_config = MigrationConfig("SR", migration_backend, 16, 1, 4, 5, 20, 2) output_queue_type = QueueType.RAYQUEUE que, server_info = request_output_queue_server(output_queue_type) @@ -179,12 +242,10 @@ async def test_pd_diaggregation_correctness(setup_ray_env, migration_backend): res = ray.get([llumlet_0.is_ready.remote(),llumlet_1.is_ready.remote()]) if all(res): break - - ray.get([llumlet_0.execute_engine_method.remote("_run_workers", "rebuild_migration_backend", id_rank_map, "llumnix"), - llumlet_1.execute_engine_method.remote("_run_workers", "rebuild_migration_backend", id_rank_map, "llumnix")]) - + ray.get([llumlet_0.execute_engine_method.remote("_run_workers","rebuild_migration_backend", id_rank_map, "llumnix"), + llumlet_1.execute_engine_method.remote("_run_workers","rebuild_migration_backend", id_rank_map, "llumnix")]) # empty instance migrate out - res = ray.get(llumlet_0.migrate_out.remote("instance_1", num_requests=math.inf)) + res = ray.get(llumlet_0.migrate_out.remote("instance_1")) assert not res # running without migration @@ -207,7 +268,7 @@ async def test_correctness(prompt): ray.get(llumlet_0.generate.remote(request_id1, server_info, request_expected_steps_id1, prompt, sampling_params)) # migrate request for decoding while True: - res = ray.get(llumlet_0.migrate_out.remote("instance_1", num_requests = math.inf)) + res = ray.get(llumlet_0.migrate_out.remote("instance_1")) if len(res) == 1: break request_output_queue = que @@ -216,12 +277,8 @@ async def test_correctness(prompt): while not finished: request_outputs = await request_output_queue.get() for request_output in request_outputs: - origin_output = request_output.outputs[0] + output = request_output.outputs[0] finished = request_output.finished - if request_output.request_id != request_id1: - continue - output = request_output.outputs[0] - finished = request_output.finished assert output.text == origin_output.text assert output.cumulative_logprob == origin_output.cumulative_logprob @@ -233,13 +290,19 @@ async def test_correctness(prompt): def test_clear_migration_states(): llumlet = MockLlumlet() - llumlet.backend_engine.pre_alloc("0", 1) + llumlet.backend_engine.pre_alloc("0", RequestStatus.RUNNING, 0.0, 1) num_gpu_blocks = 8 block_size = 4 llumlet.clear_migration_states(is_migrate_in=True) - assert len(llumlet.backend_engine.pre_alloc("0", num_gpu_blocks)) == num_gpu_blocks - _, seq_group = create_dummy_prompt("0",7,block_size) + assert len(llumlet.backend_engine.pre_alloc("0", RequestStatus.RUNNING, 0.0, num_gpu_blocks)) == num_gpu_blocks + _, seq_group = create_dummy_prompt("0",7,block_size,SequenceStatus.RUNNING) + seq_group.set_status(RequestStatus.RUNNING_MIGRATING) + llumlet.backend_engine.add_migrating_out_request_last_stage(seq_group) + llumlet.clear_migration_states(is_migrate_in=False) + assert len(llumlet.backend_engine.get_running_queue()) == 1 + _, seq_group = create_dummy_prompt("0",7,block_size,SequenceStatus.WAITING) + seq_group.set_status(RequestStatus.WAITING_MIGRATING) llumlet.backend_engine.add_migrating_out_request_last_stage(seq_group) llumlet.clear_migration_states(is_migrate_in=False) - assert len(llumlet.backend_engine.get_running_queue()) > 0 + assert len(llumlet.backend_engine.get_waiting_queue()) == 1 diff --git a/tests/unit_test/backends/vllm/test_scheduler.py b/tests/unit_test/backends/vllm/test_scheduler.py index 1c1af7ac..c8a03981 100644 --- a/tests/unit_test/backends/vllm/test_scheduler.py +++ b/tests/unit_test/backends/vllm/test_scheduler.py @@ -12,13 +12,14 @@ # limitations under the License. import math +import time from vllm.sequence import Sequence from vllm.sequence import Logprob from vllm.core.policy import PolicyFactory from llumnix.backends.vllm.scheduler import BlockManagerLlumnix -from llumnix.llumlet.request import RequestInferenceType +from llumnix.llumlet.request import RequestInferenceType, RequestStatus from .utils import create_dummy_prompt, initialize_scheduler, create_token_budget @@ -129,6 +130,25 @@ def test_scheduler_running_request(): scheduler.add_running_request(seq_group) assert scheduler.get_num_unfinished_seq_groups() == 4 +def test_scheduler_waiting_request(): + scheduler = initialize_scheduler() + num_seq_group = 4 + block_size = 4 + _, seq_group_0 = create_dummy_prompt("0", prompt_length=0, block_size=block_size) + for idx in range(1, num_seq_group + 1): + _, seq_group = create_dummy_prompt(str(idx), prompt_length=idx, block_size=block_size) + scheduler.add_seq_group(seq_group) + assert scheduler.get_num_unfinished_seq_groups() == 4 + scheduler.remove_waiting_request("1") + assert scheduler.get_num_unfinished_seq_groups() == 3 + _, seq_group = create_dummy_prompt("6", prompt_length=idx, block_size=block_size) + scheduler.add_waiting_request(seq_group) + assert scheduler.get_num_unfinished_seq_groups() == 4 + # Test if sort the waiting queue by arrival time in add_waiting_request. + scheduler.add_waiting_request(seq_group_0) + waiting_queue = scheduler.get_waiting_queue() + assert waiting_queue[0] == seq_group_0 + def test_scheduler_migrating_out_request_last_stage(): scheduler = initialize_scheduler() block_size = 4 @@ -142,13 +162,13 @@ def test_scheduler_migrating_out_request_last_stage(): def test_scheduler_pre_alloc(): # total 8 blocks scheduler = initialize_scheduler() - blocks = scheduler.pre_alloc("1", 2) + blocks = scheduler.pre_alloc("1", RequestStatus.RUNNING, 0.0, 2) assert len(blocks) == 2 assert len(scheduler.pre_alloc_cache_dict["1"]) == 2 - blocks = scheduler.pre_alloc("1", 4) + blocks = scheduler.pre_alloc("1", RequestStatus.RUNNING, 0.0, 4) assert len(blocks) == 4 assert len(scheduler.pre_alloc_cache_dict["1"]) == 6 - blocks = scheduler.pre_alloc("2,", 4) + blocks = scheduler.pre_alloc("2", RequestStatus.RUNNING, 0.0, 4) assert len(blocks) == 0 def test_schedule_running(): @@ -176,3 +196,37 @@ def test_schedule_running(): assert len(running_scheduled.decode_seq_groups) == 1 assert len(running_scheduled.prefill_seq_groups) == 0 assert len(remainig_running) == 1 + + # test pre alloc waiting condition + # total 8 blocks + scheduler = initialize_scheduler() + before_arrival = time.time() + _, seq_group = create_dummy_prompt("1", prompt_length=1, block_size=2, expected_steps=math.inf) + after_arrival = time.time() + blocks = scheduler.pre_alloc("2", RequestStatus.WAITING_MIGRATING, after_arrival, 2) + assert len(blocks) == 2 + scheduler.add_waiting_request(seq_group) + blocks = scheduler.pre_alloc("3", RequestStatus.WAITING_MIGRATING, after_arrival, 2) + assert len(blocks) == 0 + blocks = scheduler.pre_alloc("4", RequestStatus.WAITING_MIGRATING, before_arrival, 2) + assert len(blocks) == 2 + +def test_try_schedule_times(): + # total 8 blocks + scheduler = initialize_scheduler() + _, seq_group_1 = create_dummy_prompt("1", prompt_length=8, block_size=1) + _, seq_group_2 = create_dummy_prompt("2", prompt_length=8, block_size=1) + scheduler.add_seq_group(seq_group_1) + scheduler.add_seq_group(seq_group_2) + waiting_queue = scheduler.get_waiting_queue() + assert len(waiting_queue) == 2 + assert seq_group_1.try_schedule_times == 0 + assert seq_group_2.try_schedule_times == 0 + scheduler.schedule() + # seq_group_2 cannot be scheduled due to lack of blocks + assert seq_group_1.try_schedule_times == 0 + assert seq_group_2.try_schedule_times == 1 + scheduler.schedule() + # seq_group_1 is preempted to waiting queue + assert seq_group_1.try_schedule_times == 1 + assert seq_group_2.try_schedule_times == 2 diff --git a/tests/unit_test/backends/vllm/test_simulator.py b/tests/unit_test/backends/vllm/test_simulator.py index c0753b06..7a2632cf 100644 --- a/tests/unit_test/backends/vllm/test_simulator.py +++ b/tests/unit_test/backends/vllm/test_simulator.py @@ -71,7 +71,7 @@ async def test_backend(setup_ray_env): # TODO(ZeldaHuang): add tests for BackendSimVLLM methods # (currently BackendSimVLLM is just a wrapper of BackendVLLM) engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) - migration_config = MigrationConfig("LCFS", "gloo", 16, 1, 4, 5, 20, 2) + migration_config = MigrationConfig("SR", "gloo", 16, 1, 4, 5, 20, 2) output_queue_type = QueueType.RAYQUEUE que, server_info = request_output_queue_server(output_queue_type) diff --git a/tests/unit_test/backends/vllm/utils.py b/tests/unit_test/backends/vllm/utils.py index bc8d1f09..887bdd93 100644 --- a/tests/unit_test/backends/vllm/utils.py +++ b/tests/unit_test/backends/vllm/utils.py @@ -18,7 +18,7 @@ from vllm import SamplingParams from vllm.lora.request import LoRARequest -from vllm.sequence import Logprob, Sequence +from vllm.sequence import Logprob, Sequence, SequenceStatus from vllm.config import SchedulerConfig, CacheConfig from vllm.core.scheduler import SchedulingBudget @@ -45,6 +45,7 @@ def create_dummy_prompt( request_id: str, prompt_length: int, block_size: Optional[int] = None, + status: SequenceStatus = SequenceStatus.WAITING, lora_request: Optional[LoRARequest] = None, use_beam_search: bool = False, best_of: int = 1, @@ -63,6 +64,7 @@ def create_dummy_prompt( request_id, server_info, expected_steps, [prompt], SamplingParams(use_beam_search=use_beam_search, best_of=best_of), time.time(), lora_request) + seq_group.get_seqs()[0].status = status return prompt, seq_group diff --git a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py index 28fc129e..fedaa154 100644 --- a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py +++ b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py @@ -101,7 +101,7 @@ def test_dispatch_queue(): instance_info.instance_id = instance_id instance_info.num_waiting_requests = random.randint(1, 10) instance_info_dict[instance_id] = instance_info - if len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.num_dispatch_instances: + if len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.num_dispatch_instances: dispatch_scheduler.available_dispatch_instance_set.add(instance_id) instance_num_requests[instance_id] = 0 dispatch_scheduler.instance_num_requests = instance_num_requests diff --git a/tests/unit_test/global_scheduler/test_llm_engine_manager.py b/tests/unit_test/global_scheduler/test_llm_engine_manager.py index 024ad4bf..5f81baf6 100644 --- a/tests/unit_test/global_scheduler/test_llm_engine_manager.py +++ b/tests/unit_test/global_scheduler/test_llm_engine_manager.py @@ -78,14 +78,14 @@ def abort(self, request_id): self.num_requests = len(self.request_id_set) return self.num_requests - def migrate_out(self, dst_instance_name, num_requests): + def migrate_out(self, dst_instance_name): self.num_migrate_out += 1 migrate_in_ray_actor = ray.get_actor(dst_instance_name, namespace='llumnix') - ray.get(migrate_in_ray_actor.migrate_in.remote(self.actor_name, num_requests)) + ray.get(migrate_in_ray_actor.migrate_in.remote(self.actor_name)) time.sleep(0.1) return self.num_migrate_out - def migrate_in(self, src_instance_name, num_requests): + def migrate_in(self, src_instance_name): self.num_migrate_in += 1 return self.num_migrate_in diff --git a/tests/unit_test/llumlet/test_engine_step_exception.py b/tests/unit_test/llumlet/test_engine_step_exception.py index c630a04f..86176ab2 100644 --- a/tests/unit_test/llumlet/test_engine_step_exception.py +++ b/tests/unit_test/llumlet/test_engine_step_exception.py @@ -39,7 +39,7 @@ async def raise_error_step(): @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need at least 1 GPU to run the test.") def test_engine_step_exception(setup_ray_env): engine_args = EngineArgs(model="facebook/opt-125m", max_model_len=8, worker_use_ray=True) - migration_config = MigrationConfig("LCFS", "rpc", 16, 1, 4, 5, 20, 2) + migration_config = MigrationConfig("SR", "rpc", 16, 1, 4, 5, 20, 2) node_id = ray.get_runtime_context().get_node_id() scheduling_strategy = NodeAffinitySchedulingStrategy(node_id=node_id, soft=False) diff --git a/tests/unit_test/llumlet/test_local_migration_scheduler.py b/tests/unit_test/llumlet/test_local_migration_scheduler.py index c0c6f834..ecca2b71 100644 --- a/tests/unit_test/llumlet/test_local_migration_scheduler.py +++ b/tests/unit_test/llumlet/test_local_migration_scheduler.py @@ -13,20 +13,25 @@ import math from llumnix.llumlet.local_migration_scheduler import LocalMigrationScheduler -from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType +from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType, RequestStatus class MockRequest(LlumnixRequest): - def __init__(self, request_id, length, expected_steps) -> None: + def __init__(self, request_id, length, expected_steps, status=RequestStatus.RUNNING) -> None: super().__init__(request_id=request_id, server_info=None, expected_steps=expected_steps) self.length = length - self.status = RequestInferenceType.DECODE + self._status = status + self._inference_type = RequestInferenceType.DECODE + self._finished = False + self.try_schedule_times = 0 + self.eom = False - def is_finished(self) -> bool: - return False + @property + def finished(self) -> bool: + return self._finished @property def inference_type(self) -> RequestInferenceType: - return self.status + return self._inference_type @property def request_len(self) -> int: @@ -40,16 +45,37 @@ def prompt_len(self) -> int: def output_len(self) -> int: return self.length + @property + def arrival_time(self) -> float: + pass + + @property + def status(self) -> RequestStatus: + return self._status + + @property + def prefill_num_blocks(self) -> int: + pass + class MockeEngine(): def __init__(self) -> None: self.running = [] + self.waiting = [] def add_request(self, request_id, length, expected_steps) -> None: self.running.append(MockRequest(request_id, length, expected_steps)) + def add_request_waiting(self, request_id, length, expected_steps) -> None: + request = MockRequest(request_id, length, expected_steps, status=RequestStatus.WAITING) + request.try_schedule_times += 1 + self.waiting.append(request) + def get_running_queue(self): return self.running + def get_waiting_queue(self): + return self.waiting + def test_scheduler_policy(): engine = MockeEngine() scheduler = LocalMigrationScheduler("", engine) @@ -57,37 +83,40 @@ def test_scheduler_policy(): engine.add_request(request_id="0", length=1, expected_steps=math.inf) engine.add_request(request_id="1", length=3, expected_steps=math.inf) engine.add_request(request_id="2", length=2, expected_steps=math.inf) - - scheduler.request_migration_policy = "LCFS" - assert scheduler.get_migrate_out_request().request_id == "2" - scheduler.request_migration_policy = "LJF" - assert scheduler.get_migrate_out_request().request_id == "1" - scheduler.request_migration_policy = "SJF" - assert scheduler.get_migrate_out_request().request_id == "0" - - engine.add_request(request_id="3", length=2, expected_steps=1) - engine.add_request(request_id="4", length=3, expected_steps=math.inf) - engine.add_request(request_id="5", length=4, expected_steps=math.inf) - scheduler.request_migration_policy = "LCFS" - request = scheduler.get_migrate_out_request() - request.migrating = True - assert request.request_id == "3" + engine.add_request_waiting(request_id="3", length=2, expected_steps=math.inf) + engine.add_request_waiting(request_id="4", length=2, expected_steps=math.inf) + + scheduler.request_migration_policy = "LCR" + assert scheduler.get_migrate_out_requests()[0].request_id == "2" + scheduler.request_migration_policy = "LR" + assert scheduler.get_migrate_out_requests()[0].request_id == "1" + scheduler.request_migration_policy = "SR" + assert scheduler.get_migrate_out_requests()[0].request_id == "0" + scheduler.request_migration_policy = "FCW" + assert scheduler.get_migrate_out_requests()[0].request_id == "3" + scheduler.request_migration_policy = "FCWSR" + assert scheduler.get_migrate_out_requests()[0].request_id == "3" + assert scheduler.get_migrate_out_requests()[1].request_id == "0" + + engine.add_request(request_id="5", length=2, expected_steps=1) + request = scheduler.get_migrate_out_requests()[0] + assert request.request_id == "5" assert request.output_len >= request.expected_steps and request.inference_type == RequestInferenceType.DECODE - request = scheduler.get_migrate_out_request() - request.migrating = True + engine.add_request(request_id="6", length=3, expected_steps=math.inf) + scheduler.request_migration_policy = "LCR" + request = scheduler.get_migrate_out_requests()[0] assert request.request_id == "5" - request = scheduler.get_migrate_out_request() - assert request.request_id == "4" + assert request.output_len >= request.expected_steps and request.inference_type == RequestInferenceType.DECODE def test_scheduler_should_abort_migration(): req_0 = MockRequest(request_id="0", length=1, expected_steps=math.inf) req_0.stage_timestamps = [1] assert req_0.should_abort_migration() is False - req_0.status = RequestInferenceType.PREFILL - assert req_0.should_abort_migration() is True - req_0.status = RequestInferenceType.DECODE req_0.last_preemption_time = 2 assert req_0.should_abort_migration() is True + req_0.last_preemption_time = None + req_0._finished = True + assert req_0.should_abort_migration() is True def test_blocking_migration(): req_0 = MockRequest(request_id="0", length=1, expected_steps=math.inf) diff --git a/tests/unit_test/llumlet/test_migration_coordinator.py b/tests/unit_test/llumlet/test_migration_coordinator.py index 8a1a4d44..fcdf0638 100644 --- a/tests/unit_test/llumlet/test_migration_coordinator.py +++ b/tests/unit_test/llumlet/test_migration_coordinator.py @@ -38,7 +38,7 @@ async def test_migrate_out_onestage(setup_ray_env): migrate_out_request = MagicMock() # Create an instance of MigrationCoordinator - coordinator = MigrationCoordinator(backend_engine, 1, 3) + coordinator = MigrationCoordinator(backend_engine, last_stage_max_blocks=1, max_stages=3) # Mock method return values and test data src_blocks = [1, 2, 3] @@ -49,7 +49,7 @@ async def test_migrate_out_onestage(setup_ray_env): migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) # Test normal migration scenario - status = await coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) + status = await coordinator._migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) assert status == MigrationStatus.RUNNING # Test the last stage of migration @@ -59,20 +59,21 @@ async def test_migrate_out_onestage(setup_ray_env): migrate_out_request.should_abort_migration.return_value = False migrate_out_request.blocking_migration = False migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) - status = await coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) - assert status == MigrationStatus.FINISHED_DONE + status = await coordinator._migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) + assert status == MigrationStatus.FINISHED migrate_out_request = MagicMock() - # Test migration aborted scenario + # Test migration dst aborted scenario src_blocks = [1, 2, 3] dst_blocks = [] backend_engine.get_request_incremental_blocks.return_value = src_blocks migrate_out_request.should_abort_migration.return_value = False migrate_out_request.blocking_migration = False migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) - status = await coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) + status = await coordinator._migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) assert status == MigrationStatus.ABORTED_DST + # Test migration src aborted scenario migrate_out_request = MagicMock() src_blocks = [1, 2, 3] dst_blocks = [1, 2] @@ -80,23 +81,13 @@ async def test_migrate_out_onestage(setup_ray_env): migrate_out_request.should_abort_migration.return_value = True migrate_out_request.blocking_migration = False migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) - status = await coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) + status = await coordinator._migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) assert status == MigrationStatus.ABORTED_SRC - migrate_out_request = MagicMock() - src_blocks = [1, 2, 3] - dst_blocks = [1, 2] - backend_engine.get_request_incremental_blocks.return_value = src_blocks - migrate_out_request.should_abort_migration.return_value = False - migrate_out_request.blocking_migration = True - migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) - status = await coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) - assert status == MigrationStatus.ABORTED_DST - -# setup_ray_env should be passed after migrate_out_onestage -@patch.object(MigrationCoordinator, 'migrate_out_onestage') +# setup_ray_env should be passed after _migrate_out_onestage +@patch.object(MigrationCoordinator, '_migrate_out_onestage') @pytest.mark.asyncio -async def test_migrate_out_multistage(_, setup_ray_env): +async def test_migrate_out_running_request(_, setup_ray_env): # Create mock objects backend_engine = MagicMock(spec=BackendInterface) migrate_in_ray_actor = MagicMock() @@ -110,16 +101,41 @@ async def test_migrate_out_multistage(_, setup_ray_env): migrate_in_ray_actor.execute_engine_method.remote = MagicMock() migrate_in_ray_actor.execute_engine_method.remote.return_value = ray_remote_call.remote([1]) migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote([1]) - coordinator.migrate_out_onestage.side_effect = [MigrationStatus.FINISHED_DONE] - status = await coordinator.migrate_out_multistage(migrate_in_ray_actor, migrate_out_request) - assert coordinator.migrate_out_onestage.call_count == 1 - assert status == MigrationStatus.FINISHED_DONE + coordinator._migrate_out_onestage.side_effect = [MigrationStatus.FINISHED] + status = await coordinator.migrate_out_running_request(migrate_in_ray_actor, migrate_out_request) + assert coordinator._migrate_out_onestage.call_count == 1 + assert status == MigrationStatus.FINISHED max_stages = 3 - coordinator.migrate_out_onestage.side_effect = [MigrationStatus.RUNNING, - MigrationStatus.RUNNING, - MigrationStatus.RUNNING, - MigrationStatus.RUNNING] - status = await coordinator.migrate_out_multistage(migrate_in_ray_actor, migrate_out_request) - assert coordinator.migrate_out_onestage.call_count == max_stages + 1 + coordinator._migrate_out_onestage.side_effect = [MigrationStatus.RUNNING, + MigrationStatus.RUNNING, + MigrationStatus.RUNNING, + MigrationStatus.RUNNING] + status = await coordinator.migrate_out_running_request(migrate_in_ray_actor, migrate_out_request) + assert coordinator._migrate_out_onestage.call_count == max_stages + 1 assert status == MigrationStatus.ABORTED_SRC + +@pytest.mark.asyncio +async def test_migrate_out_waiting_request(): + # Create mock objects + backend_engine = MagicMock(spec=BackendInterface) + migrate_in_ray_actor = MagicMock() + migrate_out_request = MagicMock() + + # Create an instance of MigrationCoordinator + coordinator = MigrationCoordinator(backend_engine, last_stage_max_blocks=1, max_stages=3) + + # Test FINISHED + migrate_out_request.prefill_num_blocks = 3 + dst_blocks = [1, 2, 3] + migrate_in_ray_actor.execute_engine_method = MagicMock() + migrate_in_ray_actor.execute_engine_method.remote = MagicMock() + migrate_in_ray_actor.execute_engine_method.remote.return_value = ray_remote_call.remote(dst_blocks) + migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) + status = await coordinator.migrate_out_waiting_request(migrate_in_ray_actor, migrate_out_request) + assert status == MigrationStatus.FINISHED + + # Test FINISHED_ABORTED + migrate_out_request.prefill_num_blocks = 2 + status = await coordinator.migrate_out_waiting_request(migrate_in_ray_actor, migrate_out_request) + assert status == MigrationStatus.ABORTED_DST