Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM] Support caching in merged multi-modal processor #11341

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
faa9b84
Refactor multi-modal processor to support caching
DarkLight1337 Dec 19, 2024
9711a15
Clean up
DarkLight1337 Dec 19, 2024
29e3fcd
Fix cached result being mutated
DarkLight1337 Dec 19, 2024
ab64e85
Rename
DarkLight1337 Dec 19, 2024
81215a2
Fix docs
DarkLight1337 Dec 19, 2024
cf52b3b
Fix a typo
DarkLight1337 Dec 19, 2024
a4a8eb9
Fix unhandled sampling rate in initialization
DarkLight1337 Dec 19, 2024
c48f7c5
format
DarkLight1337 Dec 19, 2024
b84ff42
Change the delimiter
DarkLight1337 Dec 19, 2024
c3f1bde
Fix extra dimension
DarkLight1337 Dec 19, 2024
32e5197
Update
DarkLight1337 Dec 19, 2024
7264d4e
Use the inner processor to enable fine-grained caching
DarkLight1337 Dec 20, 2024
02ea829
Make the cache optional
DarkLight1337 Dec 20, 2024
b981a9d
Fix invalid kwargs being passed to tokenizer
DarkLight1337 Dec 20, 2024
5dde7d0
Fix Phi3V prompt replacement
DarkLight1337 Dec 20, 2024
7339ab8
Refine
DarkLight1337 Dec 20, 2024
509411d
Enable fine-grained caching for audio models
DarkLight1337 Dec 20, 2024
c0454f5
Add fallback
DarkLight1337 Dec 20, 2024
d50ef03
Fix typo
DarkLight1337 Dec 20, 2024
81f7d61
Fix video processor for Qwen2-VL
DarkLight1337 Dec 20, 2024
13eede3
Merge branch 'main' into mm-processor-cache
DarkLight1337 Dec 20, 2024
affbc5c
Fix a bunch of type errors
DarkLight1337 Dec 20, 2024
b4ddfb1
Fix qwen2-vl
DarkLight1337 Dec 20, 2024
4b3db32
Fix
DarkLight1337 Dec 20, 2024
dafbc7f
Simplify Pixtral-HF
DarkLight1337 Dec 21, 2024
38aaff8
Cleanup
DarkLight1337 Dec 21, 2024
5fcb5d6
Fix Pixtral-HF
DarkLight1337 Dec 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def linkcode_resolve(domain, info):

