From 322a8af9fe0ba853e51ec9614a679d92a805c359 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 24 Oct 2024 19:02:52 +0200 Subject: [PATCH] create token type ids when needed --- optimum/onnxruntime/modeling_ort.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index ce1d68536ac..8e5a814b689 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -931,7 +931,6 @@ def _prepare_onnx_inputs( self, use_torch: bool, **inputs: Union[torch.Tensor, np.ndarray] ) -> Dict[str, np.ndarray]: onnx_inputs = {} - # converts pytorch inputs into numpy inputs for onnx for input_name in self.input_names.keys(): onnx_inputs[input_name] = inputs.pop(input_name) @@ -1086,6 +1085,9 @@ def forward( use_torch = isinstance(input_ids, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) + if token_type_ids is None and "token_type_ids" in self.input_names: + token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids) + if self.device.type == "cuda" and self.use_io_binding: io_binding, output_shapes, output_buffers = self.prepare_io_binding( input_ids, @@ -1241,6 +1243,9 @@ def forward( use_torch = isinstance(input_ids, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) + if token_type_ids is None and "token_type_ids" in self.input_names: + token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids) + if self.device.type == "cuda" and self.use_io_binding: io_binding, output_shapes, output_buffers = self.prepare_io_binding( input_ids, @@ -1330,6 +1335,9 @@ def forward( use_torch = isinstance(input_ids, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) + if token_type_ids is None and "token_type_ids" in self.input_names: + token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids) + if self.device.type == "cuda" and self.use_io_binding: io_binding, output_shapes, output_buffers = self.prepare_io_binding( input_ids, @@ -1437,6 +1445,9 @@ def forward( use_torch = isinstance(input_ids, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) + if token_type_ids is None and "token_type_ids" in self.input_names: + token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids) + if self.device.type == "cuda" and self.use_io_binding: io_binding, output_shapes, output_buffers = self.prepare_io_binding( input_ids, @@ -1527,6 +1538,9 @@ def forward( use_torch = isinstance(input_ids, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) + if token_type_ids is None and "token_type_ids" in self.input_names: + token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids) + if self.device.type == "cuda" and self.use_io_binding: io_binding, output_shapes, output_buffers = self.prepare_io_binding( input_ids, @@ -1610,6 +1624,9 @@ def forward( use_torch = isinstance(input_ids, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) + if token_type_ids is None and "token_type_ids" in self.input_names: + token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids) + if self.device.type == "cuda" and self.use_io_binding: io_binding, output_shapes, output_buffers = self.prepare_io_binding( input_ids,