diff --git a/requirements-tpu.txt b/requirements-tpu.txt index f9a0770804e55..3d1e80f6be620 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -16,8 +16,8 @@ ray[default] --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.6.0.dev20241028+cpu -torchvision==0.20.0.dev20241028+cpu -torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl +torch==2.6.0.dev20241114+cpu +torchvision==0.20.0.dev20241114+cpu +torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241114-cp310-cp310-linux_x86_64.whl jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 6fee81de14420..eeab8731a2c39 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -65,6 +65,7 @@ class PallasMetadata(AttentionMetadata): # or all decoding. block_tables: Optional[torch.Tensor] = None context_lens: Optional[torch.Tensor] = None + effective_query_lens: Optional[torch.Tensor] = None @property def prefill_metadata(self) -> Optional["PallasMetadata"]: @@ -72,8 +73,6 @@ def prefill_metadata(self) -> Optional["PallasMetadata"]: return None assert self.num_decode_tokens == 0 - assert self.block_tables is None - assert self.context_lens is None return self @property @@ -186,29 +185,50 @@ def forward( query = query * self.scale if attn_metadata.num_prefills > 0: - assert seq_len % 16 == 0, ( - "Pallas FlashAttention kernel requires seq_len to be a " - f"multiple of 16 but got {seq_len}") - - # Handle GQA/MQA. - if self.num_kv_heads != self.num_heads: - key = key.repeat_interleave(self.num_queries_per_kv, dim=-2) - key = key.view(batch_size, seq_len, self.num_heads, - self.head_size) - value = value.repeat_interleave(self.num_queries_per_kv, + if attn_metadata.block_tables is None: + # Prefill without paged KV cache. + assert seq_len % 16 == 0, ( + "Pallas FlashAttention kernel requires seq_len to be a " + f"multiple of 16 but got {seq_len}") + + # Handle GQA/MQA. + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=-2) - value = value.view(batch_size, seq_len, self.num_heads, + key = key.view(batch_size, seq_len, self.num_heads, self.head_size) - # FlashAttention requires [batch_size, num_heads, seq_len, d_model] - # while the input is [batch_size, seq_len, num_heads, d_model]. - # Permute the input to match the required format. - output = torch.ops.xla.flash_attention( - query.permute(0, 2, 1, 3), - key.permute(0, 2, 1, 3), - value.permute(0, 2, 1, 3), - True, - ) - output = output.permute(0, 2, 1, 3) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=-2) + value = value.view(batch_size, seq_len, self.num_heads, + self.head_size) + # FlashAttention kernel requires the input shape to be + # [batch_size, num_heads, seq_len, d_model] + # while the input is [batch_size, seq_len, num_heads, d_model]. + # Permute the input to match the required format. + output = torch.ops.xla.flash_attention( + query.permute(0, 2, 1, 3), + key.permute(0, 2, 1, 3), + value.permute(0, 2, 1, 3), + True, + ) + output = output.permute(0, 2, 1, 3) + else: + # Prefill with paged KV cache. + # TODO(woosuk): Tune the below knobs. + num_kv_pages_per_compute_block = 16 + num_queries_per_compute_block = 16 + assert seq_len % num_queries_per_compute_block == 0 + output = torch.ops.xla.multi_queries_paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + attn_metadata.effective_query_lens, + num_kv_pages_per_compute_block, + num_queries_per_compute_block, + use_kernel=True, + ) else: # Decoding run. assert kv_cache[0].numel() > 0 diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index d7a641857a613..9a054eb8a4cf7 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -1,3 +1,4 @@ +import enum import time from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, @@ -11,7 +12,6 @@ import torch_xla.runtime as xr from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput @@ -39,6 +39,15 @@ _MAX_NUM_SAMPLES = 128 +class ExecutionMode(enum.Enum): + PREFILL = enum.auto() + DECODE = enum.auto() + PREFIX_PREFILL = enum.auto() + + def is_prefill(self) -> bool: + return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) + + @dataclass(frozen=True) class ModelInputForTPU(ModelRunnerInputBase): token_ids: torch.Tensor @@ -140,16 +149,21 @@ def load_model(self) -> None: model = get_model(vllm_config=self.vllm_config) model = model.eval() xm.wait_device_ops() - self.model = ModelWrapper(model, self.vllm_config) + model = ModelWrapper(model) + self.model = torch.compile(model, + backend="openxla", + fullgraph=True, + dynamic=False) def _dummy_run( self, batch_size: int, seq_len: int, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - is_prompt: bool, + exec_mode: ExecutionMode, ) -> None: - if is_prompt: + exec_mode = ExecutionMode(exec_mode) + if exec_mode.is_prefill(): seq_len = (seq_len + 15) // 16 * 16 token_ids = torch.zeros((batch_size, seq_len), dtype=torch.int32, @@ -160,18 +174,38 @@ def _dummy_run( slot_mapping = torch.zeros((batch_size, seq_len), dtype=torch.int64, device=self.device) - attn_metadata = self.attn_backend.make_metadata( - num_prefills=batch_size, - num_prefill_tokens=batch_size * seq_len, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - block_tables=None, - context_lens=None, - ) input_lens = torch.ones((batch_size, ), dtype=torch.int32, device=self.device) + if exec_mode == ExecutionMode.PREFILL: + attn_metadata = self.attn_backend.make_metadata( + num_prefills=batch_size, + num_prefill_tokens=batch_size * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + block_tables=None, + context_lens=None, + effective_query_lens=None, + ) + else: + context_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + block_tables = torch.tensor(self.block_tables[:batch_size], + dtype=torch.int32, + device=self.device) + effective_query_lens = torch.ones_like(context_lens) + attn_metadata = self.attn_backend.make_metadata( + num_prefills=batch_size, + num_prefill_tokens=batch_size * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + block_tables=block_tables, + context_lens=context_lens, + effective_query_lens=effective_query_lens, + ) else: assert seq_len == 1 token_ids = torch.zeros((batch_size, seq_len), @@ -204,7 +238,7 @@ def _dummy_run( ) t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) - num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 + num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1 # NOTE(woosuk): There are two stages of compilation: torch.compile and # XLA compilation. Using `mark_dynamic` can reduce the torch.compile @@ -213,7 +247,7 @@ def _dummy_run( # be re-compiled for every different shapes. This overhead is inevitable # in the first run, but can be skipped afterwards as we cache the XLA # graphs in the disk (VLLM_XLA_CACHE_PATH). - if is_prompt: + if exec_mode.is_prefill(): # Prefll torch._dynamo.mark_dynamic(token_ids, 1) torch._dynamo.mark_dynamic(position_ids, 1) @@ -229,15 +263,8 @@ def _dummy_run( torch._dynamo.mark_dynamic(t, 0) torch._dynamo.mark_dynamic(p, 0) # Dummy run. - self.model(token_ids, - position_ids, - attn_metadata, - input_lens, - t, - p, - num_samples, - kv_caches, - is_prompt=is_prompt) + self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, + num_samples, kv_caches) def warmup_model( self, @@ -248,13 +275,13 @@ def warmup_model( start = time.time() for batch_size in [1]: seq_len = 16 - while True: - self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=True) + while seq_len <= self.model_config.max_model_len: + self._dummy_run(batch_size, + seq_len, + kv_caches, + exec_mode=ExecutionMode.PREFILL) xm.wait_device_ops() logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) - - if seq_len >= self.model_config.max_model_len: - break num_tokens = batch_size * seq_len if num_tokens >= self.scheduler_config.max_num_batched_tokens: break @@ -263,12 +290,39 @@ def warmup_model( end = time.time() logger.info("Compilation for prefill done in %.2f s.", end - start) + # Prefix prefill + if self.cache_config.enable_prefix_caching: + logger.info("Compiling the model with different input shapes for " + "prefix prefill...") + start = time.time() + for batch_size in [1]: + seq_len = 16 + while seq_len <= self.model_config.max_model_len: + self._dummy_run(batch_size, + seq_len, + kv_caches, + exec_mode=ExecutionMode.PREFIX_PREFILL) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, + seq_len) + num_tokens = batch_size * seq_len + if (num_tokens >= + self.scheduler_config.max_num_batched_tokens): + break + seq_len = seq_len * 2 + end = time.time() + logger.info("Compilation for prefix prefill done in %.2f s.", + end - start) + # Decode start = time.time() seq_len = 1 batch_size = 8 # Must be in sync with _get_padded_batch_size() while True: - self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False) + self._dummy_run(batch_size, + seq_len, + kv_caches, + exec_mode=ExecutionMode.DECODE) xm.wait_device_ops() logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) @@ -287,9 +341,11 @@ def _prepare_prompt( input_tokens: List[int] = [] input_positions: List[int] = [] prompt_lens: List[int] = [] + context_lens: List[int] = [] slot_mapping: List[int] = [] - for seq_group_metadata in seq_group_metadata_list: + for batch_idx, seq_group_metadata in enumerate( + seq_group_metadata_list): assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) assert len(seq_ids) == 1 @@ -298,19 +354,31 @@ def _prepare_prompt( seq_data = seq_group_metadata.seq_data[seq_id] # Could include output tokens when a request is preempted. prompt_tokens = seq_data.get_token_ids() + seq_len = len(prompt_tokens) + + num_computed_blocks = len(seq_group_metadata.computed_block_nums) + num_computed_tokens = num_computed_blocks * self.block_size + if num_computed_tokens > 0: + prompt_tokens = prompt_tokens[num_computed_tokens:] + context_lens.append(seq_len) + else: + context_lens.append(0) + prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) input_tokens.extend(prompt_tokens) - input_positions.extend(list(range(prompt_len))) + input_positions.extend(range(num_computed_tokens, seq_len)) assert seq_group_metadata.block_tables is not None block_table = seq_group_metadata.block_tables[seq_id] - for i in range(prompt_len): + for i in range(num_computed_tokens, seq_len): block_number = block_table[i // self.block_size] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset slot_mapping.append(slot) + if num_computed_tokens > 0: + self.block_tables[batch_idx, :len(block_table)] = block_table # Add paddings to EACH prompt to the smallest power of 2 that is # greater than or equal to the prompt length. @@ -338,14 +406,21 @@ def _prepare_prompt( prompt_lens = torch.tensor(prompt_lens, dtype=torch.int32, device="cpu") + context_lens = torch.tensor(context_lens, + dtype=torch.int32, + device="cpu") + block_tables = torch.tensor(self.block_tables[:num_prefills], + dtype=torch.int32, + device="cpu") attn_metadata = self.attn_backend.make_metadata( num_prefills=num_prefills, num_prefill_tokens=0, # NOTE: This is not used. num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, - block_tables=None, - context_lens=None, + block_tables=block_tables, + context_lens=context_lens, + effective_query_lens=prompt_lens, ) return input_tokens, input_positions, attn_metadata, prompt_lens @@ -550,6 +625,10 @@ def execute_model( # process them separately. This is a temporary hack that should be # optimized by using SplashAttention. orig_slot_mapping = model_input.attn_metadata.slot_mapping + orig_block_tables = model_input.attn_metadata.block_tables + orig_context_lens = model_input.attn_metadata.context_lens + orig_effective_query_lens = \ + model_input.attn_metadata.effective_query_lens batch_size = model_input.input_lens.shape[0] start_idx = 0 next_token_ids = [] @@ -568,18 +647,24 @@ def execute_model( attn_metadata.num_prefills = 1 attn_metadata.slot_mapping = orig_slot_mapping[ None, start_idx:end_idx].to(self.device) + if orig_context_lens[i].item() > 0: + attn_metadata.context_lens = orig_context_lens[i:i + 1].to( + self.device) + attn_metadata.block_tables = orig_block_tables[ + i].unsqueeze(0).to(self.device) + attn_metadata.effective_query_lens = \ + orig_effective_query_lens[i:i + 1].to(self.device) + else: + attn_metadata.context_lens = None + attn_metadata.block_tables = None + attn_metadata.effective_query_lens = None input_lens = model_input.input_lens[i:i + 1].to(self.device) t = model_input.t[i:i + 1].to(self.device) p = model_input.p[i:i + 1].to(self.device) - output_token_ids = self.model(token_ids, - position_ids, - attn_metadata, - input_lens, - t, - p, + output_token_ids = self.model(token_ids, position_ids, + attn_metadata, input_lens, t, p, model_input.num_samples, - kv_caches, - is_prompt=True) + kv_caches) next_token_ids.append(output_token_ids[0]) start_idx = end_idx @@ -624,15 +709,10 @@ def execute_model( input_lens = model_input.input_lens.to(self.device) for i in range(num_steps): slot_mapping = attn_metadata.slot_mapping - output_token_ids = self.model(token_ids, - position_ids, - attn_metadata, - input_lens, - t, - p, + output_token_ids = self.model(token_ids, position_ids, + attn_metadata, input_lens, t, p, model_input.num_samples, - kv_caches, - is_prompt=False) + kv_caches) self.cached_step_outputs.append(output_token_ids) if i < num_steps - 1: @@ -667,34 +747,11 @@ def execute_model( return [sampler_output] -class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): +class ModelWrapper(nn.Module): - def __init__(self, model: nn.Module, vllm_config: VllmConfig): + def __init__(self, model: nn.Module): + super().__init__() self.model = model - compiled_callable = torch.compile(self.forward, - backend="openxla", - fullgraph=True, - dynamic=False) - super().__init__( - compiled_callable, - compilation_level=vllm_config.compilation_config.level) - - def __call__(self, *args, is_prompt: bool, **kwargs): - if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: - # not fully compiled yet, or not using the custom dispatcher, - # let PyTorch handle it - return self.compiled_callable(*args, **kwargs) - # the 3 compiled codes are: - # 0: for profiling - # 1: for prompt - # 2: for decode - # dispatch to the compiled code directly, skip PyTorch - if is_prompt: - with self.dispatch_to_code(1): - return self.forward(*args, **kwargs) - else: - with self.dispatch_to_code(2): - return self.forward(*args, **kwargs) def forward( self, diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 096cb23416909..8754f7538f251 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -13,7 +13,7 @@ from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size -from vllm.worker.tpu_model_runner import TPUModelRunner +from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, LoraNotSupportedWorkerBase, WorkerBase, WorkerInput) @@ -112,7 +112,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: batch_size=1, seq_len=self.scheduler_config.max_num_batched_tokens, kv_caches=kv_caches, - is_prompt=True, + exec_mode=ExecutionMode.PREFILL, ) # Synchronize before measuring the memory usage. xm.wait_device_ops()