diff --git a/pytest.ini b/pytest.ini index 8fc65f7eb..177fb9c25 100644 --- a/pytest.ini +++ b/pytest.ini @@ -21,7 +21,6 @@ filterwarnings = # For packages importing distutils in py 3.10 (tensorboard) ignore:.*distutils package is deprecated and slated:DeprecationWarning # For warnings from torch 1.13 - ignore:'torch.onnx._patch_torch._graph_op' is deprecated:FutureWarning # For ipywidgets 8.0.3 ignore:Widget.widgets is deprecated.:DeprecationWarning ignore:Widget.widget_types is deprecated.:DeprecationWarning diff --git a/pytorch_pfn_extras/onnx/__init__.py b/pytorch_pfn_extras/onnx/__init__.py index 813dc73d9..e2e110bcd 100644 --- a/pytorch_pfn_extras/onnx/__init__.py +++ b/pytorch_pfn_extras/onnx/__init__.py @@ -8,6 +8,7 @@ from pytorch_pfn_extras.onnx.annotate import annotate # NOQA from pytorch_pfn_extras.onnx.annotate import apply_annotation # NOQA from pytorch_pfn_extras.onnx.annotate import scoped_anchor # NOQA + from pytorch_pfn_extras.onnx._helper import suppress_symbolic_warnings # NOQA from pytorch_pfn_extras.onnx._as_output import as_output # NOQA from pytorch_pfn_extras.onnx._grad import grad # NOQA from pytorch_pfn_extras.onnx.load import load_model # NOQA diff --git a/pytorch_pfn_extras/onnx/_as_output.py b/pytorch_pfn_extras/onnx/_as_output.py index fdbbf1dce..8c0b24b9f 100644 --- a/pytorch_pfn_extras/onnx/_as_output.py +++ b/pytorch_pfn_extras/onnx/_as_output.py @@ -4,6 +4,7 @@ import threading from contextlib import contextmanager import warnings +from pytorch_pfn_extras.onnx._helper import suppress_symbolic_warnings _outputs = threading.local() @@ -98,7 +99,8 @@ def trace( _outputs.outputs = None -# Add Identity function to prevent constant folding in torch.onnx +# Add Identity function to cevent constant folding in torch.onnx +@suppress_symbolic_warnings class _ExplicitIdentity(torch.autograd.Function): @staticmethod def forward( # type: ignore diff --git a/pytorch_pfn_extras/onnx/_grad.py b/pytorch_pfn_extras/onnx/_grad.py index 244976fa7..aeabe4e65 100644 --- a/pytorch_pfn_extras/onnx/_grad.py +++ b/pytorch_pfn_extras/onnx/_grad.py @@ -4,6 +4,7 @@ import torch import torch.onnx import threading +from pytorch_pfn_extras.onnx._helper import suppress_symbolic_warnings from pytorch_pfn_extras.onnx._as_output import as_output @@ -48,6 +49,7 @@ def grad( input_names.append(input_name) inputs_l[i] = as_output(input_name, input, add_identity=False) + @suppress_symbolic_warnings class _Gradient(torch.autograd.Function): @staticmethod def forward( # type: ignore diff --git a/pytorch_pfn_extras/onnx/_helper.py b/pytorch_pfn_extras/onnx/_helper.py index 757bbc9eb..a5bb7f885 100644 --- a/pytorch_pfn_extras/onnx/_helper.py +++ b/pytorch_pfn_extras/onnx/_helper.py @@ -1,5 +1,7 @@ import torch -from typing import Callable, Any +from typing import Callable, Any, Type, TypeVar + +import pytorch_pfn_extras as ppe def _detach(x: Any) -> Any: @@ -20,3 +22,40 @@ def no_grad(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: out = fn(*args, **kwargs) # torch.no_grad() does not export `detach` op when tracing return _detach(out) + + +T = TypeVar('T') + + +# Using hack from https://stackoverflow.com/a/56856290 +def suppress_symbolic_warnings(cls: Type[T]) -> Type[T]: + global torch + assert issubclass(cls, torch.autograd.Function) + assert hasattr(cls, "symbolic") + + if (not ppe.requires("1.13")) or ppe.requires("2.0"): + return cls + + import torch.onnx._internal.jit_utils + import torch.onnx._globals + + orig_symbolic = cls.symbolic + + # Untyped due to type checker in torch.onnx + @staticmethod # type: ignore[misc] + def new_symbolic(g, *args, **kwargs): # type: ignore[no-untyped-def] + if isinstance(g, torch._C.Graph): + ctx = torch.onnx._internal.jit_utils.GraphContext( + graph=g, + block=g.block(), + opset=torch.onnx._globals.GLOBALS.export_onnx_opset_version, + original_node=None, # type: ignore[arg-type] + params_dict=torch.onnx.utils._params_dict, + env={}, + ) + return orig_symbolic(ctx, *args, **kwargs) + return orig_symbolic(g, *args, **kwargs) + + cls.symbolic = new_symbolic + + return cls diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py index 59a427a33..7ce89e03a 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py @@ -3,6 +3,7 @@ import torch from pytorch_pfn_extras_tests.onnx_tests.utils import run_model_test +from pytorch_pfn_extras.onnx._helper import suppress_symbolic_warnings def test_simple(): @@ -37,6 +38,7 @@ def forward(self, x): @pytest.mark.filterwarnings("ignore::torch.jit.TracerWarning") def test_symbolic_function(): + @suppress_symbolic_warnings class Func(torch.autograd.Function): @staticmethod def forward(ctx, a): @@ -202,6 +204,7 @@ def forward(self, *hidden): @pytest.mark.filterwarnings("ignore:The shape inference of org.chainer..Add type is missing:UserWarning") def test_custom_opsets(): + @suppress_symbolic_warnings class Func(torch.autograd.Function): @staticmethod def forward(ctx, a): diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_lax.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_lax.py index 55f5f2ac8..2628291c6 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_lax.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_lax.py @@ -13,7 +13,6 @@ from pytorch_pfn_extras_tests.onnx_tests.test_export_testcase import _helper -@pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_fori_loop_no_export(): if not pytorch_pfn_extras.requires("1.8.0"): pytest.skip('skip for PyTorch 1.7 or earlier') @@ -41,7 +40,6 @@ def forward(self, x): torch.testing.assert_close(y, y_expected) -@pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_fori_loop(): if not pytorch_pfn_extras.requires('1.8.0'): pytest.skip('skip for PyTorch 1.7 or earlier') @@ -80,7 +78,6 @@ def forward(self, x): torch.testing.assert_close(expected, torch.tensor(actual[0])) -@pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_fori_loop_with_tuple_state(): if not pytorch_pfn_extras.requires('1.8.0'): pytest.skip('skip for PyTorch 1.7 or earlier') @@ -123,7 +120,6 @@ def body(it, val): torch.testing.assert_close(expected, torch.tensor(actual[0])) -@pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_while_loop_no_export(): if not pytorch_pfn_extras.requires('1.8.0'): pytest.skip('skip for PyTorch 1.7 or earlier') @@ -153,7 +149,6 @@ def body_fn(x): assert out.sum().item() > 100 -@pytest.mark.filterwarnings("ignore::DeprecationWarning") @pytest.mark.filterwarnings("ignore:Converting a tensor to a Python boolean might cause the trace to be incorrect:torch.jit.TracerWarning") def test_while_loop(): if not pytorch_pfn_extras.requires('1.8.0'): @@ -230,7 +225,6 @@ def false_fn(x): assert out[1] == -1 -@pytest.mark.filterwarnings("ignore::DeprecationWarning") @pytest.mark.filterwarnings("ignore:Converting a tensor to a Python boolean might cause the trace to be incorrect:torch.jit.TracerWarning") def test_cond(): if not pytorch_pfn_extras.requires('1.8.0'): @@ -277,7 +271,6 @@ def false_fn(x): torch.testing.assert_close(expected, torch.tensor(actual[0])) -@pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_lax_multiple_times(): if not pytorch_pfn_extras.requires('1.8.0'): pytest.skip('skip for PyTorch 1.7 or earlier') @@ -323,7 +316,6 @@ def body1(it, h): torch.testing.assert_close(expected, torch.tensor(actual[0])) -@pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_lax_nested(): if not pytorch_pfn_extras.requires('1.8.0'): pytest.skip('skip for PyTorch 1.7 or earlier')