diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 94cef76c365a1..8776748cc85f0 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -16,6 +16,7 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.sequence import PromptLogprobs, SampleLogprobs from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.tokenizer_group import ( BaseTokenizerGroup, init_tokenizer_from_configs) @@ -315,6 +316,8 @@ def send_to_detokenizer(self, sampled: List[Tuple[Request, int]]) -> None: skip_special_tokens=[], spaces_between_special_tokens=[], free_req_ids=[], # TODO(woosuk): Implement freeing. + logprobs={}, + prompt_logprobs={}, ) for req, num_tokens in sampled: inputs.req_ids.append(req.request_id) @@ -330,6 +333,13 @@ def send_to_detokenizer(self, sampled: List[Tuple[Request, int]]) -> None: inputs.spaces_between_special_tokens.append( req.sampling_params.spaces_between_special_tokens) + # Transmit (prompt)logprobs to detokenizer + req_id = req.request_id + if req.logprobs is not None: + inputs.logprobs[req_id] = req.logprobs + if req.prompt_logprobs is not None: + inputs.prompt_logprobs[req_id] = req.prompt_logprobs + # Update the number of lagged steps. self.num_lagged_steps[req.request_id] += 1 self.detokenizer.send(inputs) @@ -356,7 +366,9 @@ def recv_from_detokenizer(self) -> List[RequestOutput]: and req.is_finished()) req_output = self._make_request_output( req, detokenizer_output.num_output_token_ids[i], - detokenizer_output.detokenized_texts[i], finished) + detokenizer_output.detokenized_texts[i], + detokenizer_output.logprobs.get(req_id, None), + detokenizer_output.prompt_logprobs.get(req_id, None), finished) req_outputs.append(req_output) if finished: @@ -371,6 +383,8 @@ def _make_request_output( request: Request, num_output_tokens: int, new_output_text: str, + logprobs: Optional[SampleLogprobs], + prompt_logprobs: Optional[PromptLogprobs], finished: bool, ) -> RequestOutput: req_output = self.request_outputs.get(request.request_id) @@ -408,8 +422,7 @@ def _make_request_output( completion_output.token_ids = ( request.output_token_ids[:num_output_tokens]) if do_logprobs: - completion_output.logprobs = ( - request.logprobs[:num_output_tokens]) + completion_output.logprobs = (logprobs[:num_output_tokens]) elif request.sampling_params.output_kind == RequestOutputKind.DELTA: completion_output.text = new_output_text num_prev_tokens = len(completion_output.token_ids) @@ -417,14 +430,14 @@ def _make_request_output( num_prev_tokens:num_output_tokens] if do_logprobs: completion_output.logprobs = ( - request.logprobs[num_prev_tokens:num_output_tokens]) + logprobs[num_prev_tokens:num_output_tokens]) elif (request.sampling_params.output_kind == RequestOutputKind.FINAL_ONLY): if finished: completion_output.text = request.output_text completion_output.token_ids = request.output_token_ids if do_logprobs: - completion_output.logprobs = request.logprobs + completion_output.logprobs = logprobs else: completion_output.text = "" completion_output.token_ids = [] diff --git a/vllm/v1/tokenizer/detokenizer.py b/vllm/v1/tokenizer/detokenizer.py index 4bbcf4717981e..165667217a082 100644 --- a/vllm/v1/tokenizer/detokenizer.py +++ b/vllm/v1/tokenizer/detokenizer.py @@ -1,16 +1,19 @@ import multiprocessing from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import msgspec import zmq from msgspec import msgpack +from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs from vllm.transformers_utils.detokenizer_utils import ( convert_prompt_ids_to_tokens, detokenize_incrementally) from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import get_open_port +AnyLogprobs = Union[Optional[SampleLogprobs], Optional[PromptLogprobs]] + class DetokenizerInputs(msgspec.Struct): @@ -23,6 +26,13 @@ class DetokenizerInputs(msgspec.Struct): skip_special_tokens: List[bool] spaces_between_special_tokens: List[bool] + # If (prompt)logprobs are going to be returned, the decoded token + # field must be computed (detokenized) from the logprob token id. + # Assumption: the decoded token fields of the (prompt)logprobs + # is `None` + logprobs: Dict[str, SampleLogprobs] + prompt_logprobs: Dict[str, PromptLogprobs] + # [num_free_reqs] free_req_ids: List[str] @@ -39,6 +49,14 @@ class DetokenizerOutputs(msgspec.Struct): # output token ids to be consistent with the detokenized text. num_output_token_ids: List[int] + # (prompt)logprobs outputs are `None` unless un-detokenized (prompt)logprobs + # were not provided as input. + # Unlike the input logprobs, the decoded token fields of the + # (prompt)logprobs (if not `None`) will be the decoded string representation + # of the logprob token id + logprobs: Dict[str, SampleLogprobs] + prompt_logprobs: Dict[str, PromptLogprobs] + class Detokenizer: @@ -117,6 +135,8 @@ def run(self): for req_id in inputs.free_req_ids: self.free(req_id) + detokenized_logprobs: Dict[str, SampleLogprobs] = {} + detokenized_prompt_logprobs: Dict[str, PromptLogprobs] = {} detokenized_texts: List[str] = [] num_output_token_ids: List[int] = [] num_reqs = len(inputs.req_ids) @@ -132,6 +152,23 @@ def run(self): ) new_str = self.detokenize(req_id, inputs.new_token_ids[i]) detokenized_texts.append(new_str) + if req_id in inputs.logprobs: + # If there are logprobs associated with this request, + # detokenize them + self.modify_logprobs_in_place( + req_id, + inputs.logprobs[req_id], + ) + detokenized_logprobs[req_id] = inputs.logprobs[req_id] + if req_id in inputs.prompt_logprobs: + # If there are prompt logprobs associated with this request, + # detokenize them + self.modify_logprobs_in_place( + req_id, + inputs.prompt_logprobs[req_id], + ) + detokenized_prompt_logprobs[ + req_id] = inputs.prompt_logprobs[req_id] req_state = self.request_states[req_id] num_output_token_ids.append( len(req_state.token_ids) - req_state.num_prompt_tokens) @@ -140,7 +177,8 @@ def run(self): req_ids=inputs.req_ids, detokenized_texts=detokenized_texts, num_output_token_ids=num_output_token_ids, - ) + logprobs=detokenized_logprobs, + prompt_logprobs=detokenized_prompt_logprobs) self.push_socket.send(self.msgpack_encoder.encode(detokenized), flags=zmq.NOBLOCK) @@ -196,6 +234,47 @@ def detokenize(self, request_id: str, new_token_ids: List[int]) -> str: decoded_text += new_decoded_token_text return decoded_text + def detokenize_logprob_in_place( + self, + skip_special_tokens: bool, + logprob_dict: Dict[int, Logprob], + ) -> None: + """Cmopute `decoded_token` by detokenizing a logprob's token id""" + + for token_id in logprob_dict: + # Detokenize logprob for a particular top + # token at a particular token offset + logprob_dict[token_id].decoded_token = ( + self.tokenizer.convert_ids_to_tokens( + [token_id], skip_special_tokens=skip_special_tokens))[0] + + def modify_logprobs_in_place( + self, + request_id: str, + logprobs: AnyLogprobs, + ) -> None: + """Compute (in-place) the `decoded_token` field of a request's logprobs + + Behavior: for each token offset, for each top token, + compute `decoded_token` for that token. + + Args: + request_id + logprobs_list: request logprobs + """ + + if logprobs is not None: + # Request has logprobs + req_state = self.request_states[request_id] + skip_special_tokens = req_state.skip_special_tokens + for logprob_dict in logprobs: + if logprob_dict is not None: + # Logprobs at a token offset + self.detokenize_logprob_in_place( + skip_special_tokens, + logprob_dict, + ) + @dataclass class RequestState: