Skip to content

Commit

Permalink
[CI/Build] Fix VLM test failures when using transformers v4.46 (vllm-…
Browse files Browse the repository at this point in the history
…project#9666)

Signed-off-by: Maxime Fournioux <[email protected]>
  • Loading branch information
DarkLight1337 authored and mfournioux committed Nov 20, 2024
1 parent 9fe22ba commit 47d9d12
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 12 deletions.
16 changes: 9 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,20 +232,22 @@ def video_assets() -> _VideoAssets:
return VIDEO_ASSETS


_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)


class HfRunner:

def wrap_device(self, input: _T, device: Optional[str] = None) -> _T:
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
if device is None:
return self.wrap_device(
input, "cpu" if current_platform.is_cpu() else "cuda")
device = "cpu" if current_platform.is_cpu() else "cuda"

if hasattr(input, "device") and input.device.type == device:
return input
if isinstance(x, dict):
return {k: self.wrap_device(v, device) for k, v in x.items()}

return input.to(device)
if hasattr(x, "device") and x.device.type == device:
return x

return x.to(device)

def __init__(
self,
Expand Down
5 changes: 5 additions & 0 deletions tests/models/decoder_only/vision_language/test_chameleon.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional, Type

import pytest
import transformers
from transformers import AutoModelForVision2Seq, BatchEncoding

from vllm.multimodal.utils import rescale_image_size
Expand Down Expand Up @@ -93,6 +94,10 @@ def process(hf_inputs: BatchEncoding):
)


@pytest.mark.skipif(
transformers.__version__.startswith("4.46.0"),
reason="Model broken in HF, see huggingface/transformers#34379",
)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
Expand Down
4 changes: 2 additions & 2 deletions tests/models/decoder_only/vision_language/test_minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
models = ["openbmb/MiniCPM-Llama3-V-2_5"]


def _wrap_inputs(hf_inputs: BatchEncoding) -> BatchEncoding:
return BatchEncoding({"model_inputs": hf_inputs})
def _wrap_inputs(hf_inputs: BatchEncoding):
return {"model_inputs": hf_inputs}


def trunc_hf_output(hf_output: Tuple[List[int], str,
Expand Down
15 changes: 12 additions & 3 deletions tests/models/decoder_only/vision_language/test_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from typing import List, Optional, Tuple, Type

import pytest
from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding)

from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from vllm.utils import is_hip
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_hip

from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from ...utils import check_logprobs_close
Expand Down Expand Up @@ -74,6 +75,7 @@ def run_test(
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
images = [asset.pil_image for asset in image_assets]

inputs_per_image = [(
Expand All @@ -100,7 +102,14 @@ def run_test(
for prompts, images in inputs_per_image
]

with hf_runner(model, dtype=dtype,
def process(hf_inputs: BatchEncoding):
hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
.to(torch_dtype) # type: ignore
return hf_inputs

with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
Expand Down

0 comments on commit 47d9d12

Please sign in to comment.