Skip to content

Commit

Permalink
Fix onnx export when no dir is specified.
Browse files Browse the repository at this point in the history
  • Loading branch information
hqucms committed Feb 28, 2024
1 parent 59235a5 commit 70b4e43
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
install_requires.append(line)

setup(name="weaver-core",
version='0.4.13',
version='0.4.14',
description="A streamlined deep-learning framework for high energy physics",
long_description_content_type="text/markdown",
author="H. Qu, C. Li",
Expand Down
4 changes: 3 additions & 1 deletion weaver/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,8 @@ def onnx(args):
model = model.cpu()
model.eval()

if not os.path.dirname(args.export_onnx):
args.export_onnx = os.path.join(os.path.dirname(model_path), args.export_onnx)
os.makedirs(os.path.dirname(args.export_onnx), exist_ok=True)
inputs = tuple(
torch.ones(model_info['input_shapes'][k], dtype=torch.float32) for k in model_info['input_names'])
Expand Down Expand Up @@ -879,7 +881,7 @@ def _main(args):
del test_loader

if args.predict_output:
if '/' not in args.predict_output:
if not os.path.dirname(predict_output):
predict_output = os.path.join(
os.path.dirname(args.model_prefix),
'predict_output', args.predict_output)
Expand Down

0 comments on commit 70b4e43

Please sign in to comment.