From a7534ecfbdb7ac4646d0a40503d7b7d029ac553f Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Mon, 18 Nov 2024 09:28:10 +0000 Subject: [PATCH] fix commends --- .buildkite/run-cpu-test.sh | 4 +- .../getting_started/cpu-installation.rst | 1 + .../basic_correctness/test_chunked_prefill.py | 4 +- vllm/attention/backends/torch_sdpa.py | 109 +++++- vllm/worker/cpu_model_runner.py | 310 +++++++----------- 5 files changed, 233 insertions(+), 195 deletions(-) diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index 4b214e6abf730..6e5a61061ee20 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -60,7 +60,7 @@ function cpu_tests() { # Run chunked-prefill and prefix-cache test docker exec cpu-test bash -c " set -e - pytest -s -v -k cpu_only \ + pytest -s -v -k cpu_model \ tests/basic_correctness/test_chunked_prefill.py" # online inference @@ -81,4 +81,4 @@ function cpu_tests() { # All of CPU tests are expected to be finished less than 25 mins. export -f cpu_tests -timeout 25m bash -c "cpu_tests $CORE_RANGE" +timeout 30m bash -c "cpu_tests $CORE_RANGE" diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst index 6f49d9fa78dd6..649de1cd9b53c 100644 --- a/docs/source/getting_started/cpu-installation.rst +++ b/docs/source/getting_started/cpu-installation.rst @@ -9,6 +9,7 @@ vLLM initially supports basic model inferencing and serving on x86 CPU platform, - Model Quantization (``INT8 W8A8, AWQ``) - Chunked-prefill - Prefix-caching +- FP8-E5M2 KV-Caching (TODO) Table of contents: diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index ee3ed1f9e8853..ffe8de0ba0849 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -274,7 +274,7 @@ def test_with_prefix_caching( # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) @pytest.mark.parametrize("attention_backend", ["TORCH_SDPA"]) -@pytest.mark.cpu_only +@pytest.mark.cpu_model @pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") def test_models_cpu( hf_runner, @@ -311,7 +311,7 @@ def test_models_cpu( # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) @pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.cpu_only +@pytest.mark.cpu_model @pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") def test_with_prefix_caching_cpu( vllm_runner, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index b3981edf3de01..00b72f893e78d 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -7,10 +7,14 @@ from torch.nn.functional import scaled_dot_product_attention from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType) from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.ipex_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttentionMetadata +from vllm.utils import make_tensor_with_pad +from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder class TorchSDPABackend(AttentionBackend): @@ -31,6 +35,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def get_state_cls() -> Type["CommonAttentionState"]: return CommonAttentionState + @staticmethod + def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]: + return TorchSDPAMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -266,6 +274,105 @@ def get_seq_len_block_table_args( raise AttributeError(f"Invalid attention type {str(attn_type)}") +class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]): + + def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: + self.chunked_prefill = input_builder.chunked_prefill + self.input_data = input_builder.input_data + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata: + input_data = self.input_data + prefill_seq_lens = seq_lens[0:input_data.num_prefills] + prefill_query_lens = query_lens[0:input_data.num_prefills] + slot_mapping = torch.tensor(input_data.slot_mapping, + dtype=torch.long, + device="cpu") + + # For chunked-prefill + if self.chunked_prefill and input_data.num_prefill_tokens != 0: + prefill_block_tables = make_tensor_with_pad( + self.input_data.prefill_block_tables, + pad=0, + dtype=torch.int, + device="cpu", + ) + query_lens_tensor = torch.tensor(prefill_query_lens, + dtype=torch.int32, + device="cpu") + kv_lens_tensor = torch.tensor(prefill_seq_lens, + dtype=torch.int32, + device="cpu") + query_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + kv_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + torch.cumsum(query_lens_tensor, + dim=0, + dtype=torch.int32, + out=query_start_loc[1:]) + torch.cumsum(kv_lens_tensor, + dim=0, + dtype=torch.int32, + out=kv_start_loc[1:]) + max_query_len = max(prefill_query_lens) + max_kv_len = max(prefill_seq_lens) + else: + prefill_block_tables = None + query_start_loc = None + kv_start_loc = None + max_query_len = None + max_kv_len = None + + # For paged attention + if input_data.num_decode_tokens != 0: + seq_lens_tensor = torch.tensor( + input_data.seq_lens[input_data.num_prefills:], + dtype=torch.int, + device="cpu", + ) + block_tables = make_tensor_with_pad( + self.input_data.decode_block_tables, + pad=0, + dtype=torch.int, + device="cpu", + ) + else: + block_tables = torch.tensor([]) + seq_lens_tensor = torch.tensor([]) + + # For multi-modal models + placeholder_index_maps = None + if len(input_data.multi_modal_inputs_list) != 0: + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + input_data.multi_modal_placeholder_maps.items() + } + + attn_metadata = TorchSDPAMetadata( + chunked_prefill=self.chunked_prefill, + seq_lens=prefill_seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_kv_len=max_kv_len, + query_start_loc=query_start_loc, + kv_start_loc=kv_start_loc, + max_decode_seq_len=input_data.max_decode_seq_len, + num_prefills=input_data.num_prefills, + num_prefill_tokens=input_data.num_prefill_tokens, + num_decode_tokens=input_data.num_decode_tokens, + block_tables=block_tables, + prefill_block_tables=prefill_block_tables, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, + ) + + return attn_metadata + + class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): def __init__( diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 635377dbe0fc8..7d566c45ac2a9 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -19,7 +19,6 @@ MultiModalKwargs, MultiModalPlaceholderMap) from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) -from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -134,6 +133,7 @@ def __init__(self, super().__init__() self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] self.runner = runner + self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled or runner.cache_config.enable_prefix_caching) self.model_input_cls = self.runner._model_input_cls @@ -141,6 +141,8 @@ def __init__(self, self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper self.input_data = ModelInputForCPUBuilder.ModelInputData( self.runner.model_config.uses_mrope) + self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()( + self) def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): self.seq_group_metadata_list.append(seq_group_metadata) @@ -153,8 +155,6 @@ def build(self) -> ModelInputForCPU: self._build_input_data() input_data = self.input_data - prefill_seq_lens = input_data.seq_lens[0:input_data.num_prefills] - prefill_query_lens = input_data.query_lens[0:input_data.num_prefills] input_tokens = torch.tensor(input_data.input_tokens, dtype=torch.long, device="cpu") @@ -163,93 +163,15 @@ def build(self) -> ModelInputForCPU: if not input_data.use_mrope else input_data.input_mrope_positions, dtype=torch.long, device="cpu") - slot_mapping = torch.tensor(input_data.slot_mapping, - dtype=torch.long, - device="cpu") - - # For chunked-prefill - if self.chunked_prefill and input_data.num_prefill_tokens != 0: - prefill_block_tables = make_tensor_with_pad( - self.input_data.prefill_block_tables, - pad=0, - dtype=torch.int, - device="cpu", - ) - query_lens_tensor = torch.tensor(prefill_query_lens, - dtype=torch.int32, - device="cpu") - kv_lens_tensor = torch.tensor(prefill_seq_lens, - dtype=torch.int32, - device="cpu") - query_start_loc = torch.zeros(input_data.num_prefills + 1, - dtype=torch.int32, - device="cpu") - kv_start_loc = torch.zeros(input_data.num_prefills + 1, - dtype=torch.int32, - device="cpu") - torch.cumsum(query_lens_tensor, - dim=0, - dtype=torch.int32, - out=query_start_loc[1:]) - torch.cumsum(kv_lens_tensor, - dim=0, - dtype=torch.int32, - out=kv_start_loc[1:]) - max_query_len = max(prefill_query_lens) - max_kv_len = max(prefill_seq_lens) - else: - prefill_block_tables = None - query_start_loc = None - kv_start_loc = None - max_query_len = None - max_kv_len = None - - # For paged attention - if input_data.num_decode_tokens != 0: - seq_lens_tensor = torch.tensor( - input_data.seq_lens[input_data.num_prefills:], - dtype=torch.int, - device="cpu", - ) - block_tables = make_tensor_with_pad( - self.input_data.decode_block_tables, - pad=0, - dtype=torch.int, - device="cpu", - ) - else: - block_tables = torch.tensor([]) - seq_lens_tensor = torch.tensor([]) # For multi-modal models multi_modal_kwargs = None - placeholder_index_maps = None if len(input_data.multi_modal_inputs_list) != 0: multi_modal_kwargs = MultiModalKwargs.batch( input_data.multi_modal_inputs_list) - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - input_data.multi_modal_placeholder_maps.items() - } - - attn_metadata = self.runner.attn_backend.make_metadata( - chunked_prefill=self.chunked_prefill, - seq_lens=prefill_seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_kv_len=max_kv_len, - query_start_loc=query_start_loc, - kv_start_loc=kv_start_loc, - max_decode_seq_len=input_data.max_decode_seq_len, - num_prefills=input_data.num_prefills, - num_prefill_tokens=input_data.num_prefill_tokens, - num_decode_tokens=input_data.num_decode_tokens, - block_tables=block_tables, - prefill_block_tables=prefill_block_tables, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=placeholder_index_maps, - ) + + attn_metadata = self.att_metadata_builder.build( + input_data.seq_lens, input_data.query_lens, -1, -1) return self.model_input_cls( input_tokens=input_tokens, @@ -263,115 +185,126 @@ def build(self) -> ModelInputForCPU: def _build_input_data(self): for seq_group_metadata in self.seq_group_metadata_list: for seq_id, seq_data in seq_group_metadata.seq_data.items(): - self._compute_input_tokens(self.input_data, seq_group_metadata, - seq_data, seq_id) - if (seq_group_metadata.is_prompt - and seq_group_metadata.multi_modal_data): - self._compute_multi_modal_input(seq_group_metadata, - seq_data) - - def _compute_input_tokens(self, data: ModelInputData, - seq_group_metadata: SequenceGroupMetadata, - seq_data: SequenceData, seq_id: int): + if seq_group_metadata.is_prompt: + self._compute_prompt_input_tokens(self.input_data, + seq_group_metadata, + seq_data, seq_id) + if seq_group_metadata.multi_modal_data: + self._compute_multi_modal_input( + seq_group_metadata, seq_data) + else: + self._compute_decode_input_tokens(self.input_data, + seq_group_metadata, + seq_data, seq_id) + + def _compute_decode_input_tokens(self, data: ModelInputData, + seq_group_metadata: SequenceGroupMetadata, + seq_data: SequenceData, seq_id: int): """ - Compute input tokens, positions, block table and slot mapping. + Compute decode input tokens, positions, block table and slot mapping. """ - is_prompt = seq_group_metadata.is_prompt - token_chunk_size = seq_group_metadata.token_chunk_size block_size = self.runner.block_size block_table = seq_group_metadata.block_tables[seq_id] seq_len = seq_data.get_len() context_len = seq_data.get_num_computed_tokens() - if is_prompt: - seq_len = min(seq_len, context_len + token_chunk_size) - - # For prefix caching - prefix_cache_block_num = len( - seq_group_metadata.computed_block_nums) - if prefix_cache_block_num > 0: - prefix_cache_len = (prefix_cache_block_num * - self.runner.block_size) - if prefix_cache_len <= context_len: - # We already passed the cache hit region, - # so do normal computation. - pass - elif context_len < prefix_cache_len < seq_len: - # Partial hit. Compute the missing part. - context_len = prefix_cache_len - token_chunk_size = seq_len - context_len - elif seq_len <= prefix_cache_len: - # Full hit. Only compute the last token to avoid - # erroneous behavior. FIXME: Ideally we should directly - # mark all tokens as computed in the scheduler and do not - # schedule this sequence, so this case should not happen. - context_len = seq_len - 1 - token_chunk_size = 1 - - tokens = seq_data.get_token_ids() - tokens = tokens[context_len:seq_len] - token_positions = range(context_len, seq_len) - - # For encoder-only models, the block_table is None, - # and there is no need to initialize the slot_mapping. - if block_table is not None: - slot_mapping = [_PAD_SLOT_ID] * len(token_positions) - for i, pos in enumerate(token_positions): - block_number = block_table[pos // block_size] - block_offset = pos % block_size - slot = block_number * block_size + block_offset - slot_mapping[i] = slot - data.slot_mapping.extend(slot_mapping) - - # The MRPOE positions are prepared in _compute_multi_modal_input - if data.input_positions is not None: - data.input_positions.extend(token_positions) - - # Update fields - data.input_tokens.extend(tokens) - data.num_prefills += 1 - data.num_prefill_tokens += len(tokens) - data.query_lens.append(len(tokens)) - data.prefill_block_tables.append(block_table) + + tokens = seq_data.get_last_token_id() + token_positions = seq_len - 1 + block_number = block_table[token_positions // block_size] + block_offset = token_positions % block_size + slot = block_number * block_size + block_offset + + # For paged_attention kernel + if self.runner.sliding_window: + start_idx = max(0, seq_len - self.runner.sliding_window) + start_block = start_idx // block_size + start_idx = start_block * block_size + seq_len = seq_len - start_idx + block_table = block_table[start_block:] + + # For MRotaryEmbedding + if data.input_positions is None: + next_pos = MRotaryEmbedding.get_next_input_positions( + seq_data.mrope_position_delta, + context_len, + seq_len, + ) + for idx in range(3): + data.input_mrope_positions[idx].extend( # type: ignore + next_pos[idx]) else: - tokens = seq_data.get_last_token_id() - token_positions = seq_len - 1 - block_number = block_table[token_positions // block_size] - block_offset = token_positions % block_size - slot = block_number * block_size + block_offset - - # For paged_attention kernel - if self.runner.sliding_window: - start_idx = max(0, seq_len - self.runner.sliding_window) - start_block = start_idx // block_size - start_idx = start_block * block_size - seq_len = seq_len - start_idx - block_table = block_table[start_block:] - - # For MRotaryEmbedding - if data.input_positions is None: - next_pos = MRotaryEmbedding.get_next_input_positions( - seq_data.mrope_position_delta, - context_len, - seq_len, - ) - data.input_mrope_positions[0].extend( # type: ignore - next_pos[0]) - data.input_mrope_positions[1].extend( # type: ignore - next_pos[1]) - data.input_mrope_positions[2].extend( # type: ignore - next_pos[2]) - else: - data.input_positions.append(token_positions) # type: ignore - - # Update fields - data.input_tokens.append(tokens) - data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len) - data.num_decode_tokens += 1 - data.slot_mapping.append(slot) - data.decode_block_tables.append(block_table) - data.query_lens.append(1) + data.input_positions.append(token_positions) # type: ignore + + # Update fields + data.input_tokens.append(tokens) + data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len) + data.num_decode_tokens += 1 + data.slot_mapping.append(slot) + data.decode_block_tables.append(block_table) + data.query_lens.append(1) + data.seq_lens.append(seq_len) + + def _compute_prompt_input_tokens(self, data: ModelInputData, + seq_group_metadata: SequenceGroupMetadata, + seq_data: SequenceData, seq_id: int): + """ + Compute prompt input tokens, positions, block table and slot mapping. + """ + token_chunk_size = seq_group_metadata.token_chunk_size + block_size = self.runner.block_size + block_table = seq_group_metadata.block_tables[seq_id] + seq_len = seq_data.get_len() + context_len = seq_data.get_num_computed_tokens() + seq_len = min(seq_len, context_len + token_chunk_size) + + # For prefix caching + prefix_cache_block_num = len(seq_group_metadata.computed_block_nums) + if prefix_cache_block_num > 0: + prefix_cache_len = (prefix_cache_block_num * + self.runner.block_size) + if prefix_cache_len <= context_len: + # We already passed the cache hit region, + # so do normal computation. + pass + elif context_len < prefix_cache_len < seq_len: + # Partial hit. Compute the missing part. + context_len = prefix_cache_len + token_chunk_size = seq_len - context_len + elif seq_len <= prefix_cache_len: + # Full hit. Only compute the last token to avoid + # erroneous behavior. FIXME: Ideally we should directly + # mark all tokens as computed in the scheduler and do not + # schedule this sequence, so this case should not happen. + context_len = seq_len - 1 + token_chunk_size = 1 + + tokens = seq_data.get_token_ids() + tokens = tokens[context_len:seq_len] + token_positions = range(context_len, seq_len) + + # For encoder-only models, the block_table is None, + # and there is no need to initialize the slot_mapping. + if block_table is not None: + slot_mapping = [_PAD_SLOT_ID] * len(token_positions) + for i, pos in enumerate(token_positions): + block_number = block_table[pos // block_size] + block_offset = pos % block_size + slot = block_number * block_size + block_offset + slot_mapping[i] = slot + data.slot_mapping.extend(slot_mapping) + + # The MRPOE positions are prepared in _compute_multi_modal_input + if data.input_positions is not None: + data.input_positions.extend(token_positions) + + # Update fields + data.input_tokens.extend(tokens) + data.num_prefills += 1 + data.num_prefill_tokens += len(tokens) + data.query_lens.append(len(tokens)) + data.prefill_block_tables.append(block_table) data.seq_lens.append(seq_len) def _compute_multi_modal_input(self, @@ -425,12 +358,9 @@ def _compute_multi_modal_input(self, ) seq_data.mrope_position_delta = mrope_position_delta - self.input_data.input_mrope_positions[0].extend( # type: ignore - mrope_positions[0]) - self.input_data.input_mrope_positions[1].extend( # type: ignore - mrope_positions[1]) - self.input_data.input_mrope_positions[2].extend( # type: ignore - mrope_positions[2]) + for i in range(3): + self.input_data.input_mrope_positions[ # type: ignore + i].extend(mrope_positions[i]) self.input_data.multi_modal_inputs_list.append(mm_kwargs) for modality, placeholder_map in placeholder_maps.items():