diff --git a/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py b/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py index f015f150ecd..5a16c59598d 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py @@ -43,6 +43,7 @@ import os import copy from collections import defaultdict +import logging import torch import torch.nn as nn import torch.onnx.symbolic_caffe2 @@ -104,6 +105,24 @@ } +def export_to_onnx(*args, **kwargs): + """ + A wrapper function to export torch module to onnx + + `enable_checker` is ignored for pytorch >= 1.10 + """ + enable_checker = kwargs.get('enable_onnx_checker', None) + if version.parse(torch.__version__) >= version.parse("1.10") and not enable_checker: + logging.warning('Export torch module to onnx with `enable_onnx_checker` deprecated') + kwargs.pop('enable_onnx_checker') + try: + torch.onnx.export(*args, **kwargs) + except torch.onnx.utils.ONNXCheckerError as e: + logging.error('Error when exporting to onnx: {}, could be ignored'.format(e)) + else: + torch.onnx.export(*args, **kwargs) + + if version.parse(torch.__version__) >= version.parse("1.9"): onnx_subgraph_op_to_pytorch_module_param_name = { torch.nn.GroupNorm: @@ -656,10 +675,18 @@ def _create_onnx_model_with_markers(cls, dummy_input, pt_model, working_dir, onn if is_conditional: dummy_output = model(*dummy_input) scripted_model = torch.jit.script(model) - torch.onnx.export(scripted_model, dummy_input, temp_file, example_outputs=dummy_output, - enable_onnx_checker=False, **onnx_export_args.kwargs) + export_to_onnx(scripted_model, + dummy_input, + temp_file, + example_outputs=dummy_output, + enable_onnx_checker=False, + **onnx_export_args.kwargs) else: - torch.onnx.export(model, dummy_input, temp_file, enable_onnx_checker=False, **onnx_export_args.kwargs) + export_to_onnx(model, + dummy_input, + temp_file, + enable_onnx_checker=False, + **onnx_export_args.kwargs) onnx_model = onnx.load(temp_file) return onnx_model