diff --git a/tests/multimodal/test_image.py b/tests/multimodal/test_image.py new file mode 100644 index 0000000000000..93a9edec4c8c0 --- /dev/null +++ b/tests/multimodal/test_image.py @@ -0,0 +1,35 @@ +import pytest +from transformers import AutoConfig, AutoTokenizer + +from vllm.multimodal.image import repeat_and_pad_image_tokens + + +@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"]) +def test_repeat_and_pad_image_tokens(model): + config = AutoConfig.from_pretrained(model) + image_token_id = config.image_token_index + + tokenizer = AutoTokenizer.from_pretrained(model) + + test_cases = [ + ("", 2, "", [32000, 32000]), + ("", 2, "", [32000, 32000, 32000]), + ("", [3, 2], "", + [32000, 32000, 32000, 32000, 32000]), + ("Image:Image:!", [3, 2], + "Image:Image:!", + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]), + ("", [3, 2], "", [32000, 32000, 32000]), + ] + + for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases: + new_prompt, new_token_ids = repeat_and_pad_image_tokens( + tokenizer=tokenizer, + prompt=prompt, + prompt_token_ids=tokenizer.encode(prompt, + add_special_tokens=False), + image_token_id=image_token_id, + repeat_count=repeat_count, + ) + assert new_prompt == expected_prompt + assert new_token_ids == expected_token_ids diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 978a058b25f5a..d1478eb298178 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -1,7 +1,7 @@ """Minimal implementation of CLIPVisionModel intended to be only used within a vision language model.""" from array import array -from typing import Iterable, Optional, Tuple, Union, List +from typing import Iterable, List, Optional, Tuple, Union import torch import torch.nn as nn diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 84a31dacc3aea..63d70288f02ce 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -3,7 +3,7 @@ import math from array import array -from typing import Iterable, Optional, Tuple, Union, List +from typing import Iterable, List, Optional, Tuple, Union import torch from PIL import Image diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 04656884920e8..4dd6d42f24051 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -44,12 +44,12 @@ def repeat_and_pad_image_tokens( prompt_token_ids: List[int], *, image_token_id: int, - repeat_count: Union[int, List[int]] = 1, + repeat_count: Union[int, List[int]], pad_token_left: Optional[int] = None, pad_token_right: Optional[int] = None, ) -> Tuple[Optional[str], List[int]]: - if not isinstance(repeat_count, list): - repeat_count = [repeat_count] * len(prompt_token_ids) + if isinstance(repeat_count, int): + repeat_count = [repeat_count] if prompt is None: new_prompt = None @@ -59,13 +59,6 @@ def repeat_and_pad_image_tokens( tokenizer.decode(pad_token_left)) pad_token_str_right = (None if pad_token_right is None else tokenizer.decode(pad_token_right)) - replacement_str = "".join( - repeat_and_pad_token( - image_token_str, - repeat_count=repeat_count[0], - pad_token_left=pad_token_str_left, - pad_token_right=pad_token_str_right, - )) image_token_count = prompt.count(image_token_str) # This is an arbitrary number to distinguish between the two cases @@ -74,22 +67,45 @@ def repeat_and_pad_image_tokens( "Please follow the prompt format that is " "documented on HuggingFace which does not involve " "repeating %s tokens.", image_token_str) - - # The image tokens are removed to be consistent with HuggingFace - new_prompt = prompt.replace(image_token_str, replacement_str) + if image_token_count < len(repeat_count): + logger.warning( + "The number of image tokens in the prompt is less than " + "the number of image inputs. Extra image tokens will be " + "treated as plain text") + repeat_count = repeat_count[:image_token_count] + + prompt_parts = prompt.split(image_token_str, + maxsplit=len(repeat_count)) + new_prompt = "" + for i in range(len(repeat_count)): + replacement_str = "".join( + repeat_and_pad_token( + image_token_str, + repeat_count=repeat_count[i], + pad_token_left=pad_token_str_left, + pad_token_right=pad_token_str_right, + )) + # The image tokens are removed to be consistent with HuggingFace + new_prompt += prompt_parts[i] + replacement_str + new_prompt += prompt_parts[-1] new_token_ids: List[int] = [] - idx = 0 + image_token_idx = 0 for i, token in enumerate(prompt_token_ids): if token == image_token_id: replacement_ids = repeat_and_pad_token( image_token_id, - repeat_count=repeat_count[idx], + repeat_count=repeat_count[image_token_idx], pad_token_left=pad_token_left, pad_token_right=pad_token_right, ) new_token_ids.extend(replacement_ids) - idx += 1 + image_token_idx += 1 + + # No need to further scan the list since we replaced all tokens + if image_token_idx >= len(repeat_count): + new_token_ids.extend(prompt_token_ids[i + 1:]) + break else: new_token_ids.append(token)