Skip to content

Commit

Permalink
Disable Triton FA path in some problematic Paligemma and Phi3v tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mawong-amd committed Jul 19, 2024
1 parent 39c2cc5 commit e503f7c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
7 changes: 6 additions & 1 deletion tests/distributed/test_multimodal_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@

import pytest

from vllm.utils import cuda_device_count_stateless
from vllm.utils import cuda_device_count_stateless, is_hip

model = os.environ["TEST_DIST_MODEL"]

if model.startswith("llava-hf/llava"):
from ..models.test_llava import models, run_test
elif model.startswith("microsoft/Phi-3-vision"):
from ..models.test_phi3v import models, run_test

# ROCm Triton FA runs into issues with these models, use other backends
# FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
else:
raise NotImplementedError(f"Unsupported model: {model}")

Expand Down
6 changes: 6 additions & 0 deletions tests/models/test_paligemma.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import List, Optional, Tuple, Type

import pytest
Expand All @@ -23,6 +24,11 @@

models = ["google/paligemma-3b-mix-224"]

# ROCm Triton FA runs into issues with these models, use other backends
# FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"


def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
Expand Down

0 comments on commit e503f7c

Please sign in to comment.