Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Compute Llava Next Max Tokens / Dummy Data From Gridpoints #9650

Merged
merged 4 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion tests/models/decoder_only/vision_language/test_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import pytest
from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer

from vllm.inputs import InputContext
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs

from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
from ...utils import check_logprobs_close
from ...utils import build_model_context, check_logprobs_close

_LIMIT_IMAGE_PER_PROMPT = 4

Expand All @@ -22,6 +23,19 @@
models = ["llava-hf/llava-v1.6-mistral-7b-hf"]


@pytest.fixture()
def get_max_llava_next_image_tokens():
from vllm.model_executor.models.llava_next import (
get_max_llava_next_image_tokens)
return get_max_llava_next_image_tokens


@pytest.fixture()
def dummy_data_for_llava_next():
from vllm.model_executor.models.llava_next import dummy_data_for_llava_next
return dummy_data_for_llava_next


def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
model: str):
Expand Down Expand Up @@ -281,3 +295,52 @@ def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)


@pytest.mark.parametrize("gridpoints,expected_max_tokens", [
([[336, 336]], 1176),
([[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]], 2928),
])
def test_get_max_llava_next_image_tokens(gridpoints, expected_max_tokens,
get_max_llava_next_image_tokens):
ctx = build_model_context(model_name="llava-hf/llava-v1.6-mistral-7b-hf")

# Update the config image_grid_pinpoints
# and calculate the resulting max tokens
ctx.model_config.hf_config.image_grid_pinpoints = gridpoints

actual_max_tokens = get_max_llava_next_image_tokens(
InputContext(ctx.model_config))

assert expected_max_tokens == actual_max_tokens


@pytest.mark.parametrize(
"gridpoints,expected_size",
[
# One point; it has to be the largest
([[336, 336]], (336, 336)),
# Default for most llava next models; the 2x2 tile is the largest
([[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]],
(672, 672)),
# If two rectangular gridpoints are the same, the more vertical
# one has the higher feature count due to newline features
([[336, 672], [672, 336]], (672, 336))
])
def test_dummy_data_for_llava_next_feature_size(dummy_data_for_llava_next,
gridpoints, expected_size):
ctx = build_model_context(model_name="llava-hf/llava-v1.6-mistral-7b-hf")

# Update the config image_grid_pinpoints
ctx.model_config.hf_config.image_grid_pinpoints = gridpoints
seq_len = 5000 # bigger than the max feature size for any image

seq_data, mm_data = dummy_data_for_llava_next(
ctx,
seq_len=seq_len,
mm_counts={"image": 1},
)

# The dummy data dims should match the gridpoint with the biggest feat size
assert mm_data["image"].size == expected_size
assert len(seq_data.get_token_ids()) >= seq_len
41 changes: 28 additions & 13 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
init_vllm_registered_model)

# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448


class LlavaNextImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
Expand Down Expand Up @@ -149,11 +146,28 @@ def get_llava_next_image_feature_size(


def get_max_llava_next_image_tokens(ctx: InputContext):
return get_llava_next_image_feature_size(
ctx.get_hf_config(LlavaNextConfig),
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
)
"""Compute the max feature size for all possible image grid pinpoints."""
return _get_pinpoint_with_largest_features(ctx)[0]


def _get_pinpoint_with_largest_features(
ctx: InputContext) -> Tuple[int, Tuple[int, int]]:
"""Get the grid pinpoint with the largest features & its feature size."""
hf_config = ctx.get_hf_config(LlavaNextConfig)
largest_feature_size = 0
largest_feature_pinpoint = None
for (height, width) in hf_config.image_grid_pinpoints:
feat_size = get_llava_next_image_feature_size(
hf_config,
input_height=height,
input_width=width,
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = (height, width)
if not largest_feature_size or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_size, largest_feature_pinpoint


def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
Expand All @@ -162,7 +176,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
vision_config = hf_config.vision_config
num_images = mm_counts["image"]

image_feature_size = get_max_llava_next_image_tokens(ctx)
image_feature_size, pinpoint = _get_pinpoint_with_largest_features(ctx)
max_feat_height, max_feat_width = pinpoint

if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip(
Expand All @@ -176,8 +191,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
mm_data = dummy_image_for_clip(
vision_config,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
image_width_override=max_feat_width,
image_height_override=max_feat_height,
)

return seq_data, mm_data
Expand All @@ -193,8 +208,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
mm_data = dummy_image_for_siglip(
vision_config,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
image_width_override=max_feat_width,
image_height_override=max_feat_height,
)

return seq_data, mm_data
Expand Down