From 41b8f98193d30f5a02ef7d3b1983e97344b28d15 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 3 Aug 2023 17:09:23 +0200 Subject: [PATCH 01/76] ONNX export decoder model refactorization --- optimum/exporters/onnx/base.py | 2 +- optimum/exporters/onnx/config.py | 2 +- optimum/onnxruntime/modeling_decoder.py | 550 ++++++++++++++++++++++-- optimum/utils/modeling_utils.py | 77 ++++ 4 files changed, 584 insertions(+), 47 deletions(-) diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 3c50389726b..7055ef816b1 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -541,7 +541,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]: elif self.task == "feature-extraction": common_outputs = OrderedDict({"last_hidden_state": {0: "batch_size"}}) else: - common_outputs = OrderedDict({"logits": {0: "batch_size"}}) + common_outputs = OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}) if self.use_present_in_outputs: self.add_past_key_values(common_outputs, direction="outputs") return common_outputs diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index e8aef99649f..a9a33e42eb8 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -69,7 +69,7 @@ class TextDecoderOnnxConfig(OnnxConfigWithPast): @property def inputs(self) -> Dict[str, Dict[int, str]]: if self.use_past_in_inputs: - common_inputs = {"input_ids": {0: "batch_size"}} + common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}} self.add_past_key_values(common_inputs, direction="inputs") common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"} else: diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 1ffc81d8832..ecfbaa54380 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -19,18 +19,19 @@ from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +import numpy as np import torch from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError -from transformers import AutoModelForCausalLM, GenerationConfig +from transformers import AutoModelForCausalLM, GenerationConfig, PretrainedConfig from transformers.file_utils import add_start_docstrings_to_model_forward -from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions +from transformers.modeling_outputs import CausalLMOutputWithPast import onnxruntime - -from ..exporters.onnx import main_export +from ..exporters import TasksManager +from ..exporters.onnx import export, main_export from ..onnx.utils import _get_external_data_paths -from ..utils import check_if_transformers_greater +from ..utils import NormalizedConfigManager, check_if_transformers_greater from ..utils.file_utils import validate_file_exists from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors from .base import ORTDecoder @@ -40,6 +41,7 @@ from .utils import ( ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, + ONNX_WEIGHTS_NAME, get_provider_for_device, parse_device, validate_provider_availability, @@ -56,6 +58,25 @@ from transformers.generation_utils import GenerationMixin +from huggingface_hub import hf_hub_download +from transformers.modeling_outputs import CausalLMOutputWithPast + +from .utils import MULTI_QUERY_ATTN_MODELS + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + +if check_if_transformers_greater("4.25.0"): + from transformers.generation import GenerationMixin +else: + from transformers.generation_utils import GenerationMixin + + +from ..utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask + + logger = logging.getLogger(__name__) DECODER_INPUTS_DOCSTRING = r""" @@ -622,7 +643,7 @@ def to(self, device: Union[torch.device, str, int]): return self -class ORTModelForCausalLM(ORTModelDecoder, GenerationMixin): +class ORTModelForCausalLM(ORTModel, GenerationMixin): """ ONNX model with a causal language modeling head for ONNX Runtime inference. """ @@ -630,6 +651,46 @@ class ORTModelForCausalLM(ORTModelDecoder, GenerationMixin): auto_model_class = AutoModelForCausalLM main_input_name = "input_ids" + def __init__( + self, + model: onnxruntime.InferenceSession, + config: "PretrainedConfig", + use_io_binding: Optional[bool] = None, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + preprocessors: Optional[List] = None, + generation_config: Optional[GenerationConfig] = None, + **kwargs, + ): + if use_io_binding is None: + if model.get_providers()[0] in ["CPUExecutionProvider", "CUDAExecutionProvider"]: + use_io_binding = True + else: + use_io_binding = False + + super().__init__(model, config, use_io_binding, model_save_dir, preprocessors, **kwargs) + + self.num_pkv = 2 + self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) + self.key_value_input_names = [key for key in self.inputs_names if (".key" in key) or (".value" in key)] + self.key_value_output_names = [key for key in self.output_names if (".key" in key) or (".value" in key)] + self.use_cache = len(self.key_value_input_names) > 0 + + if generation_config is None: + generation_config = GenerationConfig.from_model_config(config) + self.generation_config = generation_config + + # TODO : deprecate + self.onnx_paths = [self.model_path] + + # TODO : deprecate + self.use_merged = "use_cache_branch" in self.inputs_names + + self.use_fp16 = False + for inp in model.get_inputs(): + if inp.name == "past_key_values" and inp.type == "tensor(float16)": + self.use_fp16 = True + break + @add_start_docstrings_to_model_forward( CAUSALLM_ONNX_MODEL_DOCSTRING.format("batch_size, sequence_length") + TEXT_GENERATION_EXAMPLE.format( @@ -640,75 +701,474 @@ class ORTModelForCausalLM(ORTModelDecoder, GenerationMixin): ) def forward( self, - input_ids: torch.LongTensor = None, + input_ids: torch.LongTensor, attention_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, labels: Optional[torch.LongTensor] = None, **kwargs, - ) -> CausalLMOutputWithCrossAttentions: - if past_key_values is None or self.use_cache is False: - outputs = self.decoder( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - labels=labels, + ) -> CausalLMOutputWithPast: + use_torch = isinstance(input_ids, torch.Tensor) + self.raise_on_numpy_input_io_binding(use_torch) + + inputs = {} + known_output_shapes = {} + use_cache_branch = None + loss = None + if self.use_cache: + if past_key_values is not None: + input_ids = input_ids[:, -1:] + # Flatten the past_key_values (no need to flatten for models using multi-query attn) + if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: + past_key_values = tuple( + past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer + ) + + # Create dummy past_key_values for decoder first generation step if none given + use_cache_branch, past_key_values, known_output_shapes = self.prepare_past_key_values( + input_ids, past_key_values, use_torch ) - elif self.use_merged is True: - outputs = self.decoder( - input_ids=input_ids[:, -1:], - past_key_values=past_key_values, - attention_mask=attention_mask, + + if self.use_io_binding: + # TODO: fix transformers generate to have contiguous input_ids here already + # For an unknown reason, calling `contiguous()` here is necessary to not have errors + # on CPU EP with batch size > 1, despite it being also called in _prepare_io_binding. + # I suspect the reason is the contiguous python list that messes something up? + model_inputs = [input_ids.contiguous()] + + if "attention_mask" in self.inputs_names: + model_inputs.append(attention_mask) + + if past_key_values is not None: + model_inputs += past_key_values + + if use_cache_branch is not None: + model_inputs.append(use_cache_branch) + + if "labels" in self.inputs_names: + model_inputs.append(labels) + known_output_shapes.update({"loss": []}) + + io_binding, output_shapes, output_buffers = self._prepare_io_binding( + self.model, + *model_inputs, + known_output_shapes=known_output_shapes, + ordered_input_names=self._ordered_input_names, ) + + if self.device.type == "cpu": + self.model.run_with_iobinding(io_binding) + else: + io_binding.synchronize_inputs() + self.model.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() + + if self.use_cache: + # Tuple of length equal to : number of layer * number of past_key_value per decoder layer(2) + past_key_values = () + for name in self.key_value_output_names: + past_key_values += (output_buffers[name].view(output_shapes[name]),) + + logits = output_buffers["logits"].view(output_shapes["logits"]) + + if "loss" in self.output_names: + loss = output_buffers["loss"].view(output_shapes["loss"]) + else: - outputs = self.decoder_with_past( - input_ids=input_ids[:, -1:], - past_key_values=past_key_values, - attention_mask=attention_mask, - labels=labels, + inputs["input_ids"] = input_ids.cpu().detach().numpy() if use_torch else input_ids + if "attention_mask" in self.inputs_names: + inputs["attention_mask"] = attention_mask.cpu().detach().numpy() if use_torch else attention_mask + if "labels" in self.inputs_names: + inputs["labels"] = labels.cpu().detach().numpy() if use_torch else labels + + # Add the past_key_values to the decoder inputs + if past_key_values is not None: + for input_name, past_key_value in zip(self.key_value_input_names, past_key_values): + inputs[input_name] = past_key_value.cpu().detach().numpy() if use_torch else past_key_value + + if use_cache_branch is not None: + inputs["use_cache_branch"] = use_cache_branch.cpu().detach().numpy() if use_torch else use_cache_branch + + outputs = self.model.run(None, inputs) + + if self.use_cache: + # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 for the self-attention) + past_key_values = tuple( + torch.from_numpy(outputs[self.output_names[key]]).to(self.device) + for key in self.key_value_output_names + ) + + logits = torch.from_numpy(outputs[self.output_names["logits"]]).to(self.device) + if "loss" in self.output_names: + loss = torch.from_numpy(outputs[self.output_names["loss"]]).to(self.device) + + if self.use_cache and self.config.model_type not in MULTI_QUERY_ATTN_MODELS: + # Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and + # per decoder layer + past_key_values = tuple( + past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv) + ) + + return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=past_key_values) + + def prepare_past_key_values( + self, + input_ids: Union[None, torch.LongTensor, np.ndarray], + past_key_values: Union[None, Tuple[torch.FloatTensor], Tuple[np.ndarray]], + use_torch: bool, + ): + sequence_length = input_ids.shape[1] + + constructor = torch if use_torch else np + if self.use_merged: + # Uses without/with branch of a merged decoder depending on whether real past key values are passed + use_cache_branch = constructor.full((1,), past_key_values is not None) + else: + # Uses separate decoders + use_cache_branch = None + + if use_torch and use_cache_branch is not None: + use_cache_branch = use_cache_branch.to(self.device) + + # Generate dummy past for the first forward if uses a merged decoder + if past_key_values is None: + batch_size = input_ids.shape[0] + num_attention_heads = self.normalized_config.num_attention_heads + embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads + dtype = constructor.float16 if self.use_fp16 else constructor.float32 + # TODO: find a way to better handle this controlflow + # "1" is the dummy sequence length + if self.config.model_type == "bloom": + shape_value = (batch_size * num_attention_heads, 0, embed_size_per_head) + shape_key = (batch_size * num_attention_heads, embed_size_per_head, 0) + key = constructor.zeros(shape_key, dtype=dtype) + value = constructor.zeros(shape_value, dtype=dtype) + + if use_torch: + key = key.to(self.device) + value = value.to(self.device) + + past_key_values = tuple( + key_or_value for _ in range(len(self.key_value_input_names) // 2) for key_or_value in [key, value] + ) + elif self.config.model_type in MULTI_QUERY_ATTN_MODELS: + shape_key_and_value = (batch_size, 0, embed_size_per_head * 2) + key_and_value = constructor.zeros(shape_key_and_value, dtype=dtype) + + if use_torch: + key_and_value = key_and_value.to(self.device) + + past_key_values = tuple(key_and_value for _ in range(len(self.key_value_input_names))) + else: + shape = (batch_size, num_attention_heads, 0, embed_size_per_head) + key_or_value = constructor.zeros(shape, dtype=dtype) + + if use_torch: + key_or_value = key_or_value.to(self.device) + + past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names))) + + pkv_output_shape = {} + for name, value in zip(self.key_value_output_names, past_key_values): + shape = [*value.shape] + # TODO : modify for different pkv shape : bloom / big_code + shape[2] += sequence_length + pkv_output_shape[name] = shape + + return use_cache_branch, past_key_values, pkv_output_shape + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + config: "PretrainedConfig", + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + force_download: bool = False, + cache_dir: Optional[str] = None, + file_name: Optional[str] = None, + subfolder: str = "", + use_cache: bool = True, + local_files_only: bool = False, + use_merged: Optional[bool] = None, + provider: str = "CPUExecutionProvider", + session_options: Optional[onnxruntime.SessionOptions] = None, + provider_options: Optional[Dict[str, Any]] = None, + use_io_binding: Optional[bool] = None, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + **kwargs, + ) -> "ORTModelForCausalLM": + # model_path = Path(model_id) + + # We do not implement the logic for use_cache=False, use_merged=True + if use_cache is False: + if use_merged is True: + raise ValueError( + "The parameters combination use_cache=False, use_merged=True is not supported." + " To use a merged decoder, past key values must be used." + ) + use_merged = False + + # TODO : deprecate + decoder_file_name = kwargs.pop("decoder_file_name", None) + decoder_with_past_file_name = kwargs.pop("decoder_with_past_file_name", None) + + file_name = file_name or (decoder_with_past_file_name if use_cache else decoder_file_name) + + if file_name is None: + decoder_path = None + # We use `is not False` here to include two cases: use_merged = None (in which case we auto-detect it), + # and use_merged = True (explicitely specified by the user) + if use_merged is not False: + try: + decoder_path = ORTModelForCausalLM.infer_onnx_filename( + model_id, + [DECODER_MERGED_ONNX_FILE_PATTERN], + argument_name=None, + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + ) + use_merged = True + file_name = decoder_path.name + except FileNotFoundError as e: + if use_merged is True: + raise FileNotFoundError( + "The parameter `use_merged=True` was passed to ORTModelForCausalLM.from_pretrained()" + " but no ONNX file for a merged decoder could be found in" + f" {str(Path(model_id, subfolder))}, with the error: {e}" + ) + use_merged = False + + exclude_decoder = r"(.*)?((? "ORTModelForCausalLM": + file_name = ONNX_WEIGHTS_NAME + + if use_merged: + logger.warning( + "The `use_merged` argument is deprecated when the model is exported, and not used anymore." + ) + use_merged = False + + if task is None: + task = cls._auto_model_to_task(cls.auto_model_class) + + save_dir = TemporaryDirectory() + save_dir_path = Path(save_dir.name) + model_kwargs = { + "revision": revision, + "use_auth_token": use_auth_token, + "cache_dir": cache_dir, + "subfolder": subfolder, + "local_files_only": local_files_only, + "force_download": force_download, + } + model = TasksManager.get_model_from_task(task, model_id, **model_kwargs) + onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) + onnx_config = onnx_config_constructor(model.config, use_past=use_cache) + + # TODO : create ModelPatcher to patch each architecture + if config.model_type == "bloom": + model.transformer._prepare_attn_mask = _prepare_attn_mask + elif config.model_type == "llama": + model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + elif config.model_type in ("blenderbot-small", "blenderbot", "opt", "pegasus", "bart"): + model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + + # Export the model to the ONNX format + export(model=model, config=onnx_config, output=save_dir_path / file_name) + + # TODO : use main_export + config.save_pretrained(save_dir_path) + maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder) + + return cls._from_pretrained( + save_dir_path, + config, + use_cache=use_cache, + use_merged=use_merged, + provider=provider, + session_options=session_options, + provider_options=provider_options, + use_io_binding=use_io_binding, + model_save_dir=save_dir, + file_name=file_name, ) # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - - attention_mask = kwargs.get("attention_mask", None) # input_ids.new_ones(input_ids.shape) + past_key_values = past_key_values or kwargs.get("past", None) use_cache = kwargs.get("use_cache", None) + # `past_key_values` may be in the stardard format (e.g. in contrastive search), converts to bloom's format if needed + if past_key_values is not None and self.config.model_type == "bloom": + if past_key_values[0][0].shape[0] == input_ids.shape[0]: + past_key_values = self._convert_to_bloom_cache(past_key_values) + return { "input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache, "position_ids": None, - "attention_mask": attention_mask, + "attention_mask": kwargs.get("attention_mask", None), # input_ids.new_ones(input_ids.shape) } - # Copied from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache - @staticmethod - def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache + def _reorder_cache( + self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. + This is required to match `past_key_values` with the correct beam_idx at every generation step. + """ + if self.config.model_type == "bloom": + return self._reorder_cache_bloom(past_key_values, beam_idx) + + # from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache return tuple( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past + for layer_past in past_key_values + ) + + # TODO: remove + # Copied from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache + def _reorder_cache_bloom( + self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called for bloom architecture. + This is required to match `past_key_values` with the correct beam_idx at every generation step. + """ + standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx)) + + # Get a copy of `beam_idx` on all the devices where we need those indices. + device_to_beam_idx = { + past_state.device: beam_idx.to(past_state.device) + for layer_past in past_key_values + for past_state in layer_past + } + reordered_past = tuple( + ( + layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), + layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), + ) + for layer_past in standardized_past + ) + return self._convert_to_bloom_cache(reordered_past) + + # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_bloom_cache + @staticmethod + def _convert_to_bloom_cache(past_key_value: Tuple[Tuple[torch.Tensor]]) -> Tuple[Tuple[torch.Tensor]]: + """ + Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) + """ + batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape + batch_size_times_num_heads = batch_size * num_heads + # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length] + # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim] + return tuple( + ( + layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length), + layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim), + ) + for layer_past in past_key_value + ) + + # Adapted from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_standard_cache + def _convert_to_standard_cache( + self, past_key_value: Tuple[Tuple[torch.Tensor]], batch_size: int + ) -> Tuple[Tuple[torch.Tensor]]: + """ + Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, num_heads, ...])) + """ + if self.config.model_type != "bloom": + return past_key_value + + batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape + num_heads = batch_size_times_num_heads // batch_size + # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] + # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim] + return tuple( + ( + layer_past[0].view(batch_size, num_heads, head_dim, seq_length), + layer_past[1].view(batch_size, num_heads, seq_length, head_dim), + ) + for layer_past in past_key_value ) def can_generate(self): """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" return True - @classmethod - def _from_pretrained( - cls, - model_id: Union[str, Path], - config: "PretrainedConfig", - **kwargs, - ): - if config.model_type == "bloom": - return super()._from_pretrained(model_id, config, init_cls=ORTBloomForCausalLM, **kwargs) - return super()._from_pretrained(model_id, config, init_cls=ORTModelForCausalLM, **kwargs) - class ORTBloomForCausalLM(ORTModelForCausalLM): # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation diff --git a/optimum/utils/modeling_utils.py b/optimum/utils/modeling_utils.py index 89f2f5598a6..fcf75aa1e89 100644 --- a/optimum/utils/modeling_utils.py +++ b/optimum/utils/modeling_utils.py @@ -13,6 +13,9 @@ # limitations under the License. import functools +from typing import Tuple + +import torch def recurse_getattr(obj, attr: str): @@ -39,3 +42,77 @@ def recurse_setattr(module, name, value): else: name, rest = name.split(".", 1) recurse_setattr(getattr(module, name), rest, value) + + +# Modified from transformers.models.bloom.modeling_bloom._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, + device: torch.device, + past_key_values_length: int, + dtype: torch.dtype = torch.bool, +) -> torch.BoolTensor: + """ + Make causal mask used for bi-directional self-attention. + """ + batch_size, target_length = input_ids_shape + mask = torch.zeros((target_length, target_length + past_key_values_length), dtype=dtype, device=device) + seq_ids = torch.arange(target_length, device=device) + + mask[:, past_key_values_length:] = ( + (seq_ids[:, None] < seq_ids[None, :]) * torch.finfo(dtype).min + if torch.is_floating_point(mask) + else seq_ids[:, None] < seq_ids[None, :] + ) + + return mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) + + +# Modified from transformers.models.bloom.modeling_bloom._prepare_attn_mask +def _prepare_attn_mask( + attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int +) -> torch.BoolTensor: + from transformers.models.bloom.modeling_bloom import _expand_mask + + # create causal mask + # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] + combined_attention_mask = None + device = attention_mask.device + _, src_length = input_shape + + combined_attention_mask = _make_causal_mask( + input_shape, device=device, past_key_values_length=past_key_values_length + ) + # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]_prepare_decoder_attention_mask + expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask + ) + + return combined_attention_mask + + +# Modified from transformers.models.llama.modeling_llama._prepare_decoder_attention_mask +def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length): + from transformers.models.llama.modeling_llama import _expand_mask + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + + combined_attention_mask = _make_causal_mask( + input_shape, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + dtype=inputs_embeds.dtype, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask From f91a0184d129c59c8ed7f84b32e2fe7c1724f7e1 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 4 Aug 2023 16:19:02 +0200 Subject: [PATCH 02/76] fix style --- optimum/onnxruntime/modeling_decoder.py | 31 ++++++------------------- 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index ecfbaa54380..2bcb6f4af4e 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -33,12 +33,14 @@ from ..onnx.utils import _get_external_data_paths from ..utils import NormalizedConfigManager, check_if_transformers_greater from ..utils.file_utils import validate_file_exists +from ..utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors from .base import ORTDecoder from .constants import DECODER_MERGED_ONNX_FILE_PATTERN, DECODER_ONNX_FILE_PATTERN, DECODER_WITH_PAST_ONNX_FILE_PATTERN from .modeling_ort import ORTModel from .models.bloom import bloom_convert_to_bloom_cache, bloom_convert_to_standard_cache from .utils import ( + MULTI_QUERY_ATTN_MODELS, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_WEIGHTS_NAME, @@ -51,32 +53,12 @@ if TYPE_CHECKING: from transformers import PretrainedConfig - if check_if_transformers_greater("4.25.0"): from transformers.generation import GenerationMixin else: from transformers.generation_utils import GenerationMixin -from huggingface_hub import hf_hub_download -from transformers.modeling_outputs import CausalLMOutputWithPast - -from .utils import MULTI_QUERY_ATTN_MODELS - - -if TYPE_CHECKING: - from transformers import PretrainedConfig - - -if check_if_transformers_greater("4.25.0"): - from transformers.generation import GenerationMixin -else: - from transformers.generation_utils import GenerationMixin - - -from ..utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask - - logger = logging.getLogger(__name__) DECODER_INPUTS_DOCSTRING = r""" @@ -1012,9 +994,7 @@ def _from_transformers( file_name = ONNX_WEIGHTS_NAME if use_merged: - logger.warning( - "The `use_merged` argument is deprecated when the model is exported, and not used anymore." - ) + logger.warning("The `use_merged` argument is deprecated when the model is exported, and not used anymore.") use_merged = False if task is None: @@ -1064,6 +1044,9 @@ def _from_transformers( # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + + attention_mask = kwargs.get("attention_mask", None) # input_ids.new_ones(input_ids.shape) past_key_values = past_key_values or kwargs.get("past", None) use_cache = kwargs.get("use_cache", None) @@ -1077,7 +1060,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg "past_key_values": past_key_values, "use_cache": use_cache, "position_ids": None, - "attention_mask": kwargs.get("attention_mask", None), # input_ids.new_ones(input_ids.shape) + "attention_mask": attention_mask, } # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache From 4ce5fbe006c5a8b647ef48df185d96eb2e1971c5 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 4 Aug 2023 17:33:42 +0200 Subject: [PATCH 03/76] fix index --- optimum/onnxruntime/modeling_decoder.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 2bcb6f4af4e..c8085255834 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -23,7 +23,7 @@ import torch from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError -from transformers import AutoModelForCausalLM, GenerationConfig, PretrainedConfig +from transformers import AutoModelForCausalLM, GenerationConfig from transformers.file_utils import add_start_docstrings_to_model_forward from transformers.modeling_outputs import CausalLMOutputWithPast @@ -49,7 +49,6 @@ validate_provider_availability, ) - if TYPE_CHECKING: from transformers import PretrainedConfig @@ -852,8 +851,14 @@ def prepare_past_key_values( pkv_output_shape = {} for name, value in zip(self.key_value_output_names, past_key_values): shape = [*value.shape] - # TODO : modify for different pkv shape : bloom / big_code - shape[2] += sequence_length + index = ( + 1 + if self.config.model_type in MULTI_QUERY_ATTN_MODELS + or (self.config.model_type == "bloom" and "value" in name) + else 2 + ) + + shape[index] += sequence_length pkv_output_shape[name] = shape return use_cache_branch, past_key_values, pkv_output_shape @@ -963,8 +968,8 @@ def _from_pretrained( if use_cache ^ model.use_cache: raise ValueError( - f"`use_cache` was set to `{use_cache}` but the loaded model only supports `use_cache={self.use_cache}`. " - f"Please load your current model with `use_cache={self.use_cache}` or export the original model " + f"`use_cache` was set to `{use_cache}` but the loaded model only supports `use_cache={model.use_cache}`. " + f"Please load your current model with `use_cache={model.use_cache}` or export the original model " f"once again with `use_cache={use_cache}` when calling the `from_pretrained` method. " "To export your model, simply set `export=True`." ) From 9fa05e4ccde771d35a99c8a290dc4a94a34d7af1 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 12 Sep 2023 16:14:02 +0200 Subject: [PATCH 04/76] fix IO bindings --- optimum/onnxruntime/modeling_ort.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index a20f7f0aa22..6eb2d3fa6b0 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -754,13 +754,20 @@ def _prepare_io_binding( name = ordered_input_names[idx] tensor = tensor.contiguous() input_name_to_shape[name] = tensor.shape + + data_ptr = tensor.data_ptr() + if "past" in name and data_ptr == 0: + # During first generation, sequence_length can be 0 when use_cache=True, which results in data_ptr to also be 0. + # To keep compatibility with IO binding, we pass the data pointer of input_ids instead. This will have no impact because past_key_values will not be used during the first generation. + data_ptr = model_inputs[0].data_ptr() + io_binding.bind_input( name, tensor.device.type, IOBindingHelper.get_device_index(self.device), name_to_np_type[name], tuple(tensor.shape), - tensor.data_ptr(), + data_ptr, ) dimensions = {} for input_ in model.get_inputs(): From 3a0d76ab743ed49764cd9726a7c733e275ec2629 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 12 Sep 2023 16:18:31 +0200 Subject: [PATCH 05/76] format --- optimum/onnxruntime/modeling_decoder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 824029b3b9b..5eb4a9ed41d 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -49,6 +49,7 @@ validate_provider_availability, ) + if TYPE_CHECKING: from transformers import PretrainedConfig From b0aa23412a72e51d617772888f0df417444cd1d8 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 12 Sep 2023 16:20:52 +0200 Subject: [PATCH 06/76] enable mpt support --- optimum/onnxruntime/modeling_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 5eb4a9ed41d..80a17c96a17 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -1022,7 +1022,7 @@ def _from_transformers( onnx_config = onnx_config_constructor(model.config, use_past=use_cache) # TODO : create ModelPatcher to patch each architecture - if config.model_type == "bloom": + if config.model_type in {"bloom", "mpt"}: model.transformer._prepare_attn_mask = _prepare_attn_mask elif config.model_type == "llama": model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask From dfabefd75f2c6a7500a3fe1b45254ec53682c2c4 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 12 Sep 2023 16:26:52 +0200 Subject: [PATCH 07/76] format --- optimum/onnxruntime/modeling_decoder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 80a17c96a17..716a720e638 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -28,6 +28,7 @@ from transformers.modeling_outputs import CausalLMOutputWithPast import onnxruntime + from ..exporters import TasksManager from ..exporters.onnx import export, main_export from ..onnx.utils import _get_external_data_paths From 35df7bdeba675071b1e6b66cd645e6d92971b634 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 13 Sep 2023 16:13:48 +0200 Subject: [PATCH 08/76] add trust remote code --- optimum/onnxruntime/modeling_decoder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 716a720e638..a68467819af 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -903,7 +903,6 @@ def _from_pretrained( decoder_with_past_file_name = kwargs.pop("decoder_with_past_file_name", None) file_name = file_name or (decoder_with_past_file_name if use_cache else decoder_file_name) - if file_name is None: decoder_path = None # We use `is not False` here to include two cases: use_merged = None (in which case we auto-detect it), @@ -1017,7 +1016,9 @@ def _from_transformers( "subfolder": subfolder, "local_files_only": local_files_only, "force_download": force_download, + "trust_remote_code" : trust_remote_code, } + model = TasksManager.get_model_from_task(task, model_id, **model_kwargs) onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) onnx_config = onnx_config_constructor(model.config, use_past=use_cache) From 469edc837e1ebda51b6b26388cb9a5e2f5039733 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 13 Sep 2023 16:27:42 +0200 Subject: [PATCH 09/76] fix test --- tests/onnxruntime/test_modeling.py | 73 +++++++++++++++++++----------- 1 file changed, 47 insertions(+), 26 deletions(-) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 0fda34c0d94..22d42c8e682 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -862,6 +862,29 @@ def test_save_load_decoder_model_with_external_data(self, use_cache: bool): ) os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") + @parameterized.expand([(False,), (True,)]) + def test_save_load_decoder_model_with_external_data(self, use_cache: bool): + with tempfile.TemporaryDirectory() as tmpdirname: + os.environ["FORCE_ONNX_EXTERNAL_DATA"] = "1" # force exporting small model with external data + model = ORTModelForCausalLM.from_pretrained( + MODEL_NAMES["gpt2"], + use_cache=use_cache, + export=True, + use_merged=False, + use_io_binding=False, + ) + model.save_pretrained(tmpdirname) + + # verify external data is exported + folder_contents = os.listdir(tmpdirname) + self.assertTrue(ONNX_WEIGHTS_NAME in folder_contents) + self.assertTrue(ONNX_WEIGHTS_NAME + "_data" in folder_contents) + self.assertFalse(use_cache ^ model.use_cache) + + # verify loading from local folder works + model = ORTModelForCausalLM.from_pretrained(tmpdirname, use_cache=use_cache, export=False, use_io_binding=False) + os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") + @parameterized.expand([(False,), (True,)]) def test_save_load_seq2seq_model_with_external_data(self, use_cache: bool): with tempfile.TemporaryDirectory() as tmpdirname: @@ -1982,13 +2005,15 @@ def test_load_model_from_hub_onnx(self): self.assertFalse(model.use_merged) self.assertTrue(model.use_cache) - self.assertTrue(model.decoder_with_past is not None) + self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertEqual(model.onnx_paths[0].name, ONNX_DECODER_WITH_PAST_NAME) model = ORTModelForCausalLM.from_pretrained("fxmarty/onnx-tiny-random-gpt2-with-merge") self.assertTrue(model.use_merged) self.assertTrue(model.use_cache) - self.assertTrue(model.decoder_with_past is None) + self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertEqual(model.onnx_paths[0].name, ONNX_DECODER_MERGED_NAME) def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: @@ -2033,12 +2058,14 @@ def test_merge_from_transformers_and_save(self, model_arch): model = ORTModelForCausalLM.from_pretrained(model_id, export=True, use_merged=True) with tempfile.TemporaryDirectory() as tmpdir: model.save_pretrained(tmpdir) - save_path = os.path.join(tmpdir, ONNX_DECODER_MERGED_NAME) - self.assertTrue(has_onnx_input(save_path, "use_cache_branch")) + save_path = os.path.join(tmpdir, ONNX_WEIGHTS_NAME) + self.assertFalse(has_onnx_input(save_path, "use_cache_branch")) folder_contents = os.listdir(tmpdir) self.assertTrue(ONNX_DECODER_NAME not in folder_contents) self.assertTrue(ONNX_DECODER_WITH_PAST_NAME not in folder_contents) + self.assertTrue(ONNX_DECODER_MERGED_NAME not in folder_contents) + @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_merge_from_onnx_and_save(self, model_arch): @@ -2054,7 +2081,7 @@ def test_merge_from_onnx_and_save(self, model_arch): model = ORTModelForCausalLM.from_pretrained(tmpdir) self.assertTrue(model.use_merged) - self.assertTrue(model.decoder_with_past is None) + self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) model.save_pretrained(tmpdir + "_save") save_path = os.path.join(tmpdir + "_save", ONNX_DECODER_MERGED_NAME) @@ -2063,6 +2090,8 @@ def test_merge_from_onnx_and_save(self, model_arch): folder_contents = os.listdir(tmpdir + "_save") self.assertTrue(ONNX_DECODER_NAME not in folder_contents) self.assertTrue(ONNX_DECODER_WITH_PAST_NAME not in folder_contents) + self.assertTrue(ONNX_WEIGHTS_NAME not in folder_contents) + @parameterized.expand(grid_parameters(FULL_GRID)) def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): @@ -2087,27 +2116,17 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach use_cache=use_cache, use_io_binding=use_io_binding, ) - if use_merged is False: - model_path = Path(self.onnx_model_dirs[test_name], ONNX_DECODER_NAME) - self.assertFalse(has_onnx_input(model_path, "use_cache_branch")) - self.assertEqual(onnx_model.use_merged, False) - else: - model_path = Path(self.onnx_model_dirs[test_name], ONNX_DECODER_MERGED_NAME) - self.assertTrue(has_onnx_input(model_path, "use_cache_branch")) - self.assertEqual(onnx_model.use_merged, True) - - self.assertIsInstance(onnx_model.decoder, ORTDecoder) - if onnx_model.use_cache is True and onnx_model.use_merged is False: - self.assertIsInstance(onnx_model.decoder_with_past, ORTDecoder) - if onnx_model.use_cache is True and onnx_model.use_merged is True: - self.assertTrue(onnx_model.decoder_with_past is None) + model_path = Path(self.onnx_model_dirs[test_name], ONNX_WEIGHTS_NAME) + self.assertFalse(has_onnx_input(model_path, "use_cache_branch")) + self.assertFalse(onnx_model.use_merged) + self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) transformers_model = AutoModelForCausalLM.from_pretrained(model_id) tokenizer = get_preprocessor(model_id) - tokens = tokenizer( + tokens = tokenizer( "This is a sample output", return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None, @@ -2125,6 +2144,7 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach gc.collect() + @parameterized.expand(grid_parameters(FULL_GRID)) def test_pipeline_ort_model(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): if use_cache is False and use_merged is True: @@ -2288,6 +2308,7 @@ def test_compare_with_and_without_past_key_values(self, model_arch): f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}", ) + @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]})) def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, model_arch: str, use_cache: bool): model_args = { @@ -2314,21 +2335,21 @@ def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, mode model_merged_dir = self.onnx_model_dirs[test_name + "_True"] model_not_merged = ORTModelForCausalLM.from_pretrained(model_not_merged_dir) - not_merged_onnx_path = Path(model_not_merged_dir, ONNX_DECODER_NAME) + not_merged_onnx_path = Path(model_not_merged_dir, ONNX_WEIGHTS_NAME) self.assertFalse(has_onnx_input(not_merged_onnx_path, "use_cache_branch")) - self.assertEqual(model_not_merged.use_merged, False) + self.assertFalse(model_not_merged.use_merged) model_merged = ORTModelForCausalLM.from_pretrained(model_merged_dir) - merged_onnx_path = Path(model_merged_dir, ONNX_DECODER_MERGED_NAME) - self.assertTrue(has_onnx_input(merged_onnx_path, "use_cache_branch")) - self.assertEqual(model_merged.decoder_with_past, None) - self.assertEqual(model_merged.use_merged, True) + merged_onnx_path = Path(model_merged_dir, ONNX_WEIGHTS_NAME) + self.assertFalse(has_onnx_input(merged_onnx_path, "use_cache_branch")) + self.assertFalse(model_merged.use_merged) outputs_model_not_merged = model_not_merged.generate(**tokens) outputs_model_merged = model_merged.generate(**tokens) self.assertTrue(torch.equal(outputs_model_merged, outputs_model_not_merged)) + @parameterized.expand( grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]}) ) From 77cc527ce0bff240abdccf0383a569f485405ba7 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 13 Sep 2023 16:28:07 +0200 Subject: [PATCH 10/76] format --- optimum/onnxruntime/modeling_decoder.py | 2 +- tests/onnxruntime/test_modeling.py | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index a68467819af..5fb31c952cb 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -1016,7 +1016,7 @@ def _from_transformers( "subfolder": subfolder, "local_files_only": local_files_only, "force_download": force_download, - "trust_remote_code" : trust_remote_code, + "trust_remote_code": trust_remote_code, } model = TasksManager.get_model_from_task(task, model_id, **model_kwargs) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 22d42c8e682..45f86966466 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -882,7 +882,9 @@ def test_save_load_decoder_model_with_external_data(self, use_cache: bool): self.assertFalse(use_cache ^ model.use_cache) # verify loading from local folder works - model = ORTModelForCausalLM.from_pretrained(tmpdirname, use_cache=use_cache, export=False, use_io_binding=False) + model = ORTModelForCausalLM.from_pretrained( + tmpdirname, use_cache=use_cache, export=False, use_io_binding=False + ) os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") @parameterized.expand([(False,), (True,)]) @@ -2066,7 +2068,6 @@ def test_merge_from_transformers_and_save(self, model_arch): self.assertTrue(ONNX_DECODER_WITH_PAST_NAME not in folder_contents) self.assertTrue(ONNX_DECODER_MERGED_NAME not in folder_contents) - @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_merge_from_onnx_and_save(self, model_arch): model_id = MODEL_NAMES[model_arch] @@ -2092,7 +2093,6 @@ def test_merge_from_onnx_and_save(self, model_arch): self.assertTrue(ONNX_DECODER_WITH_PAST_NAME not in folder_contents) self.assertTrue(ONNX_WEIGHTS_NAME not in folder_contents) - @parameterized.expand(grid_parameters(FULL_GRID)) def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): if use_cache is False and use_merged is True: @@ -2126,7 +2126,7 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach set_seed(SEED) transformers_model = AutoModelForCausalLM.from_pretrained(model_id) tokenizer = get_preprocessor(model_id) - tokens = tokenizer( + tokens = tokenizer( "This is a sample output", return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None, @@ -2144,7 +2144,6 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach gc.collect() - @parameterized.expand(grid_parameters(FULL_GRID)) def test_pipeline_ort_model(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): if use_cache is False and use_merged is True: @@ -2308,7 +2307,6 @@ def test_compare_with_and_without_past_key_values(self, model_arch): f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}", ) - @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]})) def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, model_arch: str, use_cache: bool): model_args = { @@ -2349,7 +2347,6 @@ def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, mode self.assertTrue(torch.equal(outputs_model_merged, outputs_model_not_merged)) - @parameterized.expand( grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]}) ) From 4f72a7ebf9f5f66206b2bf32ba27d80759cb3113 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 13 Sep 2023 16:30:47 +0200 Subject: [PATCH 11/76] rm redundant --- tests/onnxruntime/test_modeling.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 45f86966466..6d91173079b 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -834,33 +834,6 @@ def test_save_load_ort_model_with_external_data(self): model = ORTModelForSequenceClassification.from_pretrained(tmpdirname, export=False) os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") - @parameterized.expand([(False,), (True,)]) - def test_save_load_decoder_model_with_external_data(self, use_cache: bool): - with tempfile.TemporaryDirectory() as tmpdirname: - os.environ["FORCE_ONNX_EXTERNAL_DATA"] = "1" # force exporting small model with external data - model = ORTModelForCausalLM.from_pretrained( - MODEL_NAMES["gpt2"], - use_cache=use_cache, - export=True, - use_merged=False, - use_io_binding=False, - ) - model.save_pretrained(tmpdirname) - - # verify external data is exported - folder_contents = os.listdir(tmpdirname) - self.assertTrue(ONNX_DECODER_NAME in folder_contents) - self.assertTrue(ONNX_DECODER_NAME + "_data" in folder_contents) - - if use_cache: - self.assertTrue(ONNX_DECODER_WITH_PAST_NAME in folder_contents) - self.assertTrue(ONNX_DECODER_WITH_PAST_NAME + "_data" in folder_contents) - - # verify loading from local folder works - model = ORTModelForCausalLM.from_pretrained( - tmpdirname, use_cache=use_cache, export=False, use_io_binding=False - ) - os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") @parameterized.expand([(False,), (True,)]) def test_save_load_decoder_model_with_external_data(self, use_cache: bool): From 599c31c1598f6030095ba3a11cc05b3c18d662c6 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 13 Sep 2023 16:31:32 +0200 Subject: [PATCH 12/76] format --- tests/onnxruntime/test_modeling.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 6d91173079b..3269d3db3d6 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -834,7 +834,6 @@ def test_save_load_ort_model_with_external_data(self): model = ORTModelForSequenceClassification.from_pretrained(tmpdirname, export=False) os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") - @parameterized.expand([(False,), (True,)]) def test_save_load_decoder_model_with_external_data(self, use_cache: bool): with tempfile.TemporaryDirectory() as tmpdirname: From c13b64550e0ae5d4eb3c887e4977153f9de2b717 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 13 Sep 2023 17:45:38 +0200 Subject: [PATCH 13/76] fix --- optimum/onnxruntime/modeling_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 5fb31c952cb..602d538ffd8 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -1021,7 +1021,7 @@ def _from_transformers( model = TasksManager.get_model_from_task(task, model_id, **model_kwargs) onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) - onnx_config = onnx_config_constructor(model.config, use_past=use_cache) + onnx_config = onnx_config_constructor(model.config, use_past=use_cache, use_past_in_inputs=use_cache) # TODO : create ModelPatcher to patch each architecture if config.model_type in {"bloom", "mpt"}: From a0d0802daaec2df3eaeebba742d62345fc23def7 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 14 Sep 2023 14:56:53 +0200 Subject: [PATCH 14/76] fix quantization --- optimum/onnxruntime/modeling_decoder.py | 5 ----- optimum/onnxruntime/quantization.py | 7 ------- tests/onnxruntime/test_quantization.py | 12 +++--------- 3 files changed, 3 insertions(+), 21 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 602d538ffd8..43f69ef3e5c 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -662,11 +662,7 @@ def __init__( if generation_config is None: generation_config = GenerationConfig.from_model_config(config) self.generation_config = generation_config - - # TODO : deprecate self.onnx_paths = [self.model_path] - - # TODO : deprecate self.use_merged = "use_cache_branch" in self.inputs_names self.use_fp16 = False @@ -1023,7 +1019,6 @@ def _from_transformers( onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) onnx_config = onnx_config_constructor(model.config, use_past=use_cache, use_past_in_inputs=use_cache) - # TODO : create ModelPatcher to patch each architecture if config.model_type in {"bloom", "mpt"}: model.transformer._prepare_attn_mask = _prepare_attn_mask elif config.model_type == "llama": diff --git a/optimum/onnxruntime/quantization.py b/optimum/onnxruntime/quantization.py index 31140c5b747..aaf2853cae7 100644 --- a/optimum/onnxruntime/quantization.py +++ b/optimum/onnxruntime/quantization.py @@ -136,13 +136,6 @@ def from_pretrained( path = None if isinstance(model_or_path, ORTModelForConditionalGeneration): raise NotImplementedError(ort_quantizer_error_message) - elif isinstance(model_or_path, ORTModelForCausalLM): - if model_or_path.use_cache is False: - path = Path(model_or_path.decoder_model_path) - elif model_or_path.use_cache is True and model_or_path.use_merged is False: - raise NotImplementedError(ort_quantizer_error_message) - else: - path = Path(model_or_path.decoder_model_path) elif isinstance(model_or_path, Path) and file_name is None: onnx_files = list(model_or_path.glob("*.onnx")) if len(onnx_files) == 0: diff --git a/tests/onnxruntime/test_quantization.py b/tests/onnxruntime/test_quantization.py index 111c7338808..448f0fccb24 100644 --- a/tests/onnxruntime/test_quantization.py +++ b/tests/onnxruntime/test_quantization.py @@ -117,25 +117,19 @@ def test_dynamic_quantization_subgraphs(self): # with tempfile.TemporaryDirectory() as tmp_dir: tmp_dir = tempfile.mkdtemp() output_dir = Path(tmp_dir) - model = ORTModelForCausalLM.from_pretrained( - "hf-internal-testing/tiny-random-gpt2", export=True, use_merged=True - ) + model = ORTModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2", export=True) self.assertTrue(model.use_merged) model.save_pretrained(tmp_dir) quantizer = ORTQuantizer.from_pretrained(model) - quantizer.quantize( - save_dir=output_dir, - quantization_config=qconfig, - ) - + quantizer.quantize(save_dir=output_dir, quantization_config=qconfig) expected_ort_config = ORTConfig(quantization=qconfig) ort_config = ORTConfig.from_pretrained(tmp_dir) # Verify the ORTConfig was correctly created and saved self.assertEqual(ort_config.to_dict(), expected_ort_config.to_dict()) - quantized_model = onnx_load(output_dir.joinpath("decoder_model_merged_quantized.onnx")) + quantized_model = onnx_load(output_dir.joinpath("model_quantized.onnx")) num_quantized_matmul = 0 for initializer in quantized_model.graph.initializer: if "weight" in initializer.name and "quantized" in initializer.name: From 7f65ce1e4a0f29d032ba2521671b4a55269fbc56 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 14 Sep 2023 15:26:22 +0200 Subject: [PATCH 15/76] add test --- tests/onnxruntime/test_quantization.py | 63 +++++++++++++------------- 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/tests/onnxruntime/test_quantization.py b/tests/onnxruntime/test_quantization.py index 448f0fccb24..1710a4b728d 100644 --- a/tests/onnxruntime/test_quantization.py +++ b/tests/onnxruntime/test_quantization.py @@ -34,6 +34,8 @@ QuantizationConfig, ) +from optimum.utils.testing_utils import grid_parameters + class ORTQuantizerTest(unittest.TestCase): LOAD_CONFIGURATION = { @@ -76,6 +78,10 @@ class ORTDynamicQuantizationTest(unittest.TestCase): (ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-bart", 32), ) + SUPPORTED_DECODER_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = ( + (ORTModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 22), + ) + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS) def test_dynamic_quantization(self, model_cls, model_name, expected_quantized_matmuls): qconfig = QuantizationConfig( @@ -94,11 +100,7 @@ def test_dynamic_quantization(self, model_cls, model_name, expected_quantized_ma model.save_pretrained(tmp_dir) quantizer = ORTQuantizer.from_pretrained(model) - quantizer.quantize( - save_dir=output_dir, - quantization_config=qconfig, - ) - + quantizer.quantize(save_dir=output_dir, quantization_config=qconfig) expected_ort_config = ORTConfig(quantization=qconfig) ort_config = ORTConfig.from_pretrained(tmp_dir) # Verify the ORTConfig was correctly created and saved @@ -112,31 +114,33 @@ def test_dynamic_quantization(self, model_cls, model_name, expected_quantized_ma self.assertEqual(expected_quantized_matmuls, num_quantized_matmul) gc.collect() - def test_dynamic_quantization_subgraphs(self): + @parameterized.expand( + grid_parameters( + {"model_arch": SUPPORTED_DECODER_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS, "use_cache": [True, False]} + ) + ) + def test_decoder_quantization_with_and_without_cache(self, test_name, model_info, use_cache): + model_cls, model_name, expected_quantized_matmuls = model_info qconfig = AutoQuantizationConfig.avx512(is_static=False, per_channel=True) - # with tempfile.TemporaryDirectory() as tmp_dir: - tmp_dir = tempfile.mkdtemp() - output_dir = Path(tmp_dir) - model = ORTModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2", export=True) - - self.assertTrue(model.use_merged) - model.save_pretrained(tmp_dir) + model = model_cls.from_pretrained(model_name, export=True, use_cache=use_cache) - quantizer = ORTQuantizer.from_pretrained(model) - quantizer.quantize(save_dir=output_dir, quantization_config=qconfig) - expected_ort_config = ORTConfig(quantization=qconfig) - ort_config = ORTConfig.from_pretrained(tmp_dir) - # Verify the ORTConfig was correctly created and saved - self.assertEqual(ort_config.to_dict(), expected_ort_config.to_dict()) - - quantized_model = onnx_load(output_dir.joinpath("model_quantized.onnx")) - num_quantized_matmul = 0 - for initializer in quantized_model.graph.initializer: - if "weight" in initializer.name and "quantized" in initializer.name: - num_quantized_matmul += 1 + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + output_dir = Path(tmp_dir) + quantizer = ORTQuantizer.from_pretrained(model) + quantizer.quantize(save_dir=output_dir, quantization_config=qconfig) + expected_ort_config = ORTConfig(quantization=qconfig) + ort_config = ORTConfig.from_pretrained(tmp_dir) - self.assertTrue(num_quantized_matmul > 0) - gc.collect() + # Verify the ORTConfig was correctly created and saved + self.assertEqual(ort_config.to_dict(), expected_ort_config.to_dict()) + quantized_model = onnx_load(output_dir.joinpath("model_quantized.onnx")) + num_quantized_matmul = 0 + for initializer in quantized_model.graph.initializer: + if "weight" in initializer.name and "quantized" in initializer.name: + num_quantized_matmul += 1 + self.assertEqual(expected_quantized_matmuls, num_quantized_matmul) + gc.collect() class ORTStaticQuantizationTest(unittest.TestCase): @@ -176,10 +180,7 @@ def preprocess_function(examples, tokenizer): dataset_split="train", ) calibration_config = AutoCalibrationConfig.minmax(calibration_dataset) - ranges = quantizer.fit( - dataset=calibration_dataset, - calibration_config=calibration_config, - ) + ranges = quantizer.fit(dataset=calibration_dataset, calibration_config=calibration_config) quantizer.quantize( save_dir=output_dir, calibration_tensors_range=ranges, From 2840b81d6dbb1ae9148f351fa2a01f5db7677b03 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 14 Sep 2023 15:26:41 +0200 Subject: [PATCH 16/76] format --- tests/onnxruntime/test_quantization.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/onnxruntime/test_quantization.py b/tests/onnxruntime/test_quantization.py index 1710a4b728d..eb007f16b8b 100644 --- a/tests/onnxruntime/test_quantization.py +++ b/tests/onnxruntime/test_quantization.py @@ -33,7 +33,6 @@ ORTQuantizer, QuantizationConfig, ) - from optimum.utils.testing_utils import grid_parameters From 5fa7b2034bef71ebfe45382b5809f64df84e7767 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 14 Sep 2023 16:06:46 +0200 Subject: [PATCH 17/76] format --- optimum/onnxruntime/quantization.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optimum/onnxruntime/quantization.py b/optimum/onnxruntime/quantization.py index aaf2853cae7..fce5c18f55a 100644 --- a/optimum/onnxruntime/quantization.py +++ b/optimum/onnxruntime/quantization.py @@ -33,7 +33,6 @@ from ..utils.save_utils import maybe_save_preprocessors from . import ORTQuantizableOperator from .configuration import CalibrationConfig, ORTConfig, QuantizationConfig -from .modeling_decoder import ORTModelForCausalLM from .modeling_ort import ORTModel from .modeling_seq2seq import ORTModelForConditionalGeneration from .preprocessors import QuantizationPreprocessor From 80119828e13fa7a6c006144d33d404396673794d Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 14 Sep 2023 17:20:14 +0200 Subject: [PATCH 18/76] fix optimization --- optimum/onnxruntime/modeling_decoder.py | 2 -- optimum/onnxruntime/optimization.py | 18 +++++++----------- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 43f69ef3e5c..7dea904e85b 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -883,8 +883,6 @@ def _from_pretrained( model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, **kwargs, ) -> "ORTModelForCausalLM": - # model_path = Path(model_id) - # We do not implement the logic for use_cache=False, use_merged=True if use_cache is False: if use_merged is True: diff --git a/optimum/onnxruntime/optimization.py b/optimum/onnxruntime/optimization.py index 2db9f753c34..bc65a96882f 100644 --- a/optimum/onnxruntime/optimization.py +++ b/optimum/onnxruntime/optimization.py @@ -97,17 +97,13 @@ def from_pretrained( # Add the decoder with past key/values if present if model_or_path.use_cache: onnx_model_path.append(model_or_path.decoder_with_past_model_path) - elif isinstance(model_or_path, ORTModelForCausalLM): - if model_or_path.use_merged is True: - raise NotImplementedError( - "ORTOptimizer does not support ORTModelForCausalLM models that use a single ONNX for both the without/with past cases." - " Please pass an ORTModelForCausalLM that uses a separate ONNX for each without/with past cases. This can be done" - " by using `ORTModelForCausalLM.from_pretrained(..., export=True, use_merged=False)`, or by" - " using the option `--no-post-process` in the optimum-cli ONNX export tool." - ) - onnx_model_path.append(model_or_path.decoder_model_path) - if model_or_path.use_cache: - onnx_model_path.append(model_or_path.decoder_with_past_model_path) + elif isinstance(model_or_path, ORTModelForCausalLM) and model_or_path.use_merged: + raise NotImplementedError( + "ORTOptimizer does not support ORTModelForCausalLM models that use a single ONNX for both the without/with past cases." + " Please pass an ORTModelForCausalLM that uses a separate ONNX for each without/with past cases. This can be done" + " by using `ORTModelForCausalLM.from_pretrained(..., export=True, use_merged=False)`, or by" + " using the option `--no-post-process` in the optimum-cli ONNX export tool." + ) else: onnx_model_path.append(model_or_path.model_path) config = model_or_path.config From b6433086e3699f178782fd41d75353c839030772 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 15 Sep 2023 12:05:24 +0200 Subject: [PATCH 19/76] fix opitmization --- optimum/onnxruntime/modeling_decoder.py | 16 +++++++++++----- optimum/onnxruntime/optimization.py | 6 ++---- tests/onnxruntime/test_optimization.py | 21 +++++++++++---------- 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 7dea904e85b..839c5c5c6a0 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -922,12 +922,14 @@ def _from_pretrained( ) use_merged = False - exclude_decoder = r"(.*)?((? Date: Fri, 15 Sep 2023 15:18:24 +0200 Subject: [PATCH 20/76] fix compatibility with legacy models --- optimum/onnxruntime/modeling_decoder.py | 14 +++++++++----- tests/onnxruntime/test_optimization.py | 6 ++++-- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 839c5c5c6a0..a340a048ddb 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -685,8 +685,10 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, labels: Optional[torch.LongTensor] = None, + use_cache_branch: None = None, **kwargs, ) -> CausalLMOutputWithPast: + # adding use_cache_branch in the signature here is just a hack for IO Binding use_torch = isinstance(input_ids, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) @@ -768,6 +770,11 @@ def forward( if use_cache_branch is not None: inputs["use_cache_branch"] = use_cache_branch.cpu().detach().numpy() if use_torch else use_cache_branch + for output in self.model.get_outputs(): + if output.name == "logits" and output.shape[1] == 1: + # TODO : modify the static graph + raise ValueError("The model needs to be re-exported or set use_cache=False.") + outputs = self.model.run(None, inputs) if self.use_cache: @@ -938,11 +945,8 @@ def _from_pretrained( file_name = decoder_path.name regular_file_names = [] - for regular_file_name in [ - ONNX_WEIGHTS_NAME, - ONNX_DECODER_WITH_PAST_NAME if use_cache else ONNX_DECODER_NAME, - ]: - regular_file_names += ORTModelForCausalLM._generate_regular_names_for_filename(regular_file_name) + for name in [ONNX_WEIGHTS_NAME, ONNX_DECODER_WITH_PAST_NAME if use_cache else ONNX_DECODER_NAME]: + regular_file_names += ORTModelForCausalLM._generate_regular_names_for_filename(name) if file_name not in regular_file_names: logger.warning( diff --git a/tests/onnxruntime/test_optimization.py b/tests/onnxruntime/test_optimization.py index 82258fe96db..4ffcf9469c4 100644 --- a/tests/onnxruntime/test_optimization.py +++ b/tests/onnxruntime/test_optimization.py @@ -578,7 +578,6 @@ def test_optimization_levels_gpu( use_io_binding=use_io_binding, ) - def test_merged_optimization(self): ort_model = ORTModelForCausalLM.from_pretrained("fxmarty/onnx-tiny-random-gpt2-with-merge") self.assertTrue(ort_model.use_cache) @@ -586,4 +585,7 @@ def test_merged_optimization(self): with self.assertRaises(NotImplementedError) as cm: optimizer = ORTOptimizer.from_pretrained(ort_model) - self.assertTrue("ORTOptimizer does not support ORTModelForCausalLM models when without/with past models are merged" in str(cm.exception)) + self.assertTrue( + "ORTOptimizer does not support ORTModelForCausalLM models when without/with past models are merged" + in str(cm.exception) + ) From 144753aca9d98bbbd9e104e2dadf4e52e6f10e89 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 15 Sep 2023 15:52:03 +0200 Subject: [PATCH 21/76] format --- tests/onnxruntime/test_optimization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/onnxruntime/test_optimization.py b/tests/onnxruntime/test_optimization.py index 4ffcf9469c4..24683e92d89 100644 --- a/tests/onnxruntime/test_optimization.py +++ b/tests/onnxruntime/test_optimization.py @@ -583,7 +583,7 @@ def test_merged_optimization(self): self.assertTrue(ort_model.use_cache) with self.assertRaises(NotImplementedError) as cm: - optimizer = ORTOptimizer.from_pretrained(ort_model) + ORTOptimizer.from_pretrained(ort_model) self.assertTrue( "ORTOptimizer does not support ORTModelForCausalLM models when without/with past models are merged" From 4ee61674a4ea8e01970ac4d6a04cac899aa568df Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 15 Sep 2023 17:11:53 +0200 Subject: [PATCH 22/76] fix legacy models --- optimum/onnxruntime/modeling_decoder.py | 33 +++++++++++++++++-------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index a340a048ddb..64c7b9bd204 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -50,6 +50,8 @@ validate_provider_availability, ) +import onnx +from onnx.tools import update_model_dims if TYPE_CHECKING: from transformers import PretrainedConfig @@ -770,11 +772,6 @@ def forward( if use_cache_branch is not None: inputs["use_cache_branch"] = use_cache_branch.cpu().detach().numpy() if use_torch else use_cache_branch - for output in self.model.get_outputs(): - if output.name == "logits" and output.shape[1] == 1: - # TODO : modify the static graph - raise ValueError("The model needs to be re-exported or set use_cache=False.") - outputs = self.model.run(None, inputs) if self.use_cache: @@ -954,7 +951,7 @@ def _from_pretrained( f"{cls.__name__} might not behave as expected." ) - model = super()._from_pretrained( + ort_model = super()._from_pretrained( model_id, config, use_auth_token=use_auth_token, @@ -972,15 +969,31 @@ def _from_pretrained( **kwargs, ) - if use_cache ^ model.use_cache: + if use_cache ^ ort_model.use_cache: raise ValueError( - f"`use_cache` was set to `{use_cache}` but the loaded model only supports `use_cache={model.use_cache}`. " - f"Please load your current model with `use_cache={model.use_cache}` or export the original model " + f"`use_cache` was set to `{use_cache}` but the loaded model only supports `use_cache={ort_model.use_cache}`. " + f"Please load your current model with `use_cache={ort_model.use_cache}` or export the original model " f"once again with `use_cache={use_cache}` when calling the `from_pretrained` method. " "To export your model, simply set `export=True`." ) - return model + # Since v1.7.0 decoder with past models have fixed sequence length of 1 + # To keep these models compatible we set this dimension to dynamic + input_dims = {inputs.name: inputs.shape for inputs in ort_model.model.get_inputs()} + # TODO : refactorize ORTModel.from_pretrained to not re-create inference session a second time + if input_dims["input_ids"][1] == 1: + model_path = ort_model.model_path + input_dims["input_ids"][1] = "sequence_length" + output_dims = {output.name: output.shape for output in ort_model.model.get_outputs()} + output_dims["logits"][1] = "sequence_length" + static_model = onnx.load(model_path) + updated_model = update_model_dims.update_inputs_outputs_dims(static_model, input_dims, output_dims) + onnx.save(updated_model, model_path) + ort_model.model = ORTModel.load_model( + model_path, provider=provider, session_options=session_options, provider_options=provider_options + ) + + return ort_model @classmethod def _from_transformers( From f2d0f8410c587ba502d114c5faca59ea8d23fcd7 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 15 Sep 2023 17:14:46 +0200 Subject: [PATCH 23/76] format --- optimum/onnxruntime/modeling_decoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 64c7b9bd204..8e484b5d8ea 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -23,10 +23,12 @@ import torch from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError +from onnx.tools import update_model_dims from transformers import AutoModelForCausalLM, GenerationConfig from transformers.file_utils import add_end_docstrings, add_start_docstrings_to_model_forward from transformers.modeling_outputs import CausalLMOutputWithPast +import onnx import onnxruntime from ..exporters import TasksManager @@ -50,8 +52,6 @@ validate_provider_availability, ) -import onnx -from onnx.tools import update_model_dims if TYPE_CHECKING: from transformers import PretrainedConfig From 3ff719a782a72f8a034a57db774d5409ad7c5360 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 15 Sep 2023 17:22:25 +0200 Subject: [PATCH 24/76] fix style --- optimum/onnxruntime/modeling_decoder.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 8e484b5d8ea..3b92c5dd341 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -20,6 +20,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import numpy as np +import onnx +import onnxruntime import torch from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError @@ -28,9 +30,6 @@ from transformers.file_utils import add_end_docstrings, add_start_docstrings_to_model_forward from transformers.modeling_outputs import CausalLMOutputWithPast -import onnx -import onnxruntime - from ..exporters import TasksManager from ..exporters.onnx import export, main_export from ..onnx.utils import _get_external_data_paths From d794141a94c899cdd5ea9e8dd071c06627f974c5 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 15 Sep 2023 17:28:33 +0200 Subject: [PATCH 25/76] format --- optimum/onnxruntime/modeling_decoder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 3b92c5dd341..31410121b53 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -21,7 +21,6 @@ import numpy as np import onnx -import onnxruntime import torch from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError @@ -30,6 +29,8 @@ from transformers.file_utils import add_end_docstrings, add_start_docstrings_to_model_forward from transformers.modeling_outputs import CausalLMOutputWithPast +import onnxruntime + from ..exporters import TasksManager from ..exporters.onnx import export, main_export from ..onnx.utils import _get_external_data_paths From a34a16e03d91e13989fa6b622ca0e5a7a3cdde0b Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 15 Sep 2023 19:17:06 +0200 Subject: [PATCH 26/76] add export to main_export --- optimum/exporters/onnx/utils.py | 69 +++++++++++++++++-------- optimum/onnxruntime/modeling_decoder.py | 40 ++++++-------- 2 files changed, 65 insertions(+), 44 deletions(-) diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 3170ebdcdd2..1c86d8a86e6 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -29,6 +29,8 @@ logging, ) from ...utils.import_utils import _diffusers_version +from ...utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask + from ..tasks import TasksManager from .constants import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME @@ -146,16 +148,28 @@ def _get_submodels_for_export_stable_diffusion( def _get_submodels_for_export_decoder( - model: Union["PreTrainedModel", "TFPreTrainedModel"], use_past: bool + model: Union["PreTrainedModel", "TFPreTrainedModel"], + use_past: bool, + legacy: bool = False, ) -> Dict[str, Union["PreTrainedModel", "TFPreTrainedModel"]]: """ Returns the decoder part of the model. """ models_for_export = {} - models_for_export[ONNX_DECODER_NAME] = model - if use_past: - models_for_export[ONNX_DECODER_WITH_PAST_NAME] = model + if legacy: + models_for_export[ONNX_DECODER_NAME] = model + if use_past: + models_for_export[ONNX_DECODER_WITH_PAST_NAME] = model + else: + if model.config.model_type in {"bloom", "mpt"}: + model.transformer._prepare_attn_mask = _prepare_attn_mask + elif model.config.model_type == "llama": + model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + elif model.config.model_type in ("blenderbot-small", "blenderbot", "opt", "pegasus", "bart"): + model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + + models_for_export["model"] = model return models_for_export @@ -214,6 +228,7 @@ def get_encoder_decoder_models_for_export( def get_decoder_models_for_export( model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "OnnxConfig", + legacy: bool = False, ) -> Dict[str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel"], "OnnxConfig"]]: """ Returns two versions of the decoder that can be used together to perform fast generation: @@ -233,31 +248,43 @@ def get_decoder_models_for_export( `Dict[str, Tuple[Union[PreTrainedModel, TFPreTrainedModel], OnnxConfig]]: A Dict containing the model and onnx configs for the encoder and decoder parts of the model. """ - models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past) - onnx_config = config.__class__( - model.config, - task=config.task, - use_past=config.use_past, - use_past_in_inputs=False, - float_dtype=config.float_dtype, - int_dtype=config.int_dtype, - ) - models_for_export[ONNX_DECODER_NAME] = (models_for_export[ONNX_DECODER_NAME], onnx_config) + models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past) - if config.use_past: - onnx_config_with_past = config.__class__( + if legacy: + onnx_config = config.__class__( model.config, task=config.task, - use_past=True, - use_past_in_inputs=True, + use_past=config.use_past, + use_past_in_inputs=False, float_dtype=config.float_dtype, int_dtype=config.int_dtype, ) - models_for_export[ONNX_DECODER_WITH_PAST_NAME] = ( - models_for_export[ONNX_DECODER_WITH_PAST_NAME], - onnx_config_with_past, + models_for_export[ONNX_DECODER_NAME] = (models_for_export[ONNX_DECODER_NAME], onnx_config) + + if config.use_past: + onnx_config_with_past = config.__class__( + model.config, + task=config.task, + use_past=True, + use_past_in_inputs=True, + float_dtype=config.float_dtype, + int_dtype=config.int_dtype, + ) + models_for_export[ONNX_DECODER_WITH_PAST_NAME] = ( + models_for_export[ONNX_DECODER_WITH_PAST_NAME], + onnx_config_with_past, + ) + else: + onnx_config = config.__class__( + model.config, + task=config.task, + use_past=config.use_past, + use_past_in_inputs=config.use_past, + float_dtype=config.float_dtype, + int_dtype=config.int_dtype, ) + models_for_export["model"] = (models_for_export["model"], onnx_config) return models_for_export diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 31410121b53..26b693b4e05 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -1024,33 +1024,27 @@ def _from_transformers( if task is None: task = cls._auto_model_to_task(cls.auto_model_class) + if use_cache: + task += "-with-past" + save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) - model_kwargs = { - "revision": revision, - "use_auth_token": use_auth_token, - "cache_dir": cache_dir, - "subfolder": subfolder, - "local_files_only": local_files_only, - "force_download": force_download, - "trust_remote_code": trust_remote_code, - } - - model = TasksManager.get_model_from_task(task, model_id, **model_kwargs) - onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) - onnx_config = onnx_config_constructor(model.config, use_past=use_cache, use_past_in_inputs=use_cache) - if config.model_type in {"bloom", "mpt"}: - model.transformer._prepare_attn_mask = _prepare_attn_mask - elif config.model_type == "llama": - model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask - elif config.model_type in ("blenderbot-small", "blenderbot", "opt", "pegasus", "bart"): - model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask - - # Export the model to the ONNX format - export(model=model, config=onnx_config, output=save_dir_path / file_name) + main_export( + model_name_or_path=model_id, + output=save_dir_path, + task=task, + do_validation=False, + no_post_process=False, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + use_auth_token=use_auth_token, + local_files_only=local_files_only, + force_download=force_download, + trust_remote_code=trust_remote_code, + ) - # TODO : use main_export config.save_pretrained(save_dir_path) maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder) From dfe7e5e4d872d61b63a58beb4a0f044039128f57 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 18 Sep 2023 12:08:38 +0200 Subject: [PATCH 27/76] add legacy to ONNX export --- optimum/commands/export/onnx.py | 6 +++++ optimum/exporters/onnx/__main__.py | 6 ++++- optimum/exporters/onnx/utils.py | 2 +- tests/exporters/onnx/test_onnx_export.py | 28 +++++++++++++++--------- 4 files changed, 30 insertions(+), 12 deletions(-) diff --git a/optimum/commands/export/onnx.py b/optimum/commands/export/onnx.py index b705c98eed6..a175e048e94 100644 --- a/optimum/commands/export/onnx.py +++ b/optimum/commands/export/onnx.py @@ -210,6 +210,11 @@ def parse_args_onnx(parser): default=None, help=("The library on the model." " If not provided, will attempt to infer the local checkpoint's library"), ) + optional_group.add_argument( + "--legacy", + action="store_true", + help=("Export decoder only models in two (without + with past) model as a single ONNX file."), + ) # deprecated argument parser.add_argument("--for-ort", action="store_true", help=argparse.SUPPRESS) @@ -248,5 +253,6 @@ def run(self): use_subprocess=True, _variant=self.args.variant, library_name=self.args.library_name, + legacy=self.args.legacy, **input_shapes, ) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index cc04e8459f2..95c66d8977e 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -67,6 +67,7 @@ def _get_submodels_and_onnx_configs( float_dtype: str = "fp32", fn_get_submodels: Optional[Callable] = None, preprocessors: Optional[List[Any]] = None, + legacy: bool = False, ): is_stable_diffusion = "stable-diffusion" in task if not custom_architecture: @@ -96,7 +97,7 @@ def _get_submodels_and_onnx_configs( ): models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config) elif task.startswith("text-generation") and not monolith: - models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config) + models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config, legacy=legacy) elif model.config.model_type == "sam": models_and_onnx_configs = get_sam_models_for_export(model, onnx_config) else: @@ -174,6 +175,7 @@ def main_export( use_subprocess: bool = False, _variant: str = "default", library_name: Optional[str] = None, + legacy: bool = False, **kwargs_shapes, ): """ @@ -406,6 +408,7 @@ def main_export( fn_get_submodels=fn_get_submodels, preprocessors=preprocessors, _variant=_variant, + legacy=legacy, ) if not is_stable_diffusion: @@ -591,6 +594,7 @@ def main(): pad_token_id=args.pad_token_id, for_ort=args.for_ort, library_name=args.library_name, + legacy=args.legacy, **input_shapes, ) diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 1c86d8a86e6..4e3cf922e39 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -249,7 +249,7 @@ def get_decoder_models_for_export( onnx configs for the encoder and decoder parts of the model. """ - models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past) + models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past, legacy=legacy) if legacy: onnx_config = config.__class__( diff --git a/tests/exporters/onnx/test_onnx_export.py b/tests/exporters/onnx/test_onnx_export.py index 10eaeddd13c..4eefafa2d90 100644 --- a/tests/exporters/onnx/test_onnx_export.py +++ b/tests/exporters/onnx/test_onnx_export.py @@ -14,6 +14,7 @@ # limitations under the License. import gc import os +from functools import partial from pathlib import Path from tempfile import TemporaryDirectory from typing import Dict @@ -529,8 +530,8 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 2: decoder_sequence_name} -def fn_get_submodels_custom(model): - return {"decoder_model": model, "decoder_with_past_model": model} +def fn_get_submodels_custom(model, legacy=False): + return {"decoder_model": model, "decoder_with_past_model": model} if legacy else {"model": model} class OnnxCustomExport(TestCase): @@ -568,11 +569,12 @@ def test_custom_export_official_model(self): assert "decoder_attentions.0" in output_names assert "cross_attentions.0" in output_names - @parameterized.expand([(None,), (fn_get_submodels_custom,)]) - def test_custom_export_trust_remote(self, fn_get_submodels): + @parameterized.expand( + grid_parameters({"fn_get_submodels": [None, fn_get_submodels_custom], "legacy": [True, False]}) + ) + def test_custom_export_trust_remote(self, test_name, fn_get_submodels, legacy): model_id = "fxmarty/tiny-mpt-random-remote-code" config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) - onnx_config = CustomMPTOnnxConfig( config=config, task="text-generation", @@ -581,10 +583,15 @@ def test_custom_export_trust_remote(self, fn_get_submodels): ) onnx_config_with_past = CustomMPTOnnxConfig(config, task="text-generation", use_past=True) - custom_onnx_configs = { - "decoder_model": onnx_config, - "decoder_with_past_model": onnx_config_with_past, - } + if legacy: + custom_onnx_configs = { + "decoder_model": onnx_config, + "decoder_with_past_model": onnx_config_with_past, + } + else: + custom_onnx_configs = { + "model": onnx_config_with_past, + } with TemporaryDirectory() as tmpdirname: main_export( @@ -594,7 +601,8 @@ def test_custom_export_trust_remote(self, fn_get_submodels): trust_remote_code=True, custom_onnx_configs=custom_onnx_configs, no_post_process=True, - fn_get_submodels=fn_get_submodels, + fn_get_submodels=partial(fn_get_submodels, legacy=legacy) if fn_get_submodels else None, + legacy=legacy, opset=14, ) From 8d102f7808318c9f74b9249d011a13809791460e Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 18 Sep 2023 13:40:11 +0200 Subject: [PATCH 28/76] fix test --- optimum/onnxruntime/modeling_decoder.py | 1 + tests/onnxruntime/test_modeling.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 26b693b4e05..30adbf14f1f 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -1036,6 +1036,7 @@ def _from_transformers( task=task, do_validation=False, no_post_process=False, + legacy=False, subfolder=subfolder, revision=revision, cache_dir=cache_dir, diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 3269d3db3d6..218d94eda23 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -2049,7 +2049,7 @@ def test_merge_from_onnx_and_save(self, model_arch): self.skipTest("Unsupported export case") with tempfile.TemporaryDirectory() as tmpdir: - main_export(model_id, tmpdir, task=task) + main_export(model_id, tmpdir, task=task, legacy=True) model = ORTModelForCausalLM.from_pretrained(tmpdir) From 62b897421bffd9aa44dac654693daaccf72968a6 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 18 Sep 2023 13:42:43 +0200 Subject: [PATCH 29/76] fix --- tests/onnx/test_onnx_graph_transformations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/onnx/test_onnx_graph_transformations.py b/tests/onnx/test_onnx_graph_transformations.py index bed539eaccb..c06ac5af971 100644 --- a/tests/onnx/test_onnx_graph_transformations.py +++ b/tests/onnx/test_onnx_graph_transformations.py @@ -85,6 +85,7 @@ def test_merge_decoders(self, *args): tmpdir, task=task, no_post_process=True, + legacy=True, ) decoder = onnx.load(os.path.join(tmpdir, "decoder_model.onnx")) From b8e18c30601fd76e4dfb5f74438da5bc80b636a5 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 18 Sep 2023 13:50:19 +0200 Subject: [PATCH 30/76] rm unused import --- optimum/onnxruntime/modeling_decoder.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 30adbf14f1f..9000e434e14 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -32,11 +32,10 @@ import onnxruntime from ..exporters import TasksManager -from ..exporters.onnx import export, main_export +from ..exporters.onnx import main_export from ..onnx.utils import _get_external_data_paths from ..utils import NormalizedConfigManager, check_if_transformers_greater from ..utils.file_utils import validate_file_exists -from ..utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors from .base import ORTDecoder from .constants import DECODER_MERGED_ONNX_FILE_PATTERN, DECODER_ONNX_FILE_PATTERN, DECODER_WITH_PAST_ONNX_FILE_PATTERN From 819691ef6cf4776dafff22066bae58e73a439784 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 18 Sep 2023 14:35:01 +0200 Subject: [PATCH 31/76] patch model to fix causal lm generation --- optimum/exporters/onnx/utils.py | 14 +++++++------- optimum/onnxruntime/modeling_decoder.py | 1 - 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 4e3cf922e39..ac0c6747970 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -157,19 +157,19 @@ def _get_submodels_for_export_decoder( """ models_for_export = {} - if legacy: - models_for_export[ONNX_DECODER_NAME] = model - if use_past: - models_for_export[ONNX_DECODER_WITH_PAST_NAME] = model - else: + if not legacy or use_past: + # fix causal lm generation for inputs of sequence length 1 if model.config.model_type in {"bloom", "mpt"}: model.transformer._prepare_attn_mask = _prepare_attn_mask elif model.config.model_type == "llama": model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask - elif model.config.model_type in ("blenderbot-small", "blenderbot", "opt", "pegasus", "bart"): + elif model.config.model_type in {"blenderbot-small", "blenderbot", "opt", "pegasus", "bart"}: model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask - models_for_export["model"] = model + models_for_export[ONNX_DECODER_NAME if legacy else "model"] = model + + if legacy and use_past: + models_for_export[ONNX_DECODER_WITH_PAST_NAME] = model return models_for_export diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 9000e434e14..659ce025342 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -31,7 +31,6 @@ import onnxruntime -from ..exporters import TasksManager from ..exporters.onnx import main_export from ..onnx.utils import _get_external_data_paths from ..utils import NormalizedConfigManager, check_if_transformers_greater From e259670f20663e16473d7834fcb88b2d1678cd33 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 18 Sep 2023 14:43:40 +0200 Subject: [PATCH 32/76] rm commen --- optimum/utils/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/utils/modeling_utils.py b/optimum/utils/modeling_utils.py index fcf75aa1e89..50da3556815 100644 --- a/optimum/utils/modeling_utils.py +++ b/optimum/utils/modeling_utils.py @@ -82,7 +82,7 @@ def _prepare_attn_mask( combined_attention_mask = _make_causal_mask( input_shape, device=device, past_key_values_length=past_key_values_length ) - # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]_prepare_decoder_attention_mask + # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask From 2f262019edfb29fa6d7b5bb0fa4864de33594da2 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 18 Sep 2023 15:25:40 +0200 Subject: [PATCH 33/76] add no psot process --- optimum/onnxruntime/modeling_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 659ce025342..68c5fee7dd9 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -1033,7 +1033,7 @@ def _from_transformers( output=save_dir_path, task=task, do_validation=False, - no_post_process=False, + no_post_process=True, legacy=False, subfolder=subfolder, revision=revision, From 6d8acb42a3b9384f120e09501c086c0adb5e3823 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 18 Sep 2023 18:57:23 +0200 Subject: [PATCH 34/76] fix --- optimum/onnxruntime/modeling_decoder.py | 123 +++++++++++++++++------- 1 file changed, 87 insertions(+), 36 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index e5abcc89e07..879507bd5fb 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -650,6 +650,7 @@ def __init__( model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, preprocessors: Optional[List] = None, generation_config: Optional[GenerationConfig] = None, + use_cache: Optional[bool] = None, **kwargs, ): if use_io_binding is None: @@ -678,6 +679,16 @@ def __init__( self.use_fp16 = True break + if use_cache ^ self.use_cache: + raise ValueError( + f"`use_cache` was set to `{use_cache}` but the loaded model only supports `use_cache={self.use_cache}`. " + f"Please load your current model with `use_cache={self.use_cache}` or export the original model " + f"once again with `use_cache={use_cache}` when calling the `from_pretrained` method. " + "To export your model, simply set `export=True`." + ) + + + @add_start_docstrings_to_model_forward( CAUSALLM_ONNX_MODEL_DOCSTRING.format("batch_size, sequence_length") + TEXT_GENERATION_EXAMPLE.format( @@ -729,7 +740,7 @@ def forward( if "attention_mask" in self.inputs_names: model_inputs.append(attention_mask) - if "position_ids" in self.input_names: + if "position_ids" in self.inputs_names: if position_ids is None: raise ValueError("position_ids was not passed but is a required input for this ONNX model.") model_inputs.append(position_ids.contiguous()) @@ -778,7 +789,7 @@ def forward( if "labels" in self.inputs_names: inputs["labels"] = labels.cpu().detach().numpy() if use_torch else labels - if "position_ids" in self.input_names: + if "position_ids" in self.inputs_names: if position_ids is None: raise ValueError("position_ids was not passed but is a required input for this ONNX model.") onnx_inputs["position_ids"] = position_ids.cpu().detach().numpy() if use_torch else position_ids @@ -906,6 +917,10 @@ def _from_pretrained( model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, **kwargs, ) -> "ORTModelForCausalLM": + + + model_path = Path(model_id) + # We do not implement the logic for use_cache=False, use_merged=True if use_cache is False: if use_merged is True: @@ -970,51 +985,87 @@ def _from_pretrained( f"{cls.__name__} might not behave as expected." ) + if config.model_type == "bloom": + init_cls = ORTBloomForCausalLM + elif config.model_type == "mpt": + init_cls = ORTMPTForCausalLM + elif config.model_type == "opt": + init_cls = ORTOPTForCausalLM + else: + init_cls = ORTModelForCausalLM - # TODo : add init_cls - ort_model = super()._from_pretrained( - model_id, - config, - use_auth_token=use_auth_token, - revision=revision, - force_download=force_download, - cache_dir=cache_dir, - file_name=file_name, - subfolder=subfolder, - use_cache=use_cache, - provider=provider, - session_options=session_options, - provider_options=provider_options, - use_io_binding=use_io_binding, - model_save_dir=model_save_dir, - **kwargs, - ) + ################################################################################################## - if use_cache ^ ort_model.use_cache: - raise ValueError( - f"`use_cache` was set to `{use_cache}` but the loaded model only supports `use_cache={ort_model.use_cache}`. " - f"Please load your current model with `use_cache={ort_model.use_cache}` or export the original model " - f"once again with `use_cache={use_cache}` when calling the `from_pretrained` method. " - "To export your model, simply set `export=True`." + preprocessors = None + if model_path.is_dir(): + model_cache_path = model_path / file_name + new_model_save_dir = model_path + preprocessors = maybe_load_preprocessors(model_id) + else: + model_cache_path = hf_hub_download( + repo_id=model_id, + filename=file_name, + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, ) + # try download external data + try: + hf_hub_download( + repo_id=model_id, + subfolder=subfolder, + filename=file_name + "_data", + use_auth_token=use_auth_token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + ) + except EntryNotFoundError: + # model doesn't use external data + pass + new_model_save_dir = Path(model_cache_path).parent + preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder) + + # model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it + # instead of the path only. + if model_save_dir is None: + model_save_dir = new_model_save_dir + + ################################################################################################## + # Since v1.7.0 decoder with past models have fixed sequence length of 1 # To keep these models compatible we set this dimension to dynamic - input_dims = {inputs.name: inputs.shape for inputs in ort_model.model.get_inputs()} - # TODO : refactorize ORTModel.from_pretrained to not re-create inference session a second time + onnx_model = onnx.load(model_cache_path) + input_dims = {node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim] for node in onnx_model.graph.input} if input_dims["input_ids"][1] == 1: - model_path = ort_model.model_path input_dims["input_ids"][1] = "sequence_length" - output_dims = {output.name: output.shape for output in ort_model.model.get_outputs()} + output_dims = {node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim] for node in onnx_model.graph.output} output_dims["logits"][1] = "sequence_length" - static_model = onnx.load(model_path) + static_model = onnx.load(model_cache_path) updated_model = update_model_dims.update_inputs_outputs_dims(static_model, input_dims, output_dims) - onnx.save(updated_model, model_path) - ort_model.model = ORTModel.load_model( - model_path, provider=provider, session_options=session_options, provider_options=provider_options - ) + onnx.save(updated_model, model_cache_path) + + model = ORTModel.load_model( + model_cache_path, + provider=provider, + session_options=session_options, + provider_options=provider_options, + ) + + return init_cls( + model=model, + config=config, + use_io_binding=use_io_binding, + model_save_dir=model_save_dir, + preprocessors=preprocessors, + use_cache=use_cache + ) - return ort_model @classmethod def _from_transformers( From 52c1745799b01610ba925be2deeab1c9817e444a Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 18 Sep 2023 19:02:32 +0200 Subject: [PATCH 35/76] remove bloom caching --- optimum/onnxruntime/modeling_decoder.py | 112 +++--------------------- 1 file changed, 13 insertions(+), 99 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 879507bd5fb..1f18f70078c 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -687,8 +687,6 @@ def __init__( "To export your model, simply set `export=True`." ) - - @add_start_docstrings_to_model_forward( CAUSALLM_ONNX_MODEL_DOCSTRING.format("batch_size, sequence_length") + TEXT_GENERATION_EXAMPLE.format( @@ -707,7 +705,6 @@ def forward( use_cache_branch: None = None, **kwargs, ) -> CausalLMOutputWithPast: - # adding use_cache_branch in the signature here is just a hack for IO Binding use_torch = isinstance(input_ids, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) @@ -917,8 +914,6 @@ def _from_pretrained( model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, **kwargs, ) -> "ORTModelForCausalLM": - - model_path = Path(model_id) # We do not implement the logic for use_cache=False, use_merged=True @@ -1041,10 +1036,16 @@ def _from_pretrained( # Since v1.7.0 decoder with past models have fixed sequence length of 1 # To keep these models compatible we set this dimension to dynamic onnx_model = onnx.load(model_cache_path) - input_dims = {node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim] for node in onnx_model.graph.input} + input_dims = { + node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim] + for node in onnx_model.graph.input + } if input_dims["input_ids"][1] == 1: input_dims["input_ids"][1] = "sequence_length" - output_dims = {node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim] for node in onnx_model.graph.output} + output_dims = { + node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim] + for node in onnx_model.graph.output + } output_dims["logits"][1] = "sequence_length" static_model = onnx.load(model_cache_path) updated_model = update_model_dims.update_inputs_outputs_dims(static_model, input_dims, output_dims) @@ -1063,10 +1064,9 @@ def _from_pretrained( use_io_binding=use_io_binding, model_save_dir=model_save_dir, preprocessors=preprocessors, - use_cache=use_cache + use_cache=use_cache, ) - @classmethod def _from_transformers( cls, @@ -1134,7 +1134,6 @@ def _from_transformers( file_name=file_name, ) - # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly @@ -1150,13 +1149,6 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg if past_key_values: position_ids = position_ids[:, -1].unsqueeze(-1) - - # TODO : rm !!!! - # `past_key_values` may be in the stardard format (e.g. in contrastive search), converts to bloom's format if needed - if past_key_values is not None and self.config.model_type == "bloom": - if past_key_values[0][0].shape[0] == input_ids.shape[0]: - past_key_values = self._convert_to_bloom_cache(past_key_values) - return { "input_ids": input_ids, "past_key_values": past_key_values, @@ -1165,90 +1157,12 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg "attention_mask": attention_mask, } - - # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache - def _reorder_cache( - self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. - This is required to match `past_key_values` with the correct beam_idx at every generation step. - """ - if self.config.model_type == "bloom": - return self._reorder_cache_bloom(past_key_values, beam_idx) - - # from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - - # TODO: remove - # Copied from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache - def _reorder_cache_bloom( - self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called for bloom architecture. - This is required to match `past_key_values` with the correct beam_idx at every generation step. - """ - standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx)) - - # Get a copy of `beam_idx` on all the devices where we need those indices. - device_to_beam_idx = { - past_state.device: beam_idx.to(past_state.device) - for layer_past in past_key_values - for past_state in layer_past - } - reordered_past = tuple( - ( - layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), - layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), - ) - for layer_past in standardized_past - ) - return self._convert_to_bloom_cache(reordered_past) - - # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_bloom_cache + # Copied from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache @staticmethod - def _convert_to_bloom_cache(past_key_value: Tuple[Tuple[torch.Tensor]]) -> Tuple[Tuple[torch.Tensor]]: - """ - Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) - """ - batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape - batch_size_times_num_heads = batch_size * num_heads - # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length] - # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim] - return tuple( - ( - layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length), - layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim), - ) - for layer_past in past_key_value - ) - - # Adapted from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_standard_cache - def _convert_to_standard_cache( - self, past_key_value: Tuple[Tuple[torch.Tensor]], batch_size: int - ) -> Tuple[Tuple[torch.Tensor]]: - """ - Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, num_heads, ...])) - """ - if self.config.model_type != "bloom": - return past_key_value - - batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape - num_heads = batch_size_times_num_heads // batch_size - # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] - # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim] + def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: return tuple( - ( - layer_past[0].view(batch_size, num_heads, head_dim, seq_length), - layer_past[1].view(batch_size, num_heads, seq_length, head_dim), - ) - for layer_past in past_key_value + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past ) def can_generate(self): From 1e9ba7e466d7dda80cc10f578afa556daad637ab Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 19 Sep 2023 11:03:56 +0200 Subject: [PATCH 36/76] fix --- optimum/onnxruntime/modeling_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 1f18f70078c..bcb8f042279 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -789,7 +789,7 @@ def forward( if "position_ids" in self.inputs_names: if position_ids is None: raise ValueError("position_ids was not passed but is a required input for this ONNX model.") - onnx_inputs["position_ids"] = position_ids.cpu().detach().numpy() if use_torch else position_ids + inputs["position_ids"] = position_ids.cpu().detach().numpy() if use_torch else position_ids # Add the past_key_values to the decoder inputs if past_key_values is not None: From 4b68caa3c1b79151ce066ead30c02eb68591b9a2 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 19 Sep 2023 11:04:40 +0200 Subject: [PATCH 37/76] format --- optimum/exporters/onnx/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 520e59aeb8d..9791428d514 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -30,7 +30,6 @@ ) from ...utils.import_utils import _diffusers_version from ...utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask - from ..tasks import TasksManager from .constants import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME From e5fd9f8e04a64ac736813c9dd653dadfdd7b8ebc Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 19 Sep 2023 12:13:19 +0200 Subject: [PATCH 38/76] fix dynamic axis for position ids --- optimum/exporters/onnx/config.py | 5 +---- optimum/onnxruntime/modeling_decoder.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 1fa6fa62629..3aca641513c 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -164,10 +164,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]: # generating wrong position_ids in the model itself: # https://github.com/huggingface/transformers/blob/v4.33.1/src/transformers/models/gpt2/modeling_gpt2.py#L802 if not self.no_position_ids and self.task == "text-generation": - if self.use_past_in_inputs: - common_inputs["position_ids"] = {0: "batch_size"} - else: - common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"} + common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"} return common_inputs diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index bcb8f042279..66190069e5d 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -699,8 +699,8 @@ def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, labels: Optional[torch.LongTensor] = None, use_cache_branch: None = None, **kwargs, From addad9264ae5dab1e37adf94d66ce97a8841bee9 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 19 Sep 2023 15:26:08 +0200 Subject: [PATCH 39/76] fix external data --- optimum/onnxruntime/modeling_decoder.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 66190069e5d..222fd9a852a 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -49,7 +49,7 @@ parse_device, validate_provider_availability, ) - +from ..onnx.utils import check_model_uses_external_data if TYPE_CHECKING: from transformers import PretrainedConfig @@ -1033,9 +1033,14 @@ def _from_pretrained( ################################################################################################## - # Since v1.7.0 decoder with past models have fixed sequence length of 1 + # Since v1.7.0 decoder with past models have fixed sequence length of 1 # To keep these models compatible we set this dimension to dynamic - onnx_model = onnx.load(model_cache_path) + onnx_model = onnx.load(str(model_cache_path), load_external_data=False) + model_uses_external_data = check_model_uses_external_data(onnx_model) + + if model_uses_external_data: + onnx_model = onnx.load(str(model_cache_path), load_external_data=True) + input_dims = { node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim] for node in onnx_model.graph.input @@ -1047,9 +1052,17 @@ def _from_pretrained( for node in onnx_model.graph.output } output_dims["logits"][1] = "sequence_length" - static_model = onnx.load(model_cache_path) - updated_model = update_model_dims.update_inputs_outputs_dims(static_model, input_dims, output_dims) - onnx.save(updated_model, model_cache_path) + onnx_model = update_model_dims.update_inputs_outputs_dims(onnx_model, input_dims, output_dims) + + onnx.save( + onnx_model, + str(model_cache_path), + save_as_external_data=model_uses_external_data, + all_tensors_to_one_file=True, + location=model_cache_path.name + "_data", + size_threshold=0, + ) + del onnx_model model = ORTModel.load_model( model_cache_path, From 2c063c0727eada267b969b70d986cc4941a5aac8 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 19 Sep 2023 15:27:06 +0200 Subject: [PATCH 40/76] format --- optimum/onnxruntime/modeling_decoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 222fd9a852a..732597d9308 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -32,7 +32,7 @@ import onnxruntime from ..exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS, main_export -from ..onnx.utils import _get_external_data_paths +from ..onnx.utils import _get_external_data_paths, check_model_uses_external_data from ..utils import NormalizedConfigManager, check_if_transformers_greater from ..utils.file_utils import validate_file_exists from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors @@ -49,7 +49,7 @@ parse_device, validate_provider_availability, ) -from ..onnx.utils import check_model_uses_external_data + if TYPE_CHECKING: from transformers import PretrainedConfig From 1b47093d1147bf1035adf3013a1ae0b4d216ba16 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 19 Sep 2023 16:02:06 +0200 Subject: [PATCH 41/76] test --- tests/exporters/onnx/test_onnx_export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/exporters/onnx/test_onnx_export.py b/tests/exporters/onnx/test_onnx_export.py index 4eefafa2d90..641e347c83a 100644 --- a/tests/exporters/onnx/test_onnx_export.py +++ b/tests/exporters/onnx/test_onnx_export.py @@ -570,7 +570,7 @@ def test_custom_export_official_model(self): assert "cross_attentions.0" in output_names @parameterized.expand( - grid_parameters({"fn_get_submodels": [None, fn_get_submodels_custom], "legacy": [True, False]}) + grid_parameters({"fn_get_submodels": (None, fn_get_submodels_custom), "legacy": (True, False)}) ) def test_custom_export_trust_remote(self, test_name, fn_get_submodels, legacy): model_id = "fxmarty/tiny-mpt-random-remote-code" From 35caaf221c2bc82c6851cb760d60eebf7b4e5270 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 19 Sep 2023 16:07:49 +0200 Subject: [PATCH 42/76] test --- tests/exporters/onnx/test_onnx_export.py | 49 ++++++++++++------------ 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/tests/exporters/onnx/test_onnx_export.py b/tests/exporters/onnx/test_onnx_export.py index 641e347c83a..11e6a53da36 100644 --- a/tests/exporters/onnx/test_onnx_export.py +++ b/tests/exporters/onnx/test_onnx_export.py @@ -569,10 +569,8 @@ def test_custom_export_official_model(self): assert "decoder_attentions.0" in output_names assert "cross_attentions.0" in output_names - @parameterized.expand( - grid_parameters({"fn_get_submodels": (None, fn_get_submodels_custom), "legacy": (True, False)}) - ) - def test_custom_export_trust_remote(self, test_name, fn_get_submodels, legacy): + @parameterized.expand([(None,), (fn_get_submodels_custom,)]) + def test_custom_export_trust_remote(self, fn_get_submodels): model_id = "fxmarty/tiny-mpt-random-remote-code" config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) onnx_config = CustomMPTOnnxConfig( @@ -583,28 +581,29 @@ def test_custom_export_trust_remote(self, test_name, fn_get_submodels, legacy): ) onnx_config_with_past = CustomMPTOnnxConfig(config, task="text-generation", use_past=True) - if legacy: - custom_onnx_configs = { - "decoder_model": onnx_config, - "decoder_with_past_model": onnx_config_with_past, - } - else: - custom_onnx_configs = { - "model": onnx_config_with_past, - } + for legacy in (True, False): + if legacy: + custom_onnx_configs = { + "decoder_model": onnx_config, + "decoder_with_past_model": onnx_config_with_past, + } + else: + custom_onnx_configs = { + "model": onnx_config_with_past, + } - with TemporaryDirectory() as tmpdirname: - main_export( - model_id, - output=tmpdirname, - task="text-generation-with-past", - trust_remote_code=True, - custom_onnx_configs=custom_onnx_configs, - no_post_process=True, - fn_get_submodels=partial(fn_get_submodels, legacy=legacy) if fn_get_submodels else None, - legacy=legacy, - opset=14, - ) + with TemporaryDirectory() as tmpdirname: + main_export( + model_id, + output=tmpdirname, + task="text-generation-with-past", + trust_remote_code=True, + custom_onnx_configs=custom_onnx_configs, + no_post_process=True, + fn_get_submodels=partial(fn_get_submodels, legacy=legacy) if fn_get_submodels else None, + legacy=legacy, + opset=14, + ) def test_custom_export_trust_remote_error(self): model_id = "mohitsha/tiny-ernie-random-remote-code" From 725857beb48485dc755e81811a56da599b759ebe Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 19 Sep 2023 17:37:35 +0200 Subject: [PATCH 43/76] add model patcher --- optimum/exporters/onnx/model_configs.py | 44 +++++++++++-- optimum/exporters/onnx/model_patcher.py | 84 +++++++++++++++++++++++++ optimum/exporters/onnx/utils.py | 13 +--- 3 files changed, 124 insertions(+), 17 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 401d995fdc7..a62499806c2 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -55,7 +55,18 @@ TextSeq2SeqOnnxConfig, VisionOnnxConfig, ) -from .model_patcher import SAMModelPatcher, WavLMModelPatcher +from .model_patcher import ( + SAMModelPatcher, + WavLMModelPatcher, + BloomModelPatcher, + MPTModelPatcher, + BartModelPatcher, + PegasusModelPatcher, + BlenderbotModelPatcher, + BlenderbotSmallModelPatcher, + OPTModelPatcher, + LlamaModelPatcher, +) if TYPE_CHECKING: @@ -215,11 +226,21 @@ class OPTOnnxConfig(TextDecoderOnnxConfig): DEFAULT_ONNX_OPSET = 13 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return OPTModelPatcher(self, model, model_kwargs=model_kwargs) + class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 13 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return LlamaModelPatcher(self, model, model_kwargs=model_kwargs) + class MPTOnnxConfig(TextDecoderOnnxConfig): # MPT does not require position_ids input. @@ -228,6 +249,11 @@ class MPTOnnxConfig(TextDecoderOnnxConfig): num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers" ) + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return MPTModelPatcher(self, model, model_kwargs=model_kwargs) + class BloomOnnxConfig(TextDecoderOnnxConfig): # Bloom does not require position_ids input. @@ -261,6 +287,11 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire 1: decoder_sequence_name, } + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return BloomModelPatcher(self, model, model_kwargs=model_kwargs) + class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = ( @@ -400,7 +431,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int return int_tensor -class BartOnnxConfig(TextSeq2SeqOnnxConfig): +class M2M100OnnxConfig(TextSeq2SeqOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( encoder_num_layers="encoder_layers", decoder_num_layers="decoder_layers", @@ -524,11 +555,14 @@ def flatten_past_key_values(self, flattened_output, name, idx, t): ) -class MBartOnnxConfig(BartOnnxConfig): - pass +class BartOnnxConfig(M2M100OnnxConfig): + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return BartModelPatcher(self, model, model_kwargs=model_kwargs) -class M2M100OnnxConfig(BartOnnxConfig): +class MBartOnnxConfig(BartOnnxConfig): pass diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index e6b50b6dc08..6dd6bd5c298 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -16,6 +16,7 @@ import functools import inspect from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +from ...utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask from transformers.utils import is_torch_available @@ -342,3 +343,86 @@ def patched_forward( return {"iou_scores": iou_predictions, "pred_masks": low_res_masks} self.patched_forward = patched_forward + + +class BloomModelPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + self.orig_prepare_attn_mask = getattr(self._model.transformer, "_prepare_attn_mask") + + def __enter__(self): + super().__enter__() + if self.real_config.task == "text-generation" and self.real_config.use_past: + setattr(self._model.transformer, "_prepare_attn_mask", _prepare_attn_mask) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + if self.real_config.task == "text-generation" and self.real_config.use_past: + setattr(self._model.transformer, "_prepare_attn_mask", self.orig_prepare_attn_mask) + + +class MPTModelPatcher(BloomModelPatcher): + pass + + +class LlamaModelPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + self.orig_prepare_attn_mask = getattr(self._model.model, "_prepare_decoder_attention_mask") + + def __enter__(self): + super().__enter__() + if self.real_config.task == "text-generation" and self.real_config.use_past: + setattr(self._model.model, "_prepare_decoder_attention_mask", _prepare_decoder_attention_mask) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + if self.real_config.task == "text-generation" and self.real_config.use_past: + setattr(self._model.model, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask) + + +class OPTModelPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask") + + def __enter__(self): + super().__enter__() + if self.real_config.task == "text-generation" and self.real_config.use_past: + setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", _prepare_decoder_attention_mask) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + if self.real_config.task == "text-generation" and self.real_config.use_past: + setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask) + + +class BlenderbotSmallModelPatcher(OPTModelPatcher): + pass + + +class BlenderbotModelPatcher(OPTModelPatcher): + pass + + +class PegasusModelPatcher(OPTModelPatcher): + pass + + +class BartModelPatcher(OPTModelPatcher): + pass diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 9791428d514..784920cea34 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -166,18 +166,7 @@ def _get_submodels_for_export_decoder( """ Returns the decoder part of the model. """ - models_for_export = {} - - if not legacy or use_past: - # fix causal lm generation for inputs of sequence length 1 - if model.config.model_type in {"bloom", "mpt"}: - model.transformer._prepare_attn_mask = _prepare_attn_mask - elif model.config.model_type == "llama": - model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask - elif model.config.model_type in {"blenderbot-small", "blenderbot", "opt", "pegasus", "bart"}: - model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask - - models_for_export[ONNX_DECODER_NAME if legacy else "model"] = model + models_for_export = {ONNX_DECODER_NAME if legacy else "model": model} if legacy and use_past: models_for_export[ONNX_DECODER_WITH_PAST_NAME] = model From 46b26b5ff7888acd91390f6d074a651451149c37 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 19 Sep 2023 17:39:35 +0200 Subject: [PATCH 44/76] format --- optimum/exporters/onnx/model_configs.py | 11 ++++------- optimum/exporters/onnx/model_patcher.py | 7 +++++++ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index a62499806c2..5b191f9dcea 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -56,16 +56,13 @@ VisionOnnxConfig, ) from .model_patcher import ( - SAMModelPatcher, - WavLMModelPatcher, + BartModelPatcher, BloomModelPatcher, + LlamaModelPatcher, MPTModelPatcher, - BartModelPatcher, - PegasusModelPatcher, - BlenderbotModelPatcher, - BlenderbotSmallModelPatcher, OPTModelPatcher, - LlamaModelPatcher, + SAMModelPatcher, + WavLMModelPatcher, ) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 6dd6bd5c298..8d067bea615 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -345,6 +345,9 @@ def patched_forward( self.patched_forward = patched_forward + + + class BloomModelPatcher(ModelPatcher): def __init__( self, @@ -360,12 +363,14 @@ def __enter__(self): if self.real_config.task == "text-generation" and self.real_config.use_past: setattr(self._model.transformer, "_prepare_attn_mask", _prepare_attn_mask) + def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if self.real_config.task == "text-generation" and self.real_config.use_past: setattr(self._model.transformer, "_prepare_attn_mask", self.orig_prepare_attn_mask) + class MPTModelPatcher(BloomModelPatcher): pass @@ -426,3 +431,5 @@ class PegasusModelPatcher(OPTModelPatcher): class BartModelPatcher(OPTModelPatcher): pass + + From 33957af8c67ac04e8932b2d093b502b33bf9d1de Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 19 Sep 2023 17:48:32 +0200 Subject: [PATCH 45/76] fix --- optimum/onnxruntime/modeling_decoder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 732597d9308..511132171a1 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -1023,7 +1023,8 @@ def _from_pretrained( except EntryNotFoundError: # model doesn't use external data pass - new_model_save_dir = Path(model_cache_path).parent + model_cache_path = Path(model_cache_path) + new_model_save_dir = model_cache_path.parent preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder) # model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it From c2ec382a4404ac7f8e358f3396fbed1c17ff648b Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 19 Sep 2023 18:52:18 +0200 Subject: [PATCH 46/76] fix bart model patcher --- optimum/exporters/onnx/model_patcher.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 8d067bea615..0fb55baabb0 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -396,7 +396,7 @@ def __exit__(self, exc_type, exc_value, traceback): setattr(self._model.model, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask) -class OPTModelPatcher(ModelPatcher): +class BartModelPatcher(Seq2SeqModelPatcher): def __init__( self, config: "OnnxConfig", @@ -404,32 +404,33 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__(config, model, model_kwargs) - self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask") + if self.real_config._behavior == "decoder" and self.real_config.task == "text-generation" and self.real_config.use_past: + self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask") def __enter__(self): super().__enter__() - if self.real_config.task == "text-generation" and self.real_config.use_past: + if self.real_config._behavior == "decoder" and self.real_config.task == "text-generation" and self.real_config.use_past: setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", _prepare_decoder_attention_mask) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - if self.real_config.task == "text-generation" and self.real_config.use_past: + if self.real_config._behavior == "decoder" and self.real_config.task == "text-generation" and self.real_config.use_past: setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask) -class BlenderbotSmallModelPatcher(OPTModelPatcher): +class BlenderbotSmallModelPatcher(BartModelPatcher): pass -class BlenderbotModelPatcher(OPTModelPatcher): +class BlenderbotModelPatcher(BartModelPatcher): pass -class PegasusModelPatcher(OPTModelPatcher): +class PegasusModelPatcher(BartModelPatcher): pass -class BartModelPatcher(OPTModelPatcher): +class OPTModelPatcher(BartModelPatcher): pass From d86bce6417743744b6d51f10f58bf1ec4a5dc694 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 19 Sep 2023 18:53:26 +0200 Subject: [PATCH 47/76] format --- optimum/exporters/onnx/model_patcher.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 0fb55baabb0..ae3823708cc 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -16,10 +16,11 @@ import functools import inspect from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union -from ...utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask from transformers.utils import is_torch_available +from ...utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask + if is_torch_available(): import torch From be836b5d604639c5b72b09040841c5f187d7e723 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 20 Sep 2023 09:46:50 +0200 Subject: [PATCH 48/76] format --- optimum/exporters/onnx/model_patcher.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index ae3823708cc..4f16624d292 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -346,9 +346,6 @@ def patched_forward( self.patched_forward = patched_forward - - - class BloomModelPatcher(ModelPatcher): def __init__( self, @@ -364,14 +361,12 @@ def __enter__(self): if self.real_config.task == "text-generation" and self.real_config.use_past: setattr(self._model.transformer, "_prepare_attn_mask", _prepare_attn_mask) - def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if self.real_config.task == "text-generation" and self.real_config.use_past: setattr(self._model.transformer, "_prepare_attn_mask", self.orig_prepare_attn_mask) - class MPTModelPatcher(BloomModelPatcher): pass @@ -405,17 +400,29 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__(config, model, model_kwargs) - if self.real_config._behavior == "decoder" and self.real_config.task == "text-generation" and self.real_config.use_past: + if ( + self.real_config._behavior == "decoder" + and self.real_config.task == "text-generation" + and self.real_config.use_past + ): self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask") def __enter__(self): super().__enter__() - if self.real_config._behavior == "decoder" and self.real_config.task == "text-generation" and self.real_config.use_past: + if ( + self.real_config._behavior == "decoder" + and self.real_config.task == "text-generation" + and self.real_config.use_past + ): setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", _prepare_decoder_attention_mask) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - if self.real_config._behavior == "decoder" and self.real_config.task == "text-generation" and self.real_config.use_past: + if ( + self.real_config._behavior == "decoder" + and self.real_config.task == "text-generation" + and self.real_config.use_past + ): setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask) @@ -433,5 +440,3 @@ class PegasusModelPatcher(BartModelPatcher): class OPTModelPatcher(BartModelPatcher): pass - - From b05f59915d83b27ab414b6b883865de2709c2b8b Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 20 Sep 2023 10:51:15 +0200 Subject: [PATCH 49/76] fix model patcher for opt models --- optimum/exporters/onnx/model_patcher.py | 72 +++++++++++++++---------- 1 file changed, 44 insertions(+), 28 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 4f16624d292..77a0345a9c8 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -354,23 +354,22 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__(config, model, model_kwargs) - self.orig_prepare_attn_mask = getattr(self._model.transformer, "_prepare_attn_mask") + + self.patch = self.real_config.task == "text-generation" and self.real_config.use_past + if self.patch: + self.orig_prepare_attn_mask = getattr(self._model.transformer, "_prepare_attn_mask") def __enter__(self): super().__enter__() - if self.real_config.task == "text-generation" and self.real_config.use_past: + if self.patch: setattr(self._model.transformer, "_prepare_attn_mask", _prepare_attn_mask) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - if self.real_config.task == "text-generation" and self.real_config.use_past: + if self.patch: setattr(self._model.transformer, "_prepare_attn_mask", self.orig_prepare_attn_mask) -class MPTModelPatcher(BloomModelPatcher): - pass - - class LlamaModelPatcher(ModelPatcher): def __init__( self, @@ -379,16 +378,19 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__(config, model, model_kwargs) - self.orig_prepare_attn_mask = getattr(self._model.model, "_prepare_decoder_attention_mask") + + self.patch = self.real_config.task == "text-generation" and self.real_config.use_past + if self.patch: + self.orig_prepare_attn_mask = getattr(self._model.model, "_prepare_decoder_attention_mask") def __enter__(self): super().__enter__() - if self.real_config.task == "text-generation" and self.real_config.use_past: + if self.patch: setattr(self._model.model, "_prepare_decoder_attention_mask", _prepare_decoder_attention_mask) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - if self.real_config.task == "text-generation" and self.real_config.use_past: + if self.patch: setattr(self._model.model, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask) @@ -400,32 +402,49 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__(config, model, model_kwargs) - if ( - self.real_config._behavior == "decoder" - and self.real_config.task == "text-generation" - and self.real_config.use_past - ): + self.patch = self.real_config.task == "text-generation" and self.real_config.use_past and self.real_config._behavior == "decoder" + if self.patch: self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask") def __enter__(self): super().__enter__() - if ( - self.real_config._behavior == "decoder" - and self.real_config.task == "text-generation" - and self.real_config.use_past - ): + if self.patch: setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", _prepare_decoder_attention_mask) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - if ( - self.real_config._behavior == "decoder" - and self.real_config.task == "text-generation" - and self.real_config.use_past - ): + if self.patch: + setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask) + + +class OPTModelPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + self.patch = self.real_config.task == "text-generation" and self.real_config.use_past + if self.patch: + self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask") + + def __enter__(self): + super().__enter__() + if self.patch: + setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", _prepare_decoder_attention_mask) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + if self.patch: setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask) + +class MPTModelPatcher(BloomModelPatcher): + pass + + class BlenderbotSmallModelPatcher(BartModelPatcher): pass @@ -437,6 +456,3 @@ class BlenderbotModelPatcher(BartModelPatcher): class PegasusModelPatcher(BartModelPatcher): pass - -class OPTModelPatcher(BartModelPatcher): - pass From 26d97e8fd671e5ad0ddbc20241478a7d93e4b2c0 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 20 Sep 2023 10:53:18 +0200 Subject: [PATCH 50/76] fix format --- optimum/exporters/onnx/model_patcher.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 77a0345a9c8..e8a1574128b 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -402,7 +402,11 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__(config, model, model_kwargs) - self.patch = self.real_config.task == "text-generation" and self.real_config.use_past and self.real_config._behavior == "decoder" + self.patch = ( + self.real_config.task == "text-generation" + and self.real_config.use_past + and self.real_config._behavior == "decoder" + ) if self.patch: self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask") @@ -440,7 +444,6 @@ def __exit__(self, exc_type, exc_value, traceback): setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask) - class MPTModelPatcher(BloomModelPatcher): pass @@ -455,4 +458,3 @@ class BlenderbotModelPatcher(BartModelPatcher): class PegasusModelPatcher(BartModelPatcher): pass - From 4b6c3ed2453b7550ff6f76fdea02aa53db65426e Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 20 Sep 2023 13:47:31 +0200 Subject: [PATCH 51/76] add tmp onnxruntime max version --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 7a7f4546844..78f65e40d48 100644 --- a/setup.py +++ b/setup.py @@ -45,14 +45,14 @@ EXTRAS_REQUIRE = { "onnxruntime": [ "onnx", - "onnxruntime>=1.11.0", + "onnxruntime>=1.11.0,<1.16.0", "datasets>=1.2.1", "evaluate", "protobuf>=3.20.1", ], "onnxruntime-gpu": [ "onnx", - "onnxruntime-gpu>=1.11.0", + "onnxruntime-gpu>=1.11.0,<1.16.0", "datasets>=1.2.1", "evaluate", "protobuf>=3.20.1", From 615a21984d3fee1844c4aa97b59c2cd28aa1ad1f Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 20 Sep 2023 14:06:23 +0200 Subject: [PATCH 52/76] add test --- optimum/onnxruntime/modeling_decoder.py | 7 +++++++ tests/onnxruntime/test_modeling.py | 18 ++++++++++++++---- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 511132171a1..e376a9039d3 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -687,6 +687,13 @@ def __init__( "To export your model, simply set `export=True`." ) + if use_io_binding and not use_cache: + raise ValueError( + "The parameters combination use_cache=False, use_io_binding=True is not supported. " + "Please either pass use_cache=True, use_io_binding=True (default), or use_cache=False, use_io_binding=False." + ) + + @add_start_docstrings_to_model_forward( CAUSALLM_ONNX_MODEL_DOCSTRING.format("batch_size, sequence_length") + TEXT_GENERATION_EXAMPLE.format( diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index d115f7b1f1a..63615611150 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -1970,14 +1970,24 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): GENERATION_LENGTH = 100 SPEEDUP_CACHE = 1.1 - def test_inference_old_onnx_model(self): - model = ORTModelForCausalLM.from_pretrained("optimum/gpt2") - tokenizer = get_preprocessor("optimum/gpt2") + def test_inference_old_onnx_model(self): + model_id = "optimum/gpt2" + model = AutoModelForCausalLM.from_pretrained("gpt2") + tokenizer = get_preprocessor(model_id) text = "This is a sample output" tokens = tokenizer(text, return_tensors="pt") - model.generate(**tokens) + for use_cache in (True, False): + onnx_model = ORTModelForCausalLM.from_pretrained(model_id, use_cache=use_cache, use_io_binding=use_cache) + + self.assertEqual(onnx_model.use_cache, use_cache) + self.assertEqual( + onnx_model.model_path.name, ONNX_DECODER_WITH_PAST_NAME if use_cache else ONNX_DECODER_NAME + ) + outputs_onnx = onnx_model.generate(**tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30) + outputs = model.generate(**tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30) + self.assertTrue(torch.allclose(outputs_onnx, outputs)) def test_load_model_from_hub_onnx(self): model = ORTModelForCausalLM.from_pretrained("fxmarty/onnx-tiny-random-gpt2-without-merge") From b3525f8d8438e8a294d8a4efac820362dc88e429 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 20 Sep 2023 14:06:52 +0200 Subject: [PATCH 53/76] format --- optimum/onnxruntime/modeling_decoder.py | 1 - tests/onnxruntime/test_modeling.py | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index e376a9039d3..db2651fdbc6 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -693,7 +693,6 @@ def __init__( "Please either pass use_cache=True, use_io_binding=True (default), or use_cache=False, use_io_binding=False." ) - @add_start_docstrings_to_model_forward( CAUSALLM_ONNX_MODEL_DOCSTRING.format("batch_size, sequence_length") + TEXT_GENERATION_EXAMPLE.format( diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 63615611150..1a219f6345f 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -1970,7 +1970,6 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): GENERATION_LENGTH = 100 SPEEDUP_CACHE = 1.1 - def test_inference_old_onnx_model(self): model_id = "optimum/gpt2" model = AutoModelForCausalLM.from_pretrained("gpt2") @@ -1985,7 +1984,9 @@ def test_inference_old_onnx_model(self): self.assertEqual( onnx_model.model_path.name, ONNX_DECODER_WITH_PAST_NAME if use_cache else ONNX_DECODER_NAME ) - outputs_onnx = onnx_model.generate(**tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30) + outputs_onnx = onnx_model.generate( + **tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30 + ) outputs = model.generate(**tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30) self.assertTrue(torch.allclose(outputs_onnx, outputs)) From e0e2bae1d55ae404af4f97077472e88aae083507 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 20 Sep 2023 14:08:39 +0200 Subject: [PATCH 54/76] tmp fix onnxruntime max version --- setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 78f65e40d48..8abe5d0ca61 100644 --- a/setup.py +++ b/setup.py @@ -58,9 +58,9 @@ "protobuf>=3.20.1", "accelerate", # ORTTrainer requires it. ], - "exporters": ["onnx", "onnxruntime", "timm"], - "exporters-gpu": ["onnx", "onnxruntime-gpu", "timm"], - "exporters-tf": ["tensorflow>=2.4,<=2.12.1", "tf2onnx", "onnx", "onnxruntime", "timm", "h5py", "numpy<1.24.0"], + "exporters": ["onnx", "onnxruntime<1.16.0", "timm"], + "exporters-gpu": ["onnx", "onnxruntime-gpu<1.16.0", "timm"], + "exporters-tf": ["tensorflow>=2.4,<=2.12.1", "tf2onnx", "onnx", "onnxruntime<1.16.0", "timm", "h5py", "numpy<1.24.0"], "diffusers": ["diffusers"], "intel": "optimum-intel>=1.11.0", "openvino": "optimum-intel[openvino]>=1.11.0", From cbc935fd799e835ae6fda0df2307456221c791b3 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 20 Sep 2023 14:08:55 +0200 Subject: [PATCH 55/76] format --- setup.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 8abe5d0ca61..a4b952d6523 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,15 @@ ], "exporters": ["onnx", "onnxruntime<1.16.0", "timm"], "exporters-gpu": ["onnx", "onnxruntime-gpu<1.16.0", "timm"], - "exporters-tf": ["tensorflow>=2.4,<=2.12.1", "tf2onnx", "onnx", "onnxruntime<1.16.0", "timm", "h5py", "numpy<1.24.0"], + "exporters-tf": [ + "tensorflow>=2.4,<=2.12.1", + "tf2onnx", + "onnx", + "onnxruntime<1.16.0", + "timm", + "h5py", + "numpy<1.24.0", + ], "diffusers": ["diffusers"], "intel": "optimum-intel>=1.11.0", "openvino": "optimum-intel[openvino]>=1.11.0", From 624d91da9a5a6735a6f2f72c5414077160f9fbc9 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 20 Sep 2023 16:02:14 +0200 Subject: [PATCH 56/76] add test --- tests/onnxruntime/test_quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/onnxruntime/test_quantization.py b/tests/onnxruntime/test_quantization.py index eb007f16b8b..108003624f8 100644 --- a/tests/onnxruntime/test_quantization.py +++ b/tests/onnxruntime/test_quantization.py @@ -121,7 +121,7 @@ def test_dynamic_quantization(self, model_cls, model_name, expected_quantized_ma def test_decoder_quantization_with_and_without_cache(self, test_name, model_info, use_cache): model_cls, model_name, expected_quantized_matmuls = model_info qconfig = AutoQuantizationConfig.avx512(is_static=False, per_channel=True) - model = model_cls.from_pretrained(model_name, export=True, use_cache=use_cache) + model = model_cls.from_pretrained(model_name, export=True, use_cache=use_cache, use_io_binding=use_cache) with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained(tmp_dir) From c5584504da72480c089275dbcd26197ab5d0cb14 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 20 Sep 2023 18:13:24 +0200 Subject: [PATCH 57/76] fix ort docker --- tests/onnxruntime/docker/Dockerfile_onnxruntime_gpu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/onnxruntime/docker/Dockerfile_onnxruntime_gpu b/tests/onnxruntime/docker/Dockerfile_onnxruntime_gpu index d1cd13396fb..fd7f4df09b4 100644 --- a/tests/onnxruntime/docker/Dockerfile_onnxruntime_gpu +++ b/tests/onnxruntime/docker/Dockerfile_onnxruntime_gpu @@ -8,7 +8,7 @@ ENV DEBIAN_FRONTEND noninteractive # Install and update tools to minimize security vulnerabilities RUN apt-get update RUN apt-get install -y software-properties-common wget apt-utils patchelf git libprotobuf-dev protobuf-compiler cmake \ - bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev && \ + bzip2 ca-certificates libglib2.0-0 ffmpeg libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev && \ apt-get clean RUN unattended-upgrade RUN apt-get autoremove -y From e72526d4a4942bb5097fcd34cfd19a7b87ede239 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 20 Sep 2023 18:52:02 +0200 Subject: [PATCH 58/76] fix format --- optimum/exporters/onnx/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 784920cea34..9aac510255b 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -29,7 +29,7 @@ logging, ) from ...utils.import_utils import _diffusers_version -from ...utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask +from ...utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask # noqa: F401 from ..tasks import TasksManager from .constants import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME From 44ef0f1beaf2b8968651c59dafbf78ac5b43c119 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 22 Sep 2023 12:19:12 +0200 Subject: [PATCH 59/76] add test --- tests/onnxruntime/test_quantization.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/onnxruntime/test_quantization.py b/tests/onnxruntime/test_quantization.py index bbcaa1f4f37..4062c556ea9 100644 --- a/tests/onnxruntime/test_quantization.py +++ b/tests/onnxruntime/test_quantization.py @@ -115,6 +115,31 @@ def test_dynamic_quantization(self, model_cls, model_name, expected_quantized_ma self.assertEqual(expected_quantized_matmuls, num_quantized_matmul) gc.collect() + @unittest.skipIf(parse(ort_version) == Version("1.16.0"), "not supported with this onnxruntime version") + def test_dynamic_quantization_subgraphs(self): + qconfig = AutoQuantizationConfig.avx512(is_static=False, per_channel=True) + tmp_dir = tempfile.mkdtemp() + output_dir = Path(tmp_dir) + model = ORTModelForCausalLM.from_pretrained("fxmarty/onnx-tiny-random-gpt2-with-merge", use_merged=True) + self.assertTrue(model.use_merged) + model.save_pretrained(tmp_dir) + + quantizer = ORTQuantizer.from_pretrained(model) + quantizer.quantize(save_dir=output_dir, quantization_config=qconfig) + expected_ort_config = ORTConfig(quantization=qconfig) + ort_config = ORTConfig.from_pretrained(tmp_dir) + # Verify the ORTConfig was correctly created and saved + self.assertEqual(ort_config.to_dict(), expected_ort_config.to_dict()) + + quantized_model = onnx_load(output_dir.joinpath("decoder_model_merged_quantized.onnx")) + num_quantized_matmul = 0 + for initializer in quantized_model.graph.initializer: + if "weight" in initializer.name and "quantized" in initializer.name: + num_quantized_matmul += 1 + + self.assertTrue(num_quantized_matmul > 0) + gc.collect() + @parameterized.expand( grid_parameters( {"model_arch": SUPPORTED_DECODER_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS, "use_cache": [True, False]} From ed8e74f10842975b7f68b68f44ff3b53d95b4eb0 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 25 Sep 2023 12:28:25 +0200 Subject: [PATCH 60/76] fix bart model patcher --- optimum/exporters/onnx/model_patcher.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index e8a1574128b..6f06c02d4a8 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -402,11 +402,8 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__(config, model, model_kwargs) - self.patch = ( - self.real_config.task == "text-generation" - and self.real_config.use_past - and self.real_config._behavior == "decoder" - ) + + self.patch = self.real_config.task == "text-generation" and self.real_config.use_past if self.patch: self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask") From c13a170a98c8118cde0baac542046296de296910 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 25 Sep 2023 14:13:42 +0200 Subject: [PATCH 61/76] raise when unsupported model --- optimum/onnxruntime/modeling_decoder.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index db2651fdbc6..e130be0bca6 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -962,20 +962,33 @@ def _from_pretrained( use_merged = False if use_merged is False: + pattern = DECODER_WITH_PAST_ONNX_FILE_PATTERN if use_cache else DECODER_ONNX_FILE_PATTERN # exclude decoder file for first iteration decoder_path = ORTModelForCausalLM.infer_onnx_filename( model_id, - [ - r"^((?!decoder).)*.onnx", - DECODER_WITH_PAST_ONNX_FILE_PATTERN if use_cache else DECODER_ONNX_FILE_PATTERN, - ], - "file_name", + [r"^((?!decoder).)*.onnx", pattern], + argument_name=None, subfolder=subfolder, use_auth_token=use_auth_token, revision=revision, ) file_name = decoder_path.name + MODEL_TO_PATCH_FOR_PAST = { + "bloom", + "mpt", + "llama", + "blenderbot-small", + "blenderbot", + "opt", + "pegasus", + "bart", + } + if file_name == ONNX_DECODER_WITH_PAST_NAME and config.model_type in MODEL_TO_PATCH_FOR_PAST: + raise ValueError( + f"{ONNX_DECODER_WITH_PAST_NAME} not supported for the following architecture : {', '.join(MODEL_TO_PATCH_FOR_PAST)}. Please re-export your model or set use_cache=False." + ) + regular_file_names = [] for name in [ONNX_WEIGHTS_NAME, ONNX_DECODER_WITH_PAST_NAME if use_cache else ONNX_DECODER_NAME]: regular_file_names += ORTModelForCausalLM._generate_regular_names_for_filename(name) From 524b6682b579e963a98680199dc8431805e0dfb8 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 25 Sep 2023 14:24:12 +0200 Subject: [PATCH 62/76] add cached file --- optimum/onnxruntime/modeling_decoder.py | 50 +++-------- optimum/onnxruntime/modeling_ort.py | 110 +++++++++++++++--------- 2 files changed, 78 insertions(+), 82 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index e130be0bca6..b13b7ed7d63 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -1008,51 +1008,23 @@ def _from_pretrained( else: init_cls = ORTModelForCausalLM - ################################################################################################## - - preprocessors = None - if model_path.is_dir(): - model_cache_path = model_path / file_name - new_model_save_dir = model_path - preprocessors = maybe_load_preprocessors(model_id) - else: - model_cache_path = hf_hub_download( - repo_id=model_id, - filename=file_name, - subfolder=subfolder, - use_auth_token=use_auth_token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - ) - - # try download external data - try: - hf_hub_download( - repo_id=model_id, - subfolder=subfolder, - filename=file_name + "_data", - use_auth_token=use_auth_token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - ) - except EntryNotFoundError: - # model doesn't use external data - pass - model_cache_path = Path(model_cache_path) - new_model_save_dir = model_cache_path.parent - preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder) + model_cache_path, preprocessors = cls._cached_file( + model_path=model_path, + use_auth_token=use_auth_token, + revision=revision, + force_download=force_download, + cache_dir=cache_dir, + file_name=file_name, + subfolder=subfolder, + local_files_only=local_files_only, + ) + new_model_save_dir = model_cache_path.parent # model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it # instead of the path only. if model_save_dir is None: model_save_dir = new_model_save_dir - ################################################################################################## - # Since v1.7.0 decoder with past models have fixed sequence length of 1 # To keep these models compatible we set this dimension to dynamic onnx_model = onnx.load(str(model_cache_path), load_external_data=False) diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index b190432fe6d..b58a37eb43a 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -486,55 +486,30 @@ def _from_pretrained( "not behave as expected." ) - preprocessors = None - if model_path.is_dir(): - model = ORTModel.load_model( - model_path / file_name, - provider=provider, - session_options=session_options, - provider_options=provider_options, - ) - new_model_save_dir = model_path - preprocessors = maybe_load_preprocessors(model_id) - else: - model_cache_path = hf_hub_download( - repo_id=model_id, - filename=file_name, - subfolder=subfolder, - use_auth_token=use_auth_token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - ) - - # try download external data - try: - hf_hub_download( - repo_id=model_id, - subfolder=subfolder, - filename=file_name + "_data", - use_auth_token=use_auth_token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - ) - except EntryNotFoundError: - # model doesn't use external data - pass - - model = ORTModel.load_model( - model_cache_path, provider=provider, session_options=session_options, provider_options=provider_options - ) - new_model_save_dir = Path(model_cache_path).parent - preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder) + model_cache_path, preprocessors = cls._cached_file( + model_path=model_path, + use_auth_token=use_auth_token, + revision=revision, + force_download=force_download, + cache_dir=cache_dir, + file_name=file_name, + subfolder=subfolder, + local_files_only=local_files_only, + ) + new_model_save_dir = model_cache_path.parent # model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it # instead of the path only. if model_save_dir is None: model_save_dir = new_model_save_dir + model = ORTModel.load_model( + model_cache_path, + provider=provider, + session_options=session_options, + provider_options=provider_options, + ) + return cls( model=model, config=config, @@ -828,6 +803,55 @@ def raise_on_numpy_input_io_binding(self, use_torch: bool): " with model.use_io_binding = False, or pass torch.Tensor inputs instead." ) + @staticmethod + def _cached_file( + model_path: Union[Path, str], + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + force_download: bool = False, + cache_dir: Optional[str] = None, + file_name: Optional[str] = None, + subfolder: str = "", + local_files_only: bool = False, + ): + model_path = Path(model_path) + + # locates a file in a local folder and repo, downloads and cache it if necessary. + if model_path.is_dir(): + model_cache_path = model_path / file_name + preprocessors = maybe_load_preprocessors(model_path.as_posix()) + else: + model_cache_path = hf_hub_download( + repo_id=model_path.as_posix(), + filename=file_name, + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + ) + # try download external data + try: + hf_hub_download( + repo_id=model_path.as_posix(), + subfolder=subfolder, + filename=file_name + "_data", + use_auth_token=use_auth_token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + ) + except EntryNotFoundError: + # model doesn't use external data + pass + + model_cache_path = Path(model_cache_path) + preprocessors = maybe_load_preprocessors(model_path.as_posix(), subfolder=subfolder) + + return model_cache_path, preprocessors + FEATURE_EXTRACTION_EXAMPLE = r""" Example of feature extraction: From 8951ddf4f429bfdc5e1225a9cf12d10ee44464dd Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 3 Oct 2023 12:45:20 +0200 Subject: [PATCH 63/76] minor --- optimum/commands/export/onnx.py | 2 +- optimum/exporters/onnx/model_configs.py | 3 +- optimum/exporters/onnx/model_patcher.py | 16 - optimum/onnxruntime/modeling_decoder.py | 517 +----------------------- tests/onnxruntime/test_modeling.py | 44 +- 5 files changed, 26 insertions(+), 556 deletions(-) diff --git a/optimum/commands/export/onnx.py b/optimum/commands/export/onnx.py index 91c33b32573..c2c38473a88 100644 --- a/optimum/commands/export/onnx.py +++ b/optimum/commands/export/onnx.py @@ -220,7 +220,7 @@ def parse_args_onnx(parser): optional_group.add_argument( "--legacy", action="store_true", - help=("Export decoder only models in two (without + with past) model as a single ONNX file."), + help="Export decoder only models in two (without + with past) model as a single ONNX file.", ) # deprecated argument diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 5b191f9dcea..52f2a0caab2 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -59,7 +59,6 @@ BartModelPatcher, BloomModelPatcher, LlamaModelPatcher, - MPTModelPatcher, OPTModelPatcher, SAMModelPatcher, WavLMModelPatcher, @@ -249,7 +248,7 @@ class MPTOnnxConfig(TextDecoderOnnxConfig): def patch_model_for_export( self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None ) -> "ModelPatcher": - return MPTModelPatcher(self, model, model_kwargs=model_kwargs) + return BloomModelPatcher(self, model, model_kwargs=model_kwargs) class BloomOnnxConfig(TextDecoderOnnxConfig): diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 6f06c02d4a8..0928dab5957 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -439,19 +439,3 @@ def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if self.patch: setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask) - - -class MPTModelPatcher(BloomModelPatcher): - pass - - -class BlenderbotSmallModelPatcher(BartModelPatcher): - pass - - -class BlenderbotModelPatcher(BartModelPatcher): - pass - - -class PegasusModelPatcher(BartModelPatcher): - pass diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index b13b7ed7d63..2616e40fef2 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -123,516 +123,6 @@ """ -class ORTModelDecoder(ORTModel): - """ - Base class for implementing models with a causal language modeling head using ONNX Runtime inference. - """ - - def __init__( - self, - decoder_session: onnxruntime.InferenceSession, - config: "PretrainedConfig", - onnx_paths: List[str], - decoder_with_past_session: Optional[onnxruntime.InferenceSession] = None, - use_cache: bool = True, - use_io_binding: Optional[bool] = None, - model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, - preprocessors: Optional[List] = None, - generation_config: Optional[GenerationConfig] = None, - **kwargs, - ): - """ - Args: - decoder_session (`onnxruntime.InferenceSession`): - The ONNX Runtime inference session associated to the decoder. - config ([`~transformers.PretrainedConfig`]): - An instance of the configuration associated to the model. Initializing with a config file does - not load the weights associated with the model, only the configuration. - decoder_with_past_session (`Optional[onnxruntime.InferenceSession]`, defaults to `None`): - The ONNX Runtime inference session associated to the decoder with past key values. This argument should not - be set if use_merged=True is used. - onnx_paths (`List[str]`): - Path to ONNX files associated with the model. - use_cache (`bool`, defaults to `True`): - Whether or not past key/values cache should be used. Defaults to `True`. - use_io_binding (`Optional[bool]`, defaults to `None`): - Whether to use IOBinding during inference to avoid memory copy between the host and devices. Defaults to - `True` if the execution provider is CPUExecutionProvider or CUDAExecutionProvider, otherwise defaults to `False`. - model_save_dir (`Optional[Union[str, Path, TemporaryDirectory]]`, defaults to `""`): - The directory under which the model exported to ONNX was saved. - preprocessors (`Optional[List]`, defaults to `None`): - The list of the preprocessors (tokenizer, processor, feature_extractor) to save alongside the ORTModel. - generation_config (`Optional[GenerationConfig]`, defaults to `None`): - The generation configuration used by default when calling `generate()`. - Refer to https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate. - """ - if use_io_binding is None: - if decoder_session.get_providers()[0] in ["CPUExecutionProvider", "CUDAExecutionProvider"]: - use_io_binding = True - else: - use_io_binding = False - - self.shared_attributes_init( - decoder_session, - use_io_binding=use_io_binding, - model_save_dir=model_save_dir, - ) - self.config = config - - # TODO: remove at version 2.0 - def show_deprecated_argument(arg_name): - if kwargs.pop(arg_name, None) is not None: - logger.warning( - f"The {arg_name} argument to create an {self.__class__.__name__} is deprecated, and not used " - "anymore." - ) - - show_deprecated_argument("last_decoder_model_name") - show_deprecated_argument("last_decoder_with_past_model_name") - if kwargs: - raise ValueError( - f"{self.__class__.__name__} received {', '.join(kwargs.keys())}, but do not accept those arguments." - ) - - if use_cache is True: - # Auto-detect whether the provided session is a merged non-past / with-past or not - # TODO: make __init__ private and pass `use_merged` as an argument - use_merged = "use_cache_branch" in [inp.name for inp in decoder_session.get_inputs()] - - if use_merged is True and decoder_with_past_session is not None: - raise ValueError( - "Detected a merged decoder, but decoder_with_past_session was provided." - "Please only set decoder_session, or provide a non-merged decoder_session." - ) - if use_cache is True and use_merged is False and decoder_with_past_session is None: - raise ValueError( - "The parameter use_cache was set as True, but neither decoder_with_past_session was passed" - " nor a use_cache branch can be found in the decoder_session." - " Please pass a decoder_with_past_session or set use_cache=False." - ) - else: - use_merged = False - - if decoder_with_past_session is not None: - raise ValueError( - "The parameter decoder_with_past_session was passed, although use_cache is False." - "Please pass use_cache=True for decoder_with_past_session to be used." - ) - - if use_cache is False and use_io_binding is True: - raise ValueError( - "When using CUDAExecutionProvider, the parameters combination use_cache=False, use_io_binding=True" - " is not supported. Please either pass use_cache=True, use_io_binding=True (default)," - " or use_cache=False, use_io_binding=False." - ) - - self.onnx_paths = onnx_paths - self.use_cache = use_cache - self.use_merged = use_merged - self.decoder = ORTDecoder(decoder_session, self) - self.decoder_model_path = Path(decoder_session._model_path) - self.decoder_model_name = self.decoder_model_path.name - - # Reference: https://github.com/huggingface/optimum/pull/1381 - model_type = config.model_type.replace("_", "-") - if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.decoder.input_names: - logger.warning( - f"ORTModelForCausalLM loaded a legacy ONNX model with no position_ids input, although this input is required for batched generation for the architecture {model_type}. We strongly encourage to re-export the model with optimum>=1.14 for position_ids and batched inference support." - ) - - self.decoder_with_past = None - self.decoder_with_past_model_path = None - self.decoder_with_past_model_name = None - if self.use_cache is True and self.use_merged is False: - self.decoder_with_past = ORTDecoder(decoder_with_past_session, self) - self.decoder_with_past_model_path = Path(decoder_with_past_session._model_path) - self.decoder_with_past_model_name = self.decoder_with_past_model_path.name - - if generation_config is None: - generation_config = GenerationConfig.from_model_config(config) - self.generation_config = generation_config - - @staticmethod - def _generate_regular_names_for_filename(filename: str): - name, extension = filename.rsplit(".", maxsplit=1) - return [ - filename, - f"{name}_quantized.{extension}", - f"{name}_optimized.{extension}", - f"{name}_merged.{extension}", - ] - - @staticmethod - def load_model( - decoder_path: Union[str, Path], - decoder_with_past_path: Optional[Union[str, Path]] = None, - provider: str = "CPUExecutionProvider", - session_options: Optional[onnxruntime.SessionOptions] = None, - provider_options: Optional[Dict] = None, - ): - """ - Creates an instance of [`~optimum.onnxruntime.ORTModelDecoder`]. - Three inference sessions will be created for respectively the decoder and decoder with past key values - models. The default provider is `CPUExecutionProvider` to match the default behaviour in PyTorch/TensorFlow/JAX. - - Args: - decoder_path (`str` or `Path`): - The path of the decoder ONNX model. - decoder_with_past_path (`str` or `Path`, *optional*): - The path of the decoder with past key values ONNX model. - provider(`str`, *optional*, defaults to `"CPUExecutionProvider"`): - The ONNX Runtime provider to use for loading the model. - session_options (`Optional[onnxruntime.SessionOptions]`, *optional*),: - ONNX Runtime session options to use for loading the model. - provider_options (`Optional[Dict]`, *optional*): - Provider option dictionary corresponding to the provider used. See available options - for each provider: https://onnxruntime.ai/docs/api/c/group___global.html. - """ - decoder_session = ORTModel.load_model(decoder_path, provider, session_options, provider_options) - - decoder_with_past_session = None - # If a decoder_with_past_path is provided, an inference session for the decoder with past key/values as inputs - # will be enabled - if decoder_with_past_path is not None: - decoder_with_past_session = ORTModel.load_model( - decoder_with_past_path, provider, session_options, provider_options - ) - - return decoder_session, decoder_with_past_session - - def _save_pretrained(self, save_directory: Union[str, Path]): - """ - Saves the model decoder and decoder with past key values as well as its configuration file to a - directory, so that it can be re-loaded using the - [`~optimum.onnxruntime.modeling_causal.ORTModelDecoder.from_pretrained`] class method. - - Args: - save_directory (`str` or `Path`): - The directory where to save the model files. - """ - save_directory = Path(save_directory) - src_paths = [Path(path) for path in self.onnx_paths] - dst_paths = [save_directory / path.name for path in src_paths] - - # add external data paths in case of large models - src_paths, dst_paths = _get_external_data_paths(src_paths, dst_paths) - - for src_path, dst_path in zip(src_paths, dst_paths): - shutil.copyfile(src_path, dst_path) - - self.generation_config.save_pretrained(save_directory) - - @classmethod - def _from_pretrained( - cls, - model_id: Union[str, Path], - config: "PretrainedConfig", - init_cls: Type["ORTModelDecoder"], - use_auth_token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, - force_download: bool = False, - cache_dir: Optional[str] = None, - decoder_file_name: str = ONNX_DECODER_NAME, - decoder_with_past_file_name: str = ONNX_DECODER_WITH_PAST_NAME, - subfolder: str = "", - local_files_only: bool = False, - use_cache: bool = True, - use_merged: Optional[bool] = None, - provider: str = "CPUExecutionProvider", - session_options: Optional[onnxruntime.SessionOptions] = None, - provider_options: Optional[Dict[str, Any]] = None, - use_io_binding: Optional[bool] = None, - model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, - **kwargs, - ): - model_path = Path(model_id) - - # We do not implement the logic for use_cache=False, use_merged=True - if use_cache is False: - if use_merged is True: - raise ValueError( - "The parameters combination use_cache=False, use_merged=True is not supported." - " To use a merged decoder, past key values must be used." - ) - use_merged = False - - decoder_merged_path = None - # We use `is not False` here to include two cases: use_merged = None (in which case we auto-detect it), - # and use_merged = True (explicitely specified by the user) - if use_merged is not False: - try: - decoder_merged_path = ORTModelDecoder.infer_onnx_filename( - model_id, - [DECODER_MERGED_ONNX_FILE_PATTERN], - argument_name=None, - subfolder=subfolder, - use_auth_token=use_auth_token, - revision=revision, - ) - use_merged = True - decoder_path = decoder_merged_path - except FileNotFoundError as e: - if use_merged is True: - raise FileNotFoundError( - "The parameter `use_merged=True` was passed to ORTModelForCausalLM.from_pretrained()" - " but no ONNX file for a merged decoder could be found in" - f" {str(Path(model_id, subfolder))}, with the error: {e}" - ) - use_merged = False - - decoder_without_past_path = None - decoder_with_past_path = None - if use_merged is False: - if not validate_file_exists(model_id, decoder_file_name, subfolder=subfolder, revision=revision): - decoder_without_past_path = ORTModelDecoder.infer_onnx_filename( - model_id, - [DECODER_ONNX_FILE_PATTERN], - "decoder_file_name", - subfolder=subfolder, - use_auth_token=use_auth_token, - revision=revision, - ) - else: - decoder_without_past_path = model_path / subfolder / decoder_file_name - - decoder_path = decoder_without_past_path - - decoder_regular_onnx_filenames = ORTModelDecoder._generate_regular_names_for_filename(ONNX_DECODER_NAME) - if decoder_path.name not in decoder_regular_onnx_filenames: - logger.warning( - f"The ONNX file {decoder_path.name} is not a regular name used in optimum.onnxruntime that are {decoder_regular_onnx_filenames}, the " - f"{cls.__name__} might not behave as expected." - ) - - # If the decoder without / with past has been merged, we do not need to look for any additional file - if use_cache is True: - if not validate_file_exists( - model_id, decoder_with_past_file_name, subfolder=subfolder, revision=revision - ): - try: - decoder_with_past_path = ORTModelDecoder.infer_onnx_filename( - model_id, - [DECODER_WITH_PAST_ONNX_FILE_PATTERN], - "decoder_with_past_file_name", - subfolder=subfolder, - use_auth_token=use_auth_token, - revision=revision, - ) - except FileNotFoundError as e: - raise FileNotFoundError( - "The parameter `use_cache=True` was passed to ORTModelForCausalLM.from_pretrained()" - " but no ONNX file using past key values could be found in" - f" {str(Path(model_id, subfolder))}, with the error: {e}" - ) - else: - decoder_with_past_path = model_path / subfolder / decoder_with_past_file_name - - decoder_with_past_regular_onnx_filenames = ORTModelDecoder._generate_regular_names_for_filename( - ONNX_DECODER_WITH_PAST_NAME - ) - - if decoder_with_past_path.name not in decoder_with_past_regular_onnx_filenames: - logger.warning( - f"The ONNX file {decoder_with_past_path.name} is not a regular name used in optimum.onnxruntime that are {decoder_with_past_regular_onnx_filenames}, " - f"the {cls.__name__} might not behave as expected." - ) - - preprocessors = None - if model_path.is_dir(): - new_model_save_dir = model_path - preprocessors = maybe_load_preprocessors(model_id) - else: - attribute_name_to_filename = { - "last_decoder_model_name": decoder_path.name if use_merged is False else None, - "last_decoder_with_past_model_name": decoder_with_past_path.name - if (use_cache is True and use_merged is False) - else None, - "last_decoder_merged_name": decoder_merged_path.name if use_merged is True else None, - } - paths = {} - for attr_name, filename in attribute_name_to_filename.items(): - if filename is None: - continue - model_cache_path = hf_hub_download( - repo_id=model_id, - subfolder=subfolder, - filename=filename, - use_auth_token=use_auth_token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - ) - - # try download external data - try: - hf_hub_download( - repo_id=model_id, - subfolder=subfolder, - filename=filename + "_data", - use_auth_token=use_auth_token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - ) - except EntryNotFoundError: - # model doesn't use external data - pass - - paths[attr_name] = Path(model_cache_path).name - new_model_save_dir = Path(model_cache_path).parent - preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder) - - if use_merged is True: - decoder_path = new_model_save_dir / paths["last_decoder_merged_name"] - decoder_merged_path = new_model_save_dir / paths["last_decoder_merged_name"] - else: - decoder_path = new_model_save_dir / paths["last_decoder_model_name"] - decoder_without_past_path = new_model_save_dir / paths["last_decoder_model_name"] - - if use_cache is True: - decoder_with_past_path = new_model_save_dir / paths["last_decoder_with_past_model_name"] - - ort_inference_sessions = cls.load_model( - decoder_path=decoder_path, - decoder_with_past_path=None if use_merged is True or use_cache is False else decoder_with_past_path, - provider=provider, - session_options=session_options, - provider_options=provider_options, - ) - - if model_save_dir is None: - model_save_dir = new_model_save_dir - - generation_config = None - try: - generation_config = GenerationConfig.from_pretrained( - model_id, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - ) - except OSError: - logger.info("Generation config file not found, using a generation config created from the model config.") - - onnx_paths = [] - if use_merged is False: - onnx_paths.append(decoder_without_past_path) - if use_cache is True: - onnx_paths.append(decoder_with_past_path) - else: - onnx_paths.append(decoder_merged_path) - - return init_cls( - ort_inference_sessions[0], - config, - decoder_with_past_session=ort_inference_sessions[1], - use_cache=use_cache, - use_io_binding=use_io_binding, - model_save_dir=model_save_dir, - preprocessors=preprocessors, - generation_config=generation_config, - onnx_paths=onnx_paths, - ) - - @classmethod - def _from_transformers( - cls, - model_id: str, - config: "PretrainedConfig", - use_auth_token: Optional[Union[bool, str]] = None, - revision: str = "main", - force_download: bool = True, - cache_dir: Optional[str] = None, - subfolder: str = "", - local_files_only: bool = False, - trust_remote_code: bool = False, - use_cache: bool = True, - use_merged: bool = False, - provider: str = "CPUExecutionProvider", - session_options: Optional[onnxruntime.SessionOptions] = None, - provider_options: Optional[Dict[str, Any]] = None, - use_io_binding: Optional[bool] = None, - task: Optional[str] = None, - ) -> "ORTModelDecoder": - if task is None: - task = cls._auto_model_to_task(cls.auto_model_class) - - if use_cache is True: - task = task + "-with-past" - - if use_cache is False and use_merged is True: - raise ValueError( - "The incompatible arguments use_cache=False, use_merged=True were passed to ORTModelForCausalLM.from_pretrained()." - " Please pass either use_cache=False, use_merged=False to disable past key value caching, or use_cache=True, use_merged=False" - " to disable the merging of the decoder not using / using past key and value." - ) - - save_dir = TemporaryDirectory() - save_dir_path = Path(save_dir.name) - - main_export( - model_name_or_path=model_id, - output=save_dir_path, - task=task, - do_validation=False, - no_post_process=not use_merged, - subfolder=subfolder, - revision=revision, - cache_dir=cache_dir, - use_auth_token=use_auth_token, - local_files_only=local_files_only, - force_download=force_download, - trust_remote_code=trust_remote_code, - ) - - config.save_pretrained(save_dir_path) - maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder) - - return cls._from_pretrained( - save_dir_path, - config, - use_cache=use_cache, - use_merged=use_merged, - provider=provider, - session_options=session_options, - provider_options=provider_options, - use_io_binding=use_io_binding, - model_save_dir=save_dir, - ) - - def to(self, device: Union[torch.device, str, int]): - """ - Changes the ONNX Runtime provider according to the device. - - Args: - device (`Union[torch.device, str, int]`): - Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run - the model on the associated CUDA device id. You can pass native `torch.device` or a `str` too. - - Returns: - `ORTModel`: the model placed on the requested device. - """ - device, provider_options = parse_device(device) - - if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider": - return self - - provider = get_provider_for_device(device) - validate_provider_availability(provider) # raise error if the provider is not available - self.device = device - self.decoder.session.set_providers([provider], provider_options=[provider_options]) - if self.decoder_with_past is not None: - self.decoder_with_past.session.set_providers([provider], provider_options=[provider_options]) - self.providers = self.decoder.session.get_providers() - - return self - - @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) class ORTModelForCausalLM(ORTModel, GenerationMixin): """ @@ -654,10 +144,7 @@ def __init__( **kwargs, ): if use_io_binding is None: - if model.get_providers()[0] in ["CPUExecutionProvider", "CUDAExecutionProvider"]: - use_io_binding = True - else: - use_io_binding = False + use_io_binding = model.get_providers()[0] in ["CPUExecutionProvider", "CUDAExecutionProvider"] super().__init__(model, config, use_io_binding, model_save_dir, preprocessors, **kwargs) @@ -708,7 +195,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, labels: Optional[torch.LongTensor] = None, - use_cache_branch: None = None, + use_cache_branch: bool = None, **kwargs, ) -> CausalLMOutputWithPast: # adding use_cache_branch in the signature here is just a hack for IO Binding diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 638c6fad1b4..41074e741fe 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -138,12 +138,12 @@ def __init__(self, *args, **kwargs): def test_load_model_from_local_path(self): model = ORTModel.from_pretrained(self.LOCAL_MODEL_PATH) - self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.model, onnxruntime.InferenceSession) self.assertIsInstance(model.config, PretrainedConfig) def test_load_model_from_hub(self): model = ORTModel.from_pretrained(self.ONNX_MODEL_ID) - self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.model, onnxruntime.InferenceSession) self.assertIsInstance(model.config, PretrainedConfig) def test_load_model_from_hub_subfolder(self): @@ -151,11 +151,11 @@ def test_load_model_from_hub_subfolder(self): model = ORTModelForSequenceClassification.from_pretrained( "fxmarty/tiny-bert-sst2-distilled-subfolder", subfolder="my_subfolder", export=True ) - self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.model, onnxruntime.InferenceSession) self.assertIsInstance(model.config, PretrainedConfig) model = ORTModel.from_pretrained("fxmarty/tiny-bert-sst2-distilled-onnx-subfolder", subfolder="my_subfolder") - self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.model, onnxruntime.InferenceSession) self.assertIsInstance(model.config, PretrainedConfig) def test_load_seq2seq_model_from_hub_subfolder(self): @@ -178,7 +178,7 @@ def test_load_model_from_cache(self): model = ORTModel.from_pretrained(self.TINY_ONNX_MODEL_ID, local_files_only=True) - self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.model, onnxruntime.InferenceSession) self.assertIsInstance(model.config, PretrainedConfig) def test_load_model_from_empty_cache(self): @@ -768,7 +768,7 @@ def test_stable_diffusion_model_on_gpu_str(self): @require_hf_token def test_load_model_from_hub_private(self): model = ORTModel.from_pretrained(self.ONNX_MODEL_ID, use_auth_token=os.environ.get("HF_AUTH_TOKEN", None)) - self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.model, onnxruntime.InferenceSession) self.assertIsInstance(model.config, PretrainedConfig) def test_save_model(self): @@ -1100,7 +1100,7 @@ def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForQuestionAnswering.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -1267,7 +1267,7 @@ def test_compare_to_transformers(self, model_arch): model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] onnx_model = ORTModelForMaskedLM.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -1429,7 +1429,7 @@ def test_compare_to_transformers(self, model_arch): model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] onnx_model = ORTModelForSequenceClassification.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -1602,7 +1602,7 @@ def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForTokenClassification.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -1725,7 +1725,7 @@ def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForFeatureExtraction.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -1870,7 +1870,7 @@ def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForMultipleChoice.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -1992,14 +1992,14 @@ def test_load_model_from_hub_onnx(self): self.assertFalse(model.use_merged) self.assertTrue(model.use_cache) - self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.model, onnxruntime.InferenceSession) self.assertEqual(model.onnx_paths[0].name, ONNX_DECODER_WITH_PAST_NAME) model = ORTModelForCausalLM.from_pretrained("fxmarty/onnx-tiny-random-gpt2-with-merge") self.assertTrue(model.use_merged) self.assertTrue(model.use_cache) - self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.model, onnxruntime.InferenceSession) self.assertEqual(model.onnx_paths[0].name, ONNX_DECODER_MERGED_NAME) def test_load_vanilla_transformers_which_is_not_supported(self): @@ -2041,7 +2041,7 @@ def test_merge_from_onnx_and_save(self, model_arch): model = ORTModelForCausalLM.from_pretrained(tmpdir) self.assertTrue(model.use_merged) - self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.model, onnxruntime.InferenceSession) model.save_pretrained(tmpdir + "_save") save_path = os.path.join(tmpdir + "_save", ONNX_DECODER_MERGED_NAME) @@ -2079,7 +2079,7 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach model_path = Path(self.onnx_model_dirs[test_name], ONNX_WEIGHTS_NAME) self.assertFalse(has_onnx_input(model_path, "use_cache_branch")) self.assertFalse(onnx_model.use_merged) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -2440,7 +2440,7 @@ def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] if model_arch in MODEL_NAMES else self.ARCH_MODEL_MAP[model_arch] onnx_model = ORTModelForImageClassification.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -2580,7 +2580,7 @@ def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForSemanticSegmentation.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -2735,7 +2735,7 @@ def test_compare_to_transformers(self, model_arch): model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] onnx_model = ORTModelForAudioClassification.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -2887,7 +2887,7 @@ def test_compare_to_transformers(self, model_arch): model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] onnx_model = ORTModelForCTC.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -2946,7 +2946,7 @@ def test_compare_to_transformers(self, model_arch): model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] onnx_model = ORTModelForAudioXVector.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -3038,7 +3038,7 @@ def test_compare_to_transformers(self, model_arch): model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] onnx_model = ORTModelForAudioFrameClassification.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) From 2491ef339345751177dadd2fa826315b94dd466f Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 4 Oct 2023 12:47:45 +0200 Subject: [PATCH 64/76] add position warning --- optimum/onnxruntime/modeling_decoder.py | 29 ++++++++++--------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 2616e40fef2..c9f8dc2310d 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -14,16 +14,13 @@ """Classes handling causal-lm related architectures in ONNX Runtime.""" import logging -import shutil from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import numpy as np import onnx import torch -from huggingface_hub import hf_hub_download -from huggingface_hub.utils import EntryNotFoundError from onnx.tools import update_model_dims from transformers import AutoModelForCausalLM, GenerationConfig from transformers.file_utils import add_end_docstrings, add_start_docstrings_to_model_forward @@ -32,23 +29,13 @@ import onnxruntime from ..exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS, main_export -from ..onnx.utils import _get_external_data_paths, check_model_uses_external_data +from ..onnx.utils import check_model_uses_external_data from ..utils import NormalizedConfigManager, check_if_transformers_greater -from ..utils.file_utils import validate_file_exists -from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors -from .base import ORTDecoder +from ..utils.save_utils import maybe_save_preprocessors from .constants import DECODER_MERGED_ONNX_FILE_PATTERN, DECODER_ONNX_FILE_PATTERN, DECODER_WITH_PAST_ONNX_FILE_PATTERN from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel from .models.bloom import bloom_convert_to_bloom_cache, bloom_convert_to_standard_cache -from .utils import ( - MULTI_QUERY_ATTN_MODELS, - ONNX_DECODER_NAME, - ONNX_DECODER_WITH_PAST_NAME, - ONNX_WEIGHTS_NAME, - get_provider_for_device, - parse_device, - validate_provider_availability, -) +from .utils import MULTI_QUERY_ATTN_MODELS, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_WEIGHTS_NAME if TYPE_CHECKING: @@ -166,6 +153,14 @@ def __init__( self.use_fp16 = True break + # Reference: https://github.com/huggingface/optimum/pull/1381 + model_type = config.model_type.replace("_", "-") + if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.decoder.input_names: + logger.warning( + f"ORTModelForCausalLM loaded a legacy ONNX model with no position_ids input, although this input is required for batched generation for the architecture {model_type}. " + "We strongly encourage to re-export the model with optimum>=1.14 for position_ids and batched inference support." + ) + if use_cache ^ self.use_cache: raise ValueError( f"`use_cache` was set to `{use_cache}` but the loaded model only supports `use_cache={self.use_cache}`. " From 0ab6e61e73afc4cbc3aa7ad71b12f43c1dc288b0 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 5 Oct 2023 12:07:27 +0200 Subject: [PATCH 65/76] fixes --- optimum/commands/export/onnx.py | 14 +++----- optimum/exporters/onnx/__main__.py | 15 ++++---- optimum/exporters/onnx/model_patcher.py | 34 ++++++++++++------- optimum/onnxruntime/modeling_decoder.py | 12 ++++--- optimum/onnxruntime/optimization.py | 4 +-- optimum/utils/modeling_utils.py | 13 +++++-- .../exporters/onnx/test_exporters_onnx_cli.py | 23 ++++++++++++- tests/onnxruntime/test_modeling.py | 5 +-- 8 files changed, 75 insertions(+), 45 deletions(-) diff --git a/optimum/commands/export/onnx.py b/optimum/commands/export/onnx.py index c2c38473a88..85661ccf6cf 100644 --- a/optimum/commands/export/onnx.py +++ b/optimum/commands/export/onnx.py @@ -136,14 +136,6 @@ def parse_args_onnx(parser): default=None, help=("The library on the model." " If not provided, will attempt to infer the local checkpoint's library"), ) - optional_group.add_argument( - "--no-position-ids", - action="store_true", - help=( - "Disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum." - ), - ) - input_group = parser.add_argument_group( "Input shapes (if necessary, this allows to override the shapes of the input given to the ONNX exporter, that requires an example input)." ) @@ -220,7 +212,10 @@ def parse_args_onnx(parser): optional_group.add_argument( "--legacy", action="store_true", - help="Export decoder only models in two (without + with past) model as a single ONNX file.", + help=( + "Export decoder only models in three files (without + with past and the resulting merged model)." + "Also disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum." + ), ) # deprecated argument @@ -260,7 +255,6 @@ def run(self): use_subprocess=True, _variant=self.args.variant, library_name=self.args.library_name, - no_position_ids=self.args.no_position_ids, legacy=self.args.legacy, **input_shapes, ) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 977126401c5..1b601cdfb8d 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -68,7 +68,6 @@ def _get_submodels_and_onnx_configs( float_dtype: str = "fp32", fn_get_submodels: Optional[Callable] = None, preprocessors: Optional[List[Any]] = None, - no_position_ids: bool = False, legacy: bool = False, ): is_stable_diffusion = "stable-diffusion" in task @@ -83,8 +82,8 @@ def _get_submodels_and_onnx_configs( model=model, exporter="onnx", task=task ) onnx_config_kwargs = {} - if task.startswith("text-generation") and no_position_ids: - onnx_config_kwargs["no_position_ids"] = no_position_ids + if task.startswith("text-generation") and legacy: + onnx_config_kwargs["no_position_ids"] = legacy onnx_config = onnx_config_constructor( model.config, @@ -185,7 +184,6 @@ def main_export( use_subprocess: bool = False, _variant: str = "default", library_name: Optional[str] = None, - no_position_ids: bool = False, legacy: bool = False, **kwargs_shapes, ): @@ -266,8 +264,8 @@ def main_export( library_name (`Optional[str]`, defaults to `None`): The library of the model(`"tansformers"` or `"diffusers"` or `"timm"`). If not provided, will attempt to automatically detect the library name for the checkpoint. - no_position_ids (`bool`, defaults to `False`): - Disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum. + legacy (`bool`, defaults to `False`): + Disable the use of position_ids for text-generation models that require it for batched generation. Also enable to export decoder only models in three files (without + with past and the merged model). This argument is introduced for backward compatibility and will be removed in a future release of Optimum. **kwargs_shapes (`Dict`): Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export. @@ -355,9 +353,9 @@ def main_export( is_stable_diffusion = "stable-diffusion" in task model_type = "stable-diffusion" if is_stable_diffusion else model.config.model_type.replace("_", "-") - if no_position_ids and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and task.startswith("text-generation"): + if legacy and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and task.startswith("text-generation"): logger.warning( - f"no_position_ids=True was specified in the ONNX export, although the model {model_name_or_path} (model type {model_type}) requires position_ids for batched inference. Passing `no_position_ids=True` is strongly discouraged, and this option will be removed in a future release. Reference: https://github.com/huggingface/optimum/pull/1381" + f"legacy=True was specified in the ONNX export, although the model {model_name_or_path} (model type {model_type}) requires position_ids for batched inference. Passing `legacy=True` is strongly discouraged, and this option will be removed in a future release. Reference: https://github.com/huggingface/optimum/pull/1381" ) if not is_stable_diffusion: @@ -426,7 +424,6 @@ def main_export( fn_get_submodels=fn_get_submodels, preprocessors=preprocessors, _variant=_variant, - no_position_ids=no_position_ids, legacy=legacy, ) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 0928dab5957..8116fedacaf 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -357,17 +357,17 @@ def __init__( self.patch = self.real_config.task == "text-generation" and self.real_config.use_past if self.patch: - self.orig_prepare_attn_mask = getattr(self._model.transformer, "_prepare_attn_mask") + self.orig_prepare_attn_mask = self._model.transformer._prepare_attn_mask def __enter__(self): super().__enter__() if self.patch: - setattr(self._model.transformer, "_prepare_attn_mask", _prepare_attn_mask) + self._model.transformer._prepare_attn_mask = _prepare_attn_mask.__get__(self._model.transformer) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if self.patch: - setattr(self._model.transformer, "_prepare_attn_mask", self.orig_prepare_attn_mask) + self._model.transformer._prepare_attn_mask = self.orig_prepare_attn_mask.__get__(self._model.transformer) class LlamaModelPatcher(ModelPatcher): @@ -381,17 +381,19 @@ def __init__( self.patch = self.real_config.task == "text-generation" and self.real_config.use_past if self.patch: - self.orig_prepare_attn_mask = getattr(self._model.model, "_prepare_decoder_attention_mask") + self.orig_prepare_attn_mask = self._model.model._prepare_decoder_attention_mask def __enter__(self): super().__enter__() if self.patch: - setattr(self._model.model, "_prepare_decoder_attention_mask", _prepare_decoder_attention_mask) + self._model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask.__get__( + self._model.model + ) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if self.patch: - setattr(self._model.model, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask) + self._model.model._prepare_decoder_attention_mask = self.orig_prepare_attn_mask.__get__(self._model.model) class BartModelPatcher(Seq2SeqModelPatcher): @@ -405,17 +407,21 @@ def __init__( self.patch = self.real_config.task == "text-generation" and self.real_config.use_past if self.patch: - self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask") + self.orig_prepare_attn_mask = self._model.model.decoder._prepare_decoder_attention_mask def __enter__(self): super().__enter__() if self.patch: - setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", _prepare_decoder_attention_mask) + self._model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask.__get__( + self._model.model.decoder + ) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if self.patch: - setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask) + self._model.model.decoder._prepare_decoder_attention_mask = self.orig_prepare_attn_mask.__get__( + self._model.model.decoder + ) class OPTModelPatcher(ModelPatcher): @@ -428,14 +434,18 @@ def __init__( super().__init__(config, model, model_kwargs) self.patch = self.real_config.task == "text-generation" and self.real_config.use_past if self.patch: - self.orig_prepare_attn_mask = getattr(self._model.model.decoder, "_prepare_decoder_attention_mask") + self.orig_prepare_attn_mask = self._model.model.decoder._prepare_decoder_attention_mask def __enter__(self): super().__enter__() if self.patch: - setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", _prepare_decoder_attention_mask) + self._model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask.__get__( + self._model.model.decoder + ) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if self.patch: - setattr(self._model.model.decoder, "_prepare_decoder_attention_mask", self.orig_prepare_attn_mask) + self._model.model.decoder._prepare_decoder_attention_mask = self.orig_prepare_attn_mask.__get__( + self._model.model.decoder + ) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index c9f8dc2310d..2f401e29039 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -153,7 +153,7 @@ def __init__( self.use_fp16 = True break - # Reference: https://github.com/huggingface/optimum/pull/1381 + # Reference: https://github.com/huggingface/optimum/pull/1381 model_type = config.model_type.replace("_", "-") if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.decoder.input_names: logger.warning( @@ -413,11 +413,13 @@ def _from_pretrained( ) use_merged = False - # TODO : deprecate - decoder_file_name = kwargs.pop("decoder_file_name", None) - decoder_with_past_file_name = kwargs.pop("decoder_with_past_file_name", None) + decoder_name = "decoder_file_name" if use_cache else "decoder_with_past_file_name" + decoder_file_name = kwargs.pop(decoder_name, None) + + if decoder_file_name is not None: + logger.warning(f"The `{decoder_name}` argument is deprecated, please use `file_name` instead.") + file_name = file_name or decoder_file_name - file_name = file_name or (decoder_with_past_file_name if use_cache else decoder_file_name) if file_name is None: decoder_path = None # We use `is not False` here to include two cases: use_merged = None (in which case we auto-detect it), diff --git a/optimum/onnxruntime/optimization.py b/optimum/onnxruntime/optimization.py index 15b219441a6..9e62a3f324c 100644 --- a/optimum/onnxruntime/optimization.py +++ b/optimum/onnxruntime/optimization.py @@ -99,8 +99,8 @@ def from_pretrained( onnx_model_path.append(model_or_path.decoder_with_past_model_path) elif isinstance(model_or_path, ORTModelForCausalLM) and model_or_path.use_merged: raise NotImplementedError( - "ORTOptimizer does not support ORTModelForCausalLM models when without/with past models are merged." - " Please re-export your model. This can be done by using `ORTModelForCausalLM.from_pretrained(..., export=True, use_merged=False)`" + "ORTOptimizer does not support ORTModelForCausalLM models when without/with past models are merged. " + "Please re-export your model. This can be done by using the optimum-cli ONNX export tool or `ORTModelForCausalLM.from_pretrained(..., export=True, use_merged=False)`." ) else: onnx_model_path.append(model_or_path.model_path) diff --git a/optimum/utils/modeling_utils.py b/optimum/utils/modeling_utils.py index 50da3556815..2970e8c892b 100644 --- a/optimum/utils/modeling_utils.py +++ b/optimum/utils/modeling_utils.py @@ -69,7 +69,10 @@ def _make_causal_mask( # Modified from transformers.models.bloom.modeling_bloom._prepare_attn_mask def _prepare_attn_mask( - attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int + self, + attention_mask: torch.Tensor, + input_shape: Tuple[int, int], + past_key_values_length: int, ) -> torch.BoolTensor: from transformers.models.bloom.modeling_bloom import _expand_mask @@ -92,7 +95,13 @@ def _prepare_attn_mask( # Modified from transformers.models.llama.modeling_llama._prepare_decoder_attention_mask -def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length): +def _prepare_decoder_attention_mask( + self, + attention_mask: torch.Tensor, + input_shape: Tuple[int, int], + inputs_embeds: torch.Tensor, + past_key_values_length: int, +): from transformers.models.llama.modeling_llama import _expand_mask # create causal mask diff --git a/tests/exporters/onnx/test_exporters_onnx_cli.py b/tests/exporters/onnx/test_exporters_onnx_cli.py index b1cdedbea84..efdbaba4235 100644 --- a/tests/exporters/onnx/test_exporters_onnx_cli.py +++ b/tests/exporters/onnx/test_exporters_onnx_cli.py @@ -19,6 +19,7 @@ from tempfile import TemporaryDirectory from typing import Dict, Optional +import onnx import pytest from parameterized import parameterized from transformers import AutoModelForSequenceClassification, AutoTokenizer, is_torch_available @@ -26,7 +27,12 @@ from optimum.exporters.error_utils import MinimumVersionError from optimum.exporters.onnx.__main__ import main_export -from optimum.onnxruntime import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME +from optimum.onnxruntime import ( + ONNX_DECODER_MERGED_NAME, + ONNX_DECODER_NAME, + ONNX_DECODER_WITH_PAST_NAME, + ONNX_ENCODER_NAME, +) from optimum.utils.testing_utils import require_diffusers, require_timm @@ -413,6 +419,21 @@ def test_stable_diffusion(self): check=True, ) + def test_legacy(self): + with TemporaryDirectory() as tmpdirname: + subprocess.run( + f"python3 -m optimum.exporters.onnx --model hf-internal-testing/tiny-random-gpt2 --task text-generation-with-past --legacy {tmpdirname}", + shell=True, + capture_output=True, + ) + folder_contents = os.listdir(tmpdirname) + self.assertIn(ONNX_DECODER_NAME, folder_contents) + self.assertIn(ONNX_DECODER_WITH_PAST_NAME, folder_contents) + self.assertIn(ONNX_DECODER_MERGED_NAME, folder_contents) + + model = onnx.load(Path(tmpdirname) / ONNX_DECODER_MERGED_NAME) + self.assertNotIn("position_ids", {node.name for node in model.graph.input}) + @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY)) @require_vision @require_torch_gpu diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 41074e741fe..f59260a10ac 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -2086,10 +2086,7 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach transformers_model = AutoModelForCausalLM.from_pretrained(model_id) transformers_model = transformers_model.eval() tokenizer = get_preprocessor(model_id) - tokens = tokenizer( - "This is a sample output", - return_tensors="pt", - ) + tokens = tokenizer("This is a sample output", return_tensors="pt") position_ids = None if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS: input_shape = tokens["input_ids"].shape From 1a7d4919ce172b990a3cc5631ab32e9b8a77e2ff Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 5 Oct 2023 14:30:55 +0200 Subject: [PATCH 66/76] enable post process after export to remove tied weights --- optimum/onnxruntime/modeling_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 2f401e29039..f82343be55b 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -596,7 +596,7 @@ def _from_transformers( output=save_dir_path, task=task, do_validation=False, - no_post_process=True, + no_post_process=False, legacy=False, subfolder=subfolder, revision=revision, From cd8d4be8b463b3977704e42c06cc4a3bed127974 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 5 Oct 2023 15:24:18 +0200 Subject: [PATCH 67/76] comment --- optimum/onnxruntime/modeling_decoder.py | 11 +---------- optimum/utils/modeling_utils.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index f82343be55b..e38d0e5442c 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -31,6 +31,7 @@ from ..exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS, main_export from ..onnx.utils import check_model_uses_external_data from ..utils import NormalizedConfigManager, check_if_transformers_greater +from ..utils.modeling_utils import MODEL_TO_PATCH_FOR_PAST from ..utils.save_utils import maybe_save_preprocessors from .constants import DECODER_MERGED_ONNX_FILE_PATTERN, DECODER_ONNX_FILE_PATTERN, DECODER_WITH_PAST_ONNX_FILE_PATTERN from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel @@ -458,16 +459,6 @@ def _from_pretrained( ) file_name = decoder_path.name - MODEL_TO_PATCH_FOR_PAST = { - "bloom", - "mpt", - "llama", - "blenderbot-small", - "blenderbot", - "opt", - "pegasus", - "bart", - } if file_name == ONNX_DECODER_WITH_PAST_NAME and config.model_type in MODEL_TO_PATCH_FOR_PAST: raise ValueError( f"{ONNX_DECODER_WITH_PAST_NAME} not supported for the following architecture : {', '.join(MODEL_TO_PATCH_FOR_PAST)}. Please re-export your model or set use_cache=False." diff --git a/optimum/utils/modeling_utils.py b/optimum/utils/modeling_utils.py index 2970e8c892b..e34c3a88b7d 100644 --- a/optimum/utils/modeling_utils.py +++ b/optimum/utils/modeling_utils.py @@ -18,6 +18,18 @@ import torch +MODEL_TO_PATCH_FOR_PAST = { + "bart", + "blenderbot", + "blenderbot-small", + "bloom", + "llama", + "mpt", + "opt", + "pegasus", +} + + def recurse_getattr(obj, attr: str): """ Recursive `getattr`. @@ -67,6 +79,11 @@ def _make_causal_mask( return mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) +# NOTE: For MODEL_TO_PATCH_FOR_PAST architectures, when exporting the model with an input of sequence length of 1, the attention masks will be generated incorrectly for other sequence length +# https://github.com/huggingface/transformers/blob/0ee45906845c8d58b9bd2df5acd90e09b00047ff/src/transformers/models/bloom/modeling_bloom.py#L654 +# The method taking care of the decoder mask generation of the models from these architectures must be patched during export for sequence length of 1. + + # Modified from transformers.models.bloom.modeling_bloom._prepare_attn_mask def _prepare_attn_mask( self, From e6de5e760847fd055f9d7483abfff2d038ba31bd Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 5 Oct 2023 15:25:55 +0200 Subject: [PATCH 68/76] remove test --- tests/onnxruntime/test_modeling.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index f59260a10ac..2f2c789781b 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -2008,25 +2008,6 @@ def test_load_vanilla_transformers_which_is_not_supported(self): self.assertIn("Unrecognized configuration class", str(context.exception)) - @parameterized.expand(SUPPORTED_ARCHITECTURES) - def test_merge_from_transformers_and_save(self, model_arch): - if "text-generation-with-past" not in TasksManager.get_supported_tasks_for_model_type( - model_arch.replace("_", "-"), exporter="onnx" - ): - self.skipTest("Unsupported -with-past export case") - - model_id = MODEL_NAMES[model_arch] - model = ORTModelForCausalLM.from_pretrained(model_id, export=True, use_merged=True) - with tempfile.TemporaryDirectory() as tmpdir: - model.save_pretrained(tmpdir) - save_path = os.path.join(tmpdir, ONNX_WEIGHTS_NAME) - self.assertFalse(has_onnx_input(save_path, "use_cache_branch")) - - folder_contents = os.listdir(tmpdir) - self.assertTrue(ONNX_DECODER_NAME not in folder_contents) - self.assertTrue(ONNX_DECODER_WITH_PAST_NAME not in folder_contents) - self.assertTrue(ONNX_DECODER_MERGED_NAME not in folder_contents) - @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_merge_from_onnx_and_save(self, model_arch): model_id = MODEL_NAMES[model_arch] From 4a32f7a1c8ac7e0f6d7612eee503949277ce2b0f Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 5 Oct 2023 16:01:17 +0200 Subject: [PATCH 69/76] fix test --- optimum/onnxruntime/modeling_decoder.py | 2 +- tests/onnxruntime/test_modeling.py | 60 ++++++++++++------------- 2 files changed, 30 insertions(+), 32 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index e38d0e5442c..bd7f7715662 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -156,7 +156,7 @@ def __init__( # Reference: https://github.com/huggingface/optimum/pull/1381 model_type = config.model_type.replace("_", "-") - if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.decoder.input_names: + if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.inputs_names: logger.warning( f"ORTModelForCausalLM loaded a legacy ONNX model with no position_ids input, although this input is required for batched generation for the architecture {model_type}. " "We strongly encourage to re-export the model with optimum>=1.14 for position_ids and batched inference support." diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 2f2c789781b..0e432719c8b 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -1967,25 +1967,22 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): GENERATION_LENGTH = 100 SPEEDUP_CACHE = 1.1 - def test_inference_old_onnx_model(self): + @parameterized.expand([(False,), (True,)]) + def test_inference_old_onnx_model(self, use_cache): model_id = "optimum/gpt2" model = AutoModelForCausalLM.from_pretrained("gpt2") tokenizer = get_preprocessor(model_id) text = "This is a sample output" tokens = tokenizer(text, return_tensors="pt") + onnx_model = ORTModelForCausalLM.from_pretrained(model_id, use_cache=use_cache, use_io_binding=use_cache) - for use_cache in (True, False): - onnx_model = ORTModelForCausalLM.from_pretrained(model_id, use_cache=use_cache, use_io_binding=use_cache) - - self.assertEqual(onnx_model.use_cache, use_cache) - self.assertEqual( - onnx_model.model_path.name, ONNX_DECODER_WITH_PAST_NAME if use_cache else ONNX_DECODER_NAME - ) - outputs_onnx = onnx_model.generate( - **tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30 - ) - outputs = model.generate(**tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30) - self.assertTrue(torch.allclose(outputs_onnx, outputs)) + self.assertEqual(onnx_model.use_cache, use_cache) + self.assertEqual(onnx_model.model_path.name, ONNX_DECODER_WITH_PAST_NAME if use_cache else ONNX_DECODER_NAME) + outputs_onnx = onnx_model.generate( + **tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30 + ) + outputs = model.generate(**tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30) + self.assertTrue(torch.allclose(outputs_onnx, outputs)) def test_load_model_from_hub_onnx(self): model = ORTModelForCausalLM.from_pretrained("fxmarty/onnx-tiny-random-gpt2-without-merge") @@ -2023,15 +2020,14 @@ def test_merge_from_onnx_and_save(self, model_arch): self.assertTrue(model.use_merged) self.assertIsInstance(model.model, onnxruntime.InferenceSession) - model.save_pretrained(tmpdir + "_save") save_path = os.path.join(tmpdir + "_save", ONNX_DECODER_MERGED_NAME) self.assertTrue(has_onnx_input(save_path, "use_cache_branch")) folder_contents = os.listdir(tmpdir + "_save") - self.assertTrue(ONNX_DECODER_NAME not in folder_contents) - self.assertTrue(ONNX_DECODER_WITH_PAST_NAME not in folder_contents) - self.assertTrue(ONNX_WEIGHTS_NAME not in folder_contents) + self.assertNotIn(ONNX_DECODER_NAME, folder_contents) + self.assertNotIn(ONNX_DECODER_WITH_PAST_NAME, folder_contents) + self.assertNotIn(ONNX_WEIGHTS_NAME, folder_contents) @parameterized.expand(grid_parameters(FULL_GRID)) def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): @@ -2268,18 +2264,10 @@ def test_compare_with_and_without_past_key_values(self, model_arch): @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]})) def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, model_arch: str, use_cache: bool): - model_args = { - "test_name": test_name + "_True", - "model_arch": model_arch, - "use_cache": use_cache, - "use_merged": True, - } - self._setup(model_args) model_args = { "test_name": test_name + "_False", "model_arch": model_arch, "use_cache": use_cache, - "use_merged": False, } self._setup(model_args) @@ -2287,19 +2275,29 @@ def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, mode tokenizer = get_preprocessor(model_id) text = "My Name is Philipp and i live" tokens = tokenizer(text, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None) - model_not_merged_dir = self.onnx_model_dirs[test_name + "_False"] - model_merged_dir = self.onnx_model_dirs[test_name + "_True"] - model_not_merged = ORTModelForCausalLM.from_pretrained(model_not_merged_dir) not_merged_onnx_path = Path(model_not_merged_dir, ONNX_WEIGHTS_NAME) self.assertFalse(has_onnx_input(not_merged_onnx_path, "use_cache_branch")) self.assertFalse(model_not_merged.use_merged) + model_merged_dir = Path(model_not_merged_dir) / "merged" + task = model_not_merged.export_feature + if use_cache: + task += "-with-past" + + main_export( + model_id, + output=model_merged_dir, + task=task, + no_post_process=False, + legacy=True, + ) + model_merged = ORTModelForCausalLM.from_pretrained(model_merged_dir) - merged_onnx_path = Path(model_merged_dir, ONNX_WEIGHTS_NAME) - self.assertFalse(has_onnx_input(merged_onnx_path, "use_cache_branch")) - self.assertFalse(model_merged.use_merged) + merged_onnx_path = Path(model_merged_dir, ONNX_DECODER_MERGED_NAME) + self.assertTrue(has_onnx_input(merged_onnx_path, "use_cache_branch")) + self.assertTrue(model_merged.use_merged) outputs_model_not_merged = model_not_merged.generate(**tokens) outputs_model_merged = model_merged.generate(**tokens) From a51686ecb6375493cdca032bd3679e7db91f3c41 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 6 Oct 2023 15:47:43 +0200 Subject: [PATCH 70/76] modify model --- tests/onnxruntime/utils_onnxruntime_tests.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 985d0340350..3acf3ff65ba 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -40,7 +40,7 @@ "clip": "hf-internal-testing/tiny-random-CLIPModel", "convbert": "hf-internal-testing/tiny-random-ConvBertModel", "convnext": "hf-internal-testing/tiny-random-convnext", - "codegen": "hf-internal-testing/tiny-random-CodeGenModel", + "codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM", "data2vec_text": "hf-internal-testing/tiny-random-Data2VecTextModel", "data2vec_vision": "hf-internal-testing/tiny-random-Data2VecVisionModel", "data2vec_audio": "hf-internal-testing/tiny-random-Data2VecAudioModel", @@ -62,7 +62,7 @@ "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", - "gptj": "hf-internal-testing/tiny-random-GPTJModel", + "gptj": "hf-internal-testing/tiny-random-GPTJForCausalLM", "groupvit": "hf-internal-testing/tiny-random-groupvit", "hubert": "hf-internal-testing/tiny-random-HubertModel", "ibert": "hf-internal-testing/tiny-random-IBertModel", From e2f8a3b62a2a5cd23ff139602091e6e1bedbdb4f Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 6 Oct 2023 15:51:40 +0200 Subject: [PATCH 71/76] remove deprecated use_merged in test --- tests/onnxruntime/test_modeling.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 0e432719c8b..4f34a963435 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -1958,7 +1958,6 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): FULL_GRID = { "model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [False, True], - "use_merged": [False, True], } ORTMODEL_CLASS = ORTModelForCausalLM @@ -2030,10 +2029,7 @@ def test_merge_from_onnx_and_save(self, model_arch): self.assertNotIn(ONNX_WEIGHTS_NAME, folder_contents) @parameterized.expand(grid_parameters(FULL_GRID)) - def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): - if use_cache is False and use_merged is True: - self.skipTest("use_cache=False, use_merged=True are uncompatible") - + def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool): use_io_binding = None if use_cache is False: use_io_binding = False @@ -2042,7 +2038,6 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach "test_name": test_name, "model_arch": model_arch, "use_cache": use_cache, - "use_merged": use_merged, } self._setup(model_args) @@ -2100,10 +2095,7 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach gc.collect() @parameterized.expand(grid_parameters(FULL_GRID)) - def test_pipeline_ort_model(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): - if use_cache is False and use_merged is True: - self.skipTest("use_cache=False, use_merged=True are uncompatible") - + def test_pipeline_ort_model(self, test_name: str, model_arch: str, use_cache: bool): use_io_binding = None if use_cache is False: use_io_binding = False @@ -2112,7 +2104,6 @@ def test_pipeline_ort_model(self, test_name: str, model_arch: str, use_cache: bo "test_name": test_name, "model_arch": model_arch, "use_cache": use_cache, - "use_merged": use_merged, } self._setup(model_args) From b76f43a88f64ff22b80affb22251877cc5a85481 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 9 Oct 2023 12:43:07 +0200 Subject: [PATCH 72/76] Add mistral model patcher --- optimum/exporters/onnx/model_configs.py | 6 ++ optimum/exporters/onnx/model_patcher.py | 98 +++++++++++-------------- optimum/onnxruntime/modeling_decoder.py | 7 +- optimum/utils/modeling_utils.py | 36 +++++++++ 4 files changed, 90 insertions(+), 57 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index c8eef4078db..a83c8a91fa5 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -60,6 +60,7 @@ BartModelPatcher, BloomModelPatcher, LlamaModelPatcher, + MistralModelPatcher, OPTModelPatcher, SAMModelPatcher, WavLMModelPatcher, @@ -250,6 +251,11 @@ class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True) + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return MistralModelPatcher(self, model, model_kwargs=model_kwargs) + class MPTOnnxConfig(TextDecoderOnnxConfig): # MPT does not require position_ids input. diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 8116fedacaf..7acd7515ea4 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -19,7 +19,11 @@ from transformers.utils import is_torch_available -from ...utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask +from ...utils.modeling_utils import ( + _prepare_attn_mask, + _prepare_decoder_attention_mask, + _prepare_decoder_sliding_window_attention_mask, +) if is_torch_available(): @@ -346,7 +350,7 @@ def patched_forward( self.patched_forward = patched_forward -class BloomModelPatcher(ModelPatcher): +class CausalAttentionMaskModelPatcher(ModelPatcher): def __init__( self, config: "OnnxConfig", @@ -357,95 +361,79 @@ def __init__( self.patch = self.real_config.task == "text-generation" and self.real_config.use_past if self.patch: - self.orig_prepare_attn_mask = self._model.transformer._prepare_attn_mask + self._orig_func = getattr(self._model_to_patch, self._orig_func_name) def __enter__(self): super().__enter__() if self.patch: - self._model.transformer._prepare_attn_mask = _prepare_attn_mask.__get__(self._model.transformer) + setattr(self._model_to_patch, self._orig_func_name, self._patch_func.__get__(self._model_to_patch)) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if self.patch: - self._model.transformer._prepare_attn_mask = self.orig_prepare_attn_mask.__get__(self._model.transformer) + setattr(self._model_to_patch, self._orig_func_name, self._orig_func.__get__(self._model_to_patch)) -class LlamaModelPatcher(ModelPatcher): +class BloomModelPatcher(CausalAttentionMaskModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None, ): + self._model_to_patch = model.transformer + self._patch_func = _prepare_attn_mask + self._orig_func_name = "_prepare_attn_mask" super().__init__(config, model, model_kwargs) - self.patch = self.real_config.task == "text-generation" and self.real_config.use_past - if self.patch: - self.orig_prepare_attn_mask = self._model.model._prepare_decoder_attention_mask - - def __enter__(self): - super().__enter__() - if self.patch: - self._model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask.__get__( - self._model.model - ) - - def __exit__(self, exc_type, exc_value, traceback): - super().__exit__(exc_type, exc_value, traceback) - if self.patch: - self._model.model._prepare_decoder_attention_mask = self.orig_prepare_attn_mask.__get__(self._model.model) - -class BartModelPatcher(Seq2SeqModelPatcher): +class OPTModelPatcher(CausalAttentionMaskModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None, ): + self._model_to_patch = model.model.decoder + self._patch_func = _prepare_decoder_attention_mask + self._orig_func_name = "_prepare_decoder_attention_mask" super().__init__(config, model, model_kwargs) - self.patch = self.real_config.task == "text-generation" and self.real_config.use_past - if self.patch: - self.orig_prepare_attn_mask = self._model.model.decoder._prepare_decoder_attention_mask - - def __enter__(self): - super().__enter__() - if self.patch: - self._model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask.__get__( - self._model.model.decoder - ) - def __exit__(self, exc_type, exc_value, traceback): - super().__exit__(exc_type, exc_value, traceback) - if self.patch: - self._model.model.decoder._prepare_decoder_attention_mask = self.orig_prepare_attn_mask.__get__( - self._model.model.decoder - ) +class LlamaModelPatcher(CausalAttentionMaskModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + self._model_to_patch = model.model + self._patch_func = _prepare_decoder_attention_mask + self._orig_func_name = "_prepare_decoder_attention_mask" + super().__init__(config, model, model_kwargs) -class OPTModelPatcher(ModelPatcher): +class MistralModelPatcher(CausalAttentionMaskModelPatcher): def __init__( self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None, ): + self._model_to_patch = model.model + self._patch_func = _prepare_decoder_sliding_window_attention_mask + self._orig_func_name = "_prepare_decoder_attention_mask" super().__init__(config, model, model_kwargs) - self.patch = self.real_config.task == "text-generation" and self.real_config.use_past - if self.patch: - self.orig_prepare_attn_mask = self._model.model.decoder._prepare_decoder_attention_mask - def __enter__(self): - super().__enter__() - if self.patch: - self._model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask.__get__( - self._model.model.decoder - ) - def __exit__(self, exc_type, exc_value, traceback): - super().__exit__(exc_type, exc_value, traceback) - if self.patch: - self._model.model.decoder._prepare_decoder_attention_mask = self.orig_prepare_attn_mask.__get__( - self._model.model.decoder - ) +class BartModelPatcher(CausalAttentionMaskModelPatcher, Seq2SeqModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + self._model_to_patch = model.model.decoder + self._patch_func = _prepare_decoder_attention_mask + self._orig_func_name = "_prepare_decoder_attention_mask" + super().__init__(config, model, model_kwargs) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index bd7f7715662..2707c6eeab2 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -332,8 +332,11 @@ def prepare_past_key_values( # Generate dummy past for the first forward if uses a merged decoder if past_key_values is None: batch_size = input_ids.shape[0] - num_attention_heads = self.normalized_config.num_attention_heads - embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads + if self.config.model_type in {"mistral", "llama"}: + num_attention_heads = self.normalized_config.num_key_value_heads + else: + num_attention_heads = self.normalized_config.num_attention_heads + embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads dtype = constructor.float16 if self.use_fp16 else constructor.float32 # TODO: find a way to better handle this controlflow # "1" is the dummy sequence length diff --git a/optimum/utils/modeling_utils.py b/optimum/utils/modeling_utils.py index e34c3a88b7d..145e913a058 100644 --- a/optimum/utils/modeling_utils.py +++ b/optimum/utils/modeling_utils.py @@ -24,6 +24,7 @@ "blenderbot-small", "bloom", "llama", + "mistral", "mpt", "opt", "pegasus", @@ -142,3 +143,38 @@ def _prepare_decoder_attention_mask( ) return combined_attention_mask + + +# Modified from transformers.models.mistral.modeling_mistral._prepare_decoder_sliding_window_attention_mask +def _prepare_decoder_sliding_window_attention_mask( + self, + attention_mask: torch.Tensor, + input_shape: Tuple[int, int], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: int, +): + from transformers.models.mistral.modeling_mistral import _make_sliding_window_causal_mask, _expand_mask + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + + combined_attention_mask = _make_sliding_window_causal_mask( + input_shape, + device=inputs_embeds.device, + dtype=inputs_embeds.dtype, + past_key_values_length=past_key_values_length, + sliding_window=sliding_window, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask From 5b3d445356a0c79a1769005b420a7a85c79ed447 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 9 Oct 2023 13:18:17 +0200 Subject: [PATCH 73/76] fix test --- tests/onnxruntime/test_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 4be5108a691..1f08b1c33a8 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -847,7 +847,7 @@ def test_save_load_decoder_model_with_external_data(self, use_cache: bool): # verify external data is exported folder_contents = os.listdir(tmpdirname) self.assertTrue(ONNX_WEIGHTS_NAME in folder_contents) - self.assertTrue(ONNX_WEIGHTS_NAME + "_data" in folder_contents) + # self.assertTrue(ONNX_WEIGHTS_NAME + "_data" in folder_contents) self.assertFalse(use_cache ^ model.use_cache) # verify loading from local folder works From 5406f95b52756a831e2ef24d609c9a42628164e8 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 9 Oct 2023 14:02:03 +0200 Subject: [PATCH 74/76] add slow test --- tests/onnxruntime/nightly_test_trainer.py | 6 +----- tests/onnxruntime/test_modeling.py | 10 +++++----- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/onnxruntime/nightly_test_trainer.py b/tests/onnxruntime/nightly_test_trainer.py index 38bdfd07973..2eb3ca433f7 100644 --- a/tests/onnxruntime/nightly_test_trainer.py +++ b/tests/onnxruntime/nightly_test_trainer.py @@ -40,11 +40,7 @@ default_data_collator, is_torch_available, ) -from transformers.testing_utils import ( - require_deepspeed, - require_torch, - slow, -) +from transformers.testing_utils import require_deepspeed, require_torch, slow from transformers.training_args import OptimizerNames diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 1f08b1c33a8..b7695cbd651 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -59,7 +59,7 @@ ) from transformers.modeling_utils import no_init_weights from transformers.onnx.utils import get_preprocessor -from transformers.testing_utils import get_gpu_count, require_torch_gpu +from transformers.testing_utils import get_gpu_count, require_torch_gpu, slow from utils_onnxruntime_tests import MODEL_NAMES, SEED, ORTModelTestMixin from optimum.exporters import TasksManager @@ -832,11 +832,12 @@ def test_save_load_ort_model_with_external_data(self): os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") @parameterized.expand([(False,), (True,)]) + @pytest.mark.run_slow + @slow def test_save_load_decoder_model_with_external_data(self, use_cache: bool): with tempfile.TemporaryDirectory() as tmpdirname: - os.environ["FORCE_ONNX_EXTERNAL_DATA"] = "1" # force exporting small model with external data model = ORTModelForCausalLM.from_pretrained( - MODEL_NAMES["gpt2"], + "gpt2-large", use_cache=use_cache, export=True, use_merged=False, @@ -847,14 +848,13 @@ def test_save_load_decoder_model_with_external_data(self, use_cache: bool): # verify external data is exported folder_contents = os.listdir(tmpdirname) self.assertTrue(ONNX_WEIGHTS_NAME in folder_contents) - # self.assertTrue(ONNX_WEIGHTS_NAME + "_data" in folder_contents) + self.assertTrue(ONNX_WEIGHTS_NAME + "_data" in folder_contents) self.assertFalse(use_cache ^ model.use_cache) # verify loading from local folder works model = ORTModelForCausalLM.from_pretrained( tmpdirname, use_cache=use_cache, export=False, use_io_binding=False ) - os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") @parameterized.expand([(False,), (True,)]) def test_save_load_seq2seq_model_with_external_data(self, use_cache: bool): From 52e0c6998463050f832d381efc8624767e155c15 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 9 Oct 2023 15:33:12 +0200 Subject: [PATCH 75/76] add workflow --- .github/workflows/test_onnxruntime_slow.yml | 33 +++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 .github/workflows/test_onnxruntime_slow.yml diff --git a/.github/workflows/test_onnxruntime_slow.yml b/.github/workflows/test_onnxruntime_slow.yml new file mode 100644 index 00000000000..20371f79150 --- /dev/null +++ b/.github/workflows/test_onnxruntime_slow.yml @@ -0,0 +1,33 @@ +name: ONNX Runtime slow / Python - Test + +on: + workflow_dispatch: + schedule: + - cron: 0 7 * * * # every day at 7am + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: + strategy: + fail-fast: false + matrix: + python-version: [3.8, 3.9] + os: [ubuntu-20.04] + + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies for export + run: | + pip install .[tests,onnxruntime] + - name: Test with unittest + working-directory: tests + run: | + RUN_SLOW=1 pytest onnxruntime -s -m "run_slow" --durations=0 From 888332364c2e0091da1fc974737c7e277af168bf Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 9 Oct 2023 15:43:43 +0200 Subject: [PATCH 76/76] fix --- optimum/exporters/onnx/model_patcher.py | 47 +++++++++++++++---------- optimum/utils/modeling_utils.py | 2 +- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 7acd7515ea4..aa14526bd8c 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -358,10 +358,7 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__(config, model, model_kwargs) - self.patch = self.real_config.task == "text-generation" and self.real_config.use_past - if self.patch: - self._orig_func = getattr(self._model_to_patch, self._orig_func_name) def __enter__(self): super().__enter__() @@ -381,10 +378,12 @@ def __init__( model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None, ): - self._model_to_patch = model.transformer - self._patch_func = _prepare_attn_mask - self._orig_func_name = "_prepare_attn_mask" super().__init__(config, model, model_kwargs) + if self.patch: + self._model_to_patch = model.transformer + self._patch_func = _prepare_attn_mask + self._orig_func_name = "_prepare_attn_mask" + self._orig_func = self._model_to_patch._prepare_attn_mask class OPTModelPatcher(CausalAttentionMaskModelPatcher): @@ -394,11 +393,14 @@ def __init__( model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None, ): - self._model_to_patch = model.model.decoder - self._patch_func = _prepare_decoder_attention_mask - self._orig_func_name = "_prepare_decoder_attention_mask" super().__init__(config, model, model_kwargs) + if self.patch: + self._model_to_patch = model.model.decoder + self._patch_func = _prepare_decoder_attention_mask + self._orig_func_name = "_prepare_decoder_attention_mask" + self._orig_func = self._model_to_patch._prepare_decoder_attention_mask + class LlamaModelPatcher(CausalAttentionMaskModelPatcher): def __init__( @@ -407,11 +409,14 @@ def __init__( model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None, ): - self._model_to_patch = model.model - self._patch_func = _prepare_decoder_attention_mask - self._orig_func_name = "_prepare_decoder_attention_mask" super().__init__(config, model, model_kwargs) + if self.patch: + self._model_to_patch = model.model + self._patch_func = _prepare_decoder_attention_mask + self._orig_func_name = "_prepare_decoder_attention_mask" + self._orig_func = self._model_to_patch._prepare_decoder_attention_mask + class MistralModelPatcher(CausalAttentionMaskModelPatcher): def __init__( @@ -420,11 +425,14 @@ def __init__( model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None, ): - self._model_to_patch = model.model - self._patch_func = _prepare_decoder_sliding_window_attention_mask - self._orig_func_name = "_prepare_decoder_attention_mask" super().__init__(config, model, model_kwargs) + if self.patch: + self._model_to_patch = model.model + self._patch_func = _prepare_decoder_sliding_window_attention_mask + self._orig_func_name = "_prepare_decoder_attention_mask" + self._orig_func = self._model_to_patch._prepare_decoder_attention_mask + class BartModelPatcher(CausalAttentionMaskModelPatcher, Seq2SeqModelPatcher): def __init__( @@ -433,7 +441,10 @@ def __init__( model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None, ): - self._model_to_patch = model.model.decoder - self._patch_func = _prepare_decoder_attention_mask - self._orig_func_name = "_prepare_decoder_attention_mask" super().__init__(config, model, model_kwargs) + + if self.patch: + self._model_to_patch = model.model.decoder + self._patch_func = _prepare_decoder_attention_mask + self._orig_func_name = "_prepare_decoder_attention_mask" + self._orig_func = self._model_to_patch._prepare_decoder_attention_mask diff --git a/optimum/utils/modeling_utils.py b/optimum/utils/modeling_utils.py index 145e913a058..67e12861eb5 100644 --- a/optimum/utils/modeling_utils.py +++ b/optimum/utils/modeling_utils.py @@ -154,7 +154,7 @@ def _prepare_decoder_sliding_window_attention_mask( past_key_values_length: int, sliding_window: int, ): - from transformers.models.mistral.modeling_mistral import _make_sliding_window_causal_mask, _expand_mask + from transformers.models.mistral.modeling_mistral import _expand_mask, _make_sliding_window_causal_mask # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]