From 046f7c2b6255aecb9f2babcccaf811930bcf83e5 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 18 Oct 2024 08:19:53 +0100 Subject: [PATCH] [BugFix] Typing fixes to RequestOutput.prompt and beam search (#9473) Signed-off-by: qishuai --- vllm/beam_search.py | 8 +++--- vllm/engine/protocol.py | 59 ++++++++++++++++++++--------------------- vllm/entrypoints/llm.py | 1 + vllm/outputs.py | 3 +-- 4 files changed, 35 insertions(+), 36 deletions(-) diff --git a/vllm/beam_search.py b/vllm/beam_search.py index d30af35e2624b..59a808bf3a7a0 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -1,8 +1,7 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import Dict, List, Optional -if TYPE_CHECKING: - from vllm.multimodal import MultiModalDataDict +from vllm.sequence import Logprob @dataclass @@ -14,6 +13,7 @@ class BeamSearchSequence: """ # The tokens includes the prompt. tokens: List[int] + logprobs: List[Dict[int, Logprob]] cum_logprob: float = 0.0 text: Optional[str] = None finish_reason: Optional[str] = None @@ -35,7 +35,7 @@ class BeamSearchInstance: def __init__(self, prompt_tokens: List[int]): self.beams: List[BeamSearchSequence] = [ - BeamSearchSequence(tokens=prompt_tokens) + BeamSearchSequence(tokens=prompt_tokens, logprobs=[]) ] self.completed: List[BeamSearchSequence] = [] diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 6ed9efc271960..12709cff4578a 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -60,8 +60,7 @@ def generate( async def beam_search( self, - prompt: Union[PromptType, List[int]], - model_config: ModelConfig, + prompt: Union[str, List[int]], request_id: str, params: BeamSearchParams, ) -> AsyncGenerator[RequestOutput, None]: @@ -73,29 +72,25 @@ async def beam_search( length_penalty = params.length_penalty include_stop_str_in_output = params.include_stop_str_in_output - tokenizer = await self.get_tokenizer() - input_preprocessor = InputPreprocessor(model_config, tokenizer) - - (prompt_text, prompt_token_ids, multi_modal_data, - mm_processor_kwargs) = input_preprocessor._extract_prompt_components( - prompt, - request_id=request_id, - ) - tokenized_length = len(prompt_token_ids) + tokenizer = await self.get_tokenizer(lora_request=None) + if isinstance(prompt, str): + tokenized_prompt = tokenizer.encode(prompt) + prompt_text = prompt + else: + tokenized_prompt = prompt + prompt_text = None + tokenized_length = len(tokenized_prompt) sort_beams_key = create_sort_beams_key_function( tokenizer.eos_token_id, length_penalty) - beam_search_params = SamplingParams( - logprobs=2 * beam_width, - max_tokens=1, - temperature=temperature, - ) + beam_search_params = SamplingParams(logprobs=2 * beam_width, + max_tokens=1, + temperature=temperature) all_beams = [ - BeamSearchSequence(tokens=prompt_token_ids, - cum_logprob=0, - multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs) + BeamSearchSequence(tokens=tokenized_prompt, + logprobs=[], + cum_logprob=0) ] completed = [] @@ -129,6 +124,12 @@ async def beam_search( if result.outputs[0].logprobs is not None: logprobs = result.outputs[0].logprobs[0] for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + [logprobs], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob) + if token_id == tokenizer.eos_token_id and \ not ignore_eos: completed.append( @@ -165,18 +166,16 @@ async def beam_search( request_id=request_id, prompt=prompt_text, outputs=[ - CompletionOutput(text=beam.text, - cumulative_logprob=beam.cum_logprob, - token_ids=beam.tokens[tokenized_length:], - index=i, - logprobs=beam.cum_logprob, - finish_reason=beam.finish_reason if - beam.finish_reason is not None else "length", - stop_reason=beam.stop_reason) - for (i, beam) in enumerate(best_beams) + CompletionOutput( + text=beam.text, + cumulative_logprob=beam.cum_logprob, + token_ids=beam.tokens[tokenized_length:], + index=i, + logprobs=beam.logprobs, + ) for (i, beam) in enumerate(best_beams) ], finished=True, - prompt_token_ids=prompt_token_ids, + prompt_token_ids=tokenized_prompt, prompt_logprobs=None) yield beam_search_output diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 2010381076c7d..088ec35798de8 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -433,6 +433,7 @@ def sort_beams_key(x: BeamSearchSequence) -> float: for token_id, logprob_obj in logprobs.items(): new_beam = BeamSearchSequence( tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + [logprobs], cum_logprob=current_beam.cum_logprob + logprob_obj.logprob) diff --git a/vllm/outputs.py b/vllm/outputs.py index 15cb8d53186df..07650241cb638 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -4,7 +4,6 @@ from typing import Sequence as GenericSequence from typing import Union -from vllm.inputs import PromptType from vllm.lora.request import LoRARequest from vllm.sampling_params import RequestOutputKind from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, @@ -93,7 +92,7 @@ class RequestOutput: def __init__( self, request_id: str, - prompt: Optional[PromptType], + prompt: Optional[str], prompt_token_ids: Optional[List[int]], prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput],