Skip to content
This repository was archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Add supplementary tests for CosineSimilarity, Dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaehyun An committed Sep 15, 2023
1 parent f23f891 commit 9c13863
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
30 changes: 28 additions & 2 deletions tests/test_cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,43 @@ def test_backward(z_size, y_size, x_size, dim, device):
else:
target_dim = (z_size, y_size)

grad_ouput = torch.randn(target_dim, **factory_kwargs)
grad_output = torch.randn(target_dim, **factory_kwargs)

def train(func):
i = x1.clone()
j = x2.clone()
i.requires_grad = j.requires_grad = True
func(i, j).backward(grad_ouput, retain_graph=True)
func(i, j).backward(grad_output, retain_graph=True)
return i.grad, j.grad

(x, y) = train(torch.nn.CosineSimilarity(dim))
(a, b) = train(trident.CosineSimilarity(dim))

assert util.equal(x, a)
assert util.equal(y, b)


@pytest.mark.parametrize("z_size, y_size, x_size, dim", [(640, 21, 86, 2)])
def test_cosine_similarity(z_size, y_size, x_size, dim, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
x1 = torch.randn(z_size, y_size, x_size, **factory_kwargs, requires_grad=True)
x2 = torch.randn(z_size, y_size, x_size, **factory_kwargs, requires_grad=True)

output = trident.CosineSimilarity(dim).forward(x1, x2)
assert output is not None
assert output.dtype == dtype

if dim == 0:
target_dim = (y_size, x_size)
elif dim == 1:
target_dim = (z_size, x_size)
else:
target_dim = (z_size, y_size)

grad_output = torch.randn(target_dim, **factory_kwargs)

output.backward(grad_output)
assert x1.grad is not None
assert x1.grad.dtype == dtype
assert x2.grad is not None
assert x2.grad.dtype == dtype
7 changes: 6 additions & 1 deletion tests/test_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ def train(func):
@pytest.mark.parametrize("x_size, p", [(16, 0.7)])
def test_dropout(x_size, p, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(x_size, **factory_kwargs)
input = torch.randn(x_size, **factory_kwargs, requires_grad=True)
grad_output = torch.randn_like(input)

output = trident.Dropout(p).forward(input)
assert output is not None and output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype

0 comments on commit 9c13863

Please sign in to comment.