From fc5ebbd1d3453461ea6e00a78faf87c41d1aa625 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Fri, 23 Aug 2024 11:06:54 +0800 Subject: [PATCH] [Hardware][Intel GPU] refactor xpu_model_runner for tp (#7712) --- vllm/executor/ray_xpu_executor.py | 383 +----------------- vllm/worker/xpu_model_runner.py | 628 +++++++++++++++++------------- vllm/worker/xpu_worker.py | 9 +- 3 files changed, 370 insertions(+), 650 deletions(-) diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py index 938f83bc1338b..2b1cdc09b0a9f 100644 --- a/vllm/executor/ray_xpu_executor.py +++ b/vllm/executor/ray_xpu_executor.py @@ -1,386 +1,37 @@ import asyncio -import os -from collections import defaultdict -from itertools import islice, repeat -from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set, - Tuple, Union) +from typing import List, Optional import vllm.envs as envs -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, PromptAdapterConfig, - SchedulerConfig, SpeculativeConfig) -from vllm.executor.distributed_gpu_executor import ( # yapf: disable - DistributedGPUExecutor, DistributedGPUExecutorAsync) -from vllm.executor.ray_utils import RayWorkerWrapper, ray +from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync +from vllm.executor.xpu_executor import XPUExecutor from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - make_async) - -if ray is not None: - from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy - -if TYPE_CHECKING: - from ray.util.placement_group import PlacementGroup +from vllm.utils import get_vllm_instance_id, make_async logger = init_logger(__name__) -# If the env var is set, it uses the Ray's compiled DAG API -# which optimizes the control plane overhead. -# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. -USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG - - -class RayXPUExecutor(DistributedGPUExecutor): - - uses_ray: bool = True - - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - prompt_adapter_config: Optional[PromptAdapterConfig], - speculative_config: Optional[SpeculativeConfig], - ) -> None: - assert device_config.device_type == "xpu" - assert (not speculative_config - ), "Speculative decoding not yet supported for XPU backend" - - self.model_config = model_config - self.cache_config = cache_config - self.load_config = load_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.prompt_adapter_config = prompt_adapter_config - - placement_group = self.parallel_config.placement_group - - # Disable Ray usage stats collection. - ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") - if ray_usage != "1": - os.environ["RAY_USAGE_STATS_ENABLED"] = "0" - - # Create the parallel GPU workers. - self._init_workers_ray(placement_group) - - self.forward_dag = None - if USE_RAY_COMPILED_DAG: - self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) - - # This is non-None when the execute model loop is running - # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. - self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None - # Updated by implementations that require additional args to be passed - # to the _run_workers execute_model call - self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} - - def _init_executor(self) -> None: - pass - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks. - This invokes `determine_num_available_blocks` on each worker and takes - the min of the results, guaranteeing that the selected cache sizes are - compatible with all workers. - - Returns: - - Tuple[num_gpu_blocks, num_cpu_blocks] - """ - # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_blocks = self._run_workers("determine_num_available_blocks", ) - - # Since we use a shared centralized controller, we take the minimum - # number of blocks across all workers to make sure all the memory - # operators can be applied to all workers. - num_gpu_blocks = min(b[0] for b in num_blocks) - num_cpu_blocks = min(b[1] for b in num_blocks) - - return num_gpu_blocks, num_cpu_blocks - - def _get_worker_wrapper_args(self) -> Dict[str, Any]: - return dict( - worker_module_name="vllm.worker.xpu_worker", - worker_class_name="XPUWorker", - trust_remote_code=self.model_config.trust_remote_code, - ) - - def _init_workers_ray(self, placement_group: "PlacementGroup", - **ray_remote_kwargs): - if self.parallel_config.tensor_parallel_size == 1: - # For single GPU case, we use a ray worker with constrained memory. - num_gpus = self.cache_config.gpu_memory_utilization - else: - # Otherwise, the ray workers are allocated with a full GPU. - num_gpus = 1 - - # The driver dummy worker does not actually use any resources. - # It holds the resource for the driver worker. - self.driver_dummy_worker: Optional[RayWorkerWrapper] = None - # The remaining workers are the actual ray actors. - self.workers: List[RayWorkerWrapper] = [] - - # Create the workers. - driver_ip = get_ip() - worker_wrapper_kwargs = self._get_worker_wrapper_args() - for bundle_id, bundle in enumerate(placement_group.bundle_specs): - if not bundle.get("GPU", 0): - continue - scheduling_strategy = PlacementGroupSchedulingStrategy( - placement_group=placement_group, - placement_group_capture_child_tasks=True, - placement_group_bundle_index=bundle_id, - ) - worker = ray.remote( - num_cpus=0, - num_gpus=num_gpus, - scheduling_strategy=scheduling_strategy, - **ray_remote_kwargs, - )(RayWorkerWrapper).remote(**worker_wrapper_kwargs) - - worker_ip = ray.get(worker.get_node_ip.remote()) - if worker_ip == driver_ip and self.driver_dummy_worker is None: - # If the worker is on the same node as the driver, we use it - # as the resource holder for the driver process. - self.driver_dummy_worker = worker - self.driver_worker = RayWorkerWrapper(**worker_wrapper_kwargs) - else: - # Else, added to the list of workers. - self.workers.append(worker) - if self.driver_dummy_worker is None: - raise ValueError( - "Ray does not allocate any GPUs on the driver node. Consider " - "adjusting the Ray placement group or running the driver on a " - "GPU node.") +class RayXPUExecutor(RayGPUExecutor, XPUExecutor): + def _get_env_vars_to_be_updated(self): # Get the set of GPU IDs used on each node. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", use_dummy_driver=True) - node_workers = defaultdict(list) - node_gpus = defaultdict(list) - - for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): - node_workers[node_id].append(i) - node_gpus[node_id].extend(gpu_ids) - for node_id, gpu_ids in node_gpus.items(): - node_gpus[node_id] = sorted(gpu_ids) - - # TODO: add env var for xpu - - distributed_init_method = get_distributed_init_method( - driver_ip, get_open_port()) - - def collect_arg_helper_func(**kwargs): - # avoid writing `{"name": value}` manually - return kwargs - - init_worker_all_kwargs = [] - - # Initialize the actual workers inside worker wrapper. - for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ): - local_rank = node_workers[node_id].index(rank) - init_worker_all_kwargs.append( - collect_arg_helper_func( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - is_driver_worker=rank == 0, - )) - self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) + VLLM_INSTANCE_ID = get_vllm_instance_id() - self._run_workers("init_device") - self._run_workers( - "load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers, - ) + # Set environment variables for the driver and workers. + all_args_to_update_environment_variables = [({ + "VLLM_INSTANCE_ID": + VLLM_INSTANCE_ID, + "VLLM_TRACE_FUNCTION": + str(envs.VLLM_TRACE_FUNCTION), + }, ) for (_, _) in worker_node_and_gpu_ids] + return all_args_to_update_environment_variables - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache in all workers. - """ - # NOTE: We log here to avoid multiple logs when number of workers is - # greater than one. We could log in the engine, but not all executors - # have GPUs. - logger.info("# GPU blocks: %d, " - "# CPU blocks: %d", num_gpu_blocks, num_cpu_blocks) - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - self._run_workers("initialize_cache", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks) - - def _driver_execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: - """Run execute_model in the driver worker. - - Passing None will cause the driver to stop the model execution - loop running in each of the remote workers. - """ - return self.driver_worker.execute_method("execute_model", - execute_model_req) - - def add_lora(self, lora_request: LoRARequest) -> bool: - assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." - return self._run_workers( - "add_lora", - lora_request=lora_request, - ) - - def remove_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self._run_workers( - "remove_lora", - lora_id=lora_id, - ) - - def list_loras(self) -> Set[int]: - return self._run_workers("list_loras") - - def _run_workers( - self, - method: str, - *args, - async_run_remote_workers_only: bool = False, - all_args: Optional[List[Tuple[Any, ...]]] = None, - all_kwargs: Optional[List[Dict[str, Any]]] = None, - use_dummy_driver: bool = False, - max_concurrent_workers: Optional[int] = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers. Can be used in the following - ways: - - - args/kwargs: All workers share the same args/kwargs - - args/kwargs and driver_args/driver_kwargs: Driver worker has - different args - - all_args/all_kwargs: args/kwargs for each worker are specified - individually - """ - - if max_concurrent_workers: - raise NotImplementedError( - "max_concurrent_workers is not supported yet.") - - count = len(self.workers) - all_worker_args = repeat(args, count) if all_args is None \ - else islice(all_args, 1, None) - all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ - else islice(all_kwargs, 1, None) - - # Start the ray workers first. - ray_worker_outputs = [ - worker.execute_method.remote(method, *worker_args, **worker_kwargs) - for (worker, worker_args, worker_kwargs - ) in zip(self.workers, all_worker_args, all_worker_kwargs) - ] - - if async_run_remote_workers_only: - # Just return futures - return ray_worker_outputs - - driver_worker_output = [] - driver_args = args if all_args is None else all_args[0] - driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] - # Start the driver worker after all the ray workers. - if not use_dummy_driver: - driver_worker_output = self.driver_worker.execute_method( - method, *driver_args, **driver_kwargs) - else: - assert self.driver_dummy_worker is not None - driver_worker_output = ray.get( - self.driver_dummy_worker.execute_method.remote( - method, *driver_args, **driver_kwargs)) - # Get the results of the ray workers. - if self.workers: - ray_worker_outputs = ray.get(ray_worker_outputs) - - return driver_worker_output + ray_worker_outputs - - def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: - """Wait for futures returned from _run_workers() with - async_run_remote_workers_only to complete.""" - ray.get(parallel_worker_tasks) - - def _compiled_ray_dag(self, enable_asyncio: bool): - import pkg_resources - from packaging import version - - required_version = version.parse("2.32") - current_version = version.parse( - pkg_resources.get_distribution("ray").version) - if current_version < required_version: - raise ValueError(f"Ray version {required_version} or greater is " - f"required, but found {current_version}") - - from ray.dag import InputNode, MultiOutputNode - assert self.parallel_config.use_ray - - # Right now, compiled DAG requires at least 1 arg. We send - # a dummy value for now. It will be fixed soon. - with InputNode() as input_data: - forward_dag = MultiOutputNode([ - worker.execute_model_compiled_dag_remote. - bind( # type: ignore[attr-defined] - input_data) for worker in self.workers - ]) - return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) - - def check_health(self) -> None: - """Raises an error if engine is unhealthy.""" - self._check_if_any_actor_is_dead() - - def _check_if_any_actor_is_dead(self): - if not self.workers: - return - - dead_actors = [] - for actor in self.workers: - actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access - if actor_state["State"] == "DEAD": - dead_actors.append(actor) - if dead_actors: - raise RuntimeError("At least one Worker is dead. " - f"Dead Workers: {dead_actors}. ") - - -class RayXPUExecutorAsync(RayXPUExecutor, DistributedGPUExecutorAsync): +class RayXPUExecutorAsync(RayXPUExecutor, RayGPUExecutorAsync): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.driver_exec_method = make_async(self.driver_worker.execute_method) - - async def _driver_execute_model_async( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: - return await self.driver_exec_method("execute_model", - execute_model_req) - - async def _start_worker_execution_loop(self): - coros = [ - worker.execute_method.remote("start_worker_execution_loop") - for worker in self.workers - ] - return await asyncio.gather(*coros) + self.pp_locks: Optional[List[asyncio.Lock]] = None diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 0bfc57a1c57de..0335bbcd091e8 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -1,14 +1,17 @@ +import dataclasses +import time +import weakref from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, + TypeVar) import torch import torch.nn as nn from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, MultiModalConfig, ParallelConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) -from vllm.distributed import broadcast_tensor_dict from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model @@ -20,7 +23,7 @@ from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata from vllm.worker.model_runner_base import ( - ModelRunnerBase, ModelRunnerInputBase, + ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, @@ -37,6 +40,8 @@ _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) ] +TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU") + @dataclass(frozen=True) class ModelInputForXPU(ModelRunnerInputBase): @@ -46,11 +51,40 @@ class ModelInputForXPU(ModelRunnerInputBase): input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None attn_metadata: Optional["AttentionMetadata"] = None - sampling_metadata: Optional["SamplingMetadata"] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None + virtual_engine: Optional[int] = None + seq_lens: Optional[List[int]] = None + query_lens: Optional[List[int]] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type[TModelInputForXPU], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> TModelInputForXPU: + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +@dataclass(frozen=True) +class ModelInputForXPUWithSamplingMetadata(ModelInputForXPU): + """ + Used by the ModelRunner. + """ + sampling_metadata: Optional["SamplingMetadata"] = None - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, @@ -62,10 +96,10 @@ def as_broadcastable_tensor_dict( @classmethod def from_broadcasted_tensor_dict( - cls: Type["ModelInputForXPU"], + cls, tensor_dict: Dict[str, Any], attn_backend: Optional["AttentionBackend"] = None, - ) -> "ModelInputForXPU": + ) -> "ModelInputForXPUWithSamplingMetadata": tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( @@ -73,7 +107,230 @@ def from_broadcasted_tensor_dict( return cls(**tensor_dict) -class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): +class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): + + def __init__(self, + runner: "XPUModelRunner", + finished_requests_ids: Optional[List[str]] = None) -> None: + super().__init__() + self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] + self.runner = runner + self.model_input_cls = self.runner._model_input_cls + self.attn_backend = self.runner.attn_backend + self.sliding_window = self.runner.sliding_window + self.block_size = self.runner.block_size + self.device = self.runner.device + + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): + self.seq_group_metadata_list.append(seq_group_metadata) + + def build(self) -> ModelInputForXPU: + is_prompt = self.seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, attn_metadata, seq_lens, + multi_modal_kwargs) = self._prepare_prompt( + self.seq_group_metadata_list) + else: + (input_tokens, input_positions, + attn_metadata) = self._prepare_decode( + self.seq_group_metadata_list) + seq_lens = [] + multi_modal_kwargs = None + + return self.model_input_cls( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + multi_modal_kwargs=multi_modal_kwargs, + seq_lens=seq_lens, + query_lens=seq_lens, + ) + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], + BatchedTensorInputs]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + seq_lens: List[int] = [] + multi_modal_inputs_list: List[MultiModalInputs] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + computed_len = seq_data.get_num_computed_tokens() + seq_len = len(prompt_tokens) + + seq_lens.append(seq_len) # Prompt token num + input_tokens.extend(prompt_tokens) # Token ids + + # Token position ids + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + input_positions.extend(list(range(computed_len, seq_len))) + + if seq_group_metadata.block_tables is None: + # During memory profiling, the block tables are not initialized + # yet. In this case, we just use a dummy slot mapping. + slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + continue + + # Compute the slot mapping. + block_table = seq_group_metadata.block_tables[seq_id] + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, seq_len - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + start_idx = max(0, seq_len - self.sliding_window) + + for i in range(computed_len, seq_len): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // + self.block_size] # type: ignore + block_offset = i % self.block_size # type: ignore + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + num_prompt_tokens = len(input_tokens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) # type: ignore + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) # type: ignore + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) # type: ignore + + max_seqlen = max(seq_lens) + tmp = [0] + tmp.extend(seq_lens) + seqlen = torch.tensor(tmp) + seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + slot_mapping=slot_mapping, + seq_lens=seq_lens, + seqlen_q=seqlen_q, + max_seqlen=max_seqlen, + seq_lens_tensor=torch.tensor([]), + max_decode_seq_len=0, + num_prefills=len(seq_lens), + num_prefill_tokens=num_prompt_tokens, + num_decode_tokens=0, + block_tables=torch.tensor([], device=self.device, dtype=torch.int), + ) + + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) + + return (input_tokens, input_positions, attn_metadata, seq_lens, + multi_modal_kwargs) + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + seq_lens: List[int] = [] + block_tables: List[List[int]] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + assert seq_group_metadata.token_chunk_size == 1 + + seq_ids = list(seq_group_metadata.seq_data.keys()) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append(generation_token) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append(position) + + seq_len = seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window) + seq_lens.append(seq_len) + + block_table = seq_group_metadata.block_tables[seq_id] + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + block_tables.append(block_table) + + max_decode_seq_len = max(seq_lens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + + block_tables = make_tensor_with_pad( + block_tables, + pad=0, + dtype=torch.int, + device=self.device, + ) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + slot_mapping=slot_mapping, + seq_lens=seq_lens, + seqlen_q=torch.tensor([]), + max_seqlen=0, + seq_lens_tensor=seq_lens_tensor, + max_decode_seq_len=max_decode_seq_len, + num_prefill_tokens=0, + num_decode_tokens=len(input_tokens), + num_prefills=0, + block_tables=block_tables, + ) + return ( + input_tokens, + input_positions, + attn_metadata, + ) + + +class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): + _model_input_cls: Type[ModelInputForXPUWithSamplingMetadata] = ( + ModelInputForXPUWithSamplingMetadata) + _builder_cls: Type[ModelInputForXPUBuilder] = ModelInputForXPUBuilder def __init__( self, @@ -84,30 +341,32 @@ def __init__( cache_config: CacheConfig, load_config: LoadConfig, lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], kv_cache_dtype: Optional[str] = "auto", - prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + return_hidden_states: bool = False, + observability_config: Optional[ObservabilityConfig] = None, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - *args, - **kwargs, ): self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config self.lora_config = lora_config self.load_config = load_config - self.cache_config = cache_config - self.prompt_adapter_config = prompt_adapter_config - self.multimodal_config = multimodal_config self.is_driver_worker = is_driver_worker + self.prompt_adapter_config = prompt_adapter_config + self.observability_config = observability_config + if self.observability_config is not None: + print(f"observability_config is {self.observability_config}") + self.return_hidden_states = return_hidden_states - self.sliding_window = model_config.get_sliding_window() - self.device_config = device_config self.device = self.device_config.device self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.attn_backend = get_attn_backend( @@ -203,166 +462,68 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - model_input = self.prepare_model_input(seqs) + finished_requests_ids = [seq.request_id for seq in seqs] + model_input = self.prepare_model_input( + seqs, finished_requests_ids=finished_requests_ids) self.execute_model(model_input, kv_caches) torch.xpu.synchronize() return def make_model_input_from_broadcasted_tensor_dict( - self, tensor_dict: Dict[str, Any]) -> ModelInputForXPU: - return (ModelInputForXPU.from_broadcasted_tensor_dict( - tensor_dict, - attn_backend=self.attn_backend, - )) - - def prepare_model_input( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForXPU: - multi_modal_kwargs = None - if self.is_driver_worker: - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt - # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_kwargs - ) = self._prepare_prompt(seq_group_metadata_list) - else: - (input_tokens, input_positions, - attn_metadata) = self._prepare_decode(seq_group_metadata_list) - seq_lens = [] - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - # subquery_lens is not needed if chunked prefill is not - # supported. Since CPU worker doesn't support chunked prefill - # just use seq_lens instead. - seq_lens, - self.device, - pin_memory=False, - generators=self.get_generators(finished_requests_ids)) - # Broadcast the metadata. - metadata_dict = { - "input_tokens": input_tokens, - "input_positions": input_positions, - "selected_token_indices": - sampling_metadata.selected_token_indices, - "multi_modal_kwargs": multi_modal_kwargs, - } - metadata_dict.update(attn_metadata.asdict_zerocopy()) - broadcast_tensor_dict(metadata_dict, src=0) - else: - metadata_dict = broadcast_tensor_dict(src=0) - input_tokens = metadata_dict.pop("input_tokens") - input_positions = metadata_dict.pop("input_positions") - selected_token_indices = metadata_dict.pop( - "selected_token_indices") - multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") - attn_metadata = self.attn_backend.make_metadata(**metadata_dict) - sampling_metadata = SamplingMetadata( - seq_groups=None, - selected_token_indices=selected_token_indices, - categorized_sample_indices=None, - num_prompts=0, - ) - - return ModelInputForXPU(input_tokens=input_tokens, - input_positions=input_positions, - attn_metadata=attn_metadata, - sampling_metadata=sampling_metadata, - multi_modal_kwargs=multi_modal_kwargs) + tensor_dict: Dict[str, + Any]) -> ModelInputForXPUWithSamplingMetadata: + return ( + ModelInputForXPUWithSamplingMetadata.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + )) - def _prepare_decode( + def _prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[int] = [] - input_positions: List[int] = [] - slot_mapping: List[int] = [] - seq_lens: List[int] = [] - block_tables: List[List[int]] = [] - + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForXPUWithSamplingMetadata: + """Helper method to prepare the model input based on a given sequence + group. Prepares metadata needed for the base model forward pass but not + metadata for possible additional steps, e.g., sampling. + + """ + builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) for seq_group_metadata in seq_group_metadata_list: - assert not seq_group_metadata.is_prompt - assert seq_group_metadata.token_chunk_size == 1 + builder.add_seq_group(seq_group_metadata) - seq_ids = list(seq_group_metadata.seq_data.keys()) + return builder.build() # type: ignore - for seq_id in seq_ids: - seq_data = seq_group_metadata.seq_data[seq_id] - generation_token = seq_data.get_last_token_id() - input_tokens.append(generation_token) - - seq_len = seq_data.get_len() - position = seq_len - 1 - input_positions.append(position) - - seq_len = seq_len if self.sliding_window is None else min( - seq_len, self.sliding_window) - seq_lens.append(seq_len) - - block_table = seq_group_metadata.block_tables[seq_id] - block_number = block_table[position // self.block_size] - block_offset = position % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - if self.sliding_window is not None: - sliding_window_blocks = (self.sliding_window // - self.block_size) - block_table = block_table[-sliding_window_blocks:] - block_tables.append(block_table) - - max_decode_seq_len = max(seq_lens) - - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) - - block_tables = make_tensor_with_pad( - block_tables, - pad=0, - dtype=torch.int, - device=self.device, - ) - - attn_metadata = self.attn_backend.make_metadata( - is_prompt=False, - slot_mapping=slot_mapping, - seq_lens=seq_lens, - seqlen_q=None, - max_seqlen=None, - seq_lens_tensor=seq_lens_tensor, - max_decode_seq_len=max_decode_seq_len, - num_prefill_tokens=0, - num_decode_tokens=len(input_tokens), - num_prefills=0, - block_tables=block_tables, - ) - return ( - input_tokens, - input_positions, - attn_metadata, - ) + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForXPUWithSamplingMetadata: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + """ + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + # Sampling metadata is only required for the final pp group + generators = self.get_generators(finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + pin_memory=False, + generators=generators) + + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + virtual_engine=virtual_engine) @torch.inference_mode() def execute_model( self, - model_input: ModelInputForXPU, + model_input: ModelInputForXPUWithSamplingMetadata, kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, @@ -372,20 +533,21 @@ def execute_model( "XPUModelRunner does not support multi-step execution.") model_executable = self.model - execute_model_kwargs = { - "input_ids": - model_input.input_tokens, - "positions": - model_input.input_positions, - "kv_caches": - kv_caches, - "attn_metadata": - model_input.attn_metadata, + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_start_time = time.time() + + hidden_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), - } - - hidden_states = model_executable(**execute_model_kwargs) + device=self.device)) + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_end_time = time.time() # Compute the logits. logits = self.model.compute_logits(hidden_states, @@ -396,109 +558,19 @@ def execute_model( return [] # Sample the next token. - output = self.model.sample( + output: SamplerOutput = self.model.sample( logits=logits, sampling_metadata=model_input.sampling_metadata, ) - return [output] - - def _prepare_prompt( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - BatchedTensorInputs]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[int] = [] - input_positions: List[int] = [] - slot_mapping: List[int] = [] - seq_lens: List[int] = [] - multi_modal_inputs_list: List[MultiModalInputs] = [] - - for seq_group_metadata in seq_group_metadata_list: - assert seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - - seq_data = seq_group_metadata.seq_data[seq_id] - prompt_tokens = seq_data.get_token_ids() - computed_len = seq_data.get_num_computed_tokens() - seq_len = len(prompt_tokens) + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time + and output is not None): + model_forward_time = (model_forward_end_time - + model_forward_start_time) + # If there are multiple workers, we are still tracking the latency + # from the start time of the driver worker to the end time of the + # driver worker. The model forward time will then end up covering + # the communication time as well. + output.model_forward_time = model_forward_time - seq_lens.append(seq_len) # Prompt token num - input_tokens.extend(prompt_tokens) # Token ids - - # Token position ids - # NOTE(woosuk): Here we assume that the first token in the prompt - # is always the first token in the sequence. - input_positions.extend(list(range(computed_len, seq_len))) - - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - mm_kwargs = self.multi_modal_input_mapper(mm_data) - multi_modal_inputs_list.append(mm_kwargs) - - if seq_group_metadata.block_tables is None: - # During memory profiling, the block tables are not initialized - # yet. In this case, we just use a dummy slot mapping. - slot_mapping.extend([_PAD_SLOT_ID] * seq_len) - continue - - # Compute the slot mapping. - block_table = seq_group_metadata.block_tables[seq_id] - # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, seq_len - sliding_window). - # For example, if the prompt len is 10, sliding window is 8, and - # block size is 4, the first two tokens are masked and the slot - # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - start_idx = 0 - if self.sliding_window is not None: - start_idx = max(0, seq_len - self.sliding_window) - - for i in range(computed_len, seq_len): - if i < start_idx: - slot_mapping.append(_PAD_SLOT_ID) - continue - - block_number = block_table[i // - self.block_size] # type: ignore - block_offset = i % self.block_size # type: ignore - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - num_prompt_tokens = len(input_tokens) - - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) # type: ignore - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) # type: ignore - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) # type: ignore - - max_seqlen = max(seq_lens) - tmp = [0] - tmp.extend(seq_lens) - seqlen = torch.tensor(tmp) - seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device) - - attn_metadata = self.attn_backend.make_metadata( - is_prompt=True, - slot_mapping=slot_mapping, - seq_lens=seq_lens, - seqlen_q=seqlen_q, - max_seqlen=max_seqlen, - seq_lens_tensor=None, - max_decode_seq_len=None, - num_prefills=len(seq_lens), - num_prefill_tokens=num_prompt_tokens, - num_decode_tokens=0, - block_tables=torch.tensor([], device=self.device, dtype=torch.int), - ) - - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) - - return (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_kwargs) + return [output] diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index 7c8f5e0cf65ec..b00d1889f8d4b 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -9,8 +9,8 @@ import torch.distributed from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, MultiModalConfig, ObservabilityConfig, - ParallelConfig, PromptAdapterConfig, SchedulerConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, SpeculativeConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) @@ -46,7 +46,6 @@ def __init__( rank: int, distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, - multimodal_config: Optional[MultiModalConfig] = None, speculative_config: Optional[SpeculativeConfig] = None, prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, @@ -73,8 +72,6 @@ def __init__( assert rank % parallel_config.tensor_parallel_size == 0, \ "Driver worker should be rank 0 of tensor parallel group." - self.multimodal_config = multimodal_config - self.model_runner = XPUModelRunner( # type: ignore model_config, parallel_config, @@ -85,7 +82,7 @@ def __init__( lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, - multimodal_config=multimodal_config, + observability_config=self.observability_config, ) # Uninitialized cache engine. Will be initialized by # initialize_cache.