Skip to content

Commit

Permalink
[V1] Support per-request seed (vllm-project#9945)
Browse files Browse the repository at this point in the history
Signed-off-by: Nick Hill <[email protected]>
  • Loading branch information
njhill authored Nov 3, 2024
1 parent 3bb4bef commit 1f1b6d6
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 48 deletions.
5 changes: 2 additions & 3 deletions vllm/v1/sample/metadata.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Optional
from typing import Dict

import torch

Expand All @@ -16,7 +16,6 @@ class SamplingMetadata:
no_top_p: bool
no_top_k: bool

generators: List[Optional[torch.Generator]]
no_generator: bool
generators: Dict[int, torch.Generator]

max_num_logprobs: int
23 changes: 10 additions & 13 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""A layer that samples the next tokens from the model's outputs."""
from typing import List, Optional
from typing import Dict

import torch
import torch.nn as nn
Expand Down Expand Up @@ -84,22 +84,21 @@ def greedy_sample(self, probs: torch.Tensor) -> torch.Tensor:
def random_sample(
self,
probs: torch.Tensor,
generators: List[Optional[torch.Generator]],
no_generator: bool,
generators: Dict[int, torch.Generator],
) -> torch.Tensor:
q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
q.exponential_()
if not no_generator:
assert len(generators) == probs.shape[0]
if len(generators) != probs.shape[0]:
# This might still be done here unnecessarily if there are greedies
q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in enumerate(generators):
if generator is not None:
q[i].exponential_(generator=generator)
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)

def sample(
Expand All @@ -112,13 +111,11 @@ def sample(
if sampling_metadata.all_greedy:
return self.greedy_sample(probs)
if sampling_metadata.all_random:
return self.random_sample(probs, sampling_metadata.generators,
sampling_metadata.no_generator)
return self.random_sample(probs, sampling_metadata.generators)

greedy_sampled = self.greedy_sample(probs)
random_sampled = self.random_sample(probs,
sampling_metadata.generators,
sampling_metadata.no_generator)
sampling_metadata.generators)
sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled,
Expand Down
61 changes: 29 additions & 32 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Add new requests to the cached states.
for req_data in scheduler_output.scheduled_new_reqs:
req_id = req_data.req_id
sampling_params = req_data.sampling_params
if sampling_params.seed is not None:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None

self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=req_data.prompt_token_ids,
prompt=req_data.prompt,
multi_modal_data=req_data.multi_modal_data,
sampling_params=req_data.sampling_params,
generator=None, # TODO
sampling_params=sampling_params,
generator=generator,
block_ids=req_data.block_ids,
num_computed_tokens=req_data.num_computed_tokens,
output_token_ids=[],
Expand Down Expand Up @@ -342,11 +349,9 @@ def execute_model(
else:
# Ignore the sampled token from the partial request.
# Rewind the generator state as if the token was not sampled.
generator = self.input_batch.generators[i]
generator = self.input_batch.generators.get(i)
if generator is not None:
offset = generator.get_offset()
generator = generator.set_offset(offset - 1)
self.input_batch.generators[i] = generator
generator.set_offset(generator.get_offset() - 1)

if sampler_output.logprob_token_ids is None:
logprob_token_ids = None
Expand Down Expand Up @@ -494,8 +499,8 @@ def __init__(
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: Set[str] = set()

self.generators: List[Optional[torch.Generator]] = [None
] * max_num_reqs
# req_index -> generator
self.generators: Dict[int, torch.Generator] = {}

self.num_logprobs: Dict[str, int] = {}
self.prompt_logprob_reqs: Set[str] = set()
Expand All @@ -509,8 +514,9 @@ def add_request(
req_index = self.num_reqs
assert req_index < self.max_num_reqs

self.req_ids[req_index] = request.req_id
self.req_id_to_index[request.req_id] = req_index
req_id = request.req_id
self.req_ids[req_index] = req_id
self.req_id_to_index[req_id] = req_index

# Copy the prompt token ids and output token ids.
num_prompt_tokens = len(request.prompt_token_ids)
Expand All @@ -528,27 +534,24 @@ def add_request(
sampling_params = request.sampling_params
self.temperature_cpu[req_index] = sampling_params.temperature
if sampling_params.sampling_type == SamplingType.GREEDY:
self.greedy_reqs.add(req_index)
elif sampling_params.sampling_type == SamplingType.RANDOM:
self.random_reqs.add(req_index)
elif sampling_params.sampling_type == SamplingType.RANDOM_SEED:
# TODO(woosuk): Support per-request random seed.
raise NotImplementedError("Per-request seed is not supported yet.")
self.greedy_reqs.add(req_id)
else:
self.random_reqs.add(req_id)

self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_index)
self.top_p_reqs.add(req_id)
self.top_k_cpu[req_index] = sampling_params.top_k
if sampling_params.top_k > 0:
self.top_k_reqs.add(req_index)
self.top_k_reqs.add(req_id)

self.generators[req_index] = request.generator

num_logprobs = sampling_params.logprobs
if num_logprobs is not None and num_logprobs > 0:
self.num_logprobs[request.req_id] = num_logprobs
self.num_logprobs[req_id] = num_logprobs
if sampling_params.prompt_logprobs:
self.prompt_logprob_reqs.add(req_index)
self.prompt_logprob_reqs.add(req_id)

def remove_request(self, req_id: str) -> Optional[int]:
req_index = self.req_id_to_index.pop(req_id, None)
Expand All @@ -560,7 +563,7 @@ def remove_request(self, req_id: str) -> Optional[int]:
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.generators[req_index] = None
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.prompt_logprob_reqs.discard(req_id)
return req_index
Expand Down Expand Up @@ -612,7 +615,9 @@ def condense(self, empty_req_indices: List[int]) -> None:
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
self.generators[empty_index] = self.generators[last_req_index]
generator = self.generators.pop(last_req_index, None)
if generator is not None:
self.generators[empty_index] = generator

# Decrement last_req_index since it is now empty.
last_req_index -= 1
Expand All @@ -636,8 +641,7 @@ def make_sampling_metadata(
top_k=self.top_k[:self.num_reqs],
no_top_p=self.no_top_p,
no_top_k=self.no_top_k,
generators=self.generators[:self.num_reqs],
no_generator=self.no_generator,
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
)

Expand All @@ -661,16 +665,9 @@ def no_top_p(self) -> bool:
def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0

@property
def no_generator(self) -> bool:
return len(self.generators) == 0

@property
def max_num_logprobs(self) -> int:
if self.num_logprobs:
return max(self.num_logprobs.values())
else:
return 0
return max(self.num_logprobs.values()) if self.num_logprobs else 0

@property
def no_logprob(self) -> bool:
Expand Down

0 comments on commit 1f1b6d6

Please sign in to comment.