From de9f5b68244a06914a89422ab626756fe3ec7250 Mon Sep 17 00:00:00 2001 From: Jaehyun An Date: Fri, 15 Sep 2023 14:57:03 +0900 Subject: [PATCH] Add backward tests for float16, bfloat16 Fix for GEGLU, PRELU is included. --- tests/test_cosine_similarity.py | 30 ++++++++++++++++++++++++++++-- tests/test_dropout.py | 7 ++++++- tests/test_geglu.py | 8 ++++++++ tests/test_gelu.py | 8 +++++++- tests/test_group_norm.py | 7 ++++++- tests/test_instance_norm.py | 14 ++++++++++++-- tests/test_layer_norm.py | 7 ++++++- tests/test_leaky_relu.py | 8 +++++++- tests/test_max.py | 12 +++++++++--- tests/test_mean.py | 10 ++++++++-- tests/test_prelu.py | 7 ++++++- tests/test_relu.py | 7 ++++++- tests/test_rms_norm.py | 7 ++++++- tests/test_silu.py | 7 ++++++- tests/test_softmax.py | 11 +++++++++-- tests/test_sum.py | 7 ++++++- trident/kernel/geglu.py | 5 +++-- trident/kernel/prelu.py | 5 +++-- trident/operation/geglu.py | 9 ++++++++- trident/operation/prelu.py | 1 + 20 files changed, 151 insertions(+), 26 deletions(-) diff --git a/tests/test_cosine_similarity.py b/tests/test_cosine_similarity.py index 43fe2fc1..c26a50ac 100644 --- a/tests/test_cosine_similarity.py +++ b/tests/test_cosine_similarity.py @@ -38,13 +38,13 @@ def test_backward(z_size, y_size, x_size, dim, device): else: target_dim = (z_size, y_size) - grad_ouput = torch.randn(target_dim, **factory_kwargs) + grad_output = torch.randn(target_dim, **factory_kwargs) def train(func): i = x1.clone() j = x2.clone() i.requires_grad = j.requires_grad = True - func(i, j).backward(grad_ouput, retain_graph=True) + func(i, j).backward(grad_output, retain_graph=True) return i.grad, j.grad (x, y) = train(torch.nn.CosineSimilarity(dim)) @@ -52,3 +52,29 @@ def train(func): assert util.equal(x, a) assert util.equal(y, b) + + +@pytest.mark.parametrize("z_size, y_size, x_size, dim", [(640, 21, 86, 2)]) +def test_cosine_similarity(z_size, y_size, x_size, dim, device, dtype): + factory_kwargs = {"device": device, "dtype": dtype} + x1 = torch.randn(z_size, y_size, x_size, **factory_kwargs, requires_grad=True) + x2 = torch.randn(z_size, y_size, x_size, **factory_kwargs, requires_grad=True) + + output = trident.CosineSimilarity(dim).forward(x1, x2) + assert output is not None + assert output.dtype == dtype + + if dim == 0: + target_dim = (y_size, x_size) + elif dim == 1: + target_dim = (z_size, x_size) + else: + target_dim = (z_size, y_size) + + grad_output = torch.randn(target_dim, **factory_kwargs) + + output.backward(grad_output) + assert x1.grad is not None + assert x1.grad.dtype == dtype + assert x2.grad is not None + assert x2.grad.dtype == dtype diff --git a/tests/test_dropout.py b/tests/test_dropout.py index 07a34aa0..42a8c88e 100644 --- a/tests/test_dropout.py +++ b/tests/test_dropout.py @@ -46,7 +46,12 @@ def train(func): @pytest.mark.parametrize("x_size, p", [(16, 0.7)]) def test_dropout(x_size, p, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn(x_size, **factory_kwargs) + input = torch.randn(x_size, **factory_kwargs, requires_grad=True) + grad_output = torch.randn_like(input) output = trident.Dropout(p).forward(input) assert output is not None and output.dtype == dtype + + output.backward(grad_output) + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/tests/test_geglu.py b/tests/test_geglu.py index b2913e28..a3376484 100644 --- a/tests/test_geglu.py +++ b/tests/test_geglu.py @@ -88,6 +88,9 @@ def train(func): def test_geglu(num_batches, m_size, n_size, k_size, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} input = torch.randn(num_batches, m_size, k_size, **factory_kwargs) + x_size = n_size // 2 + input = torch.randn(num_batches, m_size, k_size, **factory_kwargs, requires_grad=True) + grad_output = torch.randn(num_batches, m_size, x_size, **factory_kwargs) output = trident.GEGLU(m_size, n_size, **factory_kwargs).forward(input) @@ -98,3 +101,8 @@ def test_geglu(num_batches, m_size, n_size, k_size, device, dtype): assert output is not None assert output.dtype == dtype + + output.backward(grad_output) + + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/tests/test_gelu.py b/tests/test_gelu.py index 3d7d1888..e7dc9da3 100644 --- a/tests/test_gelu.py +++ b/tests/test_gelu.py @@ -46,7 +46,13 @@ def train(func): @pytest.mark.parametrize("y_size, x_size", [(2, 128)]) def test_gelu(y_size, x_size, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn((y_size, x_size), **factory_kwargs) + input = torch.randn(y_size, x_size, **factory_kwargs) + grad_output = torch.randn_like(input) output = trident.GELU().forward(input) assert output is not None and output.dtype == dtype + + output.backward(grad_output) + + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/tests/test_group_norm.py b/tests/test_group_norm.py index 6b11b71e..b86b77ca 100644 --- a/tests/test_group_norm.py +++ b/tests/test_group_norm.py @@ -96,7 +96,8 @@ def train(func): @pytest.mark.parametrize("num_batches, y_size, x_size, num_groups", [(1, 8, 1, 4)]) def test_group_norm(num_batches, y_size, x_size, num_groups, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn((num_batches, y_size, x_size), **factory_kwargs) + input = torch.randn((num_batches, y_size, x_size), **factory_kwargs, requires_grad=True) + grad_output = torch.rand_like(input) operation = trident.GroupNorm(num_groups, y_size, **factory_kwargs) output = operation.forward(input) @@ -105,3 +106,7 @@ def test_group_norm(num_batches, y_size, x_size, num_groups, device, dtype): operation = trident.GroupNorm(num_groups, y_size, affine=True, **factory_kwargs) output = operation.forward(input) assert output is not None and output.dtype == dtype + + output.backward(grad_output) + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/tests/test_instance_norm.py b/tests/test_instance_norm.py index 297e8bbe..53aa4c92 100644 --- a/tests/test_instance_norm.py +++ b/tests/test_instance_norm.py @@ -102,7 +102,8 @@ def train(func): @pytest.mark.parametrize("num_channels, length", [(1, 64)]) def test_instance_norm1d(num_channels, length, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn(num_channels, length, **factory_kwargs) + input = torch.randn(num_channels, length, **factory_kwargs, requires_grad=True) + grad_output = torch.rand_like(input) output = trident.InstanceNorm1d(num_channels, **factory_kwargs).forward(input) @@ -126,11 +127,16 @@ def test_instance_norm1d(num_channels, length, device, dtype): assert output is not None assert output.dtype == dtype + output.backward(grad_output) + assert input.grad is not None + assert input.grad.dtype == dtype + @pytest.mark.parametrize("num_batches, num_channels, height, width", [(1, 1, 64, 64)]) def test_instance_norm2d(num_batches, num_channels, height, width, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn(num_batches, num_channels, height, width, **factory_kwargs) + input = torch.randn(num_batches, num_channels, height, width, **factory_kwargs, requires_grad=True) + grad_output = torch.rand_like(input) output = trident.InstanceNorm2d(num_channels, **factory_kwargs).forward(input) @@ -153,3 +159,7 @@ def test_instance_norm2d(num_batches, num_channels, height, width, device, dtype assert output is not None assert output.dtype == dtype + + output.backward(grad_output) + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/tests/test_layer_norm.py b/tests/test_layer_norm.py index bffd1975..9c791f44 100644 --- a/tests/test_layer_norm.py +++ b/tests/test_layer_norm.py @@ -114,11 +114,16 @@ def train(func): @pytest.mark.parametrize("y_size, x_size", [(1, 32)]) def test_layer_norm(y_size, x_size, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn(y_size, x_size, **factory_kwargs) + input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True) normalized_shape = (input.shape[-1],) + grad_output = torch.rand_like(input) output = trident.LayerNorm(normalized_shape, **factory_kwargs).forward(input) assert output is not None and output.dtype == dtype output = trident.LayerNorm(normalized_shape, elementwise_affine=False, **factory_kwargs).forward(input) assert output is not None and output.dtype == dtype + + output.backward(grad_output) + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/tests/test_leaky_relu.py b/tests/test_leaky_relu.py index 85e8abc3..5da7a33c 100644 --- a/tests/test_leaky_relu.py +++ b/tests/test_leaky_relu.py @@ -46,9 +46,15 @@ def train(func): @pytest.mark.parametrize("num_batches, num_elements", [(1, 100)]) def test_leaky_relu(num_batches, num_elements, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn(num_batches, num_elements, **factory_kwargs) + input = torch.randn(num_batches, num_elements, **factory_kwargs, requires_grad=True) + grad_output = torch.rand_like(input) output = trident.LeakyReLU().forward(input) assert output is not None assert output.dtype == dtype + + output.backward(grad_output) + + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/tests/test_max.py b/tests/test_max.py index 19476929..2b9c54c8 100644 --- a/tests/test_max.py +++ b/tests/test_max.py @@ -54,7 +54,13 @@ def test_max(y_size, x_size, dim, device, dtype): pytest.skip("Skipping due to bfloat16 dtype") factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn(y_size, x_size, **factory_kwargs) + input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True) + grad_output = torch.randn(x_size if dim == 0 else y_size, **factory_kwargs) - operation = trident.Max(dim) - assert operation.forward(input) is not None + output, argmax = trident.Max(dim).forward(input) + assert output is not None + assert output.dtype == dtype + + output.backward(grad_output) + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/tests/test_mean.py b/tests/test_mean.py index 37327807..164fa019 100644 --- a/tests/test_mean.py +++ b/tests/test_mean.py @@ -46,7 +46,13 @@ def train(func): @pytest.mark.parametrize("y_size, x_size, dim", [(1, 16, 0), (16, 1, 1)]) def test_mean(y_size, x_size, dim, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn(y_size, x_size, **factory_kwargs) + input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True) + grad_output = torch.randn(x_size if dim == 0 else y_size, **factory_kwargs) output = trident.Mean(dim).forward(input) - assert output is not None and output.dtype == dtype + assert output is not None + assert output.dtype == dtype + + output.backward(grad_output) + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/tests/test_prelu.py b/tests/test_prelu.py index 495a32db..bf6beab8 100644 --- a/tests/test_prelu.py +++ b/tests/test_prelu.py @@ -50,9 +50,14 @@ def train(func): @pytest.mark.parametrize("y_size, x_size", [(1, 100)]) def test_prelu(y_size, x_size, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn(y_size, x_size, **factory_kwargs) + input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True) + grad_output = torch.rand_like(input) output = trident.PReLU(x_size, 0.3, **factory_kwargs).forward(input) assert output is not None assert output.dtype == dtype + + output.backward(grad_output) + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/tests/test_relu.py b/tests/test_relu.py index 33447ea6..5a3211db 100644 --- a/tests/test_relu.py +++ b/tests/test_relu.py @@ -46,9 +46,14 @@ def train(func): @pytest.mark.parametrize("y_size, x_size", [(1, 100)]) def test_relu(y_size, x_size, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn(y_size, x_size, **factory_kwargs) + input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True) + grad_output = torch.randn_like(input) output = trident.ReLU().forward(input) assert output is not None assert output.dtype == dtype + + output.backward(grad_output) + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/tests/test_rms_norm.py b/tests/test_rms_norm.py index ccafdd23..e7208625 100644 --- a/tests/test_rms_norm.py +++ b/tests/test_rms_norm.py @@ -91,7 +91,8 @@ def train(func): @pytest.mark.parametrize("y_size, x_size", [(4, 16)]) def test_rms_norm(y_size, x_size, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn((y_size, x_size), **factory_kwargs) + input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True) + grad_output = torch.randn(y_size, x_size, device=device) output = trident.RMSNorm(x_size, **factory_kwargs).forward(input) assert output is not None and output.dtype == dtype @@ -101,3 +102,7 @@ def test_rms_norm(y_size, x_size, device, dtype): assert output is not None and output.dtype == dtype output = trident.RMSNorm(x_size, 0.5, bias=True, **factory_kwargs).forward(input) assert output is not None and output.dtype == dtype + + output.backward(grad_output) + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/tests/test_silu.py b/tests/test_silu.py index 1c6f5dfb..8cb19c35 100644 --- a/tests/test_silu.py +++ b/tests/test_silu.py @@ -46,7 +46,12 @@ def train(func): @pytest.mark.parametrize("y_size, x_size", [(1, 32)]) def test_silu(y_size, x_size, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn(y_size, x_size, **factory_kwargs) + input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True) + grad_output = torch.rand_like(input) output = trident.SiLU().forward(input) assert output is not None and output.dtype == dtype + + output.backward(grad_output) + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/tests/test_softmax.py b/tests/test_softmax.py index 29788b4f..aa92f944 100644 --- a/tests/test_softmax.py +++ b/tests/test_softmax.py @@ -46,6 +46,13 @@ def train(func, dim): @pytest.mark.parametrize("y_size, x_size, dim", [(1, 32, 1)]) def test_softmax(y_size, x_size, dim, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn(y_size, x_size, **factory_kwargs) + input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True) + grad_output = torch.randn_like(input) - assert trident.Softmax(dim).forward(input) is not None + output = trident.Softmax(dim).forward(input) + assert output is not None + assert output.dtype == dtype + + output.backward(grad_output) + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/tests/test_sum.py b/tests/test_sum.py index 8f0eede5..aaf237f5 100644 --- a/tests/test_sum.py +++ b/tests/test_sum.py @@ -46,11 +46,16 @@ def train(func): @pytest.mark.parametrize("y_size, x_size, dim", [(1, 16, 0)]) def test_sum(y_size, x_size, dim, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn(y_size, x_size, **factory_kwargs) + input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True) + grad_output = torch.randn(x_size if dim == 0 else y_size, **factory_kwargs) output = trident.Sum(dim).forward(input) assert output is not None and output.dtype == dtype + output.backward(grad_output) + assert input.grad is not None + assert input.grad.dtype == dtype + @pytest.mark.parametrize("dim", [0, 1]) def test_sum_issue1(dim, device): diff --git a/trident/kernel/geglu.py b/trident/kernel/geglu.py index dbac029c..660cc26e 100644 --- a/trident/kernel/geglu.py +++ b/trident/kernel/geglu.py @@ -150,6 +150,7 @@ def backward( m_size: tl.int32, n_size: tl.int32, x_size: tl.int32, + dtype: tl.constexpr, x_block_size: tl.constexpr, ): pid = tl.program_id(0) @@ -202,5 +203,5 @@ def backward( gate = tl.load(gate_block_ptr, boundary_check=(1,)) grad_state = grad_output * language.math.GELU.forward(gate) grad_gate = language.math.GELU.backward(grad_output * state, gate) - tl.store(grad_state_block_ptr, grad_state, boundary_check=(1,)) - tl.store(grad_gate_block_ptr, grad_gate, boundary_check=(1,)) + tl.store(grad_state_block_ptr, grad_state.to(dtype), boundary_check=(1,)) + tl.store(grad_gate_block_ptr, grad_gate.to(dtype), boundary_check=(1,)) diff --git a/trident/kernel/prelu.py b/trident/kernel/prelu.py index 14544ede..b72dad54 100644 --- a/trident/kernel/prelu.py +++ b/trident/kernel/prelu.py @@ -102,6 +102,7 @@ def backward( batch_stride: tl.int32, y_stride: tl.int32, x_stride: tl.int32, + dtype: tl.constexpr, y_block_size: tl.constexpr, x_block_size: tl.constexpr, ): @@ -162,5 +163,5 @@ def backward( grad_output = tl.load(grad_output_block_ptr, boundary_check=(1, 2)) grad_input = language.math.LeakyReLU.backward(grad_output, input, weight) grad_weight = grad_output * tl.where(input > 0, 0, input) - tl.store(grad_input_block_ptr, grad_input, boundary_check=(1, 2)) - tl.store(grad_weight_staging_block_ptr, grad_weight, boundary_check=(1, 2)) + tl.store(grad_input_block_ptr, grad_input.to(dtype), boundary_check=(1, 2)) + tl.store(grad_weight_staging_block_ptr, grad_weight.to(dtype), boundary_check=(1, 2)) diff --git a/trident/operation/geglu.py b/trident/operation/geglu.py index 2d41f294..b7cae541 100644 --- a/trident/operation/geglu.py +++ b/trident/operation/geglu.py @@ -99,7 +99,14 @@ def grid(meta): util.push_trace("kernel.GEGLU.backward") kernel.GEGLU.backward[grid]( - grad_state_gate, grad_output, state_gate, m_size, n_size, x_size, triton.next_power_of_2(x_size) + grad_state_gate, + grad_output, + state_gate, + m_size, + n_size, + x_size, + util.dtype(state_gate.dtype), + triton.next_power_of_2(x_size), ) util.pop_trace() diff --git a/trident/operation/prelu.py b/trident/operation/prelu.py index a08b2540..0b06ff40 100644 --- a/trident/operation/prelu.py +++ b/trident/operation/prelu.py @@ -95,6 +95,7 @@ def grid(meta): grad_input.stride(0), grad_input.stride(1), grad_input.stride(2), + util.dtype(grad_input.dtype), ) util.pop_trace()