From 98519e94c7096b81092f432757ed1ad2294b52aa Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 20 Jul 2024 12:17:24 +0800 Subject: [PATCH] [Misc] Consolidate and optimize logic for building padded tensors (#6541) --- tests/conftest.py | 12 ++--- vllm/attention/backends/flash_attn.py | 3 -- vllm/attention/backends/flashinfer.py | 3 -- vllm/attention/backends/utils.py | 3 -- vllm/model_executor/sampling_metadata.py | 56 ++++++++-------------- vllm/utils.py | 61 ++++++++++++++++++++---- vllm/worker/cpu_model_runner.py | 3 -- vllm/worker/neuron_model_runner.py | 8 ++-- vllm/worker/xpu_model_runner.py | 3 -- 9 files changed, 77 insertions(+), 75 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 71c4a539c4e8a..652d627377786 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,8 @@ from vllm.inputs import TextPrompt from vllm.logger import init_logger from vllm.sequence import SampleLogprobs -from vllm.utils import cuda_device_count_stateless, is_cpu +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, + is_cpu) logger = init_logger(__name__) @@ -124,12 +125,6 @@ def image_assets() -> _ImageAssets: return IMAGE_ASSETS -_STR_DTYPE_TO_TORCH_DTYPE = { - "half": torch.half, - "bfloat16": torch.bfloat16, - "float": torch.float, -} - _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding) @@ -151,8 +146,7 @@ def __init__( is_vision_model: bool = False, is_sparseml_model: bool = False, ) -> None: - assert dtype in _STR_DTYPE_TO_TORCH_DTYPE - torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] self.model_name = model_name diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index b8a64205b362b..cad3181d3edb7 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -306,11 +306,8 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, input_block_tables[i, :len(block_table)] = block_table block_tables = torch.tensor(input_block_tables, device=device) else: - max_block_table_len = max( - len(block_table) for block_table in self.block_tables) block_tables = make_tensor_with_pad( self.block_tables, - max_len=max_block_table_len, pad=0, dtype=torch.int, device=device, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 9c25b2cc2ba97..eb8b1f0fcfb39 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -344,11 +344,8 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, cuda_graph_pad_size) self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size) else: - max_block_table_len = max( - len(block_table) for block_table in self.block_tables) block_tables = make_tensor_with_pad( self.block_tables, - max_len=max_block_table_len, pad=0, dtype=torch.int, device=device, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 62d0eeb249bd4..0706e2d3a48b7 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -182,11 +182,8 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens: List[int], input_block_tables[i, :len(block_table)] = block_table block_tables = torch.tensor(input_block_tables, device=device) else: - max_block_table_len = max( - len(block_table) for block_table in self.block_tables) block_tables = make_tensor_with_pad( self.block_tables, - max_len=max_block_table_len, pad=0, dtype=torch.int, device=device, diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 29b077cf6d912..390b5d173ebcd 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -2,14 +2,13 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Tuple -import numpy as np import torch from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData, SequenceGroupMetadata from vllm.utils import (async_tensor_h2d, is_pin_memory_available, - maybe_expand_dim) + make_tensor_with_pad, maybe_expand_dim) _SAMPLING_EPS = 1e-5 _SEED_0_REPLACEMENT = 3403598558 @@ -466,22 +465,24 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], do_penalties = prompt_tokens or output_tokens if do_penalties: - prompt_max_len = max([len(tokens) for tokens in prompt_tokens], - default=0) - prompt_padded_tokens = np.full( - (len(prompt_tokens), prompt_max_len), + prompt_t = make_tensor_with_pad( + prompt_tokens, vocab_size, - dtype=np.int64) - for i, tokens in enumerate(prompt_tokens): - prompt_padded_tokens[i, :len(tokens)] = tokens - output_max_len = max([len(tokens) for tokens in output_tokens], - default=0) - output_padded_tokens = np.full( - (len(output_tokens), output_max_len), + device="cpu", + dtype=torch.int64, + pin_memory=pin_memory, + ) + output_t = make_tensor_with_pad( + output_tokens, vocab_size, - dtype=np.int64) - for i, tokens in enumerate(output_tokens): - output_padded_tokens[i, :len(tokens)] = tokens + device="cpu", + dtype=torch.int64, + pin_memory=pin_memory, + ) + else: + empty_tensor = torch.empty(0, device=device, dtype=torch.long) + prompt_t = empty_tensor + output_t = empty_tensor temperatures_t = torch.tensor( temperatures, @@ -531,15 +532,6 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], dtype=torch.long, pin_memory=pin_memory, ) - if do_penalties: - prompt_tensor = torch.from_numpy(prompt_padded_tokens) - output_tensor = torch.from_numpy(output_padded_tokens) - if pin_memory: - prompt_tensor = prompt_tensor.pin_memory() - output_tensor = output_tensor.pin_memory() - else: - prompt_tensor = None - output_tensor = None # need to transpose and make contiguous to # copy the tensor correctly. # [batch_size, n_seeds] -> [n_seeds, batch_size] @@ -562,16 +554,6 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], extra_seeds_gpu = None sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds] - if do_penalties: - prompt_tokens_gpu = prompt_tensor.to(device=device, - non_blocking=True) - output_tokens_gpu = output_tensor.to(device=device, - non_blocking=True) - else: - empty_tensor = torch.empty(0, device=device, dtype=torch.long) - prompt_tokens_gpu = empty_tensor - output_tokens_gpu = empty_tensor - return cls( temperatures=temperatures_t.to(device=device, non_blocking=True), top_ps=top_ps_t.to(device=device, non_blocking=True), @@ -583,8 +565,8 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], non_blocking=True), repetition_penalties=repetition_penalties_t.to(device=device, non_blocking=True), - prompt_tokens=prompt_tokens_gpu, - output_tokens=output_tokens_gpu, + prompt_tokens=prompt_t.to(device=device, non_blocking=True), + output_tokens=output_t.to(device=device, non_blocking=True), sampling_seeds=sampling_seeds_gpu, sample_indices=sample_indices_t.to(device=device, non_blocking=True), diff --git a/vllm/utils.py b/vllm/utils.py index f906d82581233..9e222772eb5b9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -20,6 +20,7 @@ Union) import numpy as np +import numpy.typing as npt import psutil import torch import torch.types @@ -40,6 +41,15 @@ "fp8_e5m2": torch.uint8, } +TORCH_DTYPE_TO_NUMPY_DTYPE = { + torch.float16: np.float16, + torch.float32: np.float32, + torch.float64: np.float64, + torch.uint8: np.uint8, + torch.int32: np.int32, + torch.int64: np.int64, +} + P = ParamSpec('P') K = TypeVar("K") T = TypeVar("T") @@ -617,23 +627,54 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]: f"(e.g., 1, 2, 3). Given input: {s}") from e -def make_tensor_with_pad( - x: List[List[int]], - max_len: int, - pad: int, - dtype: torch.dtype, - device: Optional[Union[str, torch.device]], -) -> torch.Tensor: - """Make a padded tensor of a 2D inputs. +def make_ndarray_with_pad( + x: List[List[T]], + pad: T, + dtype: npt.DTypeLike, + *, + max_len: Optional[int] = None, +) -> npt.NDArray: + """ + Make a padded array from 2D inputs. The padding is applied to the end of each inner list until it reaches `max_len`. """ - padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad + if max_len is None: + # Unlike for most functions, map is faster than a genexpr over `len` + max_len = max(map(len, x), default=0) + + padded_x = np.full((len(x), max_len), pad, dtype=dtype) for ind, blocktb in enumerate(x): assert len(blocktb) <= max_len padded_x[ind, :len(blocktb)] = blocktb - return torch.tensor(padded_x, dtype=dtype, device=device) + + return padded_x + + +def make_tensor_with_pad( + x: List[List[T]], + pad: T, + dtype: torch.dtype, + *, + max_len: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + pin_memory: bool = False, +) -> torch.Tensor: + """ + Make a padded tensor from 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ + np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype] + padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len) + + tensor = torch.from_numpy(padded_x).to(device) + if pin_memory: + tensor = tensor.pin_memory() + + return tensor def async_tensor_h2d( diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index db0e178e45f4e..83f4ba69fb728 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -276,11 +276,8 @@ def _prepare_decode( dtype=torch.int, device=self.device) - max_block_table_len = max( - len(block_table) for block_table in block_tables) block_tables = make_tensor_with_pad( block_tables, - max_len=max_block_table_len, pad=0, dtype=torch.int, device=self.device, diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 423f44085e310..651319ab14548 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -121,13 +121,13 @@ def _prepare_prompt( max_seq_len = max(seq_lens) assert max_seq_len > 0 input_tokens = make_tensor_with_pad(input_tokens, - max_seq_len, pad=0, + max_len=max_seq_len, dtype=torch.long, device=self.device) input_positions = make_tensor_with_pad(input_positions, - max_seq_len, pad=0, + max_len=max_seq_len, dtype=torch.long, device=self.device) input_block_ids = torch.tensor(input_block_ids, @@ -171,13 +171,13 @@ def _prepare_decode( input_block_ids.append(block_table[0]) input_tokens = make_tensor_with_pad(input_tokens, - max_len=1, pad=0, + max_len=1, dtype=torch.long, device=self.device) input_positions = make_tensor_with_pad(input_positions, - max_len=1, pad=0, + max_len=1, dtype=torch.long, device=self.device) context_lens = torch.tensor(context_lens, diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 876abb3bf94d1..2f0ca42316e13 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -335,11 +335,8 @@ def _prepare_decode( dtype=torch.int, device=self.device) - max_block_table_len = max( - len(block_table) for block_table in block_tables) block_tables = make_tensor_with_pad( block_tables, - max_len=max_block_table_len, pad=0, dtype=torch.int, device=self.device,