From 2b10cb96df5662bf0e106ad958fda7d60b2ae91b Mon Sep 17 00:00:00 2001 From: Jaehyun An Date: Fri, 15 Sep 2023 13:55:27 +0900 Subject: [PATCH] Fix a bug in BatchNorm --- tests/test_batch_norm.py | 16 ++++++++++++++-- trident/kernel/batch_norm.py | 5 +++-- trident/operation/batch_norm.py | 1 + 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/tests/test_batch_norm.py b/tests/test_batch_norm.py index eca524c8..51528221 100644 --- a/tests/test_batch_norm.py +++ b/tests/test_batch_norm.py @@ -72,18 +72,30 @@ def train(func): @pytest.mark.parametrize("num_batches, y_size", [(3, 20)]) def test_batch_norm_1d(num_batches, y_size, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn(num_batches, y_size, **factory_kwargs) + input = torch.randn(num_batches, y_size, **factory_kwargs, requires_grad=True) + grad_output = torch.randn_like(input) output = trident.BatchNorm1d(y_size, **factory_kwargs).forward(input) assert output is not None assert output.dtype == dtype + output.backward(grad_output) + + assert input.grad is not None + assert input.grad.dtype == dtype + @pytest.mark.parametrize("num_batches, z_size, y_size, x_size", [(3, 3, 128, 128)]) def test_batch_norm_2d(num_batches, z_size, y_size, x_size, device, dtype): factory_kwargs = {"device": device, "dtype": dtype} - input = torch.randn(num_batches, z_size, y_size, x_size, **factory_kwargs) + input = torch.randn(num_batches, z_size, y_size, x_size, **factory_kwargs, requires_grad=True) + grad_output = torch.randn_like(input) output = trident.BatchNorm2d(z_size, **factory_kwargs).forward(input) assert output is not None assert output.dtype == dtype + + output.backward(grad_output) + + assert input.grad is not None + assert input.grad.dtype == dtype diff --git a/trident/kernel/batch_norm.py b/trident/kernel/batch_norm.py index 13d2feeb..3609c7e5 100644 --- a/trident/kernel/batch_norm.py +++ b/trident/kernel/batch_norm.py @@ -104,6 +104,7 @@ def backward( y_size: tl.int32, x_size: tl.int32, eps: tl.float32, + dtype: tl.constexpr, batch_block_size: tl.constexpr, x_block_size: tl.constexpr, ): @@ -152,12 +153,12 @@ def backward( grad_mean = tl.sum(tl.where(condition, grad_centered_mean, 0.0) / denominator) grad_input = grad_centered_mean - grad_mean - tl.store(grad_input_block_ptr, grad_input, boundary_check=(0, 1)) + tl.store(grad_input_block_ptr, grad_input.to(dtype), boundary_check=(0, 1)) if grad_weight_ptr: input_norm = centered_mean / std grad_weight = tl.sum(input_norm * grad_output) - tl.store(grad_weight_ptr + pid, grad_weight) + tl.store(grad_weight_ptr + pid, grad_weight.to(dtype)) if grad_bias_ptr: grad_bias = tl.sum(grad_output) diff --git a/trident/operation/batch_norm.py b/trident/operation/batch_norm.py index 93ad057a..2d13a41a 100644 --- a/trident/operation/batch_norm.py +++ b/trident/operation/batch_norm.py @@ -118,6 +118,7 @@ def grid(meta): y_size, x_size, eps, + util.dtype(input.dtype), batch_block_size=triton.next_power_of_2(num_batches), x_block_size=triton.next_power_of_2(x_size), )