From dd654cb27f437e0e2d30908cef5be8c3cd2337b0 Mon Sep 17 00:00:00 2001 From: twata Date: Tue, 9 May 2023 10:21:26 +0000 Subject: [PATCH] [onnx] Fix grad op domain --- pytorch_pfn_extras/onnx/_grad.py | 2 +- tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pytorch_pfn_extras/onnx/_grad.py b/pytorch_pfn_extras/onnx/_grad.py index 244976fa7..00bd8c4de 100644 --- a/pytorch_pfn_extras/onnx/_grad.py +++ b/pytorch_pfn_extras/onnx/_grad.py @@ -85,7 +85,7 @@ def _grad( # type: ignore @staticmethod def symbolic(g, output, grad_output, *inputs): # type: ignore return g.op( - "ai.onnx.preview::Gradient", + "ai.onnx.preview.training::Gradient", *inputs, xs_s=input_names, zs_s=[], diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py index 4a885ca0f..a33bfbeec 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py @@ -62,7 +62,7 @@ def forward(self, x): @pytest.mark.parametrize("use_pfto", [False, True]) -@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning") +@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview.training..Gradient type is missing:UserWarning") @pytest.mark.filterwarnings("ignore:Specified output_names .*:UserWarning") @pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_grad(use_pfto: bool): @@ -103,6 +103,7 @@ def forward(self, x): ) actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx')) + print(actual_onnx) named_nodes = {n.name: n for n in actual_onnx.graph.node} if pytorch_pfn_extras.requires("1.13") and not use_pfto: assert '/_ppe_as_out_module/conv/Conv' in named_nodes @@ -136,7 +137,7 @@ def forward(self, x): @pytest.mark.parametrize("use_pfto", [False, True]) -@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning") +@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview.training..Gradient type is missing:UserWarning") @pytest.mark.filterwarnings("ignore:Specified output_names .*:UserWarning") @pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_grad_multiple_times(use_pfto: bool): @@ -218,7 +219,7 @@ def forward(self, x): @pytest.mark.parametrize("use_pfto", [False, True]) -@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning") +@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview.training..Gradient type is missing:UserWarning") @pytest.mark.filterwarnings("ignore:Specified output_names .*:UserWarning") @pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_grad_with_multiple_inputs(use_pfto: bool):