diff --git a/test/test_decomp.py b/test/test_decomp.py index 3a48818f736e13..e988faf006f0d6 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -915,6 +915,14 @@ def test_weight_norm_interface(self, device): self.assertTrue(torch.allclose(ref[0], res[0])) self.assertTrue(torch.allclose(ref[1], res[1])) + inp = torch.rand([30, 10], device=device) + inp2 = torch.rand([30, 1], device=device) + + self.assertEqual( + torch.ops.aten._weight_norm_interface(inp, inp2), + torch._decomp.decompositions._weight_norm_interface(inp, inp2) + ) + @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") @onlyCPU @skipIfCrossRef diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index a7606cdc525ac1..e27ccc576740ac 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -4421,7 +4421,7 @@ def squeeze_default(self: Tensor, dim: Optional[int] = None): @register_decomposition(torch.ops.aten._weight_norm_interface) -def _weight_norm_interface(x, y, dim): +def _weight_norm_interface(x, y, dim=0): # https://github.com/pytorch/pytorch/blob/852f8526c52190125446adc9a6ecbcc28fb66182/aten/src/ATen/native/WeightNorm.cpp#L58 keep_dim = tuple(i for i in range(len(x.shape)) if i != dim) norm = x.norm(2, keep_dim, keepdim=True)