Skip to content

Commit

Permalink
refactor input processor
Browse files Browse the repository at this point in the history
Signed-off-by: xffxff <[email protected]>
  • Loading branch information
xffxff committed Nov 21, 2024
1 parent 50cacff commit da57824
Showing 1 changed file with 57 additions and 173 deletions.
230 changes: 57 additions & 173 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import math
from typing import Iterable, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from transformers import LlamaConfig
from transformers.utils import logging
from vllm.attention import AttentionMetadata
Expand All @@ -21,7 +19,7 @@
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig

Check failure on line 19 in vllm/model_executor/models/aria.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/models/aria.py:19:81: E501 Line too long (82 > 80)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput, SamplingMetadata

Check failure on line 20 in vllm/model_executor/models/aria.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/models/aria.py:20:81: E501 Line too long (87 > 80)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear, QKVParallelLinear
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear

Check failure on line 22 in vllm/model_executor/models/aria.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/models/aria.py:22:81: E501 Line too long (85 > 80)
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama import (
LlamaAttention,
Expand All @@ -46,14 +44,9 @@
repeat_and_pad_placeholder_tokens,
)
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of

logger = logging.get_logger(__name__)

from typing import Optional, Tuple

import torch
import torch.nn as nn
from torch.nn.init import trunc_normal_

Check failure on line 50 in vllm/model_executor/models/aria.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E402)

vllm/model_executor/models/aria.py:50:1: E402 Module level import not at top of file
from vllm.config import QuantizationConfig

Check failure on line 51 in vllm/model_executor/models/aria.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E402)

vllm/model_executor/models/aria.py:51:1: E402 Module level import not at top of file

Check failure on line 51 in vllm/model_executor/models/aria.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F811)

vllm/model_executor/models/aria.py:51:25: F811 Redefinition of unused `QuantizationConfig` from line 19
from vllm.model_executor.models.idefics2_vision_model import Idefics2VisionTransformer

Check failure on line 52 in vllm/model_executor/models/aria.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E402)

vllm/model_executor/models/aria.py:52:1: E402 Module level import not at top of file

Check failure on line 52 in vllm/model_executor/models/aria.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/models/aria.py:52:81: E501 Line too long (86 > 80)
Expand Down Expand Up @@ -536,141 +529,48 @@ def build_mm_projector(config):
)


def _select_best_resolution(img_width: int, img_height: int,
target_ratios: List[List[int]], patch_size: int):
"""
Selects the best resolution from a list of possible resolutions based on the original size.
def get_max_multimodal_tokens(ctx):
return max(ctx.model_config.hf_config.image_size2tokens.values())

Args:
img_width: the original widths of images.
img_height: the original heights of images.
target_ratios (2d numpy array): dimension size (M,2)
patch_size (int): image patch size

Returns:
tuple: The best fit resolution in the format (width, height).
"""
def input_mapper_for_aria(ctx, data):
return MultiModalInputs(data)


aspect_ratio = img_width / img_height
best_ratio_diff = float("inf")
best_ratio_w, best_ratio_h = 1, 1
area = np.int32(img_width) * np.int32(img_height)
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio_w, best_ratio_h = ratio[0], ratio[1]
elif (ratio_diff == best_ratio_diff
and area > 0.5 * patch_size * patch_size * ratio[0] * ratio[1]):
best_ratio_w, best_ratio_h = ratio[0], ratio[1]

return best_ratio_w, best_ratio_h


