Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Add backward tests for float16 and bfloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaehyun An authored and daemyung committed Sep 15, 2023
1 parent 17d2b30 commit 45cf51e
Show file tree
Hide file tree
Showing 21 changed files with 159 additions and 28 deletions.
30 changes: 28 additions & 2 deletions tests/test_cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,43 @@ def test_backward(z_size, y_size, x_size, dim, device):
else:
target_dim = (z_size, y_size)

grad_ouput = torch.randn(target_dim, **factory_kwargs)
grad_output = torch.randn(target_dim, **factory_kwargs)

def train(func):
i = x1.clone()
j = x2.clone()
i.requires_grad = j.requires_grad = True
func(i, j).backward(grad_ouput, retain_graph=True)
func(i, j).backward(grad_output, retain_graph=True)
return i.grad, j.grad

(x, y) = train(torch.nn.CosineSimilarity(dim))
(a, b) = train(trident.CosineSimilarity(dim))

assert util.equal(x, a)
assert util.equal(y, b)


@pytest.mark.parametrize("z_size, y_size, x_size, dim", [(640, 21, 86, 2)])
def test_cosine_similarity(z_size, y_size, x_size, dim, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
x1 = torch.randn(z_size, y_size, x_size, **factory_kwargs, requires_grad=True)
x2 = torch.randn(z_size, y_size, x_size, **factory_kwargs, requires_grad=True)

output = trident.CosineSimilarity(dim).forward(x1, x2)
assert output is not None
assert output.dtype == dtype

if dim == 0:
target_dim = (y_size, x_size)
elif dim == 1:
target_dim = (z_size, x_size)
else:
target_dim = (z_size, y_size)

grad_output = torch.randn(target_dim, **factory_kwargs)

output.backward(grad_output)
assert x1.grad is not None
assert x1.grad.dtype == dtype
assert x2.grad is not None
assert x2.grad.dtype == dtype
7 changes: 6 additions & 1 deletion tests/test_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ def train(func):
@pytest.mark.parametrize("x_size, p", [(16, 0.7)])
def test_dropout(x_size, p, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(x_size, **factory_kwargs)
input = torch.randn(x_size, **factory_kwargs, requires_grad=True)
grad_output = torch.randn_like(input)

output = trident.Dropout(p).forward(input)
assert output is not None and output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype
8 changes: 8 additions & 0 deletions tests/test_geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def train(func):
def test_geglu(num_batches, m_size, n_size, k_size, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(num_batches, m_size, k_size, **factory_kwargs)
x_size = n_size // 2
input = torch.randn(num_batches, m_size, k_size, **factory_kwargs, requires_grad=True)
grad_output = torch.randn(num_batches, m_size, x_size, **factory_kwargs)

output = trident.GEGLU(m_size, n_size, **factory_kwargs).forward(input)

Expand All @@ -98,3 +101,8 @@ def test_geglu(num_batches, m_size, n_size, k_size, device, dtype):

assert output is not None
assert output.dtype == dtype

output.backward(grad_output)

assert input.grad is not None
assert input.grad.dtype == dtype
8 changes: 7 additions & 1 deletion tests/test_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ def train(func):
@pytest.mark.parametrize("y_size, x_size", [(2, 128)])
def test_gelu(y_size, x_size, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn((y_size, x_size), **factory_kwargs)
input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True)
grad_output = torch.randn_like(input)

output = trident.GELU().forward(input)
assert output is not None and output.dtype == dtype

output.backward(grad_output)

assert input.grad is not None
assert input.grad.dtype == dtype
7 changes: 6 additions & 1 deletion tests/test_group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def train(func):
@pytest.mark.parametrize("num_batches, y_size, x_size, num_groups", [(1, 8, 1, 4)])
def test_group_norm(num_batches, y_size, x_size, num_groups, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn((num_batches, y_size, x_size), **factory_kwargs)
input = torch.randn(num_batches, y_size, x_size, **factory_kwargs, requires_grad=True)
grad_output = torch.rand_like(input)

operation = trident.GroupNorm(num_groups, y_size, **factory_kwargs)
output = operation.forward(input)
Expand All @@ -105,3 +106,7 @@ def test_group_norm(num_batches, y_size, x_size, num_groups, device, dtype):
operation = trident.GroupNorm(num_groups, y_size, affine=True, **factory_kwargs)
output = operation.forward(input)
assert output is not None and output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype
14 changes: 12 additions & 2 deletions tests/test_instance_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def train(func):
@pytest.mark.parametrize("num_channels, length", [(1, 64)])
def test_instance_norm1d(num_channels, length, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(num_channels, length, **factory_kwargs)
input = torch.randn(num_channels, length, **factory_kwargs, requires_grad=True)
grad_output = torch.rand_like(input)

output = trident.InstanceNorm1d(num_channels, **factory_kwargs).forward(input)

Expand All @@ -126,11 +127,16 @@ def test_instance_norm1d(num_channels, length, device, dtype):
assert output is not None
assert output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype


@pytest.mark.parametrize("num_batches, num_channels, height, width", [(1, 1, 64, 64)])
def test_instance_norm2d(num_batches, num_channels, height, width, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(num_batches, num_channels, height, width, **factory_kwargs)
input = torch.randn(num_batches, num_channels, height, width, **factory_kwargs, requires_grad=True)
grad_output = torch.rand_like(input)

output = trident.InstanceNorm2d(num_channels, **factory_kwargs).forward(input)

Expand All @@ -153,3 +159,7 @@ def test_instance_norm2d(num_batches, num_channels, height, width, device, dtype

assert output is not None
assert output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype
7 changes: 6 additions & 1 deletion tests/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,16 @@ def train(func):
@pytest.mark.parametrize("y_size, x_size", [(1, 32)])
def test_layer_norm(y_size, x_size, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(y_size, x_size, **factory_kwargs)
input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True)
normalized_shape = (input.shape[-1],)
grad_output = torch.rand_like(input)

output = trident.LayerNorm(normalized_shape, **factory_kwargs).forward(input)
assert output is not None and output.dtype == dtype

output = trident.LayerNorm(normalized_shape, elementwise_affine=False, **factory_kwargs).forward(input)
assert output is not None and output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype
8 changes: 7 additions & 1 deletion tests/test_leaky_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,15 @@ def train(func):
@pytest.mark.parametrize("num_batches, num_elements", [(1, 100)])
def test_leaky_relu(num_batches, num_elements, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(num_batches, num_elements, **factory_kwargs)
input = torch.randn(num_batches, num_elements, **factory_kwargs, requires_grad=True)
grad_output = torch.rand_like(input)

output = trident.LeakyReLU().forward(input)

assert output is not None
assert output.dtype == dtype

output.backward(grad_output)

assert input.grad is not None
assert input.grad.dtype == dtype
12 changes: 9 additions & 3 deletions tests/test_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@ def test_max(y_size, x_size, dim, device, dtype):
pytest.skip("Skipping due to bfloat16 dtype")

factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(y_size, x_size, **factory_kwargs)
input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True)
grad_output = torch.randn(x_size if dim == 0 else y_size, **factory_kwargs)

operation = trident.Max(dim)
assert operation.forward(input) is not None
output, argmax = trident.Max(dim).forward(input)
assert output is not None
assert output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype
10 changes: 8 additions & 2 deletions tests/test_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ def train(func):
@pytest.mark.parametrize("y_size, x_size, dim", [(1, 16, 0), (16, 1, 1)])
def test_mean(y_size, x_size, dim, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(y_size, x_size, **factory_kwargs)
input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True)
grad_output = torch.randn(x_size if dim == 0 else y_size, **factory_kwargs)

output = trident.Mean(dim).forward(input)
assert output is not None and output.dtype == dtype
assert output is not None
assert output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype
7 changes: 6 additions & 1 deletion tests/test_prelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,14 @@ def train(func):
@pytest.mark.parametrize("y_size, x_size", [(1, 100)])
def test_prelu(y_size, x_size, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(y_size, x_size, **factory_kwargs)
input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True)
grad_output = torch.rand_like(input)

output = trident.PReLU(x_size, 0.3, **factory_kwargs).forward(input)

assert output is not None
assert output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype
7 changes: 6 additions & 1 deletion tests/test_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,14 @@ def train(func):
@pytest.mark.parametrize("y_size, x_size", [(1, 100)])
def test_relu(y_size, x_size, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(y_size, x_size, **factory_kwargs)
input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True)
grad_output = torch.randn_like(input)

output = trident.ReLU().forward(input)

assert output is not None
assert output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype
7 changes: 6 additions & 1 deletion tests/test_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def train(func):
@pytest.mark.parametrize("y_size, x_size", [(4, 16)])
def test_rms_norm(y_size, x_size, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn((y_size, x_size), **factory_kwargs)
input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True)
grad_output = torch.randn_like(input)

output = trident.RMSNorm(x_size, **factory_kwargs).forward(input)
assert output is not None and output.dtype == dtype
Expand All @@ -101,3 +102,7 @@ def test_rms_norm(y_size, x_size, device, dtype):
assert output is not None and output.dtype == dtype
output = trident.RMSNorm(x_size, 0.5, bias=True, **factory_kwargs).forward(input)
assert output is not None and output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype
10 changes: 8 additions & 2 deletions tests/test_shift_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@ def train(func):
@pytest.mark.parametrize("num_batches, y_size, x_size", [(2, 10, 1000)])
def test_shift_gelu(num_batches, y_size, x_size, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn((num_batches, y_size, x_size), **factory_kwargs)
input = torch.randn(num_batches, y_size, x_size, **factory_kwargs, requires_grad=True)
grad_output = torch.randn_like(input)

output = trident.ShiftGELU(x_size, **factory_kwargs).forward(input)
assert output is not None and output.dtype == dtype
assert output is not None
assert output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype
7 changes: 6 additions & 1 deletion tests/test_silu.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ def train(func):
@pytest.mark.parametrize("y_size, x_size", [(1, 32)])
def test_silu(y_size, x_size, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(y_size, x_size, **factory_kwargs)
input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True)
grad_output = torch.rand_like(input)

output = trident.SiLU().forward(input)
assert output is not None and output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype
11 changes: 9 additions & 2 deletions tests/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def train(func, dim):
@pytest.mark.parametrize("y_size, x_size, dim", [(1, 32, 1)])
def test_softmax(y_size, x_size, dim, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(y_size, x_size, **factory_kwargs)
input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True)
grad_output = torch.randn_like(input)

assert trident.Softmax(dim).forward(input) is not None
output = trident.Softmax(dim).forward(input)
assert output is not None
assert output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype
7 changes: 6 additions & 1 deletion tests/test_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,16 @@ def train(func):
@pytest.mark.parametrize("y_size, x_size, dim", [(1, 16, 0)])
def test_sum(y_size, x_size, dim, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(y_size, x_size, **factory_kwargs)
input = torch.randn(y_size, x_size, **factory_kwargs, requires_grad=True)
grad_output = torch.randn(x_size if dim == 0 else y_size, **factory_kwargs)

output = trident.Sum(dim).forward(input)
assert output is not None and output.dtype == dtype

output.backward(grad_output)
assert input.grad is not None
assert input.grad.dtype == dtype


@pytest.mark.parametrize("dim", [0, 1])
def test_sum_issue1(dim, device):
Expand Down
5 changes: 3 additions & 2 deletions trident/kernel/geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def backward(
m_size: tl.int32,
n_size: tl.int32,
x_size: tl.int32,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
):
pid = tl.program_id(0)
Expand Down Expand Up @@ -202,5 +203,5 @@ def backward(
gate = tl.load(gate_block_ptr, boundary_check=(1,))
grad_state = grad_output * language.math.GELU.forward(gate)
grad_gate = language.math.GELU.backward(grad_output * state, gate)
tl.store(grad_state_block_ptr, grad_state, boundary_check=(1,))
tl.store(grad_gate_block_ptr, grad_gate, boundary_check=(1,))
tl.store(grad_state_block_ptr, grad_state.to(dtype), boundary_check=(1,))
tl.store(grad_gate_block_ptr, grad_gate.to(dtype), boundary_check=(1,))
5 changes: 3 additions & 2 deletions trident/kernel/prelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def backward(
batch_stride: tl.int32,
y_stride: tl.int32,
x_stride: tl.int32,
dtype: tl.constexpr,
y_block_size: tl.constexpr,
x_block_size: tl.constexpr,
):
Expand Down Expand Up @@ -162,5 +163,5 @@ def backward(
grad_output = tl.load(grad_output_block_ptr, boundary_check=(1, 2))
grad_input = language.math.LeakyReLU.backward(grad_output, input, weight)
grad_weight = grad_output * tl.where(input > 0, 0, input)
tl.store(grad_input_block_ptr, grad_input, boundary_check=(1, 2))
tl.store(grad_weight_staging_block_ptr, grad_weight, boundary_check=(1, 2))
tl.store(grad_input_block_ptr, grad_input.to(dtype), boundary_check=(1, 2))
tl.store(grad_weight_staging_block_ptr, grad_weight.to(dtype), boundary_check=(1, 2))
9 changes: 8 additions & 1 deletion trident/operation/geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,14 @@ def grid(meta):

util.push_trace("kernel.GEGLU.backward")
kernel.GEGLU.backward[grid](
grad_state_gate, grad_output, state_gate, m_size, n_size, x_size, triton.next_power_of_2(x_size)
grad_state_gate,
grad_output,
state_gate,
m_size,
n_size,
x_size,
util.dtype(state_gate.dtype),
triton.next_power_of_2(x_size),
)
util.pop_trace()

Expand Down
1 change: 1 addition & 0 deletions trident/operation/prelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def grid(meta):
grad_input.stride(0),
grad_input.stride(1),
grad_input.stride(2),
util.dtype(grad_input.dtype),
)
util.pop_trace()

Expand Down

0 comments on commit 45cf51e

Please sign in to comment.