diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index 40829785d3214..259cbe515066d 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -21,6 +21,7 @@ "cherry_blossom": "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n", }) +HF_MULTIIMAGE_IMAGE_PROMPT = "<|user|>\n<|image_1|>\n<|image_2|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501 models = ["microsoft/Phi-3.5-vision-instruct"] @@ -184,3 +185,113 @@ def test_regression_7840(hf_runner, vllm_runner, image_assets, model, num_logprobs=10, tensor_parallel_size=1, ) + + +def run_multi_image_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + images: List[Image.Image], + model: str, + *, + size_factors: List[float], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test is under tests/images. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding MultiModalConfig as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + + inputs_per_case = [ + ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], + [[rescale_image_size(image, factor) for image in images] + for factor in size_factors]) + ] + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + + # max_model_len should be greater than image_feature_size + with vllm_runner(model, + max_model_len=4096, + max_num_seqs=1, + limit_mm_per_prompt={"image": len(images)}, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + vllm_outputs_per_case = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs_per_case + ] + + hf_model_kwargs = {"_attn_implementation": "eager"} + with hf_runner(model, dtype=dtype, + model_kwargs=hf_model_kwargs) as hf_model: + eos_token_id = hf_model.processor.tokenizer.eos_token_id + hf_outputs_per_case = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + eos_token_id=eos_token_id) + for prompts, images in inputs_per_case + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, + vllm_outputs_per_case): + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, model) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, + size_factors, dtype: str, max_tokens: int, + num_logprobs: int) -> None: + run_multi_image_test( + hf_runner, + vllm_runner, + [asset.pil_image for asset in image_assets], + model, + size_factors=size_factors, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 2e52531989232..4872929ec36cc 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools import re from functools import lru_cache from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, @@ -37,11 +38,11 @@ from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.utils import cached_get_tokenizer +from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.utils import is_list_of -from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, - input_processor_for_clip) +from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .interfaces import SupportsMultiModal from .utils import merge_multimodal_embeddings @@ -400,9 +401,20 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): image_data = multi_modal_data["image"] if isinstance(image_data, Image.Image): w, h = image_data.size - image_feature_size = get_phi3v_image_feature_size(hf_config, - input_width=w, - input_height=h) + image_feature_size = [ + get_phi3v_image_feature_size(hf_config, + input_width=w, + input_height=h) + ] + image_data = [image_data] + elif is_list_of(image_data, Image.Image): + image_feature_size = [] + for image in image_data: + w, h = image.size + image_feature_size.append( + get_phi3v_image_feature_size(hf_config, + input_width=w, + input_height=h)) elif isinstance(image_data, torch.Tensor): image_feature_size = image_data.shape[0] else: @@ -410,45 +422,61 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): prompt = llm_inputs.get("prompt") if prompt is None: + image_idx = [] new_prompt = None else: + image_idx = sorted(map(int, re.findall(r"<\|image_(\d+)\|>+", prompt))) if prompt.count("<|image|>") > 0: logger.warning("Please follow the prompt format that is " "documented on HuggingFace which does not involve " "repeating <|image|> tokens.") - elif len(re.findall(r"(<\|image_\d+\|>)+", prompt)) > 1: - logger.warning("Multiple image input is not supported yet, " - "so any extra image tokens will be treated " - "as plain text.") - + elif (num_image_tags := len(image_idx)) > 1: + assert num_image_tags == len( + image_data), "The count of image_placeholder not match image's" new_prompt = prompt - prompt_token_ids = llm_inputs["prompt_token_ids"] - image_1_token_ids = _get_image_placeholder_token_ids(model_config, idx=1) + prompt_token_ids = llm_inputs["prompt_token_ids"].copy() + + # masked place_holder with image token id + for idx in image_idx: + image_token_ids = _get_image_placeholder_token_ids(model_config, + idx=idx) + for i in range(len(prompt_token_ids) - len(image_token_ids) + 1): + if prompt_token_ids[i:i + len(image_token_ids)] == image_token_ids: + prompt_token_ids[i:i + len(image_token_ids)] = [ + _IMAGE_TOKEN_ID + ] * len(image_token_ids) + break + + # merge consecutive tag ids + merged_token_ids: List[int] = [] + for is_placeholder, token_ids in itertools.groupby( + prompt_token_ids, lambda x: x == _IMAGE_TOKEN_ID): + if is_placeholder: + merged_token_ids.append(_IMAGE_TOKEN_ID) + else: + merged_token_ids.extend(list(token_ids)) + # TODO: Move this to utils or integrate with clip. new_token_ids: List[int] = [] - for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1): - if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids: - new_token_ids.append(_IMAGE_TOKEN_ID) - - # No need to further scan the list since we only replace once - new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):]) - break + placeholder_idx = 0 + while merged_token_ids: + token_id = merged_token_ids.pop(0) + if token_id == _IMAGE_TOKEN_ID: + new_token_ids.extend( + repeat_and_pad_token( + _IMAGE_TOKEN_ID, + repeat_count=image_feature_size[placeholder_idx], + )) + placeholder_idx += 1 else: - new_token_ids.append(prompt_token_ids[i]) + new_token_ids.append(token_id) # NOTE: Create a defensive copy of the original inputs llm_inputs = LLMInputs(prompt_token_ids=new_token_ids, prompt=new_prompt, multi_modal_data=multi_modal_data) - - return input_processor_for_clip( - model_config, - CLIP_VIT_LARGE_PATCH14_336_CONFIG, - llm_inputs, - image_token_id=_IMAGE_TOKEN_ID, - image_feature_size_override=image_feature_size, - ) + return llm_inputs @MULTIMODAL_REGISTRY.register_image_input_mapper()