Skip to content

Commit

Permalink
logprobs with decoded token almost working except newline and spaces(?)
Browse files Browse the repository at this point in the history
  • Loading branch information
abf149 committed Nov 4, 2024
1 parent 6c60d56 commit 0c23022
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 7 deletions.
23 changes: 18 additions & 5 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import PromptLogprobs, SampleLogprobs
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs)
Expand Down Expand Up @@ -315,6 +316,8 @@ def send_to_detokenizer(self, sampled: List[Tuple[Request, int]]) -> None:
skip_special_tokens=[],
spaces_between_special_tokens=[],
free_req_ids=[], # TODO(woosuk): Implement freeing.
logprobs={},
prompt_logprobs={},
)
for req, num_tokens in sampled:
inputs.req_ids.append(req.request_id)
Expand All @@ -330,6 +333,13 @@ def send_to_detokenizer(self, sampled: List[Tuple[Request, int]]) -> None:
inputs.spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens)

# Transmit (prompt)logprobs to detokenizer
req_id = req.request_id
if req.logprobs is not None:
inputs.logprobs[req_id] = req.logprobs
if req.prompt_logprobs is not None:
inputs.prompt_logprobs[req_id] = req.prompt_logprobs

# Update the number of lagged steps.
self.num_lagged_steps[req.request_id] += 1
self.detokenizer.send(inputs)
Expand All @@ -356,7 +366,9 @@ def recv_from_detokenizer(self) -> List[RequestOutput]:
and req.is_finished())
req_output = self._make_request_output(
req, detokenizer_output.num_output_token_ids[i],
detokenizer_output.detokenized_texts[i], finished)
detokenizer_output.detokenized_texts[i],
detokenizer_output.logprobs.get(req_id, None),
detokenizer_output.prompt_logprobs.get(req_id, None), finished)
req_outputs.append(req_output)

if finished:
Expand All @@ -371,6 +383,8 @@ def _make_request_output(
request: Request,
num_output_tokens: int,
new_output_text: str,
logprobs: Optional[SampleLogprobs],
prompt_logprobs: Optional[PromptLogprobs],
finished: bool,
) -> RequestOutput:
req_output = self.request_outputs.get(request.request_id)
Expand Down Expand Up @@ -408,23 +422,22 @@ def _make_request_output(
completion_output.token_ids = (
request.output_token_ids[:num_output_tokens])
if do_logprobs:
completion_output.logprobs = (
request.logprobs[:num_output_tokens])
completion_output.logprobs = (logprobs[:num_output_tokens])
elif request.sampling_params.output_kind == RequestOutputKind.DELTA:
completion_output.text = new_output_text
num_prev_tokens = len(completion_output.token_ids)
completion_output.token_ids = request.output_token_ids[
num_prev_tokens:num_output_tokens]
if do_logprobs:
completion_output.logprobs = (
request.logprobs[num_prev_tokens:num_output_tokens])
logprobs[num_prev_tokens:num_output_tokens])
elif (request.sampling_params.output_kind ==
RequestOutputKind.FINAL_ONLY):
if finished:
completion_output.text = request.output_text
completion_output.token_ids = request.output_token_ids
if do_logprobs:
completion_output.logprobs = request.logprobs
completion_output.logprobs = logprobs
else:
completion_output.text = ""
completion_output.token_ids = []
Expand Down
83 changes: 81 additions & 2 deletions vllm/v1/tokenizer/detokenizer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import multiprocessing
from dataclasses import dataclass
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

import msgspec
import zmq
from msgspec import msgpack

from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
from vllm.transformers_utils.detokenizer_utils import (
convert_prompt_ids_to_tokens, detokenize_incrementally)
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import get_open_port

AnyLogprobs = Union[Optional[SampleLogprobs], Optional[PromptLogprobs]]


class DetokenizerInputs(msgspec.Struct):

Expand All @@ -23,6 +26,13 @@ class DetokenizerInputs(msgspec.Struct):
skip_special_tokens: List[bool]
spaces_between_special_tokens: List[bool]

