diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 02a4364da3f19..d8030ab219ccb 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -93,14 +93,13 @@ steps: - label: Models Test #mirror_hardwares: [amd] commands: - - bash ../.buildkite/download-images.sh - - pytest -v -s models --ignore=models/test_llava.py + - pytest -v -s models -m \"not llava\" - label: Llava Test mirror_hardwares: [amd] commands: - bash ../.buildkite/download-images.sh - - pytest -v -s models/test_llava.py + - pytest -v -s models -m llava - label: Prefix Caching Test mirror_hardwares: [amd] diff --git a/pyproject.toml b/pyproject.toml index 06f150009aa81..eb691c29724ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,4 +71,5 @@ markers = [ "skip_global_cleanup", "llm: run tests for vLLM API only", "openai: run tests for OpenAI API only", + "llava: run tests for LLaVA models only", ] diff --git a/tests/conftest.py b/tests/conftest.py index 55efc56ec3d02..a481daa3c23e4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,24 +29,19 @@ # Multi modal related # You can use `.buildkite/download-images.sh` to download the assets -_PIXEL_VALUES_FILES = [ +PIXEL_VALUES_FILES = [ os.path.join(_TEST_DIR, "images", filename) for filename in ["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"] ] -_IMAGE_FEATURES_FILES = [ +IMAGE_FEATURES_FILES = [ os.path.join(_TEST_DIR, "images", filename) for filename in ["stop_sign_image_features.pt", "cherry_blossom_image_features.pt"] ] -_IMAGE_FILES = [ +IMAGE_FILES = [ os.path.join(_TEST_DIR, "images", filename) for filename in ["stop_sign.jpg", "cherry_blossom.jpg"] ] -_IMAGE_PROMPTS = [ - "\nUSER: What's the content of the image?\nASSISTANT:", - "\nUSER: What is the season?\nASSISTANT:" -] -assert len(_PIXEL_VALUES_FILES) == len(_IMAGE_FEATURES_FILES) == len( - _IMAGE_FILES) == len(_IMAGE_PROMPTS) +assert len(PIXEL_VALUES_FILES) == len(IMAGE_FEATURES_FILES) == len(IMAGE_FILES) def _read_prompts(filename: str) -> List[str]: @@ -84,14 +79,9 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool): cleanup() -@pytest.fixture(scope="session") -def hf_image_prompts() -> List[str]: - return _IMAGE_PROMPTS - - @pytest.fixture(scope="session") def hf_images() -> List[Image.Image]: - return [Image.open(filename) for filename in _IMAGE_FILES] + return [Image.open(filename) for filename in IMAGE_FILES] @pytest.fixture() @@ -101,26 +91,17 @@ def vllm_images(request) -> List[MultiModalData]: VisionLanguageConfig.ImageInputType.IMAGE_FEATURES): return [ ImageFeatureData(torch.load(filename)) - for filename in _IMAGE_FEATURES_FILES + for filename in IMAGE_FEATURES_FILES ] else: return [ - ImagePixelData(Image.open(filename)) for filename in _IMAGE_FILES + ImagePixelData(Image.open(filename)) for filename in IMAGE_FILES ] @pytest.fixture() def vllm_image_tensors(request) -> List[torch.Tensor]: - return [torch.load(filename) for filename in _PIXEL_VALUES_FILES] - - -@pytest.fixture() -def vllm_image_prompts(request) -> List[str]: - vision_language_config = request.getfixturevalue("model_and_config")[1] - return [ - "" * (vision_language_config.image_feature_size - 1) + p - for p in _IMAGE_PROMPTS - ] + return [torch.load(filename) for filename in PIXEL_VALUES_FILES] @pytest.fixture diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index 839a9f78d1bb8..f03dbdbb770e5 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -1,14 +1,22 @@ -import gc -from dataclasses import fields -from enum import Enum -from typing import Any, Dict, List, Tuple +from typing import List, Tuple import pytest -import torch from transformers import AutoTokenizer from vllm.config import VisionLanguageConfig +from ..conftest import IMAGE_FILES + +pytestmark = pytest.mark.llava + +# The image token is placed before "user" on purpose so that the test can pass +HF_IMAGE_PROMPTS = [ + "\nUSER: What's the content of the image?\nASSISTANT:", + "\nUSER: What is the season?\nASSISTANT:", +] + +assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES) + def iter_llava_configs(model_name: str): image_hw_to_feature_size = { @@ -36,53 +44,35 @@ def iter_llava_configs(model_name: str): ] -def as_dict(vlm_config: VisionLanguageConfig) -> Dict[str, Any]: - """Flatten vision language config to pure args. - - Compatible with what llm entrypoint expects. - """ - result = {} - for field in fields(vlm_config): - value = getattr(vlm_config, field.name) - if isinstance(value, Enum): - result[field.name] = value.name.lower() - elif isinstance(value, tuple): - result[field.name] = ",".join([str(item) for item in value]) - else: - result[field.name] = value - - result["disable_image_processor"] = vlm_config.image_processor is None - - return result - - -def sanitize_vllm_output(vllm_output: Tuple[List[int], str], - vision_language_config: VisionLanguageConfig, - model_id: str): +def vllm_to_hf_output(vllm_output: Tuple[List[int], str], + vlm_config: VisionLanguageConfig, model_id: str): """Sanitize vllm output to be comparable with hf output. The function reduces `input_ids` from 1, 32000, 32000, ..., 32000, x1, x2, x3 ... to 1, 32000, x1, x2, x3 ... It also reduces `output_str` from "bla" to "bla". """ - tokenizer = AutoTokenizer.from_pretrained(model_id) - image_token_str = tokenizer.decode(vision_language_config.image_token_id) - image_token_str_len = len(image_token_str) input_ids, output_str = vllm_output - sanitized_input_ids = input_ids[0:2] + input_ids[2 + vision_language_config - .image_feature_size - 1:] - sanitzied_output_str = output_str[vision_language_config. - image_feature_size * - image_token_str_len:] - return sanitized_input_ids, sanitzied_output_str + image_token_id = vlm_config.image_token_id + + tokenizer = AutoTokenizer.from_pretrained(model_id) + image_token_str = tokenizer.decode(image_token_id) + hf_input_ids = [ + input_id for idx, input_id in enumerate(input_ids) + if input_id != image_token_id or input_ids[idx - 1] != image_token_id + ] + hf_output_str = output_str \ + .replace(image_token_str * vlm_config.image_feature_size, "") -@pytest.mark.parametrize("worker_use_ray", [False]) + return hf_input_ids, hf_output_str + + +# TODO: Add test for `tensor_parallel_size` [ref: PR #3883] @pytest.mark.parametrize("model_and_config", model_and_vl_config) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) -def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images, - vllm_image_prompts, vllm_images, model_and_config, dtype: str, - max_tokens: int, worker_use_ray: bool) -> None: +def test_models(hf_runner, vllm_runner, hf_images, vllm_images, + model_and_config, dtype: str, max_tokens: int) -> None: """Inference result should be the same between hf and vllm. All the image fixtures for the test is under tests/images. @@ -92,36 +82,33 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images, Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ - model_id, vision_language_config = model_and_config + model_id, vlm_config = model_and_config hf_model = hf_runner(model_id, dtype=dtype, is_vision_model=True) - hf_outputs = hf_model.generate_greedy(hf_image_prompts, + hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS, max_tokens, images=hf_images) del hf_model + vllm_image_prompts = [ + p.replace("", "" * vlm_config.image_feature_size) + for p in HF_IMAGE_PROMPTS + ] + vllm_model = vllm_runner(model_id, dtype=dtype, - worker_use_ray=worker_use_ray, enforce_eager=True, - **as_dict(vision_language_config)) + **vlm_config.as_cli_args_dict()) vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts, max_tokens, images=vllm_images) del vllm_model - gc.collect() - torch.cuda.empty_cache() - - for i in range(len(hf_image_prompts)): + for i in range(len(HF_IMAGE_PROMPTS)): hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = sanitize_vllm_output( - vllm_outputs[i], vision_language_config, model_id) + vllm_output_ids, vllm_output_str = vllm_to_hf_output( + vllm_outputs[i], vlm_config, model_id) assert hf_output_str == vllm_output_str, ( f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") assert hf_output_ids == vllm_output_ids, ( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") - - -# TODO: Add test for `tensor_parallel_size` [ref: PR #3883] -# (Requires multiple GPUs) diff --git a/vllm/config.py b/vllm/config.py index 8609bfb738fa7..3a14c391b7f00 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,7 +1,8 @@ import enum import json from dataclasses import dataclass, field, fields -from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union +from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, + Union) import torch from transformers import PretrainedConfig @@ -1114,6 +1115,25 @@ def get_image_input_enum_type(cls, value: str) -> ImageInputType: f"Expecting to choose from " f"{[x.name for x in cls.ImageInputType]}.") from e + def as_cli_args_dict(self) -> Dict[str, Any]: + """Flatten vision language config to pure args. + + Compatible with what llm entrypoint expects. + """ + result: Dict[str, Any] = {} + for f in fields(self): + value = getattr(self, f.name) + if isinstance(value, enum.Enum): + result[f.name] = value.name.lower() + elif isinstance(value, tuple): + result[f.name] = ",".join([str(item) for item in value]) + else: + result[f.name] = value + + result["disable_image_processor"] = self.image_processor is None + + return result + _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index b964e9ee42624..08fb09d111605 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -75,6 +75,14 @@ def __init__(self, image: Union[Image.Image, torch.Tensor]) -> None: self.image = image + def __repr__(self) -> str: + image = self.image + if isinstance(image, Image.Image): + return f"{type(self).__name__}(image={image})" + + return (f"{type(self).__name__}(image=torch.Tensor(shape=" + f"{image.shape}, dtype={image.dtype}))") + class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]): @@ -96,10 +104,10 @@ def _default_input_processor( self, data: ImagePixelData, model_config: ModelConfig, vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]: image = data.image - image_processor = self._get_hf_image_processor(model_config, - vlm_config) if isinstance(image, Image.Image): + image_processor = self._get_hf_image_processor( + model_config, vlm_config) if image_processor is None: raise RuntimeError("No HuggingFace processor is available" "to process the image object") @@ -127,6 +135,12 @@ class ImageFeatureData(MultiModalData): def __init__(self, image_features: torch.Tensor) -> None: self.image_features = image_features + def __repr__(self) -> str: + image_features = self.image_features + + return (f"{type(self).__name__}(image_features=torch.Tensor(shape=" + f"{image_features.shape}, dtype={image_features.dtype}))") + class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]):