Skip to content

Commit

Permalink
[Misc] Consolidate and optimize logic for building padded tensors (vl…
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored and jimpang committed Jul 24, 2024
1 parent 6aed998 commit 98519e9
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 75 deletions.
12 changes: 3 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from vllm.inputs import TextPrompt
from vllm.logger import init_logger
from vllm.sequence import SampleLogprobs
from vllm.utils import cuda_device_count_stateless, is_cpu
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
is_cpu)

logger = init_logger(__name__)

Expand Down Expand Up @@ -124,12 +125,6 @@ def image_assets() -> _ImageAssets:
return IMAGE_ASSETS


_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
}

_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)


Expand All @@ -151,8 +146,7 @@ def __init__(
is_vision_model: bool = False,
is_sparseml_model: bool = False,
) -> None:
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]

self.model_name = model_name

Expand Down
3 changes: 0 additions & 3 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,8 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
else:
max_block_table_len = max(
len(block_table) for block_table in self.block_tables)
block_tables = make_tensor_with_pad(
self.block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=device,
Expand Down
3 changes: 0 additions & 3 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,8 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
cuda_graph_pad_size)
self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
else:
max_block_table_len = max(
len(block_table) for block_table in self.block_tables)
block_tables = make_tensor_with_pad(
self.block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=device,
Expand Down
3 changes: 0 additions & 3 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,8 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens: List[int],
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
else:
max_block_table_len = max(
len(block_table) for block_table in self.block_tables)
block_tables = make_tensor_with_pad(
self.block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=device,
Expand Down
56 changes: 19 additions & 37 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch

from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData, SequenceGroupMetadata
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
maybe_expand_dim)
make_tensor_with_pad, maybe_expand_dim)

_SAMPLING_EPS = 1e-5
_SEED_0_REPLACEMENT = 3403598558
Expand Down Expand Up @@ -466,22 +465,24 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
do_penalties = prompt_tokens or output_tokens

if do_penalties:
prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
default=0)
prompt_padded_tokens = np.full(
(len(prompt_tokens), prompt_max_len),
prompt_t = make_tensor_with_pad(
prompt_tokens,
vocab_size,
dtype=np.int64)
for i, tokens in enumerate(prompt_tokens):
prompt_padded_tokens[i, :len(tokens)] = tokens
output_max_len = max([len(tokens) for tokens in output_tokens],
default=0)
output_padded_tokens = np.full(
(len(output_tokens), output_max_len),
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
output_t = make_tensor_with_pad(
output_tokens,
vocab_size,
dtype=np.int64)
for i, tokens in enumerate(output_tokens):
output_padded_tokens[i, :len(tokens)] = tokens
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
else:
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
prompt_t = empty_tensor
output_t = empty_tensor

temperatures_t = torch.tensor(
temperatures,
Expand Down Expand Up @@ -531,15 +532,6 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
dtype=torch.long,
pin_memory=pin_memory,
)
if do_penalties:
prompt_tensor = torch.from_numpy(prompt_padded_tokens)
output_tensor = torch.from_numpy(output_padded_tokens)
if pin_memory:
prompt_tensor = prompt_tensor.pin_memory()
output_tensor = output_tensor.pin_memory()
else:
prompt_tensor = None
output_tensor = None
# need to transpose and make contiguous to
# copy the tensor correctly.
# [batch_size, n_seeds] -> [n_seeds, batch_size]
Expand All @@ -562,16 +554,6 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
extra_seeds_gpu = None
sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]

if do_penalties:
prompt_tokens_gpu = prompt_tensor.to(device=device,
non_blocking=True)
output_tokens_gpu = output_tensor.to(device=device,
non_blocking=True)
else:
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
prompt_tokens_gpu = empty_tensor
output_tokens_gpu = empty_tensor

return cls(
temperatures=temperatures_t.to(device=device, non_blocking=True),
top_ps=top_ps_t.to(device=device, non_blocking=True),
Expand All @@ -583,8 +565,8 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
non_blocking=True),
repetition_penalties=repetition_penalties_t.to(device=device,
non_blocking=True),
prompt_tokens=prompt_tokens_gpu,
output_tokens=output_tokens_gpu,
prompt_tokens=prompt_t.to(device=device, non_blocking=True),
output_tokens=output_t.to(device=device, non_blocking=True),
sampling_seeds=sampling_seeds_gpu,
sample_indices=sample_indices_t.to(device=device,
non_blocking=True),
Expand Down
61 changes: 51 additions & 10 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Union)

