From e503f7c2798d060533e8cc780af16bd17639146f Mon Sep 17 00:00:00 2001 From: Matt Wong <156021403+mawong-amd@users.noreply.github.com> Date: Fri, 19 Jul 2024 13:18:39 +0000 Subject: [PATCH] Disable Triton FA path in some problematic Paligemma and Phi3v tests --- tests/distributed/test_multimodal_broadcast.py | 7 ++++++- tests/models/test_paligemma.py | 6 ++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/distributed/test_multimodal_broadcast.py b/tests/distributed/test_multimodal_broadcast.py index 8e0e8ecd675eb..69ed801b8ee47 100644 --- a/tests/distributed/test_multimodal_broadcast.py +++ b/tests/distributed/test_multimodal_broadcast.py @@ -15,7 +15,7 @@ import pytest -from vllm.utils import cuda_device_count_stateless +from vllm.utils import cuda_device_count_stateless, is_hip model = os.environ["TEST_DIST_MODEL"] @@ -23,6 +23,11 @@ from ..models.test_llava import models, run_test elif model.startswith("microsoft/Phi-3-vision"): from ..models.test_phi3v import models, run_test + + # ROCm Triton FA runs into issues with these models, use other backends + # FIXME (mattwong, gshtrasb, hongxiayan) + if is_hip(): + os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" else: raise NotImplementedError(f"Unsupported model: {model}") diff --git a/tests/models/test_paligemma.py b/tests/models/test_paligemma.py index 9d1abb12a0380..1a38b35fc9702 100644 --- a/tests/models/test_paligemma.py +++ b/tests/models/test_paligemma.py @@ -1,3 +1,4 @@ +import os from typing import List, Optional, Tuple, Type import pytest @@ -23,6 +24,11 @@ models = ["google/paligemma-3b-mix-224"] +# ROCm Triton FA runs into issues with these models, use other backends +# FIXME (mattwong, gshtrasb, hongxiayan) +if is_hip(): + os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" + def vllm_to_hf_output(vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],