Skip to content

Commit

Permalink
[BugFix] Typing fixes to RequestOutput.prompt and beam search (vllm-p…
Browse files Browse the repository at this point in the history
…roject#9473)

Signed-off-by: qishuai <[email protected]>
  • Loading branch information
njhill authored and FerdinandZhong committed Oct 29, 2024
1 parent d102412 commit 046f7c2
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 36 deletions.
8 changes: 4 additions & 4 deletions vllm/beam_search.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import Dict, List, Optional

if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict
from vllm.sequence import Logprob


@dataclass
Expand All @@ -14,6 +13,7 @@ class BeamSearchSequence:
"""
# The tokens includes the prompt.
tokens: List[int]
logprobs: List[Dict[int, Logprob]]
cum_logprob: float = 0.0
text: Optional[str] = None
finish_reason: Optional[str] = None
Expand All @@ -35,7 +35,7 @@ class BeamSearchInstance:

def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens)
BeamSearchSequence(tokens=prompt_tokens, logprobs=[])
]
self.completed: List[BeamSearchSequence] = []

Expand Down
59 changes: 29 additions & 30 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def generate(

async def beam_search(
self,
prompt: Union[PromptType, List[int]],
model_config: ModelConfig,
prompt: Union[str, List[int]],
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:
Expand All @@ -73,29 +72,25 @@ async def beam_search(
length_penalty = params.length_penalty
include_stop_str_in_output = params.include_stop_str_in_output

tokenizer = await self.get_tokenizer()
input_preprocessor = InputPreprocessor(model_config, tokenizer)

(prompt_text, prompt_token_ids, multi_modal_data,
mm_processor_kwargs) = input_preprocessor._extract_prompt_components(
prompt,
request_id=request_id,
)
tokenized_length = len(prompt_token_ids)
tokenizer = await self.get_tokenizer(lora_request=None)
if isinstance(prompt, str):
tokenized_prompt = tokenizer.encode(prompt)
prompt_text = prompt
else:
tokenized_prompt = prompt
prompt_text = None
tokenized_length = len(tokenized_prompt)

sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)

beam_search_params = SamplingParams(
logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature,
)
beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
all_beams = [
BeamSearchSequence(tokens=prompt_token_ids,
cum_logprob=0,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs)
BeamSearchSequence(tokens=tokenized_prompt,
logprobs=[],
cum_logprob=0)
]
completed = []

Expand Down Expand Up @@ -129,6 +124,12 @@ async def beam_search(
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

if token_id == tokenizer.eos_token_id and \
not ignore_eos:
completed.append(
Expand Down Expand Up @@ -165,18 +166,16 @@ async def beam_search(
request_id=request_id,
prompt=prompt_text,
outputs=[
CompletionOutput(text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens[tokenized_length:],
index=i,
logprobs=beam.cum_logprob,
finish_reason=beam.finish_reason if
beam.finish_reason is not None else "length",
stop_reason=beam.stop_reason)
for (i, beam) in enumerate(best_beams)
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens[tokenized_length:],
index=i,
logprobs=beam.logprobs,
) for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=prompt_token_ids,
prompt_token_ids=tokenized_prompt,
prompt_logprobs=None)

yield beam_search_output
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ def sort_beams_key(x: BeamSearchSequence) -> float:
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

Expand Down
3 changes: 1 addition & 2 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Sequence as GenericSequence
from typing import Union

from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
Expand Down Expand Up @@ -93,7 +92,7 @@ class RequestOutput:
def __init__(
self,
request_id: str,
prompt: Optional[PromptType],
prompt: Optional[str],
prompt_token_ids: Optional[List[int]],
prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput],
Expand Down

0 comments on commit 046f7c2

Please sign in to comment.