Skip to content

Commit

Permalink
[VLM] Merged multimodal processor for Qwen2-Audio (vllm-project#11303)
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored Dec 19, 2024
1 parent c6b0a7d commit 6142ef0
Show file tree
Hide file tree
Showing 11 changed files with 416 additions and 360 deletions.
6 changes: 6 additions & 0 deletions examples/offline_inference_audio_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
2: "What sport and what nursery rhyme are referenced?"
}

# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.


# Ultravox 0.3
def run_ultravox(question: str, audio_count: int):
Expand All @@ -33,6 +37,8 @@ def run_ultravox(question: str, audio_count: int):
add_generation_prompt=True)

llm = LLM(model=model_name,
max_model_len=4096,
max_num_seqs=5,
trust_remote_code=True,
limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None
Expand Down
9 changes: 4 additions & 5 deletions tests/models/decoder_only/audio_language/test_ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest_asyncio
from transformers import AutoModel, AutoTokenizer, BatchEncoding

from vllm.multimodal.audio import resample_audio
from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE

Expand Down Expand Up @@ -130,16 +131,14 @@ def process(hf_inputs: BatchEncoding, **kwargs):
dtype=dtype,
postprocess_inputs=process,
auto_cls=AutoModel) as hf_model:
import librosa

hf_outputs_per_audio = [
hf_model.generate_greedy_logprobs_limit(
[hf_prompt],
max_tokens,
num_logprobs=num_logprobs,
audios=[(librosa.resample(audio[0],
orig_sr=audio[1],
target_sr=16000), 16000)])
audios=[(resample_audio(audio[0],
orig_sr=audio[1],
target_sr=16000), 16000)])
for _, hf_prompt, audio in prompts_and_audios
]

Expand Down
100 changes: 70 additions & 30 deletions vllm/inputs/registry.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import functools
from collections import UserDict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple,
Optional, Protocol, Type)
from typing import (TYPE_CHECKING, Any, Callable, Mapping, NamedTuple,
Optional, Protocol, Union)

from torch import nn
from transformers import PretrainedConfig, ProcessorMixin
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import TypeVar, assert_never

from vllm.logger import init_logger
Expand All @@ -26,6 +26,7 @@
logger = init_logger(__name__)

C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
P = TypeVar("P", bound=ProcessorMixin, default=ProcessorMixin)


@dataclass(frozen=True)
Expand All @@ -38,24 +39,28 @@ class InputContext:
model_config: "ModelConfig"
"""The configuration of the model."""

def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C:
def get_hf_config(
self,
typ: Union[type[C], tuple[type[C], ...]] = PretrainedConfig,
/,
) -> C:
"""
Get the HuggingFace configuration
(:class:`transformers.PretrainedConfig`) of the model,
additionally checking its type.
Raises:
TypeError: If the model is not of the specified type.
TypeError: If the configuration is not of the specified type.
"""
hf_config = self.model_config.hf_config
if not isinstance(hf_config, hf_config_type):
if not isinstance(hf_config, typ):
raise TypeError("Invalid type of HuggingFace config. "
f"Expected type: {hf_config_type}, but "
f"Expected type: {typ}, but "
f"found type: {type(hf_config)}")

return hf_config

def get_hf_image_processor_config(self) -> Dict[str, Any]:
def get_hf_image_processor_config(self) -> dict[str, Any]:
"""
Get the HuggingFace image processor configuration of the model.
"""
Expand All @@ -74,58 +79,93 @@ def get_mm_config(self):

return mm_config

def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
def get_hf_processor(
self,
typ: Union[type[P], tuple[type[P], ...]] = ProcessorMixin,
/,
**kwargs: object,
) -> P:
"""
Get the HuggingFace processor
(:class:`transformers.ProcessorMixin`) of the model,
additionally checking its type.
Raises:
TypeError: If the processor is not of the specified type.
"""
base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}

merged_kwargs = {**base_kwargs, **kwargs}

return cached_get_processor(
hf_processor = cached_get_processor(
self.model_config.model,
trust_remote_code=self.model_config.trust_remote_code,
**merged_kwargs,
)
if not isinstance(hf_processor, typ):
raise TypeError("Invalid type of HuggingFace processor. "
f"Expected type: {typ}, but "
f"found type: {type(hf_processor)}")

return hf_processor


@dataclass(frozen=True)
class InputProcessingContext(InputContext):
tokenizer: AnyTokenizer
"""The tokenizer used to tokenize the inputs."""

def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}

