diff --git a/tests/test_instance_norm.py b/tests/test_instance_norm.py index d42d0ed3..297e8bbe 100644 --- a/tests/test_instance_norm.py +++ b/tests/test_instance_norm.py @@ -19,7 +19,10 @@ from tests import util -@pytest.mark.parametrize("num_batches, y_size, x_size, use_input_stats", [(2, 8, 1024, True), (3, 5, 2000, False)]) +@pytest.mark.parametrize( + "num_batches, y_size, x_size, use_input_stats", + [(2, 8, 1024, True), (2, 2, 2, True), (3, 5, 2000, False)], +) def test_forward(num_batches, y_size, x_size, use_input_stats, device): input = torch.randn(num_batches, y_size, x_size, device=device) diff --git a/trident/kernel/instance_norm.py b/trident/kernel/instance_norm.py index 6c879f0e..29a5a388 100644 --- a/trident/kernel/instance_norm.py +++ b/trident/kernel/instance_norm.py @@ -143,6 +143,7 @@ def forward_running_mean_running_var( running_var_ptr: tl.tensor, num_batches: tl.int32, y_size: tl.int32, + x_size: tl.int32, momentum: tl.constexpr, batch_block_size: tl.constexpr, ): @@ -188,7 +189,7 @@ def forward_running_mean_running_var( running_mean = tl.load(running_mean_block_ptr, boundary_check=(0,)) running_mean = mean * momentum + running_mean * (1 - momentum) running_var = tl.load(running_var_block_ptr, boundary_check=(0,)) - running_var = var * momentum + running_var * (1 - momentum) + running_var = var * (x_size / (x_size - 1)) * momentum + running_var * (1 - momentum) tl.store(running_mean_block_ptr, running_mean, boundary_check=(0,)) tl.store(running_var_block_ptr, running_var, boundary_check=(0,)) diff --git a/trident/operation/instance_norm.py b/trident/operation/instance_norm.py index b8235efb..5e6f361b 100644 --- a/trident/operation/instance_norm.py +++ b/trident/operation/instance_norm.py @@ -96,7 +96,15 @@ def grid(meta): util.push_trace("kernel.InstanceNorm.forward_running_mean_running_var") kernel.InstanceNorm.forward_running_mean_running_var[grid]( - mean, var, running_mean, running_var, num_batches, y_size, momentum, triton.next_power_of_2(num_batches) + mean, + var, + running_mean, + running_var, + num_batches, + y_size, + x_size, + momentum, + triton.next_power_of_2(num_batches), ) util.pop_trace()