# Mock out external dependencies here, otherwise the autodoc pages may be blank.
autodoc_mock_imports = [
"blake3",
"compressed_tensors",
"cpuinfo",
"cv2",
Expand All @@ -178,7 +179,7 @@ def linkcode_resolve(domain, info):
"tensorizer",
"pynvml",
"outlines",
"xgrammar,"
"xgrammar",
"librosa",
"soundfile",
"gguf",
Expand Down
18 changes: 7 additions & 11 deletions vllm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,12 @@ def get_hf_processor(
def call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
inference_kwargs: Mapping[str, object],
data: Mapping[str, object],
kwargs: Optional[Mapping[str, object]] = None,
) -> BatchFeature:
if kwargs is None:
kwargs = {}

assert callable(hf_processor)

base_kwargs = self.model_config.mm_processor_kwargs
Expand All @@ -144,21 +146,15 @@ def call_hf_processor(

merged_kwargs = resolve_mm_processor_kwargs(
base_kwargs,
inference_kwargs,
kwargs,
hf_processor,
requires_kw_only=False,
allow_var_kwargs=True,
)

try:
return hf_processor(
text=prompt,
**processor_data,
**merged_kwargs,
return_tensors="pt",
)
return hf_processor(**data, **merged_kwargs, return_tensors="pt")
except Exception as exc:
data = dict(text=prompt, **processor_data)
msg = (f"Failed to apply {type(hf_processor).__name__} "
f"on data={data} with kwargs={merged_kwargs}")

Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
hf_mm_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(LlavaConfig)
image_token_id = hf_config.image_token_index
Expand Down Expand Up @@ -218,7 +218,7 @@ def _get_dummy_mm_inputs(
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=data,
mm_processor_kwargs={},
hf_mm_kwargs={},
)


Expand Down
20 changes: 9 additions & 11 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,11 @@ def get_max_phi3v_image_tokens(
*,
num_crops: Optional[int] = None,
) -> int:
mm_processor_kwargs = {}
hf_mm_kwargs = {}
if num_crops:
mm_processor_kwargs["num_crops"] = num_crops
hf_mm_kwargs["num_crops"] = num_crops

processor = ctx.get_hf_processor(**mm_processor_kwargs)
processor = ctx.get_hf_processor(**hf_mm_kwargs)

return processor.calc_num_image_tokens_from_image_size(
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
Expand All @@ -331,16 +331,14 @@ def _get_hf_processor(

def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)

# Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
Expand All @@ -356,7 +354,7 @@ def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
hf_mm_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
Expand Down Expand Up @@ -401,7 +399,7 @@ def _get_dummy_mm_inputs(
return ProcessorInputs(
prompt_text="".join(image_tokens[:num_images]),
mm_data=data,
mm_processor_kwargs={},
hf_mm_kwargs={},
)


Expand Down
37 changes: 20 additions & 17 deletions vllm/model_executor/models/qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import numpy as np
import torch
import torch.nn as nn
from transformers import BatchFeature, ProcessorMixin
from transformers import BatchFeature
from transformers.models.qwen2_audio import (Qwen2AudioConfig,
Qwen2AudioEncoder,
Qwen2AudioProcessor)
Expand Down Expand Up @@ -88,56 +88,59 @@ def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int:

class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):

def _get_hf_processor(self) -> Qwen2AudioProcessor:
def _get_hf_processor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
) -> Qwen2AudioProcessor:
return self.ctx.get_hf_processor(Qwen2AudioProcessor)

def _get_feature_extractor(self) -> WhisperFeatureExtractor:
return self._get_hf_processor().feature_extractor # type: ignore

def _get_processor_data(
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
# resample audio to the model's sampling rate
feature_extractor = self._get_feature_extractor()
mm_items.resample_audios(feature_extractor.sampling_rate)

return super()._get_processor_data(mm_items)
return super()._get_hf_mm_data(mm_items)

def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processor_data = dict(processor_data)
audios = processor_data.pop("audios", [])
mm_data = dict(mm_data)
audios = mm_data.pop("audios", [])

if audios:
processor_data["audios"] = audios
mm_data["audios"] = audios

feature_extractor = self._get_feature_extractor()
mm_processor_kwargs = dict(
**mm_processor_kwargs,
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
else:
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
pass

return super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)

def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
hf_mm_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
placeholder = hf_config.audio_token_index
Expand Down Expand Up @@ -175,7 +178,7 @@ def _get_dummy_mm_inputs(
return ProcessorInputs(
prompt_text="<|AUDIO|>" * audio_count,
mm_data=data,
mm_processor_kwargs={},
hf_mm_kwargs={},
)


Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ def _get_hf_processor(

return hf_processor

def _get_processor_data(
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
Expand Down Expand Up @@ -817,7 +817,7 @@ def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
hf_mm_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
image_processor = _get_image_processor(hf_processor)
Expand Down Expand Up @@ -869,7 +869,7 @@ def _get_dummy_mm_inputs(
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=data,
mm_processor_kwargs={},
hf_mm_kwargs={},
)


Expand Down
43 changes: 24 additions & 19 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,45 +72,51 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):

class UltravoxMultiModalProcessor(BaseMultiModalProcessor):

def _get_hf_processor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
) -> ProcessorMixin:
return self.ctx.get_hf_processor()

def _get_feature_extractor(self) -> WhisperFeatureExtractor:
hf_processor = self._get_hf_processor()
return hf_processor.audio_processor.feature_extractor # type: ignore

def _get_processor_data(
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
# resample audio to the model's sampling rate
feature_extractor = self._get_feature_extractor()
mm_items.resample_audios(feature_extractor.sampling_rate)

return super()._get_processor_data(mm_items)
return super()._get_hf_mm_data(mm_items)

def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processor_data = dict(processor_data)
audios = processor_data.pop("audios", [])
mm_data = dict(mm_data)
audios = mm_data.pop("audios", [])

if not audios:
return super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)

feature_extractor = self._get_feature_extractor()
mm_processor_kwargs = dict(
**mm_processor_kwargs,
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)

# Already resampled by _get_processor_data
# Already resampled by _get_hf_mm_data
assert is_list_of(audios, np.ndarray)

# Ultravox processor doesn't support multiple inputs,
Expand All @@ -119,13 +125,12 @@ def _call_hf_processor(
shared_outputs = {}
for audio in audios:
# NOTE: Ultravox processor accepts "audio" instead of "audios"
item_processor_data = dict(**processor_data, audio=audio)
item_processor_data = dict(**mm_data, audio=audio)

item_outputs = super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=item_processor_data,
mm_processor_kwargs=mm_processor_kwargs,
mm_data=item_processor_data,
mm_kwargs=mm_kwargs,
)

audio_features.append(item_outputs.pop("audio_values")[0])
Expand All @@ -143,7 +148,7 @@ def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
hf_mm_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
placeholder = hf_processor.audio_token_replacement # type: ignore
Expand Down Expand Up @@ -175,7 +180,7 @@ def _get_dummy_mm_inputs(
return ProcessorInputs(
prompt_text="<|audio|>" * audio_count,
mm_data=data,
mm_processor_kwargs={},
hf_mm_kwargs={},
)


Expand Down
Loading
Loading