diff --git a/docs/source/conf.py b/docs/source/conf.py index ded6742ea2e5c..d24ed03207824 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -112,6 +112,8 @@ def setup(app): "tensorizer", "pynvml", "outlines", + "librosa", + "soundfile", "gguf", "lark", ] diff --git a/docs/source/models/enabling_multimodal_inputs.rst b/docs/source/models/enabling_multimodal_inputs.rst index 20be920b5f699..dc76f921d5b09 100644 --- a/docs/source/models/enabling_multimodal_inputs.rst +++ b/docs/source/models/enabling_multimodal_inputs.rst @@ -15,14 +15,14 @@ This document walks you through the steps to extend a vLLM model so that it acce It is assumed that you have already implemented the model in vLLM according to :ref:`these steps `. Further update the model as follows: -- Implement the :class:`~vllm.model_executor.models.interfaces.SupportsVision` interface. +- Implement the :class:`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface. .. code-block:: diff - + from vllm.model_executor.models.interfaces import SupportsVision + + from vllm.model_executor.models.interfaces import SupportsMultiModal - class YourModelForImage2Seq(nn.Module): - + class YourModelForImage2Seq(nn.Module, SupportsVision): + + class YourModelForImage2Seq(nn.Module, SupportsMultiModal): .. note:: The model class does not have to be named :code:`*ForCausalLM`. @@ -51,11 +51,11 @@ This decorator accepts a function that maps multi-modal inputs to the keyword ar .. code-block:: diff - from vllm.model_executor.models.interfaces import SupportsVision + from vllm.model_executor.models.interfaces import SupportsMultiModal + from vllm.multimodal import MULTIMODAL_REGISTRY + @MULTIMODAL_REGISTRY.register_image_input_mapper() - class YourModelForImage2Seq(nn.Module, SupportsVision): + class YourModelForImage2Seq(nn.Module, SupportsMultiModal): A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function. @@ -72,13 +72,13 @@ and register it via :meth:`INPUT_REGISTRY.register_dummy_data ) @INPUT_REGISTRY.register_dummy_data() - class YourModelForImage2Seq(nn.Module, SupportsVision): + class YourModelForImage2Seq(nn.Module, SupportsMultiModal): Here are some examples: @@ -98,13 +98,13 @@ In such cases, you can define your own dummy data by registering a factory metho .. code-block:: diff from vllm.inputs import INPUT_REGISTRY - from vllm.model_executor.models.interfaces import SupportsVision + from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.multimodal import MULTIMODAL_REGISTRY @MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens() + @INPUT_REGISTRY.register_dummy_data() - class YourModelForImage2Seq(nn.Module, SupportsVision): + class YourModelForImage2Seq(nn.Module, SupportsMultiModal): .. note:: The dummy data should have the maximum possible number of multi-modal tokens, as described in the previous step. @@ -128,14 +128,14 @@ You can register input processors via :meth:`INPUT_REGISTRY.register_input_proce .. code-block:: diff from vllm.inputs import INPUT_REGISTRY - from vllm.model_executor.models.interfaces import SupportsVision + from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.multimodal import MULTIMODAL_REGISTRY @MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens() @INPUT_REGISTRY.register_dummy_data() + @INPUT_REGISTRY.register_input_processor() - class YourModelForImage2Seq(nn.Module, SupportsVision): + class YourModelForImage2Seq(nn.Module, SupportsMultiModal): A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation. Here are some examples: diff --git a/requirements-common.txt b/requirements-common.txt index 1876ab3c7d48b..170c3e06ba226 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -20,4 +20,6 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0 typing_extensions >= 4.10 filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 pyzmq +librosa # Required for audio processing +soundfile # Required for audio processing gguf == 0.9.1 diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py new file mode 100644 index 0000000000000..3c2c652fd317d --- /dev/null +++ b/tests/entrypoints/openai/test_audio.py @@ -0,0 +1,351 @@ +import math +import sys +import time +from typing import Dict, List, Optional, Tuple, Union, cast +from unittest.mock import patch + +import librosa +import numpy as np +import openai +import pytest +import requests +import torch + +from vllm import ModelRegistry +from vllm.config import MultiModalConfig +from vllm.inputs import INPUT_REGISTRY +from vllm.inputs.data import LLMInputs +from vllm.inputs.registry import InputContext +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.opt import OPTForCausalLM +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.image import (cached_get_tokenizer, + repeat_and_pad_image_tokens) +from vllm.multimodal.utils import encode_audio_base64, fetch_audio +from vllm.utils import get_open_port + +from ...utils import VLLM_PATH + +chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" +assert chatml_jinja_path.exists() + +MODEL_NAME = "facebook/opt-125m" +TEST_AUDIO_URLS = [ + "https://upload.wikimedia.org/wikipedia/en/b/bf/Dave_Niehaus_Winning_Call_1995_AL_Division_Series.ogg", +] + + +def server_function(port): + + def fake_input_mapper(ctx: InputContext, data: object): + assert isinstance(data, tuple) + (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data) + + # Resample it to 1 sample per second + audio = librosa.resample(audio, orig_sr=sr, target_sr=1) + return MultiModalInputs({"processed_audio": torch.from_numpy(audio)}) + + def fake_input_processor(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "audio" not in multi_modal_data: + return llm_inputs + + audio, sr = multi_modal_data.get("audio") + audio_duration = math.ceil(len(audio) / sr) + + new_prompt, new_token_ids = repeat_and_pad_image_tokens( + cached_get_tokenizer(ctx.model_config.tokenizer), + llm_inputs.get("prompt"), + llm_inputs["prompt_token_ids"], + image_token_id=62, # "_" + repeat_count=audio_duration) + + return LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + + @MULTIMODAL_REGISTRY.register_input_mapper("audio", fake_input_mapper) + @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( + "audio", lambda *_, **__: 100) + @INPUT_REGISTRY.register_input_processor(fake_input_processor) + class FakeAudioModel(OPTForCausalLM, SupportsMultiModal): + + def __init__(self, *args, multimodal_config: MultiModalConfig, + **kwargs): + assert multimodal_config is not None + super().__init__(*args, **kwargs) + + def forward( + self, + *args, + processed_audio: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + return super().forward(*args, **kwargs) + + ModelRegistry.register_model("OPTForCausalLM", FakeAudioModel) + + with patch("vllm.entrypoints.chat_utils._mm_token_str", + lambda *_, **__: "_"): + sys.argv = ["placeholder.py"] + \ + (f"--model {MODEL_NAME} --gpu-memory-utilization 0.10 " + "--dtype bfloat16 --enforce-eager --api-key token-abc123 " + f"--port {port} --chat-template {chatml_jinja_path} " + "--disable-frontend-multiprocessing").split() + import runpy + runpy.run_module('vllm.entrypoints.openai.api_server', + run_name='__main__') + + +@pytest.fixture(scope="module") +def client(): + port = get_open_port() + ctx = torch.multiprocessing.get_context("spawn") + server = ctx.Process(target=server_function, args=(port, )) + server.start() + MAX_SERVER_START_WAIT_S = 60 + client = openai.AsyncOpenAI( + base_url=f"http://localhost:{port}/v1", + api_key="token-abc123", + ) + # run health check + health_url = f"http://localhost:{port}/health" + start = time.time() + while True: + try: + if requests.get(health_url).status_code == 200: + break + except Exception as err: + result = server.exitcode + if result is not None: + raise RuntimeError("Server exited unexpectedly.") from err + + time.sleep(0.5) + if time.time() - start > MAX_SERVER_START_WAIT_S: + raise RuntimeError("Server failed to start in time.") from err + + try: + yield client + finally: + server.kill() + + +@pytest.fixture(scope="session") +def base64_encoded_audio() -> Dict[str, str]: + return { + audio_url: encode_audio_base64(*fetch_audio(audio_url)) + for audio_url in TEST_AUDIO_URLS + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) +async def test_single_chat_session_audio(client: openai.AsyncOpenAI, + model_name: str, audio_url: str): + messages = [{ + "role": + "user", + "content": [ + { + "type": "audio_url", + "audio_url": { + "url": audio_url + } + }, + { + "type": "text", + "text": "What's happening in this audio?" + }, + ], + }] + + # test single completion + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=10, + logprobs=True, + top_logprobs=5) + assert len(chat_completion.choices) == 1 + + choice = chat_completion.choices[0] + assert choice.finish_reason == "length" + assert chat_completion.usage == openai.types.CompletionUsage( + completion_tokens=10, prompt_tokens=36, total_tokens=46) + + message = choice.message + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 10 + assert message.role == "assistant" + messages.append({"role": "assistant", "content": message.content}) + + # test multi-turn dialogue + messages.append({"role": "user", "content": "express your result in json"}) + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + ) + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) +async def test_single_chat_session_audio_base64encoded( + client: openai.AsyncOpenAI, model_name: str, audio_url: str, + base64_encoded_audio: Dict[str, str]): + + messages = [{ + "role": + "user", + "content": [ + { + "type": "audio_url", + "audio_url": { + "url": + f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}" + } + }, + { + "type": "text", + "text": "What's happening in this audio?" + }, + ], + }] + + # test single completion + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=10, + logprobs=True, + top_logprobs=5) + assert len(chat_completion.choices) == 1 + + choice = chat_completion.choices[0] + assert choice.finish_reason == "length" + assert chat_completion.usage == openai.types.CompletionUsage( + completion_tokens=10, prompt_tokens=36, total_tokens=46) + + message = choice.message + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 10 + assert message.role == "assistant" + messages.append({"role": "assistant", "content": message.content}) + + # test multi-turn dialogue + messages.append({"role": "user", "content": "express your result in json"}) + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + ) + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) +async def test_chat_streaming_audio(client: openai.AsyncOpenAI, + model_name: str, audio_url: str): + messages = [{ + "role": + "user", + "content": [ + { + "type": "audio_url", + "audio_url": { + "url": audio_url + } + }, + { + "type": "text", + "text": "What's happening in this audio?" + }, + ], + }] + + # test single completion + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + temperature=0.0, + ) + output = chat_completion.choices[0].message.content + stop_reason = chat_completion.choices[0].finish_reason + + # test streaming + stream = await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + temperature=0.0, + stream=True, + ) + chunks: List[str] = [] + finish_reason_count = 0 + async for chunk in stream: + delta = chunk.choices[0].delta + if delta.role: + assert delta.role == "assistant" + if delta.content: + chunks.append(delta.content) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + # finish reason should only return in last block + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == stop_reason + assert delta.content + assert "".join(chunks) == output + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) +async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str, + audio_url: str): + + messages = [{ + "role": + "user", + "content": [ + { + "type": "audio_url", + "audio_url": { + "url": audio_url + } + }, + { + "type": "audio_url", + "audio_url": { + "url": audio_url + } + }, + { + "type": "text", + "text": "What's happening in this audio?" + }, + ], + }] + + with pytest.raises(openai.BadRequestError): # test multi-audio input + await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + temperature=0.0, + ) + + # the server should still work afterwards + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + completion = completion.choices[0].text + assert completion is not None and len(completion) >= 0 diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 1197c70d88ae3..4a0b0f879e8ef 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -2,7 +2,8 @@ from dataclasses import dataclass from functools import lru_cache from pathlib import Path -from typing import Any, Awaitable, Iterable, List, Optional, Tuple, Union, cast +from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple, + Union, cast) # yapf conflicts with isort for this block # yapf: disable @@ -21,12 +22,27 @@ from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.multimodal import MultiModalDataDict -from vllm.multimodal.utils import async_get_and_parse_image +from vllm.multimodal.utils import (async_get_and_parse_audio, + async_get_and_parse_image) from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) +class AudioURL(TypedDict, total=False): + url: Required[str] + """ + Either a URL of the audio or a data URL with base64 encoded audio data. + """ + + +class ChatCompletionContentPartAudioParam(TypedDict, total=False): + audio_url: Required[AudioURL] + + type: Required[Literal["audio_url"]] + """The type of the content part.""" + + class CustomChatCompletionContentPartParam(TypedDict, total=False): __pydantic_config__ = ConfigDict(extra="allow") # type: ignore @@ -35,6 +51,7 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False): ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam, + ChatCompletionContentPartAudioParam, CustomChatCompletionContentPartParam] @@ -97,34 +114,41 @@ def load_chat_template( @lru_cache(maxsize=None) -def _image_token_str(model_config: ModelConfig, - tokenizer: PreTrainedTokenizer) -> Optional[str]: +def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer, + modality: Literal["image", "audio"]) -> Optional[str]: # TODO: Let user specify how to insert image tokens into prompt # (similar to chat template) - model_type = model_config.hf_config.model_type - if model_type == "phi3_v": - # Workaround since this token is not defined in the tokenizer - return "<|image_1|>" - if model_type == "minicpmv": - return "(./)" - if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"): - # These models do not use image tokens in the prompt - return None - if model_type.startswith("llava"): - return tokenizer.decode(model_config.hf_config.image_token_index) - if model_type in ("chameleon", "internvl_chat"): - return "" - raise TypeError(f"Unknown model type: {model_type}") - - -# TODO: Let user specify how to insert image tokens into prompt + if modality == "image": + model_type = model_config.hf_config.model_type + if model_type == "phi3_v": + # Workaround since this token is not defined in the tokenizer + return "<|image_1|>" + if model_type == "minicpmv": + return "(./)" + if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"): + # These models do not use image tokens in the prompt + return None + if model_type.startswith("llava"): + return tokenizer.decode(model_config.hf_config.image_token_index) + if model_type in ("chameleon", "internvl_chat"): + return "" + + raise TypeError(f"Unknown model type: {model_type}") + elif modality == "audio": + raise TypeError("No audio models are supported yet.") + else: + raise TypeError(f"Unknown modality: {modality}") + + +# TODO: Let user specify how to insert multimodal tokens into prompt # (similar to chat template) -def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str: - """Combine image and text prompts for vision language model""" +def _get_full_multimodal_text_prompt(placeholder_token_str: str, + text_prompt: str) -> str: + """Combine multimodal prompts for a multimodal language model""" # NOTE: For now we assume all model architectures use the same - # image + text prompt format. This may change in the future. - return f"{image_token_str}\n{text_prompt}" + # placeholder + text prompt format. This may change in the future. + return f"{placeholder_token_str}\n{text_prompt}" def _parse_chat_message_content_parts( @@ -135,6 +159,7 @@ def _parse_chat_message_content_parts( ) -> ChatMessageParseResult: texts: List[str] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = [] + modality: Literal["image", "audio"] = "image" for part in parts: part_type = part["type"] @@ -142,9 +167,10 @@ def _parse_chat_message_content_parts( text = cast(ChatCompletionContentPartTextParam, part)["text"] texts.append(text) elif part_type == "image_url": + modality = "image" if len(mm_futures) > 0: raise NotImplementedError( - "Multiple 'image_url' input is currently not supported.") + "Multiple multimodal inputs is currently not supported.") image_url = cast(ChatCompletionContentPartImageParam, part)["image_url"] @@ -156,21 +182,32 @@ def _parse_chat_message_content_parts( image_future = async_get_and_parse_image(image_url["url"]) mm_futures.append(image_future) + elif part_type == "audio_url": + modality = "audio" + if len(mm_futures) > 0: + raise NotImplementedError( + "Multiple multimodal inputs is currently not supported.") + + audio_url = cast(ChatCompletionContentPartAudioParam, + part)["audio_url"] + audio_future = async_get_and_parse_audio(audio_url["url"]) + mm_futures.append(audio_future) else: raise NotImplementedError(f"Unknown part type: {part_type}") text_prompt = "\n".join(texts) if mm_futures: - image_token_str = _image_token_str(model_config, tokenizer) - if image_token_str is not None: - if image_token_str in text_prompt: + placeholder_token_str = _mm_token_str(model_config, tokenizer, + modality) + if placeholder_token_str is not None: + if placeholder_token_str in text_prompt: logger.warning( - "Detected image token string in the text prompt. " + "Detected multi-modal token string in the text prompt. " "Skipping prompt formatting.") else: - text_prompt = _get_full_image_text_prompt( - image_token_str=image_token_str, + text_prompt = _get_full_multimodal_text_prompt( + placeholder_token_str=placeholder_token_str, text_prompt=text_prompt, ) diff --git a/vllm/envs.py b/vllm/envs.py index 26d0c33707fea..ca8ec96d07aa3 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -44,6 +44,7 @@ VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_IMAGE_FETCH_TIMEOUT: int = 5 + VLLM_AUDIO_FETCH_TIMEOUT: int = 5 VLLM_TARGET_DEVICE: str = "cuda" MAX_JOBS: Optional[str] = None NVCC_THREADS: Optional[str] = None @@ -321,6 +322,11 @@ def get_default_config_root(): "VLLM_IMAGE_FETCH_TIMEOUT": lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")), + # Timeout for fetching audio when serving multimodal models + # Default is 5 seconds + "VLLM_AUDIO_FETCH_TIMEOUT": + lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "5")), + # Path to the XLA persistent cache directory. # Only used for XLA devices such as TPUs. "VLLM_XLA_CACHE_PATH": diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 493d3dd29b376..b07a05828ed15 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -38,7 +38,7 @@ safetensors_weights_iterator) from vllm.model_executor.models.interfaces import (has_inner_state, supports_lora, - supports_vision) + supports_multimodal) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available @@ -131,7 +131,7 @@ def _get_model_initialization_kwargs( "be added in the future. If this is important to you, " "please open an issue on github.") - if supports_vision(model_class): + if supports_multimodal(model_class): if multimodal_config is None: raise ValueError("Provide vision related configurations " "through LLM entrypoint or engine arguments.") diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 4968d6d900ac2..c64428a4d7c75 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -20,8 +20,8 @@ from .blip import (BlipVisionModel, dummy_image_for_blip, get_max_blip_image_tokens) -from .interfaces import SupportsVision -from .utils import merge_vision_embeddings +from .interfaces import SupportsMultiModal +from .utils import merge_multimodal_embeddings _KEYS_TO_MODIFY_MAPPING = { "language_model.lm_head": "lm_head", @@ -457,7 +457,7 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2) @INPUT_REGISTRY.register_input_processor(input_processor_for_blip2) -class Blip2ForConditionalGeneration(nn.Module, SupportsVision): +class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__(self, config: Blip2Config, @@ -621,9 +621,9 @@ def forward( vision_embeddings = self._process_image_input(image_input) inputs_embeds = self.language_model.get_input_embeddings(input_ids) - inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds, - vision_embeddings, - BLIP2_IMAGE_TOKEN_ID) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + BLIP2_IMAGE_TOKEN_ID) input_ids = None else: diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 2b6e5ee975172..fd694119932df 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -35,7 +35,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData from vllm.utils import print_warning_once -from .interfaces import SupportsVision +from .interfaces import SupportsMultiModal logger = init_logger(__name__) @@ -886,7 +886,7 @@ def forward( @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon) @INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon) -class ChameleonForConditionalGeneration(nn.Module, SupportsVision): +class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__( self, diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 41e8b13990e81..5bb871d5a093b 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -40,8 +40,8 @@ cached_get_tokenizer) from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData -from .interfaces import SupportsVision -from .utils import merge_vision_embeddings +from .interfaces import SupportsMultiModal +from .utils import merge_multimodal_embeddings logger = init_logger(__name__) @@ -209,7 +209,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu) @INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu) -class FuyuForCausalLM(nn.Module, SupportsVision): +class FuyuForCausalLM(nn.Module, SupportsMultiModal): def __init__(self, config: FuyuConfig, @@ -271,9 +271,9 @@ def forward( if image_input is not None: vision_embeddings = self._process_image_input(image_input) inputs_embeds = self.language_model.model.embed_tokens(input_ids) - inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds, - vision_embeddings, - self.image_token_id) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.image_token_id) else: inputs_embeds = None diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index db0d6b429d64d..2f323ea552ccb 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -10,12 +10,15 @@ @runtime_checkable -class SupportsVision(Protocol): - """The interface required for all vision language models (VLMs).""" +class SupportsMultiModal(Protocol): + """ + The interface required for all multimodal (vision or audio) language + models. + """ - supports_vision: ClassVar[Literal[True]] = True + supports_multimodal: ClassVar[Literal[True]] = True """ - A flag that indicates this model supports vision inputs. + A flag that indicates this model supports multimodal inputs. Note: There is no need to redefine this flag if this class is in the @@ -29,30 +32,31 @@ def __init__(self, *, multimodal_config: MultiModalConfig) -> None: # We can't use runtime_checkable with ClassVar for issubclass checks # so we need to treat the class as an instance and use isinstance instead @runtime_checkable -class _SupportsVisionType(Protocol): - supports_vision: Literal[True] +class _SupportsMultiModalType(Protocol): + supports_multimodal: Literal[True] def __call__(self, *, multimodal_config: MultiModalConfig) -> None: ... @overload -def supports_vision(model: Type[object]) -> TypeIs[Type[SupportsVision]]: +def supports_multimodal( + model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]: ... @overload -def supports_vision(model: object) -> TypeIs[SupportsVision]: +def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]: ... -def supports_vision( +def supports_multimodal( model: Union[Type[object], object], -) -> Union[TypeIs[Type[SupportsVision]], TypeIs[SupportsVision]]: +) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]: if isinstance(model, type): - return isinstance(model, _SupportsVisionType) + return isinstance(model, _SupportsMultiModalType) - return isinstance(model, SupportsVision) + return isinstance(model, SupportsMultiModal) @runtime_checkable diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index e34a486f56e38..bf772a80a343c 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -27,9 +27,9 @@ from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_num_patches) -from .interfaces import SupportsVision +from .interfaces import SupportsMultiModal from .utils import (filter_weights, init_vllm_registered_model, - merge_vision_embeddings) + merge_multimodal_embeddings) IMG_START = '' IMG_END = '' @@ -292,7 +292,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl) @INPUT_REGISTRY.register_input_processor(input_processor_for_internvl) -class InternVLChatModel(nn.Module, SupportsVision): +class InternVLChatModel(nn.Module, SupportsMultiModal): def __init__(self, config: PretrainedConfig, @@ -451,9 +451,9 @@ def forward( inputs_embeds = self.language_model.model.get_input_embeddings( input_ids) vision_embeddings = self._process_image_input(image_input) - inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds, - vision_embeddings, - self.img_context_token_id) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.img_context_token_id) input_ids = None else: inputs_embeds = None diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 71a46256040c6..d4faf82b49697 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -19,12 +19,12 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip, dummy_seq_data_for_clip, get_max_clip_image_tokens, input_processor_for_clip) -from .interfaces import SupportsVision +from .interfaces import SupportsMultiModal from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens, input_processor_for_siglip) from .utils import (filter_weights, init_vllm_registered_model, - merge_vision_embeddings) + merge_multimodal_embeddings) class LlavaImagePixelInputs(TypedDict): @@ -181,7 +181,7 @@ def _init_vision_tower(hf_config: LlavaConfig): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava) -class LlavaForConditionalGeneration(nn.Module, SupportsVision): +class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__(self, config: LlavaConfig, @@ -338,7 +338,7 @@ def forward( inputs_embeds = self.language_model.model.get_input_embeddings( input_ids) - inputs_embeds = merge_vision_embeddings( + inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, vision_embeddings, self.config.image_token_index) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 8331cbe8bcd1e..4ae545461eef8 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -23,13 +23,13 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_image_feature_size, get_clip_patch_grid_length, input_processor_for_clip) -from .interfaces import SupportsVision +from .interfaces import SupportsMultiModal from .llava import LlavaMultiModalProjector from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_siglip_image_feature_size, get_siglip_patch_grid_length, input_processor_for_siglip) from .utils import (filter_weights, init_vllm_registered_model, - merge_vision_embeddings) + merge_multimodal_embeddings) logger = init_logger(__name__) @@ -275,7 +275,7 @@ def _init_vision_tower(hf_config: LlavaNextConfig): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next) -class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): +class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__(self, config: LlavaNextConfig, @@ -571,7 +571,7 @@ def forward( inputs_embeds = self.language_model.model.get_input_embeddings( input_ids) - inputs_embeds = merge_vision_embeddings( + inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, vision_embeddings, self.config.image_token_index) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index ab2b2c81ef4db..70002e2d532d4 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -48,7 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsVision +from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.minicpm import MiniCPMModel from vllm.model_executor.models.qwen2 import Qwen2Model @@ -479,7 +479,7 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int): return llm_inputs -class MiniCPMVBaseModel(nn.Module, SupportsVision): +class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): """ The abstract class of MiniCPMV can only be inherited, but cannot be instantiated. diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 3b9470774f843..51ff8c5d6fd13 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -19,10 +19,10 @@ from vllm.multimodal.image import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SamplerOutput -from .interfaces import SupportsVision +from .interfaces import SupportsMultiModal from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens) -from .utils import merge_vision_embeddings +from .utils import merge_multimodal_embeddings logger = init_logger(__name__) @@ -130,7 +130,7 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma) @INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma) -class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision): +class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__(self, config: PaliGemmaConfig, @@ -244,7 +244,7 @@ def forward(self, inputs_embeds = self.language_model.get_input_embeddings(input_ids) - inputs_embeds = merge_vision_embeddings( + inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, vision_embeddings, self.config.image_token_index) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index dd921c6af0538..f0ae0b6fdfb93 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -42,8 +42,8 @@ from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, input_processor_for_clip) -from .interfaces import SupportsVision -from .utils import merge_vision_embeddings +from .interfaces import SupportsMultiModal +from .utils import merge_multimodal_embeddings logger = init_logger(__name__) @@ -453,7 +453,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v) @INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v) -class Phi3VForCausalLM(nn.Module, SupportsVision): +class Phi3VForCausalLM(nn.Module, SupportsMultiModal): def __init__(self, config: PretrainedConfig, @@ -568,9 +568,9 @@ def forward(self, if image_input is not None: vision_embeddings = self._process_image_input(image_input) inputs_embeds = self.model.get_input_embeddings(input_ids) - inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds, - vision_embeddings, - self.image_token_id) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.image_token_id) input_ids = None else: inputs_embeds = None diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index d1bb030c6c90f..91b414b1fd91a 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -54,41 +54,42 @@ def init_vllm_registered_model( ) -def merge_vision_embeddings(input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - vision_embeddings: BatchedTensors, - image_token_id: int) -> torch.Tensor: +def merge_multimodal_embeddings(input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + multimodal_embeddings: BatchedTensors, + placeholder_token_id: int) -> torch.Tensor: """ - Merge ``vision_embeddings`` into ``inputs_embeds`` by overwriting the - positions in ``inputs_embeds`` corresponding to placeholder image tokens in + Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the + positions in ``inputs_embeds`` corresponding to placeholder tokens in ``input_ids``. Note: This updates ``inputs_embeds`` in place. """ - mask = (input_ids == image_token_id) + mask = (input_ids == placeholder_token_id) num_expected_tokens = mask.sum() - if isinstance(vision_embeddings, torch.Tensor): - batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape + if isinstance(multimodal_embeddings, torch.Tensor): + batch_size, batch_tokens, *_, embed_dim = multimodal_embeddings.shape total_tokens = batch_size * batch_tokens if num_expected_tokens != total_tokens: expr = f"{batch_size} x {batch_tokens}" raise ValueError( f"Attempted to assign {expr} = {total_tokens} " - f"image tokens to {num_expected_tokens} placeholders") + f"multimodal tokens to {num_expected_tokens} placeholders") - inputs_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim) + inputs_embeds[mask] = multimodal_embeddings.view( + total_tokens, embed_dim) else: - size_per_batch = [t.shape[0] for t in vision_embeddings] + size_per_batch = [t.shape[0] for t in multimodal_embeddings] total_tokens = sum(size_per_batch) if num_expected_tokens != total_tokens: expr = ' + '.join(map(str, size_per_batch)) raise ValueError( f"Attempted to assign {expr} = {total_tokens} " - f"image tokens to {num_expected_tokens} placeholders") + f"multimodal tokens to {num_expected_tokens} placeholders") - inputs_embeds[mask] = torch.cat(vision_embeddings) + inputs_embeds[mask] = torch.cat(multimodal_embeddings) return inputs_embeds diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py new file mode 100644 index 0000000000000..b4bf4b4541db8 --- /dev/null +++ b/vllm/multimodal/audio.py @@ -0,0 +1,17 @@ +from vllm.inputs.registry import InputContext +from vllm.multimodal.base import MultiModalInputs, MultiModalPlugin + + +class AudioPlugin(MultiModalPlugin): + """Plugin for audio data.""" + + def get_data_key(self) -> str: + return "audio" + + def _default_input_mapper(self, ctx: InputContext, + data: object) -> MultiModalInputs: + raise NotImplementedError("There is no default audio input mapper") + + def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: + raise NotImplementedError( + "There is no default maximum multimodal tokens") diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index aefb5f438c5ad..7717d77198a19 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -3,8 +3,9 @@ from collections import UserDict, defaultdict from typing import Any, Callable, Dict, List, Optional from typing import Sequence as GenericSequence -from typing import Type, TypedDict, TypeVar, Union, cast +from typing import Tuple, Type, TypedDict, TypeVar, Union, cast +import numpy as np import torch import torch.types from PIL import Image @@ -121,6 +122,9 @@ class MultiModalDataBuiltins(TypedDict, total=False): image: Image.Image """The input image.""" + audio: Tuple[np.ndarray, Union[int, float]] + """The input audio and its sampling rate.""" + MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]] """ diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index d8e1b68178acd..19c26123c2df3 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -6,6 +6,7 @@ from vllm.config import ModelConfig from vllm.logger import init_logger +from .audio import AudioPlugin from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs, MultiModalPlugin, MultiModalTokensCalc) from .image import ImagePlugin @@ -19,7 +20,7 @@ class MultiModalRegistry: :class:`~vllm.multimodal.MultiModalPlugin` for each modality. """ - DEFAULT_PLUGINS = (ImagePlugin(), ) + DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin()) def __init__( self, diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 8f7e613cdf90a..d1e624cdb8ace 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -1,11 +1,14 @@ import base64 from io import BytesIO -from typing import Union +from typing import Tuple, Union +import librosa +import numpy as np +import soundfile from PIL import Image from vllm.connections import global_http_connection -from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT +from vllm.envs import VLLM_AUDIO_FETCH_TIMEOUT, VLLM_IMAGE_FETCH_TIMEOUT from vllm.multimodal.base import MultiModalDataDict @@ -63,11 +66,62 @@ async def async_fetch_image(image_url: str, return image.convert(image_mode) +def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]: + """ + Load audio from a URL. + """ + if audio_url.startswith("http"): + audio_bytes = global_http_connection.get_bytes( + audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT) + elif audio_url.startswith("data:audio"): + _, audio_base64 = audio_url.split(",", 1) + audio_bytes = base64.b64decode(audio_base64) + else: + raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start " + "with either 'data:audio' or 'http'.") + + return librosa.load(BytesIO(audio_bytes), sr=None) + + +async def async_fetch_audio( + audio_url: str) -> Tuple[np.ndarray, Union[int, float]]: + """ + Asynchronously fetch audio from a URL. + """ + if audio_url.startswith("http"): + audio_bytes = await global_http_connection.async_get_bytes( + audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT) + elif audio_url.startswith("data:audio"): + _, audio_base64 = audio_url.split(",", 1) + audio_bytes = base64.b64decode(audio_base64) + else: + raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start " + "with either 'data:audio' or 'http'.") + + return librosa.load(BytesIO(audio_bytes), sr=None) + + +async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict: + audio, sr = await async_fetch_audio(audio_url) + return {"audio": (audio, sr)} + + async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict: image = await async_fetch_image(image_url) return {"image": image} +def encode_audio_base64( + audio: np.ndarray, + sampling_rate: int, +) -> str: + """Encode audio as base64.""" + buffered = BytesIO() + soundfile.write(buffered, audio, sampling_rate, format="WAV") + + return base64.b64encode(buffered.getvalue()).decode('utf-8') + + def encode_image_base64( image: Image.Image, *, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index cfbbb6698cd8a..a4ce1b512dd05 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -40,7 +40,7 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models.interfaces import (supports_lora, - supports_vision) + supports_multimodal) from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) @@ -900,9 +900,9 @@ def load_model(self) -> None: if self.lora_config: assert supports_lora(self.model), "Model does not support LoRA" - assert not supports_vision( + assert not supports_multimodal( self.model - ), "To be tested: vision language model with LoRA settings." + ), "To be tested: multimodal language model with LoRA settings." self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, @@ -1054,7 +1054,7 @@ def profile_run(self) -> None: # of images processed. model_config = self.model_config - if supports_vision(self.model): + if supports_multimodal(self.model): max_mm_tokens = MULTIMODAL_REGISTRY \ .get_max_multimodal_tokens(model_config) max_num_seqs_orig = max_num_seqs diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 112e494fadede..a1e1c1bef6336 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -12,7 +12,7 @@ from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model -from vllm.model_executor.models.interfaces import supports_vision +from vllm.model_executor.models.interfaces import supports_multimodal from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) from vllm.sampling_params import SamplingParams @@ -165,7 +165,7 @@ def profile_run(self) -> None: # of images processed. model_config = self.model_config - if supports_vision(self.model): + if supports_multimodal(self.model): max_mm_tokens = MULTIMODAL_REGISTRY \ .get_max_multimodal_tokens(model_config) max_num_seqs_orig = max_num_seqs