Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Use numpy to speed up padded token processing #6442

Merged
merged 3 commits into from
Jul 16, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 15 additions & 20 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch

from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
Expand Down Expand Up @@ -457,16 +458,20 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
if do_penalties:
prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
default=0)
prompt_padded_tokens = [
tokens + [vocab_size] * (prompt_max_len - len(tokens))
for tokens in prompt_tokens
]
prompt_padded_tokens = np.full(
(len(prompt_tokens), prompt_max_len),
vocab_size,
dtype=np.int64)
for i, tokens in enumerate(prompt_tokens):
prompt_padded_tokens[i, :len(tokens)] = tokens
output_max_len = max([len(tokens) for tokens in output_tokens],
default=0)
output_padded_tokens = [
tokens + [vocab_size] * (output_max_len - len(tokens))
for tokens in output_tokens
]
output_padded_tokens = np.full(
(len(output_tokens), output_max_len),
vocab_size,
dtype=np.int64)
for i, tokens in enumerate(output_tokens):
output_padded_tokens[i, :len(tokens)] = tokens

temperatures_t = torch.tensor(
temperatures,
Expand Down Expand Up @@ -517,18 +522,8 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
pin_memory=pin_memory,
)
if do_penalties:
prompt_tensor = torch.tensor(
prompt_padded_tokens,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
output_tensor = torch.tensor(
output_padded_tokens,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
prompt_tensor = torch.from_numpy(prompt_padded_tokens)
output_tensor = torch.from_numpy(output_padded_tokens)
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
else:
prompt_tensor = None
output_tensor = None
Expand Down
Loading