diff --git a/test/test_mps.py b/test/test_mps.py index fe7a65d3696fc..a294b1b0d2bcc 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -8385,6 +8385,14 @@ def test_cumprod_dim_check(self): self.assertRaises(IndexError, lambda: x.cumprod(2)) self.assertRaises(IndexError, lambda: x.cumprod(-3)) + def test_do_sync_thrice_its_all_right(self): + # Regression test for https://github.com/pytorch/pytorch/commit/9bc9d4cdb4355a385a7d7959f07d04d1648d6904 + # That caused sync calls to deadlock + x = torch.nextafter(torch.ones(1024, device='mps'), torch.zeros(1024, device='mps')) + for _ in range(3): + torch.mps.synchronize() + self.assertLess(x.sum().item(), x.numel()) + class TestLogical(TestCaseMPS): def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False): return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)