diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index 845780cafa..84f7191581 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -26,7 +26,7 @@ from ..utils.logging import warn_once from .io_binding import TypeHelper from .modeling_ort import ORTModel -from .utils import get_ordered_input_names, logging +from .utils import logging logger = logging.get_logger(__name__) @@ -38,6 +38,11 @@ class ORTModelPart: It has its own `onnxruntime.InferenceSession`, and can perform a forward pass. """ + # should be in an ORTMixin + _prepare_io_binding = ORTModel._prepare_io_binding + _prepare_output_buffer = ORTModel._prepare_output_buffer + _output_shape_inference = ORTModel._output_shape_inference + _prepare_onnx_inputs = ORTModel._prepare_onnx_inputs _prepare_onnx_outputs = ORTModel._prepare_onnx_outputs @@ -48,10 +53,12 @@ def __init__(self, session: InferenceSession, parent_model: "ORTModel"): self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())} self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} + self.input_dtypes = {input_key.name: input_key.type for input_key in session.get_inputs()} self.output_dtypes = {output_key.name: output_key.type for output_key in session.get_outputs()} - self._ordered_input_names = get_ordered_input_names(self.input_names.keys(), func=self.forward) + self.input_shapes = {input_key.name: input_key.shape for input_key in session.get_inputs()} + self.output_shapes = {output_key.name: output_key.shape for output_key in session.get_outputs()} @property def device(self): @@ -118,15 +125,10 @@ def forward(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, use_torch = isinstance(input_ids, torch.Tensor) self.parent_model.raise_on_numpy_input_io_binding(use_torch) - if self.device.type == "cuda" and self.parent_model.use_io_binding: - model_inputs = [input_ids] - if "attention_mask" in self.input_names: - model_inputs.append(attention_mask) - io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding( - self.session, - *model_inputs, - ordered_input_names=self._ordered_input_names, - ) + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} + + if self.parent_model.use_io_binding: + io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.session, model_inputs) io_binding.synchronize_inputs() self.session.run_with_iobinding(io_binding) @@ -134,8 +136,6 @@ def forward(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) else: - model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.session.run(None, onnx_inputs) model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) @@ -259,7 +259,6 @@ def forward( past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, labels: Optional[torch.LongTensor] = None, cache_position: Optional[torch.Tensor] = None, - use_cache_branch: None = None, ) -> Seq2SeqLMOutput: # Adding use_cache_branch in the signature here is just a hack for IO Binding @@ -279,6 +278,18 @@ def forward( input_ids, past_key_values, cache_position, use_torch=use_torch ) + model_inputs = { + "input_ids": input_ids, + "encoder_hidden_states": encoder_hidden_states, + "decoder_attention_mask": decoder_attention_mask, + "encoder_attention_mask": encoder_attention_mask, + "use_cache_branch": use_cache_branch_tensor, + "cache_position": cache_position, + "labels": labels, + } + if past_key_values is not None: + model_inputs.update(zip(self.key_value_input_names, past_key_values)) + if self.parent_model.use_io_binding: known_output_shapes = self.compute_past_key_values_output_shapes( input_ids, @@ -289,50 +300,22 @@ def forward( outputs_to_not_bind = self.get_outputs_not_to_bind(use_merged_cache) - # TODO: fix transformers generate to have contiguous input_ids here already - # For an unknown reason, calling `contiguous()` here is necessary to not have errors - # on CPU EP with batch size > 1, despite it being also called in _prepare_io_binding.g - model_inputs = [input_ids.contiguous()] - - if "encoder_hidden_states" in self.input_names: - model_inputs.append(encoder_hidden_states) - - if "decoder_attention_mask" in self.input_names: - model_inputs.append(decoder_attention_mask) - - if "encoder_attention_mask" in self.input_names: - model_inputs.append(encoder_attention_mask) - - if past_key_values is not None: - model_inputs += past_key_values - - if "labels" in self.input_names: - model_inputs.append(labels) - known_output_shapes.update({"loss": []}) - - if use_cache_branch_tensor is not None: - model_inputs.append(use_cache_branch_tensor) - - if "cache_position" in self.input_names: - model_inputs.append(cache_position) - - io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding( + io_binding, output_shapes, output_buffers = self._prepare_io_binding( self.session, - *model_inputs, + model_inputs, known_output_shapes=known_output_shapes, - ordered_input_names=self._ordered_input_names, outputs_to_not_bind=outputs_to_not_bind, ) + io_binding.synchronize_inputs() + self.session.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() + # Set -1 for sequence_length as it could be larger than the real sequence_length for name, shape in output_shapes.items(): if name in self.key_value_output_names: output_shapes[name] = shape[:2] + (-1,) + shape[3:] - io_binding.synchronize_inputs() - self.session.run_with_iobinding(io_binding) - io_binding.synchronize_outputs() - # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the # self-attention layer and 2 to the cross-attention layer) out_past_key_values = () @@ -382,18 +365,6 @@ def forward( else: raise ValueError("Unsupported num_pkv") else: - model_inputs = { - "input_ids": input_ids, - "encoder_hidden_states": encoder_hidden_states, - "decoder_attention_mask": decoder_attention_mask, - "encoder_attention_mask": encoder_attention_mask, - "use_cache_branch": use_cache_branch_tensor, - "cache_position": cache_position, - "labels": labels, - } - if past_key_values is not None: - model_inputs.update(zip(self.key_value_input_names, past_key_values)) - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.session.run(None, onnx_inputs) model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 3905a573a3..9d3535384a 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -240,33 +240,21 @@ def forward( if past_key_values: position_ids = position_ids[:, -1].unsqueeze(-1) - if self.use_io_binding: - # TODO: fix transformers generate to have contiguous input_ids here already - # For an unknown reason, calling `contiguous()` here is necessary to not have errors - # on CPU EP with batch size > 1, despite it being also called in _prepare_io_binding. - # I suspect the reason is the contiguous python list that messes something up? - model_inputs = [input_ids.contiguous()] - - if "attention_mask" in self.input_names: - model_inputs.append(attention_mask) - - if "position_ids" in self.input_names: - model_inputs.append(position_ids.contiguous()) - - if past_key_values is not None: - model_inputs += past_key_values - - if use_cache_branch is not None: - model_inputs.append(use_cache_branch) - - if "labels" in self.input_names: - model_inputs.append(labels) - known_output_shapes.update({"loss": []}) + model_inputs = { + "input_ids": input_ids, + "position_ids": position_ids, + "attention_mask": attention_mask, + "use_cache_branch": use_cache_branch, + "labels": labels, + } + if past_key_values is not None: + model_inputs.update( + zip(self.key_value_input_names, past_key_values), + ) - io_binding, output_shapes, output_buffers = self.prepare_io_binding( - *model_inputs, - known_output_shapes=known_output_shapes, - ordered_input_names=self._ordered_input_names, + if self.use_io_binding: + io_binding, output_shapes, output_buffers = self._prepare_io_binding( + self.model, model_inputs, known_output_shapes=known_output_shapes ) if self.device.type == "cpu": @@ -287,18 +275,6 @@ def forward( if "loss" in self.output_names: loss = output_buffers["loss"].view(output_shapes["loss"]) else: - model_inputs = { - "input_ids": input_ids, - "position_ids": position_ids, - "attention_mask": attention_mask, - "use_cache_branch": use_cache_branch, - "labels": labels, - } - if past_key_values is not None: - model_inputs.update( - zip(self.key_value_input_names, past_key_values), - ) - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index 4de03a65c6..35a448fbb8 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -72,7 +72,6 @@ ONNX_WEIGHTS_NAME, check_io_binding, get_device_for_provider, - get_ordered_input_names, get_provider_for_device, parse_device, validate_provider_availability, @@ -276,8 +275,6 @@ def __init__( self.output_names = {output_key.name: idx for idx, output_key in enumerate(model.get_outputs())} self.output_dtypes = {output_key.name: output_key.type for output_key in model.get_outputs()} - self._ordered_input_names = get_ordered_input_names(self.input_names.keys(), func=self.forward) - @property def dtype(self) -> torch.dtype: """ @@ -773,43 +770,23 @@ def _output_shape_inference(self, axis_name: Union[str, int], dimensions: Dict[s """ if isinstance(axis_name, int): return axis_name - # It is actually covered below, but this is to make things faster. + elif axis_name in dimensions: return dimensions[axis_name] - # Tokens is going to be populated by iterating over every match for the self.output_shape_inference_pattern. - # This pattern matches 4 things: axis names, integer values, operators (+, -, *, /) and parenthesis. - tokens = [] - for idx, match_ in enumerate(re.finditer(self.output_shape_inference_pattern, axis_name)): - groups = match_.groups() - matched_group = None - for idx, group in enumerate(groups): - if group is not None: - matched_group = idx - break - - # For every match except an axis name, we simply append the content of the match to the tokens list. - # For an axis name, we check if it is specified in the `dimensions` dictionary. If for some reason it is - # not there, or its value not an integer, the shape inference process stops and we return the axis name as - # is. - if matched_group == 0: - dim = dimensions.get(groups[0], None) - if dim is None or not isinstance(dim, int): - return axis_name - tokens.append(str(dim)) - else: - tokens.append(groups[matched_group]) + # faster way to do the same thing, assuming the axis names are well defined (by us in the exporter config) + tokens = axis_name.split(" ") + for idx, token in enumerate(tokens): + if token in dimensions: + tokens[idx] = str(dimensions[token]) - # Here it should not be problematic to use eval since anything not matching the pattern would trigger an - # exception. - return int(eval(" ".join(tokens))) + return eval(" ".join(tokens)) # TODO: this method is bloated with state arguments (that are accesible using self) why ? def _prepare_io_binding( self, model: ort.InferenceSession, - *model_inputs: torch.Tensor, - ordered_input_names: List[str], + model_inputs: Dict[str, torch.Tensor], known_output_shapes: Optional[Dict[str, Tuple[int]]] = None, outputs_to_not_bind: Optional[Union[Set[str], str]] = None, ) -> Tuple[ort.IOBinding, Dict[str, Tuple[int]], Dict[str, torch.Tensor]]: @@ -819,10 +796,8 @@ def _prepare_io_binding( Args: model (`ort.InferenceSession`): The model for which we want to bind the inputs and outputs. - *model_inputs: - The inputs of the model. - ordered_input_names (`List[str]`): - Names of the inputs, that must match with the order of model_inputs. + model_inputs (`Dict[str, torch.Tensor]`): + The inputs to bind to the model. known_output_shapes (`Optional[Dict[str, Tuple[int]]]`, defaults to `None`): It can be hard to infer all the output shapes from the inputs only. For instance for the past key / values. It is possible to explicitely pass the shape via this argument. @@ -838,24 +813,21 @@ def _prepare_io_binding( name_to_np_type = TypeHelper.get_io_numpy_type_map(model) input_name_to_shape = {} - for idx, tensor in enumerate(model_inputs): - if tensor is None: - continue - name = ordered_input_names[idx] - tensor = tensor.contiguous() - input_name_to_shape[name] = tensor.shape + for input_name in self.input_names.keys(): + tensor = model_inputs[input_name].contiguous() + input_name_to_shape[input_name] = tensor.shape data_ptr = tensor.data_ptr() - if "past" in name and data_ptr == 0: + if "past" in input_name and data_ptr == 0: # During first generation, sequence_length can be 0 when use_cache=True, which results in data_ptr to also be 0. # To keep compatibility with IO binding, we pass the data pointer of input_ids instead. This will have no impact because past_key_values will not be used during the first generation. - data_ptr = model_inputs[0].data_ptr() + data_ptr = next(iter(model_inputs.values())).data_ptr() io_binding.bind_input( - name, + input_name, tensor.device.type, IOBindingHelper.get_device_index(self.device), - name_to_np_type[name], + name_to_np_type[input_name], tuple(tensor.shape), data_ptr, ) @@ -902,17 +874,6 @@ def _prepare_io_binding( return io_binding, output_shapes, output_buffers - def prepare_io_binding( - self, *model_inputs, ordered_input_names, outputs_to_not_bind=None, known_output_shapes=None - ): - return self._prepare_io_binding( - self.model, - *model_inputs, - ordered_input_names=ordered_input_names, - known_output_shapes=known_output_shapes, - outputs_to_not_bind=outputs_to_not_bind, - ) - def raise_on_numpy_input_io_binding(self, use_torch: bool): """ Raises an error if IO Binding is requested although the tensor used are numpy arrays. @@ -1091,13 +1052,14 @@ def forward( if token_type_ids is None and "token_type_ids" in self.input_names: token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids) - if self.device.type == "cuda" and self.use_io_binding: - io_binding, output_shapes, output_buffers = self.prepare_io_binding( - input_ids, - attention_mask, - token_type_ids, - ordered_input_names=self._ordered_input_names, - ) + model_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + if self.use_io_binding: + io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.model, model_inputs) # run inference with binding & synchronize in case of multiple CUDA streams io_binding.synchronize_inputs() @@ -1106,8 +1068,6 @@ def forward( last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) else: - model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids} - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) @@ -1246,16 +1206,17 @@ def forward( use_torch = isinstance(input_ids, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) - if token_type_ids is None and "token_type_ids" in self.input_names: + if "token_type_ids" in self.input_names and token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids) - if self.device.type == "cuda" and self.use_io_binding: - io_binding, output_shapes, output_buffers = self.prepare_io_binding( - input_ids, - attention_mask, - token_type_ids, - ordered_input_names=self._ordered_input_names, - ) + model_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + if self.use_io_binding: + io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.model, model_inputs) # run inference with binding & synchronize in case of multiple CUDA streams io_binding.synchronize_inputs() @@ -1264,8 +1225,6 @@ def forward( logits = output_buffers["logits"].view(output_shapes["logits"]) else: - model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids} - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) @@ -1341,13 +1300,10 @@ def forward( if token_type_ids is None and "token_type_ids" in self.input_names: token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids) - if self.device.type == "cuda" and self.use_io_binding: - io_binding, output_shapes, output_buffers = self.prepare_io_binding( - input_ids, - attention_mask, - token_type_ids, - ordered_input_names=self._ordered_input_names, - ) + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids} + + if self.use_io_binding: + io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.model, model_inputs) # run inference with binding & synchronize in case of multiple CUDA streams io_binding.synchronize_inputs() @@ -1358,8 +1314,6 @@ def forward( start_logits = output_buffers["start_logits"].view(output_shapes["start_logits"]) end_logits = output_buffers["end_logits"].view(output_shapes["end_logits"]) else: - model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids} - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) @@ -1451,13 +1405,14 @@ def forward( if token_type_ids is None and "token_type_ids" in self.input_names: token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids) - if self.device.type == "cuda" and self.use_io_binding: - io_binding, output_shapes, output_buffers = self.prepare_io_binding( - input_ids, - attention_mask, - token_type_ids, - ordered_input_names=self._ordered_input_names, - ) + model_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + if self.use_io_binding: + io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.model, model_inputs) # run inference with binding & synchronize in case of multiple CUDA streams io_binding.synchronize_inputs() @@ -1466,8 +1421,6 @@ def forward( logits = output_buffers["logits"].view(output_shapes["logits"]) else: - model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids} - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) @@ -1544,13 +1497,14 @@ def forward( if token_type_ids is None and "token_type_ids" in self.input_names: token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids) - if self.device.type == "cuda" and self.use_io_binding: - io_binding, output_shapes, output_buffers = self.prepare_io_binding( - input_ids, - attention_mask, - token_type_ids, - ordered_input_names=self._ordered_input_names, - ) + model_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + if self.use_io_binding: + io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.model, model_inputs) # run inference with binding & synchronize in case of multiple CUDA streams io_binding.synchronize_inputs() @@ -1559,8 +1513,6 @@ def forward( logits = output_buffers["logits"].view(output_shapes["logits"]) else: - model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids} - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) @@ -1630,13 +1582,14 @@ def forward( if token_type_ids is None and "token_type_ids" in self.input_names: token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids) - if self.device.type == "cuda" and self.use_io_binding: - io_binding, output_shapes, output_buffers = self.prepare_io_binding( - input_ids, - attention_mask, - token_type_ids, - ordered_input_names=self._ordered_input_names, - ) + model_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + if self.use_io_binding: + io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.model, model_inputs) # run inference with binding & synchronize in case of multiple CUDA streams io_binding.synchronize_inputs() @@ -1645,8 +1598,6 @@ def forward( logits = output_buffers["logits"].view(output_shapes["logits"]) else: - model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids} - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) @@ -1720,11 +1671,12 @@ def forward( use_torch = isinstance(pixel_values, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) - if self.device.type == "cuda" and self.use_io_binding: - io_binding, output_shapes, output_buffers = self.prepare_io_binding( - pixel_values, - ordered_input_names=self._ordered_input_names, - ) + model_inputs = { + "pixel_values": pixel_values, + } + + if self.use_io_binding: + io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.model, model_inputs) # run inference with binding & synchronize in case of multiple CUDA streams io_binding.synchronize_inputs() @@ -1733,8 +1685,6 @@ def forward( logits = output_buffers["logits"].view(output_shapes["logits"]) else: - model_inputs = {"pixel_values": pixel_values} - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) @@ -1808,11 +1758,12 @@ def forward( use_torch = isinstance(pixel_values, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) - if self.device.type == "cuda" and self.use_io_binding: - io_binding, output_shapes, output_buffers = self.prepare_io_binding( - pixel_values, - ordered_input_names=self._ordered_input_names, - ) + model_inputs = { + "pixel_values": pixel_values, + } + + if self.use_io_binding: + io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.model, model_inputs) # run inference with binding & synchronize in case of multiple CUDA streams io_binding.synchronize_inputs() @@ -1821,8 +1772,6 @@ def forward( logits = output_buffers["logits"].view(output_shapes["logits"]) else: - model_inputs = {"pixel_values": pixel_values} - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) @@ -1935,12 +1884,13 @@ def forward( use_torch = isinstance(model_input, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) - if self.device.type == "cuda" and self.use_io_binding: - io_binding, output_shapes, output_buffers = self.prepare_io_binding( - model_input, - attention_mask, - ordered_input_names=self._ordered_input_names, - ) + model_inputs = { + self.input_name: model_input, + "attention_mask": attention_mask, + } + + if self.use_io_binding: + io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.model, model_inputs) # run inference with binding & synchronize in case of multiple CUDA streams io_binding.synchronize_inputs() @@ -1949,8 +1899,6 @@ def forward( logits = output_buffers["logits"].view(output_shapes["logits"]) else: - model_inputs = {self.input_name: model_input, "attention_mask": attention_mask} - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) @@ -2012,21 +1960,21 @@ def forward( use_torch = isinstance(input_values, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) - if self.device.type == "cuda" and self.use_io_binding: - input_size = input_values.shape[1] - output_sizes = [] + model_inputs = { + "input_values": input_values, + } - def _conv_output_size(input_size, kernel_size, stride): - return (input_size - kernel_size) // stride + 1 + if self.use_io_binding: + batch_size = input_values.shape[0] + final_input_size = input_values.shape[-1] for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): - input_size = _conv_output_size(input_size, kernel_size, stride) - output_sizes.append(input_size) + final_input_size = (final_input_size - kernel_size) // stride + 1 - known_output_shapes = {"logits": [input_values.shape[0], output_sizes[-1], self.config.vocab_size]} + known_output_shapes = {"logits": [batch_size, final_input_size, self.config.vocab_size]} - io_binding, output_shapes, output_buffers = self.prepare_io_binding( - input_values, ordered_input_names=self._ordered_input_names, known_output_shapes=known_output_shapes + io_binding, output_shapes, output_buffers = self._prepare_io_binding( + self.model, model_inputs, known_output_shapes=known_output_shapes ) # run inference with binding & synchronize in case of multiple CUDA streams @@ -2036,8 +1984,6 @@ def _conv_output_size(input_size, kernel_size, stride): logits = output_buffers["logits"].view(output_shapes["logits"]) else: - model_inputs = {"input_values": input_values} - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) @@ -2107,10 +2053,12 @@ def forward( use_torch = isinstance(input_values, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) - if self.device.type == "cuda" and self.use_io_binding: - io_binding, output_shapes, output_buffers = self.prepare_io_binding( - input_values, ordered_input_names=self._ordered_input_names - ) + model_inputs = { + "input_values": input_values, + } + + if self.use_io_binding: + io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.model, model_inputs) # run inference with binding & synchronize in case of multiple CUDA streams io_binding.synchronize_inputs() @@ -2121,8 +2069,6 @@ def forward( embeddings = output_buffers["embeddings"].view(output_shapes["embeddings"]) else: - model_inputs = {"input_values": input_values} - onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) @@ -2185,7 +2131,7 @@ def forward( use_torch = isinstance(input_values, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) - if self.device.type == "cuda" and self.use_io_binding: + if self.use_io_binding: raise NotImplementedError() else: model_inputs = {"input_values": input_values} @@ -2244,26 +2190,27 @@ def forward( ): use_torch = isinstance(pixel_values, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) - if self.device.type == "cuda" and self.use_io_binding: - input_shapes = pixel_values.shape - io_binding, output_shapes, output_buffers = self.prepare_io_binding( - pixel_values, - ordered_input_names=self._ordered_input_names, - known_output_shapes={ - "reconstruction": [ - input_shapes[0], - input_shapes[1], - input_shapes[2] * self.config.upscale, - input_shapes[3] * self.config.upscale, - ] - }, + + model_inputs = { + "pixel_values": pixel_values, + } + + if self.use_io_binding: + batch_size, num_channels, height, width = pixel_values.shape + known_output_shapes = { + "reconstruction": [batch_size, num_channels, height * self.config.upscale, width * self.config.upscale] + } + + io_binding, output_shapes, output_buffers = self._prepare_io_binding( + self.model, model_inputs, known_output_shapes=known_output_shapes ) + io_binding.synchronize_inputs() self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() + reconstruction = output_buffers["reconstruction"].view(output_shapes["reconstruction"]) else: - model_inputs = {"pixel_values": pixel_values} onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) @@ -2321,7 +2268,7 @@ def forward(self, **model_inputs: Union[torch.Tensor, np.ndarray]): use_torch = isinstance(next(iter(model_inputs.values())), torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) - if self.device.type == "cuda" and self.use_io_binding: + if self.use_io_binding: # TODO: should this be used in favor of `model.prepare_io_binding`? io_binding = IOBindingHelper.prepare_io_binding(self, **model_inputs) diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index e7e723dc28..8db67b26e7 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -367,11 +367,7 @@ def forward( model_inputs = ( [input_features, attention_mask] if "attention_mask" in self.input_names else [input_features] ) - io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding( - self.session, - *model_inputs, - ordered_input_names=self._ordered_input_names, - ) + io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.session, *model_inputs) io_binding.synchronize_inputs() self.session.run_with_iobinding(io_binding) @@ -422,12 +418,12 @@ def forward( use_torch = isinstance(pixel_values, torch.Tensor) self.parent_model.raise_on_numpy_input_io_binding(use_torch) + model_inputs = { + "pixel_values": pixel_values, + } + if self.parent_model.device.type == "cuda" and self.parent_model.use_io_binding: - io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding( - self.session, - pixel_values, - ordered_input_names=self._ordered_input_names, - ) + io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.session, model_inputs) io_binding.synchronize_inputs() self.session.run_with_iobinding(io_binding) @@ -435,16 +431,11 @@ def forward( last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) else: - if use_torch: - onnx_inputs = {"pixel_values": pixel_values.cpu().detach().numpy()} - else: - onnx_inputs = {"pixel_values": pixel_values} + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) + onnx_outputs = self.session.run(None, onnx_inputs) + model_outputs = self._prepare_model_outputs(use_torch, onnx_outputs) - outputs = self.session.run(None, onnx_inputs) - - last_hidden_state = outputs[self.output_names["last_hidden_state"]] - if use_torch: - last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device) + last_hidden_state = model_outputs["last_hidden_state"] return BaseModelOutput(last_hidden_state=last_hidden_state) @@ -468,15 +459,13 @@ def forward( use_torch = isinstance(flattened_patches, torch.Tensor) self.parent_model.raise_on_numpy_input_io_binding(use_torch) + model_inputs = { + "flattened_patches": flattened_patches, + "attention_mask": attention_mask, + } + if self.parent_model.device.type == "cuda" and self.parent_model.use_io_binding: - model_inputs = ( - [flattened_patches, attention_mask] if "attention_mask" in self.input_names else [flattened_patches] - ) - io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding( - self.session, - *model_inputs, - ordered_input_names=self._ordered_input_names, - ) + io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.session, model_inputs) io_binding.synchronize_inputs() self.session.run_with_iobinding(io_binding) @@ -484,25 +473,11 @@ def forward( last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) else: - if use_torch: - onnx_inputs = {"flattened_patches": flattened_patches.cpu().detach().numpy()} - if "attention_mask" in self.input_names: - onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy() - else: - onnx_inputs = {"flattened_patches": flattened_patches} - if "attention_mask" in self.input_names: - onnx_inputs["attention_mask"] = attention_mask + onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs) + onnx_outputs = self.session.run(None, onnx_inputs) + model_outputs = self._prepare_model_outputs(use_torch, onnx_outputs) - if "attention_mask" in self.input_names: - if self.session.get_inputs()[1].type == "tensor(int64)": - onnx_inputs["attention_mask"] = onnx_inputs["attention_mask"].astype(np.int64) - - outputs = self.session.run(None, onnx_inputs) - - last_hidden_state = outputs[self.output_names["last_hidden_state"]] - - if use_torch: - last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device) + last_hidden_state = model_outputs["last_hidden_state"] return BaseModelOutput(last_hidden_state=last_hidden_state) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 986dd65df3..bc41039616 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -1448,14 +1448,18 @@ def test_compare_to_io_binding(self, model_arch): model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForQuestionAnswering.from_pretrained( - self.onnx_model_dirs[model_arch], use_io_binding=False - ).to("cuda") + self.onnx_model_dirs[model_arch], use_io_binding=False, provider="CUDAExecutionProvider" + ) io_model = ORTModelForQuestionAnswering.from_pretrained( - self.onnx_model_dirs[model_arch], use_io_binding=True - ).to("cuda") + self.onnx_model_dirs[model_arch], use_io_binding=True, provider="CUDAExecutionProvider" + ) + + self.assertFalse(onnx_model.use_io_binding) + self.assertTrue(io_model.use_io_binding) tokenizer = get_preprocessor(model_id) - tokens = tokenizer(["This is a sample output"] * 2, return_tensors="pt") + tokens = tokenizer(["This is a sample output"] * 2, return_tensors="pt").to("cuda") + onnx_outputs = onnx_model(**tokens) io_outputs = io_model(**tokens) @@ -1626,16 +1630,19 @@ def test_compare_to_io_binding(self, model_arch): self._setup(model_args) model_id = MODEL_NAMES[model_arch] - onnx_model = ORTModelForMaskedLM.from_pretrained(self.onnx_model_dirs[model_arch], use_io_binding=False).to( - "cuda" + onnx_model = ORTModelForMaskedLM.from_pretrained( + self.onnx_model_dirs[model_arch], use_io_binding=False, provider="CUDAExecutionProvider" ) - io_model = ORTModelForMaskedLM.from_pretrained(self.onnx_model_dirs[model_arch], use_io_binding=True).to( - "cuda" + io_model = ORTModelForMaskedLM.from_pretrained( + self.onnx_model_dirs[model_arch], use_io_binding=True, provider="CUDAExecutionProvider" ) + self.assertFalse(onnx_model.use_io_binding) + self.assertTrue(io_model.use_io_binding) + tokenizer = get_preprocessor(model_id) - MASK_TOKEN = tokenizer.mask_token - tokens = tokenizer([f"The capital of France is {MASK_TOKEN}."] * 2, return_tensors="pt") + tokens = tokenizer([f"The capital of France is {tokenizer.mask_token}."] * 2, return_tensors="pt").to("cuda") + onnx_outputs = onnx_model(**tokens) io_outputs = io_model(**tokens) @@ -1643,9 +1650,7 @@ def test_compare_to_io_binding(self, model_arch): self.assertIsInstance(io_outputs.logits, torch.Tensor) # compare tensor outputs - torch.testing.assert_close( - torch.Tensor(io_outputs.logits), onnx_outputs.logits, atol=self.ATOL, rtol=self.RTOL - ) + torch.testing.assert_close(io_outputs.logits, onnx_outputs.logits, atol=self.ATOL, rtol=self.RTOL) gc.collect() @@ -1840,14 +1845,18 @@ def test_compare_to_io_binding(self, model_arch): model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForSequenceClassification.from_pretrained( - self.onnx_model_dirs[model_arch], use_io_binding=False - ).to("cuda") + self.onnx_model_dirs[model_arch], use_io_binding=False, provider="CUDAExecutionProvider" + ) io_model = ORTModelForSequenceClassification.from_pretrained( - self.onnx_model_dirs[model_arch], use_io_binding=True - ).to("cuda") + self.onnx_model_dirs[model_arch], use_io_binding=True, provider="CUDAExecutionProvider" + ) + + self.assertFalse(onnx_model.use_io_binding) + self.assertTrue(io_model.use_io_binding) tokenizer = get_preprocessor(model_id) - tokens = tokenizer(["This is a sample output"] * 2, return_tensors="pt") + tokens = tokenizer(["This is a sample output"] * 2, return_tensors="pt").to("cuda") + onnx_outputs = onnx_model(**tokens) io_outputs = io_model(**tokens) @@ -2022,14 +2031,18 @@ def test_compare_to_io_binding(self, model_arch): model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForTokenClassification.from_pretrained( - self.onnx_model_dirs[model_arch], use_io_binding=False - ).to("cuda") + self.onnx_model_dirs[model_arch], use_io_binding=False, provider="CUDAExecutionProvider" + ) io_model = ORTModelForTokenClassification.from_pretrained( - self.onnx_model_dirs[model_arch], use_io_binding=True - ).to("cuda") + self.onnx_model_dirs[model_arch], use_io_binding=True, provider="CUDAExecutionProvider" + ) + + self.assertFalse(onnx_model.use_io_binding) + self.assertTrue(io_model.use_io_binding) tokenizer = get_preprocessor(model_id) - tokens = tokenizer(["This is a sample output"] * 2, return_tensors="pt") + tokens = tokenizer(["This is a sample output"] * 2, return_tensors="pt").to("cuda") + onnx_outputs = onnx_model(**tokens) io_outputs = io_model(**tokens) @@ -2181,14 +2194,18 @@ def test_compare_to_io_binding(self, model_arch): model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForFeatureExtraction.from_pretrained( - self.onnx_model_dirs[model_arch], use_io_binding=False - ).to("cuda") + self.onnx_model_dirs[model_arch], use_io_binding=False, provider="CUDAExecutionProvider" + ) io_model = ORTModelForFeatureExtraction.from_pretrained( - self.onnx_model_dirs[model_arch], use_io_binding=True - ).to("cuda") + self.onnx_model_dirs[model_arch], use_io_binding=True, provider="CUDAExecutionProvider" + ) + + self.assertFalse(onnx_model.use_io_binding) + self.assertTrue(io_model.use_io_binding) tokenizer = get_preprocessor(model_id) - tokens = tokenizer(["This is a sample output"] * 2, return_tensors="pt") + tokens = tokenizer(["This is a sample output"] * 2, return_tensors="pt").to("cuda") + onnx_outputs = onnx_model(**tokens) io_outputs = io_model(**tokens) @@ -2296,24 +2313,25 @@ def test_compare_to_io_binding(self, model_arch): model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForMultipleChoice.from_pretrained( - self.onnx_model_dirs[model_arch], use_io_binding=False - ).to("cuda") - io_model = ORTModelForMultipleChoice.from_pretrained(self.onnx_model_dirs[model_arch], use_io_binding=True).to( - "cuda" + self.onnx_model_dirs[model_arch], use_io_binding=False, provider="CUDAExecutionProvider" + ) + io_model = ORTModelForMultipleChoice.from_pretrained( + self.onnx_model_dirs[model_arch], use_io_binding=True, provider="CUDAExecutionProvider" ) - tokenizer = get_preprocessor(model_id) + self.assertFalse(onnx_model.use_io_binding) + self.assertTrue(io_model.use_io_binding) + num_choices = 4 - first_sentence = ["The sky is blue due to the shorter wavelength of blue light."] * num_choices start = "The color of the sky is" + tokenizer = get_preprocessor(model_id) + first_sentence = ["The sky is blue due to the shorter wavelength of blue light."] * num_choices second_sentence = [start + "blue", start + "green", start + "red", start + "yellow"] inputs = tokenizer(first_sentence, second_sentence, truncation=True, padding=True) - # Unflatten the tokenized inputs values expanding it to the shape [batch_size, num_choices, seq_length] for k, v in inputs.items(): inputs[k] = [v[i : i + num_choices] for i in range(0, len(v), num_choices)] - - inputs = dict(inputs.convert_to_tensors(tensor_type="pt")) + inputs = dict(inputs.convert_to_tensors(tensor_type="pt").to("cuda")) onnx_outputs = onnx_model(**inputs) io_outputs = io_model(**inputs) @@ -2668,7 +2686,6 @@ def test_pipeline_on_trt_execution_provider(self, test_name: str, model_arch: st gc.collect() @parameterized.expand(SUPPORTED_ARCHITECTURES) - @pytest.mark.cuda_ep_test # mark as GPU test as well to run the without/with cache timing test on the slow tests def test_compare_with_and_without_past_key_values(self, model_arch): model_args = {"test_name": model_arch + "_False", "model_arch": model_arch, "use_cache": False} self._setup(model_args) @@ -2759,11 +2776,17 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForCausalLM.from_pretrained( - self.onnx_model_dirs[test_name], use_cache=use_cache, use_io_binding=False - ).to("cuda") + self.onnx_model_dirs[test_name], + use_cache=use_cache, + use_io_binding=False, + provider="CUDAExecutionProvider", + ) io_model = ORTModelForCausalLM.from_pretrained( - self.onnx_model_dirs[test_name], use_cache=use_cache, use_io_binding=True - ).to("cuda") + self.onnx_model_dirs[test_name], + use_cache=use_cache, + use_io_binding=True, + provider="CUDAExecutionProvider", + ) tokenizer = get_preprocessor(model_id) tokens = tokenizer(["This is a sample output"] * 2, return_tensors="pt").to("cuda") @@ -2794,10 +2817,15 @@ def test_compare_generation_to_io_binding(self, test_name: str, model_arch: str, self._setup(model_args) model_id = MODEL_NAMES[model_arch] - onnx_model = ORTModelForCausalLM.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=False).to( - "cuda" + onnx_model = ORTModelForCausalLM.from_pretrained( + self.onnx_model_dirs[test_name], use_io_binding=False, provider="CUDAExecutionProvider" ) - io_model = ORTModelForCausalLM.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=True).to("cuda") + io_model = ORTModelForCausalLM.from_pretrained( + self.onnx_model_dirs[test_name], use_io_binding=True, provider="CUDAExecutionProvider" + ) + + self.assertFalse(onnx_model.use_io_binding) + self.assertTrue(io_model.use_io_binding) tokenizer = get_preprocessor(model_id) tokens = tokenizer( @@ -2805,6 +2833,7 @@ def test_compare_generation_to_io_binding(self, test_name: str, model_arch: str, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None, ).to("cuda") + onnx_outputs = onnx_model.generate(**tokens) io_outputs = io_model.generate(**tokens) @@ -3044,16 +3073,20 @@ def test_compare_to_io_binding(self, model_arch): model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForImageClassification.from_pretrained( - self.onnx_model_dirs[model_arch], use_io_binding=False - ).to("cuda") + self.onnx_model_dirs[model_arch], use_io_binding=False, provider="CUDAExecutionProvider" + ) io_model = ORTModelForImageClassification.from_pretrained( - self.onnx_model_dirs[model_arch], use_io_binding=True - ).to("cuda") + self.onnx_model_dirs[model_arch], use_io_binding=True, provider="CUDAExecutionProvider" + ) + + self.assertFalse(onnx_model.use_io_binding) + self.assertTrue(io_model.use_io_binding) preprocessor = get_preprocessor(model_id) url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) - inputs = preprocessor(images=[image] * 2, return_tensors="pt") + inputs = preprocessor(images=[image] * 2, return_tensors="pt").to("cuda") + onnx_outputs = onnx_model(**inputs) io_outputs = io_model(**inputs) @@ -3208,16 +3241,20 @@ def test_compare_to_io_binding(self, model_arch): model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForSemanticSegmentation.from_pretrained( - self.onnx_model_dirs[model_arch], use_io_binding=False - ).to("cuda") + self.onnx_model_dirs[model_arch], use_io_binding=False, provider="CUDAExecutionProvider" + ) io_model = ORTModelForSemanticSegmentation.from_pretrained( - self.onnx_model_dirs[model_arch], use_io_binding=True - ).to("cuda") + self.onnx_model_dirs[model_arch], use_io_binding=True, provider="CUDAExecutionProvider" + ) + + self.assertFalse(onnx_model.use_io_binding) + self.assertTrue(io_model.use_io_binding) preprocessor = get_preprocessor(model_id) url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) - inputs = preprocessor(images=[image] * 2, return_tensors="pt") + inputs = preprocessor(images=[image] * 2, return_tensors="pt").to("cuda") + onnx_outputs = onnx_model(**inputs) io_outputs = io_model(**inputs) @@ -3391,16 +3428,19 @@ def test_compare_to_io_binding(self, model_arch): model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForAudioClassification.from_pretrained( - self.onnx_model_dirs[model_arch], use_io_binding=False - ).to("cuda") + self.onnx_model_dirs[model_arch], use_io_binding=False, provider="CUDAExecutionProvider" + ) io_model = ORTModelForAudioClassification.from_pretrained( - self.onnx_model_dirs[model_arch], use_io_binding=True - ).to("cuda") + self.onnx_model_dirs[model_arch], use_io_binding=True, provider="CUDAExecutionProvider" + ) + + self.assertFalse(onnx_model.use_io_binding) + self.assertTrue(io_model.use_io_binding) - processor = AutoFeatureExtractor.from_pretrained(model_id) data = self._generate_random_audio_data() + processor = AutoFeatureExtractor.from_pretrained(model_id) + input_values = processor(data, return_tensors="pt").to("cuda") - input_values = processor(data, return_tensors="pt") onnx_outputs = onnx_model(**input_values) io_outputs = io_model(**input_values) @@ -3485,17 +3525,20 @@ def test_compare_to_io_binding(self, model_arch): self._setup(model_args) model_id = MODEL_NAMES[model_arch] - onnx_model = ORTModelForCTC.from_pretrained( - self.onnx_model_dirs[model_arch], - use_io_binding=False, - ).to("cuda") - onnx_model.use_io_binding = False - io_model = ORTModelForCTC.from_pretrained(self.onnx_model_dirs[model_arch], use_io_binding=True).to("cuda") + self.onnx_model_dirs[model_arch], use_io_binding=False, provider="CUDAExecutionProvider" + ) + io_model = ORTModelForCTC.from_pretrained( + self.onnx_model_dirs[model_arch], use_io_binding=True, provider="CUDAExecutionProvider" + ) + + self.assertFalse(onnx_model.use_io_binding) + self.assertTrue(io_model.use_io_binding) - processor = AutoFeatureExtractor.from_pretrained(model_id) data = self._generate_random_audio_data() - input_values = processor(data, return_tensors="pt") + processor = AutoFeatureExtractor.from_pretrained(model_id) + input_values = processor(data, return_tensors="pt").to("cuda") + onnx_outputs = onnx_model(**input_values) io_outputs = io_model(**input_values) @@ -3581,16 +3624,19 @@ def test_compare_to_io_binding(self, model_arch): model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForAudioXVector.from_pretrained( - self.onnx_model_dirs[model_arch], use_io_binding=False - ).to("cuda") - io_model = ORTModelForAudioXVector.from_pretrained(self.onnx_model_dirs[model_arch], use_io_binding=True).to( - "cuda" + self.onnx_model_dirs[model_arch], use_io_binding=False, provider="CUDAExecutionProvider" + ) + io_model = ORTModelForAudioXVector.from_pretrained( + self.onnx_model_dirs[model_arch], use_io_binding=True, provider="CUDAExecutionProvider" ) - processor = AutoFeatureExtractor.from_pretrained(model_id) + self.assertFalse(onnx_model.use_io_binding) + self.assertTrue(io_model.use_io_binding) + data = self._generate_random_audio_data() + processor = AutoFeatureExtractor.from_pretrained(model_id) + input_values = processor(data, return_tensors="pt").to("cuda") - input_values = processor(data, return_tensors="pt") onnx_outputs = onnx_model(**input_values) io_outputs = io_model(**input_values) @@ -4222,11 +4268,17 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: continue onnx_model = ORTModelForSeq2SeqLM.from_pretrained( - self._get_onnx_model_dir(model_id, model_arch, test_name), use_io_binding=False, use_cache=use_cache - ).to("cuda") + self._get_onnx_model_dir(model_id, model_arch, test_name), + use_io_binding=False, + use_cache=use_cache, + provider="CUDAExecutionProvider", + ) io_model = ORTModelForSeq2SeqLM.from_pretrained( - self._get_onnx_model_dir(model_id, model_arch, test_name), use_io_binding=True, use_cache=use_cache - ).to("cuda") + self._get_onnx_model_dir(model_id, model_arch, test_name), + use_io_binding=True, + use_cache=use_cache, + provider="CUDAExecutionProvider", + ) self.assertFalse(onnx_model.use_io_binding) self.assertTrue(io_model.use_io_binding) @@ -4236,8 +4288,9 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: decoder_start_token_id = onnx_model.config.decoder_start_token_id if model_arch != "mbart" else 2 if model_arch == "encoder-decoder": decoder_start_token_id = tokenizer.cls_token_id - - decoder_inputs = {"decoder_input_ids": torch.ones((2, 1), dtype=torch.long) * decoder_start_token_id} + decoder_inputs = { + "decoder_input_ids": torch.ones((2, 1), dtype=torch.long).to("cuda") * decoder_start_token_id + } onnx_outputs = onnx_model(**tokens, **decoder_inputs) io_outputs = io_model(**tokens, **decoder_inputs) @@ -4291,14 +4344,24 @@ def test_compare_generation_to_io_binding( continue onnx_model = ORTModelForSeq2SeqLM.from_pretrained( - self._get_onnx_model_dir(model_id, model_arch, test_name), use_io_binding=False, use_cache=use_cache - ).to("cuda") + self._get_onnx_model_dir(model_id, model_arch, test_name), + use_io_binding=False, + use_cache=use_cache, + provider="CUDAExecutionProvider", + ) io_model = ORTModelForSeq2SeqLM.from_pretrained( - self._get_onnx_model_dir(model_id, model_arch, test_name), use_io_binding=True, use_cache=use_cache - ).to("cuda") + self._get_onnx_model_dir(model_id, model_arch, test_name), + use_io_binding=True, + use_cache=use_cache, + provider="CUDAExecutionProvider", + ) + + self.assertFalse(onnx_model.use_io_binding) + self.assertTrue(io_model.use_io_binding) tokenizer = get_preprocessor(model_id) tokens = tokenizer("This is a sample output", return_tensors="pt").to("cuda") + onnx_outputs = onnx_model.generate(**tokens, num_beams=num_beams) io_outputs = io_model.generate(**tokens, num_beams=num_beams) @@ -4681,17 +4744,16 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForSpeechSeq2Seq.from_pretrained( - self.onnx_model_dirs[test_name], use_io_binding=False - ).to("cuda") - io_model = ORTModelForSpeechSeq2Seq.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=True).to( - "cuda" + self.onnx_model_dirs[test_name], use_io_binding=False, provider="CUDAExecutionProvider" + ) + io_model = ORTModelForSpeechSeq2Seq.from_pretrained( + self.onnx_model_dirs[test_name], use_io_binding=True, provider="CUDAExecutionProvider" ) self.assertFalse(onnx_model.use_io_binding) self.assertTrue(io_model.use_io_binding) processor = get_preprocessor(model_id) - data = self._generate_random_audio_data() features = processor.feature_extractor([data] * 2, return_tensors="pt").to("cuda") @@ -4742,14 +4804,16 @@ def test_compare_generation_to_io_binding( model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForSpeechSeq2Seq.from_pretrained( - self.onnx_model_dirs[test_name], use_io_binding=False - ).to("cuda") - io_model = ORTModelForSpeechSeq2Seq.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=True).to( - "cuda" + self.onnx_model_dirs[test_name], use_io_binding=False, provider="CUDAExecutionProvider" + ) + io_model = ORTModelForSpeechSeq2Seq.from_pretrained( + self.onnx_model_dirs[test_name], use_io_binding=True, provider="CUDAExecutionProvider" ) - processor = get_preprocessor(model_id) + self.assertFalse(onnx_model.use_io_binding) + self.assertTrue(io_model.use_io_binding) + processor = get_preprocessor(model_id) data = self._generate_random_audio_data() features = processor.feature_extractor(data, return_tensors="pt").to("cuda") @@ -5177,21 +5241,19 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: self._setup(model_args) model_id = MODEL_NAMES[model_arch] - onnx_model = ORTModelForVision2Seq.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=False).to( - "cuda" + onnx_model = ORTModelForVision2Seq.from_pretrained( + self.onnx_model_dirs[test_name], use_io_binding=False, provider="CUDAExecutionProvider" ) - io_model = ORTModelForVision2Seq.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=True).to( - "cuda" + io_model = ORTModelForVision2Seq.from_pretrained( + self.onnx_model_dirs[test_name], use_io_binding=True, provider="CUDAExecutionProvider" ) self.assertFalse(onnx_model.use_io_binding) self.assertTrue(io_model.use_io_binding) - feature_extractor, tokenizer = self._get_preprocessors(model_id) - data = self._get_sample_image() + feature_extractor, tokenizer = self._get_preprocessors(model_id) pixel_values = feature_extractor([data] * 2, return_tensors="pt").pixel_values.to("cuda") - decoder_start_token_id = onnx_model.config.decoder.bos_token_id decoder_input_ids = torch.full((2, 1), decoder_start_token_id, dtype=torch.long).to("cuda") @@ -5233,16 +5295,18 @@ def test_compare_generation_to_io_binding( self._setup(model_args) model_id = MODEL_NAMES[model_arch] - onnx_model = ORTModelForVision2Seq.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=False).to( - "cuda" + onnx_model = ORTModelForVision2Seq.from_pretrained( + self.onnx_model_dirs[test_name], use_io_binding=False, provider="CUDAExecutionProvider" ) - io_model = ORTModelForVision2Seq.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=True).to( - "cuda" + io_model = ORTModelForVision2Seq.from_pretrained( + self.onnx_model_dirs[test_name], use_io_binding=True, provider="CUDAExecutionProvider" ) - feature_extractor, tokenizer = self._get_preprocessors(model_id) + self.assertFalse(onnx_model.use_io_binding) + self.assertTrue(io_model.use_io_binding) data = self._get_sample_image() + feature_extractor, tokenizer = self._get_preprocessors(model_id) features = feature_extractor(data, return_tensors="pt").to("cuda") onnx_outputs = onnx_model.generate(**features, num_beams=num_beams) @@ -5325,13 +5389,23 @@ def test_default_pipeline_and_model_device(self, *args, **kwargs): @require_torch_gpu @pytest.mark.cuda_ep_test def test_compare_to_io_binding(self, *args, **kwargs): - model_arch, model_id = args + _, model_id = args + set_seed(SEED) - onnx_model = ORTModelForCustomTasks.from_pretrained(model_id, use_io_binding=False).to("cuda") + onnx_model = ORTModelForCustomTasks.from_pretrained( + model_id, use_io_binding=False, provider="CUDAExecutionProvider" + ) set_seed(SEED) - io_model = ORTModelForCustomTasks.from_pretrained(model_id, use_io_binding=True).to("cuda") + io_model = ORTModelForCustomTasks.from_pretrained( + model_id, use_io_binding=True, provider="CUDAExecutionProvider" + ) + + self.assertFalse(onnx_model.use_io_binding) + self.assertTrue(io_model.use_io_binding) + tokenizer = get_preprocessor(model_id) - tokens = tokenizer("This is a sample output", return_tensors="pt") + tokens = tokenizer("This is a sample output", return_tensors="pt").to("cuda") + onnx_outputs = onnx_model(**tokens) io_outputs = io_model(**tokens) @@ -5590,25 +5664,30 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: self._setup(model_args) model_id = MODEL_NAMES[model_arch] - onnx_model = ORTModelForPix2Struct.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=False) - io_model = ORTModelForPix2Struct.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=True) + onnx_model = ORTModelForPix2Struct.from_pretrained( + self.onnx_model_dirs[test_name], use_io_binding=False, provider="CUDAExecutionProvider" + ) + io_model = ORTModelForPix2Struct.from_pretrained( + self.onnx_model_dirs[test_name], use_io_binding=True, provider="CUDAExecutionProvider" + ) self.assertFalse(onnx_model.use_io_binding) self.assertTrue(io_model.use_io_binding) preprocessor = get_preprocessor(model_id) - question = [ "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud and this is even longer and longer and longer and longer and hey", "Who are you?", ] - inputs = preprocessor(images=[self.IMAGE, self.IMAGE], text=question, padding=True, return_tensors="pt") + inputs = preprocessor(images=[self.IMAGE, self.IMAGE], text=question, padding=True, return_tensors="pt").to( + "cuda" + ) del inputs["decoder_attention_mask"] del inputs["decoder_input_ids"] decoder_start_token_id = onnx_model.config.decoder_start_token_id decoder_inputs = { - "decoder_input_ids": torch.ones((2, 1), dtype=torch.long) * decoder_start_token_id, - "decoder_attention_mask": torch.ones((2, 1), dtype=torch.int64), + "decoder_input_ids": torch.ones((2, 1), dtype=torch.long).to("cuda") * decoder_start_token_id, + "decoder_attention_mask": torch.ones((2, 1), dtype=torch.int64).to("cuda"), } onnx_outputs = onnx_model(**inputs, **decoder_inputs) @@ -5651,15 +5730,24 @@ def test_compare_generation_to_io_binding( self._setup(model_args) model_id = MODEL_NAMES[model_arch] - onnx_model = ORTModelForPix2Struct.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=False) - io_model = ORTModelForPix2Struct.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=True) + onnx_model = ORTModelForPix2Struct.from_pretrained( + self.onnx_model_dirs[test_name], use_io_binding=False, provider="CUDAExecutionProvider" + ) + io_model = ORTModelForPix2Struct.from_pretrained( + self.onnx_model_dirs[test_name], use_io_binding=True, provider="CUDAExecutionProvider" + ) - preprocessor = get_preprocessor(model_id) + self.assertFalse(onnx_model.use_io_binding) + self.assertTrue(io_model.use_io_binding) + preprocessor = get_preprocessor(model_id) question = ["What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud", "Who are you?"] - inputs = preprocessor(images=[self.IMAGE, self.IMAGE], text=question, padding=True, return_tensors="pt") + inputs = preprocessor(images=[self.IMAGE, self.IMAGE], text=question, padding=True, return_tensors="pt").to( + "cuda" + ) del inputs["decoder_attention_mask"] del inputs["decoder_input_ids"] + onnx_outputs = onnx_model.generate(**inputs, num_beams=num_beams) io_outputs = io_model.generate(**inputs, num_beams=num_beams)