diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 610cc31db9c4e..83d2548a506e4 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -267,6 +267,11 @@ def run_qwen2_vl(question: str, modality: str): model=model_name, max_model_len=8192, max_num_seqs=5, + # Note - mm_processor_kwargs can also be passed to generate/chat calls + mm_processor_kwargs={ + "min_pixels": 28 * 28, + "max_pixels": 1280 * 28 * 28, + }, ) prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" diff --git a/tests/models/decoder_only/vision_language/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/test_qwen2_vl.py new file mode 100644 index 0000000000000..d3de5fb26d4b8 --- /dev/null +++ b/tests/models/decoder_only/vision_language/test_qwen2_vl.py @@ -0,0 +1,160 @@ +from typing import Any, Dict, Tuple + +import pytest +import torch +from PIL.Image import Image +from transformers import AutoTokenizer + +from vllm.inputs import InputContext, token_inputs +from vllm.multimodal import MultiModalRegistry + +from ....conftest import _ImageAssets +from ...utils import build_model_context + +MODEL = "Qwen/Qwen2-VL-2B-Instruct" +MIN_PIXELS = "min_pixels" +MAX_PIXELS = "max_pixels" + + +# Fixtures lazy import to avoid initializing CUDA during test collection +# NOTE: Qwen2vl supports multiple input modalities, so it registers multiple +# input mappers. +@pytest.fixture() +def image_input_mapper_for_qwen2_vl(): + from vllm.model_executor.models.qwen2_vl import ( + image_input_mapper_for_qwen2_vl) + return image_input_mapper_for_qwen2_vl + + +@pytest.fixture() +def input_processor_for_qwen2_vl(): + from vllm.model_executor.models.qwen2_vl import ( + input_processor_for_qwen2_vl) + return input_processor_for_qwen2_vl + + +@pytest.fixture() +def qwen2_vl_context() -> InputContext: + return build_model_context(model_name=MODEL) + + +@pytest.fixture() +def get_max_qwen2_vl_image_tokens(): + from vllm.model_executor.models.qwen2_vl import ( + get_max_qwen2_vl_image_tokens) + return get_max_qwen2_vl_image_tokens + + +@pytest.fixture() +def dummy_data_for_qwen2_vl(): + from vllm.model_executor.models.qwen2_vl import dummy_data_for_qwen2_vl + return dummy_data_for_qwen2_vl + + +@pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [ + ({}, 1225), + ({ + MIN_PIXELS: 64**2, + MAX_PIXELS: 512**2 + }, 324), +]) +def test_qwen2_vl_max_image_tokens(get_max_qwen2_vl_image_tokens, + qwen2_vl_context: InputContext, + mm_processor_kwargs: Dict[str, Any], + expected_max_tokens: int): + """Ensure that the max token calc handles min/max pixels properly.""" + actual_max_tokens = get_max_qwen2_vl_image_tokens(qwen2_vl_context, + **mm_processor_kwargs) + assert actual_max_tokens == expected_max_tokens + + +@pytest.mark.parametrize("mm_processor_kwargs,token_count,img_size", [ + [{}, 1225, (980, 980)], + [{ + MIN_PIXELS: 64**2, + MAX_PIXELS: 512**2 + }, 324, (504, 504)], +]) +def test_qwen2_vl_dummy_data(dummy_data_for_qwen2_vl, + qwen2_vl_context: InputContext, + mm_processor_kwargs: Dict[str, Any], + token_count: int, img_size: Tuple[int, int]): + """Ensure that the dummy data handles min/max pixels properly.""" + seq_len = 3000 + hf_config = qwen2_vl_context.get_hf_config() + image_token_id = hf_config.image_token_id + + # NOTE: video value is required, but isn't actually used + # when making the dummy data except for error handling currently + seq_data, mm_data = dummy_data_for_qwen2_vl(qwen2_vl_context, seq_len, { + "image": 1, + "video": 0 + }, **mm_processor_kwargs) + + # Ensure we have the right number of placeholders for min/max pixel values + assert seq_data.get_token_ids().count(image_token_id) == token_count + + # Ensure the images were resized correctly + image = mm_data["image"] + assert isinstance(image, Image) + assert image.size == img_size + + +@pytest.mark.parametrize("mm_processor_kwargs,num_placeholders", [ + ({}, 1426), + ({ + MIN_PIXELS: 64**2, + MAX_PIXELS: 512**2 + }, 330), +]) +def test_input_processor(input_processor_for_qwen2_vl, + qwen2_vl_context: InputContext, + image_assets: _ImageAssets, num_placeholders: int, + mm_processor_kwargs: Dict[str, Any]): + """Ensure that the image processor handles min/max pixels properly.""" + tokenizer = AutoTokenizer.from_pretrained(MODEL) + prompt = "<|vision_start|><|image_pad|><|vision_end|>" + + image = image_assets[0].pil_image + hf_config = qwen2_vl_context.get_hf_config() + image_token_id = hf_config.image_token_id + + inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt), + prompt=prompt, + multi_modal_data={"image": [image]}) + + processed_inputs = input_processor_for_qwen2_vl(qwen2_vl_context, inputs, + **mm_processor_kwargs) + assert processed_inputs["prompt_token_ids"].count( + image_token_id) == num_placeholders + assert len(processed_inputs["multi_modal_data"]["image"]) == 1 + + +@pytest.mark.parametrize("mm_processor_kwargs,pixels_shape", [ + ({}, [5704, 1176]), + ({ + MIN_PIXELS: 64**2, + MAX_PIXELS: 512**2 + }, [1320, 1176]), +]) +def test_image_mapper_override(qwen2_vl_context: InputContext, + image_assets: _ImageAssets, + mm_processor_kwargs: Dict[str, Any], + pixels_shape: Tuple[int, int]): + """Ensure that the image mapper handles min/max pixels properly.""" + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(qwen2_vl_context.model_config) + + image = image_assets[0].pil_image + + mapped_output = mm_registry.map_input( + qwen2_vl_context.model_config, + {"image": image}, + mm_processor_kwargs=mm_processor_kwargs, + ) + + # Dimension 0 of pixel values should match the product of image_grid_thw + actual_pixels_shape = mapped_output["pixel_values"].shape + assert list(actual_pixels_shape) == pixels_shape + assert actual_pixels_shape[0] == torch.prod( + mapped_output["image_grid_thw"]) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 9cca6b65e3277..3dc955b12ba0e 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -549,6 +549,9 @@ def mm_input_mapper_for_qwen2_vl( ctx: InputContext, data: MultiModalData[object], data_type_key: str, + *, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, ) -> MultiModalInputs: """Input mapper for Qwen2-VL.""" if data_type_key == "image" and isinstance(data, dict): @@ -557,8 +560,19 @@ def mm_input_mapper_for_qwen2_vl( "image_grid_thw": data.get("image_grid_thw"), }) model_config = ctx.model_config + # Handle mm processor kwargs; we pass these at creation time + # because preprocess() in transformers doesn't expose them + mm_processor_kwargs = {} + if min_pixels: + mm_processor_kwargs["min_pixels"] = min_pixels + if max_pixels: + mm_processor_kwargs["max_pixels"] = max_pixels + image_processor = cached_get_image_processor( - model_config.model, trust_remote_code=model_config.trust_remote_code) + model_config.model, + trust_remote_code=model_config.trust_remote_code, + **mm_processor_kwargs, + ) if image_processor is None: raise RuntimeError("No HuggingFace processor is available " "to process the image object") @@ -631,25 +645,36 @@ def _get_max_image_info( image_processor, data_type_key: str = "image", mm_count: int = 1, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, ): + # Limit min / max pixels unless they're explicitly provided + if min_pixels is None: + min_pixels = max(image_processor.min_pixels, 28 * 28) + if max_pixels is None: + max_pixels = min(image_processor.max_pixels, 1280 * 28 * 28) + return _get_vision_info( image_processor, height=9999999, width=9999999, - - # Limit min / max pixels. - min_pixels=max(image_processor.min_pixels, 28 * 28), - max_pixels=min(image_processor.max_pixels, 1280 * 28 * 28), + min_pixels=min_pixels, + max_pixels=max_pixels, data_type_key=data_type_key, mm_count=mm_count, ) -def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str) -> int: +def get_max_qwen2_vl_mm_tokens(ctx: InputContext, + data_type_key: str, + *, + min_pixels=None, + max_pixels=None) -> int: image_processor = cached_get_image_processor(ctx.model_config.model) max_resized_height, max_resized_width, max_llm_image_tokens = \ _get_max_image_info(image_processor, data_type_key=data_type_key, - mm_count=1) + mm_count=1, min_pixels=min_pixels, + max_pixels=max_pixels) return max_llm_image_tokens @@ -660,14 +685,20 @@ def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str) -> int: def dummy_data_for_qwen2_vl( - ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] + ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], + *, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None ) -> Tuple[SequenceData, Optional[MultiModalDataDict]]: image_processor = cached_get_image_processor(ctx.model_config.model) num_images = mm_counts["image"] max_resized_height, max_resized_width, max_llm_image_tokens = \ _get_max_image_info(image_processor, data_type_key="image", - mm_count=num_images) + mm_count=num_images, min_pixels=min_pixels, + max_pixels=max_pixels) if seq_len - max_llm_image_tokens - 2 < 0: raise RuntimeError( f"Qwen2-VL cannot process {num_images} images in a prompt, " @@ -678,10 +709,11 @@ def dummy_data_for_qwen2_vl( num_videos = mm_counts["video"] max_resized_height, max_resized_width, max_llm_video_tokens = \ _get_max_image_info(image_processor, data_type_key="video", - mm_count=num_videos) + mm_count=num_videos, min_pixels=min_pixels, + max_pixels=max_pixels) if seq_len - max_llm_video_tokens - 2 < 0: raise RuntimeError( - f"Qwen2-VL cannot process {num_images} videos in a prompt, " + f"Qwen2-VL cannot process {num_videos} videos in a prompt, " "please increase max_model_len or reduce video limit by " "--limit-mm-per-prompt.") @@ -706,6 +738,8 @@ def _get_llm_num_vision_tokens( mm_inputs: list, data_type_key: str, image_processor, + min_pixels: int, + max_pixels: int, ): """Get number of vision tokens of multimodal inputs. @@ -715,12 +749,13 @@ def _get_llm_num_vision_tokens( image = to_numpy_array(mm_inputs[0]) input_data_format = infer_channel_dimension_format(image) height, width = get_image_size(image, channel_dim=input_data_format) + _, _, llm_num_vision_tokens = _get_vision_info( image_processor, height=height, width=width, - min_pixels=image_processor.min_pixels, - max_pixels=image_processor.max_pixels, + min_pixels=min_pixels, + max_pixels=max_pixels, do_resize=image_processor.do_resize, data_type_key=data_type_key, mm_count=len(mm_inputs), @@ -730,7 +765,8 @@ def _get_llm_num_vision_tokens( def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, data_type_key: str, image_processor: Any, - prompt_token_ids: List[int]) -> List[int]: + prompt_token_ids: List[int], min_pixels: Optional[int], + max_pixels: Optional[int]) -> List[int]: """ Expand pad tokens for multi-modal inputs (e.g., images or videos). @@ -741,6 +777,8 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, data_type_key (str): The type of the multi-modal input. image_processor (Any): The image processor used to process the inputs. prompt_token_ids (List[int]): The list of token IDs in the prompt. + min_pixels (int): min pixels to used for img processing + max_pixels (int): max pixels to be used for img processing Returns: List[int]: The list of token IDs for the multi-modal inputs. @@ -757,6 +795,8 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, [data] if data_type_key == "image" else data, data_type_key=data_type_key, image_processor=image_processor, + min_pixels=min_pixels, + max_pixels=max_pixels, ) if cnt == 0: end_idx = indices[cnt] @@ -773,6 +813,9 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, def input_processor_for_qwen2_vl( ctx: InputContext, inputs: DecoderOnlyInputs, + *, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, ) -> DecoderOnlyInputs: multi_modal_data = inputs.get("multi_modal_data", None) if multi_modal_data is None: @@ -783,6 +826,10 @@ def input_processor_for_qwen2_vl( processor = cached_get_processor(ctx.model_config.model) image_processor = processor.image_processor + # Apply processor kwarg overrides for image processor options + min_pixels = min_pixels if min_pixels else image_processor.min_pixels + max_pixels = max_pixels if max_pixels else image_processor.max_pixels + hf_config = ctx.get_hf_config(Qwen2VLConfig) # To avoid redundant processing of vision objects (resize, rescale, etc.), @@ -830,16 +877,22 @@ def input_processor_for_qwen2_vl( else: prompt_token_ids = _expand_pad_tokens(image_inputs, hf_config.image_token_id, - make_batched_images, "image", + make_batched_images, + "image", image_processor, - prompt_token_ids) + prompt_token_ids, + min_pixels=min_pixels, + max_pixels=max_pixels) if video_inputs is not None: prompt_token_ids = _expand_pad_tokens(video_inputs, hf_config.video_token_id, - make_batched_videos, "video", + make_batched_videos, + "video", image_processor, - prompt_token_ids) + prompt_token_ids, + min_pixels=min_pixels, + max_pixels=max_pixels) return token_inputs( prompt_token_ids=prompt_token_ids,