Skip to content

Commit

Permalink
feat: allow custom input/output names in onnx files
Browse files Browse the repository at this point in the history
  • Loading branch information
clementpoiret committed Feb 3, 2024
1 parent dde30db commit 5b80121
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion hsf/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def __init__(self, engine_name: str, engine_settings: DictConfig,

def __call__(self, x):
if self.engine_name == "onnxruntime":
return self.engine.run(None, {"input": x})
feed_names = [i.name for i in self.engine.get_inputs()]
assert len(feed_names) == 1, "Only one input is supported"
return self.engine.run(None, {feed_names[0]: x})

elif self.engine_name == "deepsparse":
return self.engine.run([x])
Expand Down

0 comments on commit 5b80121

Please sign in to comment.