Skip to content

Commit

Permalink
create token type ids when needed
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 24, 2024
1 parent 59d6f7f commit 322a8af
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,6 @@ def _prepare_onnx_inputs(
self, use_torch: bool, **inputs: Union[torch.Tensor, np.ndarray]
) -> Dict[str, np.ndarray]:
onnx_inputs = {}

# converts pytorch inputs into numpy inputs for onnx
for input_name in self.input_names.keys():
onnx_inputs[input_name] = inputs.pop(input_name)
Expand Down Expand Up @@ -1086,6 +1085,9 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

if token_type_ids is None and "token_type_ids" in self.input_names:
token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids)

if self.device.type == "cuda" and self.use_io_binding:
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
input_ids,
Expand Down Expand Up @@ -1241,6 +1243,9 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

if token_type_ids is None and "token_type_ids" in self.input_names:
token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids)

if self.device.type == "cuda" and self.use_io_binding:
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
input_ids,
Expand Down Expand Up @@ -1330,6 +1335,9 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

if token_type_ids is None and "token_type_ids" in self.input_names:
token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids)

if self.device.type == "cuda" and self.use_io_binding:
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
input_ids,
Expand Down Expand Up @@ -1437,6 +1445,9 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

if token_type_ids is None and "token_type_ids" in self.input_names:
token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids)

if self.device.type == "cuda" and self.use_io_binding:
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
input_ids,
Expand Down Expand Up @@ -1527,6 +1538,9 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

if token_type_ids is None and "token_type_ids" in self.input_names:
token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids)

if self.device.type == "cuda" and self.use_io_binding:
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
input_ids,
Expand Down Expand Up @@ -1610,6 +1624,9 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

if token_type_ids is None and "token_type_ids" in self.input_names:
token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids)

if self.device.type == "cuda" and self.use_io_binding:
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
input_ids,
Expand Down

0 comments on commit 322a8af

Please sign in to comment.