Skip to content

Commit

Permalink
[MISC] Use non-blocking transfer in prepare_input (vllm-project#7172)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored Aug 5, 2024
1 parent 229f431 commit f8f4d5b
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 49 deletions.
27 changes: 12 additions & 15 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.utils import make_tensor_with_pad
from vllm.utils import async_tensor_h2d, make_tensor_with_pad

if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
Expand Down Expand Up @@ -310,7 +310,8 @@ def build(self, seq_lens: List[int], query_lens: List[int],
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
block_tables = torch.from_numpy(input_block_tables).to(
device=device, non_blocking=True)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
Expand All @@ -320,15 +321,15 @@ def build(self, seq_lens: List[int], query_lens: List[int],
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))

context_lens_tensor = torch.tensor(self.context_lens,
dtype=torch.int,
device=device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
Expand All @@ -344,10 +345,6 @@ def build(self, seq_lens: List[int], query_lens: List[int],
dtype=query_start_loc.dtype,
out=query_start_loc[1:])

slot_mapping_tensor = torch.tensor(self.slot_mapping,
dtype=torch.long,
device=device)

return FlashAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
Expand Down
23 changes: 11 additions & 12 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad)

if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
Expand Down Expand Up @@ -356,7 +357,8 @@ def build(self, seq_lens: List[int], query_lens: List[int],
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
block_tables = torch.from_numpy(input_block_tables).to(
device, non_blocking=True)

last_paged_kv_indptr = self.paged_kv_indptr[-1]
self.paged_kv_indptr.extend([last_paged_kv_indptr] *
Expand All @@ -371,12 +373,13 @@ def build(self, seq_lens: List[int], query_lens: List[int],
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))

seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
assert device is not None
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
Expand All @@ -392,10 +395,6 @@ def build(self, seq_lens: List[int], query_lens: List[int],
dtype=query_start_loc.dtype,
out=query_start_loc[1:])

slot_mapping_tensor = torch.tensor(self.slot_mapping,
dtype=torch.long,
device=device)

if len(self.paged_kv_indptr) > 0:
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
device="cpu",
Expand Down
27 changes: 12 additions & 15 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.utils import make_tensor_with_pad
from vllm.utils import async_tensor_h2d, make_tensor_with_pad

# Error string(s) for encoder/decoder
# unsupported attention scenarios
Expand Down Expand Up @@ -181,7 +181,8 @@ def build(self, seq_lens: List[int], query_lens: List[int],
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
block_tables = torch.from_numpy(input_block_tables).to(
device, non_blocking=True)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
Expand All @@ -191,15 +192,15 @@ def build(self, seq_lens: List[int], query_lens: List[int],
)
assert max_query_len > 0, "query_lens: {}".format(query_lens)

context_lens_tensor = torch.tensor(self.context_lens,
dtype=torch.int,
device=device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
Expand All @@ -215,10 +216,6 @@ def build(self, seq_lens: List[int], query_lens: List[int],
dtype=query_start_loc.dtype,
out=query_start_loc[1:])

slot_mapping_tensor = torch.tensor(self.slot_mapping,
dtype=torch.long,
device=device)

return self._metadata_cls( # type: ignore
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
Expand Down
15 changes: 8 additions & 7 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, flatten_2d_lists,
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, flatten_2d_lists,
get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available)
from vllm.worker.model_runner_base import (
Expand Down Expand Up @@ -549,12 +549,13 @@ def build(self) -> ModelInputForGPU:
# Tokens and positions.
input_tokens.extend([0] * cuda_graph_pad_size)
input_positions.extend([0] * cuda_graph_pad_size)
input_tokens_tensor = torch.tensor(input_tokens,
dtype=torch.long,
device=self.runner.device)
input_positions_tensor = torch.tensor(input_positions,
dtype=torch.long,
device=self.runner.device)
assert self.runner.device is not None
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
self.runner.device,
self.runner.pin_memory)
input_positions_tensor = async_tensor_h2d(input_positions, torch.long,
self.runner.device,
self.runner.pin_memory)

# Sequence and query lengths.
seq_lens.extend([1] * cuda_graph_pad_size)
Expand Down

0 comments on commit f8f4d5b

Please sign in to comment.