Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Fix a bug in BatchNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaehyun An committed Sep 15, 2023
1 parent f23f891 commit 7287187
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
16 changes: 14 additions & 2 deletions tests/test_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,30 @@ def train(func):
@pytest.mark.parametrize("num_batches, y_size", [(3, 20)])
def test_batch_norm_1d(num_batches, y_size, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(num_batches, y_size, **factory_kwargs)
input = torch.randn(num_batches, y_size, **factory_kwargs, requires_grad=True)
grad_output = torch.randn_like(input)

output = trident.BatchNorm1d(y_size, **factory_kwargs).forward(input)
assert output is not None
assert output.dtype == dtype

output.backward(grad_output)

assert input.grad is not None
assert input.grad.dtype == dtype


@pytest.mark.parametrize("num_batches, z_size, y_size, x_size", [(3, 3, 128, 128)])
def test_batch_norm_2d(num_batches, z_size, y_size, x_size, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
input = torch.randn(num_batches, z_size, y_size, x_size, **factory_kwargs)
input = torch.randn(num_batches, z_size, y_size, x_size, **factory_kwargs, requires_grad=True)
grad_output = torch.randn_like(input)

output = trident.BatchNorm2d(z_size, **factory_kwargs).forward(input)
assert output is not None
assert output.dtype == dtype

output.backward(grad_output)

assert input.grad is not None
assert input.grad.dtype == dtype
5 changes: 3 additions & 2 deletions trident/kernel/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def backward(
y_size: tl.int32,
x_size: tl.int32,
eps: tl.float32,
dtype: tl.constexpr,
batch_block_size: tl.constexpr,
x_block_size: tl.constexpr,
):
Expand Down Expand Up @@ -152,12 +153,12 @@ def backward(
grad_mean = tl.sum(tl.where(condition, grad_centered_mean, 0.0) / denominator)
grad_input = grad_centered_mean - grad_mean

tl.store(grad_input_block_ptr, grad_input, boundary_check=(0, 1))
tl.store(grad_input_block_ptr, grad_input.to(dtype), boundary_check=(0, 1))

if grad_weight_ptr:
input_norm = centered_mean / std
grad_weight = tl.sum(input_norm * grad_output)
tl.store(grad_weight_ptr + pid, grad_weight)
tl.store(grad_weight_ptr + pid, grad_weight.to(dtype))

if grad_bias_ptr:
grad_bias = tl.sum(grad_output)
Expand Down
1 change: 1 addition & 0 deletions trident/operation/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def grid(meta):
y_size,
x_size,
eps,
util.dtype(input.dtype),
batch_block_size=triton.next_power_of_2(num_batches),
x_block_size=triton.next_power_of_2(x_size),
)
Expand Down

0 comments on commit 7287187

Please sign in to comment.