Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI/Build] Fix VLM test failures when using transformers v4.46 #9666

Merged
merged 3 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,20 +232,22 @@ def video_assets() -> _VideoAssets:
return VIDEO_ASSETS


_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)


class HfRunner:

def wrap_device(self, input: _T, device: Optional[str] = None) -> _T:
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
if device is None:
return self.wrap_device(
input, "cpu" if current_platform.is_cpu() else "cuda")
device = "cpu" if current_platform.is_cpu() else "cuda"

if hasattr(input, "device") and input.device.type == device:
return input
if isinstance(x, dict):
return {k: self.wrap_device(v, device) for k, v in x.items()}

return input.to(device)
if hasattr(x, "device") and x.device.type == device:
return x

return x.to(device)

def __init__(
self,
Expand Down
5 changes: 5 additions & 0 deletions tests/models/decoder_only/vision_language/test_chameleon.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional, Type

import pytest
import transformers
from transformers import AutoModelForVision2Seq, BatchEncoding

from vllm.multimodal.utils import rescale_image_size
Expand Down Expand Up @@ -93,6 +94,10 @@ def process(hf_inputs: BatchEncoding):
)


@pytest.mark.skipif(
transformers.__version__.startswith("4.46.0"),
reason="Model broken in HF, see huggingface/transformers#34379",
)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
Expand Down
4 changes: 2 additions & 2 deletions tests/models/decoder_only/vision_language/test_minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
models = ["openbmb/MiniCPM-Llama3-V-2_5"]


def _wrap_inputs(hf_inputs: BatchEncoding) -> BatchEncoding:
return BatchEncoding({"model_inputs": hf_inputs})
def _wrap_inputs(hf_inputs: BatchEncoding):
return {"model_inputs": hf_inputs}


def trunc_hf_output(hf_output: Tuple[List[int], str,
Expand Down
15 changes: 12 additions & 3 deletions tests/models/decoder_only/vision_language/test_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from typing import List, Optional, Tuple, Type

import pytest
from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding)

from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from vllm.utils import is_hip
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_hip

from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from ...utils import check_logprobs_close
Expand Down Expand Up @@ -74,6 +75,7 @@ def run_test(
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
images = [asset.pil_image for asset in image_assets]

inputs_per_image = [(
Expand All @@ -100,7 +102,14 @@ def run_test(
for prompts, images in inputs_per_image
]

with hf_runner(model, dtype=dtype,
def process(hf_inputs: BatchEncoding):
hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
.to(torch_dtype) # type: ignore
return hf_inputs

with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
Expand Down