From 002d7af65fb787b8fc1b34f405efb72a2b07ea05 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 18 Oct 2024 15:29:56 -0400 Subject: [PATCH] [Model] Support Pixtral models in the HF Transformers format (#9036) Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com> --- docs/source/models/supported_models.rst | 2 +- examples/offline_inference_vision_language.py | 17 + vllm/model_executor/layers/activation.py | 2 + vllm/model_executor/models/llava.py | 74 +++- vllm/model_executor/models/pixtral.py | 410 +++++++++++++++++- vllm/model_executor/models/qwen2_vl.py | 6 +- vllm/transformers_utils/processor.py | 4 + 7 files changed, 503 insertions(+), 12 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index ee2844c8b27a0..318139a749d88 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -437,7 +437,7 @@ Text Generation * - :code:`PixtralForConditionalGeneration` - Pixtral - T + I\ :sup:`+` - - :code:`mistralai/Pixtral-12B-2409` + - :code:`mistralai/Pixtral-12B-2409`, :code:`mistral-community/pixtral-12b` etc. - - ✅︎ * - :code:`QWenLMHeadModel` diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 4c88dcc2f087b..06b424abd50b5 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -277,6 +277,22 @@ def run_qwen2_vl(question: str, modality: str): return llm, prompt, stop_token_ids +# Pixtral HF-format +def run_pixtral_hf(question: str, modality: str): + assert modality == "image" + + model_name = "mistral-community/pixtral-12b" + + llm = LLM( + model=model_name, + max_model_len=8192, + ) + + prompt = f"[INST]{question}\n[IMG][/INST]" + stop_token_ids = None + return llm, prompt, stop_token_ids + + # LLama 3.2 def run_mllama(question: str, modality: str): assert modality == "image" @@ -347,6 +363,7 @@ def run_glm4v(question: str, modality: str): "NVLM_D": run_nvlm_d, "qwen_vl": run_qwen_vl, "qwen2_vl": run_qwen2_vl, + "pixtral_hf": run_pixtral_hf, "mllama": run_mllama, "molmo": run_molmo, "glm4v": run_glm4v, diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index cf99306c9caef..8de3385a257f8 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -264,6 +264,8 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): lambda: nn.ReLU(), "relu2": lambda: ReLUSquaredActivation(), + "silu": + lambda: nn.SiLU(), "quick_gelu": lambda: QuickGELU(), }) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index fd2827c0eff09..a83b7d05df7aa 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -5,7 +5,8 @@ import torch import torch.nn as nn from PIL import Image -from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig +from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig, + SiglipVisionConfig) from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig @@ -22,6 +23,10 @@ dummy_seq_data_for_clip, get_max_clip_image_tokens, input_processor_for_clip) from .interfaces import SupportsMultiModal, SupportsPP +from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf, + dummy_seq_data_for_pixtral_hf, + get_max_pixtral_hf_image_tokens, + input_processor_for_pixtral_hf) from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens, input_processor_for_siglip) @@ -31,8 +36,13 @@ class LlavaImagePixelInputs(TypedDict): type: Literal["pixel_values"] - data: torch.Tensor - """Shape: `(batch_size * num_images, num_channels, height, width)`""" + data: Union[torch.Tensor, List[torch.Tensor]] + """ + Shape: `(batch_size * num_images, num_channels, height, width)` + + Note that `height` or `width` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. + """ class LlavaImageEmbeddingInputs(TypedDict): @@ -77,6 +87,8 @@ def get_max_llava_image_tokens(ctx: InputContext): num_image_tokens = get_max_clip_image_tokens(vision_config) elif isinstance(vision_config, SiglipVisionConfig): num_image_tokens = get_max_siglip_image_tokens(vision_config) + elif isinstance(vision_config, PixtralVisionConfig): + num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config) else: msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @@ -120,6 +132,17 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, mm_data = dummy_image_for_siglip(vision_config, num_images) return seq_data, mm_data + elif isinstance(vision_config, PixtralVisionConfig): + seq_data = dummy_seq_data_for_pixtral_hf( + vision_config, + seq_len, + num_images, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + mm_data = dummy_image_for_pixtral_hf(vision_config, num_images) + return seq_data, mm_data msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @@ -163,6 +186,15 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs): image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) + elif isinstance(vision_config, PixtralVisionConfig): + # We ignore image_feature_size_override since we have non-uniform + # image sizes for Pixtral + return input_processor_for_pixtral_hf( + model_config, + vision_config, + inputs, + image_token_id=hf_config.image_token_index, + ) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @@ -189,6 +221,9 @@ def _init_vision_tower(hf_config: LlavaConfig): vision_config, num_hidden_layers_override=num_hidden_layers, ) + elif isinstance(vision_config, PixtralVisionConfig): + # TODO: allow layer override? + return PixtralHFVisionModel(vision_config) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @@ -210,6 +245,15 @@ def __init__(self, self.config = config self.multimodal_config = multimodal_config + # NOTE: These are special cases for Pixtral-12B in the HF-format + # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa + if (config.text_config.architectures is None + and config.text_config.model_type == "mistral"): + config.text_config.architectures = ["MistralForCausalLM"] + if (config.projector_hidden_act is None + and config.vision_config.hidden_act == "gelu"): + config.projector_hidden_act = "gelu" + # TODO: Optionally initializes this for supporting embeddings. self.vision_tower = _init_vision_tower(config) self.multi_modal_projector = LlavaMultiModalProjector( @@ -246,6 +290,7 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) + image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: @@ -256,6 +301,26 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + # Case for models like PixtralHF that have dynamic image sizes + # so we need to produce a list of tensors + if image_sizes is not None: + images = pixel_values + if isinstance(images, torch.Tensor): + # if passed as batch take all images + NN, N, B, C, W, H = images.shape + images = images.reshape(NN * N * B, C, W, H) + images = [images[i] for i in range(images.size(0))] + elif isinstance(images, list): + # if passed as list flatten lists of tensors + while isinstance(images, list) and len(images) == 1: + images = images[0] + + # TODO: Add validation based on image_sizes + return LlavaImagePixelInputs( + type="pixel_values", + data=images, + ) + return LlavaImagePixelInputs( type="pixel_values", data=self._validate_pixel_values( @@ -286,7 +351,8 @@ def _select_image_features(self, image_features: torch.Tensor, *, def _image_pixels_to_features( self, - vision_tower: Union[CLIPVisionModel, SiglipVisionModel], + vision_tower: Union[CLIPVisionModel, SiglipVisionModel, + PixtralHFVisionModel], pixel_values: torch.Tensor, ) -> torch.Tensor: diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index f34d21fdef56f..d09cbe5ca02e9 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -3,18 +3,26 @@ from itertools import tee from typing import Iterable, List, Mapping, Optional, Tuple, Union +import numpy import torch import torch.nn as nn import torch.nn.functional as F from mistral_common.protocol.instruct.messages import ImageChunk from PIL import Image -from transformers import PretrainedConfig +from transformers import PixtralVisionConfig, PretrainedConfig +from transformers.models.pixtral.image_processing_pixtral import ( + _num_image_tokens) +from transformers.models.pixtral.modeling_pixtral import ( + PixtralRotaryEmbedding, apply_rotary_pos_emb, + generate_block_attention_mask, position_ids_in_meshgrid) from xformers.ops.fmha import memory_efficient_attention from xformers.ops.fmha.attn_bias import BlockDiagonalMask from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext +from vllm.config import CacheConfig, ModelConfig, MultiModalConfig +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) +from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -25,6 +33,8 @@ from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SequenceData +from vllm.transformers_utils.processor import cached_get_processor +from vllm.utils import is_list_of from .interfaces import SupportsMultiModal, SupportsPP from .utils import init_vllm_registered_model @@ -576,3 +586,397 @@ def __init__(self, args: VisionEncoderArgs, dim: int): def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w_out(self.gelu(self.w_in(x))) + + +#### HF Transformers version of Pixtral #### +# Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py +# This model follows the Llava family, meaning image embeddings are placed +# instead of the `[IMG]` token placeholders. +# The model uses [`PixtralVisionModel`] for its vision encoder, +# and [`MistralForCausalLM`] for its language decoder. + + +def get_pixtral_hf_patch_grid_length(*, image_size: int, + patch_size: int) -> int: + # Since interpolation is applied, the image size need not be divisible + # assert image_size % patch_size == 0 + return image_size // patch_size + + +def get_pixtral_hf_num_patches(*, image_size: int, patch_size: int) -> int: + grid_length = get_pixtral_hf_patch_grid_length(image_size=image_size, + patch_size=patch_size) + return grid_length * grid_length + + +def get_max_pixtral_hf_image_feature_size( + hf_config: PixtralVisionConfig) -> int: + return get_pixtral_hf_num_patches(image_size=hf_config.image_size, + patch_size=hf_config.patch_size) + + +def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int: + return get_max_pixtral_hf_image_feature_size(hf_config) + + +def dummy_seq_data_for_pixtral_hf( + hf_config: PixtralVisionConfig, + seq_len: int, + num_images: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, +): + if image_feature_size_override is None: + image_feature_size = get_max_pixtral_hf_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + return SequenceData.from_prompt_token_counts( + (image_token_id, image_feature_size * num_images), + (0, seq_len - image_feature_size * num_images), + ) + + +def dummy_image_for_pixtral_hf( + hf_config: PixtralVisionConfig, + num_images: int, + *, + image_width_override: Optional[int] = None, + image_height_override: Optional[int] = None, +): + width = height = hf_config.image_size + if image_width_override is not None: + width = image_width_override + if image_height_override is not None: + height = image_height_override + + image = Image.new("RGB", (width, height), color=0) + return {"image": image if num_images == 1 else [image] * num_images} + + +def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig, + image_width: int, + image_height: int) -> Tuple[int, int]: + # Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501 + # https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180 # noqa: E501 + max_width, max_height = hf_config.image_size, hf_config.image_size + patch_width, patch_height = hf_config.patch_size, hf_config.patch_size + + ratio = max(image_width / max_width, image_height / max_height) + + if ratio > 1: + image_width = int(numpy.ceil(image_width / ratio)) + image_height = int(numpy.ceil(image_height / ratio)) + + num_height_tokens, num_width_tokens = _num_image_tokens( + (image_height, image_width), (patch_height, patch_width)) + + return num_width_tokens, num_height_tokens + + +def input_processor_for_pixtral_hf( + model_config: ModelConfig, + hf_config: PixtralVisionConfig, + inputs: DecoderOnlyInputs, + *, + image_token_id: int, + image_feature_size_override: Optional[Union[int, List[int]]] = None, +) -> DecoderOnlyInputs: + assert image_feature_size_override is None, ( + "image_feature_size_override is not supported for Pixtral") + + multi_modal_data = inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return inputs + + processor = cached_get_processor(model_config.model) + + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + image_data = [image_data] + elif not is_list_of(image_data, Image.Image): + raise TypeError(f"Invalid image type: {type(image_data)}") + + new_prompt = inputs.get("prompt") + new_token_ids = inputs["prompt_token_ids"] + + # Update new_prompt if present + if new_prompt: + replace_strings = [] + for image in image_data: + w, h = image.size + + (num_width_tokens, + num_height_tokens) = get_pixtral_hf_image_feature_size( + hf_config, image_width=w, image_height=h) + + replace_tokens = [[processor.image_token] * num_width_tokens + + [processor.image_break_token] + ] * num_height_tokens + # Flatten list + replace_tokens = [ + item for sublist in replace_tokens for item in sublist + ] + replace_tokens[-1] = processor.image_end_token + replace_str = "".join(replace_tokens) + replace_strings.append(replace_str) + new_prompt = new_prompt.replace(processor.image_token, + "", 1) + + while "" in new_prompt: + replace_str = replace_strings.pop(0) + new_prompt = new_prompt.replace("", replace_str, 1) + + # Update new_token_ids + image_token_id = 10 + image_break_id = 12 + image_end_id = 13 + placeholder_token_id = -999 + replace_tokens_list = [] + for image in image_data: + w, h = image.size + + num_width_tokens, num_height_tokens = get_pixtral_hf_image_feature_size( + hf_config, image_width=w, image_height=h) + + replace_tokens = [[image_token_id] * num_width_tokens + + [image_break_id]] * num_height_tokens + # Flatten list + replace_tokens = [ + item for sublist in replace_tokens for item in sublist + ] + replace_tokens[-1] = image_end_id + replace_tokens_list.append(replace_tokens) + # Replace image id with placeholder id + next_image_index = new_token_ids.index(image_token_id) + new_token_ids[next_image_index] = placeholder_token_id + + while placeholder_token_id in new_token_ids: + replace_tokens = replace_tokens_list.pop(0) + next_image_index = new_token_ids.index(placeholder_token_id) + prefix = new_token_ids[:next_image_index] + postfix = new_token_ids[next_image_index + 1:] + new_token_ids = prefix + replace_tokens + postfix + + # NOTE: Create a defensive copy of the original inputs + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + + +class PixtralHFMLP(nn.Module): + + def __init__(self, config: PixtralVisionConfig): + super().__init__() + assert config.intermediate_size is not None + self.gate_proj = nn.Linear(config.hidden_size, + config.intermediate_size, + bias=False) + self.up_proj = nn.Linear(config.hidden_size, + config.intermediate_size, + bias=False) + self.down_proj = nn.Linear(config.intermediate_size, + config.hidden_size, + bias=False) + self.act = get_act_fn(config.hidden_act) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) + + +class PixtralHFAttention(nn.Module): + + def __init__(self, config: PixtralVisionConfig): + super().__init__() + self.config = config + assert not config.hidden_size % config.num_attention_heads + self.n_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + + self.scale = self.head_dim**-0.5 + + self.q_proj = nn.Linear(config.hidden_size, + config.hidden_size, + bias=False) + self.k_proj = nn.Linear(config.hidden_size, + config.hidden_size, + bias=False) + self.v_proj = nn.Linear(config.hidden_size, + config.hidden_size, + bias=False) + self.o_proj = nn.Linear(config.hidden_size, + config.hidden_size, + bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, patches, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, patches, self.n_heads, + self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, patches, self.n_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, patches, self.n_heads, + self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, + key_states, + cos, + sin, + unsqueeze_dim=0) + + attn_weights = torch.matmul(query_states, key_states.transpose( + 2, 3)) * self.scale + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, + dim=-1, + dtype=torch.float32).to( + query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, patches, -1) + + return self.o_proj(attn_output) + + +class PixtralHFTransformerBlock(nn.Module): + + def __init__(self, config: PixtralVisionConfig): + super().__init__() + self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5) + self.attention = PixtralHFAttention(config) + self.feed_forward = PixtralHFMLP(config) + self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + ) -> torch.Tensor: + r = self.attention.forward(self.attention_norm(hidden_states), + attention_mask=attention_mask, + position_embeddings=position_embeddings) + h = hidden_states + r + r = self.feed_forward.forward(self.ffn_norm(h)) + out = h + r + return out + + +class PixtralHFTransformer(nn.Module): + + def __init__(self, config: PixtralVisionConfig): + super().__init__() + self.layers = torch.nn.ModuleList() + for _ in range(config.num_hidden_layers): + self.layers.append(PixtralHFTransformerBlock(config)) + + def forward( + self, + x: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + ) -> torch.Tensor: + for layer in self.layers: + x = layer(x, attention_mask, position_embeddings) + return x + + +class PixtralHFVisionModel(nn.Module): + + def __init__(self, config: PixtralVisionConfig): + super().__init__() + + self.config = config + self.patch_conv = nn.Conv2d( + in_channels=config.num_channels, + out_channels=config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, + bias=False, + ) + self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5) + self.transformer = PixtralHFTransformer(config) + self.dtype = next(self.parameters()).dtype + self.device = next(self.parameters()).device + self.patch_positional_embedding = PixtralRotaryEmbedding( + config, self.device) + + def forward( + self, + pixel_values: List[torch.Tensor], + ) -> torch.Tensor: + """ + Args: + pixel_values: tensor of token features for + all tokens of all images of shape (N_toks, D) + Returns: + image_features: tensor of token features for + all tokens of all images of shape (N_toks, D) + """ + # pass images through initial convolution independently + patch_embeds_list = [ + self.patch_conv( + img.reshape(-1, img.shape[-3], img.shape[-2], + img.shape[-1]).to(self.dtype)) + for img in pixel_values + ] + + # flatten to a single sequence + patch_embeds = torch.cat( + [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1) + patch_embeds = self.ln_pre(patch_embeds) + + # positional embeddings + position_ids = position_ids_in_meshgrid( + patch_embeds_list, + max_width=self.config.image_size // self.config.patch_size).to( + self.device) + + position_embedding = self.patch_positional_embedding( + patch_embeds, position_ids) + attention_mask = generate_block_attention_mask( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], + patch_embeds) + out = self.transformer(patch_embeds, attention_mask, + position_embedding) + + return out + + # (TODO) Add prefix argument for filtering out weights to be loaded + # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [] + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index f7d632a83cc33..a3540abdc23d3 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -22,7 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" -from functools import lru_cache, partial +from functools import partial from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, Tuple, Type, TypedDict, Union) @@ -63,7 +63,7 @@ from vllm.multimodal.image import cached_get_image_processor from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.config import uses_mrope -from vllm.transformers_utils.processor import get_processor +from vllm.transformers_utils.processor import cached_get_processor from .interfaces import SupportsMultiModal, SupportsPP from .utils import (PPMissingLayer, get_vit_attn_backend, @@ -544,8 +544,6 @@ def forward( # === Vision input helpers === # -cached_get_processor = lru_cache(get_processor) - def mm_input_mapper_for_qwen2_vl( ctx: InputContext, diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index 98663f7f0bd07..f1523667b0466 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -1,3 +1,4 @@ +from functools import lru_cache from typing import Any, cast @@ -37,6 +38,9 @@ def get_processor( return cast(ProcessorMixin, processor) +cached_get_processor = lru_cache(get_processor) + + def get_image_processor( processor_name: str, *args: Any,