Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Add backward tests for FP16 and BF16
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaehyun An committed Sep 15, 2023
1 parent f23f891 commit df66dbd
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 13 deletions.
30 changes: 28 additions & 2 deletions tests/test_cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,43 @@ def test_backward(z_size, y_size, x_size, dim, device):
else:
target_dim = (z_size, y_size)

grad_ouput = torch.randn(target_dim, **factory_kwargs)
grad_output = torch.randn(target_dim, **factory_kwargs)

def train(func):
i = x1.clone()
j = x2.clone()
i.requires_grad = j.requires_grad = True
func(i, j).backward(grad_ouput, retain_graph=True)
func(i, j).backward(grad_output, retain_graph=True)
return i.grad, j.grad

(x, y) = train(torch.nn.CosineSimilarity(dim))
(a, b) = train(trident.CosineSimilarity(dim))

assert util.equal(x, a)
assert util.equal(y, b)


@pytest.mark.parametrize("z_size, y_size, x_size, dim", [(640, 21, 86, 2)])
def test_cosine_similarity(z_size, y_size, x_size, dim, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
x1 = torch.randn(z_size, y_size, x_size, **factory_kwargs, requires_grad=True)
x2 = torch.randn(z_size, y_size, x_size, **factory_kwargs, requires_grad=True)

output = trident.CosineSimilarity(dim).forward(x1, x2)
assert output is not None
assert output.dtype == dtype

if dim == 0:
target_dim = (y_size, x_size)
elif dim == 1:
target_dim = (z_size, x_size)
else:
target_dim = (z_size, y_size)

grad_output = torch.randn(target_dim, **factory_kwargs)

output.backward(grad_output)
assert x1.grad is not None
assert x1.grad.dtype == dtype
assert x2.grad is not None
assert x2.grad.dtype == dtype
7 changes: 6 additions & 1 deletion tests/test_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ def train(func):
@pytest.mark.parametrize("x_size, p", [(16, 0.7)])
def test_dropout(x_size, p, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(x_size, **factory_kwargs)
input = torch.randn(x_size, **factory_kwargs, requires_grad=True)
grad_output = torch.randn_like(input)

output = trident.Dropout(p).forward(input)
assert output is not None and output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype
7 changes: 6 additions & 1 deletion tests/test_group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def train(func):
@pytest.mark.parametrize("num_batches, y_size, x_size, num_groups", [(1, 8, 1, 4)])
def test_group_norm(num_batches, y_size, x_size, num_groups, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn((num_batches, y_size, x_size), **factory_kwargs)
input = torch.randn((num_batches, y_size, x_size), **factory_kwargs, requires_grad=True)
grad_output = torch.rand_like(input)

operation = trident.GroupNorm(num_groups, y_size, **factory_kwargs)
output = operation.forward(input)
Expand All @@ -105,3 +106,7 @@ def test_group_norm(num_batches, y_size, x_size, num_groups, device, dtype):
operation = trident.GroupNorm(num_groups, y_size, affine=True, **factory_kwargs)
output = operation.forward(input)
assert output is not None and output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype
14 changes: 12 additions & 2 deletions tests/test_instance_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def train(func):
@pytest.mark.parametrize("num_channels, length", [(1, 64)])
def test_instance_norm1d(num_channels, length, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(num_channels, length, **factory_kwargs)
input = torch.randn(num_channels, length, **factory_kwargs, requires_grad=True)
grad_output = torch.rand_like(input)

output = trident.InstanceNorm1d(num_channels, **factory_kwargs).forward(input)

Expand All @@ -126,11 +127,16 @@ def test_instance_norm1d(num_channels, length, device, dtype):
assert output is not None
assert output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype


@pytest.mark.parametrize("num_batches, num_channels, height, width", [(1, 1, 64, 64)])
def test_instance_norm2d(num_batches, num_channels, height, width, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(num_batches, num_channels, height, width, **factory_kwargs)
input = torch.randn(num_batches, num_channels, height, width, **factory_kwargs, requires_grad=True)
grad_output = torch.rand_like(input)

output = trident.InstanceNorm2d(num_channels, **factory_kwargs).forward(input)

Expand All @@ -153,3 +159,7 @@ def test_instance_norm2d(num_batches, num_channels, height, width, device, dtype

assert output is not None
assert output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype
7 changes: 6 additions & 1 deletion tests/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,16 @@ def train(func):
@pytest.mark.parametrize("y_size, x_size", [(1, 32)])
def test_layer_norm(y_size, x_size, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(y_size, x_size, **factory_kwargs)
input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True)
normalized_shape = (input.shape[-1],)
grad_output = torch.rand_like(input)

output = trident.LayerNorm(normalized_shape, **factory_kwargs).forward(input)
assert output is not None and output.dtype == dtype

output = trident.LayerNorm(normalized_shape, elementwise_affine=False, **factory_kwargs).forward(input)
assert output is not None and output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype
8 changes: 7 additions & 1 deletion tests/test_leaky_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,15 @@ def train(func):
@pytest.mark.parametrize("num_batches, num_elements", [(1, 100)])
def test_leaky_relu(num_batches, num_elements, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(num_batches, num_elements, **factory_kwargs)
input = torch.randn(num_batches, num_elements, **factory_kwargs, requires_grad=True)
grad_output = torch.rand_like(input)

output = trident.LeakyReLU().forward(input)

assert output is not None
assert output.dtype == dtype

output.backward(grad_output)

assert input.grad is not None
assert input.grad.dtype == dtype
12 changes: 9 additions & 3 deletions tests/test_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@ def test_max(y_size, x_size, dim, device, dtype):
pytest.skip("Skipping due to bfloat16 dtype")

factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(y_size, x_size, **factory_kwargs)
input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True)
grad_output = torch.randn(x_size if dim == 0 else y_size, **factory_kwargs)

operation = trident.Max(dim)
assert operation.forward(input) is not None
output, argmax = trident.Max(dim).forward(input)
assert output is not None
assert output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype
10 changes: 8 additions & 2 deletions tests/test_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ def train(func):
@pytest.mark.parametrize("y_size, x_size, dim", [(1, 16, 0), (16, 1, 1)])
def test_mean(y_size, x_size, dim, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(y_size, x_size, **factory_kwargs)
input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True)
grad_output = torch.randn(x_size if dim == 0 else y_size, **factory_kwargs)

output = trident.Mean(dim).forward(input)
assert output is not None and output.dtype == dtype
assert output is not None
assert output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype

0 comments on commit df66dbd

Please sign in to comment.