Skip to content

Commit

Permalink
Use merged multi-modal processor for blip2 and chameleon
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Dec 30, 2024
1 parent bbd5752 commit 0452b99
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 250 deletions.
4 changes: 4 additions & 0 deletions tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,10 @@ def _test_processing_cache_correctness(

# yapf: disable
@pytest.mark.parametrize(("model_id", "modalities"), [
("rhymes-ai/Aria", {"image"}),
("Salesforce/blip2-opt-2.7b", {"image"}),
("facebook/chameleon-7b", {"image"}),
("adept/fuyu-8b", {"image"}),
("llava-hf/llava-1.5-7b-hf", {"image"}),
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image"}),
("mistral-community/pixtral-12b", {"image"}),
Expand Down
74 changes: 0 additions & 74 deletions vllm/model_executor/models/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,13 @@
from transformers import Blip2VisionConfig, BlipVisionConfig

from vllm.attention.layer import MultiHeadAttention
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData


def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
Expand All @@ -33,36 +28,6 @@ def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
return grid_length * grid_length


def get_blip_image_feature_size(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
return get_blip_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size)


def get_max_blip_image_tokens(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
return get_blip_image_feature_size(hf_config)


def dummy_seq_data_for_blip(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = get_blip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override

return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)


def dummy_image_for_blip(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
num_images: int,
Expand All @@ -80,45 +45,6 @@ def dummy_image_for_blip(
return {"image": image if num_images == 1 else [image] * num_images}


def input_processor_for_blip(
model_config: ModelConfig,
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
inputs: DecoderOnlyInputs,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs

if "multi_modal_placeholders" in inputs and "image" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs

tokenizer = cached_get_tokenizer(model_config.tokenizer)

if image_feature_size_override is None:
image_feature_size = get_blip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override

new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=image_token_id,
repeat_count=image_feature_size,
)

# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": ranges})


# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
class BlipVisionEmbeddings(nn.Module):

Expand Down
139 changes: 52 additions & 87 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,25 @@

import torch
import torch.nn as nn
from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
apply_chunking_to_forward)
from transformers import (BatchFeature, Blip2Config, Blip2Processor,
Blip2QFormerConfig, apply_chunking_to_forward)

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import IntermediateTensors, SequenceData

from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens)
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors

from .blip import BlipVisionModel, dummy_image_for_blip
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
Expand Down Expand Up @@ -396,96 +397,60 @@ def forward(
return sequence_output


def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
return hf_config.num_query_tokens


def get_max_blip2_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(Blip2Config)
vision_config = hf_config.vision_config

if isinstance(vision_config, Blip2VisionConfig):
return get_max_blip_image_tokens(vision_config)

msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)


def dummy_seq_data_for_blip2(
hf_config: Blip2Config,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = get_blip2_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override

return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
), {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}


def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(Blip2Config)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]

seq_data, ranges = dummy_seq_data_for_blip2(
hf_config,
seq_len,
num_images,
image_token_id=BLIP2_IMAGE_TOKEN_ID,
)

if isinstance(vision_config, Blip2VisionConfig):
mm_data = dummy_image_for_blip(vision_config, num_images)
return hf_config.num_query_tokens

return DummyData(seq_data, mm_data, ranges)

msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
class Blip2MultiModalProcessor(BaseMultiModalProcessor):

def _get_hf_processor(self) -> Blip2Processor:
return self.ctx.get_hf_processor(Blip2Processor)

def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)

hf_config = ctx.get_hf_config(Blip2Config)
image_feature_size = get_blip2_image_feature_size(hf_config)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
max_image_tokens = get_max_blip2_image_tokens(self.ctx)

return [
PromptReplacement(
modality="image",
target="", # An empty target is never matched against
replacement="<image>" * max_image_tokens,
)
]

# The original model places image tokens at the front
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
new_token_ids += inputs["prompt_token_ids"]
placeholder_ranges = [
PlaceholderRange(offset=0, length=image_feature_size)
]
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config(Blip2Config)
vision_config = hf_config.vision_config
num_images = mm_counts.get("image", 0)

new_prompt = inputs.get("prompt")
if new_prompt is not None:
new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
data = dummy_image_for_blip(vision_config, num_images)

return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})
return ProcessorInputs(
prompt_text="",
mm_data=data,
)


@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand Down
Loading

0 comments on commit 0452b99

Please sign in to comment.