merged_kwargs = {**base_kwargs, **kwargs}

return cached_get_processor(
self.model_config.model,
tokenizer=self.tokenizer, # Override the tokenizer with ours
trust_remote_code=self.model_config.trust_remote_code,
**merged_kwargs,
def get_hf_processor(
self,
typ: Union[type[P], tuple[type[P], ...]] = ProcessorMixin,
/,
**kwargs: object,
) -> P:
return super().get_hf_processor(
typ,
tokenizer=self.tokenizer,
**kwargs,
)

def resolve_hf_processor_call_kwargs(
def call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
inference_kwargs: Mapping[str, object],
) -> Mapping[str, object]:
) -> BatchFeature:
assert callable(hf_processor)

base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}

return resolve_mm_processor_kwargs(
merged_kwargs = resolve_mm_processor_kwargs(
base_kwargs,
inference_kwargs,
hf_processor,
requires_kw_only=False,
allow_var_kwargs=True,
)

try:
return hf_processor(
text=prompt,
**processor_data,
**merged_kwargs,
return_tensors="pt",
)
except Exception as exc:
data = dict(text=prompt, **processor_data)
msg = (f"Failed to apply {type(hf_processor).__name__} "
f"on data={data} with kwargs={merged_kwargs}")

raise RuntimeError(msg) from exc


N = TypeVar("N", bound=Type[nn.Module])
N = TypeVar("N", bound=type[nn.Module])


class DummyData(NamedTuple):
Expand Down Expand Up @@ -232,7 +272,7 @@ def wrapper(model_cls: N) -> N:

return wrapper

def _get_dummy_data_factory(self, model_cls: Type[nn.Module]):
def _get_dummy_data_factory(self, model_cls: type[nn.Module]):
return self._dummy_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)

Expand All @@ -257,7 +297,7 @@ def wrapper(model_cls: N) -> N:

return wrapper

def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]):
def _get_dummy_encoder_data_factory(self, model_cls: type[nn.Module]):
return self._dummy_encoder_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)

Expand Down Expand Up @@ -368,14 +408,14 @@ def wrapper(model_cls: N) -> N:

return wrapper

def _get_model_input_processor(self, model_cls: Type[nn.Module]):
def _get_model_input_processor(self, model_cls: type[nn.Module]):
return self._input_processors_by_model_type \
.get(model_cls, self._default_input_processor)

def _ensure_mm_kwargs(
self,
inputs: SingletonInputs,
mm_processor_kwargs: Dict[str, Any],
mm_processor_kwargs: dict[str, Any],
):
if inputs["type"] == "token":
# In case the input processor for that model fails to set it
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ def preprocess(__self, *args, **kwargs):
hf_processor.__is_patched__ = True # type: ignore

def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
hf_processor = self.ctx.get_hf_processor()
assert isinstance(hf_processor, (LlavaProcessor, PixtralProcessor))
hf_processor = self.ctx.get_hf_processor(
(LlavaProcessor, PixtralProcessor))

if isinstance(hf_processor, PixtralProcessor):
self._patch_pixtral_processor(hf_processor)
Expand Down
16 changes: 11 additions & 5 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataDict,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
Expand Down Expand Up @@ -330,20 +329,27 @@ def _get_hf_processor(
return self.ctx.get_hf_processor(num_crops=num_crops)
return self.ctx.get_hf_processor()

def _apply_hf_processor(
def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
mm_data: MultiModalDataDict,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._apply_hf_processor(
prompt, mm_data, mm_processor_kwargs)
processed_outputs = super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
)

# Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
# which will cause OverflowError when decoding the prompt_ids.
# Therefore, we need to do an early replacement here
token_ids = processed_outputs['input_ids']
token_ids[token_ids < 0] = _IMAGE_TOKEN_ID
processed_outputs['input_ids'] = token_ids

return processed_outputs

def _get_prompt_replacements(
Expand Down
Loading

0 comments on commit 6142ef0

Please sign in to comment.