Skip to content

Commit

Permalink
export to onnx for torch>=1.10
Browse files Browse the repository at this point in the history
Signed-off-by: xmfbit <[email protected]>
  • Loading branch information
xmfbit committed May 2, 2022
1 parent 7e49dde commit 31d9a3c
Showing 1 changed file with 30 additions and 3 deletions.
33 changes: 30 additions & 3 deletions TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import os
import copy
from collections import defaultdict
import logging
import torch
import torch.nn as nn
import torch.onnx.symbolic_caffe2
Expand Down Expand Up @@ -104,6 +105,24 @@
}


def export_to_onnx(*args, **kwargs):
"""
A wrapper function to export torch module to onnx
`enable_checker` is ignored for pytorch >= 1.10
"""
enable_checker = kwargs.get('enable_onnx_checker', None)
if version.parse(torch.__version__) >= version.parse("1.10") and not enable_checker:
logging.warning('Export torch module to onnx with `enable_onnx_checker` deprecated')
kwargs.pop('enable_onnx_checker')
try:
torch.onnx.export(*args, **kwargs)
except torch.onnx.utils.ONNXCheckerError as e:
logging.error('Error when exporting to onnx: {}, could be ignored'.format(e))
else:
torch.onnx.export(*args, **kwargs)


if version.parse(torch.__version__) >= version.parse("1.9"):
onnx_subgraph_op_to_pytorch_module_param_name = {
torch.nn.GroupNorm:
Expand Down Expand Up @@ -656,10 +675,18 @@ def _create_onnx_model_with_markers(cls, dummy_input, pt_model, working_dir, onn
if is_conditional:
dummy_output = model(*dummy_input)
scripted_model = torch.jit.script(model)
torch.onnx.export(scripted_model, dummy_input, temp_file, example_outputs=dummy_output,
enable_onnx_checker=False, **onnx_export_args.kwargs)
export_to_onnx(scripted_model,
dummy_input,
temp_file,
example_outputs=dummy_output,
enable_onnx_checker=False,
**onnx_export_args.kwargs)
else:
torch.onnx.export(model, dummy_input, temp_file, enable_onnx_checker=False, **onnx_export_args.kwargs)
export_to_onnx(model,
dummy_input,
temp_file,
enable_onnx_checker=False,
**onnx_export_args.kwargs)
onnx_model = onnx.load(temp_file)
return onnx_model

Expand Down

0 comments on commit 31d9a3c

Please sign in to comment.