From ef527be06c4064f3a2753a3b2c7ede862fe459e8 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 5 Aug 2024 16:41:27 -0700 Subject: [PATCH] [MISC] Use non-blocking transfer in prepare_input (#7172) --- vllm/attention/backends/flash_attn.py | 27 ++++++++++++--------------- vllm/attention/backends/flashinfer.py | 23 +++++++++++------------ vllm/attention/backends/utils.py | 27 ++++++++++++--------------- vllm/worker/model_runner.py | 15 ++++++++------- 4 files changed, 43 insertions(+), 49 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 26b3159682b3e..8a895bbdc2dd7 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -13,7 +13,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) -from vllm.utils import make_tensor_with_pad +from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUBuilder @@ -310,7 +310,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], for i, block_table in enumerate(self.block_tables): if block_table: input_block_tables[i, :len(block_table)] = block_table - block_tables = torch.tensor(input_block_tables, device=device) + block_tables = torch.from_numpy(input_block_tables).to( + device=device, non_blocking=True) else: block_tables = make_tensor_with_pad( self.block_tables, @@ -320,15 +321,15 @@ def build(self, seq_lens: List[int], query_lens: List[int], ) assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - context_lens_tensor = torch.tensor(self.context_lens, - dtype=torch.int, - device=device) - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=device) + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) @@ -344,10 +345,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=query_start_loc.dtype, out=query_start_loc[1:]) - slot_mapping_tensor = torch.tensor(self.slot_mapping, - dtype=torch.long, - device=device) - return FlashAttentionMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 91abaab78dcb8..03188164a9637 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -21,7 +21,8 @@ compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.attention.ops.paged_attn import PagedAttention -from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad +from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, + make_tensor_with_pad) if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUBuilder @@ -356,7 +357,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], for i, block_table in enumerate(self.block_tables): if block_table: input_block_tables[i, :len(block_table)] = block_table - block_tables = torch.tensor(input_block_tables, device=device) + block_tables = torch.from_numpy(input_block_tables).to( + device, non_blocking=True) last_paged_kv_indptr = self.paged_kv_indptr[-1] self.paged_kv_indptr.extend([last_paged_kv_indptr] * @@ -371,12 +373,13 @@ def build(self, seq_lens: List[int], query_lens: List[int], ) assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=device) + assert device is not None + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) @@ -392,10 +395,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=query_start_loc.dtype, out=query_start_loc[1:]) - slot_mapping_tensor = torch.tensor(self.slot_mapping, - dtype=torch.long, - device=device) - if len(self.paged_kv_indptr) > 0: paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, device="cpu", diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index bca1370343b7b..f7cb2ee996501 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -4,7 +4,7 @@ import torch from vllm.attention import AttentionMetadata, AttentionMetadataBuilder -from vllm.utils import make_tensor_with_pad +from vllm.utils import async_tensor_h2d, make_tensor_with_pad # Error string(s) for encoder/decoder # unsupported attention scenarios @@ -181,7 +181,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], for i, block_table in enumerate(self.block_tables): if block_table: input_block_tables[i, :len(block_table)] = block_table - block_tables = torch.tensor(input_block_tables, device=device) + block_tables = torch.from_numpy(input_block_tables).to( + device, non_blocking=True) else: block_tables = make_tensor_with_pad( self.block_tables, @@ -191,15 +192,15 @@ def build(self, seq_lens: List[int], query_lens: List[int], ) assert max_query_len > 0, "query_lens: {}".format(query_lens) - context_lens_tensor = torch.tensor(self.context_lens, - dtype=torch.int, - device=device) - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=device) + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) @@ -215,10 +216,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=query_start_loc.dtype, out=query_start_loc[1:]) - slot_mapping_tensor = torch.tensor(self.slot_mapping, - dtype=torch.long, - device=device) - return self._metadata_cls( # type: ignore num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f9c26e0c318b1..8b744a438e81a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -50,7 +50,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) -from vllm.utils import (CudaMemoryProfiler, flatten_2d_lists, +from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, flatten_2d_lists, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available) from vllm.worker.model_runner_base import ( @@ -549,12 +549,13 @@ def build(self) -> ModelInputForGPU: # Tokens and positions. input_tokens.extend([0] * cuda_graph_pad_size) input_positions.extend([0] * cuda_graph_pad_size) - input_tokens_tensor = torch.tensor(input_tokens, - dtype=torch.long, - device=self.runner.device) - input_positions_tensor = torch.tensor(input_positions, - dtype=torch.long, - device=self.runner.device) + assert self.runner.device is not None + input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, + self.runner.device, + self.runner.pin_memory) + input_positions_tensor = async_tensor_h2d(input_positions, torch.long, + self.runner.device, + self.runner.pin_memory) # Sequence and query lengths. seq_lens.extend([1] * cuda_graph_pad_size)