# If (prompt)logprobs are going to be returned, the decoded token
# field must be computed (detokenized) from the logprob token id.
# Assumption: the decoded token fields of the (prompt)logprobs
# is `None`
logprobs: Dict[str, SampleLogprobs]
prompt_logprobs: Dict[str, PromptLogprobs]

# [num_free_reqs]
free_req_ids: List[str]

Expand All @@ -39,6 +49,14 @@ class DetokenizerOutputs(msgspec.Struct):
# output token ids to be consistent with the detokenized text.
num_output_token_ids: List[int]

# (prompt)logprobs outputs are `None` unless un-detokenized (prompt)logprobs
# were not provided as input.
# Unlike the input logprobs, the decoded token fields of the
# (prompt)logprobs (if not `None`) will be the decoded string representation
# of the logprob token id
logprobs: Dict[str, SampleLogprobs]
prompt_logprobs: Dict[str, PromptLogprobs]


class Detokenizer:

Expand Down Expand Up @@ -117,6 +135,8 @@ def run(self):
for req_id in inputs.free_req_ids:
self.free(req_id)

detokenized_logprobs: Dict[str, SampleLogprobs] = {}
detokenized_prompt_logprobs: Dict[str, PromptLogprobs] = {}
detokenized_texts: List[str] = []
num_output_token_ids: List[int] = []
num_reqs = len(inputs.req_ids)
Expand All @@ -132,6 +152,23 @@ def run(self):
)
new_str = self.detokenize(req_id, inputs.new_token_ids[i])
detokenized_texts.append(new_str)
if req_id in inputs.logprobs:
# If there are logprobs associated with this request,
# detokenize them
self.modify_logprobs_in_place(
req_id,
inputs.logprobs[req_id],
)
detokenized_logprobs[req_id] = inputs.logprobs[req_id]
if req_id in inputs.prompt_logprobs:
# If there are prompt logprobs associated with this request,
# detokenize them
self.modify_logprobs_in_place(
req_id,
inputs.prompt_logprobs[req_id],
)
detokenized_prompt_logprobs[
req_id] = inputs.prompt_logprobs[req_id]
req_state = self.request_states[req_id]
num_output_token_ids.append(
len(req_state.token_ids) - req_state.num_prompt_tokens)
Expand All @@ -140,7 +177,8 @@ def run(self):
req_ids=inputs.req_ids,
detokenized_texts=detokenized_texts,
num_output_token_ids=num_output_token_ids,
)
logprobs=detokenized_logprobs,
prompt_logprobs=detokenized_prompt_logprobs)
self.push_socket.send(self.msgpack_encoder.encode(detokenized),
flags=zmq.NOBLOCK)

Expand Down Expand Up @@ -196,6 +234,47 @@ def detokenize(self, request_id: str, new_token_ids: List[int]) -> str:
decoded_text += new_decoded_token_text
return decoded_text

def detokenize_logprob_in_place(
self,
skip_special_tokens: bool,
logprob_dict: Dict[int, Logprob],
) -> None:
"""Cmopute `decoded_token` by detokenizing a logprob's token id"""

for token_id in logprob_dict:
# Detokenize logprob for a particular top
# token at a particular token offset
logprob_dict[token_id].decoded_token = (
self.tokenizer.convert_ids_to_tokens(
[token_id], skip_special_tokens=skip_special_tokens))[0]

def modify_logprobs_in_place(
self,
request_id: str,
logprobs: AnyLogprobs,
) -> None:
"""Compute (in-place) the `decoded_token` field of a request's logprobs
Behavior: for each token offset, for each top token,
compute `decoded_token` for that token.
Args:
request_id
logprobs_list: request logprobs
"""

if logprobs is not None:
# Request has logprobs
req_state = self.request_states[request_id]
skip_special_tokens = req_state.skip_special_tokens
for logprob_dict in logprobs:
if logprob_dict is not None:
# Logprobs at a token offset
self.detokenize_logprob_in_place(
skip_special_tokens,
logprob_dict,
)


@dataclass
class RequestState:
Expand Down

0 comments on commit 0c23022

Please sign in to comment.