Skip to content

Commit

Permalink
Fix (test): ort provider is now mandatory
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 3, 2023
1 parent 9d24ace commit 5f5c42a
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 9 deletions.
4 changes: 2 additions & 2 deletions docs/tutorials/onnx_export.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@
"exported_model = export_onnx_qcdq(model, args=inp, export_path=path, opset_version=13)\n",
"\n",
"sess_opt = ort.SessionOptions()\n",
"sess = ort.InferenceSession(path, sess_opt)\n",
"sess = ort.InferenceSession(path, sess_opt, providers=['CPUExecutionProvider'])\n",
"input_name = sess.get_inputs()[0].name\n",
"pred_onx = sess.run(None, {input_name: inp.numpy()})[0]\n",
"\n",
Expand Down Expand Up @@ -954,7 +954,7 @@
"exported_model = export_onnx_qop(model, args=inp, export_path=path)\n",
"\n",
"sess_opt = ort.SessionOptions()\n",
"sess = ort.InferenceSession(path, sess_opt)\n",
"sess = ort.InferenceSession(path, sess_opt, providers=['CPUExecutionProvider'])\n",
"input_name = sess.get_inputs()[0].name\n",
"pred_onx = sess.run(None, {input_name: inp.numpy()})[0]\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/quant_recurrent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@
"import onnxruntime as ort\n",
"import numpy as np\n",
"\n",
"sess = ort.InferenceSession(export_path)\n",
"sess = ort.InferenceSession(export_path, providers=['CPUExecutionProvider'])\n",
"input_name = sess.get_inputs()[0].name\n",
"np_input = np.random.uniform(size=(5, 1, 10)).astype(np.float32) # (seq_len, batch_size, input_size)\n",
"pred_onnx = sess.run(None, {input_name: np_input})"
Expand Down Expand Up @@ -1031,7 +1031,7 @@
"import onnxruntime as ort\n",
"import numpy as np\n",
"\n",
"sess = ort.InferenceSession(export_path)\n",
"sess = ort.InferenceSession(export_path, providers=['CPUExecutionProvider'])\n",
"input_name = sess.get_inputs()[0].name\n",
"np_input = np.random.uniform(size=(5, 1, 10)).astype(np.float32) # (seq_len, batch_size, input_size)\n",
"pred_onnx = sess.run(None, {input_name: np_input})"
Expand Down
4 changes: 2 additions & 2 deletions notebooks/ONNX_export_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@
"exported_model = export_onnx_qcdq(model, args=inp, export_path=path, opset_version=13)\n",
"\n",
"sess_opt = ort.SessionOptions()\n",
"sess = ort.InferenceSession(path, sess_opt)\n",
"sess = ort.InferenceSession(path, sess_opt, providers=['CPUExecutionProvider'])\n",
"input_name = sess.get_inputs()[0].name\n",
"pred_onx = sess.run(None, {input_name: inp.numpy()})[0]\n",
"\n",
Expand Down Expand Up @@ -954,7 +954,7 @@
"exported_model = export_onnx_qop(model, args=inp, export_path=path)\n",
"\n",
"sess_opt = ort.SessionOptions()\n",
"sess = ort.InferenceSession(path, sess_opt)\n",
"sess = ort.InferenceSession(path, sess_opt, providers=['CPUExecutionProvider'])\n",
"input_name = sess.get_inputs()[0].name\n",
"pred_onx = sess.run(None, {input_name: inp.numpy()})[0]\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions notebooks/quantized_recurrent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@
"import onnxruntime as ort\n",
"import numpy as np\n",
"\n",
"sess = ort.InferenceSession(export_path)\n",
"sess = ort.InferenceSession(export_path, providers=['CPUExecutionProvider'])\n",
"input_name = sess.get_inputs()[0].name\n",
"np_input = np.random.uniform(size=(5, 1, 10)).astype(np.float32) # (seq_len, batch_size, input_size)\n",
"pred_onnx = sess.run(None, {input_name: np_input})"
Expand Down Expand Up @@ -1031,7 +1031,7 @@
"import onnxruntime as ort\n",
"import numpy as np\n",
"\n",
"sess = ort.InferenceSession(export_path)\n",
"sess = ort.InferenceSession(export_path, providers=['CPUExecutionProvider'])\n",
"input_name = sess.get_inputs()[0].name\n",
"np_input = np.random.uniform(size=(5, 1, 10)).astype(np.float32) # (seq_len, batch_size, input_size)\n",
"pred_onnx = sess.run(None, {input_name: np_input})"
Expand Down
2 changes: 1 addition & 1 deletion tests/brevitas_ort/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def compute_ort(export_name, np_input):
run_opt.log_severity_level = 0 # Highest verbosity
run_opt.log_verbosity_level = 0 # Highest verbosity

sess = ort.InferenceSession(export_name, sess_opt)
sess = ort.InferenceSession(export_name, sess_opt, providers=['CPUExecutionProvider'])
input_name = sess.get_inputs()[0].name
pred_onx = sess.run(None, {input_name: np_input}, run_options=run_opt)
return pred_onx
Expand Down

0 comments on commit 5f5c42a

Please sign in to comment.