Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Need help with conversion to ONNX #211

Open
surajs52 opened this issue Feb 14, 2024 · 0 comments
Open

Need help with conversion to ONNX #211

surajs52 opened this issue Feb 14, 2024 · 0 comments

Comments

@surajs52
Copy link

Hey @foolwood,
i need help with conversion to ONNX format.
My python script using torch.onnx.export() for conversion looks like this:

from tools.test import *
#from siammask.models import Custom
from custom import Custom

parser = argparse.ArgumentParser(description='PyTorch Tracking Demo')

parser.add_argument('--resume', default='', type=str, required=True,
metavar='PATH',help='path to latest checkpoint (default: none)')
parser.add_argument('--config', dest='config', default='config_davis.json',
help='hyper-parameter of SiamMask in json format')
#parser.add_argument('--base_path', default='../../data/tennis', help='datasets')
#parser.add_argument('--cpu', action='store_true', help='cpu mode')
args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

cfg = load_config(args)
siammask = Custom(anchors=cfg['anchors'])

siammask.load_state_dict(torch.load('SiamMask_DAVIS.pth')["state_dict"])

siammask.eval().to(device)
siammask.half()

template = torch.randn(1, 3, 127, 127).to(device).half()
search = torch.randn(1, 3, 255, 255).to(device).half()
label_cls = torch.randn(1, 1, 5).to(device).half()
input_dict = {'template': template, 'search': search} #, 'label_cls': label_cls}

torch.onnx.export(siammask, input_dict, "SiamMask_DAVIS_half_test.onnx",
input_names=['template', 'search'],
opset_version=11,
do_constant_folding=True,
verbose=True,
output_names=['rpn_pred_cls', 'rpn_pred_loc', 'pred_mask'],
dynamic_axes={'search': {0: 'batch_size'}, # if you want batch size to be dynamic
'rpn_pred_cls': {0: 'batch_size'},
'rpn_pred_loc': {0: 'batch_size'},
'pred_mask': {0: 'batch_size'}})

The output looks like this:
[2024-02-14 15:40:21,552-rk0-features.py# 66] Current training 0 layers:

[2024-02-14 15:40:21,554-rk0-features.py# 66] Current training 1 layers:

====== Diagnostic Run torch.onnx.export version 1.14.0a0+44dac51c.nv23.01 ======
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

Traceback (most recent call last):
File "../../tools/torch2onnx.py", line 78, in
torch.onnx.export(siammask, input_dict, "SiamMask_DAVIS_half_test.onnx",
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/onnx/utils.py", line 506, in export
_export(
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/onnx/utils.py", line 1533, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/onnx/utils.py", line 989, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/onnx/utils.py", line 893, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/jit/_trace.py", line 1260, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1480, in _call_impl
return forward_call(*args, **kwargs)
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/jit/_trace.py", line 127, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/jit/_trace.py", line 118, in wrapper
outs.append(self.inner(*trace_inputs))
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1480, in _call_impl
return forward_call(*args, **kwargs)
File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1467, in _slow_forward
result = self.forward(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'input'

To be specific, i need help figuring out the exact set of input and output parameters for torch.onnx.export() to perform the conversion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant