Skip to content

Commit

Permalink
improved coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
Hoppe committed Jul 8, 2024
1 parent 59f45c9 commit b388446
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion heat/core/tests/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_vmap(self):
# inputs split along different axes, output split along same axis (one of them different to input split)
x0 = ht.random.randn(5 * ht.MPI_WORLD.size, 10, 10, split=0)
x1 = ht.random.randn(10, 5 * ht.MPI_WORLD.size, split=1)
out_dims = (0, 0)
out_dims = 0 # test with out_dims as int (tuple below)

def func(x0, x1, k=2, scale=1e-2):
return torch.topk(torch.linalg.svdvals(x0), k)[0] ** 2, scale * x0 @ x1
Expand Down Expand Up @@ -61,6 +61,28 @@ def func(x0, x1, k=2, scale=1e-2):
self.assertTrue(torch.allclose(y0.resplit(None).larray, y0_torch))
self.assertTrue(torch.allclose(y1.resplit(None).larray, y1_torch))

# catch wrong number of output dimensions
with self.assertRaises(ValueError):
vfunc = ht.vmap(func, (0, 1, 2))
y0, y1 = vfunc(x0, x1, k=2, scale=2.2)

# one output only
def func(x0, m=1, scale=2):
return (x0 - m) ** scale

vfunc = ht.vmap(func, out_dims=(0,))

x0 = ht.random.randn(5 * ht.MPI_WORLD.size, 10, 10, split=0)
y0 = vfunc(x0, m=2, scale=3)[0]

x0_torch = x0.resplit(None).larray
vfunc_torch = torch.vmap(func, (0,), (0,))
y0_torch = vfunc_torch(x0_torch, m=2, scale=3)

print(y0.resplit(None).larray, y0_torch)

self.assertTrue(torch.allclose(y0.resplit(None).larray, y0_torch))

def test_vmap_with_chunks(self):
# same as before but now with prescribed chunk sizes for the vmap
x0 = ht.random.randn(5 * ht.MPI_WORLD.size, 10, 10, split=0)
Expand Down

0 comments on commit b388446

Please sign in to comment.