Skip to content

Commit

Permalink
Get the llava next feature size from pinpoints
Browse files Browse the repository at this point in the history
Signed-off-by: Alex-Brooks <[email protected]>
  • Loading branch information
alex-jw-brooks committed Oct 24, 2024
1 parent b548d7a commit c159a35
Showing 1 changed file with 28 additions and 13 deletions.
41 changes: 28 additions & 13 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
init_vllm_registered_model)

# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448


class LlavaNextImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
Expand Down Expand Up @@ -149,11 +146,28 @@ def get_llava_next_image_feature_size(


def get_max_llava_next_image_tokens(ctx: InputContext):
return get_llava_next_image_feature_size(
ctx.get_hf_config(LlavaNextConfig),
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
)
"""Compute the max feature size for all possible image grid pinpoints."""
return _get_pinpoint_with_largest_features(ctx)[0]


def _get_pinpoint_with_largest_features(
ctx: InputContext) -> Tuple[int, Tuple[int, int]]:
"""Get the grid pinpoint with the largest features & its feature size."""
hf_config = ctx.get_hf_config(LlavaNextConfig)
largest_feature_size = 0
largest_feature_pinpoint = None
for (height, width) in hf_config.image_grid_pinpoints:
feat_size = get_llava_next_image_feature_size(
hf_config,
input_height=height,
input_width=width,
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = (height, width)
if not largest_feature_size or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_size, largest_feature_pinpoint


def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
Expand All @@ -162,7 +176,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
vision_config = hf_config.vision_config
num_images = mm_counts["image"]

image_feature_size = get_max_llava_next_image_tokens(ctx)
image_feature_size, pinpoint = _get_pinpoint_with_largest_features(ctx)
max_feat_height, max_feat_width = pinpoint

if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip(
Expand All @@ -176,8 +191,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
mm_data = dummy_image_for_clip(
vision_config,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
image_width_override=max_feat_height,
image_height_override=max_feat_width,
)

return seq_data, mm_data
Expand All @@ -193,8 +208,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
mm_data = dummy_image_for_siglip(
vision_config,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
image_width_override=max_feat_height,
image_height_override=max_feat_width,
)

return seq_data, mm_data
Expand Down

0 comments on commit c159a35

Please sign in to comment.