import numpy as np
import numpy.typing as npt
import psutil
import torch
import torch.types
Expand All @@ -40,6 +41,15 @@
"fp8_e5m2": torch.uint8,
}

TORCH_DTYPE_TO_NUMPY_DTYPE = {
torch.float16: np.float16,
torch.float32: np.float32,
torch.float64: np.float64,
torch.uint8: np.uint8,
torch.int32: np.int32,
torch.int64: np.int64,
}

P = ParamSpec('P')
K = TypeVar("K")
T = TypeVar("T")
Expand Down Expand Up @@ -617,23 +627,54 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]:
f"(e.g., 1, 2, 3). Given input: {s}") from e


def make_tensor_with_pad(
x: List[List[int]],
max_len: int,
pad: int,
dtype: torch.dtype,
device: Optional[Union[str, torch.device]],
) -> torch.Tensor:
"""Make a padded tensor of a 2D inputs.
def make_ndarray_with_pad(
x: List[List[T]],
pad: T,
dtype: npt.DTypeLike,
*,
max_len: Optional[int] = None,
) -> npt.NDArray:
"""
Make a padded array from 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad
if max_len is None:
# Unlike for most functions, map is faster than a genexpr over `len`
max_len = max(map(len, x), default=0)

padded_x = np.full((len(x), max_len), pad, dtype=dtype)
for ind, blocktb in enumerate(x):
assert len(blocktb) <= max_len
padded_x[ind, :len(blocktb)] = blocktb
return torch.tensor(padded_x, dtype=dtype, device=device)

return padded_x


def make_tensor_with_pad(
x: List[List[T]],
pad: T,
dtype: torch.dtype,
*,
max_len: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
pin_memory: bool = False,
) -> torch.Tensor:
"""
Make a padded tensor from 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)

tensor = torch.from_numpy(padded_x).to(device)
if pin_memory:
tensor = tensor.pin_memory()

return tensor


def async_tensor_h2d(
Expand Down
3 changes: 0 additions & 3 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,8 @@ def _prepare_decode(
dtype=torch.int,
device=self.device)

max_block_table_len = max(
len(block_table) for block_table in block_tables)
block_tables = make_tensor_with_pad(
block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=self.device,
Expand Down
8 changes: 4 additions & 4 deletions vllm/worker/neuron_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,13 @@ def _prepare_prompt(
max_seq_len = max(seq_lens)
assert max_seq_len > 0
input_tokens = make_tensor_with_pad(input_tokens,
max_seq_len,
pad=0,
max_len=max_seq_len,
dtype=torch.long,
device=self.device)
input_positions = make_tensor_with_pad(input_positions,
max_seq_len,
pad=0,
max_len=max_seq_len,
dtype=torch.long,
device=self.device)
input_block_ids = torch.tensor(input_block_ids,
Expand Down Expand Up @@ -171,13 +171,13 @@ def _prepare_decode(
input_block_ids.append(block_table[0])

input_tokens = make_tensor_with_pad(input_tokens,
max_len=1,
pad=0,
max_len=1,
dtype=torch.long,
device=self.device)
input_positions = make_tensor_with_pad(input_positions,
max_len=1,
pad=0,
max_len=1,
dtype=torch.long,
device=self.device)
context_lens = torch.tensor(context_lens,
Expand Down
3 changes: 0 additions & 3 deletions vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,11 +335,8 @@ def _prepare_decode(
dtype=torch.int,
device=self.device)

max_block_table_len = max(
len(block_table) for block_table in block_tables)
block_tables = make_tensor_with_pad(
block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=self.device,
Expand Down

0 comments on commit 98519e9

Please sign in to comment.