diff --git a/tests/multimodal/test_image.py b/tests/multimodal/test_image.py deleted file mode 100644 index 93a9edec4c8c0..0000000000000 --- a/tests/multimodal/test_image.py +++ /dev/null @@ -1,35 +0,0 @@ -import pytest -from transformers import AutoConfig, AutoTokenizer - -from vllm.multimodal.image import repeat_and_pad_image_tokens - - -@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"]) -def test_repeat_and_pad_image_tokens(model): - config = AutoConfig.from_pretrained(model) - image_token_id = config.image_token_index - - tokenizer = AutoTokenizer.from_pretrained(model) - - test_cases = [ - ("", 2, "", [32000, 32000]), - ("", 2, "", [32000, 32000, 32000]), - ("", [3, 2], "", - [32000, 32000, 32000, 32000, 32000]), - ("Image:Image:!", [3, 2], - "Image:Image:!", - [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]), - ("", [3, 2], "", [32000, 32000, 32000]), - ] - - for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases: - new_prompt, new_token_ids = repeat_and_pad_image_tokens( - tokenizer=tokenizer, - prompt=prompt, - prompt_token_ids=tokenizer.encode(prompt, - add_special_tokens=False), - image_token_id=image_token_id, - repeat_count=repeat_count, - ) - assert new_prompt == expected_prompt - assert new_token_ids == expected_token_ids diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index cd1fc91c29374..d8930744027e2 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -3,11 +3,12 @@ from tempfile import NamedTemporaryFile from typing import Dict, Tuple +from transformers import AutoConfig, AutoTokenizer import numpy as np import pytest from PIL import Image -from vllm.multimodal.utils import async_fetch_image, fetch_image +from vllm.multimodal.utils import (async_fetch_image, fetch_image, repeat_and_pad_placeholder_tokens) # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) TEST_IMAGE_URLS = [ @@ -80,3 +81,33 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image], data_image_async = await async_fetch_image(data_url) assert _image_equals(data_image_sync, data_image_async) + +@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"]) +def test_repeat_and_pad_placeholder_tokens(model): + config = AutoConfig.from_pretrained(model) + image_token_id = config.image_token_index + + tokenizer = AutoTokenizer.from_pretrained(model) + + test_cases = [ + ("", 2, "", [32000, 32000]), + ("", 2, "", [32000, 32000, 32000]), + ("", [3, 2], "", + [32000, 32000, 32000, 32000, 32000]), + ("Image:Image:!", [3, 2], + "Image:Image:!", + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]), + ("", [3, 2], "", [32000, 32000, 32000]), + ] + + for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases: + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + tokenizer=tokenizer, + prompt=prompt, + prompt_token_ids=tokenizer.encode(prompt, + add_special_tokens=False), + placeholder_token_id=image_token_id, + repeat_count=repeat_count, + ) + assert new_prompt == expected_prompt + assert new_token_ids == expected_token_ids diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index e70bbc9c49076..de94bdc74d1f0 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -183,12 +183,12 @@ def repeat_and_pad_token( return replacement -def repeat_and_pad_image_tokens( +def repeat_and_pad_placeholder_tokens( tokenizer: AnyTokenizer, prompt: Optional[str], prompt_token_ids: List[int], *, - image_token_id: int, + placeholder_token_id: int, repeat_count: Union[int, List[int]], pad_token_left: Optional[int] = None, pad_token_right: Optional[int] = None, @@ -199,19 +199,19 @@ def repeat_and_pad_image_tokens( if prompt is None: new_prompt = None else: - image_token_str = tokenizer.decode(image_token_id) + placeholder_token_str = tokenizer.decode(placeholder_token_id) pad_token_str_left = (None if pad_token_left is None else tokenizer.decode(pad_token_left)) pad_token_str_right = (None if pad_token_right is None else tokenizer.decode(pad_token_right)) - image_token_count = prompt.count(image_token_str) + image_token_count = prompt.count(placeholder_token_str) # This is an arbitrary number to distinguish between the two cases if image_token_count > 16: logger.warning( "Please follow the prompt format that is " "documented on HuggingFace which does not involve " - "repeating %s tokens.", image_token_str) + "repeating %s tokens.", placeholder_token_str) if image_token_count < len(repeat_count): logger.warning( "The number of image tokens in the prompt is less than " @@ -219,13 +219,13 @@ def repeat_and_pad_image_tokens( "treated as plain text") repeat_count = repeat_count[:image_token_count] - prompt_parts = prompt.split(image_token_str, + prompt_parts = prompt.split(placeholder_token_str, maxsplit=len(repeat_count)) new_prompt = "" for i, repeat_count_item in enumerate(repeat_count): replacement_str = "".join( repeat_and_pad_token( - image_token_str, + placeholder_token_str, repeat_count=repeat_count_item, pad_token_left=pad_token_str_left, pad_token_right=pad_token_str_right, @@ -235,20 +235,20 @@ def repeat_and_pad_image_tokens( new_prompt += prompt_parts[-1] new_token_ids: List[int] = [] - image_token_idx = 0 + placeholder_token_idx = 0 for i, token in enumerate(prompt_token_ids): - if token == image_token_id: + if token == placeholder_token_id: replacement_ids = repeat_and_pad_token( - image_token_id, - repeat_count=repeat_count[image_token_idx], + placeholder_token_id, + repeat_count=repeat_count[placeholder_token_idx], pad_token_left=pad_token_left, pad_token_right=pad_token_right, ) new_token_ids.extend(replacement_ids) - image_token_idx += 1 + placeholder_token_idx += 1 # No need to further scan the list since we replaced all tokens - if image_token_idx >= len(repeat_count): + if placeholder_token_idx >= len(repeat_count): new_token_ids.extend(prompt_token_ids[i + 1:]) break else: