Skip to content

Commit

Permalink
hot fix io binding, remove its dependency to the order of inputs and …
Browse files Browse the repository at this point in the history
…make sure it's actually being tested
  • Loading branch information
IlyasMoutawwakil committed Jan 15, 2025
1 parent 941484a commit 18e887d
Show file tree
Hide file tree
Showing 5 changed files with 389 additions and 432 deletions.
91 changes: 31 additions & 60 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..utils.logging import warn_once
from .io_binding import TypeHelper
from .modeling_ort import ORTModel
from .utils import get_ordered_input_names, logging
from .utils import logging


logger = logging.get_logger(__name__)
Expand All @@ -38,6 +38,11 @@ class ORTModelPart:
It has its own `onnxruntime.InferenceSession`, and can perform a forward pass.
"""

# should be in an ORTMixin
_prepare_io_binding = ORTModel._prepare_io_binding
_prepare_output_buffer = ORTModel._prepare_output_buffer
_output_shape_inference = ORTModel._output_shape_inference

_prepare_onnx_inputs = ORTModel._prepare_onnx_inputs
_prepare_onnx_outputs = ORTModel._prepare_onnx_outputs

Expand All @@ -48,10 +53,12 @@ def __init__(self, session: InferenceSession, parent_model: "ORTModel"):

self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())}
self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())}

self.input_dtypes = {input_key.name: input_key.type for input_key in session.get_inputs()}
self.output_dtypes = {output_key.name: output_key.type for output_key in session.get_outputs()}

self._ordered_input_names = get_ordered_input_names(self.input_names.keys(), func=self.forward)
self.input_shapes = {input_key.name: input_key.shape for input_key in session.get_inputs()}
self.output_shapes = {output_key.name: output_key.shape for output_key in session.get_outputs()}

@property
def device(self):
Expand Down Expand Up @@ -118,24 +125,17 @@ def forward(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor,
use_torch = isinstance(input_ids, torch.Tensor)
self.parent_model.raise_on_numpy_input_io_binding(use_torch)

if self.device.type == "cuda" and self.parent_model.use_io_binding:
model_inputs = [input_ids]
if "attention_mask" in self.input_names:
model_inputs.append(attention_mask)
io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding(
self.session,
*model_inputs,
ordered_input_names=self._ordered_input_names,
)
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}

if self.parent_model.use_io_binding:
io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.session, model_inputs)

io_binding.synchronize_inputs()
self.session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()

last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}

onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
Expand Down Expand Up @@ -259,7 +259,6 @@ def forward(
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.Tensor] = None,
use_cache_branch: None = None,
) -> Seq2SeqLMOutput:
# Adding use_cache_branch in the signature here is just a hack for IO Binding

Expand All @@ -279,6 +278,18 @@ def forward(
input_ids, past_key_values, cache_position, use_torch=use_torch
)

model_inputs = {
"input_ids": input_ids,
"encoder_hidden_states": encoder_hidden_states,
"decoder_attention_mask": decoder_attention_mask,
"encoder_attention_mask": encoder_attention_mask,
"use_cache_branch": use_cache_branch_tensor,
"cache_position": cache_position,
"labels": labels,
}
if past_key_values is not None:
model_inputs.update(zip(self.key_value_input_names, past_key_values))

if self.parent_model.use_io_binding:
known_output_shapes = self.compute_past_key_values_output_shapes(
input_ids,
Expand All @@ -289,50 +300,22 @@ def forward(

outputs_to_not_bind = self.get_outputs_not_to_bind(use_merged_cache)

# TODO: fix transformers generate to have contiguous input_ids here already
# For an unknown reason, calling `contiguous()` here is necessary to not have errors
# on CPU EP with batch size > 1, despite it being also called in _prepare_io_binding.g
model_inputs = [input_ids.contiguous()]

if "encoder_hidden_states" in self.input_names:
model_inputs.append(encoder_hidden_states)

if "decoder_attention_mask" in self.input_names:
model_inputs.append(decoder_attention_mask)

if "encoder_attention_mask" in self.input_names:
model_inputs.append(encoder_attention_mask)

if past_key_values is not None:
model_inputs += past_key_values

if "labels" in self.input_names:
model_inputs.append(labels)
known_output_shapes.update({"loss": []})

if use_cache_branch_tensor is not None:
model_inputs.append(use_cache_branch_tensor)

if "cache_position" in self.input_names:
model_inputs.append(cache_position)

io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding(
io_binding, output_shapes, output_buffers = self._prepare_io_binding(
self.session,
*model_inputs,
model_inputs,
known_output_shapes=known_output_shapes,
ordered_input_names=self._ordered_input_names,
outputs_to_not_bind=outputs_to_not_bind,
)

io_binding.synchronize_inputs()
self.session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()

# Set -1 for sequence_length as it could be larger than the real sequence_length
for name, shape in output_shapes.items():
if name in self.key_value_output_names:
output_shapes[name] = shape[:2] + (-1,) + shape[3:]

io_binding.synchronize_inputs()
self.session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()

# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the
# self-attention layer and 2 to the cross-attention layer)
out_past_key_values = ()
Expand Down Expand Up @@ -382,18 +365,6 @@ def forward(
else:
raise ValueError("Unsupported num_pkv")
else:
model_inputs = {
"input_ids": input_ids,
"encoder_hidden_states": encoder_hidden_states,
"decoder_attention_mask": decoder_attention_mask,
"encoder_attention_mask": encoder_attention_mask,
"use_cache_branch": use_cache_branch_tensor,
"cache_position": cache_position,
"labels": labels,
}
if past_key_values is not None:
model_inputs.update(zip(self.key_value_input_names, past_key_values))

onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
Expand Down
52 changes: 14 additions & 38 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,33 +240,21 @@ def forward(
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)

if self.use_io_binding:
# TODO: fix transformers generate to have contiguous input_ids here already
# For an unknown reason, calling `contiguous()` here is necessary to not have errors
# on CPU EP with batch size > 1, despite it being also called in _prepare_io_binding.
# I suspect the reason is the contiguous python list that messes something up?
model_inputs = [input_ids.contiguous()]

if "attention_mask" in self.input_names:
model_inputs.append(attention_mask)

if "position_ids" in self.input_names:
model_inputs.append(position_ids.contiguous())

if past_key_values is not None:
model_inputs += past_key_values

if use_cache_branch is not None:
model_inputs.append(use_cache_branch)

if "labels" in self.input_names:
model_inputs.append(labels)
known_output_shapes.update({"loss": []})
model_inputs = {
"input_ids": input_ids,
"position_ids": position_ids,
"attention_mask": attention_mask,
"use_cache_branch": use_cache_branch,
"labels": labels,
}
if past_key_values is not None:
model_inputs.update(
zip(self.key_value_input_names, past_key_values),
)

io_binding, output_shapes, output_buffers = self.prepare_io_binding(
*model_inputs,
known_output_shapes=known_output_shapes,
ordered_input_names=self._ordered_input_names,
if self.use_io_binding:
io_binding, output_shapes, output_buffers = self._prepare_io_binding(
self.model, model_inputs, known_output_shapes=known_output_shapes
)

if self.device.type == "cpu":
Expand All @@ -287,18 +275,6 @@ def forward(
if "loss" in self.output_names:
loss = output_buffers["loss"].view(output_shapes["loss"])
else:
model_inputs = {
"input_ids": input_ids,
"position_ids": position_ids,
"attention_mask": attention_mask,
"use_cache_branch": use_cache_branch,
"labels": labels,
}
if past_key_values is not None:
model_inputs.update(
zip(self.key_value_input_names, past_key_values),
)

onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
Expand Down
Loading

0 comments on commit 18e887d

Please sign in to comment.