Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Fix a bug in InstanceNorm not using input status
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaehyun An(steve.an) authored Sep 15, 2023
1 parent eb17599 commit bed3c4d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
5 changes: 4 additions & 1 deletion tests/test_instance_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
from tests import util


@pytest.mark.parametrize("num_batches, y_size, x_size, use_input_stats", [(2, 8, 1024, True), (3, 5, 2000, False)])
@pytest.mark.parametrize(
"num_batches, y_size, x_size, use_input_stats",
[(2, 8, 1024, True), (2, 2, 2, True), (3, 5, 2000, False)],
)
def test_forward(num_batches, y_size, x_size, use_input_stats, device):
input = torch.randn(num_batches, y_size, x_size, device=device)

Expand Down
3 changes: 2 additions & 1 deletion trident/kernel/instance_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def forward_running_mean_running_var(
running_var_ptr: tl.tensor,
num_batches: tl.int32,
y_size: tl.int32,
x_size: tl.int32,
momentum: tl.constexpr,
batch_block_size: tl.constexpr,
):
Expand Down Expand Up @@ -188,7 +189,7 @@ def forward_running_mean_running_var(
running_mean = tl.load(running_mean_block_ptr, boundary_check=(0,))
running_mean = mean * momentum + running_mean * (1 - momentum)
running_var = tl.load(running_var_block_ptr, boundary_check=(0,))
running_var = var * momentum + running_var * (1 - momentum)
running_var = var * (x_size / (x_size - 1)) * momentum + running_var * (1 - momentum)
tl.store(running_mean_block_ptr, running_mean, boundary_check=(0,))
tl.store(running_var_block_ptr, running_var, boundary_check=(0,))

Expand Down
10 changes: 9 additions & 1 deletion trident/operation/instance_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,15 @@ def grid(meta):

util.push_trace("kernel.InstanceNorm.forward_running_mean_running_var")
kernel.InstanceNorm.forward_running_mean_running_var[grid](
mean, var, running_mean, running_var, num_batches, y_size, momentum, triton.next_power_of_2(num_batches)
mean,
var,
running_mean,
running_var,
num_batches,
y_size,
x_size,
momentum,
triton.next_power_of_2(num_batches),
)
util.pop_trace()

Expand Down

0 comments on commit bed3c4d

Please sign in to comment.