Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zifeitong committed Aug 20, 2024
1 parent 97c1069 commit b4a0bcc
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 18 deletions.
35 changes: 35 additions & 0 deletions tests/multimodal/test_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest
from transformers import AutoConfig, AutoTokenizer

from vllm.multimodal.image import repeat_and_pad_image_tokens


@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
def test_repeat_and_pad_image_tokens(model):
config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index

tokenizer = AutoTokenizer.from_pretrained(model)

test_cases = [
("<image>", 2, "<image><image>", [32000, 32000]),
("<image><image>", 2, "<image><image><image>", [32000, 32000, 32000]),
("<image><image>", [3, 2], "<image><image><image><image><image>",
[32000, 32000, 32000, 32000, 32000]),
("Image:<image>Image:<image>!", [3, 2],
"Image:<image><image><image>Image:<image><image>!",
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]),
("<image>", [3, 2], "<image><image><image>", [32000, 32000, 32000]),
]

for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases:
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
tokenizer=tokenizer,
prompt=prompt,
prompt_token_ids=tokenizer.encode(prompt,
add_special_tokens=False),
image_token_id=image_token_id,
repeat_count=repeat_count,
)
assert new_prompt == expected_prompt
assert new_token_ids == expected_token_ids
2 changes: 1 addition & 1 deletion vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from array import array
from typing import Iterable, Optional, Tuple, Union, List
from typing import Iterable, List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import math
from array import array
from typing import Iterable, Optional, Tuple, Union, List
from typing import Iterable, List, Optional, Tuple, Union

import torch
from PIL import Image
Expand Down
48 changes: 32 additions & 16 deletions vllm/multimodal/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ def repeat_and_pad_image_tokens(
prompt_token_ids: List[int],
*,
image_token_id: int,
repeat_count: Union[int, List[int]] = 1,
repeat_count: Union[int, List[int]],
pad_token_left: Optional[int] = None,
pad_token_right: Optional[int] = None,
) -> Tuple[Optional[str], List[int]]:
if not isinstance(repeat_count, list):
repeat_count = [repeat_count] * len(prompt_token_ids)
if isinstance(repeat_count, int):
repeat_count = [repeat_count]

if prompt is None:
new_prompt = None
Expand All @@ -59,13 +59,6 @@ def repeat_and_pad_image_tokens(
tokenizer.decode(pad_token_left))
pad_token_str_right = (None if pad_token_right is None else
tokenizer.decode(pad_token_right))
replacement_str = "".join(
repeat_and_pad_token(
image_token_str,
repeat_count=repeat_count[0],
pad_token_left=pad_token_str_left,
pad_token_right=pad_token_str_right,
))

image_token_count = prompt.count(image_token_str)
# This is an arbitrary number to distinguish between the two cases
Expand All @@ -74,22 +67,45 @@ def repeat_and_pad_image_tokens(
"Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"repeating %s tokens.", image_token_str)

# The image tokens are removed to be consistent with HuggingFace
new_prompt = prompt.replace(image_token_str, replacement_str)
if image_token_count < len(repeat_count):
logger.warning(
"The number of image tokens in the prompt is less than "
"the number of image inputs. Extra image tokens will be "
"treated as plain text")
repeat_count = repeat_count[:image_token_count]

prompt_parts = prompt.split(image_token_str,
maxsplit=len(repeat_count))
new_prompt = ""
for i in range(len(repeat_count)):
replacement_str = "".join(
repeat_and_pad_token(
image_token_str,
repeat_count=repeat_count[i],
pad_token_left=pad_token_str_left,
pad_token_right=pad_token_str_right,
))
# The image tokens are removed to be consistent with HuggingFace
new_prompt += prompt_parts[i] + replacement_str
new_prompt += prompt_parts[-1]

new_token_ids: List[int] = []
idx = 0
image_token_idx = 0
for i, token in enumerate(prompt_token_ids):
if token == image_token_id:
replacement_ids = repeat_and_pad_token(
image_token_id,
repeat_count=repeat_count[idx],
repeat_count=repeat_count[image_token_idx],
pad_token_left=pad_token_left,
pad_token_right=pad_token_right,
)
new_token_ids.extend(replacement_ids)
idx += 1
image_token_idx += 1

# No need to further scan the list since we replaced all tokens
if image_token_idx >= len(repeat_count):
new_token_ids.extend(prompt_token_ids[i + 1:])
break
else:
new_token_ids.append(token)

Expand Down

0 comments on commit b4a0bcc

Please sign in to comment.