diff --git a/llumnix/backends/backend_interface.py b/llumnix/backends/backend_interface.py index 87d21b2e..16a8ac1f 100644 --- a/llumnix/backends/backend_interface.py +++ b/llumnix/backends/backend_interface.py @@ -68,7 +68,7 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: raise NotImplementedError @abstractmethod - def _start_engine_step_loop(self) -> None: + async def _start_engine_step_loop(self) -> None: """Start step loop of backend engine. """ raise NotImplementedError @@ -244,7 +244,7 @@ def free_src_request(self, backend_request: LlumnixRequest) -> None: raise NotImplementedError @abstractmethod - def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]): + async def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]): """ Sends cache blocks from the source instance to the destination instance. diff --git a/llumnix/backends/vllm/executor.py b/llumnix/backends/vllm/executor.py index 825e1005..21f63a9e 100644 --- a/llumnix/backends/vllm/executor.py +++ b/llumnix/backends/vllm/executor.py @@ -12,6 +12,8 @@ # limitations under the License. import time +import asyncio + from collections import defaultdict from typing import List, Optional, Tuple import ray @@ -19,9 +21,9 @@ # pylint: disable=unused-import from ray.util.placement_group import PlacementGroup -from vllm.executor.gpu_executor import GPUExecutor -from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayWorkerWrapper, get_distributed_init_method,\ - get_ip, get_vllm_instance_id, get_open_port +from vllm.executor.executor_base import ExecutorBase +from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync, RayWorkerWrapper,\ + get_distributed_init_method, get_ip, get_vllm_instance_id, get_open_port from vllm import envs from vllm.sequence import Logprob, SequenceOutput, SequenceGroupOutput, SamplerOutput, ExecuteModelRequest @@ -34,7 +36,7 @@ logger = init_logger(__name__) -class LlumnixRayGPUExecutor(RayGPUExecutor): +class LlumnixRayGPUExecutor(RayGPUExecutorAsync): node_id: str = None migration_config: MigrationConfig = None @@ -157,17 +159,17 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", cache_config=self.cache_config, parallel_config=self.parallel_config) - def execute_model(self, *args, **kwargs): + async def execute_model_async(self, *args, **kwargs): t0 = time.time() - outputs = super().execute_model(*args, **kwargs) + outputs = await super().execute_model_async(*args, **kwargs) t1 = time.time() self.last_inference_latency = (t1 - t0) * 1000 return outputs -class SimGPUExecutor(GPUExecutor): +class SimGPUExecutor(RayGPUExecutor): latency_mem: LatencyMemData = None def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + RayGPUExecutor.__init__(self, *args, **kwargs) self.last_inference_latency = 0 self.migration_bandwidth = self.latency_mem.migration_bandwidth # TODO(ZeldaHuang): add swap bandwidth @@ -191,7 +193,7 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, num_cpu_blocks) - def execute_model( + async def execute_model_async( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: prefill_seq_len = 0 @@ -213,7 +215,7 @@ def execute_model( decode_meta_data = (decode_bs, decode_seq_len) latency += self.latency_mem.decode_latency[decode_meta_data][0] if decode_meta_data in self.latency_mem.decode_latency \ else model_decode((decode_bs, decode_seq_len), *self.latency_mem.decode_model_params) - time.sleep(latency/1000) + await asyncio.sleep(latency/1000) sampler_outputs = [] for meta_data in execute_model_req.seq_group_metadata_list: samples = [] @@ -225,6 +227,6 @@ def execute_model( sampler_outputs.append(output) return [SamplerOutput(outputs=sampler_outputs)] - def send_blocks(self, blocks_len) -> None: + async def send_blocks(self, blocks_len) -> None: migration_latency = (self.cache_block_size * blocks_len) / self.migration_bandwidth - time.sleep(migration_latency) + await asyncio.sleep(migration_latency) diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index 3e25b393..c48ddb2c 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -22,7 +22,7 @@ from ray.util.placement_group import PlacementGroup from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy -from vllm.engine.llm_engine import LLMEngine +from vllm.engine.async_llm_engine import _AsyncLLMEngine from vllm.core.scheduler import ScheduledSequenceGroup from vllm.outputs import RequestOutput from vllm.sequence import SequenceGroup, SequenceStatus, SamplerOutput, SequenceGroupMetadata @@ -82,7 +82,7 @@ async def put_nowait_to_servers(self, logger.error("exception traceback: {}".format(traceback.format_exc())) -class LLMEngineLlumnix(LLMEngine): +class LLMEngineLlumnix(_AsyncLLMEngine): def __init__(self, instance_id: str, output_queue_type: QueueType, @@ -171,38 +171,37 @@ def _process_model_outputs( seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[List[RequestOutput], List[ServerInfo]]: # ensure scheduled_seq_groups matching output - with self.scheduler.scheduler_lock: - server_infos = [] - if output: - new_output = [] - new_scheduled_seq_groups = [] - new_seq_group_metadata_list = [] - for scheduled_seq_group, seq_group_meta, seq_group_output in zip(scheduled_seq_groups, seq_group_metadata_list, output[0].outputs): - seq_group = scheduled_seq_group.seq_group - if seq_group.get_seqs(SequenceStatus.RUNNING): - new_scheduled_seq_groups.append(scheduled_seq_group) - new_seq_group_metadata_list.append(seq_group_meta) - new_output.append(seq_group_output) - server_infos.append(seq_group.server_info) - scheduled_seq_groups = new_scheduled_seq_groups - output[0].outputs = new_output - seq_group_metadata_list = new_seq_group_metadata_list - for ignored_seq_group in ignored_seq_groups: - server_infos.append(ignored_seq_group.server_info) - for server_info in server_infos: - if hasattr(server_info, 'request_timestamps'): - server_info.request_timestamps.engine_process_model_outputs_timestamp_begin = time.time() - request_outputs = super()._process_model_outputs(output, scheduled_seq_groups, ignored_seq_groups, seq_group_metadata_list) - for request_output, server_info in zip(request_outputs, server_infos): - if hasattr(server_info, 'request_timestamps'): - request_output.request_timestamps = server_info.request_timestamps - request_output.request_timestamps.engine_process_model_outputs_timestamp_end = time.time() - # TODO(ZeldaHuang): Use LlumnixRequestOutput to store llumnix output args. - return request_outputs, server_infos - - def step(self) -> None: + server_infos = [] + if output: + new_output = [] + new_scheduled_seq_groups = [] + new_seq_group_metadata_list = [] + for scheduled_seq_group, seq_group_meta, seq_group_output in zip(scheduled_seq_groups, seq_group_metadata_list, output[0].outputs): + seq_group = scheduled_seq_group.seq_group + if seq_group.get_seqs(SequenceStatus.RUNNING): + new_scheduled_seq_groups.append(scheduled_seq_group) + new_seq_group_metadata_list.append(seq_group_meta) + new_output.append(seq_group_output) + server_infos.append(seq_group.server_info) + scheduled_seq_groups = new_scheduled_seq_groups + output[0].outputs = new_output + seq_group_metadata_list = new_seq_group_metadata_list + for ignored_seq_group in ignored_seq_groups: + server_infos.append(ignored_seq_group.server_info) + for server_info in server_infos: + if hasattr(server_info, 'request_timestamps'): + server_info.request_timestamps.engine_process_model_outputs_timestamp_begin = time.time() + request_outputs = super()._process_model_outputs(output, scheduled_seq_groups, ignored_seq_groups, seq_group_metadata_list) + for request_output, server_info in zip(request_outputs, server_infos): + if hasattr(server_info, 'request_timestamps'): + request_output.request_timestamps = server_info.request_timestamps + request_output.request_timestamps.engine_process_model_outputs_timestamp_end = time.time() + # TODO(ZeldaHuang): Use LlumnixRequestOutput to store llumnix output args. + return request_outputs, server_infos + + async def step_async(self) -> None: step_begin_time = time.time() - request_outputs, server_infos = super().step() + request_outputs, server_infos = await super().step_async() for request_output in request_outputs: if hasattr(request_output, 'request_timestamps'): request_output.request_timestamps.engine_step_timestamp_begin = step_begin_time @@ -251,7 +250,6 @@ def add_request(self, request_id: str, server_info: ServerInfo, expected_steps: self.scheduler.waiting[-1] = SequenceGroupLlumnix(request_id, server_info, expected_steps, [seq_group.get_seqs()[0]], seq_group.sampling_params, seq_group.metrics.arrival_time, seq_group.lora_request, seq_group.multi_modal_data) - self.scheduler.scheduler_lock.release() def _start_put_queue_loop(self): while True: @@ -301,45 +299,38 @@ def __init__( src_worker_handle_list=self.worker_handle_list, placement_group=placement_group, node_id=node_id) - self.state_lock = threading.Lock() self.state = EngineState.INIT logger.info("engine ({}) current state {}".format(self.instance_id, self.state)) - self._stop_event = threading.Event() - self.engine_step_loop_thread = threading.Thread( - target=self._start_engine_step_loop, args=(), daemon=True, name="engine_step_loop" - ) - self.engine_step_loop_thread.start() + self._stop_event = asyncio.Event() + asyncio.create_task(self._start_engine_step_loop()) - def _start_engine_step_loop(self) -> None: + async def _start_engine_step_loop(self) -> None: self._stop_event.clear() - with self.state_lock: - previous_state = self.state - self.state = EngineState.RUNNING - logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, previous_state, self.state)) + previous_state = self.state + self.state = EngineState.RUNNING + logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, previous_state, self.state)) while not self._stop_event.is_set(): try: - request_outputs, _ = self.engine.step() + request_outputs, _ = await self.engine.step_async() if len(request_outputs) == 0: - time.sleep(0.01) + await asyncio.sleep(0.01) # pylint: disable=broad-except except Exception as e: logger.error("Error in engine loop: {}".format(e)) logger.error("exception traceback: {}".format(traceback.format_exc())) self._run_workers("shutdown") - with self.state_lock: - previous_state = self.state - self.state = EngineState.CRASHED - logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, previous_state, self.state)) + previous_state = self.state + self.state = EngineState.CRASHED + logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, previous_state, self.state)) break - with self.state_lock: - if self.state == EngineState.RUNNING: - self.state = EngineState.STOPPED - logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, EngineState.RUNNING, self.state)) + if self.state == EngineState.RUNNING: + self.state = EngineState.STOPPED + logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, EngineState.RUNNING, self.state)) def execute_worker_method(self, method, *args, **kwargs): return self.engine.model_executor.driver_worker.execute_method(method, *args, **kwargs) @@ -362,12 +353,12 @@ def commit_dst_request(self, backend_request: SequenceGroupLlumnix) -> None: backend_request.reset_migration_args() self.add_running_request(backend_request) - def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None: - ray.get(dst_ray_actor.execute_engine_method.remote("_run_workers", + async def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None: + await dst_ray_actor.execute_engine_method.remote("_run_workers", "migrate_cache", dst_blocks=dst_blocks, src_blocks=src_blocks, - src_worker_handle_list=self.worker_handle_list)) + src_worker_handle_list=self.worker_handle_list) def _run_workers(self, *args, **kwargs): # pylint: disable=protected-access diff --git a/llumnix/backends/vllm/scheduler.py b/llumnix/backends/vllm/scheduler.py index 7e9064d7..a14db0b3 100644 --- a/llumnix/backends/vllm/scheduler.py +++ b/llumnix/backends/vllm/scheduler.py @@ -13,7 +13,6 @@ from asyncio.log import logger import time -import threading from typing import Dict, List, Optional, Tuple from collections import deque @@ -23,7 +22,6 @@ from llumnix.instance_info import InstanceInfo from llumnix.logger import init_logger from llumnix.llumlet.request import RequestInferenceType -from llumnix.backends.vllm.utils import scheduler_lock from llumnix.backends.vllm.sequence import SequenceGroupLlumnix logger = init_logger(__name__) @@ -56,7 +54,6 @@ def __init__(self, *args, **kwargs) -> None: sliding_window=self.cache_config.sliding_window, enable_caching=self.cache_config.enable_prefix_caching) self.pre_alloc_cache_dict: Dict[str, BlockTable] = {} - self.scheduler_lock = threading.Lock() self.migrating_out_request_last_stage: List[SequenceGroupLlumnix] = [] def add_update_instance_info_callback(self, update_instance_info_callback): @@ -79,11 +76,9 @@ def _get_num_killed_requests(self) -> int: cnt += 1 return cnt - @scheduler_lock def get_running_queue(self): return self.running - @scheduler_lock def get_all_request_ids(self) -> List[str]: request_ids : List[str] = [] for state_queue in [self.waiting, self.running, self.swapped]: @@ -91,13 +86,11 @@ def get_all_request_ids(self) -> List[str]: request_ids.append(seq_group.request_id) return request_ids - @scheduler_lock def get_request_incremental_blocks(self, backend_request: SequenceGroupLlumnix, pre_stage_num_blocks: int) -> List[int]: seq = backend_request.get_seqs()[0] blocks = self.block_manager.get_block_table(seq) return blocks[pre_stage_num_blocks:] - @scheduler_lock def remove_running_request(self, request_id: str) -> None: for seq_group in self.running: if seq_group.request_id == request_id: @@ -117,7 +110,6 @@ def pop_migrating_out_requests_last_stage(self) -> List[SequenceGroupLlumnix]: self.migrating_out_request_last_stage.clear() return migrating_out_request_last_stage - @scheduler_lock def pre_alloc(self, request_id: str, block_num: int) -> List[int]: blocks = self.block_manager.get_free_blocks(block_num) pre_blocks = self.pre_alloc_cache_dict.get(request_id, []) @@ -126,17 +118,14 @@ def pre_alloc(self, request_id: str, block_num: int) -> List[int]: blocks = [block.block_number for block in blocks] return blocks - @scheduler_lock def add_running_request(self, backend_request: SequenceGroupLlumnix) -> None: seq = backend_request.get_seqs()[0] seq.status = SequenceStatus.RUNNING self.running.append(backend_request) - @scheduler_lock def is_request_running(self, backend_request: SequenceGroupLlumnix) -> bool: return backend_request in self.running - @scheduler_lock def free_dst_pre_alloc_cache(self, request_id: str = None) -> None: if request_id: blocks = self.pre_alloc_cache_dict.pop(request_id, []) @@ -150,7 +139,6 @@ def free_dst_pre_alloc_cache(self, request_id: str = None) -> None: # pylint: disable=protected-access self.block_manager._free_block_table(blocks) - @scheduler_lock def free_src_request(self, backend_request: SequenceGroupLlumnix) -> None: seq = backend_request.get_seqs()[0] logger.info("free seq {}".format(seq.seq_id)) @@ -201,7 +189,6 @@ def _get_instance_info(self, scheduled_seq_groups: List[SequenceGroupLlumnix]) - instance_info.finished_request_ids = [seq_group.request_id for seq_group in self.running if seq_group.is_finished()] return instance_info - @scheduler_lock def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_group_metadata_list, scheduler_outputs = super().schedule() self.update_instance_info_callback(self._get_instance_info([scheduled_seq_group.seq_group \ @@ -220,13 +207,3 @@ def _schedule_running(self, running_queue: deque, *args, **kwargs): for seq_group in remove_running: remaining_running.extend([seq_group]) return remaining_running, running_scheduled - - def add_seq_group(self, *args, **kwargs): - # The scheduler lock is mannually released in the end of LLMEngineLlumnix.add_request function. - # pylint: disable=R1732 - self.scheduler_lock.acquire() - return super().add_seq_group(*args, **kwargs) - - @scheduler_lock - def abort_seq_group(self, *args, **kwargs): - return super().abort_seq_group(*args, **kwargs) diff --git a/llumnix/backends/vllm/simulator.py b/llumnix/backends/vllm/simulator.py index e763ecb2..b5ccb45b 100644 --- a/llumnix/backends/vllm/simulator.py +++ b/llumnix/backends/vllm/simulator.py @@ -12,7 +12,7 @@ # limitations under the License. import os -import threading +import asyncio from typing import List import ray.actor @@ -49,15 +49,11 @@ def __init__( self.engine.output_processor.scheduler = self.engine.scheduler self.instance_id = instance_id - self.state_lock = threading.Lock() self.state = EngineState.INIT logger.info("engine ({}) current state {}".format(self.instance_id, self.state)) - self._stop_event = threading.Event() - self.engine_step_loop_thread = threading.Thread( - target=self._start_engine_step_loop, args=(), daemon=True, name="engine_step_loop" - ) - self.engine_step_loop_thread.start() + self._stop_event = asyncio.Event() + asyncio.create_task(self._start_engine_step_loop()) def _get_lantecy_mem(self, profiling_result_file_path: str, engine_args: EngineArgs) -> LatencyMemData: # load database @@ -80,5 +76,5 @@ def _get_lantecy_mem(self, profiling_result_file_path: str, engine_args: EngineA return latency_mem # pylint: disable=unused-argument - def send_blocks(self, dst_ray_actor: ray.actor.ActorHandle, src_blocks: List[int], dst_blocks: List[int]) -> None: - self.engine.model_executor.send_blocks(len(src_blocks)) + async def send_blocks(self, dst_ray_actor: ray.actor.ActorHandle, src_blocks: List[int], dst_blocks: List[int]) -> None: + await self.engine.model_executor.send_blocks(len(src_blocks)) diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index 5d220676..5aa3e4c2 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import threading +import asyncio import traceback from typing import List, Union, Iterable import time @@ -55,9 +55,7 @@ def __init__(self, self.backend_engine) self.log_requests = True - self.check_state_thread = threading.Thread(target=self.check_state, daemon=True, - name="llumlet_check_state_loop") - self.check_state_thread.start() + self.check_state_thread = asyncio.create_task(self.check_state()) # pylint: disable=broad-except except Exception as e: logger.error("Failed to initialize llumlet: {}".format(e)) @@ -120,22 +118,17 @@ def from_args(cls, llumlet = engine_class.remote(instance_id, output_queue_type, backend_type, migration_config, *args, **kwargs) return llumlet - def check_state(self): + async def check_state(self): while True: - time.sleep(1) - - with self.backend_engine.state_lock: - if self.backend_engine.state == EngineState.CRASHED: - logger.warning("llumlet ({}) detected backend engine crashed. Stopping...".format(self.instance_id)) - # pylint: disable=protected-access - self.backend_engine._stop_event.set() - if self.backend_engine.engine_step_loop_thread.is_alive(): - self.backend_engine.engine_step_loop_thread.join() - - self_actor = ray.get_actor(self.actor_name) - ray.kill(self_actor) - - def migrate_out(self, dst_instance_name: str, num_requests: int) -> List[str]: + await asyncio.sleep(1) + if self.backend_engine.state == EngineState.CRASHED: + logger.warning("llumlet ({}) detected backend engine crashed. Stopping...".format(self.instance_id)) + # pylint: disable=protected-access + self.backend_engine._stop_event.set() + self_actor = ray.get_actor(self.actor_name) + ray.kill(self_actor) + + async def migrate_out(self, dst_instance_name: str, num_requests: int) -> List[str]: try: migrate_in_ray_actor = ray.get_actor(dst_instance_name, namespace='llumnix') dst_instance_id = dst_instance_name[len("instance_"):] @@ -149,16 +142,16 @@ def migrate_out(self, dst_instance_name: str, num_requests: int) -> List[str]: if migrate_out_request is None: return migrated_request_list logger.info("{}->{} begin migrate out {}".format(self.instance_id, dst_instance_id, migrate_out_request.request_id)) - status = self.migration_coordinator.migrate_out_multistage(migrate_in_ray_actor, migrate_out_request) + status = await self.migration_coordinator.migrate_out_multistage(migrate_in_ray_actor, migrate_out_request) if status == MigrationStatus.FINISHED_DONE: - ray.get(migrate_in_ray_actor.execute_engine_method.remote("commit_dst_request", migrate_out_request)) + await migrate_in_ray_actor.execute_engine_method.remote("commit_dst_request", migrate_out_request) self.backend_engine.free_src_request(migrate_out_request) migrated_request_list.append(migrate_out_request.request_id) migrate_out_request.stage_timestamps.append(time.time()) self.backend_engine.remove_migrating_out_request_last_stage(migrate_out_request) else: migrate_out_request.reset_migration_args() - ray.get(migrate_in_ray_actor.execute_migration_method.remote("free_dst_pre_alloc_cache", migrate_out_request.request_id)) + await migrate_in_ray_actor.execute_migration_method.remote("free_dst_pre_alloc_cache", migrate_out_request.request_id) continue_migrate = False t1 = time.time() logger.info("{}->{} migrate done, migrate request {}, status:{}, len:{} blocks, cost:{} ms" \ diff --git a/llumnix/llumlet/migration_coordinator.py b/llumnix/llumlet/migration_coordinator.py index 5b2d5451..03b20cb2 100644 --- a/llumnix/llumlet/migration_coordinator.py +++ b/llumnix/llumlet/migration_coordinator.py @@ -15,6 +15,7 @@ import enum from typing import List +# pylint: disable=unused-import import ray from llumnix.logger import init_logger @@ -49,7 +50,7 @@ def __init__(self, self.max_stages = max_stages self.backend_engine = backend_engine - def migrate_out_onestage(self, migrate_in_ray_actor: "ray.actor.ActorHandle", migrate_out_request: LlumnixRequest, ) -> "MigrationStatus": + async def migrate_out_onestage(self, migrate_in_ray_actor: "ray.actor.ActorHandle", migrate_out_request: LlumnixRequest, ) -> "MigrationStatus": """one-stage live migration until last stage """ pre_stage_num_blocks = sum(migrate_out_request.stage_num_blocks_list) @@ -60,8 +61,8 @@ def migrate_out_onestage(self, migrate_in_ray_actor: "ray.actor.ActorHandle", m if not is_last_stage: src_blocks = incremental_blocks[:-1] stage_block_num = len(incremental_blocks) - 1 - dst_blocks = ray.get(migrate_in_ray_actor.execute_migration_method \ - .remote("migrate_in_pre_alloc", migrate_out_request.request_id, stage_block_num)) + dst_blocks = await migrate_in_ray_actor.execute_migration_method \ + .remote("migrate_in_pre_alloc", migrate_out_request.request_id, stage_block_num) else: # last stage migration, stop inference, transfer all blocks migration_status = MigrationStatus.FINISHED_DONE @@ -69,8 +70,8 @@ def migrate_out_onestage(self, migrate_in_ray_actor: "ray.actor.ActorHandle", m self.backend_engine.add_migrating_out_request_last_stage(migrate_out_request) stage_block_num = len(incremental_blocks) src_blocks = incremental_blocks[:] - dst_blocks = ray.get(migrate_in_ray_actor.execute_migration_method \ - .remote("migrate_in_pre_alloc", migrate_out_request.request_id, stage_block_num)) + dst_blocks = await migrate_in_ray_actor.execute_migration_method \ + .remote("migrate_in_pre_alloc", migrate_out_request.request_id, stage_block_num) if len(dst_blocks) != len(src_blocks): # migrate-in instance failed to prev alloc @@ -83,14 +84,14 @@ def migrate_out_onestage(self, migrate_in_ray_actor: "ray.actor.ActorHandle", m migrate_out_request.stage_timestamps.append(time.time()) migrate_out_request.stage_num_blocks_list.append(stage_block_num) # TODO(ZeldaHuang): send_blocks in migrate_in_pre_alloc/migrate_in_last_stage - self.backend_engine.send_blocks(migrate_in_ray_actor, src_blocks, dst_blocks) + await self.backend_engine.send_blocks(migrate_in_ray_actor, src_blocks, dst_blocks) if not is_last_stage and migrate_out_request.should_abort_migration(): # migrate-out request abort by scheduler during send/recv migration_status = MigrationStatus.ABORTED_SRC return migration_status - def migrate_out_multistage(self, migrate_in_ray_actor: "ray.actor.ActorHandle", migrate_out_request: LlumnixRequest) -> "MigrationStatus": + async def migrate_out_multistage(self, migrate_in_ray_actor: "ray.actor.ActorHandle", migrate_out_request: LlumnixRequest) -> "MigrationStatus": """Migrate out requests to a specified instance, return migrated request id. Args: dst_instance_name:instance actor name, used to get ray actor handle @@ -98,7 +99,7 @@ def migrate_out_multistage(self, migrate_in_ray_actor: "ray.actor.ActorHandle", state_count = 0 while state_count < self.max_stages: state_count += 1 - status = self.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) + status = await self.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) if MigrationStatus.is_finished(status): return status # exceed max stages diff --git a/tests/unit_test/backends/vllm/test_llm_engine.py b/tests/unit_test/backends/vllm/test_llm_engine.py index a6c6e3a1..6e9e6a05 100644 --- a/tests/unit_test/backends/vllm/test_llm_engine.py +++ b/tests/unit_test/backends/vllm/test_llm_engine.py @@ -108,7 +108,6 @@ def test_llm_engine_add_requset(): migration_config=None, latency_mem=MagicMock(sepc=LatencyMemData)) sampling_params = SamplingParams(top_k=1, temperature=0, ignore_eos=True, max_tokens=100) - llm_engine.scheduler.scheduler_lock = MagicMock() server_info = ServerInfo(None, None, None, None, None) llm_engine.add_request("0", server_info, math.inf, "prompt", sampling_params) assert len(llm_engine.scheduler.waiting) == 1 diff --git a/tests/unit_test/backends/vllm/test_simulator.py b/tests/unit_test/backends/vllm/test_simulator.py index ae3033d0..7fb94baa 100644 --- a/tests/unit_test/backends/vllm/test_simulator.py +++ b/tests/unit_test/backends/vllm/test_simulator.py @@ -26,7 +26,8 @@ def _get_lantecy_mem(self, *args, **kwargs): latency_mem.decode_model_params = (0,0,0) return latency_mem -def test_executor(): +@pytest.mark.asyncio +async def test_executor(): engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) engine_config = engine_args.create_engine_config() latency_mem = LatencyMemData({},{},{}) @@ -62,7 +63,7 @@ def test_executor(): num_lookahead_slots=out.num_lookahead_slots, running_queue_size=out.running_queue_size, ) - outputs = executor.execute_model(execute_model_req) + outputs = await executor.execute_model_async(execute_model_req) assert len(outputs[0].outputs) == 2 @pytest.mark.asyncio diff --git a/tests/unit_test/backends/vllm/utils.py b/tests/unit_test/backends/vllm/utils.py index 769e2576..bc8d1f09 100644 --- a/tests/unit_test/backends/vllm/utils.py +++ b/tests/unit_test/backends/vllm/utils.py @@ -26,26 +26,18 @@ from llumnix.backends.vllm.sequence import SequenceGroupLlumnix from llumnix.server_info import ServerInfo - -class SchedulerLlumnixTest(SchedulerLlumnix): - def add_seq_group(self, *args, **kwargs): - ret = super().add_seq_group(*args, **kwargs) - self.scheduler_lock.release() - return ret - - def initialize_scheduler(*, max_num_seqs=1000, max_token_budget=1000, max_model_len=1000, - lora_config=None) -> SchedulerLlumnixTest: + lora_config=None) -> SchedulerLlumnix: block_size = 4 scheduler_config = SchedulerConfig(max_token_budget, max_num_seqs, max_model_len) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - scheduler = SchedulerLlumnixTest(scheduler_config, cache_config, lora_config) + scheduler = SchedulerLlumnix(scheduler_config, cache_config, lora_config) scheduler.update_instance_info_callback = MagicMock() return scheduler diff --git a/tests/unit_test/llumlet/test_engine_step_exception.py b/tests/unit_test/llumlet/test_engine_step_exception.py index 6ce2fa1a..56b58322 100644 --- a/tests/unit_test/llumlet/test_engine_step_exception.py +++ b/tests/unit_test/llumlet/test_engine_step_exception.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import threading +import asyncio import time import ray import torch @@ -32,26 +32,21 @@ class MockLlumlet(Llumlet): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.origin_step = self.backend_engine.engine.step + self.origin_step = self.backend_engine.engine.step_async def set_error_step(self, broken: bool): self.backend_engine._stop_event.set() - if self.backend_engine.engine_step_loop_thread.is_alive(): - self.backend_engine.engine_step_loop_thread.join() - def raise_error_step(): - self.origin_step() + async def raise_error_step(): + await self.origin_step() raise ValueError("Mock engine step error") if broken: - self.backend_engine.engine.step = raise_error_step + self.backend_engine.engine.step_async = raise_error_step else: - self.backend_engine.engine.step = self.origin_step + self.backend_engine.engine.step_async = self.origin_step - self.backend_engine.engine_step_loop_thread = threading.Thread( - target=self.backend_engine._start_engine_step_loop, args=(), daemon=True, name="engine_loop" - ) - self.backend_engine.engine_step_loop_thread.start() + asyncio.create_task(self.backend_engine._start_engine_step_loop()) @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need at least 1 GPU to run the test.") def test_engine_step_exception(setup_ray_env): diff --git a/tests/unit_test/llumlet/test_migration_coordinator.py b/tests/unit_test/llumlet/test_migration_coordinator.py index c85bbb29..8a1a4d44 100644 --- a/tests/unit_test/llumlet/test_migration_coordinator.py +++ b/tests/unit_test/llumlet/test_migration_coordinator.py @@ -15,6 +15,7 @@ import math import ray +import pytest from llumnix.llumlet.migration_coordinator import MigrationCoordinator from llumnix.backends.backend_interface import BackendInterface @@ -29,7 +30,8 @@ def ray_remote_call(ret): return ret -def test_migrate_out_onestage(setup_ray_env): +@pytest.mark.asyncio +async def test_migrate_out_onestage(setup_ray_env): # Create mock objects backend_engine = MagicMock(spec=BackendInterface) migrate_in_ray_actor = MagicMock() @@ -47,7 +49,7 @@ def test_migrate_out_onestage(setup_ray_env): migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) # Test normal migration scenario - status = coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) + status = await coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) assert status == MigrationStatus.RUNNING # Test the last stage of migration @@ -57,7 +59,7 @@ def test_migrate_out_onestage(setup_ray_env): migrate_out_request.should_abort_migration.return_value = False migrate_out_request.blocking_migration = False migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) - status = coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) + status = await coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) assert status == MigrationStatus.FINISHED_DONE migrate_out_request = MagicMock() @@ -68,7 +70,7 @@ def test_migrate_out_onestage(setup_ray_env): migrate_out_request.should_abort_migration.return_value = False migrate_out_request.blocking_migration = False migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) - status = coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) + status = await coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) assert status == MigrationStatus.ABORTED_DST migrate_out_request = MagicMock() @@ -78,7 +80,7 @@ def test_migrate_out_onestage(setup_ray_env): migrate_out_request.should_abort_migration.return_value = True migrate_out_request.blocking_migration = False migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) - status = coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) + status = await coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) assert status == MigrationStatus.ABORTED_SRC migrate_out_request = MagicMock() @@ -88,12 +90,13 @@ def test_migrate_out_onestage(setup_ray_env): migrate_out_request.should_abort_migration.return_value = False migrate_out_request.blocking_migration = True migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) - status = coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) + status = await coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request) assert status == MigrationStatus.ABORTED_DST # setup_ray_env should be passed after migrate_out_onestage @patch.object(MigrationCoordinator, 'migrate_out_onestage') -def test_migrate_out_multistage(_, setup_ray_env): +@pytest.mark.asyncio +async def test_migrate_out_multistage(_, setup_ray_env): # Create mock objects backend_engine = MagicMock(spec=BackendInterface) migrate_in_ray_actor = MagicMock() @@ -108,7 +111,7 @@ def test_migrate_out_multistage(_, setup_ray_env): migrate_in_ray_actor.execute_engine_method.remote.return_value = ray_remote_call.remote([1]) migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote([1]) coordinator.migrate_out_onestage.side_effect = [MigrationStatus.FINISHED_DONE] - status = coordinator.migrate_out_multistage(migrate_in_ray_actor, migrate_out_request) + status = await coordinator.migrate_out_multistage(migrate_in_ray_actor, migrate_out_request) assert coordinator.migrate_out_onestage.call_count == 1 assert status == MigrationStatus.FINISHED_DONE @@ -117,6 +120,6 @@ def test_migrate_out_multistage(_, setup_ray_env): MigrationStatus.RUNNING, MigrationStatus.RUNNING, MigrationStatus.RUNNING] - status = coordinator.migrate_out_multistage(migrate_in_ray_actor, migrate_out_request) + status = await coordinator.migrate_out_multistage(migrate_in_ray_actor, migrate_out_request) assert coordinator.migrate_out_onestage.call_count == max_stages + 1 assert status == MigrationStatus.ABORTED_SRC