Skip to content

Commit

Permalink
logprobs is basically all right; need to properly slice prompt logpro…
Browse files Browse the repository at this point in the history
…bs in update_from_output
  • Loading branch information
abf149 committed Nov 4, 2024
1 parent 01d424d commit 37a76c3
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 86 deletions.
63 changes: 45 additions & 18 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from dataclasses import dataclass
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union

import torch

from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
Expand Down Expand Up @@ -260,29 +262,42 @@ def update_from_output(
# NOTE(woosuk): Currently, we assume that each request
# generates at most one token at each step.
token_id = sampled_token_ids[req_index]
if request.max_logprobs > 0:
# Construct logprobs, if requested
max_logprobs = request.max_logprobs
if max_logprobs > 0:
# Construct logprobs, if requested (TODO: assumes one
# generated token). Note that Sampler returns
#
# logprob_token_ids =
# <(batch max logprobs) tok ids><sampled tok id>
# logprob_values =
# <(batch max logprobs) tok logprobs><sampled tok logprob>
logprob_token_ids = logprob_token_ids_list[req_index]
logprob_values = logprob_values_list[req_index]
logprobs = {
logprob_cnt = max_logprobs
if token_id not in logprob_token_ids[0:max_logprobs]:
# Sampled token is not in the in the top logprobs;
# inject it & resort, ensuring that excess logprobs
# not requested by the user have -inf probability
logprob_values[max_logprobs:-1] = (
[float('-inf')] *
(len(logprob_values) - 1 - max_logprobs))
logprob_values, indices = torch.sort(logprob_values,
dim=-1)
logprob_token_ids = torch.gather(
logprob_token_ids, 1, indices)
# There will be one more logprob than the user requested
logprob_cnt = max_logprobs + 1

# Only keep the number of logprobs specified by the request
# (plus possibly the sampled token id & its logprob)
logprob_values = logprob_values[0:logprob_cnt]
logprob_token_ids = logprob_token_ids[0:logprob_cnt]

request.logprobs.append({
lpt: Logprob(lpv, (idx + 1), None)
for idx, (lpv, lpt) in enumerate(
zip(logprob_values, logprob_token_ids))
}
request.logprobs.append(logprobs)
if request.max_prompt_logprobs > 0:
# Construct prompt logprobs, if requested
prompt_logprob_token_ids = prompt_logprob_token_ids_list[
req_index]
prompt_logprob_values = prompt_logprob_values_list[
req_index]
prompt_logprobs = {
lpt: Logprob(lpv, (idx + 1), None)
for idx, (lpv, lpt) in enumerate(
zip(prompt_logprob_values,
prompt_logprob_token_ids))
}
request.prompt_logprobs.append(prompt_logprobs)
})
request.output_token_ids.append(token_id)
sampled.append((request, 1))
# TODO: Update the KV cache manager for prefix caching.
Expand All @@ -292,6 +307,18 @@ def update_from_output(
if stopped:
continue

if request.max_prompt_logprobs > 0:
# Construct prompt logprobs, if requested
prompt_logprob_token_ids = prompt_logprob_token_ids_list[
req_index]
prompt_logprob_values = prompt_logprob_values_list[req_index]
prompt_logprobs = {
lpt: Logprob(lpv, (idx + 1), None)
for idx, (lpv, lpt) in enumerate(
zip(prompt_logprob_values, prompt_logprob_token_ids))
}
request.prompt_logprobs.append(prompt_logprobs)

new_running.append(request)
self.running = new_running
return sampled
Expand Down
6 changes: 1 addition & 5 deletions vllm/v1/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ class SamplerOutput:
# [num_reqs, max_num_logprobs + 1]
logprobs: Optional[torch.Tensor]

# TODO: Support prompt logprobs.
prompt_logprob_token_ids: Optional[torch.Tensor]
prompt_logprobs: Optional[torch.Tensor]


@dataclass
class ModelRunnerOutput:
Expand All @@ -39,4 +35,4 @@ class ModelRunnerOutput:
# [num_reqs, max_num_prompt_logprobs]
prompt_logprob_token_ids_cpu: Optional[torch.Tensor]
# [num_reqs, max_num_prompt_logprobs]
prompt_logprobs_cpu: Optional[torch.Tensor]
prompt_logprobs_cpu: Optional[torch.Tensor]
51 changes: 13 additions & 38 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""A layer that samples the next tokens from the model's outputs."""
from typing import Dict, Optional
from typing import Dict

import torch
import torch.nn as nn
Expand All @@ -16,7 +16,6 @@ def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
prompt_logits: Optional[torch.Tensor] = None,
) -> SamplerOutput:
logits = self.apply_temperature(logits, sampling_metadata.temperature)
logits = self.apply_top_k_top_p(logits, sampling_metadata)
Expand All @@ -28,48 +27,24 @@ def forward(

if sampling_metadata.max_num_logprobs > 0:
logprobs = self.get_logprobs(logits)
# FIXME: Mask the sampled token_id, get topk logprobs,
# and concatenate the topk with the sampled token_id.
sampled_logprobs = logprobs[torch.arange(logprobs.shape[0]),
sampled]
topk_logprobs, topk_indices = torch.topk(
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
# Use int32 to reduce the tensor size.
topk_indices = topk_indices.to(torch.int32)
# Use int32 to reduce the tensor size. Concat sampled token id
topk_indices = torch.cat(
(topk_indices.to(torch.int32), sampled.unsqueeze(-1)), dim=-1)
# Concat sampled token logprobs
topk_logprobs = torch.cat(
(topk_logprobs, sampled_logprobs.unsqueeze(-1)), dim=-1)

else:
topk_logprobs = None
topk_indices = None

max_num_prompt_logprobs = sampling_metadata.max_num_prompt_logprobs
if max_num_prompt_logprobs > 0:
prompt_logits = self.apply_temperature(
prompt_logits, sampling_metadata.temperature)
prompt_logits = self.apply_top_k_top_p(prompt_logits,
sampling_metadata)
prompt_logprobs = self.get_logprobs(prompt_logits)

topk_prompt_logprobs, topk_prompt_indices = torch.topk(
prompt_logprobs,
sampling_metadata.max_num_prompt_logprobs,
dim=-1)
# Use int32 to reduce the tensor size.
topk_prompt_indices = topk_prompt_indices.to(torch.int32)

sampler_output = SamplerOutput(
sampled_token_ids=sampled,
logprob_token_ids=topk_indices,
logprobs=topk_logprobs,
prompt_logprob_token_ids=topk_prompt_indices,
prompt_logprobs=topk_prompt_logprobs,
)
else:
assert prompt_logits is None

sampler_output = SamplerOutput(
sampled_token_ids=sampled,
logprob_token_ids=topk_indices,
logprobs=topk_logprobs,
prompt_logprob_token_ids=None,
prompt_logprobs=None,
)
sampler_output = SamplerOutput(sampled_token_ids=sampled,
logprob_token_ids=topk_indices,
logprobs=topk_logprobs)

return sampler_output

Expand Down
75 changes: 50 additions & 25 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from unittest.mock import patch

import numpy as np
Expand Down Expand Up @@ -304,6 +304,34 @@ def _prepare_sampling(
sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy)
return sampling_metadata

def _compute_prompt_logprobs(
self,
sampling_metadata: SamplingMetadata,
prompt_logits: Optional[torch.Tensor],
seq_start_loc: torch.Tensor,
num_generated_tokens: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Compute prompt lens
prompt_lens = torch.diff(seq_start_loc) - num_generated_tokens

max_num_prompt_logprobs = sampling_metadata.max_num_prompt_logprobs
assert max_num_prompt_logprobs > 0

prompt_logits = self.model.sampler.apply_temperature(
prompt_logits,
torch.repeat_interleave(sampling_metadata.temperature,
prompt_lens))
prompt_logits = self.model.sampler.apply_top_k_top_p(
prompt_logits, sampling_metadata)
prompt_logprobs = self.model.sampler.get_logprobs(prompt_logits)

topk_prompt_logprobs, topk_prompt_indices = torch.topk(
prompt_logprobs, sampling_metadata.max_num_prompt_logprobs, dim=-1)
# Use int32 to reduce the tensor size.
topk_prompt_indices = topk_prompt_indices.to(torch.int32)

return topk_prompt_indices, topk_prompt_logprobs

@torch.inference_mode()
def execute_model(
self,
Expand All @@ -321,30 +349,29 @@ def execute_model(
attn_metadata=attn_metadata,
)
sampling_metadata = self._prepare_sampling(scheduler_output)

if sampling_metadata.max_num_logprobs > 0:
do_prompt_logprobs = sampling_metadata.max_num_prompt_logprobs > 0
if do_prompt_logprobs:
# One or more requests require prompt logprobs
hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(hidden_states, None)

# Sample the next token and get logprobs if needed.
sampler_output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
else:
# No requests require prompt logprobs
logits = self.model.compute_logits(hidden_states, None)
mask = torch.ones(input_ids.shape[0], dtype=torch.bool)
mask[logits_indices] = False
prompt_logits = logits[mask, :]
logits = logits[logits_indices, :]
(
prompt_logprob_token_ids,
prompt_logprobs,
) = self._compute_prompt_logprobs(sampling_metadata, prompt_logits,
attn_metadata.seq_start_loc, 1)
else:
# No requests require prompt logprobs
hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(hidden_states, None)

# Sample the next token and get logprobs if needed.
sampler_output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
prompt_logits=prompt_logits)
# Sample the next token and get logprobs if needed.
sampler_output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)

# NOTE: CPU-GPU synchronization happens here.
sampled_token_ids = sampler_output.sampled_token_ids.cpu()
Expand Down Expand Up @@ -378,15 +405,13 @@ def execute_model(
else:
logprobs = sampler_output.logprobs.cpu()

if sampler_output.prompt_logprob_token_ids is None:
prompt_logprob_token_ids = None
if do_prompt_logprobs:
prompt_logprob_token_ids = prompt_logprob_token_ids.cpu()
prompt_logprobs = prompt_logprobs.cpu()
else:
prompt_logprob_token_ids = (
sampler_output.prompt_logprob_token_ids.cpu())
if sampler_output.prompt_logprobs is None:
prompt_logprob_token_ids = None
prompt_logprobs = None
else:
prompt_logprobs = sampler_output.prompt_logprobs.cpu()

model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids[:num_reqs],
req_id_to_index=self.input_batch.req_id_to_index,
Expand Down

0 comments on commit 37a76c3

Please sign in to comment.