Skip to content

Commit

Permalink
Removes vae test pytest dependencies as it didn't order tests properly
Browse files Browse the repository at this point in the history
Moves irpa generation into vae setup
  • Loading branch information
IanNod committed Dec 9, 2024
1 parent ca4fa44 commit 96a2f3e
Showing 1 changed file with 0 additions and 6 deletions.
6 changes: 0 additions & 6 deletions sharktank/tests/models/vae/vae_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ def setUp(self):
filename="vae/vae.safetensors",
)
torch.manual_seed(12345)

@pytest.mark.dependency()
def testIrpaConversion(self):
f32_dataset = import_hf_config(
"sdxl_vae/vae/config.json",
"sdxl_vae/vae/diffusion_pytorch_model.safetensors",
Expand All @@ -75,7 +72,6 @@ def testIrpaConversion(self):
)
f16_dataset.save("sdxl_vae/vae_f16.irpa", io_report_callback=print)

@pytest.mark.dependency(depends=["testIrpaConversion"])
def testCompareF32EagerVsHuggingface(self):
dtype = getattr(torch, "float32")
inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1)
Expand All @@ -89,7 +85,6 @@ def testCompareF32EagerVsHuggingface(self):
torch.testing.assert_close(ref_results, results)

@pytest.mark.skip(reason="running fp16 on cpu is extremely slow")
@pytest.mark.dependency(depends=["testIrpaConversion"])
def testCompareF16EagerVsHuggingface(self):
dtype = getattr(torch, "float32")
inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1)
Expand All @@ -102,7 +97,6 @@ def testCompareF16EagerVsHuggingface(self):

torch.testing.assert_close(ref_results, results)

@pytest.mark.dependency(depends=["testIrpaConversion"])
def testVaeIreeVsHuggingFace(self):
dtype = getattr(torch, "float32")
inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1)
Expand Down

0 comments on commit 96a2f3e

Please sign in to comment.