Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Typing fixes to RequestOutput.prompt and beam search #9473

Merged
merged 2 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions vllm/beam_search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from dataclasses import dataclass
from typing import List, Optional
from typing import Dict, List, Optional

from vllm.sequence import Logprob


@dataclass
Expand All @@ -11,6 +13,7 @@ class BeamSearchSequence:
"""
# The tokens includes the prompt.
tokens: List[int]
logprobs: List[Dict[int, Logprob]]
cum_logprob: float = 0.0
text: Optional[str] = None

Expand All @@ -28,7 +31,7 @@ class BeamSearchInstance:

def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens)
BeamSearchSequence(tokens=prompt_tokens, logprobs=[])
]
self.completed: List[BeamSearchSequence] = []

Expand Down
29 changes: 19 additions & 10 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def generate(

async def beam_search(
self,
prompt: Union[PromptType, List[int]],
prompt: Union[str, List[int]],
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:
Expand All @@ -71,17 +71,25 @@ async def beam_search(
length_penalty = params.length_penalty

tokenizer = await self.get_tokenizer(lora_request=None)
tokenizedPrompt = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt)
if isinstance(prompt, str):
tokenized_prompt = tokenizer.encode(prompt)
prompt_text = prompt
else:
tokenized_prompt = prompt
prompt_text = None
tokenized_length = len(tokenized_prompt)

sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)

beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
all_beams = [
BeamSearchSequence(tokens=tokenized_prompt,
logprobs=[],
cum_logprob=0)
]
completed = []

for _ in range(max_tokens):
Expand Down Expand Up @@ -114,6 +122,7 @@ async def beam_search(
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

Expand All @@ -131,22 +140,22 @@ async def beam_search(
best_beams = sorted_completed[:beam_width]

for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
beam.text = tokenizer.decode(beam.tokens[tokenized_length:])

beam_search_output = RequestOutput(
request_id=request_id,
prompt=prompt,
prompt=prompt_text,
outputs=[
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens,
token_ids=beam.tokens[tokenized_length:],
index=i,
logprobs=beam.cum_logprob,
logprobs=beam.logprobs,
) for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=tokenizedPrompt,
prompt_token_ids=tokenized_prompt,
prompt_logprobs=None)

yield beam_search_output
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ def sort_beams_key(x: BeamSearchSequence) -> float:
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

Expand Down
3 changes: 1 addition & 2 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Sequence as GenericSequence
from typing import Union

from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
Expand Down Expand Up @@ -93,7 +92,7 @@ class RequestOutput:
def __init__(
self,
request_id: str,
prompt: Optional[PromptType],
prompt: Optional[str],
prompt_token_ids: Optional[List[int]],
prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput],
Expand Down