From a1716f5f79c668f00674ca95eebcdacf82c041b1 Mon Sep 17 00:00:00 2001 From: Zhong Qishuai Date: Tue, 29 Oct 2024 19:49:47 +0800 Subject: [PATCH] [Frontend] re-enable multi-modality input in the new beam search implementation (#9427) Signed-off-by: Qishuai Ferdinandzhong@gmail.com Signed-off-by: NickLucche --- tests/entrypoints/openai/test_vision.py | 71 +++++++++++++++ vllm/beam_search.py | 9 +- vllm/engine/protocol.py | 88 ++++++++++++------- vllm/entrypoints/openai/protocol.py | 4 +- vllm/entrypoints/openai/serving_chat.py | 7 +- vllm/entrypoints/openai/serving_completion.py | 10 ++- vllm/sampling_params.py | 1 + 7 files changed, 150 insertions(+), 40 deletions(-) diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 8311a5cb3c2d4..68804d6833c73 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -107,6 +107,42 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI, + model_name: str, + image_url: str): + messages = [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }] + + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + n=2, + max_tokens=10, + logprobs=True, + top_logprobs=5, + extra_body=dict(use_beam_search=True)) + assert len(chat_completion.choices) == 2 + assert chat_completion.choices[ + 0].message.content != chat_completion.choices[1].message.content + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @@ -162,6 +198,41 @@ async def test_single_chat_session_image_base64encoded( assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +async def test_single_chat_session_image_base64encoded_beamsearch( + client: openai.AsyncOpenAI, model_name: str, image_url: str, + base64_encoded_image: Dict[str, str]): + + messages = [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": + f"data:image/jpeg;base64,{base64_encoded_image[image_url]}" + } + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }] + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + n=2, + max_tokens=10, + extra_body=dict(use_beam_search=True)) + assert len(chat_completion.choices) == 2 + assert chat_completion.choices[ + 0].message.content != chat_completion.choices[1].message.content + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) diff --git a/vllm/beam_search.py b/vllm/beam_search.py index 1b48538734dae..026037e5434d1 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -1,8 +1,11 @@ from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from vllm.sequence import Logprob +if TYPE_CHECKING: + from vllm.multimodal import MultiModalDataDict + @dataclass class BeamSearchSequence: @@ -16,6 +19,10 @@ class BeamSearchSequence: logprobs: List[Dict[int, Logprob]] cum_logprob: float = 0.0 text: Optional[str] = None + finish_reason: Optional[str] = None + stop_reason: Union[int, str, None] = None + multi_modal_data: Optional["MultiModalDataDict"] = None + mm_processor_kwargs: Optional[Dict[str, Any]] = None @dataclass diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index b00dd136d4a47..6a09361c56865 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -6,6 +6,7 @@ from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs from vllm.inputs.data import PromptType, TokensPrompt +from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -59,7 +60,8 @@ def generate( async def beam_search( self, - prompt: Union[str, List[int]], + prompt: Union[PromptType, List[int]], + model_config: ModelConfig, request_id: str, params: BeamSearchParams, ) -> AsyncGenerator[RequestOutput, None]: @@ -69,32 +71,40 @@ async def beam_search( ignore_eos = params.ignore_eos temperature = params.temperature length_penalty = params.length_penalty + include_stop_str_in_output = params.include_stop_str_in_output - tokenizer = await self.get_tokenizer(lora_request=None) - if isinstance(prompt, str): - tokenized_prompt = tokenizer.encode(prompt) - prompt_text = prompt - else: - tokenized_prompt = prompt - prompt_text = None - tokenized_length = len(tokenized_prompt) + tokenizer = await self.get_tokenizer() + input_preprocessor = InputPreprocessor(model_config, tokenizer) + + (prompt_text, prompt_token_ids, multi_modal_data, + mm_processor_kwargs) = input_preprocessor._extract_prompt_components( + prompt, + request_id=request_id, + ) + tokenized_length = len(prompt_token_ids) sort_beams_key = create_sort_beams_key_function( tokenizer.eos_token_id, length_penalty) - beam_search_params = SamplingParams(logprobs=2 * beam_width, - max_tokens=1, - temperature=temperature) + beam_search_params = SamplingParams( + logprobs=2 * beam_width, + max_tokens=1, + temperature=temperature, + ) all_beams = [ - BeamSearchSequence(tokens=tokenized_prompt, + BeamSearchSequence(tokens=prompt_token_ids, + cum_logprob=0, logprobs=[], - cum_logprob=0) + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs) ] completed = [] for _ in range(max_tokens): prompts_batch = [ - TokensPrompt(prompt_token_ids=beam.tokens) + TokensPrompt(prompt_token_ids=beam.tokens, + multi_modal_data=beam.multi_modal_data, + mm_processor_kwargs=beam.mm_processor_kwargs) for beam in all_beams ] @@ -120,17 +130,31 @@ async def beam_search( if result.outputs[0].logprobs is not None: logprobs = result.outputs[0].logprobs[0] for token_id, logprob_obj in logprobs.items(): - new_beam = BeamSearchSequence( - tokens=current_beam.tokens + [token_id], - logprobs=current_beam.logprobs + [logprobs], - cum_logprob=current_beam.cum_logprob + - logprob_obj.logprob) - if token_id == tokenizer.eos_token_id and \ not ignore_eos: - completed.append(new_beam) + completed.append( + BeamSearchSequence( + tokens=current_beam.tokens + + [token_id] if include_stop_str_in_output + else current_beam.tokens, + logprobs=current_beam.logprobs + + [logprobs], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + finish_reason="stop", + stop_reason=tokenizer.eos_token_id)) else: - new_beams.append(new_beam) + new_beams.append( + BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + + [logprobs], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + multi_modal_data=current_beam. + multi_modal_data, + mm_processor_kwargs=current_beam. + mm_processor_kwargs)) sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) all_beams = sorted_beams[:beam_width] @@ -151,16 +175,18 @@ async def beam_search( request_id=request_id, prompt=prompt_text, outputs=[ - CompletionOutput( - text=beam.text, - cumulative_logprob=beam.cum_logprob, - token_ids=beam.tokens[tokenized_length:], - index=i, - logprobs=beam.logprobs, - ) for (i, beam) in enumerate(best_beams) + CompletionOutput(text=beam.text, + cumulative_logprob=beam.cum_logprob, + token_ids=beam.tokens[tokenized_length:], + index=i, + logprobs=beam.logprobs, + finish_reason=beam.finish_reason if + beam.finish_reason is not None else "length", + stop_reason=beam.stop_reason) + for (i, beam) in enumerate(best_beams) ], finished=True, - prompt_token_ids=tokenized_prompt, + prompt_token_ids=prompt_token_ids, prompt_logprobs=None) yield beam_search_output diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index a212c0d608ddb..7f270a81a7692 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -308,7 +308,7 @@ def to_beam_search_params(self, ignore_eos=self.ignore_eos, temperature=temperature, length_penalty=self.length_penalty, - ) + include_stop_str_in_output=self.include_stop_str_in_output) def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: max_tokens = self.max_tokens @@ -606,7 +606,7 @@ def to_beam_search_params(self, ignore_eos=self.ignore_eos, temperature=temperature, length_penalty=self.length_penalty, - ) + include_stop_str_in_output=self.include_stop_str_in_output) def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: max_tokens = self.max_tokens diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index cd2883a3b323b..1f951d15a7a32 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -236,9 +236,10 @@ async def create_chat_completion( if isinstance(sampling_params, BeamSearchParams): result_generator = self.engine_client.beam_search( - engine_inputs['prompt_token_ids'], - request_id, - sampling_params, + prompt=engine_inputs, + model_config=self.model_config, + request_id=request_id, + params=sampling_params, ) else: result_generator = self.engine_client.generate( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 56e35950410a0..da521a6012530 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -150,9 +150,13 @@ async def create_completion( if isinstance(sampling_params, BeamSearchParams): generator = self.engine_client.beam_search( - prompt_inputs["prompt_token_ids"], - request_id_item, - sampling_params, + prompt={ + "prompt_token_ids": + prompt_inputs["prompt_token_ids"] + }, + model_config=self.model_config, + request_id=request_id, + params=sampling_params, ) else: generator = self.engine_client.generate( diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index bac32c991a0e3..5e191c6e715e0 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -500,3 +500,4 @@ class BeamSearchParams( ignore_eos: bool = False temperature: float = 0.0 length_penalty: float = 1.0 + include_stop_str_in_output: bool = False