From 630eb5b5ce6ea59b6480440b7f6064be5ca71ae1 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 19 Jan 2025 11:16:34 +0800 Subject: [PATCH] [Bugfix] Fix multi-modal processors for transformers 4.48 (#12187) --- vllm/model_executor/models/llava.py | 25 ++++- vllm/model_executor/models/qwen2_audio.py | 72 ++++++++---- vllm/model_executor/models/ultravox.py | 9 +- vllm/transformers_utils/config.py | 9 +- vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/aria.py | 118 ++++++++++++++++++++ 6 files changed, 199 insertions(+), 36 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 722fff98d5c19..6cceded43a79d 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -5,9 +5,11 @@ import torch import torch.nn as nn +from packaging.version import Version from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig, PixtralVisionConfig, PretrainedConfig, SiglipVisionConfig) +from transformers import __version__ as TRANSFORMERS_VERSION from transformers.models.llava import LlavaProcessor from transformers.models.pixtral import PixtralProcessor @@ -716,6 +718,27 @@ def load_weights(self, weights: Iterable[Tuple[str, return loader.load_weights(weights) +class MantisProcessingInfo(LlavaProcessingInfo): + + def get_hf_processor(self): + hf_config = self.get_hf_config() + vision_info = self.get_vision_encoder_info() + + if Version(TRANSFORMERS_VERSION) < Version("4.48"): + # BUG: num_additional_image_tokens = 0 but treated as 1, + # so we set vision_feature_select_strategy to None to offset this + vision_feature_select_strategy = None + else: + # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150 + vision_feature_select_strategy = hf_config.vision_feature_select_strategy # noqa: E501 + + return self.ctx.get_hf_processor( + LlavaProcessor, + patch_size=vision_info.get_patch_size(), + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + class MantisMultiModalProcessor(LlavaMultiModalProcessor): def apply( @@ -794,7 +817,7 @@ def get_replacement_mantis(item_idx: int): # To use this model, please use # `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` @MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor, - info=LlavaProcessingInfo, + info=MantisProcessingInfo, dummy_inputs=LlavaDummyInputsBuilder) class MantisForConditionalGeneration(LlavaForConditionalGeneration): pass diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 0dff9595c6c08..47d56175261e4 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -36,8 +36,9 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, - NestedTensors) +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputsV2, MultiModalKwargs, + NestedTensors, PlaceholderRange) from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -153,29 +154,24 @@ def _call_hf_processor( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, Any], ) -> BatchFeature: - mm_data = dict(mm_data) - audios = mm_data.pop("audios", []) - - if audios: - mm_data["audios"] = audios - - feature_extractor = self.info.get_feature_extractor(**mm_kwargs) - mm_kwargs = dict( - **mm_kwargs, - sampling_rate=feature_extractor.sampling_rate, - ) - else: - # NOTE: WhisperFeatureExtractor cannot handle empty list of audios - pass + # Text-only input not supported in composite processor + if not mm_data or not mm_data.get("audios", []): + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + feature_extractor = self.info.get_feature_extractor(**mm_kwargs) + mm_kwargs = dict( + **mm_kwargs, + sampling_rate=feature_extractor.sampling_rate, + ) - processed_outputs = super()._call_hf_processor( + return super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, ) - return processed_outputs - def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -192,8 +188,14 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_config = self.info.get_hf_config() - placeholder = hf_config.audio_token_index + processor = self.info.get_hf_processor() + + # Use getattr with default to be compatible with transformers<4.48 + audio_token = getattr(processor, "audio_token", "<|AUDIO|>") + audio_bos_token = getattr(processor, "audio_bos_token", + "<|audio_bos|>") + audio_eos_token = getattr(processor, "audio_eos_token", + "<|audio_eos|>") feature_attention_mask = out_mm_kwargs.get("feature_attention_mask") if feature_attention_mask is None: @@ -214,12 +216,16 @@ def get_replacement_qwen2_audio(item_idx: int): f"The audio {audio} (len={len(audio)}) is too short " "to be represented inside the model") - return [placeholder] * num_placeholders + return "".join([ + audio_bos_token, + audio_token * num_placeholders, + audio_eos_token, + ]) return [ PromptReplacement( modality="audio", - target=[placeholder], + target=audio_token, replacement=get_replacement_qwen2_audio, ) ] @@ -234,6 +240,26 @@ def _always_apply_prompt_replacements(self) -> bool: # tokens than the number of audio items) return not hasattr(self.info.get_hf_processor(), "audio_token") + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalInputsV2: + result = super().apply(prompt, mm_data, hf_processor_mm_kwargs) + + # Only <|AUDIO|> tokens should be considered as placeholders, + # so we ignore the audio_bos_token and audio_eos_token + result["mm_placeholders"] = { + modality: [ + PlaceholderRange(offset=p["offset"] + 1, + length=p["length"] - 2) for p in ps + ] + for modality, ps in result["mm_placeholders"].items() + } + + return result + @MULTIMODAL_REGISTRY.register_processor( Qwen2AudioMultiModalProcessor, diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 587f18ccaf98f..9301422383696 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -137,7 +137,7 @@ def _call_hf_processor( mm_kwargs: Mapping[str, object], ) -> BatchFeature: # Text-only input not supported in composite processor - if not mm_data: + if not mm_data or not mm_data.get("audios", []): prompt_ids = self.info.get_tokenizer().encode(prompt) prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") @@ -146,13 +146,6 @@ def _call_hf_processor( audios = mm_data.pop("audios", []) assert isinstance(audios, list) - if not audios: - return super()._call_hf_processor( - prompt=prompt, - mm_data=mm_data, - mm_kwargs=mm_kwargs, - ) - feature_extractor = self.info.get_feature_extractor() mm_kwargs = dict( **mm_kwargs, diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index c97acffa1a719..f57dfded0a62f 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -22,10 +22,10 @@ from vllm.logger import init_logger # yapf conflicts with isort for this block # yapf: disable -from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config, - DbrxConfig, DeepseekVLV2Config, - EAGLEConfig, ExaoneConfig, - H2OVLChatConfig, +from vllm.transformers_utils.configs import (AriaConfig, ChatGLMConfig, + Cohere2Config, DbrxConfig, + DeepseekVLV2Config, EAGLEConfig, + ExaoneConfig, H2OVLChatConfig, InternVLChatConfig, JAISConfig, MedusaConfig, MllamaConfig, MLPSpeculatorConfig, MPTConfig, @@ -52,6 +52,7 @@ } _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { + "aria": AriaConfig, "chatglm": ChatGLMConfig, "cohere2": Cohere2Config, "dbrx": DbrxConfig, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index f065c56124605..807ef4fbfd0c0 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,3 +1,4 @@ +from vllm.transformers_utils.configs.aria import AriaConfig from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.cohere2 import Cohere2Config from vllm.transformers_utils.configs.dbrx import DbrxConfig @@ -23,6 +24,7 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ + "AriaConfig", "ChatGLMConfig", "Cohere2Config", "DbrxConfig", diff --git a/vllm/transformers_utils/configs/aria.py b/vllm/transformers_utils/configs/aria.py index d253da0d96a34..f4b531225b5d0 100644 --- a/vllm/transformers_utils/configs/aria.py +++ b/vllm/transformers_utils/configs/aria.py @@ -1,7 +1,32 @@ +# Copyright 2024 Rhymes AI. All rights reserved. +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Mapping + +from transformers import PretrainedConfig from transformers.models.idefics2.configuration_idefics2 import ( Idefics2VisionConfig) from transformers.models.llama.configuration_llama import LlamaConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) + class AriaVisionConfig(Idefics2VisionConfig): model_type = "aria_vision_model" @@ -45,3 +70,96 @@ def __init__( self.moe_num_experts = moe_num_experts self.moe_topk = moe_topk self.moe_num_shared_experts = moe_num_shared_experts + + +class AriaConfig(PretrainedConfig): + """ + Configuration class for Aria model. + This class handles the configuration for both vision and text components of + the Aria model, + as well as additional parameters for image token handling and projector + mapping. + + Args: + vision_config (AriaVisionConfig or dict): Configuration for the vision + component. + text_config (AriaMoELMConfig or dict): Configuration for the text + component. + projector_patch_to_query_dict (dict): Mapping of patch sizes to query + dimensions. + ignore_index (int): Index to ignore in loss calculation. + image_token_index (int): Index used to represent image tokens. + **kwargs: Additional keyword arguments passed to the parent class. + Attributes: + model_type (str): Type of the model, set to "aria". + is_composition (bool): Whether the model is a composition of multiple + components. + ignore_index (int): Index to ignore in loss calculation. + image_token_index (int): Index used to represent image tokens. + projector_patch_to_query_dict (dict): Mapping of patch sizes to query + dimensions. + vision_config (AriaVisionConfig): Configuration for the vision + component. + text_config (AriaMoELMConfig): Configuration for the text component. + """ + + model_type = "aria" + is_composition = False + + def __init__( + self, + vision_config: AriaVisionConfig = AriaVisionConfig(), # noqa: B008 + text_config: AriaMoELMConfig = AriaMoELMConfig(), # noqa: B008 + projector_patch_to_query_dict: Mapping[int, int] = { + 1225: 128, + 4900: 256, + }, + ignore_index=-100, + image_token_index=32000, + tie_word_embeddings=False, + **kwargs, + ): + super().__init__(**kwargs) + self.ignore_index = ignore_index + self.image_token_index = image_token_index + self.tie_word_embeddings = tie_word_embeddings + attn_implementation = kwargs.pop("attn_implementation", None) + + # Set the default attention implementation to flash_attention_2 if not + # specified + self._attn_implementation = ("flash_attention_2" + if attn_implementation is None else + attn_implementation) + + # Convert the keys and values of projector_patch_to_query_dict to + # integers + # This ensures consistency even if they were provided as strings + self.projector_patch_to_query_dict = { + int(k): int(v) + for k, v in projector_patch_to_query_dict.items() + } + + if isinstance(vision_config, dict) and "model_type" in vision_config: + vision_config = AriaVisionConfig(**vision_config) + if attn_implementation is None: + vision_attn_implementation = "flash_attention_2" + elif attn_implementation == "sdpa": + logger.warning("SDPA is not supported for vit, using " + "flash_attention_2 instead") + vision_attn_implementation = "flash_attention_2" + else: + vision_attn_implementation = attn_implementation + vision_config._attn_implementation = vision_attn_implementation + + self.vision_config = vision_config + + if isinstance(text_config, dict) and "model_type" in text_config: + text_attn_implementation = ("sdpa" if attn_implementation is None + else attn_implementation) + text_config = AriaMoELMConfig(**text_config) + text_config._attn_implementation = text_attn_implementation + + self.text_config = text_config + + # This is needed for the static kv cache + self.num_hidden_layers = self.text_config.num_hidden_layers