diff --git a/models/segmentation/cityscapes.py b/models/segmentation/cityscapes.py index 80bbee7..5f1bee1 100644 --- a/models/segmentation/cityscapes.py +++ b/models/segmentation/cityscapes.py @@ -1,9 +1,11 @@ +import onnx import os import torch from mmengine.config import Config, DictAction from mmengine.runner import Runner from models.base import TorchModelWrapper +from onnxsim import simplify class MmsegmentationModelWrapper(TorchModelWrapper): def load_model(self, eval=True): @@ -43,4 +45,8 @@ def onnx_exporter(self, onnx_path): random_input = torch.randn(1,3,512,1024) # todo: support other input sizes if torch.cuda.is_available(): random_input = random_input.cuda() - torch.onnx.export(self, random_input, onnx_path, verbose=False, keep_initializers_as_inputs=True) \ No newline at end of file + torch.onnx.export(self, random_input, onnx_path, verbose=False, keep_initializers_as_inputs=True) + + model = onnx.load(onnx_path) + model_simp, check = simplify(model) + onnx.save(model_simp, onnx_path) \ No newline at end of file diff --git a/optimiser_interface/utils.py b/optimiser_interface/utils.py index ffeaeb1..d54d761 100644 --- a/optimiser_interface/utils.py +++ b/optimiser_interface/utils.py @@ -31,6 +31,7 @@ def opt_cli_launcher(model_name, onnx_path, output_dir, sys.argv += ['--objective', opt_obj] sys.argv += ['--optimiser', opt_solver] sys.argv += ['--optimiser_config_path', opt_cfg_path] + sys.argv += ['--custom_onnx'] main() sys.argv = saved_argv