diff --git a/.github/workflows/offline_inference.yml b/.github/workflows/offline_inference.yml index 24084c6e..427bd00a 100644 --- a/.github/workflows/offline_inference.yml +++ b/.github/workflows/offline_inference.yml @@ -20,7 +20,7 @@ jobs: offline_inference: needs: cancel_previous_workflows runs-on: [self-hosted] - timeout-minutes: 5 + timeout-minutes: 10 steps: - uses: actions/checkout@v4 - name: Run offline inference example diff --git a/Makefile b/Makefile index c3618524..2a703565 100644 --- a/Makefile +++ b/Makefile @@ -22,7 +22,7 @@ install: .PHONY: lint lint: check_pylint_installed check_pytest_installed @pylint --rcfile=.pylintrc -s n --jobs=128 ./llumnix - + @pylint --rcfile=.pylintrc \ --disable=protected-access,super-init-not-called,unused-argument,redefined-outer-name,invalid-name \ -s n --jobs=128 ./tests @@ -62,7 +62,7 @@ test: check_pytest_installed .PHONY: unit_test unit_test: check_pytest_installed @pytest -v --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings - + .PHONY: offline_test offline_test: @python examlpes/offline_inference.py diff --git a/benchmark/benchmark_serving.py b/benchmark/benchmark_serving.py index c78bf9c0..d9250a0e 100644 --- a/benchmark/benchmark_serving.py +++ b/benchmark/benchmark_serving.py @@ -84,14 +84,12 @@ async def query_model_vllm(prompt, verbose, ip_ports): async with aiohttp.ClientSession(timeout=timeout) as session: best_of = 1 - use_beam_search = False output_len = expected_response_len request_dict = { "prompt": prompt, "n": 1, "best_of": best_of, - "use_beam_search": use_beam_search, - "temperature": 0.0 if use_beam_search else 1.0, + "temperature": 1.0, "top_k": 1, "max_tokens": max(output_len, 1), "ignore_eos": True, @@ -815,18 +813,18 @@ def main(): except FileNotFoundError: os.mknod(file_name) with open(file_name, 'w') as f: - results.append({"qps": args.qps, + results.append({"qps": args.qps, "cv": args.coefficient_variation, - "request_ids": request_ids, + "request_ids": request_ids, "request_lens": request_lens, - "request_latencies": request_latencies, - "prefill_token_latencies": prefill_token_latencies, + "request_latencies": request_latencies, + "prefill_token_latencies": prefill_token_latencies, "decode_token_latencies": decode_token_latencies, - "decode_sum_latencies": decode_sum_latencies, + "decode_sum_latencies": decode_sum_latencies, "all_decode_token_latencies": all_decode_token_latencies, "inference_latencies": inference_latencies, "per_token_latencies_breakdown_dict": per_token_latencies_breakdown_dict, - "throughput": throughput, + "throughput": throughput, "instance_num": avg_instance_num}) json.dump(results, f) diff --git a/configs/vllm.yml b/configs/vllm.yml index 3bd5d9bf..5f65a8ea 100644 --- a/configs/vllm.yml +++ b/configs/vllm.yml @@ -18,7 +18,7 @@ MANAGER: ENABLE_DEFRAG: True REQUEST_MIGRATION_POLICY: 'SR' - MIGRATION_BACKEND: 'gloo' + MIGRATION_BACKEND: 'rayrpc' MIGRATION_BUFFER_BLOCKS: 512 ENABLE_SCALING: False diff --git a/docs/Quickstart.md b/docs/Quickstart.md index 4fcd605f..5c80a8fb 100644 --- a/docs/Quickstart.md +++ b/docs/Quickstart.md @@ -82,11 +82,10 @@ HEAD_NODE=1 python -m llumnix.entrypoints.vllm.api_server \ --initial-instances $INITIAL_INSTANCES \ --launch-ray-cluster \ --model $MODEL_PATH \ - --engine-use-ray \ --worker-use-ray \ --max-model-len 4096 ``` -`CONFIG_PATH` is the path to the configuration file for Llumnix, and we give an example configuration file [here](../configs/base.yml). `MODEL_PATH` defines the location of your model. `INITIAL_INSTANCES` determines the number of instances to be launched on the current node, +`CONFIG_PATH` is the path to the configuration file for Llumnix, and we give an example configuration file [here](../configs/base.yml). `MODEL_PATH` defines the location of your model. `INITIAL_INSTANCES` determines the number of instances to be launched on the current node, Second, you can run the benchmark to evaluate the serving performance: diff --git a/examlpes/offline_inference.py b/examlpes/offline_inference.py index 5ab8f39f..a8353c3b 100644 --- a/examlpes/offline_inference.py +++ b/examlpes/offline_inference.py @@ -74,10 +74,10 @@ async def main(): for request in prompts: request_id = random_uuid() await manager.generate.remote(request_id=request_id, - server_info=server_info, + server_info=server_info, prompt=request, - sampling_params=sampling_params,) - + params=sampling_params,) + await output_task asyncio.run(main()) diff --git a/llumnix/backends/backend_interface.py b/llumnix/backends/backend_interface.py index 5e34c01f..e7b74e68 100644 --- a/llumnix/backends/backend_interface.py +++ b/llumnix/backends/backend_interface.py @@ -13,7 +13,7 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Iterable, List, Union, Deque +from typing import Iterable, List, Union, Deque, Tuple from llumnix.llumlet.request import LlumnixRequest, RequestStatus from llumnix.server_info import ServerInfo @@ -71,7 +71,7 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: # Methods for migration @abstractmethod - def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_stage_num_blocks: int) -> List[int]: + def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_stage_num_blocks: int) -> Tuple[List[int], List[int]]: """Retrieves the incremental block table for a given request. This method is used to fetch a list of block numbers that represent the incremental @@ -88,7 +88,7 @@ def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_st need to be fetched in the current stage. Returns: - A list of integers, where each integer represents a block number that indicates + A list of integers and its token ids, where each integer represents a block number that indicates physical index of kv cache block tensor. These block numbers can then be used to transfer to dstination instance. """ @@ -191,7 +191,8 @@ def pre_alloc(self, request_id: str, request_status: RequestStatus, request_arrival_time: float, - block_num: int) -> List[int]: + block_num: int, + token_ids: List[int]) -> List[int]: """Pre-allocates cache blocks for a migrating request. This method selects a specified number of free cache blocks to be reserved for an incoming @@ -207,7 +208,7 @@ def pre_alloc(self, request_status: The status (waiting/running) of the request. request_arrival_time: The arrival time of the request. block_num: The number of cache blocks that need to be pre-allocated for the request. - + token_ids: The token IDs of the request. Returns: A list of integers where each integer represents the block table reserved for the migration request. """ diff --git a/llumnix/backends/bladellm/llm_engine.py b/llumnix/backends/bladellm/llm_engine.py index 557b1bc1..bb7c47e1 100644 --- a/llumnix/backends/bladellm/llm_engine.py +++ b/llumnix/backends/bladellm/llm_engine.py @@ -316,7 +316,8 @@ def pre_alloc(self, request_id: str, request_status: RequestStatus, request_arrival_time: float, - block_num: int) -> List[int]: + block_num: int, + token_ids: List[int]) -> List[int]: pass def add_running_request(self, backend_request: LlumnixRequest) -> None: diff --git a/llumnix/backends/migration_backend_interface.py b/llumnix/backends/migration_backend_interface.py index 808ba8c8..cb88c26b 100644 --- a/llumnix/backends/migration_backend_interface.py +++ b/llumnix/backends/migration_backend_interface.py @@ -33,9 +33,9 @@ def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int] raise NotImplementedError @abstractmethod - def do_send(self, dst_handle, blocks: List[int]): + def do_send(self, dst_handle, blocks: List[int], virtuel_engine: int): raise NotImplementedError @abstractmethod - def do_recv(self, src_handle, blocks: List[int]): + def do_recv(self, src_handle, blocks: List[int], virtuel_engine: int): raise NotImplementedError diff --git a/llumnix/backends/vllm/executor.py b/llumnix/backends/vllm/executor.py index 21f63a9e..1a045856 100644 --- a/llumnix/backends/vllm/executor.py +++ b/llumnix/backends/vllm/executor.py @@ -15,19 +15,20 @@ import asyncio from collections import defaultdict -from typing import List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple, Type import ray from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy # pylint: disable=unused-import from ray.util.placement_group import PlacementGroup from vllm.executor.executor_base import ExecutorBase -from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync, RayWorkerWrapper,\ - get_distributed_init_method, get_ip, get_vllm_instance_id, get_open_port +from vllm.model_executor.layers.sampler import SamplerOutput, CompletionSequenceGroupOutput +from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync, RayWorkerWrapper, envs, \ + get_ip, get_vllm_instance_id, get_distributed_init_method, get_open_port +from vllm.worker.worker_base import WorkerBase -from vllm import envs -from vllm.sequence import Logprob, SequenceOutput, SequenceGroupOutput, SamplerOutput, ExecuteModelRequest -from vllm.config import _GB +from vllm.sequence import Logprob, SequenceOutput, ExecuteModelRequest +from vllm.utils import GiB_bytes from llumnix.internal_config import MigrationConfig from llumnix.logger import init_logger @@ -39,11 +40,12 @@ class LlumnixRayGPUExecutor(RayGPUExecutorAsync): node_id: str = None migration_config: MigrationConfig = None + last_inference_latency:int = 0 def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): - self.last_inference_latency = 0 - if self.parallel_config.tensor_parallel_size == 1: + if (self.parallel_config.tensor_parallel_size == 1 + and self.parallel_config.pipeline_parallel_size == 1): # For single GPU case, we use a ray worker with constrained memory. num_gpus = self.cache_config.gpu_memory_utilization else: @@ -56,13 +58,21 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # The remaining workers are the actual ray actors. self.workers: List[RayWorkerWrapper] = [] + # Used in ray compiled DAG: indexed first by PP rank, + # and then TP rank. In other words, the inner list is + # the TP group of workers for a PP rank. + self.pp_tp_workers: List[List[RayWorkerWrapper]] = [] + if self.parallel_config.ray_workers_use_nsight: ray_remote_kwargs = self._configure_ray_workers_use_nsight( ray_remote_kwargs) + logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) + # Create the workers. driver_ip = get_ip() node_id = self.node_id + worker_wrapper_kwargs = self._get_worker_wrapper_args() for rank in range(self.parallel_config.world_size): if placement_group: bundle = placement_group.bundle_specs[rank+1] @@ -77,50 +87,93 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", node_id=node_id, soft=False, ) + worker = ray.remote( num_cpus=0, num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, - max_concurrency=2, **ray_remote_kwargs, - )(RayWorkerWrapper).remote( - worker_module_name="llumnix.backends.vllm.worker", - worker_class_name="MigrationWorker", - trust_remote_code=self.model_config.trust_remote_code, - ) - worker_ip = ray.get(worker.get_node_ip.remote()) - if worker_ip == driver_ip and self.driver_dummy_worker is None: - # If the worker is on the same node as the driver, we use it - # as the resource holder for the driver process. - self.driver_dummy_worker = worker - self.driver_worker = RayWorkerWrapper( - worker_module_name="llumnix.backends.vllm.worker", - worker_class_name="MigrationWorker", - trust_remote_code=self.model_config.trust_remote_code, - ) - else: - # Else, added to the list of workers. + )(RayWorkerWrapper).remote(**worker_wrapper_kwargs) + + if self.use_ray_spmd_worker: self.workers.append(worker) + else: + worker_ip = ray.get(worker.get_node_ip.remote()) + if worker_ip == driver_ip and self.driver_dummy_worker is None: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + **worker_wrapper_kwargs) + else: + # Else, added to the list of workers. + self.workers.append(worker) - if self.driver_dummy_worker is None: + logger.debug("workers: %s", self.workers) + logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) + if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: raise ValueError( "Ray does not allocate any GPUs on the driver node. Consider " "adjusting the Ray placement group or running the driver on a " "GPU node.") + worker_ips = [ + ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined] + for worker in self.workers + ] + ip_counts: Dict[str, int] = {} + for ip in worker_ips: + ip_counts[ip] = ip_counts.get(ip, 0) + 1 + + def sort_by_driver_then_worker_ip(worker): + """ + Sort the workers based on 3 properties: + 1. If the worker is on the same node as the driver (vllm engine), + it should be placed first. + 2. Then, if the worker is on a node with fewer workers, it should + be placed first. + 3. Finally, if the work is on a node with smaller IP address, it + should be placed first. + """ + ip = ray.get(worker.get_node_ip.remote()) + return (ip != driver_ip, ip_counts[ip], ip) + + # After sorting, the workers on the same node will be + # close to each other, and the workers on the driver + # node will be placed first. + self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) + # Get the set of GPU IDs used on each node. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", use_dummy_driver=True) - node_workers = defaultdict(list) - node_gpus = defaultdict(list) + node_workers = defaultdict(list) # node id -> list of worker ranks + node_gpus = defaultdict(list) # node id -> list of gpu ids for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): node_workers[node_id].append(i) + # `gpu_ids` can be a list of strings or integers. + # convert them to integers for consistency. + # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs), + # string sorting is not sufficient. + # see https://github.com/vllm-project/vllm/issues/5590 + gpu_ids = [int(x) for x in gpu_ids] node_gpus[node_id].extend(gpu_ids) for node_id, gpu_ids in node_gpus.items(): node_gpus[node_id] = sorted(gpu_ids) + all_ips = set(worker_ips + [driver_ip]) + n_ips = len(all_ips) + n_nodes = len(node_workers) + + if n_nodes != n_ips: + raise RuntimeError( + f"Every node should have a unique IP address. Got {n_nodes}" + f" nodes with node ids {list(node_workers.keys())} and " + f"{n_ips} unique IP addresses {all_ips}. Please check your" + " network configuration. If you set `VLLM_HOST_IP` or " + "`HOST_IP` environment variable, make sure it is unique for" + " each node.") # pylint: disable=invalid-name VLLM_INSTANCE_ID = get_vllm_instance_id() @@ -132,10 +185,27 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", VLLM_INSTANCE_ID, "VLLM_TRACE_FUNCTION": str(envs.VLLM_TRACE_FUNCTION), + **({ + "VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND + } if envs.VLLM_ATTENTION_BACKEND is not None else {}) }, ) for (node_id, _) in worker_node_and_gpu_ids] + + self._env_vars_for_all_workers = ( + all_args_to_update_environment_variables) + self._run_workers("update_environment_variables", - all_args=all_args_to_update_environment_variables) + all_args=self._get_env_vars_to_be_updated()) + if len(node_gpus) == 1: + # in single node case, we don't need to get the IP address. + # the loopback address is sufficient + # NOTE: a node may have several IP addresses, one for each + # network interface. `get_ip()` might return any of them, + # while they might not work for communication inside the node + # if the network setup is complicated. Using the loopback address + # solves this issue, as it always works for communication inside + # the node. + driver_ip = "127.0.0.1" distributed_init_method = get_distributed_init_method( driver_ip, get_open_port()) @@ -153,11 +223,46 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", self._run_workers("load_model", max_concurrent_workers=self.parallel_config. max_parallel_loading_workers) - self._run_workers("reserve_memory_for_migration", - migration_config=self.migration_config, - model_config=self.model_config, - cache_config=self.cache_config, - parallel_config=self.parallel_config) + + if self.use_ray_spmd_worker: + for pp_rank in range(self.parallel_config.pipeline_parallel_size): + self.pp_tp_workers.append([]) + for tp_rank in range( + self.parallel_config.tensor_parallel_size): + # PP=2, TP=4 + # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] + rank = (pp_rank * self.parallel_config.tensor_parallel_size + ) + tp_rank + assert len(self.pp_tp_workers[pp_rank]) == tp_rank + assert pp_rank < len(self.pp_tp_workers) + self.pp_tp_workers[pp_rank].append(self.workers[rank]) + + # This is the list of workers that are rank 0 of each TP group EXCEPT + # global rank 0. These are the workers that will broadcast to the + # rest of the workers. + self.tp_driver_workers: List[RayWorkerWrapper] = [] + # This is the list of workers that are not drivers and not the first + # worker in a TP group. These are the workers that will be + # broadcasted to. + self.non_driver_workers: List[RayWorkerWrapper] = [] + + # Enforce rank order for correct rank to return final output. + for index, worker in enumerate(self.workers): + # The driver worker is rank 0 and not in self.workers. + rank = index + 1 + if rank % self.parallel_config.tensor_parallel_size == 0: + self.tp_driver_workers.append(worker) + else: + self.non_driver_workers.append(worker) + + def _get_worker_module_and_class( + self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: + worker_class_fn = None + if self.scheduler_config.is_multi_step or self.speculative_config: + raise NotImplementedError + worker_module_name = "llumnix.backends.vllm.worker" + worker_class_name = "MigrationWorker" + return (worker_module_name, worker_class_name, worker_class_fn) async def execute_model_async(self, *args, **kwargs): t0 = time.time() @@ -176,7 +281,7 @@ def __init__(self, *args, **kwargs) -> None: self.cache_block_size = get_cache_block_size( self.cache_config.block_size, self.model_config, self.parallel_config) - self.cache_block_size /= _GB + self.cache_block_size /= GiB_bytes self.sim_cache_config = SimCacheConfig(self.cache_config.gpu_memory_utilization, self.cache_config.block_size, self.scheduler_config.max_num_batched_tokens) @@ -223,7 +328,7 @@ async def execute_model_async( dummy_sample_output = SequenceOutput(seq_id, 20, {20: Logprob(1.0)}) samples.append(dummy_sample_output) if samples: - output = SequenceGroupOutput(samples, None) + output = CompletionSequenceGroupOutput(samples, None) sampler_outputs.append(output) return [SamplerOutput(outputs=sampler_outputs)] diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index 1af547cc..bb25798f 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -13,7 +13,7 @@ import time import traceback -from typing import Any, List, Optional, Union, Iterable, Tuple, Deque +from typing import Any, List, Optional, Union, Iterable, Deque, Tuple from collections import defaultdict import threading import asyncio @@ -23,9 +23,8 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy from vllm.engine.async_llm_engine import _AsyncLLMEngine -from vllm.core.scheduler import ScheduledSequenceGroup -from vllm.outputs import RequestOutput -from vllm.sequence import SequenceGroup, SequenceStatus, SamplerOutput, SequenceGroupMetadata +from vllm.outputs import RequestOutput, RequestOutputFactory, EmbeddingRequestOutput +from vllm.sequence import SequenceGroup, SequenceStatus from vllm.engine.arg_utils import EngineArgs from vllm.utils import Counter from vllm.usage.usage_lib import UsageContext @@ -45,14 +44,28 @@ NO_OUTPUTS_STEP_INTERVAL = 0.01 + +class LlumnixOutput(RequestOutputFactory): + @staticmethod + def create(seq_group: SequenceGroupLlumnix, use_cache: bool = False): + # Determine the type based on a condition, for example: + if hasattr(seq_group, + 'embeddings') and seq_group.embeddings is not None: + return EmbeddingRequestOutput.from_seq_group(seq_group), seq_group.server_info + # pylint: disable=too-many-function-args + return RequestOutput.from_seq_group(seq_group, use_cache), seq_group.server_info + class LLMEngineLlumnix(_AsyncLLMEngine): def __init__(self, instance_id: str, request_output_queue_type: QueueType, placement_group: Optional[PlacementGroup], node_id: Optional[str], - *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + *arg, **kwargs) -> None: + # pylint: disable=import-outside-toplevel + import vllm.outputs + vllm.outputs.RequestOutputFactory.create = LlumnixOutput.create + super().__init__(*arg, **kwargs) self.instance_id = instance_id self.step_counter = Counter() self.instance_info = None @@ -105,7 +118,7 @@ def from_engine_args( from llumnix.backends.vllm.executor import SimGPUExecutor executor_class = SimGPUExecutor executor_class.latency_mem = latency_mem - elif engine_config.parallel_config.worker_use_ray: + elif engine_config.parallel_config.use_ray: from llumnix.backends.vllm.executor import LlumnixRayGPUExecutor executor_class = LlumnixRayGPUExecutor executor_class.migration_config = migration_config @@ -126,52 +139,26 @@ def from_engine_args( ) return engine - def _process_model_outputs( - self, - output: List[SamplerOutput], - scheduled_seq_groups: List[ScheduledSequenceGroup], - ignored_seq_groups: List[SequenceGroup], - seq_group_metadata_list: List[SequenceGroupMetadata], + def _process_request_outputs( + self, + outputs: List[Tuple[RequestOutput,ServerInfo]], + step_begin_time: float ) -> Tuple[List[RequestOutput], List[ServerInfo]]: - # ensure scheduled_seq_groups matching output + request_outputs = [] server_infos = [] - if output: - new_output = [] - new_scheduled_seq_groups = [] - new_seq_group_metadata_list = [] - for scheduled_seq_group, seq_group_meta, seq_group_output in zip(scheduled_seq_groups, seq_group_metadata_list, output[0].outputs): - seq_group = scheduled_seq_group.seq_group - if seq_group.get_seqs(SequenceStatus.RUNNING): - new_scheduled_seq_groups.append(scheduled_seq_group) - new_seq_group_metadata_list.append(seq_group_meta) - new_output.append(seq_group_output) - server_infos.append(seq_group.server_info) - scheduled_seq_groups = new_scheduled_seq_groups - output[0].outputs = new_output - seq_group_metadata_list = new_seq_group_metadata_list - for ignored_seq_group in ignored_seq_groups: - server_infos.append(ignored_seq_group.server_info) - - for server_info in server_infos: - if hasattr(server_info, 'request_timestamps'): - server_info.request_timestamps.engine_process_model_outputs_timestamp_begin = time.time() - - request_outputs = super()._process_model_outputs(output, scheduled_seq_groups, ignored_seq_groups, seq_group_metadata_list) - + if outputs: + request_outputs, server_infos = zip(*outputs) + request_outputs = list(request_outputs) + server_infos = list(server_infos) for request_output, server_info in zip(request_outputs, server_infos): if hasattr(server_info, 'request_timestamps'): request_output.request_timestamps = server_info.request_timestamps request_output.request_timestamps.engine_process_model_outputs_timestamp_end = time.time() if request_output.finished: logger.info("engine finished request {}".format(request_output.request_id)) - - # TODO(ZeldaHuang): Use LlumnixRequestOutput to store llumnix output args. - return request_outputs, server_infos - - async def step_async(self) -> Tuple[List[RequestOutput], List[ServerInfo]]: - step_begin_time = time.time() - request_outputs, server_infos = await super().step_async() - + for server_info in server_infos: + if hasattr(server_info, 'request_timestamps'): + server_info.request_timestamps.engine_process_model_outputs_timestamp_begin = time.time() for request_output in request_outputs: if hasattr(request_output, 'request_timestamps'): request_output.request_timestamps.engine_step_timestamp_begin = step_begin_time @@ -185,11 +172,11 @@ async def step_async(self) -> Tuple[List[RequestOutput], List[ServerInfo]]: instance_info.num_seqs, sum(instance_info.running_seq_lens), self.model_executor.last_inference_latency) - seq_groups = self.scheduler.running + seq_groups = self.scheduler[0].running if seq_groups: tot_blocks = [] for seq in seq_groups[-1].get_seqs(SequenceStatus.RUNNING): - blocks = self.scheduler.block_manager.get_block_table(seq) + blocks = self.scheduler[0].block_manager.get_block_table(seq) tot_blocks.extend(blocks) tot_blocks = set(tot_blocks) instance_info.num_blocks_last_running_request = len(tot_blocks) @@ -205,6 +192,12 @@ async def step_async(self) -> Tuple[List[RequestOutput], List[ServerInfo]]: return request_outputs, server_infos + async def step_async(self) -> Tuple[List[RequestOutput], List[ServerInfo]]: + step_begin_time = time.time() + # pylint: disable=too-many-function-args + outputs = await super().step_async(0) + return self._process_request_outputs(outputs, step_begin_time) + def update_instance_info(self, instance_info: InstanceInfo) -> None: # These fields are updated after step. if self.instance_info is not None: @@ -217,12 +210,13 @@ def update_instance_info(self, instance_info: InstanceInfo) -> None: def add_request(self, request_id: str, server_info: ServerInfo, expected_steps: int, *args, **kwargs): super().add_request(request_id, *args, **kwargs) - seq_group = self.scheduler.waiting[-1] + seq_group = self.scheduler[0].waiting[-1] if hasattr(server_info, 'request_timestamps'): server_info.request_timestamps.engine_add_request_timestamp = time.time() - self.scheduler.waiting[-1] = SequenceGroupLlumnix(request_id, server_info, expected_steps, [seq_group.get_seqs()[0]], - seq_group.sampling_params, seq_group.metrics.arrival_time, seq_group.lora_request, - seq_group.multi_modal_data) + self.scheduler[0].waiting[-1] = SequenceGroupLlumnix(request_id, server_info, expected_steps, [seq_group.get_seqs()[0]], + seq_group.metrics.arrival_time, seq_group.sampling_params, seq_group.lora_request, + seq_group.trace_headers, seq_group.prompt_adapter_request, seq_group.encoder_seq, + seq_group.priority) def _start_put_queue_loop(self): while True: @@ -261,8 +255,10 @@ def __init__( instance_id=instance_id, placement_group=placement_group, node_id=node_id) - self.engine.scheduler = SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config) - self.engine.scheduler.add_update_instance_info_callback(self.engine.update_instance_info) + self.engine.scheduler = [SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config) + for _ in range(engine_args.pipeline_parallel_size)] + for vid in range(engine_args.pipeline_parallel_size): + self.engine.scheduler[vid].add_update_instance_info_callback(self.engine.update_instance_info) self.engine.output_processor.scheduler = self.engine.scheduler self.instance_id = instance_id self.worker_handle_list = self.engine.model_executor.workers.copy() @@ -323,8 +319,8 @@ def commit_dst_request(self, backend_request: SequenceGroupLlumnix) -> None: seq = backend_request.get_seqs()[0] seq.seq_id = next(self.engine.seq_counter) logger.info("pop request {} from pre_alloc_cache_dict".format(backend_request.request_id)) - pre_alloc_blocks = self.engine.scheduler.pre_alloc_cache_dict.pop(backend_request.request_id) - self.engine.scheduler.block_manager.add_block_table(pre_alloc_blocks, seq.seq_id) + pre_alloc_blocks = self.engine.scheduler[0].pre_alloc_cache_dict.pop(backend_request.request_id) + self.engine.scheduler[0].block_manager.add_block_table(pre_alloc_blocks, seq.seq_id) backend_request.reset_migration_args_dst() assert backend_request.status in [RequestStatus.WAITING_MIGRATING, RequestStatus.RUNNING_MIGRATING], \ "The status of request migrated to dst instance should be \ @@ -355,47 +351,50 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: request_ids = set(request_id) return self.engine.abort_request(request_ids) - def get_running_queue(self) -> Deque[SequenceGroupLlumnix]: - return self.engine.scheduler.get_running_queue() + def get_running_queue(self) -> List[SequenceGroupLlumnix]: + return self.engine.scheduler[0].get_running_queue() def get_waiting_queue(self) -> Deque[SequenceGroupLlumnix]: - return self.engine.scheduler.get_waiting_queue() + return self.engine.scheduler[0].get_waiting_queue() - def get_request_incremental_blocks(self, *args, **kwargs) -> List[int]: - return self.engine.scheduler.get_request_incremental_blocks(*args, **kwargs) + def get_request_incremental_blocks(self, *args, **kwargs) -> Tuple[List[int], List[int]]: + return self.engine.scheduler[0].get_request_incremental_blocks(*args, **kwargs) - def remove_running_request(self, *args, **kwargs) -> bool: - return self.engine.scheduler.remove_running_request(*args, **kwargs) + def remove_running_request(self, *args, **kwargs) -> None: + return self.engine.scheduler[0].remove_running_request(*args, **kwargs) def remove_waiting_request(self, *args, **kwargs) -> bool: - return self.engine.scheduler.remove_waiting_request(*args, **kwargs) + return self.engine.scheduler[0].remove_waiting_request(*args, **kwargs) def add_migrating_out_request_last_stage(self, *args, **kwargs) -> None: - return self.engine.scheduler.add_migrating_out_request_last_stage(*args, **kwargs) + return self.engine.scheduler[0].add_migrating_out_request_last_stage(*args, **kwargs) def remove_migrating_out_request_last_stage(self, *args, **kwargs) -> None: - return self.engine.scheduler.remove_migrating_out_request_last_stage(*args, **kwargs) + return self.engine.scheduler[0].remove_migrating_out_request_last_stage(*args, **kwargs) def pop_migrating_out_requests_last_stage(self, *args, **kwargs) -> List[Any]: - return self.engine.scheduler.pop_migrating_out_requests_last_stage(*args, **kwargs) + return self.engine.scheduler[0].pop_migrating_out_requests_last_stage(*args, **kwargs) def pre_alloc(self, *args, **kwargs) -> List[int]: - return self.engine.scheduler.pre_alloc(*args, **kwargs) + return self.engine.scheduler[0].pre_alloc(*args, **kwargs) def should_abort_migration(self, *args, **kwargs) -> bool: - return self.engine.scheduler.should_abort_migration(*args, **kwargs) + return self.engine.scheduler[0].should_abort_migration(*args, **kwargs) def add_running_request(self, *args, **kwargs) -> None: - return self.engine.scheduler.add_running_request(*args, **kwargs) + return self.engine.scheduler[0].add_running_request(*args, **kwargs) def add_waiting_request(self, *args, **kwargs) -> None: - return self.engine.scheduler.add_waiting_request(*args, **kwargs) + return self.engine.scheduler[0].add_waiting_request(*args, **kwargs) + + def is_request_running(self, *args, **kwargs) -> bool: + return self.engine.scheduler[0].is_request_running(*args, **kwargs) def free_dst_pre_alloc_cache(self, *args, **kwargs) -> None: - return self.engine.scheduler.free_dst_pre_alloc_cache(*args, **kwargs) + return self.engine.scheduler[0].free_dst_pre_alloc_cache(*args, **kwargs) def free_src_request(self, backend_request: SequenceGroup) -> None: - return self.engine.scheduler.free_src_request(backend_request) + return self.engine.scheduler[0].free_src_request(backend_request) def get_all_request_ids(self) -> List[str]: - return self.engine.scheduler.get_all_request_ids() + return self.engine.scheduler[0].get_all_request_ids() diff --git a/llumnix/backends/vllm/migration_backend.py b/llumnix/backends/vllm/migration_backend.py index f21c2bab..a268e47f 100644 --- a/llumnix/backends/vllm/migration_backend.py +++ b/llumnix/backends/vllm/migration_backend.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Tuple import torch from func_timeout import func_set_timeout, FunctionTimedOut @@ -41,7 +41,7 @@ def exec_method(self, is_driver_worker, handle, *args, **kwargs): NUMPY_SUPPORTED_DTYPES = [torch.float32, torch.float16] class RayRpcMigrationBackend(MigrationBackendBase): - def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, worker_rank, worker_handle_list, \ + def __init__(self, migration_config: MigrationConfig, cache_engine: List[CacheEngine], worker_rank, worker_handle_list, \ scheduling_strategy, is_driver_worker, gpu_cache) -> None: super().__init__() @@ -52,23 +52,22 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, self.worker_handle_list = worker_handle_list self.actor = ProxyActor.options(scheduling_strategy=scheduling_strategy).remote() - self.rpc_dtype = self.cache_engine.dtype - if self.cache_engine.dtype in NUMPY_SUPPORTED_DTYPES: - self.rpc_dtype = self.cache_engine.dtype + if self.cache_engine[0].dtype in NUMPY_SUPPORTED_DTYPES: + self.rpc_dtype = self.cache_engine[0].dtype else: self.rpc_dtype = torch.float32 - logger.warning("Detect numpy unsupported dtype: {}. Using torch.float32.".format(self.cache_engine.dtype)) + logger.warning("Detect numpy unsupported dtype: {}. Using torch.float32.".format(self.cache_engine[0].dtype)) self.is_driver_worker = is_driver_worker self.gpu_cache = gpu_cache self.cache_device = "cpu" self.num_migration_buffer_blocks = self.migration_config.migration_buffer_blocks - self.num_layers = self.cache_engine.num_layers - self.migration_cache_size = self.cache_engine.block_size * self.cache_engine.num_heads * self.cache_engine.head_size + self.num_layers = self.cache_engine[0].num_attention_layers + self.migration_cache_size = self.cache_engine[0].block_size * self.cache_engine[0].num_kv_heads * self.cache_engine[0].head_size self.dummy_cache = torch.empty( size=(self.num_migration_buffer_blocks, self.num_layers, 2, self.migration_cache_size), - dtype=self.cache_engine.dtype, + dtype=self.cache_engine[0].dtype, device=self.cache_device, pin_memory=True ) @@ -104,26 +103,40 @@ def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int] recv_blocks = dst_blocks[start_idx:start_idx+offset] self.do_recv(rpc_numpy_cache, recv_blocks) - def do_send(self, dst_handle, blocks: List[int]): + def do_send(self, dst_handle, blocks: List[int], virtuel_engine: int=0): num_blocks = len(blocks) send_cache = self.dummy_cache[:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size) - src_to_dst = {block_num: idx for idx, block_num in enumerate(blocks)} + # src_to_dst = {block_num: idx for idx, block_num in enumerate(blocks)} + src_to_dst: List[Tuple[int, int]] = [] + for idx in range(num_blocks): + src_to_dst.append((blocks[idx], idx)) + block_mapping_tensor = torch.tensor(src_to_dst, + dtype=torch.int64, + device="cpu", pin_memory=True).view(-1, 2) with torch.cuda.stream(self.migration_stream): for layer_idx in range(self.num_layers): - self.cache_engine.attn_backend.swap_blocks(self.gpu_cache[layer_idx], send_cache[layer_idx], src_to_dst) + self.cache_engine[virtuel_engine].attn_backend \ + .swap_blocks(self.gpu_cache[virtuel_engine][layer_idx], send_cache[layer_idx], block_mapping_tensor) torch.cuda.Stream.synchronize(self.migration_stream) return send_cache.to(self.rpc_dtype).numpy() - def do_recv(self, src_handle, blocks: List[int]): + def do_recv(self, src_handle, blocks: List[int], virtuel_engine: int=0): num_blocks = len(blocks) - src_to_dst = dict(enumerate(blocks)) + # src_to_dst = dict(enumerate(blocks)) + src_to_dst: List[Tuple[int, int]] = [] + for idx in range(num_blocks): + src_to_dst.append((idx, blocks[idx])) + block_mapping_tensor = torch.tensor(src_to_dst, + dtype=torch.int64, + device="cpu", pin_memory=True).view(-1, 2) recv_cache = self.dummy_cache[:num_blocks].view(self.num_layers, 2, num_blocks, self.migration_cache_size) # use pin memory dummy_cache to speed up data transfer recv_cache.copy_(torch.from_numpy(src_handle)) with torch.cuda.stream(self.migration_stream): for layer_idx in range(self.num_layers): - self.cache_engine.attn_backend.swap_blocks(recv_cache[layer_idx], self.gpu_cache[layer_idx], src_to_dst) + self.cache_engine[virtuel_engine].attn_backend \ + .swap_blocks(recv_cache[layer_idx], self.gpu_cache[virtuel_engine][layer_idx], block_mapping_tensor) torch.cuda.Stream.synchronize(self.migration_stream) def try_import_gloo(): @@ -140,7 +153,7 @@ def try_import_gloo(): raise ImportError("Gloo is not installed. Please install it first.") from e class RayColMigrationBackend(MigrationBackendBase): - def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, local_rank, + def __init__(self, migration_config: MigrationConfig, cache_engine: List[CacheEngine], local_rank, scheduling_strategy, is_driver_worker, gpu_cache) -> None: super().__init__() @@ -150,7 +163,7 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, self.migration_config = migration_config self.cache_engine = cache_engine self.backend = migration_config.migration_backend - self.migration_num_layers = min(migration_config.migration_num_layers, self.cache_engine.num_layers) + self.migration_num_layers = min(migration_config.migration_num_layers, self.cache_engine[0].num_attention_layers) self.num_migration_buffer_blocks = migration_config.migration_buffer_blocks self.backend = migration_config.migration_backend @@ -163,7 +176,7 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, self.is_driver_worker = is_driver_worker self.gpu_cache = gpu_cache - self.migration_cache_size = self.cache_engine.block_size * self.cache_engine.num_heads * self.cache_engine.head_size + self.migration_cache_size = self.cache_engine[0].block_size * self.cache_engine[0].num_kv_heads * self.cache_engine[0].head_size if self.backend == 'gloo': try_import_gloo() @@ -174,7 +187,7 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: CacheEngine, pin_memory = (self.backend == 'gloo') self.dummy_cache = torch.empty( size=(self.num_migration_buffer_blocks, self.migration_num_layers, 2, self.migration_cache_size), - dtype=self.cache_engine.dtype, + dtype=self.cache_engine[0].dtype, device=self.cache_device, pin_memory=pin_memory ) @@ -248,44 +261,55 @@ def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int] self.actor.exec_method.remote(self.is_driver_worker, src_handle, "do_send", self.global_rank, send_blocks) self.do_recv(src_rank, recv_blocks) - def do_send(self, dst_handle, blocks: List[int]): + def do_send(self, dst_handle, blocks: List[int], virtuel_engine: int=0): num_blocks = len(blocks) send_cache = self.dummy_cache[:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size) - src_to_dst = {block_num: idx for idx, block_num in enumerate(blocks)} - + src_to_dst: List[Tuple[int, int]] = [] + for idx in range(num_blocks): + src_to_dst.append((blocks[idx], idx)) + block_mapping_tensor = torch.tensor(src_to_dst, + dtype=torch.int64, + device="cpu", pin_memory=True).view(-1, 2) with self.migration_stream: - for layer_idx in range(self.cache_engine.num_layers): + for layer_idx in range(self.cache_engine[0].num_attention_layers): cache_idx = layer_idx % self.migration_num_layers - self.cache_engine.attn_backend.swap_blocks(self.gpu_cache[layer_idx], send_cache[cache_idx], src_to_dst) - if cache_idx + 1 == self.migration_num_layers or layer_idx + 1 == self.cache_engine.num_layers: + self.cache_engine[virtuel_engine].attn_backend \ + .swap_blocks(self.gpu_cache[virtuel_engine][layer_idx], send_cache[cache_idx], block_mapping_tensor) + if cache_idx + 1 == self.migration_num_layers or layer_idx + 1 == self.cache_engine[0].num_attention_layers: # TODO(KuilongCui): check the error code if peer is dead col.send(send_cache, dst_handle, self.group_name) self.migration_stream.synchronize() - def do_recv(self, src_handle, blocks: List[int]): + def do_recv(self, src_handle, blocks: List[int], virtuel_engine: int=0): num_blocks = len(blocks) - src_to_dst = dict(enumerate(blocks)) + src_to_dst: List[Tuple[int, int]] = [] + for idx in range(num_blocks): + src_to_dst.append((idx, blocks[idx])) + block_mapping_tensor = torch.tensor(src_to_dst, + dtype=torch.int64, + device="cpu", pin_memory=True).view(-1, 2) recv_cache = self.dummy_cache[:num_blocks].view(self.migration_num_layers, 2, num_blocks, self.migration_cache_size) with self.migration_stream: - for layer_idx in range(self.cache_engine.num_layers): + for layer_idx in range(self.cache_engine[0].num_attention_layers): cache_idx = layer_idx % self.migration_num_layers if cache_idx == 0: col.recv(recv_cache, src_handle, self.group_name) - self.cache_engine.attn_backend.swap_blocks(recv_cache[cache_idx], self.gpu_cache[layer_idx], src_to_dst) + self.cache_engine[virtuel_engine].attn_backend \ + .swap_blocks(recv_cache[cache_idx], self.gpu_cache[virtuel_engine][layer_idx], block_mapping_tensor) self.migration_stream.synchronize() -def get_migration_backend(migration_config: MigrationConfig, cache_engine: CacheEngine, worker_handle_list, scheduling_strategy, +def get_migration_backend(migration_config: MigrationConfig, cache_engine: List[CacheEngine], worker_handle_list, scheduling_strategy, is_driver_worker, gpu_cache, worker_rank, local_rank) -> MigrationBackendBase: - if cache_engine.num_gpu_blocks < migration_config.migration_buffer_blocks: - logger.warning("migration_buffer_blocks({}) is larger than num_gpu_blocks({}), reducing it to num_gpu_blocks." - .format(migration_config.migration_buffer_blocks, cache_engine.num_gpu_blocks)) - migration_config.migration_buffer_blocks = cache_engine.num_gpu_blocks + if cache_engine[0].num_gpu_blocks < migration_config.migration_buffer_blocks: + logger.warning("migration_cache_blocks({}) is larger than num_gpu_blocks({}), reducing it to num_gpu_blocks." + .format(migration_config.migration_buffer_blocks, cache_engine[0].num_gpu_blocks)) + migration_config.migration_buffer_blocks = cache_engine[0].num_gpu_blocks target_migration_backend = None backend = migration_config.migration_backend - assert backend in ['nccl', 'gloo', 'rayrpc'], "Unsupported migration backend: {} for llumnix".format(backend) + assert backend in ['nccl', 'rayrpc', 'gloo'], "Unsupported migration backend: {} for llumnix".format(backend) if backend in ['nccl', 'gloo']: target_migration_backend = RayColMigrationBackend(migration_config, cache_engine, local_rank, scheduling_strategy, diff --git a/llumnix/backends/vllm/scheduler.py b/llumnix/backends/vllm/scheduler.py index ea0991f7..18439d05 100644 --- a/llumnix/backends/vllm/scheduler.py +++ b/llumnix/backends/vllm/scheduler.py @@ -13,12 +13,13 @@ from asyncio.log import logger import time +import bisect from typing import Dict, List, Optional, Tuple, Deque from collections import deque -from vllm.core.block_manager_v1 import BlockSpaceManagerV1, BlockTable +from vllm.utils import Device +from vllm.core.block_manager import SelfAttnBlockSpaceManager, BlockTable from vllm.core.scheduler import (Scheduler, PreemptionMode, SequenceStatus, SequenceGroupMetadata, SchedulerOutputs) -from vllm.core.policy import PolicyFactory from vllm.sequence import SequenceGroup from vllm.core.interfaces import AllocStatus @@ -32,22 +33,24 @@ # TODO(ZeldaHuang): adapt prefix cache and sliding window, now use v1 manager -class BlockManagerLlumnix(BlockSpaceManagerV1): - def get_free_blocks(self, num_required_blocks: int) -> BlockTable: - num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() - if (num_free_gpu_blocks - num_required_blocks < +class BlockManagerLlumnix(SelfAttnBlockSpaceManager): + def get_free_blocks(self, num_required_blocks: int, token_ids: List[int]) -> BlockTable: + num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(device=Device.GPU) + block_table = BlockTable( + block_size=self.block_size, + block_allocator=self.block_allocator, + max_block_sliding_window=self.max_block_sliding_window, + ) + if (num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks): - return [] - blocks = [] - for _ in range(num_required_blocks): - block = self.gpu_allocator.allocate() - # Set the reference counts of the token blocks. - block.ref_count = 1 - blocks.append(block) - return blocks + block_table.allocate(token_ids) + + return block_table def add_block_table(self, block_table: BlockTable, seq_id: int) -> None: - self.block_tables[seq_id] = block_table.copy() + self.block_tables[seq_id] = block_table + self._computed_blocks_tracker.add_seq(seq_id) + self._last_access_blocks_tracker.add_seq(seq_id) class SchedulerLlumnix(Scheduler): def __init__(self, *args, **kwargs) -> None: @@ -94,13 +97,15 @@ def get_all_request_ids(self) -> List[str]: request_ids.append(seq_group.request_id) return request_ids - def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_stage_num_blocks: int) -> List[int]: + def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_stage_num_blocks: int) -> Tuple[List[int], List[int]]: seq = backend_request.get_seqs()[0] blocks = self.block_manager.get_block_table(seq) - return blocks[pre_stage_num_blocks:] + block_table = self.block_manager.block_tables[seq.seq_id] + token_ids = backend_request.token_ids + return blocks[pre_stage_num_blocks:], token_ids[pre_stage_num_blocks * self.block_manager.block_size:block_table.num_full_slots] def remove_running_request(self, request_id: str) -> bool: - for seq_group in self.running: + for seq_group in reversed(self.running): if seq_group.request_id == request_id: self.running.remove(seq_group) seq_group.set_status(RequestStatus.RUNNING_MIGRATING) @@ -130,20 +135,20 @@ def pre_alloc(self, request_id: str, request_status: RequestStatus, request_arrival_time: float, - block_num: int) -> List[int]: + block_num: int, + token_ids: List[int]) -> List[int]: # Only migrate waiting request when the waiting request is the earliest arrival one # among the requests of dst instance's waiting queue. if request_status == RequestStatus.WAITING_MIGRATING: - if (self.waiting and request_arrival_time > self.waiting[0].arrival_time) \ - or block_num * self.cache_config.block_size > self.prompt_limit: + if self.waiting and request_arrival_time > self.waiting[0].arrival_time: return [] - blocks = self.block_manager.get_free_blocks(block_num) - pre_blocks = self.pre_alloc_cache_dict.get(request_id, []) - pre_blocks.extend(blocks) - logger.info("add request {} to pre_alloc_cache_dict".format(request_id)) - self.pre_alloc_cache_dict[request_id] = pre_blocks - blocks = [block.block_number for block in blocks] - return blocks + block_table = self.pre_alloc_cache_dict.get(request_id, None) + if not block_table: + block_table = self.block_manager.get_free_blocks(block_num, token_ids) + self.pre_alloc_cache_dict[request_id] = block_table + else: + block_table.append_token_ids(token_ids) + return block_table.physical_block_ids[-block_num:] def add_running_request(self, backend_request: LlumnixRequest) -> None: self._set_status(backend_request, status_to=SequenceStatus.RUNNING) @@ -152,9 +157,12 @@ def add_running_request(self, backend_request: LlumnixRequest) -> None: def add_waiting_request(self, backend_request: LlumnixRequest) -> None: self._set_status(backend_request, status_to=SequenceStatus.WAITING) # pylint: disable=E0203 - self.waiting.append(backend_request) - fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs") - self.waiting = fcfs_policy.sort_by_priority(time.time(), self.waiting) + arrival_time_list = [request.arrival_time for request in self.waiting] + idx = bisect.bisect_right(arrival_time_list, backend_request.arrival_time) + if idx < len(self.waiting): + self.waiting.insert(idx, backend_request) + else: + self.waiting.append(backend_request) def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: if seq_group.status == RequestStatus.WAITING_MIGRATING: @@ -179,19 +187,17 @@ def _set_status(self, def free_dst_pre_alloc_cache(self, request_id: str = None) -> None: if request_id: - logger.info("pop request {} from pre_alloc_cache_dict".format(request_id)) - blocks = self.pre_alloc_cache_dict.pop(request_id, []) - # pylint: disable=protected-access - self.block_manager._free_block_table(blocks) + block_table = self.pre_alloc_cache_dict.pop(request_id, None) + if block_table: + block_table.free() else: # TODO(s5u13b): Only effective with one-to-one migration restriction. # Clear all pre-allocated cache of dst instance when src instance encounters exception. request_ids = list(self.pre_alloc_cache_dict.keys()) for req_id in request_ids: - logger.info("pop request {} from pre_alloc_cache_dict".format(req_id)) - blocks = self.pre_alloc_cache_dict.pop(req_id, []) - # pylint: disable=protected-access - self.block_manager._free_block_table(blocks) + block_table = self.pre_alloc_cache_dict.pop(req_id, None) + if block_table: + block_table.free() def free_src_request(self, backend_request: SequenceGroupLlumnix) -> None: seq = backend_request.get_seqs()[0] @@ -244,15 +250,16 @@ def _get_instance_info(self, scheduled_seq_groups: List[SequenceGroupLlumnix]) - instance_info.finished_request_ids = [seq_group.request_id for seq_group in self.running if seq_group.finished] return instance_info - def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: - seq_group_metadata_list, scheduler_outputs = super().schedule() + def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: + seq_group_metadata_list, scheduler_outputs, allow_async_output_proc = super().schedule() self.update_instance_info_callback(self._get_instance_info([scheduled_seq_group.seq_group \ for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups])) for seq_group in self.waiting: seq_group.try_schedule_times += 1 - return seq_group_metadata_list, scheduler_outputs + return seq_group_metadata_list, scheduler_outputs, allow_async_output_proc - def _schedule_running(self, running_queue: deque, *args, **kwargs): + def _schedule_running(self, *args, **kwargs): + running_queue = self.running filtered_running_queue = deque() remove_running = deque() for seq_group in running_queue: @@ -260,7 +267,9 @@ def _schedule_running(self, running_queue: deque, *args, **kwargs): remove_running.extend([seq_group]) else: filtered_running_queue.extend([seq_group]) - remaining_running, running_scheduled = super()._schedule_running(filtered_running_queue, *args, **kwargs) + + self.running = filtered_running_queue + ret = super()._schedule_running(*args, **kwargs) for seq_group in remove_running: - remaining_running.extend([seq_group]) - return remaining_running, running_scheduled + self.running.extend([seq_group]) + return ret diff --git a/llumnix/backends/vllm/sequence.py b/llumnix/backends/vllm/sequence.py index 5964f96d..ccc93f45 100644 --- a/llumnix/backends/vllm/sequence.py +++ b/llumnix/backends/vllm/sequence.py @@ -21,6 +21,10 @@ def __init__(self, request_id, server_info, expected_steps: int, *args, **kwargs SequenceGroup.__init__(self, request_id, *args, **kwargs) LlumnixRequest.__init__(self, request_id, server_info, expected_steps) + @property + def block_size(self) -> int: + return self.get_seqs()[0].block_size + @property def prompt_len(self) -> int: return self.get_seqs()[0].get_prompt_len() @@ -36,6 +40,13 @@ def request_len(self) -> int: def output_len(self) -> int: return self.get_seqs()[0].get_output_len() + @property + def n_blocks(self) -> int: + return self.get_seqs()[0].n_blocks + @property + def token_ids(self) -> int: + return self.get_seqs()[0].get_token_ids() + @property def inference_type(self) -> RequestInferenceType: if self.is_prefill(): @@ -47,8 +58,8 @@ def finished(self) -> bool: return self.get_seqs()[0].is_finished() @property - def arrival_time(self) -> float: - return self.metrics.arrival_time + def request_arrival_time(self) -> float: + return self.arrival_time @property def status(self) -> RequestStatus: @@ -66,4 +77,4 @@ def status(self) -> RequestStatus: @property def prefill_num_blocks(self) -> int: # Get the prefill len of the waiting request. - return len(self.get_seqs()[0].logical_token_blocks) + return self.get_seqs()[0].n_blocks diff --git a/llumnix/backends/vllm/simulator.py b/llumnix/backends/vllm/simulator.py index 809c61ce..eaed86bd 100644 --- a/llumnix/backends/vllm/simulator.py +++ b/llumnix/backends/vllm/simulator.py @@ -46,8 +46,10 @@ def __init__( instance_id=instance_id, latency_mem=latency_mem, node_id=node_id) - self.engine.scheduler = SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config) - self.engine.scheduler.add_update_instance_info_callback(self.engine.update_instance_info) + self.engine.scheduler = [SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config) + for _ in range(engine_args.pipeline_parallel_size)] + for vid in range(engine_args.pipeline_parallel_size): + self.engine.scheduler[vid].add_update_instance_info_callback(self.engine.update_instance_info) self.engine.output_processor.scheduler = self.engine.scheduler self.instance_id = instance_id diff --git a/llumnix/backends/vllm/utils.py b/llumnix/backends/vllm/utils.py index 7e49720a..1d0bcaa2 100644 --- a/llumnix/backends/vllm/utils.py +++ b/llumnix/backends/vllm/utils.py @@ -12,15 +12,16 @@ # limitations under the License. from functools import wraps -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional import torch from vllm.config import ModelConfig, ParallelConfig from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sampling_params import SamplingType -from vllm.model_executor.layers.sampler import SampleResultType, _multinomial, _greedy_sample, _random_sample,\ - _modify_greedy_probs_inplace, _beam_search_sample +from vllm.model_executor.layers.sampler import SamplingMetadata, SamplingTensors, SampleResultArgsType, SampleReturnType, \ + SampleResultsDictType, SampleMetadataType, MultinomialSamplesType, \ + flashinfer_top_k_top_p_sampling, _top_k_top_p_multinomial_with_flashinfer, \ + VLLM_INVALID_TOKEN_ID, _multinomial, _modify_greedy_probs_inplace, get_pythonized_sample_results from llumnix.logger import init_logger from llumnix.arg_utils import EngineManagerArgs @@ -42,8 +43,8 @@ def detect_unsupported_feature(engine_args: EngineArgs) -> None: raise ValueError(f'Unsupported feature: Llumnix does not support "{unsupported_feature}" currently.') def check_engine_args(engine_args: AsyncEngineArgs, engine_manager_args: EngineManagerArgs) -> None: - assert engine_args.engine_use_ray and engine_args.worker_use_ray, \ - ("In Llumnix, engine and worker must be ray actor.") + assert engine_args.worker_use_ray, \ + ("In Llumnix, worker must be ray actor.") migration_config = engine_manager_args.create_migration_config() engine_config = engine_args.create_engine_config() parallel_config = engine_config.parallel_config @@ -75,9 +76,22 @@ def _sample_with_torch( probs: torch.Tensor, logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, + sampling_tensors: SamplingTensors, include_gpu_probs_tensor: bool, modify_greedy_probs: bool, -) -> Tuple[SampleResultType, Optional[torch.Tensor]]: +) -> SampleReturnType: + '''Torch-oriented _sample() implementation. + + Single-step scheduling: + * Perform GPU-side sampling computation + * Immediately Pythonize sampling result + + Multi-step scheduling: + * Perform GPU-side sampling computation + * Defer Pythonization & preserve GPU-side + tensors required for Pythonization + ''' + categorized_seq_group_ids: Dict[SamplingType, List[int]] = {t: [] for t in SamplingType} @@ -87,23 +101,25 @@ def _sample_with_torch( sampling_type = sampling_params.sampling_type categorized_seq_group_ids[sampling_type].append(i) - sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} - sample_metadata = {} - multinomial_samples = {} + sample_results_dict: SampleResultsDictType = {} + sample_metadata: SampleMetadataType = {} + multinomial_samples: MultinomialSamplesType = {} + greedy_samples: Optional[torch.Tensor] = None + beam_search_logprobs: Optional[torch.Tensor] = None # Create output tensor for sampled token ids. if include_gpu_probs_tensor: - sampled_token_ids_tensor = torch.empty(logprobs.shape[0], - 1, - dtype=torch.long, - device=logprobs.device) + sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1), + VLLM_INVALID_TOKEN_ID, + dtype=torch.long, + device=logprobs.device) else: sampled_token_ids_tensor = None # Counterintiutively, having two loops here is actually faster. # The first loop can run without waiting on GPU<->CPU sync. for sampling_type in SamplingType: - sample_indices = categorized_sample_indices[sampling_type][:, 0] + sample_indices = categorized_sample_indices[sampling_type] num_tokens = len(sample_indices) if num_tokens == 0: continue @@ -116,7 +132,7 @@ def _sample_with_torch( greedy_samples = torch.argmax(logprobs[long_sample_indices], dim=-1) - if include_gpu_probs_tensor: + if sampled_token_ids_tensor is not None: # Store sampled tokens in output tensor. sampled_token_ids_tensor[ long_sample_indices] = greedy_samples.unsqueeze(-1) @@ -130,52 +146,64 @@ def _sample_with_torch( greedy_samples) elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - max_best_of_in_batch = 1 + max_n_in_batch = 1 for seq_group in seq_groups: if seq_group.is_prompt: sampling_params = seq_group.sampling_params - max_best_of_in_batch = max(max_best_of_in_batch, - sampling_params.best_of) - seeded_args = {} if sampling_type == SamplingType.RANDOM else { - "seq_groups": seq_groups, - } - - multinomial_samples[sampling_type] = _multinomial( - probs[long_sample_indices], max_best_of_in_batch, - **seeded_args) - - if include_gpu_probs_tensor: + max_n_in_batch = max(max_n_in_batch, sampling_params.n) + seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else + seq_groups) + + if flashinfer_top_k_top_p_sampling is not None: + multinomial_samples[ + sampling_type] = _top_k_top_p_multinomial_with_flashinfer( + probs[long_sample_indices], + sampling_tensors.top_ks[long_sample_indices], + sampling_tensors.top_ps[long_sample_indices], + max_n_in_batch, + seq_groups_arg, + ) + else: + multinomial_samples[sampling_type] = _multinomial( + probs[long_sample_indices], + max_n_in_batch, + seq_groups=seq_groups_arg) + + if sampled_token_ids_tensor is not None: # Store sampled tokens in output tensor. - sampled_token_ids_tensor[ - long_sample_indices] = multinomial_samples[sampling_type] + sampled_token_ids_tensor[long_sample_indices] = \ + multinomial_samples[sampling_type].to(torch.long) elif sampling_type == SamplingType.BEAM: beam_search_logprobs = logprobs[sample_indices] else: raise ValueError(f"Unsupported sampling type: {sampling_type}") - # GPU<->CPU sync happens in the loop below. - torch.cuda.current_stream().synchronize() - # This also converts the sample output to Python objects. - for sampling_type in SamplingType: - if sampling_type not in sample_metadata: - continue - (seq_group_id, seq_groups) = sample_metadata[sampling_type] - if sampling_type == SamplingType.GREEDY: - sample_results = _greedy_sample(seq_groups, greedy_samples) - elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - sample_results = _random_sample(seq_groups, - multinomial_samples[sampling_type]) - elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, - beam_search_logprobs) - sample_results_dict.update(zip(seq_group_id, sample_results)) - - sample_results = [ - sample_results_dict.get(i, ([], [])) - for i in range(len(sampling_metadata.seq_groups)) - ] - return sample_results, sampled_token_ids_tensor + # Encapsulate arguments for computing Pythonized sampler + # results, whether deferred or otherwise. + maybe_deferred_args = SampleResultArgsType( + sampling_metadata=sampling_metadata, + sample_metadata=sample_metadata, + multinomial_samples=multinomial_samples, + greedy_samples=greedy_samples, + beam_search_logprobs=beam_search_logprobs, + sample_results_dict=sample_results_dict) + + if not sampling_metadata.skip_sampler_cpu_output: + # GPU<->CPU sync happens here. + torch.cuda.current_stream().synchronize() + # This also converts the sampler output to a Python object. + # Return Pythonized sampler result & sampled token ids + return get_pythonized_sample_results( + maybe_deferred_args), sampled_token_ids_tensor + + # Defer sampler result Pythonization; return deferred + # Pythonization args & sampled token ids + return ( + maybe_deferred_args, + sampled_token_ids_tensor, + ) + def scheduler_lock(func): @wraps(func) diff --git a/llumnix/backends/vllm/worker.py b/llumnix/backends/vllm/worker.py index 0b2c6fb9..e6e3d079 100644 --- a/llumnix/backends/vllm/worker.py +++ b/llumnix/backends/vllm/worker.py @@ -22,7 +22,7 @@ from vllm.worker.worker import Worker from vllm.config import CacheConfig, ModelConfig, ParallelConfig from vllm.worker.cache_engine import CacheEngine -from vllm.config import _GB +from vllm.utils import GiB_bytes from llumnix.logger import init_logger from llumnix.backends.vllm.utils import _sample_with_torch @@ -117,8 +117,8 @@ def migrate_cache(self, src_worker_handle_list, src_blocks: List[int], dst_block total_kv_cache_size = len(src_blocks) * CacheEngine.get_cache_block_size( self.cache_config, self.model_config, self.parallel_config) - speed = total_kv_cache_size/_GB/(end_time - start_time) - logger.info("[migrate_cache] blocks_num: {}, total_kv_cache_size: {}, time: {}s, speed: {}GB/s." + speed = total_kv_cache_size/GiB_bytes/(end_time - start_time) + logger.info("[migration_cache] blocks_num: {}, total_kv_cache_size: {}, time: {}s, speed: {}GB/s." .format(len(src_blocks), convert_bytes(total_kv_cache_size), end_time-start_time, speed)) def do_recv(self, *args, **kwargs): diff --git a/llumnix/entrypoints/vllm/api_server.py b/llumnix/entrypoints/vllm/api_server.py index 46cbf842..3b3abc54 100644 --- a/llumnix/entrypoints/vllm/api_server.py +++ b/llumnix/entrypoints/vllm/api_server.py @@ -81,7 +81,7 @@ async def generate(request: Request) -> Response: # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: - async for request_output in results_generator: + async for request_output in results_generator.generator(): prompt = request_output.prompt text_outputs = [ prompt + output.text for output in request_output.outputs @@ -94,7 +94,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: # Non-streaming case final_output = None - async for request_output in results_generator: + async for request_output in results_generator.generator(): if await request.is_disconnected(): # Abort the request if the client disconnects. await llumnix_client.abort(request_id) @@ -131,7 +131,7 @@ async def generate_benchmark(request: Request) -> Response: final_output = None per_token_latency = [] per_token_latency_breakdown_dict = init_per_token_latency_breakdown_dict() - async for request_output in results_generator: + async for request_output in results_generator.generator(): if await request.is_disconnected(): # Abort the request if the client disconnects. await llumnix_client.abort(request_id) diff --git a/llumnix/entrypoints/vllm/client.py b/llumnix/entrypoints/vllm/client.py index b59ee4be..010a42a6 100644 --- a/llumnix/entrypoints/vllm/client.py +++ b/llumnix/entrypoints/vllm/client.py @@ -40,10 +40,10 @@ async def generate(self, request_id: str, *args, **kwargs) -> AsyncStream: - if sampling_params.n > 1 or sampling_params.use_beam_search: + if sampling_params.n > 1: raise ValueError("Unsupported feature: multiple sequence decoding") - - results_generator = AsyncStream(request_id) + # pylint: disable=unexpected-keyword-arg + results_generator = AsyncStream(request_id, cancel=None) self.request_streams[request_id] = results_generator server_info_copy = copy.deepcopy(self.server_info) diff --git a/llumnix/llm_engine_manager.py b/llumnix/llm_engine_manager.py index 931d33a8..45baaa68 100644 --- a/llumnix/llm_engine_manager.py +++ b/llumnix/llm_engine_manager.py @@ -241,7 +241,7 @@ async def migrate_done_callback(ret, migrate_instance_pair: Tuple[str, str]) -> logger.info("[_migrate] instance {} is dead".format(instance_id)) self.scale_down(instance_id) else: - migrate_out_request_ids = ret + migrate_out_request_ids = ret[0] if migrate_out_request_ids: migrate_out_request_id = migrate_out_request_ids[0] self.request_instance[migrate_out_request_id] = migrate_instance_pair[1] @@ -489,6 +489,7 @@ def init_llumlets(self, engine_args, node_id: str, request_output_queue_type: Qu world_size, engine_manager_args.create_migration_config(), engine_manager_args.profiling_result_file_path, + engine_args, *args, **kwargs ) diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index 56ab4435..9d50a5a1 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -117,7 +117,7 @@ def from_args(cls, else: # backend_type == backend_type.SIM_VLLM: kwargs["node_id"] = node_id engine_class = ray.remote(num_cpus=1, - num_gpu=num_gpu, + num_gpus=num_gpu, name=actor_name, namespace='llumnix', max_concurrency=4, diff --git a/llumnix/llumlet/migration_coordinator.py b/llumnix/llumlet/migration_coordinator.py index bc356f48..6c56096a 100644 --- a/llumnix/llumlet/migration_coordinator.py +++ b/llumnix/llumlet/migration_coordinator.py @@ -12,6 +12,7 @@ # limitations under the License. import time +import traceback import enum from typing import List @@ -71,8 +72,9 @@ async def migrate_out_waiting_request(self, dst_blocks = await migrate_in_ray_actor.execute_migration_method \ .remote("migrate_in_pre_alloc", migrate_out_request.request_id, migrate_out_request.status, - migrate_out_request.arrival_time, - migrate_out_request.prefill_num_blocks) + migrate_out_request.request_arrival_time, + migrate_out_request.prefill_num_blocks, + migrate_out_request.token_ids) if len(dst_blocks) != migrate_out_request.prefill_num_blocks: self.backend_engine.add_waiting_request(migrate_out_request) self.backend_engine.remove_migrating_out_request_last_stage(migrate_out_request) @@ -115,18 +117,20 @@ async def _migrate_out_onestage(self, return MigrationStatus.ABORTED_SRC pre_stage_num_blocks = sum(migrate_out_request.stage_num_blocks_list) - incremental_blocks = self.backend_engine.get_request_incremental_blocks(migrate_out_request, pre_stage_num_blocks) + incremental_blocks, incremental_token_ids = self.backend_engine.get_request_incremental_blocks(migrate_out_request, pre_stage_num_blocks) # live migration, transfer all blocks except last one(currently updating) is_last_stage = (len(incremental_blocks) <= self.last_stage_max_blocks) or migrate_out_request.blocking_migration if not is_last_stage: migration_status = MigrationStatus.RUNNING src_blocks = incremental_blocks[:-1] + incremental_token_ids = incremental_token_ids[:src_blocks*migrate_out_request.block_size] stage_block_num = len(incremental_blocks) - 1 dst_blocks = await migrate_in_ray_actor.execute_migration_method \ .remote("migrate_in_pre_alloc", migrate_out_request.request_id, migrate_out_request.status, - migrate_out_request.arrival_time, - stage_block_num) + migrate_out_request.request_arrival_time, + stage_block_num, + incremental_token_ids) else: # last stage migration, stop inference, transfer all blocks migration_status = MigrationStatus.FINISHED @@ -139,8 +143,9 @@ async def _migrate_out_onestage(self, dst_blocks = await migrate_in_ray_actor.execute_migration_method \ .remote("migrate_in_pre_alloc", migrate_out_request.request_id, migrate_out_request.status, - migrate_out_request.arrival_time, - stage_block_num) + migrate_out_request.request_arrival_time, + stage_block_num, + incremental_token_ids) if len(dst_blocks) != len(src_blocks): # migrate-in instance failed to pre alloc @@ -172,13 +177,15 @@ def migrate_in_pre_alloc(self, request_id: str, request_status: RequestStatus, request_arrival_time: float, - block_num: int) -> List[int]: + block_num: int, + token_ids: List[int]) -> List[int]: """prev alloc blocks to migrate in request """ pre_alloc_blocks = self.backend_engine.pre_alloc(request_id, request_status, request_arrival_time, - block_num) + block_num, + token_ids) if len(pre_alloc_blocks) != block_num: # failed to alloc, abort request self.free_dst_pre_alloc_cache(request_id) diff --git a/llumnix/llumlet/request.py b/llumnix/llumlet/request.py index d6c7dac5..d944b45a 100644 --- a/llumnix/llumlet/request.py +++ b/llumnix/llumlet/request.py @@ -92,7 +92,7 @@ def finished(self) -> bool: raise NotImplementedError @property - def arrival_time(self) -> float: + def request_arrival_time(self) -> float: raise NotImplementedError @property @@ -103,6 +103,18 @@ def status(self) -> RequestStatus: def prefill_num_blocks(self) -> int: raise NotImplementedError + @property + def n_blocks(self) -> int: + raise NotImplementedError + + @property + def token_ids(self) -> int: + raise NotImplementedError + + @property + def block_size(self) -> int: + raise NotImplementedError + # Whether the migration of request is completed within one stage. For requests that have already reached # the expected steps, blocking_migration is True. @property diff --git a/requirements/requirements_vllm.txt b/requirements/requirements_vllm.txt index f9fbe6a6..2ebfe3c8 100644 --- a/requirements/requirements_vllm.txt +++ b/requirements/requirements_vllm.txt @@ -1,5 +1,5 @@ -vllm == 0.4.2 -ray >= 2.9.0 +vllm == 0.6.3.post1 +ray @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl pyarrow # Required for Ray data. aiohttp scipy @@ -8,6 +8,5 @@ matplotlib func_timeout pyyaml yacs -numpy < 1.24.0 # for gloo migration backend's compatibility with numpy.float pyzmq loguru diff --git a/tests/e2e_test/test_e2e.py b/tests/e2e_test/test_e2e.py index 87b03417..fc52d200 100644 --- a/tests/e2e_test/test_e2e.py +++ b/tests/e2e_test/test_e2e.py @@ -62,7 +62,7 @@ def run_vllm(model, max_model_len, sampling_params): @pytest.mark.asyncio @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="at least 1 gpus required for e2e test") @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) -@pytest.mark.parametrize("migration_backend", ['rayrpc', 'gloo']) +@pytest.mark.parametrize("migration_backend", ['rayrpc']) @pytest.mark.parametrize("launch_mode", ['eief', 'eidf', 'dief', 'didf']) async def test_e2e(cleanup_ray_env, shutdown_llumnix_service, model, migration_backend, launch_mode): if migration_backend == 'gloo' and launch_mode != 'eief': @@ -71,7 +71,6 @@ async def test_e2e(cleanup_ray_env, shutdown_llumnix_service, model, migration_b sampling_params = { "n": 1, "best_of": 1, - "use_beam_search": False, "temperature": 0.0, "top_k": 1, "ignore_eos": False, @@ -92,10 +91,11 @@ async def test_e2e(cleanup_ray_env, shutdown_llumnix_service, model, migration_b ip=ip, port=base_port, migration_backend=migration_backend, + launch_ray_cluster=False, launch_mode=launch_mode) subprocess.run(launch_command, shell=True, check=True) - wait_for_llumnix_service_ready(ip_ports=[f"{ip}:{base_port}"]) + wait_for_llumnix_service_ready(ip_ports=[f"{ip}:{base_port}"], timeout=120) llumnix_output = {} for prompt in prompts: diff --git a/tests/e2e_test/test_migration.py b/tests/e2e_test/test_migration.py index ba0793da..1558f0dd 100644 --- a/tests/e2e_test/test_migration.py +++ b/tests/e2e_test/test_migration.py @@ -90,7 +90,7 @@ def get_instance_num_blocks(): @pytest.mark.asyncio @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="at least 2 gpus required for migration bench") @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) -@pytest.mark.parametrize("migration_backend", ['rayrpc', 'gloo']) +@pytest.mark.parametrize("migration_backend", ['rayrpc', 'gloo', 'nccl']) @pytest.mark.parametrize("migrated_request_status", ['running', 'waiting']) async def test_migration_benchmark(cleanup_ray_env, shutdown_llumnix_service, model, migration_backend, migrated_request_status): if migrated_request_status == 'waiting' and migration_backend != 'rayrpc': diff --git a/tests/e2e_test/utils.py b/tests/e2e_test/utils.py index 5e2b05f6..513a2004 100644 --- a/tests/e2e_test/utils.py +++ b/tests/e2e_test/utils.py @@ -61,7 +61,6 @@ def generate_launch_command(result_filename: str = "", f"{'--log-instance-info ' if log_instance_info else ''}" f"--enable-migration " f"--model {model} " - f"--engine-use-ray " f"--worker-use-ray " f"--max-model-len {max_model_len} " f"--dispatch-policy {dispatch_policy} " diff --git a/tests/unit_test/backends/vllm/test_llm_engine.py b/tests/unit_test/backends/vllm/test_llm_engine.py index bbd8477e..8b68c1b4 100644 --- a/tests/unit_test/backends/vllm/test_llm_engine.py +++ b/tests/unit_test/backends/vllm/test_llm_engine.py @@ -12,12 +12,15 @@ # limitations under the License. import math +import torch from unittest.mock import MagicMock import ray +import pytest from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput, - SequenceStatus,SamplerOutput) + SequenceStatus) from vllm import EngineArgs, SamplingParams +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.engine.output_processor.single_step import SingleStepOutputProcessor from vllm.engine.output_processor.stop_checker import StopChecker from vllm.transformers_utils.detokenizer import Detokenizer @@ -31,63 +34,25 @@ from llumnix.server_info import ServerInfo from .utils import create_dummy_prompt, initialize_scheduler - +# pylint: disable=unused-import +from tests.conftest import setup_ray_env class MockEngine(LLMEngineLlumnix): def __init__(self, *args, executor_class=None, **kwargs): - self.scheduler = initialize_scheduler() + self.scheduler = [initialize_scheduler()] detokenizer = MagicMock(spec=Detokenizer) stop_checker = MagicMock(spec=StopChecker) self.seq_counter = Counter() self.instance_info = None self.executor_class = executor_class - self.scheduler.add_update_instance_info_callback(self.update_instance_info) - self.output_processor = SingleStepOutputProcessor(self.scheduler.scheduler_config,detokenizer, self.scheduler, self.seq_counter, stop_checker) + self.scheduler[0].add_update_instance_info_callback(self.update_instance_info) + self.output_processor = SingleStepOutputProcessor(self.scheduler[0].scheduler_config,detokenizer, self.scheduler, self.seq_counter, stop_checker) def update_instance_info(self, instance_info): pass - -def test_llm_engine_process_model_outputs(): - llm_engine = MockEngine() - _, seq_group_0 = create_dummy_prompt( - "0", prompt_length=7, block_size=4 - ) - _, seq_group_1 = create_dummy_prompt( - "1", prompt_length=7, block_size=4 - ) - llm_engine.scheduler.add_seq_group(seq_group_0) - llm_engine.scheduler.add_seq_group(seq_group_1) - metas, out = llm_engine.scheduler.schedule() - - seqs = [seq_group_0.get_seqs()[0], seq_group_1.get_seqs()[0]] - - outputs = [ - SequenceGroupOutput( - samples=[ - SequenceOutput( - parent_seq_id=seq.seq_id, - output_token=1, - logprobs={1: Logprob(0.0)}, - ) - ], - prompt_logprobs=None, - ) for seq in seqs - ] - sampler_outputs = [SamplerOutput(outputs=outputs)] - - scheduled_seq_groups = out.scheduled_seq_groups - # normal case, all requests be processed - ret, _ = llm_engine._process_model_outputs(sampler_outputs, scheduled_seq_groups,[], metas) - assert len(ret) == 2 - metas, out = llm_engine.scheduler.schedule() - scheduled_seq_groups = out.scheduled_seq_groups - seqs[0].status=SequenceStatus.WAITING - # migration case , requests stopping during last stage migration, stop process - ret, _ = llm_engine._process_model_outputs(sampler_outputs, scheduled_seq_groups,[], metas) - assert len(ret) == 1 - -def test_llm_engine_from_engine_args(): +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need at least 1 GPU to run the test.") +def test_llm_engine_from_engine_args(setup_ray_env): engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) llm_engine = MockEngine.from_engine_args(engine_args, request_output_queue_type=QueueType.RAYQUEUE, instance_id="0", migration_config=None) @@ -98,19 +63,20 @@ def test_llm_engine_from_engine_args(): instance_id="0", migration_config=None, latency_mem=latency_data) assert llm_engine.executor_class == SimGPUExecutor -def test_llm_engine_add_requset(): +def test_llm_engine_add_requset(setup_ray_env): engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) + latency_data = LatencyMemData({},{},{}) llm_engine = LLMEngineLlumnix.from_engine_args(engine_args, request_output_queue_type=QueueType.RAYQUEUE, instance_id="0", placement_group=None, + latency_mem = latency_data, node_id=ray.get_runtime_context().get_node_id(), - migration_config=None, - latency_mem=MagicMock(sepc=LatencyMemData)) + migration_config=None) sampling_params = SamplingParams(top_k=1, temperature=0, ignore_eos=True, max_tokens=100) server_info = ServerInfo(None, None, None, None, None) llm_engine.add_request("0", server_info, math.inf, "prompt", sampling_params) - assert len(llm_engine.scheduler.waiting) == 1 - assert llm_engine.scheduler.waiting[-1].request_id == "0" - assert llm_engine.scheduler.waiting[-1].expected_steps == math.inf - assert isinstance(llm_engine.scheduler.waiting[-1], LlumnixRequest) + assert len(llm_engine.scheduler[0].waiting) == 1 + assert llm_engine.scheduler[0].waiting[-1].request_id == "0" + assert llm_engine.scheduler[0].waiting[-1].expected_steps == math.inf + assert isinstance(llm_engine.scheduler[0].waiting[-1], LlumnixRequest) diff --git a/tests/unit_test/backends/vllm/test_migration.py b/tests/unit_test/backends/vllm/test_migration.py index d73b130e..37515649 100644 --- a/tests/unit_test/backends/vllm/test_migration.py +++ b/tests/unit_test/backends/vllm/test_migration.py @@ -58,7 +58,7 @@ class MockLlumletDoNotSchedule(Llumlet): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # stop the schedule in engine step loop - self.backend_engine.engine.scheduler.schedule = MagicMock() + self.backend_engine.engine.scheduler[0].schedule = MagicMock() # For some reason, if MockScheduelrOutputs is defined outside, the constructor would raise error. class MockScheduelrOutputs: @@ -66,18 +66,19 @@ def __init__(self): self.scheduled_seq_groups = [] self.ignored_seq_groups = [] self.num_batched_tokens = 0 + self.preempted = False def is_empty(self) -> bool: return not self.scheduled_seq_groups scheduler_outputs = MockScheduelrOutputs() - self.backend_engine.engine.scheduler.schedule.return_value = ([], scheduler_outputs) + self.backend_engine.engine.scheduler[0].schedule.return_value = ([], scheduler_outputs, False) self.step_async = self.backend_engine.engine.step_async async def step_async_try_schedule(): request_outputs, server_infos = await self.step_async() - for seq_group in self.backend_engine.engine.scheduler.waiting: + for seq_group in self.backend_engine.engine.scheduler[0].waiting: seq_group.try_schedule_times += 1 return request_outputs, server_infos @@ -87,7 +88,7 @@ async def step_async_try_schedule(): @pytest.mark.parametrize("migration_request_status", ['waiting', 'running']) @pytest.mark.asyncio async def test_migration_correctness(setup_ray_env, migration_backend, migration_request_status): - engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) + engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True, gpu_memory_utilization=0.9) id_rank_map = {"0": 0, "1": 1, "2": 2} if migration_request_status == 'running': request_migration_policy = "SR" @@ -206,7 +207,7 @@ async def test_correctness(prompt): @pytest.mark.parametrize("migration_backend", ['rayrpc', 'gloo', 'nccl']) @pytest.mark.asyncio async def test_pd_diaggregation_correctness(setup_ray_env, migration_backend): - engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) + engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True, gpu_memory_utilization=0.9) id_rank_map = {"0":0, "1":1} migration_config = MigrationConfig("SR", migration_backend, 16, 1, 4, 5, 20) @@ -288,13 +289,13 @@ async def test_correctness(prompt): que.cleanup() def test_clear_migration_states(): - llumlet = MockLlumlet() - llumlet.backend_engine.pre_alloc("0", RequestStatus.RUNNING, 0.0, 1) num_gpu_blocks = 8 block_size = 4 + llumlet = MockLlumlet() + llumlet.backend_engine.pre_alloc("0", RequestStatus.RUNNING, 0.0, 1, range(4)) llumlet.clear_migration_states(is_migrate_in=True) - assert len(llumlet.backend_engine.pre_alloc("0", RequestStatus.RUNNING, 0.0, num_gpu_blocks)) == num_gpu_blocks + assert len(llumlet.backend_engine.pre_alloc("0", RequestStatus.RUNNING, 0.0, num_gpu_blocks, range(4*num_gpu_blocks))) == num_gpu_blocks _, seq_group = create_dummy_prompt("0",7,block_size,SequenceStatus.RUNNING) seq_group.set_status(RequestStatus.RUNNING_MIGRATING) llumlet.backend_engine.add_migrating_out_request_last_stage(seq_group) diff --git a/tests/unit_test/backends/vllm/test_migration_backend.py b/tests/unit_test/backends/vllm/test_migration_backend.py index 5b92fb9c..e496f124 100644 --- a/tests/unit_test/backends/vllm/test_migration_backend.py +++ b/tests/unit_test/backends/vllm/test_migration_backend.py @@ -28,13 +28,14 @@ class MockMigrationWorker(MigrationWorker): def set_gpu_cache(self, data): - for layer_idx in range(self.cache_engine.num_layers): - self.gpu_cache[layer_idx].copy_(data[layer_idx]) + print(f"data shape:::{self.gpu_cache[0][0].shape, data[0].shape}") + for layer_idx in range(self.cache_engine[0].num_attention_layers): + self.gpu_cache[0][layer_idx].copy_(data[layer_idx]) torch.cuda.synchronize() def get_gpu_cache(self): torch.cuda.synchronize() - return self.gpu_cache + return self.gpu_cache[0] @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPU to run the test.") @pytest.mark.parametrize("backend", ['rayrpc', 'gloo', 'nccl']) @@ -87,7 +88,7 @@ def test_migrate_cache(setup_ray_env, backend): num_heads = engine_config.model_config.get_num_kv_heads(engine_config.parallel_config) block_size = engine_config.cache_config.block_size - dummy_data = torch.randn(size=(num_layers, 2, num_gpu_blocks, block_size*num_heads*head_size)) + dummy_data = torch.randn(size=(num_layers, 2, num_gpu_blocks, block_size, num_heads, head_size)) ray.get(worker0.execute_method.remote('set_gpu_cache', data=dummy_data)) worker0_data = ray.get(worker0.execute_method.remote('get_gpu_cache')) diff --git a/tests/unit_test/backends/vllm/test_scheduler.py b/tests/unit_test/backends/vllm/test_scheduler.py index c8a03981..f20d14b5 100644 --- a/tests/unit_test/backends/vllm/test_scheduler.py +++ b/tests/unit_test/backends/vllm/test_scheduler.py @@ -14,9 +14,8 @@ import math import time -from vllm.sequence import Sequence -from vllm.sequence import Logprob -from vllm.core.policy import PolicyFactory +from vllm.sequence import Logprob, Sequence +from vllm.inputs import token_inputs from llumnix.backends.vllm.scheduler import BlockManagerLlumnix from llumnix.llumlet.request import RequestInferenceType, RequestStatus @@ -27,7 +26,7 @@ def get_sequence_groups(scheduler_output): return [s.seq_group for s in scheduler_output.scheduled_seq_groups] def schedule_and_update_computed_tokens(scheduler): - metas, out = scheduler.schedule() + metas, out, _ = scheduler.schedule() for s, meta in zip(out.scheduled_seq_groups, metas): s.seq_group.update_num_computed_tokens(meta.token_chunk_size) return metas, out @@ -52,10 +51,10 @@ def test_manager_get_free_blocks(): num_gpu_blocks, watermark=0) before_allocate = block_manager.get_num_free_gpu_blocks() - block_table = block_manager.get_free_blocks(2) + block_table = block_manager.get_free_blocks(2, range(2*block_size)) after_allocate = block_manager.get_num_free_gpu_blocks() assert after_allocate + 2 == before_allocate - block_manager._free_block_table(block_table) + block_table.free() after_free = block_manager.get_num_free_gpu_blocks() assert after_free == before_allocate @@ -67,8 +66,8 @@ def test_manager_add_block_table(): num_cpu_blocks, num_gpu_blocks, watermark=0) - block_table = block_manager.get_free_blocks(2) - seq = Sequence(1,"1",[0], block_size=block_size) + block_table = block_manager.get_free_blocks(2, range(2*block_size)) + seq = Sequence(1,token_inputs([0]),block_size) block_manager.add_block_table(block_table, seq.seq_id) after_allocate = block_manager.get_num_free_gpu_blocks() assert after_allocate + 2 == num_gpu_blocks @@ -88,7 +87,7 @@ def test_sequence_group_inference_type(): for req in scheduler.waiting: assert req.inference_type == RequestInferenceType.PREFILL # all seq_group in prefilling stage - metas, out = scheduler.schedule() + metas, out, _ = scheduler.schedule() for req in scheduler.running: assert req.inference_type == RequestInferenceType.PREFILL for s, meta in zip(out.scheduled_seq_groups, metas): @@ -162,18 +161,18 @@ def test_scheduler_migrating_out_request_last_stage(): def test_scheduler_pre_alloc(): # total 8 blocks scheduler = initialize_scheduler() - blocks = scheduler.pre_alloc("1", RequestStatus.RUNNING, 0.0, 2) + + blocks = scheduler.pre_alloc("1", RequestStatus.RUNNING, 0.0, 2, range(2*4)) assert len(blocks) == 2 - assert len(scheduler.pre_alloc_cache_dict["1"]) == 2 - blocks = scheduler.pre_alloc("1", RequestStatus.RUNNING, 0.0, 4) + assert len(scheduler.pre_alloc_cache_dict["1"].physical_block_ids) == 2 + blocks = scheduler.pre_alloc("1", RequestStatus.RUNNING, 0.0, 4, range(4*4)) assert len(blocks) == 4 - assert len(scheduler.pre_alloc_cache_dict["1"]) == 6 - blocks = scheduler.pre_alloc("2", RequestStatus.RUNNING, 0.0, 4) + assert len(scheduler.pre_alloc_cache_dict["1"].physical_block_ids) == 6 + blocks = scheduler.pre_alloc("2,", RequestStatus.RUNNING, 0.0, 4, range(4*4)) assert len(blocks) == 0 def test_schedule_running(): scheduler = initialize_scheduler() - policy = PolicyFactory.get_policy(policy_name="fcfs") budget = create_token_budget() curr_loras = None @@ -181,21 +180,21 @@ def test_schedule_running(): scheduler._allocate_and_set_running(seq_group_0) append_new_token_seq_group(1, seq_group_0, 1) scheduler.running.append(seq_group_0) - remainig_running, running_scheduled = scheduler._schedule_running( - scheduler.running, budget, curr_loras, policy) - assert len(running_scheduled.decode_seq_groups) == 1 - assert len(running_scheduled.prefill_seq_groups) == 0 - assert len(remainig_running) == 0 + running_scheduled = scheduler._schedule_running(budget, curr_loras) + + assert len(running_scheduled.decode_seq_groups_list) == 1 + assert len(running_scheduled.prefill_seq_groups_list) == 0 + assert len(scheduler.running) == 0 _, seq_group_1 = create_dummy_prompt("1", prompt_length=1, expected_steps=1) scheduler._allocate_and_set_running(seq_group_1) append_new_token_seq_group(1, seq_group_1, 1) scheduler.running.append(seq_group_1) - remainig_running, running_scheduled = scheduler._schedule_running( - scheduler.running, budget, curr_loras, policy) - assert len(running_scheduled.decode_seq_groups) == 1 - assert len(running_scheduled.prefill_seq_groups) == 0 - assert len(remainig_running) == 1 + running_scheduled = scheduler._schedule_running( + scheduler.running, budget, curr_loras) + assert len(running_scheduled.decode_seq_groups_list) == 0 + assert len(running_scheduled.prefill_seq_groups_list) == 0 + assert len(scheduler.running) == 1 # test pre alloc waiting condition # total 8 blocks @@ -203,19 +202,19 @@ def test_schedule_running(): before_arrival = time.time() _, seq_group = create_dummy_prompt("1", prompt_length=1, block_size=2, expected_steps=math.inf) after_arrival = time.time() - blocks = scheduler.pre_alloc("2", RequestStatus.WAITING_MIGRATING, after_arrival, 2) + blocks = scheduler.pre_alloc("2", RequestStatus.WAITING_MIGRATING, after_arrival, 2, range(2*4)) assert len(blocks) == 2 scheduler.add_waiting_request(seq_group) - blocks = scheduler.pre_alloc("3", RequestStatus.WAITING_MIGRATING, after_arrival, 2) + blocks = scheduler.pre_alloc("3", RequestStatus.WAITING_MIGRATING, after_arrival, 2, range(2*4)) assert len(blocks) == 0 - blocks = scheduler.pre_alloc("4", RequestStatus.WAITING_MIGRATING, before_arrival, 2) + blocks = scheduler.pre_alloc("4", RequestStatus.WAITING_MIGRATING, before_arrival, 2, range(2*4)) assert len(blocks) == 2 def test_try_schedule_times(): # total 8 blocks scheduler = initialize_scheduler() - _, seq_group_1 = create_dummy_prompt("1", prompt_length=8, block_size=1) - _, seq_group_2 = create_dummy_prompt("2", prompt_length=8, block_size=1) + _, seq_group_1 = create_dummy_prompt("1", prompt_length=32, block_size=4) + _, seq_group_2 = create_dummy_prompt("2", prompt_length=32, block_size=4) scheduler.add_seq_group(seq_group_1) scheduler.add_seq_group(seq_group_2) waiting_queue = scheduler.get_waiting_queue() @@ -226,6 +225,7 @@ def test_try_schedule_times(): # seq_group_2 cannot be scheduled due to lack of blocks assert seq_group_1.try_schedule_times == 0 assert seq_group_2.try_schedule_times == 1 + append_new_token_seq_group(1, seq_group_1, 1) scheduler.schedule() # seq_group_1 is preempted to waiting queue assert seq_group_1.try_schedule_times == 1 diff --git a/tests/unit_test/backends/vllm/test_simulator.py b/tests/unit_test/backends/vllm/test_simulator.py index 9a685e18..d66a3523 100644 --- a/tests/unit_test/backends/vllm/test_simulator.py +++ b/tests/unit_test/backends/vllm/test_simulator.py @@ -42,11 +42,12 @@ async def test_executor(): scheduler_config=engine_config.scheduler_config, device_config=engine_config.device_config, lora_config=engine_config.lora_config, - vision_language_config=engine_config.vision_language_config, speculative_config=engine_config.speculative_config, - load_config=engine_config.load_config) + load_config=engine_config.load_config, + prompt_adapter_config=engine_config.prompt_adapter_config, + observability_config=engine_config.observability_config) scheduler = initialize_scheduler() - scheduler.schedule() + metas, out, _ = scheduler.schedule() _, seq_group_0 = create_dummy_prompt( "0", prompt_length=7, block_size=4 ) @@ -55,7 +56,7 @@ async def test_executor(): ) scheduler.add_seq_group(seq_group_0) scheduler.add_seq_group(seq_group_1) - metas, out = scheduler.schedule() + metas, out, _ = scheduler.schedule() execute_model_req = ExecuteModelRequest( seq_group_metadata_list=metas, blocks_to_swap_in=out.blocks_to_swap_in, diff --git a/tests/unit_test/backends/vllm/test_worker.py b/tests/unit_test/backends/vllm/test_worker.py index 440bf6e9..67205355 100644 --- a/tests/unit_test/backends/vllm/test_worker.py +++ b/tests/unit_test/backends/vllm/test_worker.py @@ -50,7 +50,6 @@ def create_worker(rank: int, local_rank: int, engine_config: EngineConfig, rank=rank, distributed_init_method=get_distributed_init_method(get_ip(), get_open_port()), lora_config=engine_config.lora_config, - vision_language_config=engine_config.vision_language_config, is_driver_worker = False ) diff --git a/tests/unit_test/backends/vllm/utils.py b/tests/unit_test/backends/vllm/utils.py index 887bdd93..f4628e5e 100644 --- a/tests/unit_test/backends/vllm/utils.py +++ b/tests/unit_test/backends/vllm/utils.py @@ -19,6 +19,7 @@ from vllm import SamplingParams from vllm.lora.request import LoRARequest from vllm.sequence import Logprob, Sequence, SequenceStatus +from vllm.inputs import token_inputs from vllm.config import SchedulerConfig, CacheConfig from vllm.core.scheduler import SchedulingBudget @@ -47,7 +48,6 @@ def create_dummy_prompt( block_size: Optional[int] = None, status: SequenceStatus = SequenceStatus.WAITING, lora_request: Optional[LoRARequest] = None, - use_beam_search: bool = False, best_of: int = 1, expected_steps: int = math.inf, ) -> Tuple[Sequence, SequenceGroupLlumnix]: @@ -57,16 +57,15 @@ def create_dummy_prompt( # Create dummy prompt sequence with tokens 0...block_size-1 # and prompt "0 ... block_size". prompt_tokens = list(range(prompt_length)) - prompt_str = " ".join([str(t) for t in prompt_tokens]) - prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) + seq = Sequence(int(request_id), token_inputs(prompt_tokens), block_size) server_info = ServerInfo(None, None, None, None, None) seq_group = SequenceGroupLlumnix( - request_id, server_info, expected_steps, [prompt], - SamplingParams(use_beam_search=use_beam_search, best_of=best_of), - time.time(), lora_request) - seq_group.get_seqs()[0].status = status + request_id, server_info, expected_steps, [seq], + time.time(), + SamplingParams(best_of=best_of), + lora_request) - return prompt, seq_group + return seq, seq_group def create_seq_group( diff --git a/tests/unit_test/llumlet/test_local_migration_scheduler.py b/tests/unit_test/llumlet/test_local_migration_scheduler.py index ecca2b71..c1506c4a 100644 --- a/tests/unit_test/llumlet/test_local_migration_scheduler.py +++ b/tests/unit_test/llumlet/test_local_migration_scheduler.py @@ -46,7 +46,7 @@ def output_len(self) -> int: return self.length @property - def arrival_time(self) -> float: + def request_arrival_time(self) -> float: pass @property diff --git a/tests/unit_test/llumlet/test_migration_coordinator.py b/tests/unit_test/llumlet/test_migration_coordinator.py index fcdf0638..c2f3d682 100644 --- a/tests/unit_test/llumlet/test_migration_coordinator.py +++ b/tests/unit_test/llumlet/test_migration_coordinator.py @@ -43,7 +43,8 @@ async def test_migrate_out_onestage(setup_ray_env): # Mock method return values and test data src_blocks = [1, 2, 3] dst_blocks = [1, 2] - backend_engine.get_request_incremental_blocks.return_value = src_blocks + backend_engine.get_request_incremental_blocks.return_value = src_blocks, [] + migrate_out_request.n_blocks = 3 migrate_out_request.should_abort_migration.return_value = False migrate_out_request.blocking_migration = False migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) @@ -55,7 +56,7 @@ async def test_migrate_out_onestage(setup_ray_env): # Test the last stage of migration src_blocks = [3] dst_blocks = [3] - backend_engine.get_request_incremental_blocks.return_value = src_blocks + backend_engine.get_request_incremental_blocks.return_value = src_blocks, [] migrate_out_request.should_abort_migration.return_value = False migrate_out_request.blocking_migration = False migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) @@ -66,7 +67,8 @@ async def test_migrate_out_onestage(setup_ray_env): # Test migration dst aborted scenario src_blocks = [1, 2, 3] dst_blocks = [] - backend_engine.get_request_incremental_blocks.return_value = src_blocks + backend_engine.get_request_incremental_blocks.return_value = src_blocks, [] + migrate_out_request.n_blocks = 3 migrate_out_request.should_abort_migration.return_value = False migrate_out_request.blocking_migration = False migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks) @@ -77,7 +79,8 @@ async def test_migrate_out_onestage(setup_ray_env): migrate_out_request = MagicMock() src_blocks = [1, 2, 3] dst_blocks = [1, 2] - backend_engine.get_request_incremental_blocks.return_value = src_blocks + backend_engine.get_request_incremental_blocks.return_value = src_blocks, [] + migrate_out_request.n_blocks = 3 migrate_out_request.should_abort_migration.return_value = True migrate_out_request.blocking_migration = False migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks)