Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM] Support caching in merged multi-modal processor #11341

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
faa9b84
Refactor multi-modal processor to support caching
DarkLight1337 Dec 19, 2024
9711a15
Clean up
DarkLight1337 Dec 19, 2024
29e3fcd
Fix cached result being mutated
DarkLight1337 Dec 19, 2024
ab64e85
Rename
DarkLight1337 Dec 19, 2024
81215a2
Fix docs
DarkLight1337 Dec 19, 2024
cf52b3b
Fix a typo
DarkLight1337 Dec 19, 2024
a4a8eb9
Fix unhandled sampling rate in initialization
DarkLight1337 Dec 19, 2024
c48f7c5
format
DarkLight1337 Dec 19, 2024
b84ff42
Change the delimiter
DarkLight1337 Dec 19, 2024
c3f1bde
Fix extra dimension
DarkLight1337 Dec 19, 2024
32e5197
Update
DarkLight1337 Dec 19, 2024
7264d4e
Use the inner processor to enable fine-grained caching
DarkLight1337 Dec 20, 2024
02ea829
Make the cache optional
DarkLight1337 Dec 20, 2024
b981a9d
Fix invalid kwargs being passed to tokenizer
DarkLight1337 Dec 20, 2024
5dde7d0
Fix Phi3V prompt replacement
DarkLight1337 Dec 20, 2024
7339ab8
Refine
DarkLight1337 Dec 20, 2024
509411d
Enable fine-grained caching for audio models
DarkLight1337 Dec 20, 2024
c0454f5
Add fallback
DarkLight1337 Dec 20, 2024
d50ef03
Fix typo
DarkLight1337 Dec 20, 2024
81f7d61
Fix video processor for Qwen2-VL
DarkLight1337 Dec 20, 2024
13eede3
Merge branch 'main' into mm-processor-cache
DarkLight1337 Dec 20, 2024
affbc5c
Fix a bunch of type errors
DarkLight1337 Dec 20, 2024
b4ddfb1
Fix qwen2-vl
DarkLight1337 Dec 20, 2024
4b3db32
Fix
DarkLight1337 Dec 20, 2024
dafbc7f
Simplify Pixtral-HF
DarkLight1337 Dec 21, 2024
38aaff8
Cleanup
DarkLight1337 Dec 21, 2024
5fcb5d6
Fix Pixtral-HF
DarkLight1337 Dec 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def linkcode_resolve(domain, info):

# Mock out external dependencies here, otherwise the autodoc pages may be blank.
autodoc_mock_imports = [
"blake3",
"compressed_tensors",
"cpuinfo",
"cv2",
Expand All @@ -178,7 +179,7 @@ def linkcode_resolve(domain, info):
"tensorizer",
"pynvml",
"outlines",
"xgrammar,"
"xgrammar",
"librosa",
"soundfile",
"gguf",
Expand Down
67 changes: 53 additions & 14 deletions vllm/inputs/registry.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import functools
from collections import UserDict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Mapping, NamedTuple,
from typing import (TYPE_CHECKING, Any, Callable, Literal, Mapping, NamedTuple,
Optional, Protocol, Union)

from torch import nn
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from transformers.models.whisper import WhisperFeatureExtractor
from typing_extensions import TypeVar, assert_never

from vllm.logger import init_logger
Expand Down Expand Up @@ -111,6 +112,39 @@ def get_hf_processor(

return hf_processor

def get_modality_processor(
self,
hf_processor: ProcessorMixin,
modality_data_key: Literal["text", "images", "videos", "audios"],
) -> Callable[..., BatchFeature]:
"""
Get the HuggingFace modality-specific processor which is
a child of a :class:`transformers.ProcessorMixin`, identified by
the corresponding keyword argument in its `__call__` method.
"""
if modality_data_key == "text":
attributes = ["tokenizer"]
elif modality_data_key == "images":
attributes = ["image_processor"]
elif modality_data_key == "videos":
attributes = ["video_processor", "image_processor"]
elif modality_data_key == "audios":
attributes = ["audio_processor", "feature_extractor"]
else:
assert_never(modality_data_key)

modality_processor = next(
(getattr(hf_processor, attr)
for attr in attributes if hasattr(hf_processor, attr)),
None,
)
if modality_processor is None:
raise AttributeError(
f"Cannot find HuggingFace processor for {modality_data_key} "
f"inside {type(hf_processor)}")

return modality_processor


@dataclass(frozen=True)
class InputProcessingContext(InputContext):
Expand All @@ -131,34 +165,39 @@ def get_hf_processor(

def call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
inference_kwargs: Mapping[str, object],
hf_processor: Union[ProcessorMixin, Callable[..., BatchFeature]],
data: Mapping[str, object],
kwargs: Optional[Mapping[str, object]] = None,
) -> BatchFeature:
assert callable(hf_processor)

if kwargs is None:
kwargs = {}

base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}

merged_kwargs = resolve_mm_processor_kwargs(
base_kwargs,
inference_kwargs,
kwargs,
hf_processor,
requires_kw_only=False,
allow_var_kwargs=True,
# Modality-specific processors should state each kwarg individually
allow_var_kwargs=isinstance(hf_processor, ProcessorMixin),
)

# WhisperFeatureExtractor accepts `raw_speech`
# but the parent HF processor accepts `audios`
# Making `audios` an alias of `raw_speech` simplifies the calling code
if (isinstance(hf_processor, WhisperFeatureExtractor)
and "raw_speech" not in data):
data = dict(data)
data["raw_speech"] = data.pop("audios")

try:
return hf_processor(
text=prompt,
**processor_data,
**merged_kwargs,
return_tensors="pt",
)
return hf_processor(**data, **merged_kwargs, return_tensors="pt")
except Exception as exc:
data = dict(text=prompt, **processor_data)
msg = (f"Failed to apply {type(hf_processor).__name__} "
f"on data={data} with kwargs={merged_kwargs}")

Expand Down
46 changes: 22 additions & 24 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from functools import cached_property
from types import MethodType
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
Tuple, TypedDict, Union)

Expand Down Expand Up @@ -116,36 +115,35 @@ def get_max_llava_image_tokens(ctx: InputContext):

class LlavaMultiModalProcessor(BaseMultiModalProcessor):

def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
if getattr(hf_processor, "__is_patched__", False):
return # Already patched

image_processor = hf_processor.image_processor # type: ignore
orig_preprocess = image_processor.preprocess

def preprocess(__self, *args, **kwargs):
hf_inputs = orig_preprocess(*args, **kwargs)
hf_inputs["is_pixtral"] = torch.tensor(True)
return hf_inputs

image_processor.preprocess = MethodType(preprocess, image_processor)
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor))

hf_processor.__is_patched__ = True # type: ignore
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)

def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
hf_processor = self.ctx.get_hf_processor(
(LlavaProcessor, PixtralProcessor))
images = mm_data.get("images", [])
assert isinstance(images, list)

if isinstance(hf_processor, PixtralProcessor):
self._patch_pixtral_processor(hf_processor)
is_pixtral = isinstance(self._get_hf_processor(), PixtralProcessor)
processed_outputs["is_pixtral"] = \
torch.tensor([is_pixtral] * len(images))

return hf_processor
return processed_outputs

def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
hf_mm_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(LlavaConfig)
image_token_id = hf_config.image_token_index
Expand Down Expand Up @@ -218,7 +216,7 @@ def _get_dummy_mm_inputs(
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=data,
mm_processor_kwargs={},
hf_mm_kwargs={},
)


Expand Down Expand Up @@ -379,8 +377,8 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False]))
image_embeds = kwargs.pop("image_embeds", None)
is_pixtral = kwargs.pop("is_pixtral", None)

if pixel_values is None and image_embeds is None:
return None
Expand Down
20 changes: 9 additions & 11 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,11 @@ def get_max_phi3v_image_tokens(
*,
num_crops: Optional[int] = None,
) -> int:
mm_processor_kwargs = {}
hf_mm_kwargs = {}
if num_crops:
mm_processor_kwargs["num_crops"] = num_crops
hf_mm_kwargs["num_crops"] = num_crops

processor = ctx.get_hf_processor(**mm_processor_kwargs)
processor = ctx.get_hf_processor(**hf_mm_kwargs)

return processor.calc_num_image_tokens_from_image_size(
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
Expand All @@ -331,16 +331,14 @@ def _get_hf_processor(

def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)

# Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
Expand All @@ -356,7 +354,7 @@ def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
hf_mm_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
Expand Down Expand Up @@ -401,7 +399,7 @@ def _get_dummy_mm_inputs(
return ProcessorInputs(
prompt_text="".join(image_tokens[:num_images]),
mm_data=data,
mm_processor_kwargs={},
hf_mm_kwargs={},
)


Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def __init__(
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
norm_layer: Callable = nn.LayerNorm,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
Expand Down Expand Up @@ -266,7 +266,7 @@ def __init__(
layers: int,
heads: int,
mlp_ratio: float = 4.0,
norm_layer: Callable = nn.LayerNorm,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
Expand Down
53 changes: 35 additions & 18 deletions vllm/model_executor/models/qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import numpy as np
import torch
import torch.nn as nn
from transformers import BatchFeature, ProcessorMixin
from transformers import BatchFeature
from transformers.models.qwen2_audio import (Qwen2AudioConfig,
Qwen2AudioEncoder,
Qwen2AudioProcessor)
Expand Down Expand Up @@ -88,56 +88,73 @@ def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int:

class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):

def _get_hf_processor(self) -> Qwen2AudioProcessor:
def _get_hf_processor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
padding: Optional[str] = None,
) -> Qwen2AudioProcessor:
return self.ctx.get_hf_processor(Qwen2AudioProcessor)

def _get_feature_extractor(self) -> WhisperFeatureExtractor:
return self._get_hf_processor().feature_extractor # type: ignore

def _get_processor_data(
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
# resample audio to the model's sampling rate
feature_extractor = self._get_feature_extractor()
mm_items.resample_audios(feature_extractor.sampling_rate)

return super()._get_processor_data(mm_items)
return super()._get_hf_mm_data(mm_items)

def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processor_data = dict(processor_data)
audios = processor_data.pop("audios", [])
mm_data = dict(mm_data)
audios = mm_data.pop("audios", [])

if audios:
processor_data["audios"] = audios
mm_data["audios"] = audios

feature_extractor = self._get_feature_extractor()
mm_processor_kwargs = dict(
**mm_processor_kwargs,
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
# When fine-grained caching is applied,
# the individual processors are called separately.
return_attention_mask=True,
padding="max_length",
)
else:
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
pass

return super()._call_hf_processor(
hf_processor,
processed_outputs = super()._call_hf_processor(
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)

# When fine-grained caching is applied,
# the individual processors are called separately.
if "attention_mask" in processed_outputs:
processed_outputs["feature_attention_mask"] = \
processed_outputs.pop("attention_mask")

return processed_outputs

def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
hf_mm_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
placeholder = hf_config.audio_token_index
Expand Down Expand Up @@ -175,7 +192,7 @@ def _get_dummy_mm_inputs(
return ProcessorInputs(
prompt_text="<|AUDIO|>" * audio_count,
mm_data=data,
mm_processor_kwargs={},
hf_mm_kwargs={},
)


Expand Down
Loading
Loading