diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 2975a41797e9f..26da0d89def29 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,6 +1,7 @@ """Attention layer with FlashAttention.""" from collections import defaultdict from dataclasses import dataclass +from itertools import accumulate from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch @@ -503,6 +504,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) num_seqs = len(seq_lens) if use_captured_graph: @@ -525,29 +528,18 @@ def build(self, seq_lens: List[int], query_lens: List[int], device, self.runner.pin_memory) seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, self.runner.pin_memory) - query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, - self.runner.pin_memory) slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, device, self.runner.pin_memory) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) placeholder_index_maps = { modality: placeholder_map.index_map() for modality, placeholder_map in self.multimodal_placeholder_maps.items() } - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) return FlashAttentionMetadata( num_prefills=self.num_prefills, @@ -561,8 +553,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], max_decode_query_len=max_decode_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 096c920c4833a..12800668af223 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,6 +1,7 @@ """Attention backend utils""" from collections import defaultdict from contextlib import contextmanager +from itertools import accumulate from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union import numpy as np @@ -216,6 +217,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) @@ -244,29 +247,18 @@ def build(self, seq_lens: List[int], query_lens: List[int], device, self.runner.pin_memory) seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, self.runner.pin_memory) - query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, - self.runner.pin_memory) slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, device, self.runner.pin_memory) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) placeholder_index_maps = { modality: placeholder_map.index_map() for modality, placeholder_map in self.multimodal_placeholder_maps.items() } - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) return self._metadata_cls( # type: ignore num_prefills=self.num_prefills, @@ -279,8 +271,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], max_query_len=max_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph,