diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 8db67b26e7..07f20dc184 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -431,9 +431,9 @@ def forward( last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.session.run(None, onnx_inputs) - model_outputs = self._prepare_model_outputs(use_torch, onnx_outputs) + model_outputs = self._prepare_model_outputs(use_torch, *onnx_outputs) last_hidden_state = model_outputs["last_hidden_state"] @@ -473,9 +473,9 @@ def forward( last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.session.run(None, onnx_inputs) - model_outputs = self._prepare_model_outputs(use_torch, onnx_outputs) + model_outputs = self._prepare_model_outputs(use_torch, *onnx_outputs) last_hidden_state = model_outputs["last_hidden_state"]