Skip to content

Commit

Permalink
simplify onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu-Zhewen committed Nov 14, 2023
1 parent fd8942e commit 97c4bea
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
8 changes: 7 additions & 1 deletion models/segmentation/cityscapes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import onnx
import os
import torch

from mmengine.config import Config, DictAction
from mmengine.runner import Runner
from models.base import TorchModelWrapper
from onnxsim import simplify

class MmsegmentationModelWrapper(TorchModelWrapper):
def load_model(self, eval=True):
Expand Down Expand Up @@ -43,4 +45,8 @@ def onnx_exporter(self, onnx_path):
random_input = torch.randn(1,3,512,1024) # todo: support other input sizes
if torch.cuda.is_available():
random_input = random_input.cuda()
torch.onnx.export(self, random_input, onnx_path, verbose=False, keep_initializers_as_inputs=True)
torch.onnx.export(self, random_input, onnx_path, verbose=False, keep_initializers_as_inputs=True)

model = onnx.load(onnx_path)
model_simp, check = simplify(model)
onnx.save(model_simp, onnx_path)
1 change: 1 addition & 0 deletions optimiser_interface/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def opt_cli_launcher(model_name, onnx_path, output_dir,
sys.argv += ['--objective', opt_obj]
sys.argv += ['--optimiser', opt_solver]
sys.argv += ['--optimiser_config_path', opt_cfg_path]
sys.argv += ['--custom_onnx']
main()
sys.argv = saved_argv

Expand Down

0 comments on commit 97c4bea

Please sign in to comment.