def split_image(
image: Image.Image,
split_image: bool,
split_ratio: List[List[int]] = [
[1, 2],
[1, 3],
[1, 4],
[1, 5],
[1, 6],
[1, 7],
[1, 8],
[2, 4],
[2, 3],
[2, 2],
[2, 1],
[3, 1],
[3, 2],
[4, 1],
[4, 2],
[5, 1],
[6, 1],
[7, 1],
[8, 1],
],
patch_size: int = 980,
) -> List[Image.Image]:
def repeat_image_tokens(token_ids: list, image_token_id: int,
repeat_times: list) -> list:
"""
Split image into multiple patches
Repeats the image token in the token_ids list according to the repeat_times list.
Args:
image (PIL.Image): Input image.
split_image (bool): Whether to split the image into patches.
split_ratio (2d numpy array): dimension size (M,2)
patch_size (int): image patch size
token_ids (list): List of token IDs.
image_token_id (int): The token ID that represents an image.
repeat_times (list): List of integers specifying how many times to repeat the image token.
Returns:
List[PIL.Image]: List of splitted images.
list: A new list with the image token repeated as specified.
Example:
token_ids = [1, 2, 3, 4, 3, 5]
image_token_id = 3
repeat_times = [2, 3]
result = repeat_image_tokens(token_ids, image_token_id, repeat_times)
# result will be [1, 2, 3, 3, 4, 3, 3, 3, 5]
"""
if split_image:
ratio_width, ratio_height = _select_best_resolution(
image.width, image.height, split_ratio, patch_size)
resize_width = patch_size * ratio_width
resize_height = patch_size * ratio_height
blocks = ratio_width * ratio_height
resized_img = image.resize((resize_width, resize_height))
processed_images = []
for i in range(blocks):
box = (
(i % (resize_width // patch_size)) * patch_size,
(i // (resize_width // patch_size)) * patch_size,
((i % (resize_width // patch_size)) + 1) * patch_size,
((i // (resize_width // patch_size)) + 1) * patch_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if len(processed_images) != 1:
processed_images.insert(0, image)
return processed_images
else:
return [image]
if len(repeat_times) != token_ids.count(image_token_id):
raise ValueError(
"The length of repeat_times is not equal to the number of images.")

result = []
repeat_iter = iter(repeat_times)

def get_max_multimodal_tokens(ctx):
return max(ctx.model_config.hf_config.image_size2tokens.values())

for x in token_ids:
if x == image_token_id:
result.extend([image_token_id] * next(repeat_iter))
else:
result.append(x)

def input_mapper_for_aria(ctx, data):
"""
This is almost same with _default_input_mapper from vllm.multimodal.image.py.
Args:
ctx (ModelExecutorContext): The context object containing necessary parameters.
data (Union[Image.Image, torch.Tensor, List[Union[Image.Image, torch.Tensor]]]): The input data to be processed.
The only different is we would like to support runtime max_image_size adjustment.
"""
model_config = ctx.model_config
max_image_size = getattr(model_config.multimodal_config, "max_image_size",
980)

# PIL image
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
image_processor = cached_get_image_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the image object")
try:
batch_data = image_processor.preprocess(
data, max_image_size=max_image_size, return_tensors="pt").data
batch_data.pop("num_crops")
except Exception:
logger.error("Failed to process image (%s)", data)
raise

return MultiModalInputs(batch_data)

# Image embedding
elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
return MultiModalInputs({"image_embeds": data})

raise TypeError(f"Invalid image type: {type(data)}")
return result


def input_processor(ctx, llm_inputs):
Expand All @@ -682,55 +582,43 @@ def input_processor(ctx, llm_inputs):
model_config = ctx.model_config

tokenizer = cached_get_tokenizer(model_config.tokenizer)
image_processor = cached_get_image_processor(
model_config.model, trust_remote_code=model_config.trust_remote_code)
hf_config = model_config.hf_config

# prepare image tokens, the max_image_size is used to determine the number of patch_size for every image
max_image_size = multi_modal_data.pop("max_image_size", 980)
_split_image = multi_modal_data.pop("split_image", False)

assert isinstance(max_image_size, int) or isinstance(
max_image_size, float), "max_image_size should be float or int"
assert isinstance(max_image_size, (int, float)), "max_image_size should be float or int"
images = (multi_modal_data["image"] if isinstance(
multi_modal_data["image"], list) else [multi_modal_data["image"]])
num_crops = []
splitted_images = []
for image in images:
splitted_image = split_image(image,
_split_image,
patch_size=max_image_size)
splitted_images.extend(splitted_image)
num_crops.append(len(splitted_image))
max_image_size = [max_image_size] * len(images)
# reassign the image because we might split them into mini-patches
multi_modal_data["image"] = splitted_images

# Mapping the image patch size to the corresponding number of tokens for each image
image_feature_sizes = []
for image_size, num_crop in zip(max_image_size, num_crops):
assert (
image_size in hf_config.image_size2tokens
), f"Invalid image size: {image_size}, available options: {list(hf_config.image_size2tokens.keys())}"
image_feature_sizes.append(hf_config.image_size2tokens[image_size] *
num_crop)

# Set up the max_image_size and split_image in the RuntimeContext for the image processor
# TODO: Supports dynamic image size support
setattr(model_config.multimodal_config, "max_image_size",
max(max_image_size))

new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(

image_inputs = image_processor.preprocess(images,
max_image_size=max_image_size,
split_image=_split_image,
return_tensors="pt").data
num_crops = image_inputs.pop("num_crops")

prompt_token_ids = llm_inputs["prompt_token_ids"]
prompt_token_ids = repeat_image_tokens(prompt_token_ids,
hf_config.image_token_index,
num_crops)

repeat_count = [hf_config.image_size2tokens[max_image_size]
] * sum(num_crops).item()
new_prompt, new_token_ids, _ = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
None,
prompt_token_ids,
placeholder_token_id=hf_config.image_token_index,
repeat_count=image_feature_sizes,
repeat_count=repeat_count,
)

return token_inputs(
prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
# multi_modal_placeholders={"image": ranges},
multi_modal_data={"image": image_inputs},
)


Expand All @@ -755,14 +643,10 @@ def __init__(
quant_config = vllm_config.quant_config

# prepare the image_size to tokens mapping for the image preprocess, see input_processor
setattr(
config,
"image_size2tokens",
{
int(math.sqrt(k) * config.vision_config.patch_size): v
for k, v in config.projector_patch_to_query_dict.items()
},
)
config.image_size2tokens = {
int(math.sqrt(k) * config.vision_config.patch_size): v
for k, v in config.projector_patch_to_query_dict.items()
}
self.config = config
self.vision_tower = AriaVisionModel(config.vision_config)
self.multi_modal_projector = build_mm_projector(config)
Expand Down

0 comments on commit da57824

Please sign in to comment.