From 196b46197a7e7223c130cd03c74515d6c28709e3 Mon Sep 17 00:00:00 2001 From: Matthew Wong Date: Fri, 19 Jul 2024 18:01:20 +0000 Subject: [PATCH] Limit PaliGemma to half precision on ROCm --- tests/models/test_paligemma.py | 14 +++++++++++--- vllm/model_executor/models/__init__.py | 12 +++++++++--- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/tests/models/test_paligemma.py b/tests/models/test_paligemma.py index e11784558f196..e1c39ee6fecb6 100644 --- a/tests/models/test_paligemma.py +++ b/tests/models/test_paligemma.py @@ -24,8 +24,8 @@ models = ["google/paligemma-3b-mix-224"] -# ROCm Triton FA can run into shared memory issues with these models, -# use other backends in the meantime +# ROCm Triton FA can run into compilation issues with these models due to, +# excessive use of shared memory. Use other backends in the meantime. # FIXME (mattwong, gshtrasb, hongxiayan) if is_hip(): os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" @@ -138,7 +138,15 @@ def run_test( [0.25, 0.5, 1.0], ], ) -@pytest.mark.parametrize("dtype", ["float", "half"]) +@pytest.mark.parametrize("dtype", [ + pytest.param( + "float", + marks=pytest.mark.skipif( + is_hip(), + reason= + "ROCm FA does not yet fully support 32-bit precision on PaliGemma") + ), "half" +]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 58b6d73e5f113..8e7f470240ba8 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -91,9 +91,15 @@ "please use CK flash attention by setting " "`VLLM_USE_TRITON_FLASH_ATTN=0`") _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { - "Qwen2ForCausalLM": _ROCM_SWA_REASON, - "MistralForCausalLM": _ROCM_SWA_REASON, - "MixtralForCausalLM": _ROCM_SWA_REASON, + "Qwen2ForCausalLM": + _ROCM_SWA_REASON, + "MistralForCausalLM": + _ROCM_SWA_REASON, + "MixtralForCausalLM": + _ROCM_SWA_REASON, + "PaliGemmaForConditionalGeneration": + ("ROCm flash attention does not yet " + "fully support 32-bit precision on PaliGemma") }