Skip to content

Commit

Permalink
[Core] Use numpy to speed up padded token processing (vllm-project#6442)
Browse files Browse the repository at this point in the history
  • Loading branch information
peng1999 authored and dtrifiro committed Jul 17, 2024
1 parent af5b950 commit 6507bab
Showing 1 changed file with 18 additions and 20 deletions.
38 changes: 18 additions & 20 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch

from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
Expand Down Expand Up @@ -457,16 +458,20 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
if do_penalties:
prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
default=0)
prompt_padded_tokens = [
tokens + [vocab_size] * (prompt_max_len - len(tokens))
for tokens in prompt_tokens
]
prompt_padded_tokens = np.full(
(len(prompt_tokens), prompt_max_len),
vocab_size,
dtype=np.int64)
for i, tokens in enumerate(prompt_tokens):
prompt_padded_tokens[i, :len(tokens)] = tokens
output_max_len = max([len(tokens) for tokens in output_tokens],
default=0)
output_padded_tokens = [
tokens + [vocab_size] * (output_max_len - len(tokens))
for tokens in output_tokens
]
output_padded_tokens = np.full(
(len(output_tokens), output_max_len),
vocab_size,
dtype=np.int64)
for i, tokens in enumerate(output_tokens):
output_padded_tokens[i, :len(tokens)] = tokens

temperatures_t = torch.tensor(
temperatures,
Expand Down Expand Up @@ -517,18 +522,11 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
pin_memory=pin_memory,
)
if do_penalties:
prompt_tensor = torch.tensor(
prompt_padded_tokens,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
output_tensor = torch.tensor(
output_padded_tokens,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
prompt_tensor = torch.from_numpy(prompt_padded_tokens)
output_tensor = torch.from_numpy(output_padded_tokens)
if pin_memory:
prompt_tensor = prompt_tensor.pin_memory()
output_tensor = output_tensor.pin_memory()
else:
prompt_tensor = None
output_tensor = None
Expand Down

0 comments on commit 6507bab

Please sign in to comment.