Skip to content

Commit

Permalink
more fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zifeitong committed Aug 22, 2024
1 parent a518db7 commit 00a8d2c
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 49 deletions.
35 changes: 0 additions & 35 deletions tests/multimodal/test_image.py

This file was deleted.

33 changes: 32 additions & 1 deletion tests/multimodal/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from tempfile import NamedTemporaryFile
from typing import Dict, Tuple

from transformers import AutoConfig, AutoTokenizer
import numpy as np
import pytest
from PIL import Image

from vllm.multimodal.utils import async_fetch_image, fetch_image
from vllm.multimodal.utils import (async_fetch_image, fetch_image, repeat_and_pad_placeholder_tokens)

# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [
Expand Down Expand Up @@ -80,3 +81,33 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],

data_image_async = await async_fetch_image(data_url)
assert _image_equals(data_image_sync, data_image_async)

@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
def test_repeat_and_pad_placeholder_tokens(model):
config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index

tokenizer = AutoTokenizer.from_pretrained(model)

test_cases = [
("<image>", 2, "<image><image>", [32000, 32000]),
("<image><image>", 2, "<image><image><image>", [32000, 32000, 32000]),
("<image><image>", [3, 2], "<image><image><image><image><image>",
[32000, 32000, 32000, 32000, 32000]),
("Image:<image>Image:<image>!", [3, 2],
"Image:<image><image><image>Image:<image><image>!",
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]),
("<image>", [3, 2], "<image><image><image>", [32000, 32000, 32000]),
]

for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases:
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer=tokenizer,
prompt=prompt,
prompt_token_ids=tokenizer.encode(prompt,
add_special_tokens=False),
placeholder_token_id=image_token_id,
repeat_count=repeat_count,
)
assert new_prompt == expected_prompt
assert new_token_ids == expected_token_ids
26 changes: 13 additions & 13 deletions vllm/multimodal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,12 @@ def repeat_and_pad_token(
return replacement


def repeat_and_pad_image_tokens(
def repeat_and_pad_placeholder_tokens(
tokenizer: AnyTokenizer,
prompt: Optional[str],
prompt_token_ids: List[int],
*,
image_token_id: int,
placeholder_token_id: int,
repeat_count: Union[int, List[int]],
pad_token_left: Optional[int] = None,
pad_token_right: Optional[int] = None,
Expand All @@ -199,33 +199,33 @@ def repeat_and_pad_image_tokens(
if prompt is None:
new_prompt = None
else:
image_token_str = tokenizer.decode(image_token_id)
placeholder_token_str = tokenizer.decode(placeholder_token_id)
pad_token_str_left = (None if pad_token_left is None else
tokenizer.decode(pad_token_left))
pad_token_str_right = (None if pad_token_right is None else
tokenizer.decode(pad_token_right))

image_token_count = prompt.count(image_token_str)
image_token_count = prompt.count(placeholder_token_str)
# This is an arbitrary number to distinguish between the two cases
if image_token_count > 16:
logger.warning(
"Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"repeating %s tokens.", image_token_str)
"repeating %s tokens.", placeholder_token_str)
if image_token_count < len(repeat_count):
logger.warning(
"The number of image tokens in the prompt is less than "
"the number of image inputs. Extra image tokens will be "
"treated as plain text")
repeat_count = repeat_count[:image_token_count]

prompt_parts = prompt.split(image_token_str,
prompt_parts = prompt.split(placeholder_token_str,
maxsplit=len(repeat_count))
new_prompt = ""
for i, repeat_count_item in enumerate(repeat_count):
replacement_str = "".join(
repeat_and_pad_token(
image_token_str,
placeholder_token_str,
repeat_count=repeat_count_item,
pad_token_left=pad_token_str_left,
pad_token_right=pad_token_str_right,
Expand All @@ -235,20 +235,20 @@ def repeat_and_pad_image_tokens(
new_prompt += prompt_parts[-1]

new_token_ids: List[int] = []
image_token_idx = 0
placeholder_token_idx = 0
for i, token in enumerate(prompt_token_ids):
if token == image_token_id:
if token == placeholder_token_id:
replacement_ids = repeat_and_pad_token(
image_token_id,
repeat_count=repeat_count[image_token_idx],
placeholder_token_id,
repeat_count=repeat_count[placeholder_token_idx],
pad_token_left=pad_token_left,
pad_token_right=pad_token_right,
)
new_token_ids.extend(replacement_ids)
image_token_idx += 1
placeholder_token_idx += 1

# No need to further scan the list since we replaced all tokens
if image_token_idx >= len(repeat_count):
if placeholder_token_idx >= len(repeat_count):
new_token_ids.extend(prompt_token_ids[i + 1:])
break
else:
Expand Down

0 comments on commit 00a8d2c

Please sign in to comment.