diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index f99f5f424ab28..6855b4c017cfe 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -6,6 +6,7 @@ from torch import nn from transformers import BatchFeature, PretrainedConfig, ProcessorMixin +from transformers.models.whisper import WhisperFeatureExtractor from typing_extensions import TypeVar, assert_never from vllm.logger import init_logger @@ -186,6 +187,14 @@ def call_hf_processor( allow_var_kwargs=isinstance(hf_processor, ProcessorMixin), ) + # WhisperFeatureExtractor accepts `raw_speech` + # but the parent HF processor accepts `audios` + # Making `audios` an alias of `raw_speech` simplifies the calling code + if (isinstance(hf_processor, WhisperFeatureExtractor) + and "raw_speech" not in data): + data = dict(data) + data["raw_speech"] = data.pop("audios") + try: return hf_processor(**data, **merged_kwargs, return_tensors="pt") except Exception as exc: diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 4031a7a7626b2..7d52737400a40 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -93,6 +93,8 @@ def _get_hf_processor( *, # Ignored in initialization sampling_rate: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + padding: Optional[str] = None, ) -> Qwen2AudioProcessor: return self.ctx.get_hf_processor(Qwen2AudioProcessor) @@ -125,17 +127,29 @@ def _call_hf_processor( mm_kwargs = dict( **mm_kwargs, sampling_rate=feature_extractor.sampling_rate, + # When fine-grained caching is applied, + # the individual processors are called separately. + return_attention_mask=True, + padding="max_length", ) else: # NOTE: WhisperFeatureExtractor cannot handle empty list of audios pass - return super()._call_hf_processor( + processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, ) + # When fine-grained caching is applied, + # the individual processors are called separately. + if "attention_mask" in processed_outputs: + processed_outputs["feature_attention_mask"] = \ + processed_outputs.pop("attention_mask") + + return processed_outputs + def _get_prompt_replacements( self, mm_items: MultiModalDataItems, diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index d9b6037a343ad..5c351d31fa91c 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -21,8 +21,8 @@ from .audio import resample_audio from .inputs import (AudioItem, ImageItem, MultiModalDataDict, - MultiModalInputsV2, MultiModalKwargs, PlaceholderRange, - VideoItem) + MultiModalInputsV2, MultiModalKwargs, NestedTensors, + PlaceholderRange, VideoItem) logger = init_logger(__name__) @@ -682,7 +682,8 @@ def _cached_call_fine( processed_data = dict(**processed_text) for data_key, items in mm_data.items(): - processed_modal_items = defaultdict[str, list[torch.Tensor]](list) + processed_modal_items = defaultdict[str, Union[ + list[torch.Tensor], list[NestedTensors]]](list) for item in items: self.maybe_log_cache_stats(self._fine_mm_cache, @@ -703,7 +704,14 @@ def _cached_call_fine( # Remove the extra batch dimension processed_modal_items[k].append(v[0]) - processed_data.update(processed_modal_items) + for k, vs in processed_modal_items.items(): + # Try to merge elements into a single tensor + if is_list_of(vs, torch.Tensor, check="all") and len(vs) > 0: + first_shape = vs[0].shape + if all(v.shape == first_shape for v in vs): + vs = torch.stack(vs) + + processed_data[k] = vs return BatchFeature(processed_data)