diff --git a/tests/test_cosine_similarity.py b/tests/test_cosine_similarity.py index 43fe2fc..c26a50a 100644 --- a/tests/test_cosine_similarity.py +++ b/tests/test_cosine_similarity.py @@ -38,13 +38,13 @@ def test_backward(z_size, y_size, x_size, dim, device): else: target_dim = (z_size, y_size) - grad_ouput = torch.randn(target_dim, **factory_kwargs) + grad_output = torch.randn(target_dim, **factory_kwargs) def train(func): i = x1.clone() j = x2.clone() i.requires_grad = j.requires_grad = True - func(i, j).backward(grad_ouput, retain_graph=True) + func(i, j).backward(grad_output, retain_graph=True) return i.grad, j.grad (x, y) = train(torch.nn.CosineSimilarity(dim)) @@ -52,3 +52,29 @@ def train(func): assert util.equal(x, a) assert util.equal(y, b) + + +@pytest.mark.parametrize("z_size, y_size, x_size, dim", [(640, 21, 86, 2)]) +def test_cosine_similarity(z_size, y_size, x_size, dim, device, dtype): + factory_kwargs = {"device": device, "dtype": dtype} + x1 = torch.randn(z_size, y_size, x_size, **factory_kwargs, requires_grad=True) + x2 = torch.randn(z_size, y_size, x_size, **factory_kwargs, requires_grad=True) + + output = trident.CosineSimilarity(dim).forward(x1, x2) + assert output is not None + assert output.dtype == dtype + + if dim == 0: + target_dim = (y_size, x_size) + elif dim == 1: + target_dim = (z_size, x_size) + else: + target_dim = (z_size, y_size) + + grad_output = torch.randn(target_dim, **factory_kwargs) + + output.backward(grad_output) + assert x1.grad is not None + assert x1.grad.dtype == dtype + assert x2.grad is not None + assert x2.grad.dtype == dtype diff --git a/tests/test_dropout.py b/tests/test_dropout.py index 07a34aa..42a8c88 100644 --- a/tests/test_dropout.py +++ b/tests/test_dropout.py @@ -46,7 +46,12 @@ def train(func): @pytest.mark.parametrize("x_size, p", [(16, 0.7)]) def test_dropout(x_size, p, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn(x_size, **factory_kwargs) + input = torch.randn(x_size, **factory_kwargs, requires_grad=True) + grad_output = torch.randn_like(input) output = trident.Dropout(p).forward(input) assert output is not None and output.dtype == dtype + + output.backward(grad_output) + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/tests/test_group_norm.py b/tests/test_group_norm.py index 6b11b71..b86b77c 100644 --- a/tests/test_group_norm.py +++ b/tests/test_group_norm.py @@ -96,7 +96,8 @@ def train(func): @pytest.mark.parametrize("num_batches, y_size, x_size, num_groups", [(1, 8, 1, 4)]) def test_group_norm(num_batches, y_size, x_size, num_groups, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn((num_batches, y_size, x_size), **factory_kwargs) + input = torch.randn((num_batches, y_size, x_size), **factory_kwargs, requires_grad=True) + grad_output = torch.rand_like(input) operation = trident.GroupNorm(num_groups, y_size, **factory_kwargs) output = operation.forward(input) @@ -105,3 +106,7 @@ def test_group_norm(num_batches, y_size, x_size, num_groups, device, dtype): operation = trident.GroupNorm(num_groups, y_size, affine=True, **factory_kwargs) output = operation.forward(input) assert output is not None and output.dtype == dtype + + output.backward(grad_output) + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/tests/test_instance_norm.py b/tests/test_instance_norm.py index 297e8bb..53aa4c9 100644 --- a/tests/test_instance_norm.py +++ b/tests/test_instance_norm.py @@ -102,7 +102,8 @@ def train(func): @pytest.mark.parametrize("num_channels, length", [(1, 64)]) def test_instance_norm1d(num_channels, length, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn(num_channels, length, **factory_kwargs) + input = torch.randn(num_channels, length, **factory_kwargs, requires_grad=True) + grad_output = torch.rand_like(input) output = trident.InstanceNorm1d(num_channels, **factory_kwargs).forward(input) @@ -126,11 +127,16 @@ def test_instance_norm1d(num_channels, length, device, dtype): assert output is not None assert output.dtype == dtype + output.backward(grad_output) + assert input.grad is not None + assert input.grad.dtype == dtype + @pytest.mark.parametrize("num_batches, num_channels, height, width", [(1, 1, 64, 64)]) def test_instance_norm2d(num_batches, num_channels, height, width, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn(num_batches, num_channels, height, width, **factory_kwargs) + input = torch.randn(num_batches, num_channels, height, width, **factory_kwargs, requires_grad=True) + grad_output = torch.rand_like(input) output = trident.InstanceNorm2d(num_channels, **factory_kwargs).forward(input) @@ -153,3 +159,7 @@ def test_instance_norm2d(num_batches, num_channels, height, width, device, dtype assert output is not None assert output.dtype == dtype + + output.backward(grad_output) + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/tests/test_layer_norm.py b/tests/test_layer_norm.py index bffd197..9c791f4 100644 --- a/tests/test_layer_norm.py +++ b/tests/test_layer_norm.py @@ -114,11 +114,16 @@ def train(func): @pytest.mark.parametrize("y_size, x_size", [(1, 32)]) def test_layer_norm(y_size, x_size, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn(y_size, x_size, **factory_kwargs) + input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True) normalized_shape = (input.shape[-1],) + grad_output = torch.rand_like(input) output = trident.LayerNorm(normalized_shape, **factory_kwargs).forward(input) assert output is not None and output.dtype == dtype output = trident.LayerNorm(normalized_shape, elementwise_affine=False, **factory_kwargs).forward(input) assert output is not None and output.dtype == dtype + + output.backward(grad_output) + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/tests/test_leaky_relu.py b/tests/test_leaky_relu.py index 85e8abc..5da7a33 100644 --- a/tests/test_leaky_relu.py +++ b/tests/test_leaky_relu.py @@ -46,9 +46,15 @@ def train(func): @pytest.mark.parametrize("num_batches, num_elements", [(1, 100)]) def test_leaky_relu(num_batches, num_elements, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn(num_batches, num_elements, **factory_kwargs) + input = torch.randn(num_batches, num_elements, **factory_kwargs, requires_grad=True) + grad_output = torch.rand_like(input) output = trident.LeakyReLU().forward(input) assert output is not None assert output.dtype == dtype + + output.backward(grad_output) + + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/tests/test_max.py b/tests/test_max.py index 1947692..2b9c54c 100644 --- a/tests/test_max.py +++ b/tests/test_max.py @@ -54,7 +54,13 @@ def test_max(y_size, x_size, dim, device, dtype): pytest.skip("Skipping due to bfloat16 dtype") factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn(y_size, x_size, **factory_kwargs) + input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True) + grad_output = torch.randn(x_size if dim == 0 else y_size, **factory_kwargs) - operation = trident.Max(dim) - assert operation.forward(input) is not None + output, argmax = trident.Max(dim).forward(input) + assert output is not None + assert output.dtype == dtype + + output.backward(grad_output) + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/tests/test_mean.py b/tests/test_mean.py index 3732780..164fa01 100644 --- a/tests/test_mean.py +++ b/tests/test_mean.py @@ -46,7 +46,13 @@ def train(func): @pytest.mark.parametrize("y_size, x_size, dim", [(1, 16, 0), (16, 1, 1)]) def test_mean(y_size, x_size, dim, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn(y_size, x_size, **factory_kwargs) + input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True) + grad_output = torch.randn(x_size if dim == 0 else y_size, **factory_kwargs) output = trident.Mean(dim).forward(input) - assert output is not None and output.dtype == dtype + assert output is not None + assert output.dtype == dtype + + output.backward(grad_output) + assert input.grad is not None + assert input.grad.dtype == dtype