diff --git a/sharktank/tests/export_test.py b/sharktank/tests/export_test.py index 20b7de734..92992121b 100644 --- a/sharktank/tests/export_test.py +++ b/sharktank/tests/export_test.py @@ -4,6 +4,8 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import pytest + from sharktank.types import ( ReplicatedTensor, SplitPrimitiveTensor, @@ -70,6 +72,10 @@ def testGetFlatArgumentDeviceAffinities(self): } assert_dicts_equal(affinities, expected_affinities) + @pytest.mark.xfail( + torch.__version__ >= (2, 4), + reason="https://github.com/nod-ai/shark-ai/issues/685", + ) def testExportWithArgumentDeviceAffinities(self): args = (ReplicatedTensor(ts=[torch.tensor([1])]), torch.tensor([[2]])) diff --git a/sharktank/tests/layers/paged_llama_attention_block_test.py b/sharktank/tests/layers/paged_llama_attention_block_test.py index 63251c5a9..d5cb6863d 100644 --- a/sharktank/tests/layers/paged_llama_attention_block_test.py +++ b/sharktank/tests/layers/paged_llama_attention_block_test.py @@ -4,6 +4,8 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import pytest + import logging logging.basicConfig(level=logging.DEBUG) @@ -118,6 +120,10 @@ def forward(self, h, seq_block_ids, cache_state): asm = str(output.mlir_module) self.assertNotIn("scaled_dot_product_attention", asm) + @pytest.mark.xfail( + torch.__version__ >= (2, 4), + reason="https://github.com/nod-ai/shark-ai/issues/684", + ) def testExportNondecomposed(self): dtype = torch.float32 diff --git a/sharktank/tests/layers/sharded_conv2d_with_iree_test.py b/sharktank/tests/layers/sharded_conv2d_with_iree_test.py index 9b29e5761..f0153e25b 100644 --- a/sharktank/tests/layers/sharded_conv2d_with_iree_test.py +++ b/sharktank/tests/layers/sharded_conv2d_with_iree_test.py @@ -6,6 +6,8 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import pytest + from pathlib import Path import tempfile import torch @@ -181,6 +183,9 @@ def run_test_sharded_conv2d_with_iree( ) +@pytest.mark.xfail( + torch.__version__ >= (2, 5), reason="https://github.com/nod-ai/shark-ai/issues/682" +) def test_sharded_conv2d_with_iree( mlir_path: Optional[Path], module_path: Optional[Path], diff --git a/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py b/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py index 581584369..c24dc149e 100644 --- a/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py +++ b/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py @@ -228,6 +228,9 @@ def run_test_sharded_resnet_block_with_iree( strict=True, raises=AssertionError, ) +@pytest.mark.xfail( + torch.__version__ >= (2, 5), reason="https://github.com/nod-ai/shark-ai/issues/683" +) def test_sharded_resnet_block_with_iree( mlir_path: Optional[Path], module_path: Optional[Path],