diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 3d8df3c9f8c9f..ad153d2927d6c 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -334,6 +334,14 @@ The following modalities are supported depending on the model: - **V**\ ideo - **A**\ udio +Any combination of modalities joined by :code:`+` are supported. + +- e.g.: :code:`T + I` means that the model supports text-only, image-only, and text-with-image inputs. + +On the other hand, modalities separated by :code:`/` are mutually exclusive. + +- e.g.: :code:`T / I` means that the model supports text-only and image-only inputs, but not text-with-image inputs. + .. _supported_vlms: Text Generation @@ -484,6 +492,12 @@ Multimodal Embedding - Example HF Models - :ref:`LoRA ` - :ref:`PP ` + * - :code:`LlavaNextForConditionalGeneration` + - LLaVA-NeXT-based + - T / I + - :code:`royokong/e5-v` + - + - ✅︎ * - :code:`Phi3VForCausalLM` - Phi-3-Vision-based - T + I diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 06b424abd50b5..610cc31db9c4e 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -1,6 +1,6 @@ """ -This example shows how to use vLLM for running offline inference -with the correct prompt format on vision language models. +This example shows how to use vLLM for running offline inference with +the correct prompt format on vision language models for text generation. For most models, the prompt format should follow corresponding examples on HuggingFace model repository. @@ -450,7 +450,7 @@ def main(args): if __name__ == "__main__": parser = FlexibleArgumentParser( description='Demo on using vLLM for offline inference with ' - 'vision language models') + 'vision language models for text generation') parser.add_argument('--model-type', '-m', type=str, diff --git a/examples/offline_inference_vision_language_embedding.py b/examples/offline_inference_vision_language_embedding.py index cfedd145a015d..e1732d045f949 100644 --- a/examples/offline_inference_vision_language_embedding.py +++ b/examples/offline_inference_vision_language_embedding.py @@ -1,22 +1,170 @@ +""" +This example shows how to use vLLM for running offline inference with +the correct prompt format on vision language models for multimodal embedding. + +For most models, the prompt format should follow corresponding examples +on HuggingFace model repository. +""" +from argparse import Namespace +from typing import Literal, NamedTuple, Optional, TypedDict, Union, get_args + +from PIL.Image import Image + from vllm import LLM -from vllm.assets.image import ImageAsset - -image = ImageAsset("cherry_blossom").pil_image.convert("RGB") -prompt = "<|image_1|> Represent the given image with the following question: What is in the image" # noqa: E501 - -# Create an LLM. -llm = LLM( - model="TIGER-Lab/VLM2Vec-Full", - task="embedding", - trust_remote_code=True, - max_model_len=4096, - max_num_seqs=2, - mm_processor_kwargs={"num_crops": 16}, -) - -# Generate embedding. The output is a list of EmbeddingRequestOutputs. -outputs = llm.encode({"prompt": prompt, "multi_modal_data": {"image": image}}) - -# Print the outputs. -for output in outputs: - print(output.outputs.embedding) # list of 3072 floats +from vllm.multimodal.utils import fetch_image +from vllm.utils import FlexibleArgumentParser + + +class TextQuery(TypedDict): + modality: Literal["text"] + text: str + + +class ImageQuery(TypedDict): + modality: Literal["image"] + image: Image + + +class TextImageQuery(TypedDict): + modality: Literal["text+image"] + text: str + image: Image + + +QueryModality = Literal["text", "image", "text+image"] +Query = Union[TextQuery, ImageQuery, TextImageQuery] + + +class ModelRequestData(NamedTuple): + llm: LLM + prompt: str + image: Optional[Image] + + +def run_e5_v(query: Query): + llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501 + + if query["modality"] == "text": + text = query["text"] + prompt = llama3_template.format( + f"{text}\nSummary above sentence in one word: ") + image = None + elif query["modality"] == "image": + prompt = llama3_template.format( + "\nSummary above image in one word: ") + image = query["image"] + else: + modality = query['modality'] + raise ValueError(f"Unsupported query modality: '{modality}'") + + llm = LLM( + model="royokong/e5-v", + task="embedding", + max_model_len=4096, + ) + + return ModelRequestData( + llm=llm, + prompt=prompt, + image=image, + ) + + +def run_vlm2vec(query: Query): + if query["modality"] == "text": + text = query["text"] + prompt = f"Find me an everyday image that matches the given caption: {text}" # noqa: E501 + image = None + elif query["modality"] == "image": + prompt = "<|image_1|> Find a day-to-day image that looks similar to the provided image." # noqa: E501 + image = query["image"] + elif query["modality"] == "text+image": + text = query["text"] + prompt = f"<|image_1|> Represent the given image with the following question: {text}" # noqa: E501 + image = query["image"] + else: + modality = query['modality'] + raise ValueError(f"Unsupported query modality: '{modality}'") + + llm = LLM( + model="TIGER-Lab/VLM2Vec-Full", + task="embedding", + trust_remote_code=True, + mm_processor_kwargs={"num_crops": 4}, + ) + + return ModelRequestData( + llm=llm, + prompt=prompt, + image=image, + ) + + +def get_query(modality: QueryModality): + if modality == "text": + return TextQuery(modality="text", text="A dog sitting in the grass") + + if modality == "image": + return ImageQuery( + modality="image", + image=fetch_image( + "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/American_Eskimo_Dog.jpg/360px-American_Eskimo_Dog.jpg" # noqa: E501 + ), + ) + + if modality == "text+image": + return TextImageQuery( + modality="text+image", + text="A cat standing in the snow.", + image=fetch_image( + "https://upload.wikimedia.org/wikipedia/commons/thumb/b/b6/Felis_catus-cat_on_snow.jpg/179px-Felis_catus-cat_on_snow.jpg" # noqa: E501 + ), + ) + + msg = f"Modality {modality} is not supported." + raise ValueError(msg) + + +def run_encode(model: str, modality: QueryModality): + query = get_query(modality) + req_data = model_example_map[model](query) + + mm_data = {} + if req_data.image is not None: + mm_data["image"] = req_data.image + + outputs = req_data.llm.encode({ + "prompt": req_data.prompt, + "multi_modal_data": mm_data, + }) + + for output in outputs: + print(output.outputs.embedding) + + +def main(args: Namespace): + run_encode(args.model_name, args.modality) + + +model_example_map = { + "e5_v": run_e5_v, + "vlm2vec": run_vlm2vec, +} + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description='Demo on using vLLM for offline inference with ' + 'vision language models for multimodal embedding') + parser.add_argument('--model-name', + '-m', + type=str, + default="vlm2vec", + choices=model_example_map.keys(), + help='The name of the embedding model.') + parser.add_argument('--modality', + type=str, + default="image", + choices=get_args(QueryModality), + help='Modality of the input.') + args = parser.parse_args() + main(args) diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index 69f590fb7950d..e28514bf403f7 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -1,7 +1,7 @@ """ This example shows how to use vLLM for running offline inference with -multi-image input on vision language models, using the chat template defined -by the model. +multi-image input on vision language models for text generation, +using the chat template defined by the model. """ from argparse import Namespace from typing import List, NamedTuple, Optional @@ -334,7 +334,8 @@ def main(args: Namespace): if __name__ == "__main__": parser = FlexibleArgumentParser( description='Demo on using vLLM for offline inference with ' - 'vision language models that support multi-image input') + 'vision language models that support multi-image input for text ' + 'generation') parser.add_argument('--model-type', '-m', type=str, diff --git a/tests/conftest.py b/tests/conftest.py index fc8bd1a473476..76f581e0363f7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -43,10 +43,12 @@ _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")] -PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]] -PromptAudioInput = Union[List[Tuple[np.ndarray, int]], - List[List[Tuple[np.ndarray, int]]]] -PromptVideoInput = Union[List[np.ndarray], List[List[np.ndarray]]] +_M = TypeVar("_M") +_PromptMultiModalInput = Union[List[_M], List[List[_M]]] + +PromptImageInput = _PromptMultiModalInput[Image.Image] +PromptAudioInput = _PromptMultiModalInput[Tuple[np.ndarray, int]] +PromptVideoInput = _PromptMultiModalInput[np.ndarray] def _read_prompts(filename: str) -> List[str]: @@ -318,12 +320,12 @@ def get_inputs( "text": prompt, "return_tensors": "pt", } - if images is not None and images[i] is not None: - processor_kwargs["images"] = images[i] - if videos is not None and videos[i] is not None: - processor_kwargs["videos"] = videos[i] - if audios is not None and audios[i] is not None: - audio, sr = audios[i] + if images is not None and (image := images[i]) is not None: + processor_kwargs["images"] = image + if videos is not None and (video := videos[i]) is not None: + processor_kwargs["videos"] = video + if audios is not None and (audio_tuple := audios[i]) is not None: + audio, sr = audio_tuple processor_kwargs["audio"] = audio processor_kwargs["sampling_rate"] = sr @@ -338,7 +340,7 @@ def generate( self, prompts: List[str], images: Optional[PromptImageInput] = None, - videos: Optional[List[np.ndarray]] = None, + videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> List[Tuple[List[List[int]], List[str]]]: @@ -368,7 +370,7 @@ def generate_greedy( prompts: List[str], max_tokens: int, images: Optional[PromptImageInput] = None, - videos: Optional[List[np.ndarray]] = None, + videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> List[Tuple[List[int], str]]: @@ -409,7 +411,7 @@ def generate_greedy_logprobs( prompts: List[str], max_tokens: int, images: Optional[PromptImageInput] = None, - videos: Optional[List[np.ndarray]] = None, + videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> List[List[torch.Tensor]]: @@ -488,7 +490,7 @@ def generate_greedy_logprobs_limit( num_logprobs: int, images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, - videos: Optional[List[np.ndarray]] = None, + videos: Optional[PromptVideoInput] = None, **kwargs: Any, ) -> List[TokensTextLogprobs]: all_inputs = self.get_inputs(prompts, @@ -657,15 +659,18 @@ def get_inputs( inputs = [TextPrompt(prompt=prompt) for prompt in prompts] if images is not None: for i, image in enumerate(images): - inputs[i]["multi_modal_data"] = {"image": image} + if image is not None: + inputs[i]["multi_modal_data"] = {"image": image} if videos is not None: for i, video in enumerate(videos): - inputs[i]["multi_modal_data"] = {"video": video} + if video is not None: + inputs[i]["multi_modal_data"] = {"video": video} if audios is not None: for i, audio in enumerate(audios): - inputs[i]["multi_modal_data"] = {"audio": audio} + if audio is not None: + inputs[i]["multi_modal_data"] = {"audio": audio} return inputs @@ -837,13 +842,20 @@ def generate_beam_search( returned_outputs.append((token_ids, texts)) return returned_outputs - def encode(self, prompts: List[str]) -> List[List[float]]: - req_outputs = self.model.encode(prompts) - outputs = [] - for req_output in req_outputs: - embedding = req_output.outputs.embedding - outputs.append(embedding) - return outputs + def encode( + self, + prompts: List[str], + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + ) -> List[List[float]]: + inputs = self.get_inputs(prompts, + images=images, + videos=videos, + audios=audios) + + req_outputs = self.model.encode(inputs) + return [req_output.outputs.embedding for req_output in req_outputs] def __enter__(self): return self diff --git a/tests/models/embedding/utils.py b/tests/models/embedding/utils.py index 2fcc2013d91ef..fd1c44d9c117e 100644 --- a/tests/models/embedding/utils.py +++ b/tests/models/embedding/utils.py @@ -16,7 +16,8 @@ def check_embeddings_close( for prompt_idx, (embeddings_0, embeddings_1) in enumerate( zip(embeddings_0_lst, embeddings_1_lst)): - assert len(embeddings_0) == len(embeddings_1) + assert len(embeddings_0) == len(embeddings_1), ( + f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}") sim = F.cosine_similarity(torch.tensor(embeddings_0), torch.tensor(embeddings_1), diff --git a/tests/models/embedding/vision_language/test_llava_next.py b/tests/models/embedding/vision_language/test_llava_next.py new file mode 100644 index 0000000000000..52aef8c34d6f3 --- /dev/null +++ b/tests/models/embedding/vision_language/test_llava_next.py @@ -0,0 +1,135 @@ +from typing import List, Type + +import pytest +import torch.nn.functional as F +from transformers import AutoModelForVision2Seq + +from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner +from ....utils import large_gpu_test +from ..utils import check_embeddings_close + +llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501 + +HF_TEXT_PROMPTS = [ + # T -> X + llama3_template.format( + "The label of the object is stop sign\nSummary above sentence in one word: " # noqa: E501 + ), + # T -> X + llama3_template.format( + "cherry blossom\nSummary above sentence in one word: "), +] + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + # I -> X + "stop_sign": + llama3_template.format("\nSummary above image in one word: "), + # I -> X + "cherry_blossom": + llama3_template.format("\nSummary above image in one word: "), +}) + +MODELS = ["royokong/e5-v"] + + +def _run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + input_texts: List[str], + input_images: PromptImageInput, + model: str, + *, + dtype: str, +) -> None: + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + with vllm_runner(model, + task="embedding", + dtype=dtype, + max_model_len=4096, + enforce_eager=True) as vllm_model: + vllm_outputs = vllm_model.encode(input_texts, images=input_images) + + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForVision2Seq) as hf_model: + # Patch the issue where image_token_id + # exceeds the maximum allowed vocab size + hf_model.model.resize_token_embeddings( + hf_model.model.language_model.vocab_size + 1) + + all_inputs = hf_model.get_inputs(input_texts, images=input_images) + + all_outputs = [] + for inputs in all_inputs: + # Based on: https://huggingface.co/royokong/e5-v + outputs = hf_model.model( + **hf_model.wrap_device(inputs, + device=hf_model.model.device.type), + return_dict=True, + output_hidden_states=True, + ) + pooled_output = F.normalize(outputs.hidden_states[-1][0, -1, :], + dim=-1) + + all_outputs.append(pooled_output.tolist()) + + hf_outputs = all_outputs + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_models_text( + hf_runner, + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + input_texts_images = [(text, None) for text in HF_TEXT_PROMPTS] + input_texts = [text for text, _ in input_texts_images] + input_images = [image for _, image in input_texts_images] + + _run_test( + hf_runner, + vllm_runner, + input_texts, + input_images, # type: ignore + model, + dtype=dtype, + ) + + +@large_gpu_test(min_gb=48) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_models_image( + hf_runner, + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + input_texts_images = [ + (text, asset.pil_image) + for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) + ] + input_texts = [text for text, _ in input_texts_images] + input_images = [image for _, image in input_texts_images] + + _run_test( + hf_runner, + vllm_runner, + input_texts, + input_images, + model, + dtype=dtype, + ) diff --git a/tests/models/embedding/vision_language/test_phi3v.py b/tests/models/embedding/vision_language/test_phi3v.py index 0ca90e6bfa52e..ee411472ba284 100644 --- a/tests/models/embedding/vision_language/test_phi3v.py +++ b/tests/models/embedding/vision_language/test_phi3v.py @@ -1,42 +1,53 @@ +from typing import List, Type + import pytest import torch.nn.functional as F -from ....conftest import IMAGE_ASSETS +from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner +from ....utils import large_gpu_test from ..utils import check_embeddings_close +HF_TEXT_PROMPTS = [ + # T -> X + "Find me an everyday image that matches the given caption: The label of the object is stop sign", # noqa: E501 + # T -> X + "Retrieve an image of this caption: cherry blossom", +] + HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + # T + I -> X "stop_sign": "<|image_1|> Select the portion of the image that isolates the object of the given label: The label of the object is stop sign", # noqa: E501 + # I -> X "cherry_blossom": - "<|image_1|> Represent the given image with the following question: What is in the image", # noqa: E501 + "<|image_1|> Represent the given image for classification", # noqa: E501 }) MODELS = ["TIGER-Lab/VLM2Vec-Full"] -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -def test_models( - hf_runner, - vllm_runner, - example_prompts, +def _run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + input_texts: List[str], + input_images: PromptImageInput, model: str, + *, dtype: str, ) -> None: # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). - with vllm_runner(model, - task="embedding", - max_model_len=4096, - max_num_seqs=2, - dtype=dtype, + with vllm_runner(model, task="embedding", dtype=dtype, enforce_eager=True) as vllm_model: - vllm_outputs = vllm_model.encode(example_prompts) + vllm_outputs = vllm_model.encode(input_texts, images=input_images) - with hf_runner(model, dtype=dtype) as hf_model: - all_inputs = hf_model.get_inputs(example_prompts) + # use eager mode for hf runner, since phi3_v didn't work with flash_attn + hf_model_kwargs = {"_attn_implementation": "eager"} + with hf_runner(model, dtype=dtype, + model_kwargs=hf_model_kwargs) as hf_model: + all_inputs = hf_model.get_inputs(input_texts, images=input_images) all_outputs = [] for inputs in all_inputs: @@ -61,3 +72,53 @@ def test_models( name_0="hf", name_1="vllm", ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_models_text( + hf_runner, + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + input_texts_images = [(text, None) for text in HF_TEXT_PROMPTS] + input_texts = [text for text, _ in input_texts_images] + input_images = [image for _, image in input_texts_images] + + _run_test( + hf_runner, + vllm_runner, + input_texts, + input_images, # type: ignore + model, + dtype=dtype, + ) + + +@large_gpu_test(min_gb=48) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_models_image( + hf_runner, + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + input_texts_images = [ + (text, asset.pil_image) + for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) + ] + input_texts = [text for text, _ in input_texts_images] + input_images = [image for _, image in input_texts_images] + + _run_test( + hf_runner, + vllm_runner, + input_texts, + input_images, + model, + dtype=dtype, + ) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 4dd472b04bb1a..46cba8ebbc583 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -13,11 +13,13 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext +from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import IntermediateTensors +from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.utils import is_list_of from .clip import (CLIPVisionModel, dummy_image_for_clip, @@ -28,8 +30,8 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_siglip_image_feature_size, get_siglip_patch_grid_length, input_processor_for_siglip) -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - merge_multimodal_embeddings) +from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn, + init_vllm_registered_model) # Result in the max possible feature size (2x2 grid of 336x336px tiles) MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448 @@ -312,6 +314,10 @@ def __init__(self, self.language_model = init_vllm_registered_model( config.text_config, cache_config, quant_config) + # The same model class supports both language generation and embedding + # because the architecture name is the same + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) @@ -605,14 +611,12 @@ def forward( image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.config.image_token_index) - + inputs_embeds = embed_multimodal( + input_ids, + self.config.image_token_index, + self.language_model.model.get_input_embeddings, + lambda _: self._process_image_input(image_input), + ) input_ids = None else: inputs_embeds = None @@ -641,6 +645,13 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) loader.load_weights(weights) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 91c14e32c946c..9a1083520efd2 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -467,8 +467,6 @@ def input_processor_for_phi3v(ctx: InputContext, prompt_token_ids = inputs["prompt_token_ids"].copy() - print("prompt_token_ids (old)", prompt_token_ids) - # masked placeholder with image token id for idx in image_idx: candidates = _get_image_placeholder_token_id_candidates(model_config, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 8745e0cbd97b6..a255b2a2f3982 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -94,6 +94,7 @@ "MistralModel": ("llama", "LlamaEmbeddingModel"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), # [Multimodal] + "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), } diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index ec1d76d2117f3..d96e988fba384 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,7 +1,7 @@ import itertools from dataclasses import dataclass, field -from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, - Protocol, Tuple, Union, overload) +from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, + Optional, Protocol, Tuple, Union, overload) import torch import torch.nn as nn @@ -294,10 +294,11 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str: _embedding_count_expression(inner) for inner in embeddings) -def merge_multimodal_embeddings(input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - multimodal_embeddings: NestedTensors, - placeholder_token_id: int) -> torch.Tensor: +def _merge_multimodal_embeddings( + inputs_embeds: torch.Tensor, + is_multimodal: torch.Tensor, + multimodal_embeddings: NestedTensors, +) -> torch.Tensor: """ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the positions in ``inputs_embeds`` corresponding to placeholder tokens in @@ -306,8 +307,7 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor, Note: This updates ``inputs_embeds`` in place. """ - mask = (input_ids == placeholder_token_id) - num_expected_tokens = mask.sum().item() + num_expected_tokens = is_multimodal.sum().item() assert isinstance(num_expected_tokens, int) flattened = _flatten_embeddings(multimodal_embeddings) @@ -317,10 +317,70 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor, f"Attempted to assign {expr} = {flattened.shape[0]} " f"multimodal tokens to {num_expected_tokens} placeholders") - inputs_embeds[mask] = flattened + inputs_embeds[is_multimodal] = flattened return inputs_embeds +def embed_multimodal( + input_ids: torch.Tensor, + multimodal_token_id: int, + get_text_embeds: Callable[[torch.Tensor], torch.Tensor], + get_multimodal_embeds: Callable[[torch.Tensor], Union[torch.Tensor, + List[torch.Tensor]]], +) -> torch.Tensor: + """ + Embed token IDs and multimodal inputs and combine their embeddings. + + ``multimodal_token_id`` is used to determine whether a token ID should + be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``. + + Compared to ``merge_multimodal_embeddings`, this avoids running + ``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]`` + which causes issues when the placeholder token ID exceeds the + vocabulary size of the language model. + """ + is_multimodal = input_ids == multimodal_token_id + is_text = ~is_multimodal + + text_embeds = get_text_embeds(input_ids[is_text]) + multimodal_embeds = get_multimodal_embeds(input_ids[is_multimodal]) + + merged_embeds = torch.empty( + (input_ids.shape[0], text_embeds.shape[1]), + dtype=text_embeds.dtype, + device=text_embeds.device, + ) + + merged_embeds[is_text] = text_embeds + + return _merge_multimodal_embeddings( + merged_embeds, + is_multimodal, + multimodal_embeds, + ) + + +def merge_multimodal_embeddings( + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + multimodal_embeddings: NestedTensors, + placeholder_token_id: int, +) -> torch.Tensor: + """ + Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the + positions in ``inputs_embeds`` corresponding to placeholder tokens in + ``input_ids``. + + Note: + This updates ``inputs_embeds`` in place. + """ + return _merge_multimodal_embeddings( + inputs_embeds, + (input_ids == placeholder_token_id), + multimodal_embeddings, + ) + + class LayerFn(Protocol): def __call__(self, prefix: str) -> torch.nn.Module: