From eb060d42ea812a81a6fd77c106171a79e7895bdd Mon Sep 17 00:00:00 2001 From: zifeitong Date: Tue, 27 Aug 2024 16:09:02 -0700 Subject: [PATCH] [Model] Add multi-image input support for LLaVA-Next offline inference (#7230) --- tests/conftest.py | 21 +++--- tests/models/test_llava_next.py | 93 ++++++++++++++++++++---- tests/multimodal/test_utils.py | 35 ++++++++- vllm/model_executor/models/clip.py | 8 +- vllm/model_executor/models/llava_next.py | 14 +++- vllm/model_executor/models/siglip.py | 4 +- vllm/multimodal/utils.py | 51 ++++++++----- 7 files changed, 174 insertions(+), 52 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index ae362b228d9d8..d8264f65b6149 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,6 +41,10 @@ _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")] +PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]] +PromptAudioInput = Union[List[Tuple[np.ndarray, int]], + List[List[Tuple[np.ndarray, int]]]] + def _read_prompts(filename: str) -> List[str]: with open(filename, "r") as f: @@ -161,7 +165,7 @@ def example_encoder_decoder_prompts( decoder prompt) tuple. Returns: - + * Encoder prompt list * Decoder prompt list (reverse of encoder prompt list) ''' @@ -578,8 +582,7 @@ def generate( self, prompts: List[str], sampling_params: SamplingParams, - images: Optional[Union[List[Image.Image], - List[List[Image.Image]]]] = None, + images: Optional[PromptImageInput] = None, ) -> List[Tuple[List[List[int]], List[str]]]: if images is not None: assert len(prompts) == len(images) @@ -623,10 +626,8 @@ def generate_w_logprobs( self, prompts: List[str], sampling_params: SamplingParams, - images: Optional[Union[List[Image.Image], - List[List[Image.Image]]]] = None, - audios: Optional[Union[List[Tuple[np.ndarray, int]], - List[List[Tuple[np.ndarray, int]]]]] = None + images: Optional[PromptImageInput] = None, + audios: Optional[PromptAudioInput] = None, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: assert sampling_params.logprobs is not None @@ -676,10 +677,8 @@ def generate_greedy_logprobs( prompts: List[str], max_tokens: int, num_logprobs: int, - images: Optional[Union[List[Image.Image], - List[List[Image.Image]]]] = None, - audios: Optional[Union[List[Tuple[np.ndarray, int]], - List[List[Tuple[np.ndarray, int]]]]] = None, + images: Optional[PromptImageInput] = None, + audios: Optional[PromptAudioInput] = None, stop_token_ids: Optional[List[int]] = None, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: greedy_logprobs_params = SamplingParams(temperature=0.0, diff --git a/tests/models/test_llava_next.py b/tests/models/test_llava_next.py index 9cf55c0858df0..d5fe0cbe32880 100644 --- a/tests/models/test_llava_next.py +++ b/tests/models/test_llava_next.py @@ -6,24 +6,22 @@ from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) from .utils import check_logprobs_close pytestmark = pytest.mark.vlm -_PREFACE = ( - "A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's " - "questions.") +_LIMIT_IMAGE_PER_PROMPT = 4 HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": - f"{_PREFACE} USER: \nWhat's the content of the image? ASSISTANT:", + "[INST] \nWhat's the content of the image? [/INST]", "cherry_blossom": - f"{_PREFACE} USER: \nWhat is the season? ASSISTANT:", + "[INST] \nWhat is the season? [/INST]", }) -models = ["llava-hf/llava-v1.6-vicuna-7b-hf"] +models = ["llava-hf/llava-v1.6-mistral-7b-hf"] def vllm_to_hf_output(vllm_output: Tuple[List[int], str, @@ -114,19 +112,43 @@ def run_test( else: raise ValueError("You must provide either `size_factors` or `sizes`") + _run_test(hf_runner, + vllm_runner, + inputs_per_image, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend) + + +def _run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + inputs: List[Tuple[List[str], PromptImageInput]], + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): # max_model_len should be greater than image_feature_size with vllm_runner(model, dtype=dtype, - max_model_len=4096, + max_model_len=10240, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: + enforce_eager=True, + limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT + }) as vllm_model: vllm_outputs_per_image = [ vllm_model.generate_greedy_logprobs(prompts, max_tokens, num_logprobs=num_logprobs, images=images) - for prompts, images in inputs_per_image + for prompts, images in inputs ] with hf_runner(model, dtype=dtype, @@ -136,7 +158,7 @@ def run_test( max_tokens, num_logprobs=num_logprobs, images=images) - for prompts, images in inputs_per_image + for prompts, images in inputs ] for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, @@ -177,7 +199,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, All the image fixtures for the test is under tests/images. For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects + For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. @@ -216,3 +238,48 @@ def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes, num_logprobs=num_logprobs, tensor_parallel_size=1, ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets, + model, dtype, max_tokens, + num_logprobs) -> None: + stop_sign = image_assets[0].pil_image + cherry_blossom = image_assets[1].pil_image + + inputs = [( + [ + "[INST] \nDescribe 2 images. [/INST]", + "[INST] \nDescribe 2 images. [/INST]", + "[INST] \nDescribe 4 images. [/INST]", + "[INST] \nWhat is the season? [/INST]" + ], + [ + [stop_sign, cherry_blossom], + # Images with different sizes and aspect-ratios + [ + rescale_image_size(stop_sign, 0.1), + stop_sign, + ], + [ + stop_sign, + rescale_image_size(stop_sign, 0.25), + cherry_blossom.resize((183, 488)), + cherry_blossom.resize((488, 183)) + ], + cherry_blossom, + ])] + + _run_test( + hf_runner, + vllm_runner, + inputs, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index cd1fc91c29374..38cd48629f903 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -6,8 +6,10 @@ import numpy as np import pytest from PIL import Image +from transformers import AutoConfig, AutoTokenizer -from vllm.multimodal.utils import async_fetch_image, fetch_image +from vllm.multimodal.utils import (async_fetch_image, fetch_image, + repeat_and_pad_placeholder_tokens) # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) TEST_IMAGE_URLS = [ @@ -80,3 +82,34 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image], data_image_async = await async_fetch_image(data_url) assert _image_equals(data_image_sync, data_image_async) + + +@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"]) +def test_repeat_and_pad_placeholder_tokens(model): + config = AutoConfig.from_pretrained(model) + image_token_id = config.image_token_index + + tokenizer = AutoTokenizer.from_pretrained(model) + + test_cases = [ + ("", 2, "", [32000, 32000]), + ("", 2, "", [32000, 32000, 32000]), + ("", [3, 2], "", + [32000, 32000, 32000, 32000, 32000]), + ("Image:Image:!", [3, 2], + "Image:Image:!", + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]), + ("", [3, 2], "", [32000, 32000, 32000]), + ] + + for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases: + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + tokenizer=tokenizer, + prompt=prompt, + prompt_token_ids=tokenizer.encode(prompt, + add_special_tokens=False), + placeholder_token_id=image_token_id, + repeat_count=repeat_count, + ) + assert new_prompt == expected_prompt + assert new_token_ids == expected_token_ids diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 0933966055330..69bb9f6f3afee 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -1,7 +1,7 @@ -"""Minimal implementation of CLIPVisionModel intended to be only used +"""Minimal implementation of CLIPVisionModel intended to be only used within a vision language model.""" from array import array -from typing import Iterable, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -84,7 +84,7 @@ def input_processor_for_clip( llm_inputs: LLMInputs, *, image_token_id: int, - image_feature_size_override: Optional[int] = None, + image_feature_size_override: Optional[Union[int, List[int]]] = None, ): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: @@ -217,7 +217,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class CLIPEncoder(nn.Module): """ - Transformer encoder consisting of `config.num_hidden_layers` self + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`CLIPEncoderLayer`]. Args: diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index c7cb243fa84da..7c096a3794638 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -19,6 +19,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.utils import is_list_of from .clip import (CLIPVisionModel, dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_image_feature_size, @@ -223,6 +224,13 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): input_height=height, input_width=width, ) + elif is_list_of(image_data, Image.Image): + image_feature_size = [ + get_llava_next_image_feature_size(hf_config, + input_height=img.height, + input_width=img.width) + for img in image_data + ] elif isinstance(image_data, torch.Tensor): image_feature_size = image_data.shape[0] else: @@ -425,7 +433,10 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor, self.config.image_grid_pinpoints, self.config.vision_config.image_size, ) - other_patch_embeds = other_patch_embeds \ + num_patches = num_patch_height * num_patch_width + + # Image patches might be padded for batch processing + other_patch_embeds = other_patch_embeds[:num_patches] \ .view(num_patch_height, num_patch_width, height, width, -1) if "unpad" in strategy: @@ -496,7 +507,6 @@ def _process_image_input( self, image_input: LlavaNextImageInputs, ) -> Union[torch.Tensor, List[torch.Tensor]]: - if image_input["type"] == "image_embeds": return [image_input["data"]] diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 7f6186fa010a4..073f60bb3a056 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -3,7 +3,7 @@ import math from array import array -from typing import Iterable, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from PIL import Image @@ -93,7 +93,7 @@ def input_processor_for_siglip( llm_inputs: LLMInputs, *, image_token_id: int, - image_feature_size_override: Optional[int] = None, + image_feature_size_override: Optional[Union[int, List[int]]] = None, ): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 3bf430235462b..989b2e1a814c9 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -189,10 +189,13 @@ def repeat_and_pad_placeholder_tokens( prompt_token_ids: List[int], *, placeholder_token_id: int, - repeat_count: int = 1, + repeat_count: Union[int, List[int]], pad_token_left: Optional[int] = None, pad_token_right: Optional[int] = None, ) -> Tuple[Optional[str], List[int]]: + if isinstance(repeat_count, int): + repeat_count = [repeat_count] + if prompt is None: new_prompt = None else: @@ -201,13 +204,6 @@ def repeat_and_pad_placeholder_tokens( tokenizer.decode(pad_token_left)) pad_token_str_right = (None if pad_token_right is None else tokenizer.decode(pad_token_right)) - replacement_str = "".join( - repeat_and_pad_token( - placeholder_token_str, - repeat_count=repeat_count, - pad_token_left=pad_token_str_left, - pad_token_right=pad_token_str_right, - )) placeholder_token_count = prompt.count(placeholder_token_str) # This is an arbitrary number to distinguish between the two cases @@ -216,28 +212,45 @@ def repeat_and_pad_placeholder_tokens( "Please follow the prompt format that is " "documented on HuggingFace which does not involve " "repeating %s tokens.", placeholder_token_str) - elif placeholder_token_count > 1: - logger.warning("Multiple multi-modal input is not supported yet, " - "so any extra placeholder tokens will be treated " - "as plain text.") - - # The image tokens are removed to be consistent with HuggingFace - new_prompt = prompt.replace(placeholder_token_str, replacement_str, 1) + if placeholder_token_count < len(repeat_count): + logger.warning( + "The number of multi-modal placeholder tokens in the prompt " + "is less than the number of multi-modal inputs. Extra " + "placeholder tokens will be treated as plain text") + repeat_count = repeat_count[:placeholder_token_count] + + prompt_parts = prompt.split(placeholder_token_str, + maxsplit=len(repeat_count)) + new_prompt = "" + for i, repeat_count_item in enumerate(repeat_count): + replacement_str = "".join( + repeat_and_pad_token( + placeholder_token_str, + repeat_count=repeat_count_item, + pad_token_left=pad_token_str_left, + pad_token_right=pad_token_str_right, + )) + # The image tokens are removed to be consistent with HuggingFace + new_prompt += prompt_parts[i] + replacement_str + new_prompt += prompt_parts[-1] new_token_ids: List[int] = [] + placeholder_token_idx = 0 for i, token in enumerate(prompt_token_ids): if token == placeholder_token_id: replacement_ids = repeat_and_pad_token( placeholder_token_id, - repeat_count=repeat_count, + repeat_count=repeat_count[placeholder_token_idx], pad_token_left=pad_token_left, pad_token_right=pad_token_right, ) new_token_ids.extend(replacement_ids) + placeholder_token_idx += 1 - # No need to further scan the list since we only replace once - new_token_ids.extend(prompt_token_ids[i + 1:]) - break + # No need to further scan the list since we replaced all tokens + if placeholder_token_idx >= len(repeat_count): + new_token_ids.extend(prompt_token_ids[i + 1:]) + break else: new_token_ids.append(token)