From 64759431c5ff54157fc3a3e256e4adcdec394506 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Wed, 20 Nov 2024 18:57:39 +0800 Subject: [PATCH] [Hardware][CPU] Support chunked-prefill and prefix-caching on CPU (#10355) Signed-off-by: jiang1.li Signed-off-by: Tyler Michael Smith --- .buildkite/run-cpu-test.sh | 9 +- .../getting_started/cpu-installation.rst | 10 +- docs/source/serving/compatibility_matrix.rst | 4 +- .../basic_correctness/test_chunked_prefill.py | 63 ++- vllm/attention/backends/torch_sdpa.py | 189 +++++-- vllm/attention/ops/ipex_attn.py | 150 ++++-- vllm/platforms/cpu.py | 15 +- vllm/worker/cpu_model_runner.py | 488 ++++++++---------- 8 files changed, 559 insertions(+), 369 deletions(-) diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index f0128f091b742..4f1729d46dae2 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -25,6 +25,7 @@ docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/hugg function cpu_tests() { set -e + export NUMA_NODE=$2 # offline inference docker exec cpu-test-avx2-"$NUMA_NODE" bash -c " @@ -57,6 +58,12 @@ function cpu_tests() { pytest -s -v \ tests/quantization/test_ipex_quant.py" + # Run chunked-prefill and prefix-cache test + docker exec cpu-test-"$NUMA_NODE" bash -c " + set -e + pytest -s -v -k cpu_model \ + tests/basic_correctness/test_chunked_prefill.py" + # online inference docker exec cpu-test-"$NUMA_NODE" bash -c " set -e @@ -75,4 +82,4 @@ function cpu_tests() { # All of CPU tests are expected to be finished less than 25 mins. export -f cpu_tests -timeout 25m bash -c "cpu_tests $CORE_RANGE" +timeout 30m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst index 69530fd778c55..649de1cd9b53c 100644 --- a/docs/source/getting_started/cpu-installation.rst +++ b/docs/source/getting_started/cpu-installation.rst @@ -5,11 +5,11 @@ Installation with CPU vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. vLLM CPU backend supports the following vLLM features: -- Tensor Parallel (``-tp = N``) -- Quantization (``INT8 W8A8, AWQ``) - -.. note:: - More advanced features on `chunked-prefill`, `prefix-caching` and `FP8 KV cache` are under development and will be available soon. +- Tensor Parallel +- Model Quantization (``INT8 W8A8, AWQ``) +- Chunked-prefill +- Prefix-caching +- FP8-E5M2 KV-Caching (TODO) Table of contents: diff --git a/docs/source/serving/compatibility_matrix.rst b/docs/source/serving/compatibility_matrix.rst index 5fc86ab0a11d5..a4300761d2635 100644 --- a/docs/source/serving/compatibility_matrix.rst +++ b/docs/source/serving/compatibility_matrix.rst @@ -344,7 +344,7 @@ Feature x Hardware - ✅ - ✅ - ✅ - - ✗ + - ✅ - ✅ * - :ref:`APC ` - `✗ `__ @@ -352,7 +352,7 @@ Feature x Hardware - ✅ - ✅ - ✅ - - ✗ + - ✅ - ✅ * - :ref:`LoRA ` - ✅ diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index cc5bc2aca27c9..469d18a4dd7af 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -12,6 +12,7 @@ import pytest from tests.kernels.utils import override_backend_env_variable +from vllm.platforms import current_platform from ..models.utils import check_logprobs_close, check_outputs_equal from ..utils import multi_gpu_test @@ -206,12 +207,14 @@ def test_models_with_fp8_kv_cache( # NOTE: Increasing this in this suite will fail CI because we currently cannot # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) +@pytest.mark.parametrize("dtype", ["half"]) def test_with_prefix_caching( vllm_runner, max_tokens: int, enforce_eager: bool, chunk_size: int, tensor_parallel_size: int, + dtype: str, ) -> None: """ Checks exact match decode with and without prefix caching @@ -233,7 +236,7 @@ def test_with_prefix_caching( for enable in (True, False): with vllm_runner( model, - dtype="half", + dtype=dtype, max_num_batched_tokens=max_num_batched_tokens, enable_chunked_prefill=True, enable_prefix_caching=enable, @@ -260,3 +263,61 @@ def test_with_prefix_caching( name_0="w/o prefix caching", name_1="with prefix caching", ) + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +@pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("attention_backend", ["TORCH_SDPA"]) +@pytest.mark.cpu_model +@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") +def test_models_cpu( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, + enforce_eager: bool, + attention_backend: str, + monkeypatch, +) -> None: + test_models( + hf_runner, + vllm_runner, + example_prompts, + model, + dtype, + max_tokens, + chunked_prefill_token_size, + enforce_eager, + 1, + attention_backend, + monkeypatch, + ) + + +@pytest.mark.parametrize("max_tokens", [16]) +@pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("chunk_size", [30, 32]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.cpu_model +@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") +def test_with_prefix_caching_cpu( + vllm_runner, + max_tokens: int, + enforce_eager: bool, + chunk_size: int, + dtype: str, +) -> None: + test_with_prefix_caching( + vllm_runner, + max_tokens, + enforce_eager, + chunk_size, + 1, + dtype, + ) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 563178d3ab60d..3d025df26a7a1 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -7,18 +7,14 @@ from torch.nn.functional import scaled_dot_product_attention from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType) from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.ops.ipex_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttentionMetadata -from vllm.platforms import current_platform - -if current_platform.is_cpu(): - try: - from vllm.attention.ops.ipex_attn import PagedAttention - except ImportError: - from vllm.attention.ops.paged_attn import PagedAttention -else: - from vllm.attention.ops.paged_attn import PagedAttention +from vllm.utils import make_tensor_with_pad +from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder class TorchSDPABackend(AttentionBackend): @@ -39,6 +35,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def get_state_cls() -> Type["CommonAttentionState"]: return CommonAttentionState + @staticmethod + def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]: + return TorchSDPAMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -71,9 +71,15 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): """ # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. - is_prompt: bool - slot_mapping: torch.Tensor - seq_lens: Optional[List[int]] + chunked_prefill: bool + seq_lens: Optional[List[int]] = None # For non-chunked prefill + + # For chunked prefill only + max_query_len: Optional[int] = None + max_kv_len: Optional[int] = None + query_start_loc: Optional[torch.Tensor] = None + kv_start_loc: Optional[torch.Tensor] = None + prefill_block_tables: Optional[torch.Tensor] = None # Begin encoder attn & enc/dec cross-attn fields... # Encoder sequence lengths representation @@ -123,20 +129,14 @@ def is_all_cross_attn_metadata_set(self): @property def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: - # Currently chunked prefill is not supported - if self.num_decode_tokens == 0: - assert self.num_prefills > 0 - return self - - return None + if self.num_prefill_tokens == 0: + return None + return self @property def decode_metadata(self) -> Optional["TorchSDPAMetadata"]: - # Currently chunked prefill is not supported - if self.num_prefills > 0: - assert self.num_decode_tokens == 0 + if self.num_decode_tokens == 0: return None - return self def get_seq_lens( @@ -274,6 +274,105 @@ def get_seq_len_block_table_args( raise AttributeError(f"Invalid attention type {str(attn_type)}") +class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]): + + def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: + self.chunked_prefill = input_builder.chunked_prefill + self.input_data = input_builder.input_data + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata: + input_data = self.input_data + prefill_seq_lens = seq_lens[0:input_data.num_prefills] + prefill_query_lens = query_lens[0:input_data.num_prefills] + slot_mapping = torch.tensor(input_data.slot_mapping, + dtype=torch.long, + device="cpu") + + # For chunked-prefill + if self.chunked_prefill and input_data.num_prefill_tokens != 0: + prefill_block_tables = make_tensor_with_pad( + self.input_data.prefill_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + query_lens_tensor = torch.tensor(prefill_query_lens, + dtype=torch.int32, + device="cpu") + kv_lens_tensor = torch.tensor(prefill_seq_lens, + dtype=torch.int32, + device="cpu") + query_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + kv_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + torch.cumsum(query_lens_tensor, + dim=0, + dtype=torch.int32, + out=query_start_loc[1:]) + torch.cumsum(kv_lens_tensor, + dim=0, + dtype=torch.int32, + out=kv_start_loc[1:]) + max_query_len = max(prefill_query_lens) + max_kv_len = max(prefill_seq_lens) + else: + prefill_block_tables = None + query_start_loc = None + kv_start_loc = None + max_query_len = None + max_kv_len = None + + # For paged attention + if input_data.num_decode_tokens != 0: + seq_lens_tensor = torch.tensor( + input_data.seq_lens[input_data.num_prefills:], + dtype=torch.int32, + device="cpu", + ) + block_tables = make_tensor_with_pad( + self.input_data.decode_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + else: + block_tables = torch.tensor([]) + seq_lens_tensor = torch.tensor([]) + + # For multi-modal models + placeholder_index_maps = None + if len(input_data.multi_modal_inputs_list) != 0: + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + input_data.multi_modal_placeholder_maps.items() + } + + attn_metadata = TorchSDPAMetadata( + chunked_prefill=self.chunked_prefill, + seq_lens=prefill_seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_kv_len=max_kv_len, + query_start_loc=query_start_loc, + kv_start_loc=kv_start_loc, + max_decode_seq_len=input_data.max_decode_seq_len, + num_prefills=input_data.num_prefills, + num_prefill_tokens=input_data.num_prefill_tokens, + num_decode_tokens=input_data.num_decode_tokens, + block_tables=block_tables, + prefill_block_tables=prefill_block_tables, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, + ) + + return attn_metadata + + class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): def __init__( @@ -409,19 +508,35 @@ def forward( assert key.shape[0] == num_prefill_tokens + num_decode_tokens assert value.shape[0] == num_prefill_tokens + num_decode_tokens + output = torch.empty_like(query) if prefill_meta := attn_metadata.prefill_metadata: assert attn_metadata.seq_lens is not None - if (kv_cache.numel() == 0 - or prefill_meta.block_tables.numel() == 0): - output = self._run_sdpa_forward(query, - key, - value, - prefill_meta, - attn_type=attn_type) + if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore + self._run_sdpa_forward(output, + query, + key, + value, + prefill_meta, + attn_type=attn_type) else: # prefix-enabled attention - raise RuntimeError( - "Torch SDPA backend doesn't support prefix decoding.") + assert not self.need_mask + import intel_extension_for_pytorch.llm.modules as ipex_modules + output = torch.empty_like(query) + ipex_modules.PagedAttention.flash_attn_varlen_func( + output[:prefill_meta.num_prefill_tokens, :, :], + query[:prefill_meta.num_prefill_tokens, :, :], + key_cache, + value_cache, + prefill_meta.query_start_loc, + prefill_meta.kv_start_loc, + prefill_meta.max_query_len, + prefill_meta.max_kv_len, + self.scale, + True, + prefill_meta.prefill_block_tables, + self.alibi_slopes, + ) if decode_meta := attn_metadata.decode_metadata: assert attn_type != AttentionType.ENCODER_ONLY, ( @@ -433,8 +548,9 @@ def forward( block_tables_arg, ) = decode_meta.get_seq_len_block_table_args(attn_type) - output = PagedAttention.forward_decode( - query, + PagedAttention.forward_decode( + output[attn_metadata.num_prefill_tokens:, :, :], + query[attn_metadata.num_prefill_tokens:, :, :], key_cache, value_cache, block_tables_arg, @@ -453,12 +569,13 @@ def forward( def _run_sdpa_forward( self, + output: torch.Tensor, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: TorchSDPAMetadata, attn_type: AttentionType = AttentionType.DECODER, - ): + ) -> None: if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) value = value.repeat_interleave(self.num_queries_per_kv, dim=1) @@ -479,7 +596,6 @@ def _run_sdpa_forward( attn_masks = [None] * len(seq_lens) attn_metadata.set_attn_bias(attn_masks, attn_type) - output = torch.empty_like(query) query = query.movedim(0, query.dim() - 2) key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) @@ -502,7 +618,6 @@ def _run_sdpa_forward( scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) output[start_q:end_q, :, :] = sub_out start_q, start_kv = end_q, end_kv - return output def _make_alibi_bias( diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index 8df6d4ced9dc6..cbc6c74acf09a 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -1,12 +1,17 @@ from typing import Dict, List, Optional, Tuple -import intel_extension_for_pytorch.llm.modules as ipex_modules +try: + import intel_extension_for_pytorch.llm.modules as ipex_modules + _use_ipex = True +except ImportError: + _use_ipex = False + import torch from vllm import _custom_ops as ops -class PagedAttention: +class _PagedAttention: @staticmethod def get_supported_head_sizes() -> List[int]: @@ -22,6 +27,105 @@ def get_kv_cache_shape( ) -> Tuple[int, ...]: return (2, num_blocks, block_size * num_kv_heads * head_size) + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + *args, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 16 // kv_cache.element_size() + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + *args, + ) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + + @staticmethod + def forward_decode( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: float, + v_scale: float, + *args, + ) -> None: + tp_rank: int = 0 + blocksparse_local_blocks: int = 0 + blocksparse_vert_stride: int = 0 + blocksparse_block_size: int = 64 + blocksparse_head_sliding_step: int = 0 + block_size = value_cache.shape[3] + + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + *args, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +class _IPEXPagedAttention(_PagedAttention): + @staticmethod def split_kv_cache( kv_cache: torch.Tensor, @@ -55,6 +159,7 @@ def write_to_paged_cache( @staticmethod def forward_decode( + output: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, @@ -68,8 +173,7 @@ def forward_decode( k_scale: float, v_scale: float, *args, - ) -> torch.Tensor: - output = torch.empty_like(query) + ) -> None: block_size = value_cache.shape[2] head_mapping = torch.arange( 0, @@ -83,41 +187,5 @@ def forward_decode( scale, block_tables, context_lens, block_size, max_context_len, alibi_slopes) - return output - - @staticmethod - def forward_prefix( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache_dtype: str, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - subquery_start_loc: torch.Tensor, - prompt_lens_tensor: torch.Tensor, - context_lens: torch.Tensor, - max_subquery_len: int, - alibi_slopes: Optional[torch.Tensor], - *args, - ) -> torch.Tensor: - raise NotImplementedError - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], - *args, - ) -> None: - raise NotImplementedError - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], - *args, - ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - ops.copy_blocks(key_caches, value_caches, src_to_dists) +PagedAttention = _IPEXPagedAttention if _use_ipex else _PagedAttention diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index f9a34a47959ec..43cbafe709d84 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -53,11 +53,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config = vllm_config.cache_config - if cache_config.enable_prefix_caching: - logger.warning( - "Prefix caching is not supported on CPU, disable it.") - cache_config.enable_prefix_caching = False - kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE if kv_cache_space >= 0: @@ -74,10 +69,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: f" {kv_cache_space}, expect a positive integer value.") scheduler_config = vllm_config.scheduler_config - if scheduler_config.chunked_prefill_enabled: - logger.warning( - "Chunked prefill is not supported on CPU, disable it.") - scheduler_config.chunked_prefill_enabled = False + if ((scheduler_config.chunked_prefill_enabled + or cache_config.enable_prefix_caching) + and model_config.dtype == torch.half): + logger.warning("Chunked-prefill on the CPU backend only does not" + " support fp16 for now, cast to bf16.") + model_config.dtype = torch.bfloat16 parallel_config = vllm_config.parallel_config if (parallel_config.distributed_executor_backend is not None diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index d3e1202c15e61..66bd844c94901 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -2,8 +2,8 @@ import weakref from collections import defaultdict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, - TypeVar, Union) +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, + Union) import torch from torch import nn @@ -19,7 +19,6 @@ MultiModalKwargs, MultiModalPlaceholderMap) from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) -from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -104,65 +103,223 @@ def from_broadcasted_tensor_dict( class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): + class ModelInputData: + + def __init__(self, use_mrope: bool): + self.use_mrope = use_mrope + self.input_tokens: List[int] = [] + self.input_positions: Optional[ + List[int]] = [] if not self.use_mrope else None + self.seq_lens: List[int] = [] + self.query_lens: List[int] = [] + self.prefill_block_tables: List[List[int]] = [] + self.decode_block_tables: List[List[int]] = [] + self.max_decode_seq_len: int = 0 + self.num_prefills: int = 0 + self.num_prefill_tokens: int = 0 + self.num_decode_tokens: int = 0 + self.slot_mapping: List[int] = [] + self.multi_modal_inputs_list: List[MultiModalKwargs] = [] + self.multi_modal_placeholder_maps: Dict[ + str, MultiModalPlaceholderMap] = defaultdict( + MultiModalPlaceholderMap) + self.input_mrope_positions: Optional[List[List[int]]] = [ + [] for _ in range(3) + ] if self.use_mrope else None + def __init__(self, runner: "CPUModelRunner", finished_requests_ids: Optional[List[str]] = None) -> None: super().__init__() self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] self.runner = runner + + self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled + or runner.cache_config.enable_prefix_caching) self.model_input_cls = self.runner._model_input_cls self.attn_backend = self.runner.attn_backend - self.sliding_window = self.runner.sliding_window - self.block_size = self.runner.block_size - self.device = self.runner.device self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper + self.input_data = ModelInputForCPUBuilder.ModelInputData( + self.runner.model_config.uses_mrope) + self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()( + self) def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): self.seq_group_metadata_list.append(seq_group_metadata) + def set_seq_group_list( + self, seq_group_metadata_list: List[SequenceGroupMetadata]): + self.seq_group_metadata_list = seq_group_metadata_list + def build(self) -> ModelInputForCPU: + self._build_input_data() + + input_data = self.input_data + input_tokens = torch.tensor(input_data.input_tokens, + dtype=torch.long, + device="cpu") + input_positions = torch.tensor( + input_data.input_positions + if not input_data.use_mrope else input_data.input_mrope_positions, + dtype=torch.long, + device="cpu") + + # For multi-modal models multi_modal_kwargs = None - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = self.seq_group_metadata_list[0].is_prompt - # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_kwargs) = self._prepare_prompt( - self.seq_group_metadata_list) - else: - (input_tokens, input_positions, - attn_metadata) = self._prepare_decode( - self.seq_group_metadata_list) - seq_lens = None + if len(input_data.multi_modal_inputs_list) != 0: + multi_modal_kwargs = MultiModalKwargs.batch( + input_data.multi_modal_inputs_list) + + attn_metadata = self.att_metadata_builder.build( + input_data.seq_lens, input_data.query_lens, -1, -1) return self.model_input_cls( input_tokens=input_tokens, input_positions=input_positions, + seq_lens=input_data.seq_lens, + query_lens=input_data.query_lens, attn_metadata=attn_metadata, multi_modal_kwargs=multi_modal_kwargs, - # query_lens is not needed if chunked prefill is not - # supported. Since CPU worker doesn't support chunked prefill - # just use seq_lens instead. - seq_lens=seq_lens, - query_lens=seq_lens, ) - def _compute_multi_modal_input( - self, - seq_data: SequenceData, - computed_len: int, - seq_group_metadata: SequenceGroupMetadata, - ): + def _build_input_data(self): + for seq_group_metadata in self.seq_group_metadata_list: + for seq_id, seq_data in seq_group_metadata.seq_data.items(): + if seq_group_metadata.is_prompt: + self._compute_prompt_input_tokens(self.input_data, + seq_group_metadata, + seq_data, seq_id) + if seq_group_metadata.multi_modal_data: + self._compute_multi_modal_input( + seq_group_metadata, seq_data) + else: + self._compute_decode_input_tokens(self.input_data, + seq_group_metadata, + seq_data, seq_id) + + def _compute_decode_input_tokens(self, data: ModelInputData, + seq_group_metadata: SequenceGroupMetadata, + seq_data: SequenceData, seq_id: int): + """ + Compute decode input tokens, positions, block table and slot mapping. + """ + block_size = self.runner.block_size + + block_table = seq_group_metadata.block_tables[seq_id] + seq_len = seq_data.get_len() + context_len = seq_data.get_num_computed_tokens() + + tokens = seq_data.get_last_token_id() + token_positions = seq_len - 1 + block_number = block_table[token_positions // block_size] + block_offset = token_positions % block_size + slot = block_number * block_size + block_offset + + # For paged_attention kernel + if self.runner.sliding_window: + start_idx = max(0, seq_len - self.runner.sliding_window) + start_block = start_idx // block_size + start_idx = start_block * block_size + seq_len = seq_len - start_idx + block_table = block_table[start_block:] + + # For MRotaryEmbedding + if data.input_positions is None: + next_pos = MRotaryEmbedding.get_next_input_positions( + seq_data.mrope_position_delta, + context_len, + seq_len, + ) + for idx in range(3): + data.input_mrope_positions[idx].extend( # type: ignore + next_pos[idx]) + else: + data.input_positions.append(token_positions) # type: ignore + + # Update fields + data.input_tokens.append(tokens) + data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len) + data.num_decode_tokens += 1 + data.slot_mapping.append(slot) + data.decode_block_tables.append(block_table) + data.query_lens.append(1) + data.seq_lens.append(seq_len) + + def _compute_prompt_input_tokens(self, data: ModelInputData, + seq_group_metadata: SequenceGroupMetadata, + seq_data: SequenceData, seq_id: int): + """ + Compute prompt input tokens, positions, block table and slot mapping. + """ + token_chunk_size = seq_group_metadata.token_chunk_size + block_size = self.runner.block_size + + block_table = seq_group_metadata.block_tables[seq_id] + seq_len = seq_data.get_len() + context_len = seq_data.get_num_computed_tokens() + seq_len = min(seq_len, context_len + token_chunk_size) + + # For prefix caching + prefix_cache_block_num = len(seq_group_metadata.computed_block_nums) + if prefix_cache_block_num > 0: + prefix_cache_len = (prefix_cache_block_num * + self.runner.block_size) + if prefix_cache_len <= context_len: + # We already passed the cache hit region, + # so do normal computation. + pass + elif context_len < prefix_cache_len < seq_len: + # Partial hit. Compute the missing part. + context_len = prefix_cache_len + token_chunk_size = seq_len - context_len + elif seq_len <= prefix_cache_len: + # Full hit. Only compute the last token to avoid + # erroneous behavior. FIXME: Ideally we should directly + # mark all tokens as computed in the scheduler and do not + # schedule this sequence, so this case should not happen. + context_len = seq_len - 1 + token_chunk_size = 1 + + tokens = seq_data.get_token_ids() + tokens = tokens[context_len:seq_len] + token_positions = range(context_len, seq_len) + + # For encoder-only models, the block_table is None, + # and there is no need to initialize the slot_mapping. + if block_table is not None: + slot_mapping = [_PAD_SLOT_ID] * len(token_positions) + for i, pos in enumerate(token_positions): + block_number = block_table[pos // block_size] + block_offset = pos % block_size + slot = block_number * block_size + block_offset + slot_mapping[i] = slot + data.slot_mapping.extend(slot_mapping) + + # The MROPE positions are prepared in _compute_multi_modal_input + if data.input_positions is not None: + data.input_positions.extend(token_positions) + + # Update fields + data.input_tokens.extend(tokens) + data.num_prefills += 1 + data.num_prefill_tokens += len(tokens) + data.query_lens.append(len(tokens)) + data.prefill_block_tables.append(block_table) + data.seq_lens.append(seq_len) + + def _compute_multi_modal_input(self, + seq_group_metadata: SequenceGroupMetadata, + seq_data: SequenceData): + computed_len = seq_data.get_num_computed_tokens() + seq_len = self.input_data.seq_lens[-1] + # NOTE: mm_data only includes the subset of multi-modal items that # intersect with the current prefill positions. mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( - seq_group_metadata, - range(computed_len, len(seq_data.get_token_ids())), - ) + seq_group_metadata, range(computed_len, seq_len)) if not mm_data: - return None, None, None + return if self.runner.mm_registry.has_processor(self.runner.model_config): mm_kwargs = mm_data @@ -173,8 +330,10 @@ def _compute_multi_modal_input( ) # special processing for mrope position deltas. - mrope_positions = None if self.runner.model_config.uses_mrope: + assert not self.chunked_prefill, \ + "MROPE on CPU does not support chunked-prefill." + image_grid_thw = mm_kwargs.get("image_grid_thw", None) video_grid_thw = mm_kwargs.get("video_grid_thw", None) assert image_grid_thw is not None or video_grid_thw is not None, ( @@ -198,226 +357,15 @@ def _compute_multi_modal_input( context_len=computed_len, ) seq_data.mrope_position_delta = mrope_position_delta - return mm_kwargs, placeholder_maps, mrope_positions - def _prepare_prompt( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - BatchedTensorInputs]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[int] = [] - input_positions: List[int] = [] - input_mrope_positions: List[List[int]] = [[] for _ in range(3)] - - slot_mapping: List[int] = [] - seq_lens: List[int] = [] - multi_modal_kwargs_list: List[MultiModalKwargs] = [] - multi_modal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) - - for seq_group_metadata in seq_group_metadata_list: - assert seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - - seq_data = seq_group_metadata.seq_data[seq_id] - prompt_tokens = seq_data.get_token_ids() - computed_len = seq_data.get_num_computed_tokens() - seq_len = len(prompt_tokens) - - seq_lens.append(seq_len) # Prompt token num - input_tokens.extend(prompt_tokens) # Token ids - - mrope_positions = None - if seq_group_metadata.multi_modal_data: - ( - mm_kwargs, - placeholder_maps, - mrope_positions, - ) = self._compute_multi_modal_input(seq_data, computed_len, - seq_group_metadata) - - multi_modal_kwargs_list.append(mm_kwargs) - for modality, placeholder_map in placeholder_maps.items(): - multi_modal_placeholder_maps[modality].extend( - placeholder_map) - - # Token position ids - # NOTE(woosuk): Here we assume that the first token in the prompt - # is always the first token in the sequence. - if mrope_positions: - for idx in range(3): - input_mrope_positions[idx].extend(mrope_positions[idx]) - else: - input_positions.extend(list(range(computed_len, seq_len))) - - # Compute the slot mapping. - block_table = seq_group_metadata.block_tables[seq_id] - # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, seq_len - sliding_window). - # For example, if the prompt len is 10, sliding window is 8, and - # block size is 4, the first two tokens are masked and the slot - # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - start_idx = 0 - if self.sliding_window is not None: - start_idx = max(0, seq_len - self.sliding_window) - - for i in range(computed_len, seq_len): - if i < start_idx: - slot_mapping.append(_PAD_SLOT_ID) - continue - - # For encoder-only models, the block_table is None, - # and there is no need to initialize the slot_mapping. - if block_table is not None: - block_number = block_table[i // - self.block_size] # type: ignore - block_offset = i % self.block_size # type: ignore - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - if any(input_mrope_positions): - input_positions = None # type: ignore - else: - input_mrope_positions = None # type: ignore + for i in range(3): + self.input_data.input_mrope_positions[ # type: ignore + i].extend(mrope_positions[i]) - num_prompt_tokens = len(input_tokens) - - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) # type: ignore - input_positions = torch.tensor(input_positions - or input_mrope_positions, - dtype=torch.long, - device=self.device) # type: ignore - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) # type: ignore - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - multi_modal_placeholder_maps.items() - } - - attn_metadata = self.attn_backend.make_metadata( - is_prompt=True, - seq_lens=seq_lens, - seq_lens_tensor=torch.tensor([]), - max_decode_seq_len=0, - num_prefills=len(seq_lens), - num_prefill_tokens=num_prompt_tokens, - num_decode_tokens=0, - block_tables=torch.tensor([]), - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=placeholder_index_maps, - ) - - multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) - - return (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_kwargs) - - def _prepare_decode( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[int] = [] - input_positions: List[int] = [] - input_mrope_positions: List[List[int]] = [[] for _ in range(3)] - slot_mapping: List[int] = [] - seq_lens: List[int] = [] - block_tables: List[List[int]] = [] - - for seq_group_metadata in seq_group_metadata_list: - assert not seq_group_metadata.is_prompt - assert seq_group_metadata.token_chunk_size == 1 - - seq_ids = list(seq_group_metadata.seq_data.keys()) - - for seq_id in seq_ids: - seq_data = seq_group_metadata.seq_data[seq_id] - generation_token = seq_data.get_last_token_id() - input_tokens.append(generation_token) - - seq_len = seq_data.get_len() - position = seq_len - 1 - if seq_data.mrope_position_delta is not None: - context_len = seq_data.get_num_computed_tokens() - next_pos = MRotaryEmbedding.get_next_input_positions( - seq_data.mrope_position_delta, - context_len, - seq_len, - ) - for idx in range(3): - input_mrope_positions[idx].extend(next_pos[idx]) - else: - input_positions.append(position) - - seq_len = seq_len if self.sliding_window is None else min( - seq_len, self.sliding_window) - seq_lens.append(seq_len) - - block_table = seq_group_metadata.block_tables[seq_id] - block_number = block_table[position // self.block_size] - block_offset = position % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - if self.sliding_window is not None: - sliding_window_blocks = (self.sliding_window // - self.block_size) - block_table = block_table[-sliding_window_blocks:] - block_tables.append(block_table) - - if any(input_mrope_positions): - input_positions = None # type: ignore - else: - input_mrope_positions = None # type: ignore - - max_decode_seq_len = max(seq_lens) - - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions - or input_mrope_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) - - block_tables = make_tensor_with_pad( - block_tables, - pad=0, - dtype=torch.int, - device=self.device, - ) - - attn_metadata = self.attn_backend.make_metadata( - is_prompt=False, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_decode_seq_len=max_decode_seq_len, - num_prefill_tokens=0, - num_decode_tokens=len(input_tokens), - num_prefills=0, - block_tables=block_tables, - ) - return ( - input_tokens, - input_positions, - attn_metadata, - ) + self.input_data.multi_modal_inputs_list.append(mm_kwargs) + for modality, placeholder_map in placeholder_maps.items(): + self.input_data.multi_modal_placeholder_maps[modality].extend( + placeholder_map) class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): @@ -436,8 +384,6 @@ def __init__( **kwargs, ): ModelRunnerBase.__init__(self, vllm_config) - # Currently, CPU worker doesn't support chunked prefill. - assert self.scheduler_config.chunked_prefill_enabled is False model_config = self.model_config cache_config = self.cache_config @@ -479,8 +425,7 @@ def _prepare_model_input_tensors( """ builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) - for seq_group_metadata in seq_group_metadata_list: - builder.add_seq_group(seq_group_metadata) + builder.set_seq_group_list(seq_group_metadata_list) return builder.build() # type: ignore @@ -537,22 +482,19 @@ def execute_model( "CPU worker does not support multi-step execution.") model_executable = self.model - execute_model_kwargs = { - "input_ids": - model_input.input_tokens, - "positions": - model_input.input_positions, - "kv_caches": - kv_caches, - "attn_metadata": - model_input.attn_metadata, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), - "intermediate_tensors": - intermediate_tensors, - } - - hidden_states = model_executable(**execute_model_kwargs) + multimodal_kwargs = {} + if model_input.multi_modal_kwargs is not None: + multimodal_kwargs = MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs, device=self.device) + + hidden_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **multimodal_kwargs, + ) # Compute the logits. logits = self.model.compute_logits(hidden_states,