From a9b15c606fea67a072416ea0ea115261a2756058 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 27 Sep 2024 08:11:32 -0700 Subject: [PATCH] [torch.compile] use empty tensor instead of None for profiling (#8875) --- tests/kernels/test_encoder_decoder_attn.py | 8 ++++++-- vllm/attention/backends/blocksparse_attn.py | 6 ++++-- vllm/attention/backends/flash_attn.py | 6 ++++-- vllm/attention/backends/flashinfer.py | 6 +++--- vllm/attention/backends/ipex_attn.py | 9 ++++++--- vllm/attention/backends/pallas.py | 12 +++++++----- vllm/attention/backends/rocm_flash_attn.py | 6 ++++-- vllm/attention/backends/torch_sdpa.py | 9 ++++++--- vllm/attention/backends/xformers.py | 8 +++++--- vllm/worker/embedding_model_runner.py | 8 +++++++- vllm/worker/enc_dec_model_runner.py | 8 +++++++- vllm/worker/model_runner.py | 8 +++++++- vllm/worker/tpu_model_runner.py | 4 ++-- vllm/worker/tpu_worker.py | 10 +++++++++- vllm/worker/xpu_model_runner.py | 8 +++++++- 15 files changed, 84 insertions(+), 32 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index b550a7fdd84f0..6b979d0558c46 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -136,7 +136,9 @@ class that Attention will automatically select when it is constructed. ) if test_pt.num_blocks is None or test_pt.num_heads is None: # Caller does not require a KV cache - return TestResources(scale, attn_backend, attn, None) + return TestResources( + scale, attn_backend, attn, + torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE)) # Construct KV cache kv_cache = make_kv_cache(test_pt.num_blocks, @@ -620,7 +622,9 @@ def _run_encoder_attention_test( return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, - None, + torch.tensor([], + dtype=torch.float32, + device=packed_qkv.query.device), attn_metadata, attn_type=attn_type) diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index d84a40890ebbd..656cfd124ab44 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -357,6 +357,8 @@ def forward( key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -373,7 +375,7 @@ def forward( key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) - if kv_cache is not None: + if kv_cache.numel() > 0: key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) @@ -399,7 +401,7 @@ def forward( # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - assert kv_cache is None \ + assert kv_cache.numel() == 0 \ or prefill_meta.block_tables is None \ or prefill_meta.block_tables.numel() == 0, \ "Does not support prefix-enabled attention." diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 084e8113cd421..22d07c0a4f689 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -665,6 +665,8 @@ def forward( key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -685,7 +687,7 @@ def forward( key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) - if kv_cache is not None: + if kv_cache.numel() > 0: key_cache = kv_cache[0] value_cache = kv_cache[1] @@ -722,7 +724,7 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if (kv_cache is None or prefill_meta.block_tables is None + if (kv_cache.numel() == 0 or prefill_meta.block_tables is None or prefill_meta.block_tables.numel() == 0): # normal attention # When block_tables are not filled, it means q and k are the diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 3a602fbfbbc04..784cff0d9878e 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -746,7 +746,7 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: Optional[torch.Tensor], + kv_cache: torch.Tensor, attn_metadata: FlashInferMetadata, k_scale: float = 1.0, v_scale: float = 1.0, @@ -770,7 +770,7 @@ def forward( if attn_metadata.num_decode_tokens > 0: assert attn_metadata.num_prefill_tokens == 0, ( "Chunked prefill is not supported with flashinfer yet.") - if kv_cache is not None: + if kv_cache.numel() > 0: # Use the same reshape and cache kernel as flash attention. ops.reshape_and_cache_flash( key, @@ -796,7 +796,7 @@ def forward( # when kv_cache is not provided. # This happens when vllm runs the profiling to # determine the number of blocks. - if kv_cache is None: + if kv_cache.numel() == 0: output = torch.ops.vllm.flash_attn_varlen_func( q=query, k=key, diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 113a2788eacd3..7398732ddfc92 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -167,7 +167,7 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: Optional[torch.Tensor], + kv_cache: torch.Tensor, attn_metadata: IpexAttnMetadata, # type: ignore k_scale: float = 1.0, v_scale: float = 1.0, @@ -180,6 +180,8 @@ def forward( key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -196,7 +198,7 @@ def forward( key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) - if kv_cache is not None: + if kv_cache.numel() > 0: key_cache, value_cache = self.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) ipex_ops.reshape_and_cache( @@ -212,7 +214,8 @@ def forward( if attn_metadata.is_prompt: assert attn_metadata.seq_lens is not None - if (kv_cache is None or attn_metadata.block_tables.numel() == 0): + if (kv_cache.numel() == 0 + or attn_metadata.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) value = value.repeat_interleave(self.num_queries_per_kv, diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index a8a78d41c666c..86716602985ac 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -143,7 +143,7 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]], + kv_cache: Tuple[torch.Tensor, torch.Tensor], attn_metadata: PallasMetadata, k_scale: float = 1.0, v_scale: float = 1.0, @@ -155,8 +155,10 @@ def forward( query: shape = [batch_size, seq_len, num_heads * head_size] key: shape = [batch_size, seq_len, num_kv_heads * head_size] value: shape = [batch_size, seq_len, num_kv_heads * head_size] - key_cache = [num_kv_heads, num_blocks, block_size, head_size] - value_cache = [num_kv_heads, num_blocks, block_size, head_size] + kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size] + kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size] + NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor + with shape [0] for profiling run. attn_metadata: Metadata for attention. Returns: shape = [batch_size, seq_len, num_heads * head_size] @@ -173,7 +175,7 @@ def forward( value = value.view(batch_size, seq_len, self.num_kv_heads, self.head_size) - if kv_cache[0] is not None: + if kv_cache[0].numel() > 0: slot_mapping = attn_metadata.slot_mapping key_cache, value_cache = kv_cache write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) @@ -205,7 +207,7 @@ def forward( output = output.permute(0, 2, 1, 3) else: # Decoding run. - assert kv_cache is not None + assert kv_cache[0].numel() > 0 pages_per_compute_block = 16 # TODO(woosuk): Tune this value. if self.megacore_mode == "batch" and batch_size % 2 != 0: diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 5560f44be4196..5ee3c3b69cf36 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -396,6 +396,8 @@ def forward( key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -412,7 +414,7 @@ def forward( key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) - if kv_cache is not None: + if kv_cache.numel() > 0: key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) @@ -449,7 +451,7 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. assert prefill_meta.seq_lens is not None - if kv_cache is None or prefill_meta.block_tables.numel() == 0: + if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: # triton attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 8a1f8f2930c84..2a215331704c1 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -151,7 +151,7 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: Optional[torch.Tensor], + kv_cache: torch.Tensor, attn_metadata: TorchSDPAMetadata, # type: ignore k_scale: float = 1.0, v_scale: float = 1.0, @@ -164,6 +164,8 @@ def forward( key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -180,7 +182,7 @@ def forward( key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) - if kv_cache is not None: + if kv_cache.numel() > 0: key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) PagedAttention.write_to_paged_cache(key, value, key_cache, @@ -191,7 +193,8 @@ def forward( if attn_metadata.is_prompt: assert attn_metadata.seq_lens is not None - if (kv_cache is None or attn_metadata.block_tables.numel() == 0): + if (kv_cache.numel() == 0 + or attn_metadata.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) value = value.repeat_interleave(self.num_queries_per_kv, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index e073d616bf01d..143fa6ee7dea4 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -445,7 +445,7 @@ def forward( query: torch.Tensor, key: Optional[torch.Tensor], value: Optional[torch.Tensor], - kv_cache: Optional[torch.Tensor], + kv_cache: torch.Tensor, attn_metadata: "XFormersMetadata", k_scale: float = 1.0, v_scale: float = 1.0, @@ -489,6 +489,8 @@ def forward( key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. attn_metadata: Metadata for attention. attn_type: Select attention type, between encoder attention, decoder self-attention, or encoder/decoder cross- @@ -522,7 +524,7 @@ def forward( # which KV cache memory-mapping & which # seqlen datastructures we utilize - if (attn_type != AttentionType.ENCODER and kv_cache is not None): + if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): # KV-cache during decoder-self- or # encoder-decoder-cross-attention, but not # during encoder attention. @@ -588,7 +590,7 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or prefill_meta.block_tables.numel() == 0: + if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: # normal attention. # block tables are empty if the prompt does not have a cached # prefix. diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 0121f5da79f1d..5c5d20a51e7da 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -97,7 +97,13 @@ def execute_model( model_executable = self.model num_layers = self.model_config.get_num_layers(self.parallel_config) - kv_caches = [None] * num_layers + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + kv_caches = [ + torch.tensor([], dtype=torch.float32, device=self.device) + ] * num_layers execute_model_kwargs = { "input_ids": diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index bd716ac3e7ec3..3bb4e28c6e1b6 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -340,7 +340,13 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) - kv_caches = [None] * num_layers + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + kv_caches = [ + torch.tensor([], dtype=torch.float32, device=self.device) + ] * num_layers finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0a90f767567d6..8c2e6c2d721b9 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1223,7 +1223,13 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) - kv_caches = [None] * num_layers + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + kv_caches = [ + torch.tensor([], dtype=torch.float32, device=self.device) + ] * num_layers finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 575769ca1aa4a..2472ac25aee44 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -714,7 +714,7 @@ def forward( t: torch.Tensor, p: torch.Tensor, num_samples: int, - kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> torch.Tensor: """Executes the forward pass of the model and samples the next token. @@ -745,7 +745,7 @@ def forward( ) # Skip this in memory profiling at initialization. - if kv_caches[0][0] is not None: + if kv_caches[0][0].numel() > 0: # index_copy_(slot_mapping) only works when the inserted dimension # is 0. However, the KV cache in the Pallas backend has the shape # [num_kv_heads, num_blocks, block_size, head_size]. To make it diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 9e0c522cee453..fe819b9f4b3a8 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -115,7 +115,15 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: head_size = self.model_config.get_head_size() num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) - kv_caches = [(None, None) for _ in range(num_layers)] + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + kv_caches = [(torch.tensor([], dtype=torch.float32, + device=self.device), + torch.tensor([], dtype=torch.float32, + device=self.device)) + for _ in range(num_layers)] self.model_runner._dummy_run( batch_size=1, seq_len=self.scheduler_config.max_num_batched_tokens, diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index d3c763c995b34..8282736cf479b 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -464,7 +464,13 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) - kv_caches = [None] * num_layers + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + kv_caches = [ + torch.tensor([], dtype=torch.float32, device=self.device) + ] * num_layers finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids)