From 0452b99b143eab5fc7c4596a9ad167a74bc1f022 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 30 Dec 2024 16:48:44 +0000 Subject: [PATCH] Use merged multi-modal processor for blip2 and chameleon Signed-off-by: DarkLight1337 --- tests/multimodal/test_processing.py | 4 + vllm/model_executor/models/blip.py | 74 ----------- vllm/model_executor/models/blip2.py | 139 ++++++++------------- vllm/model_executor/models/chameleon.py | 157 +++++++++++------------- vllm/multimodal/processing.py | 5 +- 5 files changed, 129 insertions(+), 250 deletions(-) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 1b2847ed0f534..43fb6e4e25e72 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -624,6 +624,10 @@ def _test_processing_cache_correctness( # yapf: disable @pytest.mark.parametrize(("model_id", "modalities"), [ + ("rhymes-ai/Aria", {"image"}), + ("Salesforce/blip2-opt-2.7b", {"image"}), + ("facebook/chameleon-7b", {"image"}), + ("adept/fuyu-8b", {"image"}), ("llava-hf/llava-1.5-7b-hf", {"image"}), ("TIGER-Lab/Mantis-8B-siglip-llama3", {"image"}), ("mistral-community/pixtral-12b", {"image"}), diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 42a239cadac46..129a0bcecc86f 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -8,18 +8,13 @@ from transformers import Blip2VisionConfig, BlipVisionConfig from vllm.attention.layer import MultiHeadAttention -from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.multimodal.utils import (cached_get_tokenizer, - repeat_and_pad_placeholder_tokens) -from vllm.sequence import SequenceData def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: @@ -33,36 +28,6 @@ def get_blip_num_patches(*, image_size: int, patch_size: int) -> int: return grid_length * grid_length -def get_blip_image_feature_size( - hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int: - return get_blip_num_patches(image_size=hf_config.image_size, - patch_size=hf_config.patch_size) - - -def get_max_blip_image_tokens( - hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int: - return get_blip_image_feature_size(hf_config) - - -def dummy_seq_data_for_blip( - hf_config: Union[BlipVisionConfig, Blip2VisionConfig], - seq_len: int, - num_images: int, - *, - image_token_id: int, - image_feature_size_override: Optional[int] = None, -): - if image_feature_size_override is None: - image_feature_size = get_blip_image_feature_size(hf_config) - else: - image_feature_size = image_feature_size_override - - return SequenceData.from_prompt_token_counts( - (image_token_id, image_feature_size * num_images), - (0, seq_len - image_feature_size * num_images), - ) - - def dummy_image_for_blip( hf_config: Union[BlipVisionConfig, Blip2VisionConfig], num_images: int, @@ -80,45 +45,6 @@ def dummy_image_for_blip( return {"image": image if num_images == 1 else [image] * num_images} -def input_processor_for_blip( - model_config: ModelConfig, - hf_config: Union[BlipVisionConfig, Blip2VisionConfig], - inputs: DecoderOnlyInputs, - *, - image_token_id: int, - image_feature_size_override: Optional[int] = None, -): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs - - if "multi_modal_placeholders" in inputs and "image" in inputs[ - "multi_modal_placeholders"]: - # The inputs already have placeholders. - return inputs - - tokenizer = cached_get_tokenizer(model_config.tokenizer) - - if image_feature_size_override is None: - image_feature_size = get_blip_image_feature_size(hf_config) - else: - image_feature_size = image_feature_size_override - - new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( - tokenizer, - inputs.get("prompt"), - inputs["prompt_token_ids"], - placeholder_token_id=image_token_id, - repeat_count=image_feature_size, - ) - - # NOTE: Create a defensive copy of the original inputs - return token_inputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data, - multi_modal_placeholders={"image": ranges}) - - # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa class BlipVisionEmbeddings(nn.Module): diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 4e16ae522c9b2..c65acb85aa980 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -4,24 +4,25 @@ import torch import torch.nn as nn -from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig, - apply_chunking_to_forward) +from transformers import (BatchFeature, Blip2Config, Blip2Processor, + Blip2QFormerConfig, apply_chunking_to_forward) from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, VllmConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) +from vllm.inputs import InputContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import NestedTensors, PlaceholderRange -from vllm.multimodal.utils import consecutive_placeholder_ranges -from vllm.sequence import IntermediateTensors, SequenceData - -from .blip import (BlipVisionModel, dummy_image_for_blip, - get_max_blip_image_tokens) +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + MultiModalDataItems, ProcessorInputs, + PromptReplacement) +from vllm.sequence import IntermediateTensors + +from .blip import BlipVisionModel, dummy_image_for_blip from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -396,96 +397,60 @@ def forward( return sequence_output -def get_blip2_image_feature_size(hf_config: Blip2Config) -> int: - return hf_config.num_query_tokens - - def get_max_blip2_image_tokens(ctx: InputContext): hf_config = ctx.get_hf_config(Blip2Config) - vision_config = hf_config.vision_config - - if isinstance(vision_config, Blip2VisionConfig): - return get_max_blip_image_tokens(vision_config) - - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) - - -def dummy_seq_data_for_blip2( - hf_config: Blip2Config, - seq_len: int, - num_images: int, - *, - image_token_id: int, - image_feature_size_override: Optional[int] = None, -): - if image_feature_size_override is None: - image_feature_size = get_blip2_image_feature_size(hf_config) - else: - image_feature_size = image_feature_size_override - - return SequenceData.from_prompt_token_counts( - (image_token_id, image_feature_size * num_images), - (0, seq_len - image_feature_size * num_images), - ), { - "image": - consecutive_placeholder_ranges(num_items=num_images, - item_size=image_feature_size) - } - - -def dummy_data_for_blip2(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - hf_config = ctx.get_hf_config(Blip2Config) - vision_config = hf_config.vision_config - num_images = mm_counts["image"] - - seq_data, ranges = dummy_seq_data_for_blip2( - hf_config, - seq_len, - num_images, - image_token_id=BLIP2_IMAGE_TOKEN_ID, - ) - - if isinstance(vision_config, Blip2VisionConfig): - mm_data = dummy_image_for_blip(vision_config, num_images) + return hf_config.num_query_tokens - return DummyData(seq_data, mm_data, ranges) - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) +class Blip2MultiModalProcessor(BaseMultiModalProcessor): + def _get_hf_processor(self) -> Blip2Processor: + return self.ctx.get_hf_processor(Blip2Processor) -def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) - hf_config = ctx.get_hf_config(Blip2Config) - image_feature_size = get_blip2_image_feature_size(hf_config) + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + max_image_tokens = get_max_blip2_image_tokens(self.ctx) + + return [ + PromptReplacement( + modality="image", + target="", # An empty target is never matched against + replacement="" * max_image_tokens, + ) + ] - # The original model places image tokens at the front - # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514 - new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size - new_token_ids += inputs["prompt_token_ids"] - placeholder_ranges = [ - PlaceholderRange(offset=0, length=image_feature_size) - ] + def _get_dummy_mm_inputs( + self, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + hf_config = self.ctx.get_hf_config(Blip2Config) + vision_config = hf_config.vision_config + num_images = mm_counts.get("image", 0) - new_prompt = inputs.get("prompt") - if new_prompt is not None: - new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt + data = dummy_image_for_blip(vision_config, num_images) - return token_inputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data, - multi_modal_placeholders={"image": placeholder_ranges}) + return ProcessorInputs( + prompt_text="", + mm_data=data, + ) -@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2) -@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2) +@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor) class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index afca81f5d4fd6..e027579cdb8ee 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -6,13 +6,13 @@ import torch.nn.functional as F from PIL import Image from torch import nn -from transformers import ChameleonConfig, ChameleonVQVAEConfig +from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor, + ChameleonVQVAEConfig) from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) +from vllm.inputs import InputContext from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -29,11 +29,13 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import NestedTensors -from vllm.multimodal.utils import (cached_get_tokenizer, - consecutive_placeholder_ranges, - repeat_and_pad_placeholder_tokens) -from vllm.sequence import IntermediateTensors, SequenceData +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputsV2, MultiModalKwargs, + NestedTensors, PlaceholderRange) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + MultiModalDataItems, ProcessorInputs, + PromptReplacement) +from vllm.sequence import IntermediateTensors from vllm.utils import print_warning_once from .interfaces import SupportsMultiModal, SupportsPP @@ -45,10 +47,6 @@ # and processor files, so we hardcode them in the model file for now. CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512 CHAMELEON_IMAGE_SEQ_LENGTH = 1024 -CHAMELEON_IMAGE_TOKEN_ID = 8711 -CHAMELEON_IMAGE_START_TOKEN_ID = 8197 -CHAMELEON_IMAGE_END_TOKEN_ID = 8196 -CHAMELEON_SEP_TOKEN_ID = 8710 class ChameleonImagePixelInputs(TypedDict): @@ -61,28 +59,6 @@ def get_max_chameleon_image_tokens(ctx: InputContext): return CHAMELEON_IMAGE_SEQ_LENGTH -def dummy_seq_data_for_chameleon( - seq_len: int, - num_images: int, - *, - image_token_id: int, - image_feature_size_override: Optional[int] = None, -): - if image_feature_size_override is None: - image_feature_size = CHAMELEON_IMAGE_SEQ_LENGTH - else: - image_feature_size = image_feature_size_override - - return SequenceData.from_prompt_token_counts( - (image_token_id, image_feature_size * num_images), - (0, seq_len - image_feature_size * num_images), - ), { - "image": - consecutive_placeholder_ranges(num_items=num_images, - item_size=image_feature_size) - } - - def dummy_image_for_chameleon( num_images: int, *, @@ -100,61 +76,70 @@ def dummy_image_for_chameleon( return {"image": image if num_images == 1 else [image] * num_images} -def dummy_data_for_chameleon(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - num_images = mm_counts["image"] +class ChameleonMultiModalProcessor(BaseMultiModalProcessor): + + def _get_hf_processor(self) -> ChameleonProcessor: + return self.ctx.get_hf_processor(ChameleonProcessor) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image"), ) + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + processor = self._get_hf_processor() + + return [ + PromptReplacement( + modality="image", + target="", + replacement="".join([ + processor.image_start_token, + processor.image_token * CHAMELEON_IMAGE_SEQ_LENGTH, + processor.image_end_token, + ]), + ) + ] - seq_data, ranges = dummy_seq_data_for_chameleon( - seq_len, - num_images, - image_token_id=CHAMELEON_IMAGE_TOKEN_ID, - ) + def _get_dummy_mm_inputs( + self, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) - mm_data = dummy_image_for_chameleon(num_images) - return DummyData(seq_data, mm_data, ranges) + data = dummy_image_for_chameleon(num_images) + return ProcessorInputs( + prompt_text="" * num_images, + mm_data=data, + ) -def input_processor_for_chameleon(ctx: InputContext, - inputs: DecoderOnlyInputs): + def apply( + self, + prompt_text: str, + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalInputsV2: + result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) + + # Only tokens should be considered as placeholders, + # so we ignore the image_start_token and image_end_token + result["mm_placeholders"] = { + modality: [ + PlaceholderRange(offset=p["offset"] + 1, + length=p["length"] - 2) for p in ps + ] + for modality, ps in result["mm_placeholders"].items() + } - """ - Processing input prompt to insert required tokens for image placeholder. - - See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58 - """ # noqa - - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs - - if "multi_modal_placeholders" in inputs and "image" in inputs[ - "multi_modal_placeholders"]: - # The inputs already have placeholders. - return inputs - - model_config = ctx.model_config - tokenizer = cached_get_tokenizer(model_config.tokenizer) - new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( - tokenizer, - inputs.get("prompt"), - inputs["prompt_token_ids"], - placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID, - repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH, - pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID, - pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID, - ) - - # Appending sep token for chat mode to follow default processor - # behavior - if new_prompt is not None: - new_prompt += tokenizer.sep_token - new_token_ids += [CHAMELEON_SEP_TOKEN_ID] - - # NOTE: Create a defensive copy of the original inputs - return token_inputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data, - multi_modal_placeholders={"image": ranges}) + return result class ChameleonLayerNorm(nn.LayerNorm): @@ -926,10 +911,8 @@ def forward( return hidden_states -@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon) -@INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon) +@MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor) class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 3ece0762e3228..f7fb5d3bba513 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,6 +1,7 @@ import pickle import re from abc import ABC, abstractmethod +from collections import defaultdict from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass, field from functools import lru_cache @@ -352,13 +353,13 @@ def _replace_matches( ) -> list[_S]: out_seqs = list[_S]() prev_end_idx = 0 - next_idx_by_modality = {modality: 0 for modality in mm_item_counts} + next_idx_by_modality = defaultdict[str, int](lambda: 0) for match in _resolve_matches(prompt, matches): modality = match.modality item_idx = next_idx_by_modality[modality] - if item_idx >= mm_item_counts[modality]: + if item_idx >= mm_item_counts.get(modality, 0): continue start_idx = match.start_idx