diff --git a/README.md b/README.md index 5cd9222a6c..ed925ac666 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # Hugging Face Optimum -🤗 Optimum is an extension of 🤗 Transformers, providing a set of performance optimization tools enabling maximum efficiency to train and run models on targeted hardware. +🤗 Optimum is an extension of 🤗 Transformers, providing a set of optimization tools enabling maximum efficiency to train and run models on targeted hardware. The AI ecosystem evolves quickly and more and more specialized hardware along with their own optimizations are emerging every day. As such, Optimum enables users to efficiently use any of these platforms with the same ease inherent to transformers. @@ -10,28 +10,14 @@ As such, Optimum enables users to efficiently use any of these platforms with th ## Integration with Hardware Partners -🤗 Optimum aims at providing more diversity towards the kind of hardware users can target to train and finetune their models. +Optimum aims at providing more diversity towards the kind of hardware users can target to train and finetune their models. To achieve this, we are collaborating with the following hardware manufacturers in order to provide the best transformers integration: - [Graphcore IPUs](https://github.com/huggingface/optimum-graphcore) - IPUs are a completely new kind of massively parallel processor to accelerate machine intelligence. More information [here](https://www.graphcore.ai/products/ipu). - [Habana Gaudi Processor (HPU)](https://github.com/huggingface/optimum-habana) - [HPUs](https://docs.habana.ai/en/latest/Gaudi_Overview/Gaudi_Architecture.html) are designed to maximize training throughput and efficiency. More information [here](https://habana.ai/training/). -- [Intel](https://github.com/huggingface/optimum-intel) - Enabling the usage of Intel tools to accelerate end-to-end pipelines on Intel architectures. More information about [Neural Compressor](https://www.intel.com/content/www/us/en/developer/tools/oneapi/neural-compressor.html) and [OpenVINO](https://docs.openvino.ai/latest/index.html). +- [Intel](https://github.com/huggingface/optimum-intel) - Enabling the usage of Intel tools to accelerate inference on Intel architectures. More information about [Neural Compressor](https://www.intel.com/content/www/us/en/developer/tools/oneapi/neural-compressor.html) and [OpenVINO](https://docs.openvino.ai/latest/index.html). - More to come soon! :star: -## Optimizing models towards inference - -Along with supporting dedicated AI hardware for training, Optimum also provides inference optimizations towards various frameworks and -platforms. - -Optimum enables the usage of popular compression techniques such as quantization and pruning by supporting [ONNX Runtime](https://onnxruntime.ai/docs/) along with [Intel Neural Compressor](https://www.intel.com/content/www/us/en/developer/tools/oneapi/neural-compressor.html). - -| Features | ONNX Runtime | Intel Neural Compressor | -|:----------------------------------:|:---------------------:|:-----------------------:| -| Post-training Dynamic Quantization | :heavy_check_mark: | :heavy_check_mark: | -| Post-training Static Quantization | :heavy_check_mark: | :heavy_check_mark: | -| Quantization Aware Training (QAT) | Stay tuned! :star: | :heavy_check_mark: | -| Pruning | N/A | :heavy_check_mark: | - ## Installation @@ -64,86 +50,180 @@ For the accelerator-specific features, you can install them by appending `#egg=o python -m pip install git+https://github.com/huggingface/optimum.git#egg=optimum[onnxruntime] ``` + +## Optimizing models towards inference + +Along with supporting dedicated AI hardware for training, Optimum also provides inference optimizations towards various frameworks and +platforms. + +Optimum enables the usage of popular compression techniques such as quantization and pruning by supporting [ONNX Runtime](https://onnxruntime.ai/docs/) along with Intel [Neural Compressor](https://www.intel.com/content/www/us/en/developer/tools/oneapi/neural-compressor.html) and OpenVINO [NNCF](https://docs.openvino.ai/latest/tmo_introduction.html). + +| Features | ONNX Runtime | Neural Compressor | OpenVINO | +|:----------------------------------:|:---------------------:|:-----------------------:|:-----------------------:| +| Post-training Dynamic Quantization | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| Post-training Static Quantization | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| Quantization Aware Training (QAT) | Stay tuned! :star: | :heavy_check_mark: | N/A | +| Pruning | N/A | :heavy_check_mark: | Stay tuned! :star: | + ## Quick tour Check out the examples below to see how 🤗 Optimum can be used to train and run inference on various hardware accelerators. -### Accelerated training +## Accelerated inference + +#### ONNX Runtime + +To accelerate inference with ONNX Runtime, 🤗 Optimum uses _configuration objects_ to define parameters for graph optimization and quantization. These objects are then used to instantiate dedicated _optimizers_ and _quantizers_. + +Before applying quantization or optimization, first we need to load our model. To load a model and run inference with ONNX Runtime, you can just replace the canonical Transformers [`AutoModelForXxx`](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModel) class with the corresponding [`ORTModelForXxx`](https://huggingface.co/docs/optimum/onnxruntime/package_reference/modeling_ort#optimum.onnxruntime.ORTModel) class. If you want to load from a PyTorch checkpoint, set `from_transformers=True` to export your model to the ONNX format. + +```python +from optimum.onnxruntime import ORTModelForSequenceClassification +from transformers import AutoTokenizer + +model_checkpoint = "distilbert-base-uncased-finetuned-sst-2-english" +save_directory = "tmp/onnx/" +# Load a model from transformers and export it to ONNX +tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) +ort_model = ORTModelForSequenceClassification.from_pretrained(model_checkpoint, from_transformers=True) +# Save the ONNX model and tokenizer +ort_model.save_pretrained(save_directory) +tokenizer.save_pretrained(save_directory) +``` + +Let's see now how we can apply dynamic quantization with ONNX Runtime: + +```python +from optimum.onnxruntime.configuration import AutoQuantizationConfig +from optimum.onnxruntime import ORTQuantizer + +# Define the quantization methodology +qconfig = AutoQuantizationConfig.arm64(is_static=False, per_channel=False) +quantizer = ORTQuantizer.from_pretrained(ort_model) +# Apply dynamic quantization on the model +quantizer.quantize(save_dir=save_directory, quantization_config=qconfig) +``` + +In this example, we've quantized a model from the Hugging Face Hub, in the same manner we can quantize a model hosted locally by providing the path to the directory containing the model weights. The result from applying the `quantize()` method is a `model_quantized.onnx` file that can be used to run inference. Here's an example of how to load an ONNX Runtime model and generate predictions with it: + +```python +from optimum.onnxruntime import ORTModelForSequenceClassification +from transformers import pipeline, AutoTokenizer + +model = ORTModelForSequenceClassification.from_pretrained(save_directory, file_name="model_quantized.onnx") +tokenizer = AutoTokenizer.from_pretrained(save_directory) +classifier = pipeline("text-classification", model=model, tokenizer=tokenizer) +results = classifier("I love burritos!") +``` + +You can find more examples in the [documentation](https://huggingface.co/docs/optimum/onnxruntime/quickstart) and in the [examples](https://github.com/huggingface/optimum/tree/main/examples/onnxruntime). -#### Optimum Graphcore -To train transformers on Graphcore's IPUs, 🤗 Optimum provides a `IPUTrainer` that is very similar to the [🤗 Transformers trainer](https://huggingface.co/docs/transformers/main_classes/trainer). Here is a simple example: +#### Intel + +To load a model and run inference with OpenVINO Runtime, you can just replace your `AutoModelForXxx` class with the corresponding `OVModelForXxx` class. +If you want to load a PyTorch checkpoint, set `from_transformers=True` to convert your model to the OpenVINO IR (Intermediate Representation). + +```diff +- from transformers import AutoModelForSequenceClassification ++ from optimum.intel.openvino import OVModelForSequenceClassification + from transformers import AutoTokenizer, pipeline + + # Download a tokenizer and model from the Hub and convert to OpenVINO format + tokenizer = AutoTokenizer.from_pretrained(model_id) + model_id = "distilbert-base-uncased-finetuned-sst-2-english" +- model = AutoModelForSequenceClassification.from_pretrained(model_id) ++ model = OVModelForSequenceClassification.from_pretrained(model_id, from_transformers=True) + + # Run inference! + classifier = pipeline("text-classification", model=model, tokenizer=tokenizer) + results = classifier("He's a dreadful magician.") +``` + +You can find more examples in the [documentation](https://huggingface.co/docs/optimum/intel/inference) and in the [examples](https://github.com/huggingface/optimum-intel/tree/main/examples/openvino). + + +## Accelerated training + +#### Habana + +To train transformers on Habana's Gaudi processors, 🤗 Optimum provides a `GaudiTrainer` that is very similar to the 🤗 Transformers [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer). Here is a simple example: ```diff - from transformers import Trainer, TrainingArguments -+ from optimum.graphcore import IPUConfig, IPUTrainer, IPUTrainingArguments ++ from optimum.habana import GaudiTrainer, GaudiTrainingArguments # Download a pretrained model from the Hub model = AutoModelForXxx.from_pretrained("bert-base-uncased") # Define the training arguments - training_args = TrainingArguments( -+ training_args = IPUTrainingArguments( ++ training_args = GaudiTrainingArguments( output_dir="path/to/save/folder/", -+ ipu_config_name="Graphcore/bert-base-ipu", # Any IPUConfig on the Hub or stored locally ++ use_habana=True, ++ use_lazy_mode=True, ++ gaudi_config_name="Habana/bert-base-uncased", ... ) - # Define the configuration to compile and put the model on the IPU -+ ipu_config = IPUConfig.from_pretrained(training_args.ipu_config_name) - # Initialize the trainer - trainer = Trainer( -+ trainer = IPUTrainer( ++ trainer = GaudiTrainer( model=model, -+ ipu_config=ipu_config args=training_args, - train_dataset=train_dataset + train_dataset=train_dataset, ... ) - # Use Graphcore IPU for training! + # Use Habana Gaudi processor for training! trainer.train() ``` +You can find more examples in the [documentation](https://huggingface.co/docs/optimum/habana/quickstart) and in the [examples](https://github.com/huggingface/optimum-habana/tree/main/examples). + -#### Optimum Habana +#### Graphcore -To train transformers on Habana's Gaudi processors, 🤗 Optimum provides a `GaudiTrainer` that is very similar to the [🤗 Transformers trainer](https://huggingface.co/docs/transformers/main_classes/trainer). Here is a simple example: +To train transformers on Graphcore's IPUs, 🤗 Optimum provides a `IPUTrainer` that is very similar to the 🤗 Transformers [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer). Here is a simple example: ```diff - from transformers import Trainer, TrainingArguments -+ from optimum.habana import GaudiTrainer, GaudiTrainingArguments ++ from optimum.graphcore import IPUConfig, IPUTrainer, IPUTrainingArguments # Download a pretrained model from the Hub model = AutoModelForXxx.from_pretrained("bert-base-uncased") # Define the training arguments - training_args = TrainingArguments( -+ training_args = GaudiTrainingArguments( ++ training_args = IPUTrainingArguments( output_dir="path/to/save/folder/", -+ use_habana=True, -+ use_lazy_mode=True, -+ gaudi_config_name="Habana/bert-base-uncased", ++ ipu_config_name="Graphcore/bert-base-ipu", # Any IPUConfig on the Hub or stored locally ... ) + # Define the configuration to compile and put the model on the IPU ++ ipu_config = IPUConfig.from_pretrained(training_args.ipu_config_name) + # Initialize the trainer - trainer = Trainer( -+ trainer = GaudiTrainer( ++ trainer = IPUTrainer( model=model, ++ ipu_config=ipu_config args=training_args, - train_dataset=train_dataset, + train_dataset=train_dataset ... ) - # Use Habana Gaudi processor for training! + # Use Graphcore IPU for training! trainer.train() ``` +You can find more examples in the [documentation](https://huggingface.co/docs/optimum/graphcore/quickstart) and in the [examples](https://github.com/huggingface/optimum-graphcore/tree/main/examples). + + #### ONNX Runtime -To train transformers with ONNX Runtime's acceleration features, 🤗 Optimum provides a `ORTTrainer` that is very similar to the [🤗 Transformers trainer](https://huggingface.co/docs/transformers/main_classes/trainer). Here is a simple example: +To train transformers with ONNX Runtime's acceleration features, 🤗 Optimum provides a `ORTTrainer` that is very similar to the 🤗 Transformers [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer). Here is a simple example: ```diff - from transformers import Trainer, TrainingArguments @@ -174,71 +254,4 @@ To train transformers with ONNX Runtime's acceleration features, 🤗 Optimum pr trainer.train() ``` - -### Accelerated inference - -#### ONNX Runtime - -To accelerate inference with ONNX Runtime, 🤗 Optimum uses _configuration objects_ to define parameters for optimization. These objects are then used to instantiate dedicated _optimizers_ and _quantizers_. - -Before applying quantization or optimization, first export our model to the ONNX format: - -```python -from optimum.onnxruntime import ORTModelForSequenceClassification -from transformers import AutoTokenizer - -model_checkpoint = "distilbert-base-uncased-finetuned-sst-2-english" -save_directory = "tmp/onnx/" -# Load a model from transformers and export it to ONNX -tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) -ort_model = ORTModelForSequenceClassification.from_pretrained(model_checkpoint, from_transformers=True) -# Save the onnx model and tokenizer -ort_model.save_pretrained(save_directory) -tokenizer.save_pretrained(save_directory) -``` - -Let's see now how we can apply dynamic quantization with ONNX Runtime: - -```python -from optimum.onnxruntime.configuration import AutoQuantizationConfig -from optimum.onnxruntime import ORTQuantizer - -# Define the quantization methodology -qconfig = AutoQuantizationConfig.arm64(is_static=False, per_channel=False) -quantizer = ORTQuantizer.from_pretrained(ort_model) -# Apply dynamic quantization on the model -quantizer.quantize(save_dir=save_directory, quantization_config=qconfig) -``` - -In this example, we've quantized a model from the Hugging Face Hub, but it could also be a path to a local model directory. The result from applying the `quantize()` method is a `model_quantized.onnx` file that can be used to run inference. Here's an example of how to load an ONNX Runtime model and generate predictions with it: - -```python -from optimum.onnxruntime import ORTModelForSequenceClassification -from transformers import pipeline, AutoTokenizer - -model = ORTModelForSequenceClassification.from_pretrained(save_directory, file_name="model_quantized.onnx") -tokenizer = AutoTokenizer.from_pretrained(save_directory) -classifier = pipeline("text-classification", model=model, tokenizer=tokenizer) -results = classifier("I love burritos!") -``` - -#### Optimum Intel - -Here is an example on how to perform inference with the OpenVINO Runtime: - -```diff -- from transformers import AutoModelForSequenceClassification -+ from optimum.intel.openvino import OVModelForSequenceClassification - from transformers import AutoTokenizer, pipeline - - # Download a tokenizer and model from the Hub and convert to OpenVINO format - tokenizer = AutoTokenizer.from_pretrained(model_id) - model_id = "distilbert-base-uncased-finetuned-sst-2-english" -- model = AutoModelForSequenceClassification.from_pretrained(model_id) -+ model = OVModelForSequenceClassification.from_pretrained(model_id, from_transformers=True) - - # Run inference! - classifier = pipeline("text-classification", model=model, tokenizer=tokenizer) - results = classifier("He's a dreadful magician.") -``` - +You can find more examples in the [documentation](https://huggingface.co/docs/optimum/onnxruntime/usage_guides/trainer) and in the [examples](https://github.com/huggingface/optimum/tree/main/examples/onnxruntime/training). diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 631bd01a13..670de86244 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -114,7 +114,7 @@ - local: bettertransformer/tutorials/contribute title: How to add support for new architectures? title: Tutorials - title: BetterTransformer integration + title: BetterTransformer isExpanded: false - sections: - local: utils/dummy_input_generators diff --git a/docs/source/bettertransformer/overview.mdx b/docs/source/bettertransformer/overview.mdx index 7355f8f569..8153c365d0 100644 --- a/docs/source/bettertransformer/overview.mdx +++ b/docs/source/bettertransformer/overview.mdx @@ -28,11 +28,13 @@ The list of supported model below: - [BERT](https://arxiv.org/abs/1810.04805) - [BERT-generation](https://arxiv.org/abs/1907.12461) - [CamemBERT](https://arxiv.org/abs/1911.03894) +- [CLIP](https://arxiv.org/abs/2103.00020) - [Data2VecText](https://arxiv.org/abs/2202.03555) - [DistilBert](https://arxiv.org/abs/1910.01108) - [DeiT](https://arxiv.org/abs/2012.12877) - [Electra](https://arxiv.org/abs/2003.10555) - [Ernie](https://arxiv.org/abs/1904.09223) +- [Flava](https://arxiv.org/abs/2112.04482) - [FSMT](https://arxiv.org/abs/1907.06616) - [HuBERT](https://arxiv.org/pdf/2106.07447.pdf) - [LayoutLM](https://arxiv.org/abs/1912.13318) @@ -52,6 +54,8 @@ The list of supported model below: - [XLMRoberta](https://arxiv.org/abs/1911.02116) - [YOLOS](https://arxiv.org/abs/2106.00666) +Note that for encoder-decoder models, only the encoder part is supported by PyTorch's BetterTransformer for now. + Let us know by opening an issue in 🤗 Optimum if you want more models to be supported, or check out the contribution guideline if you want to add it by yourself! ### Quick usage diff --git a/docs/source/onnxruntime/package_reference/modeling_ort.mdx b/docs/source/onnxruntime/package_reference/modeling_ort.mdx index 08896b4279..66786b58cf 100644 --- a/docs/source/onnxruntime/package_reference/modeling_ort.mdx +++ b/docs/source/onnxruntime/package_reference/modeling_ort.mdx @@ -20,6 +20,10 @@ specific language governing permissions and limitations under the License. [[autodoc]] onnxruntime.ORTModelForCausalLM +## ORTModelForCustomTasks + +[[autodoc]] onnxruntime.ORTModelForCustomTasks + ## ORTModelForFeatureExtraction [[autodoc]] onnxruntime.ORTModelForFeatureExtraction diff --git a/optimum/bettertransformer/models/__init__.py b/optimum/bettertransformer/models/__init__.py index a92e243fde..6b46c3e064 100644 --- a/optimum/bettertransformer/models/__init__.py +++ b/optimum/bettertransformer/models/__init__.py @@ -17,7 +17,9 @@ AlbertLayerBetterTransformer, BartEncoderLayerBetterTransformer, BertLayerBetterTransformer, + CLIPLayerBetterTransformer, DistilBertLayerBetterTransformer, + FlavaLayerBetterTransformer, FSMTEncoderLayerBetterTransformer, MBartEncoderLayerBetterTransformer, ViltLayerBetterTransformer, @@ -75,6 +77,14 @@ # FSMTModel: "EncoderLayer": FSMTEncoderLayerBetterTransformer, "ViltLayer": ViltLayerBetterTransformer, + # Flava: + "FlavaLayer": FlavaLayerBetterTransformer, + # CLIP + "CLIPEncoderLayer": CLIPLayerBetterTransformer, +} + +EXCLUDE_FROM_TRANSFORM = { + "clip": ["text_model"], # text model uses causal attention, that is most likely not supported in BetterTransformer } diff --git a/optimum/bettertransformer/models/base.py b/optimum/bettertransformer/models/base.py index 5e799e6cc6..63518f802e 100644 --- a/optimum/bettertransformer/models/base.py +++ b/optimum/bettertransformer/models/base.py @@ -14,12 +14,18 @@ import torch import torch.nn as nn +from ...utils import logging + KNOWN_ACTIVATION_ATTRIBUTES = ["hidden_act", "activation", "act_fn", "activation_function"] KNOWN_POS_EMB_ATTRIBUTES = ["position_embedding_type"] KNOWN_NUM_LAYERS = ["num_hidden_layers", "num_layers", "encoder_layers", "n_layers"] SUPPORTED_ACTIVATION_FUNCTIONS = ["gelu", "relu", "gelu_new"] +USE_AT_OWN_RISK_ACTIVATION_FUNCTIONS = ["quick_gelu"] + + +logger = logging.get_logger(__name__) class BetterTransformerBaseLayer(nn.Module): @@ -39,6 +45,10 @@ def __init__(self, config): self.act_fn = getattr(config, attr) break + # if act_fn not found in the config, fall back to the private `_get_activation_function` if available + if self.act_fn is None and hasattr(self, "_get_activation_function"): + self.act_fn = self._get_activation_function(config) + # Get pos emb type for attr in KNOWN_POS_EMB_ATTRIBUTES: if hasattr(config, attr): @@ -77,7 +87,12 @@ def validate_bettertransformer(self): raise ValueError("norm1_eps and norm2_eps must be equal for `BetterTransformer` integration.") # Check activation function - if self.act_fn not in SUPPORTED_ACTIVATION_FUNCTIONS: + if self.act_fn in USE_AT_OWN_RISK_ACTIVATION_FUNCTIONS: + logger.warning( + f"Overridding {self.act_fn} activation with gelu. Use the transformed model at your own risk, the output logits could be significantly different." + ) + self.act_fn = "gelu" + elif self.act_fn not in SUPPORTED_ACTIVATION_FUNCTIONS: raise ValueError( f"Activation function {self.act_fn} not supported" " for `BetterTransformer` integration." ) diff --git a/optimum/bettertransformer/models/encoder_models.py b/optimum/bettertransformer/models/encoder_models.py index 8340b552e1..eff20567f4 100644 --- a/optimum/bettertransformer/models/encoder_models.py +++ b/optimum/bettertransformer/models/encoder_models.py @@ -11,12 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + import torch import torch.nn as nn from .base import BetterTransformerBaseLayer +if TYPE_CHECKING: + from transformers import PretrainedConfig + + class AlbertLayerBetterTransformer(BetterTransformerBaseLayer): def __init__(self, albert_layer, config): r""" @@ -1095,3 +1101,206 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): elif hidden_states.is_nested and self.is_last_layer: hidden_states = hidden_states.to_padded_tensor(0.0, original_shape) return (hidden_states, attention_mask) + + +class CLIPLayerBetterTransformer(BetterTransformerBaseLayer): + def __init__(self, layer, config): + r""" + A simple conversion of the CLIPEncoderLayer to its `BetterTransformer` implementation. + **The implementation is valid only for the vision model, that does not use `causal_attention_mask`.** + Args: + layer (`torch.nn.Module`): + The original `CLIPEncoderLayer` where the weights needs to be retrieved. + """ + super().__init__(config) + # In_proj layer + self.in_proj_weight = nn.Parameter( + torch.cat( + [ + layer.self_attn.q_proj.weight, + layer.self_attn.k_proj.weight, + layer.self_attn.v_proj.weight, + ] + ) + ) + self.in_proj_bias = nn.Parameter( + torch.cat( + [ + layer.self_attn.q_proj.bias, + layer.self_attn.k_proj.bias, + layer.self_attn.v_proj.bias, + ] + ) + ) + + # Out proj layer + self.out_proj_weight = layer.self_attn.out_proj.weight + self.out_proj_bias = layer.self_attn.out_proj.bias + + # Linear layer 1 + self.linear1_weight = layer.mlp.fc1.weight + self.linear1_bias = layer.mlp.fc1.bias + + # Linear layer 2 + self.linear2_weight = layer.mlp.fc2.weight + self.linear2_bias = layer.mlp.fc2.bias + + # Layer norm 1 + self.norm1_eps = layer.layer_norm1.eps + self.norm1_weight = layer.layer_norm1.weight + self.norm1_bias = layer.layer_norm1.bias + + # Layer norm 2 + self.norm2_eps = layer.layer_norm2.eps + self.norm2_weight = layer.layer_norm2.weight + self.norm2_bias = layer.layer_norm2.bias + + # Model hyper parameters + self.num_heads = layer.self_attn.num_heads + self.embed_dim = layer.self_attn.embed_dim + + # Last step: set the last layer to `False` -> this will be set to `True` when converting the model + self.is_last_layer = False + self.norm_first = True + + self.validate_bettertransformer() + + def forward(self, hidden_states, attention_mask, *_, **__): + r""" + This is just a wrapper around the forward function proposed in: + https://github.com/huggingface/transformers/pull/19553 + """ + super().forward_checker() + + # we expect attention_mask to be None in the vision model + if attention_mask is not None: + raise ValueError( + "Please do not use attention masks when using `BetterTransformer` converted vision models" + ) + + hidden_states = torch._transformer_encoder_layer_fwd( + hidden_states, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.use_gelu, + self.norm_first, + self.norm1_eps, + self.norm1_weight, + self.norm1_bias, + self.norm2_weight, + self.norm2_bias, + self.linear1_weight, + self.linear1_bias, + self.linear2_weight, + self.linear2_bias, + attention_mask, + ) + + return (hidden_states,) + + def _get_activation_function(self, config: "PretrainedConfig"): + if hasattr(config, "vision_config") and hasattr(config, "text_config"): + assert config.vision_config.hidden_act == config.text_config.hidden_act + return config.vision_config.hidden_act + else: + return config.hidden_act + + +class FlavaLayerBetterTransformer(BetterTransformerBaseLayer): + def __init__(self, flava_layer, config): + r""" + A simple conversion of the FlavaLayer to its `BetterTransformer` implementation. + + Args: + flava_layer (`torch.nn.Module`): + The original `FlavaLayer` where the weights needs to be retrieved. + """ + super().__init__(config.image_config) + # In_proj layer + self.in_proj_weight = nn.Parameter( + torch.cat( + [ + flava_layer.attention.attention.query.weight, + flava_layer.attention.attention.key.weight, + flava_layer.attention.attention.value.weight, + ] + ) + ) + self.in_proj_bias = nn.Parameter( + torch.cat( + [ + flava_layer.attention.attention.query.bias, + flava_layer.attention.attention.key.bias, + flava_layer.attention.attention.value.bias, + ] + ) + ) + + # Out proj layer + self.out_proj_weight = flava_layer.attention.output.dense.weight + self.out_proj_bias = flava_layer.attention.output.dense.bias + + # Linear layer 1 + self.linear1_weight = flava_layer.intermediate.dense.weight + self.linear1_bias = flava_layer.intermediate.dense.bias + + # Linear layer 2 + self.linear2_weight = flava_layer.output.dense.weight + self.linear2_bias = flava_layer.output.dense.bias + + # Layer norm 1 + self.norm1_eps = flava_layer.layernorm_before.eps + self.norm1_weight = flava_layer.layernorm_before.weight + self.norm1_bias = flava_layer.layernorm_before.bias + + # Layer norm 2 + self.norm2_eps = flava_layer.layernorm_after.eps + self.norm2_weight = flava_layer.layernorm_after.weight + self.norm2_bias = flava_layer.layernorm_after.bias + + # Model hyper parameters + self.num_heads = flava_layer.attention.attention.num_attention_heads + self.embed_dim = int(flava_layer.attention.attention.attention_head_size * self.num_heads) + + # Last step: set the last layer to `False` -> this will be set to `True` when converting the model + self.is_last_layer = False + self.norm_first = True + + self.validate_bettertransformer() + + def forward(self, hidden_states, *_, **__): + r""" + This is just a wrapper around the forward function proposed in: + https://github.com/huggingface/transformers/pull/19553 + """ + super().forward_checker() + attention_mask = None + + hidden_states = torch._transformer_encoder_layer_fwd( + hidden_states, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.use_gelu, + self.norm_first, + self.norm1_eps, + self.norm1_weight, + self.norm1_bias, + self.norm2_weight, + self.norm2_bias, + self.linear1_weight, + self.linear1_bias, + self.linear2_weight, + self.linear2_bias, + attention_mask, + ) + if hidden_states.is_nested and self.is_last_layer: + hidden_states = hidden_states.to_padded_tensor(0.0) + return (hidden_states,) diff --git a/optimum/bettertransformer/transformation.py b/optimum/bettertransformer/transformation.py index 48bc5decc4..09e7b166b0 100644 --- a/optimum/bettertransformer/transformation.py +++ b/optimum/bettertransformer/transformation.py @@ -16,15 +16,16 @@ import torch -from optimum.utils import check_if_pytorch_greater, is_accelerate_available - -from .models import BETTER_TRANFORMER_LAYERS_MAPPING_DICT, warn_uncompatible_save +from ..utils import check_if_pytorch_greater, is_accelerate_available +from .models import BETTER_TRANFORMER_LAYERS_MAPPING_DICT, EXCLUDE_FROM_TRANSFORM, warn_uncompatible_save if is_accelerate_available(): from accelerate import dispatch_model, infer_auto_device_map from accelerate.hooks import remove_hook_from_module +ERROR_MESSAGE = r"The Better Transformers implementation for the model {model_name} has not been implemented yet. Please open an issue requesting the addition of this model with its `BetterTransformer` implementation." + def replace_to_bettertransformer(model, config): r""" @@ -49,7 +50,11 @@ def replace_to_bettertransformer(model, config): for name, module in model.named_children(): if len(list(module.children())) > 0: - replace_to_bettertransformer(module, config) + # we may explicitly exclude part of the model to use BetterTransformer + if config.model_type not in EXCLUDE_FROM_TRANSFORM or ( + config.model_type in EXCLUDE_FROM_TRANSFORM and name not in EXCLUDE_FROM_TRANSFORM[config.model_type] + ): + replace_to_bettertransformer(module, config) if hasattr(module, "is_decoder"): # Decoders are not supported yet on Better Transformers @@ -72,16 +77,16 @@ def replace_to_bettertransformer(model, config): return model -def set_last_layer(model): +def set_last_layer(model: torch.nn.Module): r""" Iterates over the module list containing the `LayerBetterTransformer` modules. Sets the last layer's `is_last_layer` attribute to `True` Args: - `model` (`torch.nn.Module`, **required**): + `model` (`torch.nn.Module`): The input converted model - Returns: - Returns `True` if it has succesfully set the attribute to `True`, otherwise return `False`. + Raises: + `NotImplementedError`: Raised if this method fails, in which case the model is not supported. """ dict_named_module = dict(model.named_modules()) sort_fn = lambda list_modules: [module.__class__.__name__ for module in list_modules] # noqa: E731 @@ -89,7 +94,17 @@ def set_last_layer(model): modulelist_lengths = [] for key in dict_named_module.keys(): - if isinstance(dict_named_module[key], torch.nn.ModuleList) and "encoder" in key: + if ( + isinstance(dict_named_module[key], torch.nn.ModuleList) + and "encoder" in key + and ( + model.config.model_type not in EXCLUDE_FROM_TRANSFORM + or ( + model.config.model_type in EXCLUDE_FROM_TRANSFORM + and all(name not in key for name in EXCLUDE_FROM_TRANSFORM[model.config.model_type]) + ) + ) + ): modulelist_lengths.append((len(dict_named_module[key]), key)) # For Albert, each transformer layer is wrapped @@ -101,16 +116,16 @@ def set_last_layer(model): for module in largest_module_list[-1].modules(): if "LayerBetterTransformer" in module.__class__.__name__: setattr(module, "is_last_layer", True) - return True - return False + return + raise NotImplementedError(ERROR_MESSAGE.format(model_name=model.__class__.__name__)) else: for key in dict_named_module.keys(): if isinstance(dict_named_module[key], torch.nn.ModuleList) and all( "LayerBetterTransformer" in module_name for module_name in sort_fn(dict_named_module[key]) ): setattr(dict_named_module[key][-1], "is_last_layer", True) - return True - return False + return + raise NotImplementedError(ERROR_MESSAGE.format(model_name=model.__class__.__name__)) class BetterTransformer(object): @@ -176,13 +191,7 @@ def transform( model_fast = replace_to_bettertransformer(model, hf_config).eval() model = None - successfully_converted_model = set_last_layer(model_fast) - if not successfully_converted_model: - raise NotImplementedError( - f"The Better Transformers implementation for the model {model_fast.__class__.__name__} has not been" - f"implemented yet. Please open an issue requesting the addition of this model with its `BetterTransformer`" - f"implementation." - ) + set_last_layer(model_fast) # Step 6: Add a class arguments, we might need to identify whether the model # has been correctly converted to its `BetterTransformer` version. diff --git a/optimum/exporters/onnx/__init__.py b/optimum/exporters/onnx/__init__.py index b9dadaa5e6..c37d44fabd 100644 --- a/optimum/exporters/onnx/__init__.py +++ b/optimum/exporters/onnx/__init__.py @@ -15,9 +15,5 @@ from .base import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast # noqa from .config import TextDecoderOnnxConfig, TextEncoderOnnxConfig, TextSeq2SeqOnnxConfig # noqa -from .convert import ( # noqa - export, - export_encoder_decoder_model, - validate_encoder_decoder_model_outputs, - validate_model_outputs, -) +from .convert import export, export_models, validate_model_outputs, validate_models_outputs # noqa +from .utils import get_decoder_models_for_export, get_encoder_decoder_models_for_export diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 5b1c1d4239..59aec6ec17 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -17,20 +17,17 @@ from argparse import ArgumentParser from pathlib import Path -from transformers import AutoFeatureExtractor, AutoTokenizer +from transformers import AutoTokenizer from ...utils import logging +from ...utils.save_utils import maybe_save_preprocessors from ..tasks import TasksManager from .base import OnnxConfigWithPast -from .convert import ( - export, - export_encoder_decoder_model, - validate_encoder_decoder_model_outputs, - validate_model_outputs, -) +from .convert import export, export_models, validate_model_outputs, validate_models_outputs +from .utils import get_decoder_models_for_export, get_encoder_decoder_models_for_export -logger = logging.get_logger() # pylint: disable=invalid-name +logger = logging.get_logger() logger.setLevel(logging.INFO) @@ -127,15 +124,22 @@ def main(): f"Opset {args.opset} is not sufficient to export {model.config.model_type}. " f"At least {onnx_config.DEFAULT_ONNX_OPSET} is required." ) - - if model.config.is_encoder_decoder and args.for_ort: - onnx_inputs, onnx_outputs = export_encoder_decoder_model( - model, - onnx_config, - args.opset, - args.output.parent.joinpath("encoder_model.onnx"), - args.output.parent.joinpath("decoder_model.onnx"), - args.output.parent.joinpath("decoder_with_past_model.onnx"), + if args.for_ort and (model.config.is_encoder_decoder or task.startswith("causal-lm")): + if model.config.is_encoder_decoder and task.startswith("causal-lm"): + raise ValueError( + f"model.config.is_encoder_decoder is True and task is `{task}`, which are incompatible. If the task was auto-inferred, please fill a bug report" + f"at https://github.com/huggingface/optimum, if --task was explicitely passed, make sure you selected the right task for the model," + f" referring to `optimum.exporters.tasks.TaskManager`'s `_TASKS_TO_AUTOMODELS`." + ) + fn_get_models_from_config = ( + get_encoder_decoder_models_for_export if model.config.is_encoder_decoder else get_decoder_models_for_export + ) + onnx_inputs, onnx_outputs = export_models( + model=model, + onnx_config=onnx_config, + opset=args.opset, + output_dir=args.output.parent, + fn_get_models_from_config=fn_get_models_from_config, ) else: onnx_inputs, onnx_outputs = export(model, onnx_config, args.opset, args.output) @@ -143,18 +147,7 @@ def main(): # Saving the model config as this is needed sometimes. model.config.save_pretrained(args.output.parent) - # Saving the tokenizer / feature extractor as well. - try: - tokenizer = AutoTokenizer.from_pretrained(args.model) - tokenizer.save_pretrained(args.output.parent) - except Exception: - pass - - try: - feature_extractor = AutoFeatureExtractor.from_pretrained(args.model) - feature_extractor.save_pretrained(args.output.parent) - except Exception: - pass + maybe_save_preprocessors(args.model, args.output.parent) if args.atol is None: args.atol = onnx_config.ATOL_FOR_VALIDATION @@ -162,15 +155,19 @@ def main(): args.atol = args.atol[task.replace("-with-past", "")] try: - if model.config.is_encoder_decoder and args.for_ort: - validate_encoder_decoder_model_outputs( - onnx_config, - model, - onnx_outputs, - args.atol, - args.output.parent.joinpath("encoder_model.onnx"), - args.output.parent.joinpath("decoder_model.onnx"), - args.output.parent.joinpath("decoder_with_past_model.onnx"), + if args.for_ort and (model.config.is_encoder_decoder or task.startswith("causal-lm")): + fn_get_models_from_config = ( + get_encoder_decoder_models_for_export + if model.config.is_encoder_decoder + else get_decoder_models_for_export + ) + validate_models_outputs( + onnx_config=onnx_config, + reference_model=model, + onnx_named_outputs=onnx_outputs, + atol=args.atol, + output_dir=args.output.parent, + fn_get_models_from_config=fn_get_models_from_config, ) else: validate_model_outputs(onnx_config, model, args.output, onnx_outputs, args.atol) diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index a4ae0773c6..fa20f5026c 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -344,7 +344,9 @@ class OnnxConfigWithPast(OnnxConfig, ABC): Inherits from [`~exporters.onnx.OnnxConfig`]. A base class to handle the ONNX configuration of decoder-only models. """ - PAD_ATTENTION_MASK_TO_MATCH_TOTAL_SEQUENCE_LENGTH = True + PAD_ATTENTION_MASK_TO_MATCH_TOTAL_SEQUENCE_LENGTH: bool = True + USE_PAST_IN_INPUTS: Optional[bool] = None + USE_PRESENT_IN_OUTPUTS: Optional[bool] = None def __init__( self, @@ -352,8 +354,26 @@ def __init__( task: str = "default", patching_specs: List[PatchingSpec] = None, use_past: bool = False, + use_past_in_inputs: Optional[bool] = None, + use_present_in_outputs: Optional[bool] = None, ): self.use_past = use_past + if use_past_in_inputs is None: + use_past_in_inputs = self.USE_PAST_IN_INPUTS + if use_present_in_outputs is None: + use_present_in_outputs = self.USE_PRESENT_IN_OUTPUTS + self.use_past_in_inputs = use_past if use_past_in_inputs is None else use_past_in_inputs + self.use_present_in_outputs = use_past if use_present_in_outputs is None else use_present_in_outputs + if use_past != self.use_past_in_inputs: + logger.warning( + f"use_past = {use_past} is different than use_past_in_inputs = {use_past_in_inputs}, the value of " + "use_past_in_inputs will used for the inputs." + ) + if use_past != self.use_present_in_outputs: + logger.warning( + f"use_past = {use_past} is different than use_present_in_outputs = {use_present_in_outputs}, the value " + "of use_present_in_outputs value will be used for the outputs." + ) super().__init__(config, task=task, patching_specs=patching_specs) @classmethod @@ -375,15 +395,14 @@ def with_past(cls, config: "PretrainedConfig", task: str = "default") -> "OnnxCo @property def outputs(self) -> Mapping[str, Mapping[int, str]]: common_outputs = super().outputs - if self.use_past: + if self.use_present_in_outputs: self.add_past_key_values(common_outputs, direction="outputs") - return common_outputs @property def values_override(self) -> Optional[Mapping[str, Any]]: if hasattr(self._config, "use_cache"): - return {"use_cache": self.use_past} + return {"use_cache": self.use_past_in_inputs or self.use_present_in_outputs} @add_dynamic_docstring(text=GENERATE_DUMMY_DOCSTRING, dynamic_elements=DEFAULT_DUMMY_SHAPES) def generate_dummy_inputs(self, framework: str = "pt", **kwargs): @@ -407,7 +426,7 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): if ( self.PAD_ATTENTION_MASK_TO_MATCH_TOTAL_SEQUENCE_LENGTH - and self.use_past + and self.use_past_in_inputs and "attention_mask" in dummy_inputs ): past_length = dummy_inputs["past_key_values"][0][0].shape[2] @@ -473,7 +492,7 @@ def outputs(self) -> Mapping[str, Mapping[int, str]]: # We reset the value as the order in common_outputs (OrderedDict) is lost otherwise else: axes_names[axis_idx] = name - if self.use_past: + if self.use_present_in_outputs: self.add_past_key_values(common_outputs, direction="outputs") return common_outputs diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 8e0be8ec97..5ec28fccfe 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -53,7 +53,7 @@ class TextDecoderOnnxConfig(OnnxConfigWithPast): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}} - if self.use_past: + if self.use_past_in_inputs: self.add_past_key_values(common_inputs, direction="inputs") common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + sequence_length"} else: @@ -79,7 +79,7 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: "input_ids": {0: "batch_size", 1: "encoder_sequence_length"}, "attention_mask": {0: "batch_size", 1: "encoder_sequence_length"}, } - if self.use_past: + if self.use_past_in_inputs: common_inputs["attention_mask"][1] = "past_encoder_sequence_length + sequence_length" common_inputs["decoder_input_ids"] = {0: "batch_size"} # common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} @@ -87,7 +87,7 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} # common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} - if self.use_past: + if self.use_past_in_inputs: self.add_past_key_values(common_inputs, direction="inputs") return common_inputs @@ -97,7 +97,7 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGen self.task, self._normalized_config, **kwargs ) - if self.use_past is True: + if self.use_past_in_inputs is True: if "sequence_length" in kwargs and kwargs["sequence_length"] != 1: logger.warning( f"Asked a sequence length of {kwargs['sequence_length']}, but expecting a sequence length of 1 with use_past == True. Overriding the sequence length to 1." diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index b01a946cd9..6f5e36e56c 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -17,14 +17,19 @@ from inspect import signature from itertools import chain from pathlib import Path -from typing import Iterable, List, Optional, Tuple, Union +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np from transformers.utils import is_tf_available, is_torch_available from ...utils import logging from .base import OnnxConfig -from .utils import MIN_TORCH_VERSION, get_encoder_decoder_models_for_export, is_torch_onnx_support_available +from .utils import ( + MIN_TORCH_VERSION, + get_decoder_models_for_export, + get_encoder_decoder_models_for_export, + is_torch_onnx_support_available, +) if is_torch_available(): @@ -61,18 +66,21 @@ def check_dummy_inputs_are_allowed( ) -def validate_encoder_decoder_model_outputs( - config: OnnxConfig, +def validate_models_outputs( + onnx_config: OnnxConfig, reference_model: Union["PreTrainedModel", "TFPreTrainedModel"], onnx_named_outputs: List[str], atol: float, - encoder_onnx_model: Path, - decoder_onnx_model: Path, - decoder_with_past_onnx_model: Optional[Path] = None, + output_dir: Path, + fn_get_models_from_config: Callable[ + [Union["PreTrainedModel", "TFPreTrainedModel"], OnnxConfig], + Dict[str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel"], OnnxConfig]], + ], + output_names: Optional[List[str]] = None, ): """ - Validates the export by checking that the outputs from both the reference and the exported model match. - The following method validates the ONNX models exported using the `export_encoder_decoder_model` method. + Validates the export of several models, by checking that the outputs from both the reference and the exported model match. + The following method validates the ONNX models exported using the `export_models` method. Args: config ([`~OnnxConfig`]: @@ -83,34 +91,43 @@ def validate_encoder_decoder_model_outputs( The names of the outputs to check. atol (`float`): The absolute tolerance in terms of outputs difference between the reference and the exported model. - encoder_onnx_model (`Path`): - The path to the exported encoder ONNX model. - decoder_onnx_model (`Path`): - The path to the exported decoder ONNX model. - decoder_with_past_onnx_model (`Optional[Path]`, defaults to `None`): - The path to the exported decoder with past ONNX model. Required when `past_key_values` are exported. + output_dir (`Path`): + Output directory where the exported ONNX models are stored. + fn_get_models_from_config (`Callable[[Union[PreTrainedModel, TFPreTrainedModel], OnnxConfig], Dict[str, Tuple[Union[PreTrainedModel, TFPreTrainedModel], OnnxConfig]]]`): + Function outputing a dictionary of submodels and downstream onnx configurations, for example to export the `model` in several pieces, + as it is the case with encoder-decoder models. + output_names (`Optional[List[str]]`, defaults to `None`): + The names to use for the exported ONNX files. The order must be the same as the order of submodels in the ordered dict returned by `fn_get_models_from_config`. + If None, will use the keys from the output of `fn_get_models_from_config` as names. Raises: ValueError: If the outputs shapes or values do not match between the reference and the exported model. """ - models_for_validation = get_encoder_decoder_models_for_export(reference_model, config) + models_for_validation = fn_get_models_from_config(reference_model, onnx_config) if len(onnx_named_outputs) != len(models_for_validation.keys()): raise ValueError( f"Invalid number of ONNX named outputs. Required {len(models_for_validation.keys())}, Provided {len(onnx_named_outputs)}" ) - # Validate encoder - model, onnx_config = models_for_validation["encoder"] - validate_model_outputs(onnx_config, model, encoder_onnx_model, onnx_named_outputs[0], atol) - - # Validate decoder - model, onnx_config = models_for_validation["decoder"] - validate_model_outputs(onnx_config, model, decoder_onnx_model, onnx_named_outputs[1], atol) + if output_names is not None and len(output_names) != len(models_for_validation): + raise ValueError( + f"Provided custom names {output_names} for the validation of {len(models_for_validation)} models. Please provide the same number of ONNX file names as models to export." + ) - if config.use_past: - # Validate decoder with past - model, onnx_config = models_for_validation["decoder_with_past"] - validate_model_outputs(onnx_config, model, decoder_with_past_onnx_model, onnx_named_outputs[2], atol) + for i, model_name in enumerate(models_for_validation.keys()): + submodel, sub_onnx_config = models_for_validation[model_name] + onnx_model_path = ( + output_dir.joinpath(output_names[i]) + if output_names is not None + else output_dir.joinpath(model_name + ".onnx") + ) + validate_model_outputs( + sub_onnx_config, + submodel, + onnx_model_path, + onnx_named_outputs[i], + atol, + ) def validate_model_outputs( @@ -378,13 +395,16 @@ def export_tensorflow( return input_names, output_names -def export_encoder_decoder_model( +def export_models( model: Union["PreTrainedModel", "TFPreTrainedModel"], - config: OnnxConfig, + onnx_config: OnnxConfig, opset: int, - encoder_output: Path, - decoder_output: Path, - decoder_with_past_output: Optional[Path] = None, + output_dir: Path, + fn_get_models_from_config: Callable[ + [Union["PreTrainedModel", "TFPreTrainedModel"], OnnxConfig], + Dict[str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel"], OnnxConfig]], + ], + output_names: Optional[List[str]] = None, device: str = "cpu", ) -> Tuple[List[List[str]], List[List[str]]]: """ @@ -399,12 +419,14 @@ def export_encoder_decoder_model( The ONNX configuration associated with the exported model. opset (`int`): The version of the ONNX operator set to use. - encoder_output (`Path`): - Directory to store the exported encoder ONNX model. - decoder_output (`Path`): - Directory to store the exported decoder ONNX model. - decoder_with_past_output (`Optional[Path]`, defaults to `None`): - Directory to store the exported decoder with past ONNX model. Required when `past_key_values` are exported. + output_dir (`Path`): + Output directory to store the exported ONNX models. + fn_get_models_from_config (`Callable[[Union[PreTrainedModel, TFPreTrainedModel], OnnxConfig], Dict[str, Tuple[Union[PreTrainedModel, TFPreTrainedModel], OnnxConfig]]]`): + Function outputing a dictionary of submodels and downstream onnx configurations, for example to export the `model` in several pieces, + as it is the case with encoder-decoder models. + output_names (`Optional[List[str]]`, defaults to `None`): + The names to use for the exported ONNX files. The order must be the same as the order of submodels in the ordered dict returned by `fn_get_models_from_config`. + If None, will use the keys from the output of `fn_get_models_from_config` as names. device (`str`, *optional*, defaults to `cpu`): The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for export on CUDA devices. @@ -412,21 +434,30 @@ def export_encoder_decoder_model( `Tuple[List[List[str]], List[List[str]]]`: A tuple with an ordered list of the model's inputs, and the named inputs from the ONNX configuration. """ - models_for_export = get_encoder_decoder_models_for_export(model, config) + models_for_export = fn_get_models_from_config(model, onnx_config) outputs = [] - # export encoder - model, onnx_config = models_for_export["encoder"] - outputs.append(export(model, onnx_config, opset, encoder_output, device=device)) - - # export decoder - model, onnx_config = models_for_export["decoder"] - outputs.append(export(model, onnx_config, opset, decoder_output, device=device)) + if output_names is not None and len(output_names) != len(models_for_export): + raise ValueError( + f"Provided custom names {output_names} for the export of {len(models_for_export)} models. Please provide the same number of names as models to export." + ) - if config.use_past: - # export decoder with past - model, onnx_config = models_for_export["decoder_with_past"] - outputs.append(export(model, onnx_config, opset, decoder_with_past_output, device=device)) + for i, model_name in enumerate(models_for_export.keys()): + submodel, sub_onnx_config = models_for_export[model_name] + output_path = ( + output_dir.joinpath(output_names[i]) + if output_names is not None + else output_dir.joinpath(model_name + ".onnx") + ) + outputs.append( + export( + submodel, + sub_onnx_config, + opset, + output_path, + device=device, + ) + ) outputs = list(map(list, zip(*outputs))) return outputs diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index c5789cb6d1..f6d248a31f 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -73,6 +73,8 @@ class Seq2SeqDecoderOnnxConfig(TextSeq2SeqOnnxConfig): DummySeq2SeqPastKeyValuesGenerator, ) + USE_PRESENT_IN_OUTPUTS = True + @property def inputs(self) -> Mapping[str, Mapping[int, str]]: common_inputs = { @@ -81,7 +83,7 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: "attention_mask": {0: "batch_size", 1: "encoder_sequence_length"}, } - if self.use_past: + if self.use_past_in_inputs: self.add_past_key_values(common_inputs, direction="inputs") return common_inputs @@ -94,21 +96,6 @@ def torch_to_onnx_input_map(self) -> Mapping[str, str]: "attention_mask": "encoder_attention_mask", } - @property - def outputs(self) -> Mapping[str, Mapping[int, str]]: - common_outputs = super().outputs - self.add_past_key_values(common_outputs, direction="outputs") - return common_outputs - - @property - def values_override(self) -> Optional[Mapping[str, Any]]: - # Needed here because the configuration will actually be used with both use_past = True and use_past = False, - # but the cache must always be used regardless. - if hasattr(self._config, "use_cache"): - return {"use_cache": True} - - return None - def generate_dummy_inputs_for_validation(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]: reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids") reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0] @@ -413,7 +400,7 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGen self.task, self._normalized_config, **kwargs ) - if self.use_past is True: + if self.use_past_in_inputs is True: if "sequence_length" in kwargs and kwargs["sequence_length"] != 1: logger.warning( f"Asked a sequence length of {kwargs['sequence_length']}, but expecting a sequence length of 1 with use_past == True. Overriding the sequence length to 1." @@ -445,14 +432,14 @@ def inputs_for_default_and_seq2seq_lm(self): "input_ids": {0: "batch_size", 1: "encoder_sequence_length"}, "attention_mask": {0: "batch_size", 1: "encoder_sequence_length"}, } - if self.use_past: + if self.use_past_in_inputs: common_inputs["decoder_input_ids"] = {0: "batch_size"} # common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} else: common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} # common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} - if self.use_past: + if self.use_past_in_inputs: self.add_past_key_values(common_inputs, direction="inputs") return common_inputs @@ -462,7 +449,7 @@ def inputs_for_causal_lm(self): "input_ids": {0: "batch_size", 1: "encoder_sequence_length"}, "attention_mask": {0: "batch_size", 1: "encoder_sequence_length"}, } - if self.use_past: + if self.use_past_in_inputs: for i in range(self._normalized_config.decoder_num_layers): common_inputs[f"past_key_values.{i}.key"] = { 0: "batch_size", @@ -498,7 +485,7 @@ def outputs(self) -> Mapping[str, Mapping[int, str]]: common_outputs = super().outputs else: common_outputs = super(OnnxConfigWithPast, self).outputs - if self.use_past: + if self.use_present_in_outputs: for i in range(self._normalized_config.encoder_num_layers): common_outputs[f"present.{i}.key"] = {0: "batch_size", 2: "past_sequence_length + sequence_length"} common_outputs[f"present.{i}.value"] = { @@ -796,7 +783,7 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: "encoder_outputs": {0: "batch_size", 1: "encoder_sequence_length"}, } - if self.use_past: + if self.use_past_in_inputs: self.add_past_key_values(common_inputs, direction="inputs") return common_inputs @@ -817,12 +804,12 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: common_inputs = { "input_features": {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"}, } - if self.use_past: + if self.use_past_in_inputs: common_inputs["decoder_input_ids"] = {0: "batch_size"} else: common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} - if self.use_past: + if self.use_past_in_inputs: self.add_past_key_values(common_inputs, direction="inputs") return common_inputs diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 247f1e6729..47b08aa146 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -14,8 +14,6 @@ # limitations under the License. """Utility functions.""" -from ctypes import c_float, sizeof -from enum import Enum from typing import TYPE_CHECKING, Dict, Tuple, Union import packaging @@ -101,16 +99,47 @@ def get_encoder_decoder_models_for_export( encoder_model = model.get_encoder() encoder_onnx_config = config.get_encoder_onnx_config(encoder_model.config) - models_for_export["encoder"] = (encoder_model, encoder_onnx_config) + models_for_export["encoder_model"] = (encoder_model, encoder_onnx_config) decoder_model = model.get_decoder() decoder_onnx_config = config.get_decoder_onnx_config(decoder_model.config, config.task, use_past=False) - models_for_export["decoder"] = (model, decoder_onnx_config) + models_for_export["decoder_model"] = (model, decoder_onnx_config) if config.use_past: decoder_onnx_config_with_past = config.get_decoder_onnx_config( decoder_model.config, config.task, use_past=True ) - models_for_export["decoder_with_past"] = (model, decoder_onnx_config_with_past) + models_for_export["decoder_with_past_model"] = (model, decoder_onnx_config_with_past) + + return models_for_export + + +def get_decoder_models_for_export( + model: Union["PreTrainedModel", "TFPreTrainedModel"], + config: "OnnxConfig", +) -> Dict[str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel"], "OnnxConfig"]]: + """ + Returns the encoder and decoder parts of the model and their subsequent onnx configs. + + Args: + model ([`PreTrainedModel`] or [`TFPreTrainedModel`]): + The model to export. + config ([`~exporters.onnx.config.OnnxConfig`]): + The ONNX configuration associated with the exported model. + + Returns: + `Dict[str, Tuple[Union[`PreTrainedModel`, `TFPreTrainedModel`], `OnnxConfig`]: A Dict containing the model and + onnx configs for the encoder and decoder parts of the model. + """ + models_for_export = dict() + + models_for_export["decoder_model"] = ( + model, + config.__class__(model.config, task=config.task, use_past=False, use_present_in_outputs=True), + ) + + if config.use_past: + onnx_config_with_past = config.__class__.with_past(model.config, task=config.task) + models_for_export["decoder_with_past_model"] = (model, onnx_config_with_past) return models_for_export diff --git a/optimum/modeling_base.py b/optimum/modeling_base.py index 19f80b8a00..285843deda 100644 --- a/optimum/modeling_base.py +++ b/optimum/modeling_base.py @@ -75,6 +75,7 @@ def __init__(self, model: Union["PreTrainedModel", "TFPreTrainedModel"], config: super().__init__() self.model = model self.config = config + self._preprocessors = [] def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) @@ -117,6 +118,8 @@ def save_pretrained( os.makedirs(save_directory, exist_ok=True) self.config.save_pretrained(save_directory) + for preprocessor in self._preprocessors: + preprocessor.save_pretrained(save_directory) self._save_pretrained(save_directory, **kwargs) if push_to_hub: @@ -132,10 +135,10 @@ def _save_pretrained(self, save_directory, **kwargs): def push_to_hub( self, - save_directory: str = None, - repository_id: Optional[str] = None, + save_directory: str, + repository_id: str, private: Optional[bool] = None, - use_auth_token: Optional[Union[bool, str]] = None, + use_auth_token: Union[bool, str] = True, ) -> str: if isinstance(use_auth_token, str): huggingface_token = use_auth_token diff --git a/optimum/onnx/configuration.py b/optimum/onnx/configuration.py index 40c6877feb..01a51ad500 100644 --- a/optimum/onnx/configuration.py +++ b/optimum/onnx/configuration.py @@ -340,61 +340,6 @@ def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int inputs_or_outputs[f"{name}_key_values_{i}"] = {0: "batch", 2: decoder_sequence} -class DecoderOnnxConfigWithPast(OnnxConfigWithPast): - @property - def inputs(self) -> Mapping[str, Mapping[int, str]]: - common_inputs = OrderedDict([("input_ids", {0: "batch", 1: "sequence"})]) - if self.use_past: - self.fill_with_past_key_values_(common_inputs, direction="inputs") - common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} - else: - common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} - - return common_inputs - - @property - def outputs(self) -> Mapping[str, Mapping[int, str]]: - common_outputs = super().outputs - if not self.use_past: - self.fill_with_past_key_values_(common_outputs, direction="outputs") - return common_outputs - - @property - def num_layers(self) -> Tuple[int]: - num_layers_names = {"decoder_layers", "n_layer", "num_layers"} - for num_layers_name in num_layers_names: - if hasattr(self._config, num_layers_name): - return getattr(self._config, num_layers_name) - raise AttributeError( - "Could not find the number of decoder layers attributes in the model configuration, override the " - "num_layers property to solve this" - ) - - @property - def num_attention_heads(self) -> int: - num_heads_names = {"num_attention_head", "n_head", "num_heads"} - for num_heads_name in num_heads_names: - if hasattr(self._config, num_heads_name): - return getattr(self._config, num_heads_name) - raise AttributeError( - "Could not find the number of decoder attention heads attributes in the model configuration, override the " - "num_heads property to solve this" - ) - - def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str): - num_pkv_per_layer = 2 - name = "past" if direction == "inputs" else "present" - decoder_sequence = "past_decoder_sequence" if direction == "inputs" else "past_decoder_sequence + sequence" - for i in range(self.num_layers * num_pkv_per_layer): - inputs_or_outputs[f"{name}_key_values_{i}"] = {0: "batch", 2: decoder_sequence} - - @property - def values_override(self) -> Optional[Mapping[str, Any]]: - if hasattr(self._config, "use_cache"): - return {"use_cache": True} - return None - - class OnnxSeq2SeqConfigWithPastAndLoss(DecoderOnnxConfig): def __init__(self, config: DecoderOnnxConfig): self.__dict__ = copy.deepcopy(config.__dict__) diff --git a/optimum/onnxruntime/io_binding/__init__.py b/optimum/onnxruntime/io_binding/__init__.py index e0810d5e80..d218d7a700 100644 --- a/optimum/onnxruntime/io_binding/__init__.py +++ b/optimum/onnxruntime/io_binding/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .io_binding_helper import TypeHelper +from .io_binding_helper import IOBindingHelper, TypeHelper diff --git a/optimum/onnxruntime/io_binding/io_binding_helper.py b/optimum/onnxruntime/io_binding/io_binding_helper.py index e8005188be..1911b1f879 100644 --- a/optimum/onnxruntime/io_binding/io_binding_helper.py +++ b/optimum/onnxruntime/io_binding/io_binding_helper.py @@ -11,11 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import traceback + import numpy as np import torch +import onnxruntime as ort +from onnxruntime.capi.onnxruntime_inference_collection import OrtValue from onnxruntime.transformers.io_binding_helper import TypeHelper as ORTTypeHelper +from ..utils import is_cupy_available, is_onnxruntime_training_available + + +if is_cupy_available(): + import cupy as cp + # Adapted from https://github.com/microsoft/onnxruntime/blob/93e0a151177ad8222c2c95f814342bfa27f0a64d/onnxruntime/python/tools/transformers/io_binding_helper.py#L12 class TypeHelper(ORTTypeHelper): @@ -58,3 +69,81 @@ def ort_type_to_torch_type(ort_type: str): raise ValueError( f"{ort_type} is not supported. Here is a list of supported data type: {ort_type_to_torch_type_map.keys()}" ) + + +# Adapted from https://github.com/microsoft/onnxruntime/blob/1ab11a111ce0717bfbfaca964d04a017cb9b1752/onnxruntime/python/tools/transformers/io_binding_helper.py#L97 +class IOBindingHelper: + """ + A helper class to enable `ORTModel` instances to prepare IO binding with dynamic shaped outputs for an inference session and transfer + tensors from ONNX Runtime to other frameworks on device. It helps reduce memory copy between the host and device. + """ + + def __init__(self, model: ort.InferenceSession, device, **kwargs): + self.model = model + self.device = device + # Create {name:idx} dict for model inputs and outputs + self.model_inputs = {output_key.name: idx for idx, output_key in enumerate(model.get_inputs())} + self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(model.get_outputs())} + self.model_input_names = list(self.model_inputs.keys()) + self.model_output_names = list(self.model_outputs.keys()) + + @staticmethod + def to_pytorch(ort_value: OrtValue) -> torch.Tensor: + """ + Converts tensors held by OrtValues in ONNX runtime memory buffer to torch tensor. + """ + + if is_onnxruntime_training_available(): + return IOBindingHelper.to_pytorch_via_dlpack(ort_value) + else: + try: + return IOBindingHelper.to_pytorch_via_cupy(ort_value) + except Exception as e: + logging.error(traceback.format_exc()) + logging.info("Unable to access output memory in CUDA, will offload to CPU") + return IOBindingHelper.to_pytorch_via_numpy(ort_value) + + @staticmethod + def to_pytorch_via_numpy(ort_value: OrtValue) -> torch.Tensor: + ort_device = ort_value.device_name().lower() + return torch.tensor(ort_value.numpy()).to(ort_device) + + @staticmethod + def to_pytorch_via_cupy(ort_value: OrtValue) -> torch.Tensor: + ort_device = ort_value.device_name().lower() + if ort_device != "cuda": + raise RuntimeError(f"Exchange tensors to PyTorch via CuPy only when device is CUDA, got: {ort_device}") + + ort_type = ort_value.data_type() + numpy_type = TypeHelper.ort_type_to_numpy_type(ort_type) + + # Access CUDA memory via CuPy + memory = cp.cuda.UnownedMemory(ort_value.data_ptr(), 0, None) + memory_ptr = cp.cuda.MemoryPointer(memory, 0) + cp_array = cp.ndarray(shape=ort_value.shape(), memptr=memory_ptr, dtype=numpy_type) + torch_tensor = torch.from_dlpack(cp_array.toDlpack()) + + # If is boolean, the dtype will be uint8 and need to be convert back to bool. + if "bool" in ort_type: + torch_tensor = torch_tensor.to(torch.bool) + + torch_tensor = torch_tensor.clone() + + return torch_tensor + + @staticmethod + # dlpack support is available for OrtValue only when `onnxruntime-training` is installed + def to_pytorch_via_dlpack(ort_value: OrtValue) -> torch.Tensor: + from torch._C import _from_dlpack + + torch_tensor = ort_value.to_dlpacks(_from_dlpack) + return torch_tensor + + @staticmethod + def get_device_index(device): + if isinstance(device, str): + # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 + device = torch.device(device) + elif isinstance(device, int): + return device + return 0 if device.index is None else device.index diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 38720f3d95..1ecdbaac4a 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -14,30 +14,34 @@ """Classes handling causal-lm related architectures in ONNX Runtime.""" import logging -import os import shutil from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import numpy as np import torch -import transformers -from transformers import AutoModelForCausalLM, PretrainedConfig -from transformers.file_utils import add_start_docstrings_to_model_forward, default_cache_path +from transformers import AutoModelForCausalLM +from transformers.file_utils import add_start_docstrings_to_model_forward from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions -from transformers.onnx import FeaturesManager, export -from transformers.onnx.utils import get_preprocessor import onnxruntime from huggingface_hub import hf_hub_download -from ..onnx.configuration import DecoderOnnxConfigWithPast +from ..exporters import TasksManager +from ..exporters.onnx import export_models, get_decoder_models_for_export from ..utils import NormalizedConfigManager, check_if_transformers_greater +from ..utils.file_utils import validate_file_exists +from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors from .io_binding import TypeHelper from .modeling_ort import ORTModel from .utils import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, get_provider_for_device, parse_device +if TYPE_CHECKING: + from transformers import PretrainedConfig + + if check_if_transformers_greater("4.25.0"): from transformers.generation import GenerationMixin else: @@ -106,6 +110,9 @@ ``` """ +DECODER_ONNX_FILE_PATTERN = r"(.*)?decoder((?!with_past).)*?\.onnx" +DECODER_WITH_PAST_ONNX_FILE_PATTERN = r"(.*)?decoder(.*)?with_past(.*)?\.onnx" + class ORTDecoder: """ @@ -115,7 +122,7 @@ class ORTDecoder: def __init__( self, session: onnxruntime.InferenceSession, - config: transformers.PretrainedConfig, + config: "PretrainedConfig", device: torch.device, use_io_binding: bool = True, ): @@ -130,8 +137,11 @@ def __init__( self.session_outputs = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} self.session_input_names = list(self.session_inputs.keys()) self.session_output_names = list(self.session_outputs.keys()) - self.key_value_input_names = [key for key in self.session_input_names if "key_values" in key] - self.key_value_output_names = [key for key in self.session_output_names if "key_values" in key] + # TODO: make this less hacky. + self.key_value_input_names = [key for key in self.session_input_names if (".key" in key) or (".value" in key)] + self.key_value_output_names = [ + key for key in self.session_output_names if (".key" in key) or (".value" in key) + ] self.name_to_np_type = TypeHelper.get_io_numpy_type_map(self.session) if self.use_io_binding else None def prepare_output_buffer( @@ -149,7 +159,7 @@ def prepare_output_buffer( if output_name == "logits": output_shape = (batch_size, sequence_length, self.normalized_config.vocab_size) output_buffer = torch.empty(np.prod(output_shape), dtype=torch_type, device=self._device).contiguous() - elif "key_values" in output_name: + elif ".key" in output_name or ".value" in output_name: num_attention_heads = self.normalized_config.num_attention_heads hidden_size = self.normalized_config.hidden_size embed_size_per_head = hidden_size // num_attention_heads @@ -321,19 +331,19 @@ class ORTModelDecoder(ORTModel): def __init__( self, - config: transformers.PretrainedConfig, decoder_session: onnxruntime.InferenceSession, + config: "PretrainedConfig", decoder_with_past_session: Optional[onnxruntime.InferenceSession] = None, use_io_binding: bool = True, - model_save_dir: str = "", - last_decoder_model_name: str = ONNX_DECODER_NAME, - last_decoder_with_past_model_name: str = ONNX_DECODER_WITH_PAST_NAME, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + preprocessors: Optional[List] = None, + **kwargs ): """ Args: decoder_session (`onnxruntime.InferenceSession`): The ONNX Runtime inference session associated to the decoder. - config (`transformers.PretrainedConfig`): + config ([~`transformers.PretrainedConfig`]): An instance of the configuration associated to the model. Initializing with a config file does not load the weights associated with the model, only the configuration. decoder_with_past_session (`Optional[onnxruntime.InferenceSession]`, *optional*): @@ -343,26 +353,40 @@ def __init__( `True` if the device is CUDA, otherwise defaults to `False`. model_save_dir (`str`, *optional*, defaults to `""`): The directory under which the model exported to ONNX was saved. - last_decoder_model_name (`str`, *optional*, defaults to `optimum.onnxruntime.utils.ONNX_DECODER_NAME`): - The name of the ONNX file containing the decoder part of the model. - last_decoder_with_past_model_name (`str`, *optional*, defaults to `optimum.onnxruntime.utils.ONNX_DECODER_WITH_PAST_NAME`): - The name of the ONNX file containing the decoder with past key/values part of the model. + preprocessors (`Optional[List]`, defaults to `None`): + The list of the preprocessors (tokenizer, processor, feature_extractor) to save alongside the ORTModel. """ + # TODO: remove at version 2.0 + def show_deprecated_argument(arg_name): + if kwargs.pop(arg_name, None) is not None: + logger.warning( + f"The {arg_name} argument to create an {self.__class__.__name__} is deprecated, and not used " + "anymore." + ) + + show_deprecated_argument("last_decoder_model_name") + show_deprecated_argument("last_decoder_with_past_model_name") + if kwargs: + raise ValueError( + f"{self.__class__.__name__} received {', '.join(kwargs.keys())}, but do not accept those arguments." + ) + super().__init__( decoder_session, config, use_io_binding=use_io_binding, model_save_dir=model_save_dir, - latest_model_name=last_decoder_model_name, ) - self.decoder_file_name = last_decoder_model_name - self.decoder_file_with_past_name = last_decoder_with_past_model_name - self.use_cache = decoder_with_past_session is not None self.decoder = ORTDecoder( - session=self.model, config=self.config, device=self._device, use_io_binding=self.use_io_binding + session=decoder_session, config=self.config, device=self._device, use_io_binding=self.use_io_binding ) + self.decoder_model_path = Path(decoder_session._model_path) + self.decoder_model_name = self.decoder_model_path.name + self.decoder_with_past = None + self.decoder_with_past_model_path = None + self.decoder_with_past_model_name = None if self.use_cache: self.decoder_with_past = ORTDecoder( session=decoder_with_past_session, @@ -370,6 +394,8 @@ def __init__( device=self._device, use_io_binding=self.use_io_binding, ) + self.decoder_with_past_model_path = Path(decoder_with_past_session._model_path) + self.decoder_with_past_model_name = self.decoder_with_past_model_path.name @staticmethod def load_model( @@ -447,14 +473,13 @@ def _save_pretrained( The decoder with past key values model file name overwriting the default file name, allowing to save the decoder model with a different name. """ - src_file_names = [self.decoder_file_name] + src_paths = [self.decoder_model_path] dst_file_names = [decoder_file_name] if self.use_cache: - src_file_names.append(self.decoder_file_with_past_name) + src_paths.append(self.decoder_with_past_model_path) dst_file_names.append(decoder_with_past_file_name) - for src_file_name, dst_file_name in zip(src_file_names, dst_file_names): - src_path = self.model_save_dir.joinpath(src_file_name) + for src_path, dst_file_name in zip(src_paths, dst_file_names): dst_path = Path(save_directory).joinpath(dst_file_name) shutil.copyfile(src_path, dst_path) @@ -465,74 +490,118 @@ def _from_pretrained( config: "PretrainedConfig", use_auth_token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, - force_download: bool = True, + force_download: bool = False, cache_dir: Optional[str] = None, decoder_file_name: str = ONNX_DECODER_NAME, decoder_with_past_file_name: str = ONNX_DECODER_WITH_PAST_NAME, subfolder: str = "", local_files_only: bool = False, use_cache: bool = True, - use_io_binding: bool = True, provider: str = "CPUExecutionProvider", session_options: Optional[onnxruntime.SessionOptions] = None, provider_options: Optional[Dict[str, Any]] = None, + use_io_binding: bool = True, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, ): - file_names = {} - # Load model from a local directory - if os.path.isdir(os.path.join(model_id, subfolder)): - decoder_with_past_path = ( - os.path.join(model_id, subfolder, decoder_with_past_file_name) if use_cache else None + model_path = Path(model_id) + + if not validate_file_exists(model_id, decoder_file_name, subfolder=subfolder, revision=revision): + decoder_file_name = ORTModelDecoder.infer_onnx_filename( + model_id, + DECODER_ONNX_FILE_PATTERN, + "decoder_file_name", + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + ) + decoder_regular_onnx_filenames = ORTModelDecoder._generate_regular_names_for_filename(ONNX_DECODER_NAME) + if decoder_file_name not in decoder_regular_onnx_filenames: + logger.warning( + f"The ONNX file {decoder_file_name} is not a regular name used in optimum.onnxruntime that are {decoder_regular_onnx_filenames}, the " + f"{cls.__name__} might not behave as expected." + ) + + decoder_with_past_path = None + if use_cache is True: + if not validate_file_exists(model_id, decoder_with_past_file_name, subfolder=subfolder, revision=revision): + decoder_with_past_file_name = ORTModelDecoder.infer_onnx_filename( + model_id, + DECODER_WITH_PAST_ONNX_FILE_PATTERN, + "decoder_with_past_file_name", + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + fail_if_not_found=use_cache, + ) + + decoder_with_past_regular_onnx_filenames = ORTModelDecoder._generate_regular_names_for_filename( + ONNX_DECODER_WITH_PAST_NAME ) + + if decoder_with_past_file_name not in decoder_with_past_regular_onnx_filenames: + logger.warning( + f"The ONNX file {decoder_with_past_file_name} is not a regular name used in optimum.onnxruntime that are {decoder_with_past_regular_onnx_filenames}, " + f"the {cls.__name__} might not behave as expected." + ) + + decoder_with_past_path = model_path / decoder_with_past_file_name if use_cache else None + + preprocessors = None + if model_path.is_dir(): model = cls.load_model( - decoder_path=os.path.join(model_id, subfolder, decoder_file_name), + decoder_path=model_path / decoder_file_name, decoder_with_past_path=decoder_with_past_path, provider=provider, session_options=session_options, provider_options=provider_options, ) - model_save_dir = Path(model_id).joinpath(subfolder) - file_names["last_decoder_model_name"] = decoder_file_name - file_names["last_decoder_with_past_model_name"] = decoder_with_past_file_name - # Load model from hub + new_model_save_dir = model_path + preprocessors = maybe_load_preprocessors(model_id) else: - default_file_names = [ONNX_DECODER_NAME] - model_file_names = [decoder_file_name] - if use_cache: - default_file_names.append(ONNX_DECODER_WITH_PAST_NAME) - model_file_names.append(decoder_with_past_file_name) - # Download the decoder and decoder_with_past forming the model - for file_name, default_file_name in zip(model_file_names, default_file_names): + attribute_name_to_filename = { + "last_decoder_model_name": decoder_file_name, + "last_decoder_with_past_model_name": decoder_with_past_file_name if use_cache else None, + } + paths = {} + for attr_name, filename in attribute_name_to_filename.items(): + if filename is None: + continue model_cache_path = hf_hub_download( repo_id=model_id, subfolder=subfolder, - filename=file_name, + filename=filename, use_auth_token=use_auth_token, revision=revision, cache_dir=cache_dir, force_download=force_download, local_files_only=local_files_only, ) - file_names[f"last_{default_file_name.split('.')[0]}_name"] = Path(model_cache_path).name - model_save_dir = Path(model_cache_path).parent + paths[attr_name] = Path(model_cache_path).name + new_model_save_dir = Path(model_cache_path).parent + preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder) - last_decoder_with_past_name = file_names.get("last_decoder_with_past_model_name", None) + last_decoder_with_past_name = paths.get("last_decoder_with_past_model_name", None) if last_decoder_with_past_name is not None: - last_decoder_with_past_name = model_save_dir.joinpath(last_decoder_with_past_name) + last_decoder_with_past_name = new_model_save_dir / last_decoder_with_past_name + model = cls.load_model( - decoder_path=model_save_dir.joinpath(file_names["last_decoder_model_name"]), + decoder_path=new_model_save_dir / paths["last_decoder_model_name"], decoder_with_past_path=last_decoder_with_past_name, provider=provider, session_options=session_options, provider_options=provider_options, ) + if model_save_dir is None: + model_save_dir = new_model_save_dir + return cls( + model[0], config, - *model, + decoder_with_past_session=model[1], use_io_binding=use_io_binding, model_save_dir=model_save_dir, - last_decoder_model_name=file_names["last_decoder_model_name"], - last_decoder_with_past_model_name=file_names.get("last_decoder_with_past_model_name", None), + preprocessors=preprocessors, ) @classmethod @@ -540,46 +609,70 @@ def _from_transformers( cls, model_id: str, config: "PretrainedConfig", - subfolder: Optional[str] = "", - save_dir: Union[str, Path] = default_cache_path, use_auth_token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, + revision: str = "main", force_download: bool = True, cache_dir: Optional[str] = None, + subfolder: str = "", + local_files_only: bool = False, use_cache: bool = True, - **kwargs, - ): - # Create local save dir in cache dir - save_dir = Path(save_dir).joinpath(model_id) - save_dir.mkdir(parents=True, exist_ok=True) - preprocessor = get_preprocessor(model_id) - framework = FeaturesManager.determine_framework(os.path.join(model_id, subfolder)) - model_class = FeaturesManager.get_model_class_for_feature(cls.export_feature, framework) - model = model_class.from_pretrained(model_id, subfolder=subfolder, config=config, cache_dir=cache_dir) - - # Export the decoder without the past key values - onnx_config = DecoderOnnxConfigWithPast(model.config, task=cls.export_feature, use_past=False) - onnx_opset = onnx_config.default_onnx_opset - export( - preprocessor=preprocessor, + provider: str = "CPUExecutionProvider", + session_options: Optional[onnxruntime.SessionOptions] = None, + provider_options: Optional[Dict[str, Any]] = None, + use_io_binding: bool = True, + task: Optional[str] = None, + ) -> "ORTModelDecoder": + if task is None: + task = cls._AUTOMODELS_TO_TASKS[cls.auto_model_class] + + save_dir = TemporaryDirectory() + save_dir_path = Path(save_dir.name) + + model = TasksManager.get_model_from_task( + task, + model_id, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + config=config, + use_auth_token=use_auth_token, + local_files_only=local_files_only, + force_download=force_download, + ) + + model_type = model.config.model_type.replace("_", "-") + model_name = getattr(model, "name", None) + + onnx_config_constructor = TasksManager.get_exporter_config_constructor( + model_type, "onnx", task=task, model_name=model_name + ) + onnx_config = onnx_config_constructor(model.config, use_past=use_cache) + + output_names = [ONNX_DECODER_NAME] + if use_cache is True: + output_names.append(ONNX_DECODER_WITH_PAST_NAME) + export_models( model=model, - config=onnx_config, - opset=onnx_opset, - output=save_dir.joinpath(ONNX_DECODER_NAME), + onnx_config=onnx_config, + opset=onnx_config.DEFAULT_ONNX_OPSET, + output_dir=save_dir_path, + fn_get_models_from_config=get_decoder_models_for_export, + output_names=output_names, ) - # Export the decoder with the past key values - if use_cache: - onnx_config_with_past = DecoderOnnxConfigWithPast(model.config, task=cls.export_feature, use_past=True) - export( - preprocessor=preprocessor, - model=model, - config=onnx_config_with_past, - opset=onnx_opset, - output=save_dir.joinpath(ONNX_DECODER_WITH_PAST_NAME), - ) + config.save_pretrained(save_dir_path) + maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder) - return cls._from_pretrained(save_dir, config=config, use_cache=use_cache, **kwargs) + return cls._from_pretrained( + save_dir_path, + config, + use_cache=use_cache, + provider=provider, + session_options=session_options, + provider_options=provider_options, + use_io_binding=use_io_binding, + model_save_dir=save_dir, + ) def to(self, device: Union[torch.device, str, int]): """ @@ -612,8 +705,6 @@ class ORTModelForCausalLM(ORTModelDecoder, GenerationMixin): ONNX model with a causal language modeling head for ONNX Runtime inference. """ - # Used to export the model to ONNX - export_feature = "causal-lm" auto_model_class = AutoModelForCausalLM main_input_name = "input_ids" diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index 94a6668007..f696346d1d 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -14,10 +14,10 @@ """ORTModelForXXX classes, allowing to run ONNX Models with ONNX Runtime using the same API as Transformers.""" import logging -import os import shutil from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import numpy as np import torch @@ -30,7 +30,7 @@ AutoModelForSequenceClassification, AutoModelForTokenClassification, ) -from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, default_cache_path +from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers.modeling_outputs import ( BaseModelOutput, ImageClassifierOutput, @@ -42,12 +42,14 @@ ) import onnxruntime as ort -from huggingface_hub import hf_hub_download +from huggingface_hub import HfApi, HfFolder, hf_hub_download from ..exporters import TasksManager from ..exporters.onnx import export from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel -from .io_binding import TypeHelper +from ..utils.file_utils import find_files_matching_pattern +from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors +from .io_binding import IOBindingHelper, TypeHelper from .utils import ( ONNX_WEIGHTS_NAME, get_device_for_provider, @@ -107,6 +109,14 @@ """ +class classproperty: + def __init__(self, getter): + self.getter = getter + + def __get__(self, instance, owner): + return self.getter(owner) + + class ORTModel(OptimizedModel): """ Base class for implementing models using ONNX Runtime. @@ -125,32 +135,63 @@ class ORTModel(OptimizedModel): - config ([`~transformers.PretrainedConfig`] -- The configuration of the model. - use_io_binding (`bool`, *optional*, defaults to `True`) -- Whether to use I/O bindings with **ONNX Runtime with the CUDAExecutionProvider**, this can significantly speedup inference depending on the task. - - model_save_dir (`Optional[str]`, *optional*) -- The directory where the model exported to ONNX will be saved. + - model_save_dir (`Path`) -- The directory where the model exported to ONNX is saved. By defaults, if the loaded model is local, the directory where the original model will be used. Otherwise, the cache directory is used. - - latest_model_name (`str`, *optional*, defaults to `"model.onnx"` -- The name of the last ONNX model file. - providers (`List[str]) -- The list of execution providers available to ONNX Runtime. """ + _AUTOMODELS_TO_TASKS = {cls_: task for task, cls_ in TasksManager._TASKS_TO_AUTOMODELS.items()} model_type = "onnx_model" auto_model_class = AutoModel + @classproperty + def export_feature(cls): + logger.warning(f"{cls.__name__}.export_feature is deprecated, and will be removed in optimum 2.0.") + return cls._AUTOMODELS_TO_TASKS.get(cls.auto_model_class, None) + def __init__( self, model: ort.InferenceSession, config: "PretrainedConfig", use_io_binding: bool = True, - model_save_dir: Optional[str] = None, - latest_model_name: str = "model.onnx", + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + preprocessors: Optional[List] = None, + **kwargs, ): - self.model = model - self.config = config + # TODO: remove at version 2.0 + if kwargs.pop("latest_model_name", None) is not None: + logger.warning( + f"The latest_model_name argument to create an {self.__class__.__name__} is deprecated, and not used " + "anymore." + ) + if kwargs: + raise ValueError( + f"{self.__class__.__name__} received {', '.join(kwargs.keys())}, but do not accept those arguments." + ) + + super().__init__(model, config) self.use_io_binding = use_io_binding - self.model_save_dir = model_save_dir - self.latest_model_name = latest_model_name self.providers = model.get_providers() self._device = get_device_for_provider(self.providers[0]) + # This attribute is needed to keep one reference on the temporary directory, since garbage collecting it + # would end-up removing the directory containing the underlying ONNX model. + self._model_save_dir_tempdirectory_instance = None + if model_save_dir is None: + self.model_save_dir = Path(model._model_path).parent + elif isinstance(model_save_dir, TemporaryDirectory): + self._model_save_dir_tempdirectory_instance = model_save_dir + self.model_save_dir = Path(model_save_dir.name) + elif isinstance(model_save_dir, str): + self.model_save_dir = Path(model_save_dir) + else: + self.model_save_dir = model_save_dir + self.model_path = Path(model._model_path) + self.model_name = self.model_path.name + + self._preprocessors = preprocessors if preprocessors is not None else [] + if self._device is None: logger.warning( f"ORTModel outputs will be sent to CPU as the device could not be inferred from the execution provider {self.providers[0]}." @@ -237,6 +278,9 @@ def load_model( # Follow advice in https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#python providers.append("CUDAExecutionProvider") + if not isinstance(path, str): + path = str(path) + # `providers` list must of be of the same length as `provider_options` list return ort.InferenceSession( path, @@ -257,9 +301,49 @@ def _save_pretrained(self, save_directory: Union[str, Path], file_name: str = ON file_name (`str`, *optional*, defaults to the value of `optimum.onnxruntime.utils.ONNX_WEIGHTS_NAME`): The filename to use when saving the model. """ - src_path = self.model_save_dir.joinpath(self.latest_model_name) + # TODO: support models with external data dst_path = Path(save_directory).joinpath(file_name) - shutil.copyfile(src_path, dst_path) + shutil.copyfile(self.model_path, dst_path) + + @staticmethod + def _generate_regular_names_for_filename(filename: str): + name, extension = filename.rsplit(".", maxsplit=1) + return [filename, f"{name}_quantized.{extension}", f"{name}_optimized.{extension}"] + + @staticmethod + def infer_onnx_filename( + model_name_or_path: Union[str, Path], + pattern: str, + argument_name: str, + subfolder: str = "", + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + fail_if_not_found: bool = True, + ) -> str: + onnx_files = find_files_matching_pattern( + model_name_or_path, + pattern, + glob_pattern="**/*.onnx", + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + ) + + path = model_name_or_path + if subfolder != "": + path = f"{path}/{subfolder}" + + if len(onnx_files) == 0: + if fail_if_not_found: + raise FileNotFoundError(f"Could not find any ONNX model file in {path}") + return None + elif len(onnx_files) > 1: + if argument_name is not None: + raise RuntimeError( + f"Too many ONNX model files were found in {path}, specify which one to load by using the " + f"{argument_name} argument." + ) + return onnx_files[0].name @classmethod def _from_pretrained( @@ -270,23 +354,56 @@ def _from_pretrained( revision: Optional[str] = None, force_download: bool = False, cache_dir: Optional[str] = None, - file_name: str = ONNX_WEIGHTS_NAME, + file_name: Optional[str] = None, subfolder: str = "", local_files_only: bool = False, provider: str = "CPUExecutionProvider", session_options: Optional[ort.SessionOptions] = None, provider_options: Optional[Dict[str, Any]] = None, - **kwargs, + use_io_binding: bool = True, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, ) -> "ORTModel": - if os.path.isdir(os.path.join(model_id, subfolder)): + model_path = Path(model_id) + regular_onnx_filenames = ORTModel._generate_regular_names_for_filename(ONNX_WEIGHTS_NAME) + + if file_name is None: + if model_path.is_dir(): + onnx_files = list(model_path.glob("*.onnx")) + else: + if isinstance(use_auth_token, bool): + token = HfFolder().get_token() + else: + token = use_auth_token + repo_files = map(Path, HfApi().list_repo_files(model_id, revision=revision, token=token)) + pattern = "*.onnx" if subfolder == "" else f"{subfolder}/*.onnx" + onnx_files = [p for p in repo_files if p.match(pattern)] + + if len(onnx_files) == 0: + raise FileNotFoundError(f"Could not find any ONNX model file in {model_path}") + elif len(onnx_files) > 1: + raise RuntimeError( + f"Too many ONNX model files were found in {model_path}, specify which one to load by using the " + "file_name argument." + ) + else: + file_name = onnx_files[0].name + + if file_name not in regular_onnx_filenames: + logger.warning( + f"The ONNX file {file_name} is not a regular name used in optimum.onnxruntime, the ORTModel might " + "not behave as expected." + ) + + preprocessors = None + if model_path.is_dir(): model = ORTModel.load_model( - os.path.join(model_id, subfolder, file_name), + model_path / file_name, provider=provider, session_options=session_options, provider_options=provider_options, ) - kwargs["model_save_dir"] = Path(model_id).joinpath(subfolder) - kwargs["latest_model_name"] = file_name + new_model_save_dir = model_path + preprocessors = maybe_load_preprocessors(model_id) else: model_cache_path = hf_hub_download( repo_id=model_id, @@ -301,17 +418,27 @@ def _from_pretrained( model = ORTModel.load_model( model_cache_path, provider=provider, session_options=session_options, provider_options=provider_options ) - kwargs["model_save_dir"] = Path(model_cache_path).parent - kwargs["latest_model_name"] = Path(model_cache_path).name + new_model_save_dir = Path(model_cache_path).parent + preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder) - return cls(model=model, config=config, **kwargs) + # model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it + # instead of the path only. + if model_save_dir is None: + model_save_dir = new_model_save_dir + + return cls( + model=model, + config=config, + use_io_binding=use_io_binding, + model_save_dir=model_save_dir, + preprocessors=preprocessors, + ) @classmethod def _from_transformers( cls, model_id: str, config: "PretrainedConfig", - save_dir: Union[str, Path] = default_cache_path, use_auth_token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, @@ -321,23 +448,11 @@ def _from_transformers( provider: str = "CPUExecutionProvider", session_options: Optional[ort.SessionOptions] = None, provider_options: Optional[Dict[str, Any]] = None, - **kwargs, + use_io_binding: bool = True, + task: Optional[str] = None, ) -> "ORTModel": - save_dir = Path(save_dir).joinpath(model_id, subfolder) - save_dir.mkdir(parents=True, exist_ok=True) - - # Reads pipeline task from ORTModelForXXX class if available else tries to extract from hub - if cls.export_feature is not None: - task = cls.export_feature - else: - # TODO: Do we want to actually support that? - # TODO: load from subfolder? - task = TasksManager.infer_task_from_model(model_id, revision=revision) - # TODO: is it still needed? - if task in ["sentiment-analysis", "text-classification", "zero-shot-classification"]: - task = "sequence-classification" - elif task in ["feature-extraction", "fill-mask"]: - task = "default" + if task is None: + task = cls._AUTOMODELS_TO_TASKS[cls.auto_model_class] kwargs_to_get_model = { "subfolder": subfolder, @@ -352,14 +467,26 @@ def _from_transformers( onnx_config = onnx_config_class(model.config) + tmp_dir = TemporaryDirectory() + tmp_dir_path = Path(tmp_dir.name) export( model=model, config=onnx_config, opset=onnx_config.DEFAULT_ONNX_OPSET, - output=save_dir.joinpath(ONNX_WEIGHTS_NAME), + output=tmp_dir_path / ONNX_WEIGHTS_NAME, + ) + config.save_pretrained(tmp_dir_path) + maybe_save_preprocessors(model_id, tmp_dir_path, src_subfolder=subfolder) + + return cls._from_pretrained( + tmp_dir_path, + config, + use_io_binding=use_io_binding, + model_save_dir=tmp_dir, + provider=provider, + session_options=session_options, + provider_options=provider_options, ) - - return cls._from_pretrained(save_dir.as_posix(), config, **kwargs) @classmethod @add_start_docstrings(FROM_PRETRAINED_START_DOCSTRING) @@ -454,8 +581,6 @@ class ORTModelForFeatureExtraction(ORTModel): Feature Extraction model for ONNX. """ - # used in from_transformers to export model to onnx - export_feature = "default" auto_model_class = AutoModel def __init__(self, model=None, config=None, use_io_binding=True, **kwargs): @@ -627,7 +752,6 @@ class ORTModelForQuestionAnswering(ORTModel): Question Answering model for ONNX. """ - export_feature = "question-answering" auto_model_class = AutoModelForQuestionAnswering def __init__(self, model=None, config=None, use_io_binding=True, **kwargs): @@ -827,7 +951,6 @@ class ORTModelForSequenceClassification(ORTModel): Sequence Classification model for ONNX. """ - export_feature = "sequence-classification" auto_model_class = AutoModelForSequenceClassification def __init__(self, model=None, config=None, use_io_binding=True, **kwargs): @@ -998,7 +1121,6 @@ class ORTModelForTokenClassification(ORTModel): Token Classification model for ONNX. """ - export_feature = "token-classification" auto_model_class = AutoModelForTokenClassification def __init__(self, model=None, config=None, use_io_binding=True, **kwargs): @@ -1164,7 +1286,6 @@ class ORTModelForMultipleChoice(ORTModel): Multiple choice model for ONNX. """ - export_feature = "multiple-choice" auto_model_class = AutoModelForMultipleChoice def __init__(self, model=None, config=None, use_io_binding=True, **kwargs): @@ -1334,7 +1455,6 @@ class ORTModelForImageClassification(ORTModel): Image Classification model for ONNX. """ - export_feature = "image-classification" auto_model_class = AutoModelForImageClassification def __init__(self, model=None, config=None, use_io_binding=True, **kwargs): @@ -1462,21 +1582,47 @@ def forward( ) class ORTModelForCustomTasks(ORTModel): """ - Onnx Model for any custom tasks. + Onnx Model for any custom tasks using encoder or decoder-only models. """ - export_feature = "default" - auto_model_class = AutoModel + def __init__(self, model=None, config=None, use_io_binding=True, **kwargs): + super().__init__(model, config, use_io_binding=True, **kwargs) + self.model_inputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_inputs())} + self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_outputs())} + self.model_input_names = list(self.model_inputs.keys()) + self.model_output_names = list(self.model_outputs.keys()) - def __init__(self, model=None, config=None, **kwargs): - super().__init__(model, config, **kwargs) - if kwargs.pop("use_io_binding", False): - logger.warning( - "ORTModelForCustomTasks doesn't support IO Binding yet, and the inference will be done without IO binding which could cause" - " significant overhead on data copying. If you want us to enable IO binding for custom use case, please open an issue in " - "Optimum: https://github.com/huggingface/optimum." + def prepare_io_binding(self, **kwargs) -> ort.IOBinding: + """ + Returns IOBinding object for an inference session. This method is created for general purpose, if the inputs and outputs + are determined, you can prepare data buffers directly to avoid tensor transfers across frameworks. + """ + + name_to_np_type = TypeHelper.get_io_numpy_type_map(self.model) + + # Bind inputs and outputs to onnxruntime session + io_binding = self.model.io_binding() + + # Bind inputs + for input_name in self.model_input_names: + onnx_input = kwargs.pop(input_name) + onnx_input = onnx_input.contiguous() + + io_binding.bind_input( + input_name, + onnx_input.device.type, + self.device.index, + name_to_np_type[input_name], + list(onnx_input.size()), + onnx_input.data_ptr(), ) + # Bind outputs + for name in self.model_output_names: + io_binding.bind_output(name, self.device.type, device_id=self.device.index) + + return io_binding + @add_start_docstrings_to_model_forward( CUSTOM_TASKS_EXAMPLE.format( processor_class=_TOKENIZER_FOR_DOC, @@ -1485,13 +1631,30 @@ def __init__(self, model=None, config=None, **kwargs): ) ) def forward(self, **kwargs): - # converts pytorch inputs into numpy inputs for onnx - onnx_inputs = self._prepare_onnx_inputs(**kwargs) - # run inference - onnx_outputs = self.model.run(None, onnx_inputs) - outputs = self._prepare_onnx_outputs(onnx_outputs) - # converts outputs to namedtuple for pipelines post-processing if applicable - return ModelOutput(outputs) + if self.device.type == "cuda" and self.use_io_binding: + io_binding = self.prepare_io_binding(**kwargs) + + # run inference with binding + io_binding.synchronize_inputs() + self.model.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() + + outputs = {} + for name, output in zip(self.model_output_names, io_binding._iobinding.get_outputs()): + outputs[name] = IOBindingHelper.to_pytorch(output) + + # converts output to namedtuple for pipelines post-processing + return ModelOutput(**outputs) + else: + # converts pytorch inputs into numpy inputs for onnx + onnx_inputs = self._prepare_onnx_inputs(**kwargs) + + # run inference + onnx_outputs = self.model.run(None, onnx_inputs) + outputs = self._prepare_onnx_outputs(onnx_outputs) + + # converts output to namedtuple for pipelines post-processing + return ModelOutput(outputs) def _prepare_onnx_inputs(self, **kwargs): model_inputs = {input_key.name: idx for idx, input_key in enumerate(self.model.get_inputs())} diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index b7a5155807..eaa09ec3c1 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -17,32 +17,34 @@ """ import logging -import os +import re import shutil from abc import ABC, abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import numpy as np import torch -from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, AutoTokenizer -from transformers.file_utils import add_start_docstrings_to_model_forward, default_cache_path -from transformers.generation_utils import GenerationMixin +from transformers import AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq +from transformers.file_utils import add_start_docstrings_to_model_forward from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput import onnxruntime as ort from huggingface_hub import hf_hub_download -from ..exporters.onnx.convert import export_encoder_decoder_model as export +from ..exporters.onnx import export_models, get_encoder_decoder_models_for_export from ..exporters.tasks import TasksManager from ..utils import NormalizedConfigManager, check_if_transformers_greater +from ..utils.file_utils import validate_file_exists +from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors from .io_binding import TypeHelper +from .modeling_decoder import ORTDecoder from .modeling_ort import ORTModel from .utils import ( ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME, - get_device_for_provider, get_provider_for_device, parse_device, validate_provider_availability, @@ -197,475 +199,161 @@ ``` """ +ENCODER_ONNX_FILE_PATTERN = r"(.*)?encoder(.*)?\.onnx" +DECODER_ONNX_FILE_PATTERN = r"(.*)?decoder((?!with_past).)*?\.onnx" +DECODER_WITH_PAST_ONNX_FILE_PATTERN = r"(.*)?decoder(.*)?with_past(.*)?\.onnx" -class ORTModelForConditionalGeneration(ORTModel, ABC): - """ - Sequence-to-sequence model with a language modeling head for ONNX Runtime inference. - - Important attributes: - config ([`PretrainedConfig`]): - Instance of the configuration associated to the model. Initializing with a config file does - not load the weights associated with the model, only the configuration. - use_io_binding (`bool`): - Whether use IOBinding during inference to avoid memory copy between the host and devices. Defaults to `True` - if the device is CUDA, otherwise defaults to `False`. - use_cache (`bool`): - Whether or not past key/values cache should be used. It is determined by whether an InferenceSession for - that was provided or not. - providers (`List[str`]): - The list of execution providers the model is running on. - encoder (`ORTEncoder`): - The encoder model. - decoder (`ORTDecoder`): - The decoder model. - decoder_with_past (`Optional[ORTDecoder]`): - The decoder model handling the past key/values if `use_cache=True`, else `None`. - - Other attributes: - encoder_file_name (`str`, defaults to `optimum.onnxruntime.utils.ONNX_ENCODER_NAME`): - The name of the ONNX file containing the encoder part of the model. - decoder_file_name (`str`, defaults to `optimum.onnxruntime.utils.ONNX_DECODER_NAME`): - The name of the ONNX file containing the decoder part of the model. - decoder_file_with_past_name (`str`, defaults to `optimum.onnxruntime.utils.ONNX_DECODER_WITH_PAST_NAME`): - The name of the ONNX file containing the decoder with past key/values part of the model. - model_save_dir (`str`, defaults to `""`): - The directory under which the model exported to ONNX was saved. +class ORTEncoder: + """ + Encoder part of the encoder-decoder model for ONNX Runtime inference. """ - - # Used in from_transformers to export model to onnxORTEncoder - base_model_prefix = "onnx_model" def __init__( self, - encoder_session: ort.InferenceSession, - decoder_session: ort.InferenceSession, + session: ort.InferenceSession, config: "PretrainedConfig", - decoder_with_past_session: Optional[ort.InferenceSession] = None, + device: torch.device, use_io_binding: bool = True, - model_save_dir: str = "", - last_encoder_model_name: str = ONNX_ENCODER_NAME, - last_decoder_model_name: str = ONNX_DECODER_NAME, - last_decoder_with_past_model_name: str = ONNX_DECODER_WITH_PAST_NAME, + main_input_name: str = "input_ids", ): - """ - Args: - encoder_session (`ort.InferenceSession`): - The ONNX Runtime inference session associated to the encoder. - decoder_session (`ort.InferenceSession`): - The ONNX Runtime inference session associated to the decoder. - config ([`PretrainedConfig`]): - `config` is an instance of the configuration associated to the model. Initializing with a config file - does not load the weights associated with the model, only the configuration. - decoder_with_past_session (`Optional[ort.InferenceSession]`, *optional*): - The ONNX Runtime inference session associated to the decoder with past key values. - use_io_binding (`bool`, *optional*, defaults to `True`): - Whether use IOBinding during inference to avoid memory copy between the host and devices. Defaults to - `True` if the device is CUDA, otherwise defaults to `False`. - model_save_dir (`str`, *optional*, defaults to `""`): - The directory under which the model exported to ONNX was saved. - last_encoder_model_name (`str`, *optional*, defaults to `optimum.onnxruntime.utils.ONNX_ENCODER_NAME`): - The name of the ONNX file containing the encoder part of the model. - last_decoder_model_name (`str`, *optional*, defaults to `optimum.onnxruntime.utils.ONNX_DECODER_NAME`): - The name of the ONNX file containing the decoder part of the model. - last_decoder_with_past_model_name (`str`, *optional*, defaults to `optimum.onnxruntime.utils.ONNX_DECODER_WITH_PAST_NAME`): - The name of the ONNX file containing the decoder with past key/values part of the model. - """ - ABC.__init__(self) + self.session = session + self.config = config + self._device = device + self.use_io_binding = use_io_binding + self.main_input_name = main_input_name + self.normalized_config = NormalizedConfigManager.get_normalized_config_class(self.config.model_type)( + self.config + ) + self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())} + self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} + self.name_to_np_type = TypeHelper.get_io_numpy_type_map(self.session) if self.use_io_binding else None - self.encoder_file_name = last_encoder_model_name - self.decoder_file_name = last_decoder_model_name - self.decoder_file_with_past_name = last_decoder_with_past_model_name + def prepare_output_buffer(self, batch_size, sequence_length): + """Prepare the buffer of output(`last_hidden_state`) with a 1D tensor on shape: (batch_size, sequence_length, hidden_size).""" + ort_type = TypeHelper.get_output_type(self.session, "last_hidden_state") + torch_type = TypeHelper.ort_type_to_torch_type(ort_type) - self.config = config + hidden_size = self.normalized_config.hidden_size + output_shape = (batch_size, sequence_length, hidden_size) + output_buffer = torch.empty(np.prod(output_shape), dtype=torch_type, device=self._device).contiguous() - self.use_io_binding = use_io_binding - self.model_save_dir = model_save_dir + return output_shape, output_buffer - self.providers = encoder_session.get_providers() - self._device = get_device_for_provider(encoder_session.get_providers()[0]) + def prepare_io_binding( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + io_binding = self.session.io_binding() - if "TensorrtExecutionProvider" in self.providers and self.use_io_binding: - logger.warning( - "There is no need to do IO binding for TensorrtExecutionProvider, `use_io_binding` will be set to False." + # bind input ids + input_ids = input_ids.contiguous() + io_binding.bind_input( + "input_ids", + input_ids.device.type, + self._device.index, + self.name_to_np_type["input_ids"], + tuple(input_ids.shape), + input_ids.data_ptr(), + ) + if "attention_mask" in self.input_names: + # bind attention mask + attention_mask = attention_mask.contiguous() + io_binding.bind_input( + "attention_mask", + attention_mask.device.type, + self._device.index, + self.name_to_np_type["attention_mask"], + tuple(attention_mask.shape), + attention_mask.data_ptr(), ) - self.use_io_binding = False - self.encoder = self._initialize_encoder( - session=encoder_session, config=self.config, device=self._device, use_io_binding=self.use_io_binding + # bind last_hidden_state + output_shape, output_buffer = self.prepare_output_buffer( + batch_size=input_ids.size(0), + sequence_length=input_ids.size(1), ) - self.decoder = ORTDecoder( - session=decoder_session, config=self.config, device=self._device, use_io_binding=self.use_io_binding + io_binding.bind_output( + "last_hidden_state", + output_buffer.device.type, + self._device.index, + self.name_to_np_type["last_hidden_state"], + output_shape, + output_buffer.data_ptr(), ) + output_shapes = {"last_hidden_state": output_shape} + output_buffers = {"last_hidden_state": output_buffer} - self.use_cache = decoder_with_past_session is not None + return io_binding, output_shapes, output_buffers - # If a decoder_with_past_path is provided, an inference session for the decoder with past key/values as inputs - # will be enabled - self.decoder_with_past = None - if self.use_cache: - self.decoder_with_past = ORTDecoder( - session=decoder_with_past_session, - config=self.config, - device=self._device, - use_io_binding=self.use_io_binding, - ) + @add_start_docstrings_to_model_forward(SEQ2SEQ_ENCODER_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor, + attention_mask: torch.LongTensor, + **kwargs, + ) -> BaseModelOutput: - # Registers the ORTModelForXXX classes into the transformers AutoModel classes - # to avoid warnings when create a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863 - AutoConfig.register(self.base_model_prefix, AutoConfig) - self.auto_model_class.register(AutoConfig, self.__class__) + if self._device.type == "cuda" and self.use_io_binding: + io_binding, output_shapes, output_buffers = self.prepare_io_binding(input_ids, attention_mask) - @abstractmethod - def _initialize_encoder( - self, - session: ort.InferenceSession, - config: "PretrainedConfig", - device: torch.device, - use_io_binding: bool = True, - ) -> "ORTEncoder": - pass + # run inference with binding & synchronize in case of multiple CUDA streams + io_binding.synchronize_inputs() + self.session.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() - @staticmethod - def load_model( - encoder_path: Union[str, Path], - decoder_path: Union[str, Path], - decoder_with_past_path: Optional[Union[str, Path]] = None, - provider: str = "CPUExecutionProvider", - session_options: Optional[ort.SessionOptions] = None, - provider_options: Optional[Dict] = None, - ): - """ - Creates an instance of [`~optimum.onnxruntime.modeling_seq2seq.ORTModelForConditionalGeneration`]. - Three inference sessions will be created for respectively the encoder, decoder and decoder with past key values - models. The default provider is `CPUExecutionProvider` to match the default behaviour in PyTorch/TensorFlow/JAX. + # converts output to namedtuple for pipelines post-processing + return BaseModelOutput( + last_hidden_state=output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) + ) + else: + onnx_inputs = {"input_ids": input_ids.cpu().detach().numpy()} - Args: - encoder_path (`Union[str, Path]`): - The path of the encoder ONNX model. - decoder_path (`Union[str, Path]`): - The path of the decoder ONNX model. - decoder_with_past_path (`Optional[Union[str, Path]]`, *optional*): - The path of the decoder with past key values ONNX model. - provider (`str`, *optional*, defaults to `"CPUExecutionProvider"`): - ONNX Runtime provider to use for loading the model. See https://onnxruntime.ai/docs/execution-providers/ - for possible providers. - session_options (`Optional[ort.SessionOptions]`, *optional*),: - ONNX Runtime session options to use for loading the model. Defaults to `None`. - provider_options (`Optional[Dict]`, *optional*): - Provider option dictionary corresponding to the provider used. See available options - for each provider: https://onnxruntime.ai/docs/api/c/group___global.html . Defaults to `None`. - """ - validate_provider_availability(provider) # raise error if the provider is not available + # Add the attention_mask inputs when needed + if "attention_mask" in self.input_names: + onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy() - providers = [provider] - if provider == "TensorrtExecutionProvider": - # follow advice in https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#python - providers.append("CUDAExecutionProvider") + # Run inference + outputs = self.session.run(None, onnx_inputs) + last_hidden_state = torch.from_numpy(outputs[self.output_names["last_hidden_state"]]).to(self._device) - encoder_session = ort.InferenceSession( - str(encoder_path), - providers=providers, - sess_options=session_options, - provider_options=None if provider_options is None else [provider_options], - ) - decoder_session = ort.InferenceSession( - str(decoder_path), - providers=providers, - sess_options=session_options, - provider_options=None if provider_options is None else [provider_options], - ) + return BaseModelOutput(last_hidden_state=last_hidden_state) - decoder_with_past_session = None - # If a decoder_with_past_path is provided, an inference session for the decoder with past key/values as inputs - # will be enabled - if decoder_with_past_path is not None: - decoder_with_past_session = ort.InferenceSession( - str(decoder_with_past_path), - providers=providers, - sess_options=session_options, - provider_options=None if provider_options is None else [provider_options], - ) + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) - return encoder_session, decoder_session, decoder_with_past_session - def _save_pretrained( +class ORTEncoderForWhisper(ORTEncoder): + """ + Encoder model for ONNX Runtime inference for Whisper model. + + Args: + session (`ort.InferenceSession`): + The ONNX Runtime inference session associated to the encoder. + """ + + def prepare_io_binding( self, - save_directory: Union[str, Path], - # TODO: should we make the default values available here? - encoder_file_name: Optional[str] = None, - decoder_file_name: Optional[str] = None, - decoder_with_past_file_name: Optional[str] = None, + input_features: torch.FloatTensor = None, ): - """ - Saves the model encoder, decoder and decoder with past key values as well as its configuration file to a - directory, so that it can be re-loaded using the - [`~optimum.onnxruntime.modeling_seq2seq.ORTModelForSeq2SeqLM.from_pretrained`] class method. + io_binding = self.session.io_binding() - Args: - save_directory (`Union[str, Path`]): - The directory where to save the model files. - encoder_file_name(`Optional[str]`, *optional*): - The encoder model file name. Overwrites the default file name and allows one to save the encoder model - with a different name. - decoder_file_name(`Optional[str]`, *optional*): - The decoder model file name. Overwrites the default file name and allows one to save the decoder model - with a different name. - decoder_with_past_file_name(`Optional[str]`, *optional*): - The decoder with past key values model file name overwriting the default file name, allowing to save - the decoder model with a different name. - """ - src_file_names = [self.encoder_file_name, self.decoder_file_name] - dst_file_names = [encoder_file_name or ONNX_ENCODER_NAME, decoder_file_name or ONNX_DECODER_NAME] - if self.use_cache: - src_file_names.append(self.decoder_file_with_past_name) - dst_file_names.append(decoder_with_past_file_name or ONNX_DECODER_WITH_PAST_NAME) - - for src_file_name, dst_file_name in zip(src_file_names, dst_file_names): - src_path = self.model_save_dir.joinpath(src_file_name) - dst_path = Path(save_directory).joinpath(dst_file_name) - shutil.copyfile(src_path, dst_path) - - @classmethod - def _from_pretrained( - cls, - model_id: Union[str, Path], - config: "PretrainedConfig", - use_auth_token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, - force_download: bool = False, - cache_dir: Optional[str] = None, - encoder_file_name: str = ONNX_ENCODER_NAME, - decoder_file_name: str = ONNX_DECODER_NAME, - decoder_with_past_file_name: str = ONNX_DECODER_WITH_PAST_NAME, - subfolder: str = "", - local_files_only: bool = False, - use_cache: bool = True, - provider: str = "CPUExecutionProvider", - session_options: Optional[ort.SessionOptions] = None, - provider_options: Optional[Dict[str, Any]] = None, - use_io_binding: bool = True, - ): - kwargs = {"use_io_binding": use_io_binding} - - # Load model from a local directory - if os.path.isdir(os.path.join(model_id, subfolder)): - decoder_with_past_path = ( - os.path.join(model_id, subfolder, decoder_with_past_file_name) if use_cache else None - ) - model = cls.load_model( - encoder_path=os.path.join(model_id, subfolder, encoder_file_name), - decoder_path=os.path.join(model_id, subfolder, decoder_file_name), - decoder_with_past_path=decoder_with_past_path, - provider=provider, - session_options=session_options, - provider_options=provider_options, - ) - kwargs["model_save_dir"] = Path(model_id).joinpath(subfolder) - kwargs["last_encoder_model_name"] = encoder_file_name - kwargs["last_decoder_model_name"] = decoder_file_name - kwargs["last_decoder_with_past_model_name"] = decoder_with_past_file_name - # Load model from hub - else: - default_file_names = [ONNX_ENCODER_NAME, ONNX_DECODER_NAME] - model_file_names = [encoder_file_name, decoder_file_name] - if use_cache: - default_file_names.append(ONNX_DECODER_WITH_PAST_NAME) - model_file_names.append(decoder_with_past_file_name) - # Download the encoder, decoder and decoder_with_past forming the model - for file_name, default_file_name in zip(model_file_names, default_file_names): - model_cache_path = hf_hub_download( - repo_id=model_id, - subfolder=subfolder, - filename=file_name, - use_auth_token=use_auth_token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - ) - kwargs[f"last_{default_file_name.split('.')[0]}_name"] = Path(model_cache_path).name - kwargs["model_save_dir"] = Path(model_cache_path).parent - - last_decoder_with_past_name = kwargs.get("last_decoder_with_past_model_name", None) - if last_decoder_with_past_name is not None: - last_decoder_with_past_name = kwargs["model_save_dir"].joinpath(last_decoder_with_past_name) - model = cls.load_model( - encoder_path=kwargs["model_save_dir"].joinpath(kwargs["last_encoder_model_name"]), - decoder_path=kwargs["model_save_dir"].joinpath(kwargs["last_decoder_model_name"]), - decoder_with_past_path=last_decoder_with_past_name, - provider=provider, - session_options=session_options, - provider_options=provider_options, - ) - - return cls(*model[:2], config, decoder_with_past_session=model[2], **kwargs) - - @classmethod - def _from_transformers( - cls, - model_id: str, - config: "PretrainedConfig", - save_dir: Union[str, Path] = default_cache_path, - use_auth_token: Optional[Union[bool, str]] = None, - revision: str = "main", - force_download: bool = True, - cache_dir: Optional[str] = None, - subfolder: str = "", - local_files_only: bool = False, - use_cache: bool = True, - provider: str = "CPUExecutionProvider", - session_options: Optional[ort.SessionOptions] = None, - provider_options: Optional[Dict[str, Any]] = None, - use_io_binding: bool = True, - ): - # Create local save dir in cache dir - save_dir = Path(save_dir).joinpath(model_id, subfolder) - save_dir.mkdir(parents=True, exist_ok=True) - - model = TasksManager.get_model_from_task( - cls.export_feature, - model_id, - subfolder=subfolder, - revision=revision, - cache_dir=cache_dir, - config=config, - use_auth_token=use_auth_token, - local_files_only=local_files_only, - force_download=force_download, - ) - - model_type = model.config.model_type.replace("_", "-") - model_name = getattr(model, "name", None) - - onnx_config_constructor = TasksManager.get_exporter_config_constructor( - model_type, "onnx", task=cls.export_feature, model_name=model_name - ) - onnx_config = onnx_config_constructor(model.config, use_past=use_cache) - onnx_opset = onnx_config.DEFAULT_ONNX_OPSET - - export( - model, - onnx_config, - onnx_opset, - save_dir.joinpath(ONNX_ENCODER_NAME), - save_dir.joinpath(ONNX_DECODER_NAME), - save_dir.joinpath(ONNX_DECODER_WITH_PAST_NAME), - ) - - return cls._from_pretrained( - save_dir, - config=config, - use_cache=use_cache, - provider=provider, - session_options=session_options, - provider_options=provider_options, - use_io_binding=use_io_binding, - ) - - def to(self, device: Union[torch.device, str, int]): - """ - Changes the ONNX Runtime provider according to the device. - - Args: - device (`torch.device` or `str` or `int`): - Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run - the model on the associated CUDA device id. You can pass native `torch.device` or a `str` too. - - Returns: - `ORTModel`: the model placed on the requested device. - """ - device, provider_options = parse_device(device) - - provider = get_provider_for_device(device) - validate_provider_availability(provider) # raise error if the provider is not available - - self.device = device - self.encoder._device = device - self.encoder.session.set_providers([provider], provider_options=[provider_options]) - self.decoder._device = device - self.decoder.session.set_providers([provider], provider_options=[provider_options]) - if self.decoder_with_past is not None: - self.decoder_with_past._device = device - self.decoder_with_past.session.set_providers([provider], provider_options=[provider_options]) - self.providers = self.encoder.session.get_providers() - - return self - - -class ORTEncoder: - """ - Encoder model for ONNX Runtime inference. - - Args: - session (`ort.InferenceSession`): - The ONNX Runtime inference session associated to the encoder. - """ - - def __init__( - self, - session: ort.InferenceSession, - config: "PretrainedConfig", - device: torch.device, - use_io_binding: bool = True, - main_input_name: str = "input_ids", - ): - self.session = session - self.config = config - self._device = device - self.use_io_binding = use_io_binding - self.main_input_name = main_input_name - self.normalized_config = NormalizedConfigManager.get_normalized_config_class(self.config.model_type)( - self.config - ) - self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())} - self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} - self.name_to_np_type = TypeHelper.get_io_numpy_type_map(self.session) if self.use_io_binding else None - - def prepare_output_buffer(self, batch_size, sequence_length): - """Prepare the buffer of output(`last_hidden_state`) with a 1D tensor on shape: (batch_size, sequence_length, hidden_size).""" - ort_type = TypeHelper.get_output_type(self.session, "last_hidden_state") - torch_type = TypeHelper.ort_type_to_torch_type(ort_type) - - hidden_size = self.normalized_config.hidden_size - output_shape = (batch_size, sequence_length, hidden_size) - output_buffer = torch.empty(np.prod(output_shape), dtype=torch_type, device=self._device).contiguous() - - return output_shape, output_buffer - - def prepare_io_binding( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ): - io_binding = self.session.io_binding() - - # bind input ids - input_ids = input_ids.contiguous() + # bind input features + input_features = input_features.contiguous() io_binding.bind_input( - "input_ids", - input_ids.device.type, + "input_features", + input_features.device.type, self._device.index, - self.name_to_np_type["input_ids"], - tuple(input_ids.shape), - input_ids.data_ptr(), + self.name_to_np_type["input_features"], + tuple(input_features.shape), + input_features.data_ptr(), ) - if "attention_mask" in self.input_names: - # bind attention mask - attention_mask = attention_mask.contiguous() - io_binding.bind_input( - "attention_mask", - attention_mask.device.type, - self._device.index, - self.name_to_np_type["attention_mask"], - tuple(attention_mask.shape), - attention_mask.data_ptr(), - ) - # bind last_hidden_state + # bind logits output_shape, output_buffer = self.prepare_output_buffer( - batch_size=input_ids.size(0), - sequence_length=input_ids.size(1), + batch_size=input_features.size(0), + sequence_length=input_features.size(2) // 2, ) io_binding.bind_output( "last_hidden_state", @@ -680,16 +368,14 @@ def prepare_io_binding( return io_binding, output_shapes, output_buffers - @add_start_docstrings_to_model_forward(SEQ2SEQ_ENCODER_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING) def forward( self, - input_ids: torch.LongTensor, - attention_mask: torch.LongTensor, + input_features: torch.FloatTensor, **kwargs, ) -> BaseModelOutput: - if self._device.type == "cuda" and self.use_io_binding: - io_binding, output_shapes, output_buffers = self.prepare_io_binding(input_ids, attention_mask) + io_binding, output_shapes, output_buffers = self.prepare_io_binding(input_features) # run inference with binding & synchronize in case of multiple CUDA streams io_binding.synchronize_inputs() @@ -701,11 +387,7 @@ def forward( last_hidden_state=output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) ) else: - onnx_inputs = {"input_ids": input_ids.cpu().detach().numpy()} - - # Add the attention_mask inputs when needed - if "attention_mask" in self.input_names: - onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy() + onnx_inputs = {"input_features": input_features.cpu().detach().numpy()} # Run inference outputs = self.session.run(None, onnx_inputs) @@ -713,121 +395,20 @@ def forward( return BaseModelOutput(last_hidden_state=last_hidden_state) - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - -class ORTEncoderForWhisper(ORTEncoder): +class ORTDecoderForSeq2Seq(ORTDecoder): """ - Encoder model for ONNX Runtime inference for Whisper model. - - Args: - session (`ort.InferenceSession`): - The ONNX Runtime inference session associated to the encoder. + Decoder model with a language modeling head on top for ONNX Runtime inference. """ - def prepare_io_binding( + def prepare_output_buffer( self, - input_features: torch.FloatTensor = None, - ): - io_binding = self.session.io_binding() - - # bind input features - input_features = input_features.contiguous() - io_binding.bind_input( - "input_features", - input_features.device.type, - self._device.index, - self.name_to_np_type["input_features"], - tuple(input_features.shape), - input_features.data_ptr(), - ) - - # bind logits - output_shape, output_buffer = self.prepare_output_buffer( - batch_size=input_features.size(0), - sequence_length=input_features.size(2) // 2, - ) - io_binding.bind_output( - "last_hidden_state", - output_buffer.device.type, - self._device.index, - self.name_to_np_type["last_hidden_state"], - output_shape, - output_buffer.data_ptr(), - ) - output_shapes = {"last_hidden_state": output_shape} - output_buffers = {"last_hidden_state": output_buffer} - - return io_binding, output_shapes, output_buffers - - @add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING) - def forward( - self, - input_features: torch.FloatTensor, - **kwargs, - ) -> BaseModelOutput: - if self._device.type == "cuda" and self.use_io_binding: - io_binding, output_shapes, output_buffers = self.prepare_io_binding(input_features) - - # run inference with binding & synchronize in case of multiple CUDA streams - io_binding.synchronize_inputs() - self.session.run_with_iobinding(io_binding) - io_binding.synchronize_outputs() - - # converts output to namedtuple for pipelines post-processing - return BaseModelOutput( - last_hidden_state=output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) - ) - else: - onnx_inputs = {"input_features": input_features.cpu().detach().numpy()} - - # Run inference - outputs = self.session.run(None, onnx_inputs) - last_hidden_state = torch.from_numpy(outputs[self.output_names["last_hidden_state"]]).to(self._device) - - return BaseModelOutput(last_hidden_state=last_hidden_state) - - -class ORTDecoder: - """ - Decoder model with a language modeling head on top for ONNX Runtime inference. - - Args: - session (`ort.InferenceSession`): - The ONNX Runtime inference session associated to the decoder. - """ - - def __init__( - self, - session: ort.InferenceSession, - config: "PretrainedConfig", - device: torch.device, - use_io_binding: bool = True, - ): - self.session = session - self.config = config - self.normalized_config = NormalizedConfigManager.get_normalized_config_class(self.config.model_type)( - self.config - ) - self._device = device - self.use_io_binding = use_io_binding - self.session_inputs = {output_key.name: idx for idx, output_key in enumerate(self.session.get_inputs())} - self.session_outputs = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} - self.session_input_names = list(self.session_inputs.keys()) - self.session_output_names = list(self.session_outputs.keys()) - self.key_value_input_names = [key for key in self.session_input_names if (".key" in key or ".value" in key)] - self.key_value_output_names = [key for key in self.session_output_names if (".key" in key or ".value" in key)] - self.name_to_np_type = TypeHelper.get_io_numpy_type_map(self.session) if self.use_io_binding else None - - def prepare_output_buffer( - self, - output_name, - batch_size=None, - sequence_length=None, - encoder_sequence_length=None, - past_sequence_length=None, - is_self_attn=False, + output_name, + batch_size=None, + sequence_length=None, + encoder_sequence_length=None, + past_sequence_length=None, + is_self_attn=False, ): """ Prepare the buffer of outputs(`logits`/`key_values`/`loss`) with 1D tensors. @@ -1072,6 +653,7 @@ def forward( # Run inference outputs = self.session.run(None, onnx_inputs) + # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the # self-attention layer and 2 to the cross-attention layer) past_key_values = tuple( @@ -1096,12 +678,485 @@ def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) +class ORTModelForConditionalGeneration(ORTModel, ABC): + """ + Sequence-to-sequence model with a language modeling head for ONNX Runtime inference. + + Important attributes: + config ([`PretrainedConfig`]): + Instance of the configuration associated to the model. Initializing with a config file does + not load the weights associated with the model, only the configuration. + use_io_binding (`bool`): + Whether use IOBinding during inference to avoid memory copy between the host and devices. Defaults to `True` + if the device is CUDA, otherwise defaults to `False`. + use_cache (`bool`): + Whether or not past key/values cache should be used. It is determined by whether an InferenceSession for + that was provided or not. + providers (`List[str`]): + The list of execution providers the model is running on. + encoder (`ORTEncoder`): + The encoder model. + decoder (`ORTDecoderForSeq2Seq`): + The decoder model. + decoder_with_past (`Optional[ORTDecoderForSeq2Seq]`): + The decoder model handling the past key/values if `use_cache=True`, else `None`. + + Other attributes: + encoder_file_name (`str`, defaults to `optimum.onnxruntime.utils.ONNX_ENCODER_NAME`): + The name of the ONNX file containing the encoder part of the model. + decoder_file_name (`str`, defaults to `optimum.onnxruntime.utils.ONNX_DECODER_NAME`): + The name of the ONNX file containing the decoder part of the model. + decoder_file_with_past_name (`str`, defaults to `optimum.onnxruntime.utils.ONNX_DECODER_WITH_PAST_NAME`): + The name of the ONNX file containing the decoder with past key/values part of the model. + model_save_dir (`str`, defaults to `""`): + The directory under which the model exported to ONNX was saved. + + """ + + # Used in from_transformers to export model to onnxORTEncoder + base_model_prefix = "onnx_model" + + def __init__( + self, + encoder_session: ort.InferenceSession, + decoder_session: ort.InferenceSession, + config: "PretrainedConfig", + decoder_with_past_session: Optional[ort.InferenceSession] = None, + use_io_binding: bool = True, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + preprocessors: Optional[List] = None, + **kwargs, + ): + """ + Args: + encoder_session (`ort.InferenceSession`): + The ONNX Runtime inference session associated to the encoder. + decoder_session (`ort.InferenceSession`): + The ONNX Runtime inference session associated to the decoder. + config ([`PretrainedConfig`]): + `config` is an instance of the configuration associated to the model. Initializing with a config file + does not load the weights associated with the model, only the configuration. + decoder_with_past_session (`Optional[ort.InferenceSession]`, *optional*): + The ONNX Runtime inference session associated to the decoder with past key values. + use_io_binding (`bool`, *optional*, defaults to `True`): + Whether use IOBinding during inference to avoid memory copy between the host and devices. Defaults to + `True` if the device is CUDA, otherwise defaults to `False`. + model_save_dir (`str`, *optional*, defaults to `""`): + The directory under which the model exported to ONNX was saved. + preprocessors (`Optional[List]`, defaults to `None`): + The list of the preprocessors (tokenizer, processor, feature_extractor) to save alongside the ORTModel. + """ + # TODO: remove at version 2.0 + def show_deprecated_argument(arg_name): + if kwargs.pop(arg_name, None) is not None: + logger.warning( + f"The {arg_name} argument to create an {self.__class__.__name__} is deprecated, and not used " + "anymore." + ) + + show_deprecated_argument("last_encoder_model_name") + show_deprecated_argument("last_decoder_model_name") + show_deprecated_argument("last_decoder_with_past_model_name") + if kwargs: + raise ValueError( + f"{self.__class__.__name__} received {', '.join(kwargs.keys())}, but do not accept those arguments." + ) + + ABC.__init__(self) + + ORTModel.__init__( + self, + encoder_session, + config, + use_io_binding=use_io_binding, + model_save_dir=model_save_dir, + preprocessors=preprocessors, + ) + self.encoder = self._initialize_encoder( + session=encoder_session, config=self.config, device=self._device, use_io_binding=self.use_io_binding + ) + self.encoder_model_path = Path(encoder_session._model_path) + self.encoder_model_name = self.encoder_model_path.name + + self.decoder = ORTDecoderForSeq2Seq( + session=decoder_session, config=self.config, device=self._device, use_io_binding=self.use_io_binding + ) + self.decoder_model_path = Path(decoder_session._model_path) + self.decoder_model_name = self.decoder_model_path.name + + self.use_cache = decoder_with_past_session is not None + + # If a decoder_with_past_path is provided, an inference session for the decoder with past key/values as inputs + # will be enabled + self.decoder_with_past = None + self.decoder_with_past_model_path = None + self.decoder_with_past_model_name = None + if self.use_cache: + self.decoder_with_past = ORTDecoderForSeq2Seq( + session=decoder_with_past_session, + config=self.config, + device=self._device, + use_io_binding=self.use_io_binding, + ) + self.decoder_with_past_model_path = Path(decoder_with_past_session._model_path) + self.decoder_with_past_model_name = self.decoder_with_past_model_path.name + + @abstractmethod + def _initialize_encoder( + self, + session: ort.InferenceSession, + config: "PretrainedConfig", + device: torch.device, + use_io_binding: bool = True, + ) -> "ORTEncoder": + pass + + @staticmethod + def load_model( + encoder_path: Union[str, Path], + decoder_path: Union[str, Path], + decoder_with_past_path: Optional[Union[str, Path]] = None, + provider: str = "CPUExecutionProvider", + session_options: Optional[ort.SessionOptions] = None, + provider_options: Optional[Dict] = None, + ): + """ + Creates an instance of [`~optimum.onnxruntime.modeling_seq2seq.ORTModelForConditionalGeneration`]. + Three inference sessions will be created for respectively the encoder, decoder and decoder with past key values + models. The default provider is `CPUExecutionProvider` to match the default behaviour in PyTorch/TensorFlow/JAX. + + Args: + encoder_path (`Union[str, Path]`): + The path of the encoder ONNX model. + decoder_path (`Union[str, Path]`): + The path of the decoder ONNX model. + decoder_with_past_path (`Optional[Union[str, Path]]`, *optional*): + The path of the decoder with past key values ONNX model. + provider (`str`, *optional*, defaults to `"CPUExecutionProvider"`): + ONNX Runtime provider to use for loading the model. See https://onnxruntime.ai/docs/execution-providers/ + for possible providers. + session_options (`Optional[ort.SessionOptions]`, *optional*),: + ONNX Runtime session options to use for loading the model. Defaults to `None`. + provider_options (`Optional[Dict]`, *optional*): + Provider option dictionary corresponding to the provider used. See available options + for each provider: https://onnxruntime.ai/docs/api/c/group___global.html . Defaults to `None`. + """ + validate_provider_availability(provider) # raise error if the provider is not available + + providers = [provider] + if provider == "TensorrtExecutionProvider": + # follow advice in https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#python + providers.append("CUDAExecutionProvider") + + encoder_session = ort.InferenceSession( + str(encoder_path), + providers=providers, + sess_options=session_options, + provider_options=None if provider_options is None else [provider_options], + ) + decoder_session = ort.InferenceSession( + str(decoder_path), + providers=providers, + sess_options=session_options, + provider_options=None if provider_options is None else [provider_options], + ) + + decoder_with_past_session = None + # If a decoder_with_past_path is provided, an inference session for the decoder with past key/values as inputs + # will be enabled + if decoder_with_past_path is not None: + decoder_with_past_session = ort.InferenceSession( + str(decoder_with_past_path), + providers=providers, + sess_options=session_options, + provider_options=None if provider_options is None else [provider_options], + ) + + return encoder_session, decoder_session, decoder_with_past_session + + def _save_pretrained( + self, + save_directory: Union[str, Path], + # TODO: should we make the default values available here? + encoder_file_name: str = ONNX_ENCODER_NAME, + decoder_file_name: str = ONNX_DECODER_NAME, + decoder_with_past_file_name: str = ONNX_DECODER_WITH_PAST_NAME, + ): + """ + Saves the model encoder, decoder and decoder with past key values as well as its configuration file to a + directory, so that it can be re-loaded using the + [`~optimum.onnxruntime.modeling_seq2seq.ORTModelForSeq2SeqLM.from_pretrained`] class method. + + Args: + save_directory (`Union[str, Path`]): + The directory where to save the model files. + encoder_file_name(`str`, defaults to `optimum.onnxruntime.utils.ONNX_ENCODER_NAME`): + The encoder model file name. Overwrites the default file name and allows one to save the encoder model + with a different name. + decoder_file_name(`str`, defaults to `optimum.onnxruntime.utils.ONNX_DECODER_NAME`): + The decoder model file name. Overwrites the default file name and allows one to save the decoder model + with a different name. + decoder_with_past_file_name(`str`, defaults to `optimum.onnxruntime.ONNX_DECODER_WITH_PAST_NAME`): + The decoder with past key values model file name overwriting the default file name, allowing to save + the decoder model with a different name. + """ + src_file_names = [self.encoder_model_path, self.decoder_model_path] + dst_file_names = [encoder_file_name, decoder_file_name] + if self.use_cache: + src_file_names.append(self.decoder_with_past_model_path) + dst_file_names.append(decoder_with_past_file_name) + + for src_path, dst_file_name in zip(src_file_names, dst_file_names): + dst_path = Path(save_directory) / dst_file_name + shutil.copyfile(src_path, dst_path) + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + config: "PretrainedConfig", + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + force_download: bool = False, + cache_dir: Optional[str] = None, + encoder_file_name: str = ONNX_ENCODER_NAME, + decoder_file_name: str = ONNX_DECODER_NAME, + decoder_with_past_file_name: str = ONNX_DECODER_WITH_PAST_NAME, + subfolder: str = "", + local_files_only: bool = False, + use_cache: bool = True, + provider: str = "CPUExecutionProvider", + session_options: Optional[ort.SessionOptions] = None, + provider_options: Optional[Dict[str, Any]] = None, + use_io_binding: bool = True, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + ): + model_path = Path(model_id) + + if not validate_file_exists(model_id, encoder_file_name, subfolder=subfolder, revision=revision): + encoder_file_name = ORTModelForConditionalGeneration.infer_onnx_filename( + model_id, + ENCODER_ONNX_FILE_PATTERN, + "encoder_file_name", + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + ) + if not validate_file_exists(model_id, decoder_file_name, subfolder=subfolder, revision=revision): + decoder_file_name = ORTModelForConditionalGeneration.infer_onnx_filename( + model_id, + DECODER_ONNX_FILE_PATTERN, + "decoder_file_name", + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + ) + if not validate_file_exists(model_id, decoder_with_past_file_name, subfolder=subfolder, revision=revision): + decoder_with_past_file_name = ORTModelForConditionalGeneration.infer_onnx_filename( + model_id, + DECODER_WITH_PAST_ONNX_FILE_PATTERN, + "decoder_with_past_file_name", + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + fail_if_not_found=use_cache, + ) + + encoder_regular_onnx_filenames = ORTModelForConditionalGeneration._generate_regular_names_for_filename( + ONNX_ENCODER_NAME + ) + decoder_regular_onnx_filenames = ORTModelForConditionalGeneration._generate_regular_names_for_filename( + ONNX_DECODER_NAME + ) + decoder_with_past_regular_onnx_filenames = ( + ORTModelForConditionalGeneration._generate_regular_names_for_filename(ONNX_DECODER_WITH_PAST_NAME) + ) + + if encoder_file_name not in encoder_regular_onnx_filenames: + logger.warning( + f"The ONNX file {encoder_file_name} is not a regular name used in optimum.onnxruntime, the " + "ORTModelForConditionalGeneration might not behave as expected." + ) + + if decoder_file_name not in decoder_regular_onnx_filenames: + logger.warning( + f"The ONNX file {decoder_file_name} is not a regular name used in optimum.onnxruntime, the " + "ORTModelForConditionalGeneration might not behave as expected." + ) + if decoder_with_past_file_name not in decoder_with_past_regular_onnx_filenames: + logger.warning( + f"The ONNX file {decoder_with_past_file_name} is not a regular name used in optimum.onnxruntime, " + "the ORTModelForConditionalGeneration might not behave as expected." + ) + + decoder_with_past_path = model_path / decoder_with_past_file_name if use_cache else None + + preprocessors = None + if model_path.is_dir(): + model = cls.load_model( + encoder_path=model_path / encoder_file_name, + decoder_path=model_path / decoder_file_name, + decoder_with_past_path=decoder_with_past_path, + provider=provider, + session_options=session_options, + provider_options=provider_options, + ) + new_model_save_dir = model_path + preprocessors = maybe_load_preprocessors(model_id) + else: + attribute_name_to_filename = { + "last_encoder_model_name": encoder_file_name, + "last_decoder_model_name": decoder_file_name, + "last_decoder_with_past_model_name": decoder_with_past_file_name if use_cache else None, + } + paths = {} + for attr_name, filename in attribute_name_to_filename.items(): + if filename is None: + continue + model_cache_path = hf_hub_download( + repo_id=model_id, + subfolder=subfolder, + filename=filename, + use_auth_token=use_auth_token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + ) + paths[attr_name] = Path(model_cache_path).name + new_model_save_dir = Path(model_cache_path).parent + preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder) + + last_decoder_with_past_name = paths.get("last_decoder_with_past_model_name", None) + if last_decoder_with_past_name is not None: + last_decoder_with_past_name = new_model_save_dir / last_decoder_with_past_name + + model = cls.load_model( + encoder_path=new_model_save_dir / paths["last_encoder_model_name"], + decoder_path=new_model_save_dir / paths["last_decoder_model_name"], + decoder_with_past_path=last_decoder_with_past_name, + provider=provider, + session_options=session_options, + provider_options=provider_options, + ) + + if model_save_dir is None: + model_save_dir = new_model_save_dir + + return cls( + *model[:2], + config, + decoder_with_past_session=model[2], + use_io_binding=use_io_binding, + model_save_dir=model_save_dir, + preprocessors=preprocessors, + ) + + @classmethod + def _from_transformers( + cls, + model_id: str, + config: "PretrainedConfig", + use_auth_token: Optional[Union[bool, str]] = None, + revision: str = "main", + force_download: bool = True, + cache_dir: Optional[str] = None, + subfolder: str = "", + local_files_only: bool = False, + use_cache: bool = True, + provider: str = "CPUExecutionProvider", + session_options: Optional[ort.SessionOptions] = None, + provider_options: Optional[Dict[str, Any]] = None, + use_io_binding: bool = True, + task: Optional[str] = None, + ) -> "ORTModelForConditionalGeneration": + if task is None: + task = cls._AUTOMODELS_TO_TASKS[cls.auto_model_class] + + save_dir = TemporaryDirectory() + save_dir_path = Path(save_dir.name) + + model = TasksManager.get_model_from_task( + task, + model_id, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + config=config, + use_auth_token=use_auth_token, + local_files_only=local_files_only, + force_download=force_download, + ) + + model_type = model.config.model_type.replace("_", "-") + model_name = getattr(model, "name", None) + + onnx_config_constructor = TasksManager.get_exporter_config_constructor( + model_type, "onnx", task=task, model_name=model_name + ) + onnx_config = onnx_config_constructor(model.config, use_past=use_cache) + + output_names = [ONNX_ENCODER_NAME, ONNX_DECODER_NAME] + if use_cache is True: + output_names.append(ONNX_DECODER_WITH_PAST_NAME) + export_models( + model=model, + onnx_config=onnx_config, + opset=onnx_config.DEFAULT_ONNX_OPSET, + output_dir=save_dir_path, + fn_get_models_from_config=get_encoder_decoder_models_for_export, + output_names=output_names, + ) + + config.save_pretrained(save_dir_path) + maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder) + + return cls._from_pretrained( + save_dir_path, + config, + use_cache=use_cache, + provider=provider, + session_options=session_options, + provider_options=provider_options, + use_io_binding=use_io_binding, + model_save_dir=save_dir, + ) + + def to(self, device: Union[torch.device, str, int]): + """ + Changes the ONNX Runtime provider according to the device. + + Args: + device (`torch.device` or `str` or `int`): + Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run + the model on the associated CUDA device id. You can pass native `torch.device` or a `str` too. + + Returns: + `ORTModel`: the model placed on the requested device. + """ + device, provider_options = parse_device(device) + + provider = get_provider_for_device(device) + validate_provider_availability(provider) # raise error if the provider is not available + + self.device = device + self.encoder._device = device + self.encoder.session.set_providers([provider], provider_options=[provider_options]) + self.decoder._device = device + self.decoder.session.set_providers([provider], provider_options=[provider_options]) + if self.decoder_with_past is not None: + self.decoder_with_past._device = device + self.decoder_with_past.session.set_providers([provider], provider_options=[provider_options]) + self.providers = self.encoder.session.get_providers() + + return self + + class ORTModelForSeq2SeqLM(ORTModelForConditionalGeneration, GenerationMixin): """ Sequence-to-sequence model with a language modeling head for ONNX Runtime inference. """ - export_feature = "seq2seq-lm" auto_model_class = AutoModelForSeq2SeqLM main_input_name = "input_ids" @@ -1210,7 +1265,6 @@ class ORTModelForSpeechSeq2Seq(ORTModelForConditionalGeneration, GenerationMixin Speech Sequence-to-sequence model with a language modeling head for ONNX Runtime inference. """ - export_feature = "speech2seq-lm" auto_model_class = AutoModelForSpeechSeq2Seq main_input_name = "input_features" diff --git a/optimum/onnxruntime/optimization.py b/optimum/onnxruntime/optimization.py index 015e7e765d..527ec56115 100644 --- a/optimum/onnxruntime/optimization.py +++ b/optimum/onnxruntime/optimization.py @@ -25,6 +25,7 @@ from onnxruntime.transformers.optimizer import optimize_model from ..utils import CONFIG_NAME, NormalizedConfigManager +from ..utils.save_utils import maybe_save_preprocessors from .configuration import OptimizationConfig, ORTConfig from .modeling_ort import ORTModel from .modeling_seq2seq import ORTModelForSeq2SeqLM @@ -48,7 +49,7 @@ def __init__(self, onnx_model_path: List[os.PathLike], config: "PretrainedConfig Args: onnx_model_path (`List[os.PathLike]`): The paths of the onnx models to optimize. - config ([`~PretrainedConfig`]): + config ([`~transformers.PretrainedConfig`]): An instance of the configuration associated to the model to optimize. """ super().__init__() @@ -67,24 +68,23 @@ def from_pretrained( The path to a local directory hosting the model to optimize or an instance of an `ORTModel` to quantize. Can be either: - A path to a local *directory* containing the model to optimize. - - An instance of ORTModel. - file_names(`List[str]`, *optional*): + - An instance of [`~optimum.onnxruntime.ORTModel`]. + file_names(`Optional[List[str]]`, *optional*): The list of file names of the models to optimize. """ onnx_model_path = [] config = None if isinstance(model_or_path, ORTModel): if isinstance(model_or_path, ORTModelForSeq2SeqLM): - model_save_dir = model_or_path.model_save_dir - onnx_model_path = [ - model_save_dir.joinpath(model_or_path.encoder_file_name), - model_save_dir.joinpath(model_or_path.decoder_file_name), + onnx_model_path += [ + model_or_path.encoder_model_path, + model_or_path.decoder_model_path, ] # Add the decoder with past key/values if present if model_or_path.use_cache: - onnx_model_path.append(model_save_dir.joinpath(model_or_path.decoder_file_with_past_name)) + onnx_model_path.append(model_or_path.decoder_with_past_model_path) else: - onnx_model_path = [model_or_path.model_save_dir.joinpath(model_or_path.latest_model_name)] + onnx_model_path.append(model_or_path.model_path) config = model_or_path.config elif os.path.isdir(model_or_path): file_names = [ONNX_WEIGHTS_NAME] if file_names is None else file_names @@ -110,7 +110,7 @@ def optimize( Optimizes a model given the optimization specifications defined in `optimization_config`. Args: - optimization_config (`OptimizationConfig`): + optimization_config ([`~optimum.onnxruntime.OptimizationConfig`]): The configuration containing the parameters related to optimization. save_dir (`Union[str, os.PathLike]`): The path used to save the optimized model. @@ -127,6 +127,9 @@ def optimize( save_dir.mkdir(parents=True, exist_ok=True) ORTConfigManager.check_optimization_supported_model(self.model_type) + self.config.save_pretrained(save_dir) + maybe_save_preprocessors(self.onnx_model_path[0].parent, save_dir) + # Create and save the configuration summarizing all the parameters related to optimization ort_config = ORTConfig( optimization=optimization_config, diff --git a/optimum/onnxruntime/quantization.py b/optimum/onnxruntime/quantization.py index 39d64e0e18..ddf4fdd7f6 100644 --- a/optimum/onnxruntime/quantization.py +++ b/optimum/onnxruntime/quantization.py @@ -11,37 +11,40 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Classes handling quantization with ONNX Runtime.""" import logging import os -from abc import ABC from collections import defaultdict from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union from datasets import Dataset, load_dataset from packaging.version import Version, parse +from transformers import AutoConfig import onnx from onnxruntime import __version__ as ort_version from onnxruntime.quantization import CalibrationDataReader, QuantFormat, QuantizationMode, QuantType from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer from onnxruntime.quantization.qdq_quantizer import QDQQuantizer -from optimum.onnxruntime import ORTQuantizableOperator -from optimum.onnxruntime.configuration import CalibrationConfig, NodeName, NodeType, ORTConfig, QuantizationConfig -from optimum.onnxruntime.modeling_ort import ORTModel -from optimum.onnxruntime.modeling_seq2seq import ORTModelForConditionalGeneration -from optimum.onnxruntime.preprocessors import QuantizationPreprocessor -from optimum.onnxruntime.utils import ONNX_WEIGHTS_NAME -from optimum.quantization_base import OptimumQuantizer +from ..quantization_base import OptimumQuantizer +from ..utils.save_utils import maybe_save_preprocessors +from . import ORTQuantizableOperator +from .configuration import CalibrationConfig, NodeName, NodeType, ORTConfig, QuantizationConfig +from .modeling_ort import ORTModel +from .modeling_seq2seq import ORTModelForConditionalGeneration +from .preprocessors import QuantizationPreprocessor + + +if TYPE_CHECKING: + from transformers import PretrainedConfig LOGGER = logging.getLogger(__name__) class ORTCalibrationDataReader(CalibrationDataReader): - """ """ - __slots__ = ["batch_size", "dataset", "_dataset_iter"] def __init__(self, dataset: Dataset, batch_size: int = 1): @@ -83,65 +86,79 @@ class ORTQuantizer(OptimumQuantizer): Handles the ONNX Runtime quantization process for models shared on huggingface.co/models. """ - def __init__(self, onnx_model_path: List[Path]): + def __init__(self, onnx_model_path: Path, config: Optional["PretrainedConfig"] = None): """ Args: onnx_model_path (`Path`): Path to the onnx model files you want to quantize. + config (`Optional[PretrainedConfig]`, *optional*): + The configuration of the model. """ super().__init__() self.onnx_model_path = onnx_model_path + self.config = config + if self.config is None: + try: + self.config = AutoConfig.from_pretrained(self.onnx_model_path.parent) + except OSError: + LOGGER.warning( + f"Could not load the config for {self.onnx_model_path} automatically, this might make " + "the quantized model harder to use because it will not be able to be loaded by an ORTModel without " + "having to specify the configuration explicitly." + ) self._calibrator = None @classmethod def from_pretrained( cls, - model_or_path: Union[str, Path], + model_or_path: Union["ORTModel", str, Path], file_name: Optional[str] = None, ) -> "ORTQuantizer": """ - Instantiate a `ORTQuantizer` from a pretrained pytorch model and preprocessor. + Instantiates a `ORTQuantizer` from a an ONNX model file or an `ORTModel`. Args: - model_or_path (`Union[str, Path]`): + model_or_path (`Union[ORTModel, str, Path]`): Can be either: - A path to a saved exported ONNX Intermediate Representation (IR) model, e.g., `./my_model_directory/. - - Or a `ORTModelForXX` class, e.g., `ORTModelForQuestionAnswering`. - file_name(`Union[str, List[str]]`, *optional*): + - Or an `ORTModelForXX` class, e.g., `ORTModelForQuestionAnswering`. + file_name(`Optional[str]`, *optional*): Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load different model files from the same repository or directory. Returns: An instance of `ORTQuantizer`. """ - # define the file name for the quantizable models - if file_name is None: - if isinstance(model_or_path, ORTModel): - if isinstance(model_or_path, ORTModelForConditionalGeneration): - raise ValueError( - "ORTQuantizer does not support multi-file quantization. Please create separate ORTQuantizer instances for each model/file." - ) - model_file_name = model_or_path.latest_model_name - else: - model_file_name = ONNX_WEIGHTS_NAME - else: - model_file_name = file_name + ort_quantizer_error_message = "ORTQuantizer does not support multi-file quantization. Please create separate ORTQuantizer instances for each model/file." + + if isinstance(model_or_path, str): + model_or_path = Path(model_or_path) + + if isinstance(model_or_path, ORTModelForConditionalGeneration): + raise ValueError(ort_quantizer_error_message) + elif isinstance(model_or_path, Path): + onnx_files = list(model_or_path.glob("*.onnx")) + if len(onnx_files) == 0: + raise FileNotFoundError(f"Could not find any ONNX model file in {model_or_path}") + elif len(onnx_files) > 1: + raise RuntimeError( + f"Found too many ONNX model files in {model_or_path}. {ort_quantizer_error_message}" + ) + file_name = onnx_files[0].name - # create ORTQuantizer based on the provided input + path = None if isinstance(model_or_path, ORTModel): - return cls(model_or_path.model_save_dir.joinpath(model_file_name)) - # load from local path + path = Path(model_or_path.model._model_path) elif os.path.isdir(model_or_path): - if not isinstance(model_or_path, Path): - model_or_path = Path(model_or_path) - return cls(model_or_path.joinpath(model_file_name)) + path = Path(model_or_path) / file_name else: raise ValueError(f"Unable to load model from {model_or_path}.") + return cls(path) def fit( self, dataset: Dataset, calibration_config: CalibrationConfig, - onnx_augmented_model_name: str = "augmented_model.onnx", + onnx_augmented_model_name: Union[str, Path] = "augmented_model.onnx", operators_to_quantize: Optional[List[NodeType]] = None, batch_size: int = 1, use_external_data_format: bool = False, @@ -149,24 +166,24 @@ def fit( force_symmetric_range: bool = False, ) -> Dict[str, Tuple[float, float]]: """ - Perform the calibration step and collect the quantization ranges. + Performs the calibration step and collect the quantization ranges. Args: dataset (`Dataset`): The dataset to use when performing the calibration step. calibration_config (`CalibrationConfig`): The configuration containing the parameters related to the calibration step. - onnx_augmented_model_name (`Union[str, os.PathLike]`): + onnx_augmented_model_name (`Union[str, Path]`, *optional*, defaults to `"augmented_model.onnx"`): The path used to save the augmented model used to collect the quantization ranges. - operators_to_quantize (`list`, *optional*): + operators_to_quantize (`Optional[List[NodeType]]`, *optional*): List of the operators types to quantize. - batch_size (`int`, defaults to 1): + batch_size (`int`, *optional*, defaults to 1): The batch size to use when collecting the quantization ranges values. use_external_data_format (`bool`, defaults to `False`): Whether to use external data format to store model which size is >= 2Gb. use_gpu (`bool`, defaults to `False`): Whether to use the GPU when collecting the quantization ranges values. - force_symmetric_range (`bool`, defaults to `False`): + force_symmetric_range (`bool`, *optional*, defaults to `False`): Whether to make the quantization ranges symmetric. Returns: @@ -195,7 +212,7 @@ def partial_fit( self, dataset: Dataset, calibration_config: CalibrationConfig, - onnx_augmented_model_name: str = "augmented_model.onnx", + onnx_augmented_model_name: Union[str, Path] = "augmented_model.onnx", operators_to_quantize: Optional[List[NodeType]] = None, batch_size: int = 1, use_external_data_format: bool = False, @@ -203,24 +220,24 @@ def partial_fit( force_symmetric_range: bool = False, ): """ - Perform the calibration step and collect the quantization ranges. + Performs the calibration step and collect the quantization ranges. Args: dataset (`Dataset`): The dataset to use when performing the calibration step. calibration_config (`CalibrationConfig`): The configuration containing the parameters related to the calibration step. - onnx_augmented_model_name (`Union[str, os.PathLike]`): + onnx_augmented_model_name (`Union[str, Path]`, *optional*, defaults to `"augmented_model.onnx"`): The path used to save the augmented model used to collect the quantization ranges. - operators_to_quantize (`list`, *optional*): + operators_to_quantize (`Optional[List[NodeType]]`, *optional*): List of the operators types to quantize. - batch_size (`int`, defaults to 1): + batch_size (`int`, *optional*, defaults to 1): The batch size to use when collecting the quantization ranges values. - use_external_data_format (`bool`, defaults to `False`): + use_external_data_format (`bool`, *optional*, defaults to `False`): Whether uto se external data format to store model which size is >= 2Gb. - use_gpu (`bool`, defaults to `False`): + use_gpu (`bool`, *optional*, defaults to `False`): Whether to use the GPU when collecting the quantization ranges values. - force_symmetric_range (`bool`, defaults to `False`): + force_symmetric_range (`bool`, *optional*, defaults to `False`): Whether to make the quantization ranges symmetric. Returns: @@ -267,21 +284,21 @@ def quantize( preprocessor: Optional[QuantizationPreprocessor] = None, ) -> Path: """ - Quantize a model given the optimization specifications defined in `quantization_config`. + Quantizes a model given the optimization specifications defined in `quantization_config`. Args: quantization_config (`QuantizationConfig`): The configuration containing the parameters related to quantization. save_dir (`Union[str, Path]`): The directory where the quantized model should be saved. - file_suffix (`str`, *optional*, defaults to `"quantized"`): + file_suffix (`Optional[str]`, *optional*, defaults to `"quantized"`): The file_suffix used to save the quantized model. - calibration_tensors_range (`Dict[NodeName, Tuple[float, float]]`, *optional*): + calibration_tensors_range (`Optional[Dict[NodeName, Tuple[float, float]]]`, *optional*): The dictionary mapping the nodes name to their quantization ranges, used and required only when applying static quantization. - use_external_data_format (`bool`, defaults to `False`): + use_external_data_format (`bool`, *optional*, defaults to `False`): Whether to use external data format to store model which size is >= 2Gb. - preprocessor (`QuantizationPreprocessor`, *optional*): + preprocessor (`Optional[QuantizationPreprocessor]`, *optional*): The preprocessor to use to collect the nodes to include or exclude from quantization. Returns: @@ -388,6 +405,11 @@ def quantize( ort_config = ORTConfig(quantization=quantization_config, use_external_data_format=use_external_data_format) ort_config.save_pretrained(save_dir) + if self.config is not None: + self.config.save_pretrained(save_dir) + + maybe_save_preprocessors(self.onnx_model_path.parent, save_dir) + return Path(save_dir) def get_calibration_dataset( @@ -402,25 +424,25 @@ def get_calibration_dataset( use_auth_token: bool = False, ) -> Dataset: """ - Create the calibration `datasets.Dataset` to use for the post-training static quantization calibration step + Creates the calibration `datasets.Dataset` to use for the post-training static quantization calibration step. Args: dataset_name (`str`): The dataset repository name on the Hugging Face Hub or path to a local directory containing data files to load to use for the calibration step. - num_samples (`int`, defaults to 100): + num_samples (`int`, *optional*, defaults to 100): The maximum number of samples composing the calibration dataset. - dataset_config_name (`str`, *optional*): + dataset_config_name (`Optional[str]`, *optional*): The name of the dataset configuration. - dataset_split (`str`, *optional*): + dataset_split (`Optional[str]`, *optional*): Which split of the dataset to use to perform the calibration step. - preprocess_function (`Callable`, *optional*): + preprocess_function (`Optional[Callable]`, *optional*): Processing function to apply to each example after loading dataset. - preprocess_batch (`bool`, defaults to `True`): + preprocess_batch (`bool`, *optional*, defaults to `True`): Whether the `preprocess_function` should be batched. - seed (`int`, defaults to 2016): + seed (`int`, *optional*, defaults to 2016): The random seed to use when shuffling the calibration dataset. - use_auth_token (`bool`, defaults to `False`): + use_auth_token (`bool`, *optional*, defaults to `False`): Whether to use the token generated when running `transformers-cli login` (necessary for some datasets like ImageNet). Returns: diff --git a/optimum/onnxruntime/trainer.py b/optimum/onnxruntime/trainer.py index edc4181dc1..918995632d 100644 --- a/optimum/onnxruntime/trainer.py +++ b/optimum/onnxruntime/trainer.py @@ -104,7 +104,7 @@ ) from .modeling_seq2seq import ORTModelForSeq2SeqLM from .training_args import ORTOptimizerNames, ORTTrainingArguments -from .utils import wrap_onnx_config_for_loss +from .utils import is_onnxruntime_training_available, wrap_onnx_config_for_loss if is_apex_available(): @@ -119,6 +119,7 @@ if TYPE_CHECKING: import optuna + logger = logging.get_logger(__name__) # Name of the files used for checkpointing @@ -237,12 +238,12 @@ def __init__( args: ORTTrainingArguments = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[Dataset] = None, - eval_dataset: Optional[Dataset] = None, - model_init: Callable[[], PreTrainedModel] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None, + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, onnx_model_path: Union[str, os.PathLike] = None, ): @@ -289,6 +290,11 @@ def train( kwargs: Additional keyword arguments used to hide deprecated arguments """ + if not is_onnxruntime_training_available(): + raise ImportError( + "You need to install `onnxruntime-training` to use `ORTTrainer` for training. Check out " + "https://huggingface.co/docs/optimum/onnxruntime/usage_guides/trainer#install-onnx-runtime." + ) if resume_from_checkpoint is False: resume_from_checkpoint = None @@ -441,10 +447,15 @@ def _inner_training_loop( deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint ) - self.model = deepspeed_engine.module + self.model = unwrap_model(deepspeed_engine) self.model_wrapped = deepspeed_engine self.deepspeed = deepspeed_engine - self.optimizer = optimizer + if args.fp16: + from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer + + self.optimizer = FP16_Optimizer(optimizer) + else: + self.optimizer = optimizer self.lr_scheduler = lr_scheduler elif not delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) @@ -472,7 +483,7 @@ def _inner_training_loop( self._load_optimizer_and_scheduler(resume_from_checkpoint) # Important: at this point if enabled distributed training features: - # self.model is the ORTModule(Transformers Model) + # self.model is the Transformers Model # self.model_wrapped is DDP(ORTModule(Transformers Model)), Deepspeed(ORTModule(Transformers Model)), etc. # Train! @@ -807,6 +818,8 @@ def evaluate( raise total_batch_size = self.args.eval_batch_size * self.args.world_size + if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] output.metrics.update( speed_metrics( metric_key_prefix, @@ -894,6 +907,8 @@ def predict( raise total_batch_size = self.args.eval_batch_size * self.args.world_size + if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] output.metrics.update( speed_metrics( metric_key_prefix, @@ -1123,6 +1138,8 @@ def evaluation_loop_ort( if all_losses is not None: metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() + if hasattr(self, "jit_compilation_time"): + metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time # Prefix all keys with metric_key_prefix + '_' for key in list(metrics.keys()): @@ -1319,7 +1336,15 @@ def prediction_step_ort( logits and labels (each being optional). """ - has_labels = all(inputs.get(k) is not None for k in self.label_names) + has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) + # For CLIP-like models capable of returning loss values. + # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` + # is `True` in `model.forward`. + return_loss = inputs.get("return_loss", None) + if return_loss is None: + return_loss = self.can_return_loss + loss_without_labels = True if len(self.label_names) == 0 and return_loss else False + inputs = self._prepare_inputs(inputs) if ignore_keys is None: @@ -1329,7 +1354,7 @@ def prediction_step_ort( ignore_keys = [] # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. - if has_labels: + if has_labels or loss_without_labels: labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) if len(labels) == 1: labels = labels[0] @@ -1342,7 +1367,7 @@ def prediction_step_ort( "Sagemaker's distributed data parallel features are not supported by `ORTTrainer` yet." ) else: - if has_labels: + if has_labels or loss_without_labels: with self.compute_loss_context_manager(): loss, outputs = self.compute_loss_ort(model, inputs, return_outputs=True) loss = torch.tensor(loss).mean() @@ -1455,6 +1480,12 @@ def _export( ) def _wrap_model(self, model, training=True, dataloader=None): + # TODO: torchdynamo works for inference with PyTorch in ORTTrainer, will move `inference_with_ort` to training arguments and + # whether be able to use ipex will depend on both `self.args.torchdynamo` and `self.args.ort_mode_eval`. + if self.args.torchdynamo is not None: + import torch._dynamo as dynamo + + model = dynamo.optimize(self.args.torchdynamo)(model) # TODO: ipex only works with inference with PyTorch, will move `inference_with_ort` to training arguments and # whether be able to use ipex will depend on both `self.args.use_ipex` and `self.args.ort_mode_eval`. @@ -1462,11 +1493,6 @@ def _wrap_model(self, model, training=True, dataloader=None): dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 model = self.ipex_optimize_model(model, training, dtype=dtype) - # TODO: jit_mode_eval only works with inference with PyTorch, will move `inference_with_ort` to training arguments and - # whether be able to use jit_mode_eval will depend on both `self.args.jit_mode_eval` and `self.args.ort_mode_eval`. - if self.args.jit_mode_eval: - model = self.torch_jit_model_eval(model, dataloader, training) - if is_sagemaker_mp_enabled(): raise NotImplementedError( "Sagemaker's distrubuted data parallel features are not supported by `ORTTrainer`." @@ -1487,10 +1513,20 @@ def _wrap_model(self, model, training=True, dataloader=None): if self.use_apex and training: model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) + if args.fp16: + from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer + + self.optimizer = FP16_Optimizer(self.optimizer) + # Multi-gpu training (should be after apex fp16 initialization) if self.args.n_gpu > 1: model = nn.DataParallel(model) + if self.args.jit_mode_eval: + start_time = time.time() + model = self.torch_jit_model_eval(model, dataloader, training) + self.jit_compilation_time = round(time.time() - start_time, 4) + # Note: in torch.distributed mode, there's no point in wrapping the model # inside a DistributedDataParallel as we'll be under `no_grad` anyways. if not training: diff --git a/optimum/onnxruntime/training_args.py b/optimum/onnxruntime/training_args.py index c0b5c5f984..85456c7ad7 100644 --- a/optimum/onnxruntime/training_args.py +++ b/optimum/onnxruntime/training_args.py @@ -138,23 +138,32 @@ def __post_init__(self): self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"] if self.run_name is None: self.run_name = self.output_dir + if self.framework == "pt" and is_torch_available(): + if self.fp16_backend and self.fp16_backend != "auto": + warnings.warn( + "`fp16_backend` is deprecated and will be removed in version 5 of 🤗 Transformers. Use" + " `half_precision_backend` instead", + FutureWarning, + ) + self.half_precision_backend = self.fp16_backend - if self.fp16_backend and self.fp16_backend != "auto": - warnings.warn( - "`fp16_backend` is deprecated and will be removed in version 5 of 🤗 Transformers. Use" - " `half_precision_backend` instead", - FutureWarning, - ) - self.half_precision_backend = self.fp16_backend + if self.bf16 or self.bf16_full_eval: - if (self.bf16 or self.bf16_full_eval) and not is_torch_bf16_available() and not self.no_cuda: - raise ValueError( - "Your setup doesn't support bf16. You need torch>=1.10, using Ampere GPU with cuda>=11.0 or using CPU" - " (no_cuda)" - ) + if self.no_cuda and not is_torch_bf16_cpu_available(): + # cpu + raise ValueError("Your setup doesn't support bf16/cpu. You need torch>=1.10") + elif not self.no_cuda and torch.cuda.is_available() and not is_torch_bf16_gpu_available(): + # gpu + raise ValueError( + "Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0" + ) if self.fp16 and self.bf16: raise ValueError("At most one of fp16 and bf16 can be True, but not both") + + if self.fp16_full_eval and self.bf16_full_eval: + raise ValueError("At most one of fp16 and bf16 can be True for full eval, but not both") + if self.bf16: if self.half_precision_backend == "apex": raise ValueError( @@ -199,6 +208,15 @@ def __post_init__(self): " (`--bf16_full_eval`) can only be used on CUDA or CPU devices." ) + if self.framework == "pt" and is_torch_available() and self.torchdynamo is not None: + if is_torch_tf32_available(): + if self.tf32 is None and not self.fp16 or self.bf16: + logger.info("Setting TF32 in CUDA backends to speedup torchdynamo.") + torch.backends.cuda.matmul.allow_tf32 = True + else: + logger.warning( + "The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here." + ) if is_torch_available() and self.tf32 is not None: if self.tf32: if is_torch_tf32_available(): diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 81375c9790..a0cfa2e2ae 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -13,9 +13,10 @@ # limitations under the License. """Utility functions, classes and constants for ONNX Runtime.""" +import importlib.util import os from enum import Enum -from typing import Dict, Tuple, Type, Union +from typing import Dict, Tuple, Union import torch from transformers.onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast @@ -23,16 +24,14 @@ import onnx import onnxruntime as ort +import pkg_resources from ..onnx import OnnxConfigWithLoss, OnnxConfigWithPastAndLoss, OnnxSeq2SeqConfigWithPastAndLoss -from ..utils import NormalizedTextConfig logger = logging.get_logger(__name__) ONNX_WEIGHTS_NAME = "model.onnx" -OPTIMIZED_ONNX_WEIGHTS_NAME = "optimized_model.onnx" -QUANTIZED_ONNX_WEIGHTS_NAME = "q8_model.onnx" ONNX_ENCODER_NAME = "encoder_model.onnx" ONNX_DECODER_NAME = "decoder_model.onnx" @@ -41,7 +40,7 @@ def _is_gpu_available(): """ - checks if a gpu is available. + Checks if a gpu is available. """ available_providers = ort.get_available_providers() if "CUDAExecutionProvider" in available_providers and torch.cuda.is_available(): @@ -50,6 +49,24 @@ def _is_gpu_available(): return False +def is_onnxruntime_training_available(): + """ + Checks if onnxruntime-training is available. + """ + path_training_dependecy = os.path.join(ort.__path__[0], "training") + if os.path.exists(path_training_dependecy): + return True + else: + return False + + +def is_cupy_available(): + """ + Checks if onnxruntime-training is available. + """ + return importlib.util.find_spec("cupy") is not None + + class ORTConfigManager: """ A class that contains all the information needed by ONNX Runtime optimization for a given model type. diff --git a/optimum/utils/file_utils.py b/optimum/utils/file_utils.py new file mode 100644 index 0000000000..bfd62d9f3c --- /dev/null +++ b/optimum/utils/file_utils.py @@ -0,0 +1,88 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions related to both local files and files on the Hugging Face Hub.""" + +import re +from pathlib import Path +from typing import List, Optional, Union + +from huggingface_hub import HfApi, HfFolder, get_hf_file_metadata, hf_hub_url + + +def validate_file_exists( + model_name_or_path: Union[str, Path], filename: str, subfolder: str = "", revision: Optional[str] = None +) -> bool: + """ + Checks that the file called `filename` exists in the `model_name_or_path` directory or model repo. + """ + model_path = Path(model_name_or_path) if isinstance(model_name_or_path, str) else model_name_or_path + if model_path.is_dir(): + return (model_path / subfolder / filename).is_file() + succeeded = True + try: + get_hf_file_metadata(hf_hub_url(model_name_or_path, filename, subfolder=subfolder, revision=revision)) + except Exception: + succeeded = False + return succeeded + + +def find_files_matching_pattern( + model_name_or_path: Union[str, Path], + pattern: str, + glob_pattern: str = "**/*", + subfolder: str = "", + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, +) -> List[Path]: + """ + Scans either a model repo or a local directory to find filenames matching the pattern. + + Args: + model_name_or_path (`Union[str, Path]`): + The name of the model repo on the Hugging Face Hub or the path to a local directory. + pattern (`str`): + The pattern to use to look for files. + glob_pattern (`str`, defaults to `"**/*"`): + The pattern to use to list all the files that need to be checked. + subfolder (`str`, defaults to `""`): + In case the model files are located inside a subfolder of the model directory / repo on the Hugging + Face Hub, you can specify the subfolder name here. + use_auth_token (`Optional[bool, str]`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`Optional[str]`, defaults to `None`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + + Returns: + `List[Path]` + """ + model_path = Path(model_name_or_path) if isinstance(model_name_or_path, str) else model_name_or_path + pattern = re.compile(f"{subfolder}/{pattern}" if subfolder != "" else pattern) + if model_path.is_dir(): + path = model_path + files = model_path.glob("**/*.onnx") + files = [p for p in files if re.search(pattern, str(p))] + else: + path = model_name_or_path + if isinstance(use_auth_token, bool): + token = HfFolder().get_token() + else: + token = use_auth_token + repo_files = map(Path, HfApi().list_repo_files(model_name_or_path, revision=revision, token=token)) + if subfolder != "": + path = f"{path}/{subfolder}" + files = [Path(p) for p in repo_files if re.match(pattern, str(p))] + + return files diff --git a/optimum/utils/save_utils.py b/optimum/utils/save_utils.py new file mode 100644 index 0000000000..3d5550a2fd --- /dev/null +++ b/optimum/utils/save_utils.py @@ -0,0 +1,64 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities related to saving files.""" + +import logging +from pathlib import Path +from typing import List, Union + +from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer + + +logger = logging.getLogger(__name__) + + +def maybe_load_preprocessors(src_name_or_path: Union[str, Path], subfolder: str = "") -> List: + preprocessors = [] + try: + preprocessors.append(AutoTokenizer.from_pretrained(src_name_or_path, subfolder=subfolder)) + except Exception: + pass + + try: + preprocessors.append(AutoProcessor.from_pretrained(src_name_or_path, subfolder=subfolder)) + except Exception: + pass + + try: + preprocessors.append(AutoFeatureExtractor.from_pretrained(src_name_or_path, subfolder=subfolder)) + except Exception: + pass + return preprocessors + + +def maybe_save_preprocessors(src_name_or_path: Union[str, Path], dest_dir: Union[str, Path], src_subfolder: str = ""): + """ + Saves the tokenizer, the processor and the feature extractor when found in `src_dir` in `dest_dir`. + + Args: + src_dir (`Union[str, Path]`): + The source directory from which to copy the files. + dest_dir (`Union[str, Path]`): + The destination directory to copy the files to. + src_subfolder (`str`, defaults to `""`): + In case the preprocessor files are located inside a subfolder of the model directory / repo on the Hugging + Face Hub, you can specify the subfolder name here. + """ + if not isinstance(dest_dir, Path): + dest_dir = Path(dest_dir) + + dest_dir.mkdir(exist_ok=True) + for preprocessor in maybe_load_preprocessors(src_name_or_path, subfolder=src_subfolder): + preprocessor.save_pretrained(dest_dir) diff --git a/optimum/utils/testing_utils.py b/optimum/utils/testing_utils.py index 15c56d863b..1a26707461 100644 --- a/optimum/utils/testing_utils.py +++ b/optimum/utils/testing_utils.py @@ -1,3 +1,4 @@ +import collections import importlib.util import itertools import os @@ -11,6 +12,20 @@ from optimum.utils import is_accelerate_available +def flatten_dict(dictionary: Dict): + """ + Flatten a nested dictionaries as a flat dictionary. + """ + items = [] + for k, v in dictionary.items(): + new_key = k + if isinstance(v, collections.MutableMapping): + items.extend(flatten_dict(v).items()) + else: + items.append((new_key, v)) + return dict(items) + + def require_accelerate(test_case): """ Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed. diff --git a/tests/bettertransformer/test_bettertransformer_encoder.py b/tests/bettertransformer/test_bettertransformer_encoder.py index bd1a8d78d7..980458bc51 100644 --- a/tests/bettertransformer/test_bettertransformer_encoder.py +++ b/tests/bettertransformer/test_bettertransformer_encoder.py @@ -309,7 +309,10 @@ def get_batch(batch_size, avg_seqlen, max_sequence_length, seqlen_stdev, vocab_s mean_tensor = torch.Tensor([avg_seqlen]).expand(batch_size) stdev_tensor = torch.Tensor([seqlen_stdev]).expand(batch_size) lengths = torch.normal(mean_tensor, stdev_tensor).to(torch.int) - lengths = torch.clamp(lengths, min=0, max=max_sequence_length) + + # need at least a sequence length of 1 for BetterTransformer to work + lengths = torch.clamp(lengths, min=1, max=max_sequence_length) + tokens = torch.full( (batch_size, max_sequence_length), pad_idx, diff --git a/tests/bettertransformer/test_bettertransformer_vision.py b/tests/bettertransformer/test_bettertransformer_vision.py index 0f860e0a6c..6f0b2c56fd 100644 --- a/tests/bettertransformer/test_bettertransformer_vision.py +++ b/tests/bettertransformer/test_bettertransformer_vision.py @@ -15,9 +15,12 @@ import unittest from PIL import Image -from transformers import AutoFeatureExtractor, AutoProcessor +from transformers import AutoFeatureExtractor, AutoModel, AutoProcessor import requests +from optimum.bettertransformer import BetterTransformer +from optimum.utils.testing_utils import grid_parameters +from parameterized import parameterized from testing_bettertransformer_utils import BetterTransformersTestMixin @@ -32,6 +35,12 @@ ALL_VISION_TEXT_MODELS_TO_TEST = [ "hf-internal-testing/tiny-vilt-random-vqa", + "ybelkada/tiny-random-flava", +] + +ALL_ZERO_SHOT_IMAGE_CLASSIFICATION = [ + "hf-internal-testing/tiny-random-clip-zero-shot-image-classification", # with quick_gelu + "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", # with gelu ] @@ -57,6 +66,83 @@ class BetterTransformersViLTTest(BetterTransformersTestMixin, unittest.TestCase) """ all_models_to_test = ALL_VISION_TEXT_MODELS_TO_TEST + def prepare_inputs_for_class(self, model_id=None): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + text = "How many cats are there?" + + # Model takes image and text as input + processor = AutoProcessor.from_pretrained(model_id) + inputs = processor(images=image, text=text, return_tensors="pt") + return inputs + + +class BetterTransformersCLIPTest(BetterTransformersTestMixin, unittest.TestCase): + r""" + Testing suite for Vision and Text Models - tests all the tests defined in `BetterTransformersTestMixin` + """ + all_models_to_test = ALL_ZERO_SHOT_IMAGE_CLASSIFICATION + + def prepare_inputs_for_class(self, model_id, **preprocessor_kwargs): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + text = ["a photo", "a photo of dog", "a photo of two big cats"] + + # Model takes image and text as input + processor = AutoProcessor.from_pretrained(model_id) + inputs = processor(images=image, text=text, return_tensors="pt", **preprocessor_kwargs) + return inputs + + def compare_outputs(self, hf_hidden_states, bt_hidden_states, atol: float, model_name: str): + # CLIP returns a 2D tensor + self.assert_equal( + tensor1=hf_hidden_states, + tensor2=bt_hidden_states, + atol=atol, + model_name=model_name, + ) + + # run the test over all possible combinations of `model_id` and `padding` + @parameterized.expand( + grid_parameters( + { + "model_id": ALL_ZERO_SHOT_IMAGE_CLASSIFICATION, + "padding": ["max_length", True], + } + ) + ) + def test_logits(self, model_id, padding, max_length=20): + super().test_logits([model_id], padding=padding, max_length=max_length) + + @parameterized.expand( + grid_parameters( + { + "model_id": ALL_ZERO_SHOT_IMAGE_CLASSIFICATION, + "padding": ["max_length", True], + } + ) + ) + def test_raise_autocast(self, model_id, padding, max_length=20): + super().test_raise_autocast([model_id], padding=padding, max_length=max_length) + + @parameterized.expand( + grid_parameters( + { + "model_id": ALL_ZERO_SHOT_IMAGE_CLASSIFICATION, + "padding": ["max_length", True], + } + ) + ) + def test_raise_train(self, model_id, padding, max_length=20): + super().test_raise_train([model_id], padding=padding, max_length=max_length) + + +class BetterTransformersFlavaTest(BetterTransformersTestMixin, unittest.TestCase): + r""" + Testing suite for Vision and Text Models - tests all the tests defined in `BetterTransformersTestMixin` + """ + all_models_to_test = ["ybelkada/tiny-random-flava"] + def prepare_inputs_for_class(self, model_id=None): url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) @@ -66,3 +152,18 @@ def prepare_inputs_for_class(self, model_id=None): processor = AutoProcessor.from_pretrained(model_id) inputs = processor(image, text, return_tensors="pt") return inputs + + def test_raise_activation_fun(self): + r""" + A tests that checks if the conversion raises an error if the model contains an activation function + that is not supported by `BetterTransformer`. Here we need to loop over the config files + """ + from transformers import FlavaConfig + + hf_random_config = FlavaConfig() + + hf_random_config.image_config.hidden_act = "silu" + + hf_random_model = AutoModel.from_config(hf_random_config).eval() + with self.assertRaises(ValueError): + _ = BetterTransformer.transform(hf_random_model, keep_original_model=True) diff --git a/tests/bettertransformer/testing_bettertransformer_utils.py b/tests/bettertransformer/testing_bettertransformer_utils.py index e1876442d7..7fdbc82499 100644 --- a/tests/bettertransformer/testing_bettertransformer_utils.py +++ b/tests/bettertransformer/testing_bettertransformer_utils.py @@ -20,6 +20,7 @@ from transformers import AutoModel from optimum.bettertransformer import BetterTransformer +from optimum.utils.testing_utils import flatten_dict class BetterTransformersTestMixin: @@ -74,14 +75,21 @@ def test_logits(self, models_to_test: Optional[List] = None, **preprocessor_kwar torch.manual_seed(0) bt_hidden_states = converted_model(**inputs)[0] - if "gelu_new" in random_config.to_dict().values(): + if "quick_gelu" in flatten_dict(random_config.to_dict()).values(): + # Since `quick_gelu` is a rather slightly modified version of `GeLU` we expect a discrepency. + tol = 3e-1 + elif "gelu_new" in flatten_dict(random_config.to_dict()).values(): # Since `gelu_new` is a slightly modified version of `GeLU` we expect a small # discrepency. tol = 4e-2 else: tol = 2e-3 - if "attention_mask" in inputs: + if hasattr(self, "compare_outputs"): + self.compare_outputs( + hf_hidden_states, bt_hidden_states, atol=tol, model_name=hf_random_model.__class__.__name__ + ) + elif "attention_mask" in inputs: for i, attention_mask in enumerate(inputs["attention_mask"]): length = torch.argwhere(attention_mask != 0).max().item() self.assert_equal( @@ -115,14 +123,16 @@ def test_raise_on_save(self): bt_model = BetterTransformer.transform(hf_model, keep_original_model=False) bt_model.save_pretrained(tmpdirname) - def test_raise_autocast(self): + def test_raise_autocast(self, models_to_test=None, **preprocessor_kwargs): r""" A tests that checks if the conversion raises an error if the model is run under `torch.cuda.amp.autocast`. """ + if models_to_test is None: + models_to_test = self.all_models_to_test - for model_id in self.all_models_to_test: - inputs = self.prepare_inputs_for_class(model_id) + for model_id in models_to_test: + inputs = self.prepare_inputs_for_class(model_id=model_id, **preprocessor_kwargs) hf_random_model = AutoModel.from_pretrained(model_id).eval() # Check for the autocast on CPU @@ -130,13 +140,16 @@ def test_raise_autocast(self): bt_model = BetterTransformer.transform(hf_random_model, keep_original_model=True) _ = bt_model(**inputs) - def test_raise_train(self): + def test_raise_train(self, models_to_test=None, **preprocessor_kwargs): r""" A tests that checks if the conversion raises an error if the model is run under `model.train()`. """ - for model_id in self.all_models_to_test: - inputs = self.prepare_inputs_for_class(model_id) + if models_to_test is None: + models_to_test = self.all_models_to_test + + for model_id in models_to_test: + inputs = self.prepare_inputs_for_class(model_id=model_id, **preprocessor_kwargs) hf_random_model = AutoModel.from_pretrained(model_id).eval() # Check for training mode @@ -173,7 +186,10 @@ def get_batch(batch_size, avg_seqlen, max_sequence_length, seqlen_stdev, vocab_s mean_tensor = torch.Tensor([avg_seqlen]).expand(batch_size) stdev_tensor = torch.Tensor([seqlen_stdev]).expand(batch_size) lengths = torch.normal(mean_tensor, stdev_tensor).to(torch.int) - lengths = torch.clamp(lengths, min=0, max=max_sequence_length) + + # need at least a sequence length of 1 for BetterTransformer to work + lengths = torch.clamp(lengths, min=1, max=max_sequence_length) + tokens = torch.full( (batch_size, max_sequence_length), pad_idx, diff --git a/tests/exporters/test_onnx_export.py b/tests/exporters/test_onnx_export.py index 22fc69691d..d41911a642 100644 --- a/tests/exporters/test_onnx_export.py +++ b/tests/exporters/test_onnx_export.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from pathlib import Path -from tempfile import NamedTemporaryFile +from tempfile import NamedTemporaryFile, TemporaryDirectory from unittest import TestCase from unittest.mock import patch @@ -25,9 +25,11 @@ OnnxConfig, OnnxConfigWithPast, export, - export_encoder_decoder_model, - validate_encoder_decoder_model_outputs, + export_models, + get_decoder_models_for_export, + get_encoder_decoder_models_for_export, validate_model_outputs, + validate_models_outputs, ) from parameterized import parameterized @@ -283,33 +285,34 @@ def _onnx_export( if isinstance(atol, dict): atol = atol[task.replace("-with-past", "")] - if for_ort: - with NamedTemporaryFile("w") as encoder_output, NamedTemporaryFile( - "w" - ) as decoder_output, NamedTemporaryFile("w") as decoder_with_past_output: + if for_ort is True and (model.config.is_encoder_decoder or task.startswith("causal-lm")): + fn_get_models_from_config = ( + get_encoder_decoder_models_for_export + if model.config.is_encoder_decoder + else get_decoder_models_for_export + ) + + with TemporaryDirectory() as tmpdirname: try: - onnx_inputs, onnx_outputs = export_encoder_decoder_model( + onnx_inputs, onnx_outputs = export_models( model, onnx_config, onnx_config.DEFAULT_ONNX_OPSET, - Path(encoder_output.name), - Path(decoder_output.name), - Path(decoder_with_past_output.name), + output_dir=Path(tmpdirname), + fn_get_models_from_config=fn_get_models_from_config, device=device, ) - validate_encoder_decoder_model_outputs( + validate_models_outputs( onnx_config, model, onnx_outputs, atol, - Path(encoder_output.name), - Path(decoder_output.name), - Path(decoder_with_past_output.name), + output_dir=Path(tmpdirname), + fn_get_models_from_config=fn_get_models_from_config, ) except (RuntimeError, ValueError) as e: self.fail(f"{name}, {task} -> {e}") - else: with NamedTemporaryFile("w") as output: try: diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 380ff99d80..4841d01efa 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -41,7 +41,6 @@ import onnxruntime import requests from huggingface_hub.constants import default_cache_path -from huggingface_hub.utils import EntryNotFoundError from optimum.onnxruntime import ( ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, @@ -226,7 +225,7 @@ def test_load_seq2seq_model_unknown_provider(self): ORTModelForSeq2SeqLM.from_pretrained(self.ONNX_SEQ2SEQ_MODEL_ID, provider="FooExecutionProvider") def test_load_model_from_hub_without_onnx_model(self): - with self.assertRaises(EntryNotFoundError): + with self.assertRaises(FileNotFoundError): ORTModel.from_pretrained(self.FAIL_ONNX_MODEL_ID) def test_model_on_cpu(self): @@ -451,7 +450,7 @@ def test_save_model_with_different_name(self): model = ORTModel.from_pretrained(tmpdirname, file_name=test_model_name) - self.assertEqual(model.latest_model_name, test_model_name) + self.assertEqual(model.model_name, test_model_name) @require_hf_token def test_save_model_from_hub(self): @@ -1685,3 +1684,24 @@ def test_default_pipeline_and_model_device(self, *args, **kwargs): tokenizer = get_preprocessor(model_id) pipe = pipeline("feature-extraction", model=onnx_model, tokenizer=tokenizer) self.assertEqual(pipe.device, onnx_model.device) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + @require_torch_gpu + def test_compare_to_io_binding(self, *args, **kwargs): + model_arch, model_id = args + set_seed(SEED) + onnx_model = ORTModelForCustomTasks.from_pretrained(model_id, use_io_binding=False) + set_seed(SEED) + io_model = ORTModelForCustomTasks.from_pretrained(model_id, use_io_binding=True) + tokenizer = get_preprocessor(model_id) + tokens = tokenizer("This is a sample output", return_tensors="pt") + onnx_outputs = onnx_model(**tokens) + io_outputs = io_model(**tokens) + + self.assertTrue("pooler_output" in io_outputs) + self.assertIsInstance(io_outputs.pooler_output, torch.Tensor) + + # compare tensor outputs + self.assertTrue(torch.equal(onnx_outputs.pooler_output, io_outputs.pooler_output)) + + gc.collect() diff --git a/tests/onnxruntime/test_optimization.py b/tests/onnxruntime/test_optimization.py index 9f05cba905..47aa65ab19 100644 --- a/tests/onnxruntime/test_optimization.py +++ b/tests/onnxruntime/test_optimization.py @@ -18,15 +18,12 @@ import unittest from pathlib import Path -import numpy as np import torch -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoTokenizer import onnx -from onnxruntime import InferenceSession from optimum.onnxruntime import ORTConfig, ORTModelForSequenceClassification, ORTOptimizer -from optimum.onnxruntime.configuration import AutoQuantizationConfig, OptimizationConfig -from optimum.onnxruntime.modeling_ort import ORTModelForSequenceClassification +from optimum.onnxruntime.configuration import OptimizationConfig from optimum.onnxruntime.modeling_seq2seq import ORTModelForSeq2SeqLM from parameterized import parameterized @@ -96,9 +93,6 @@ def test_compare_original_seq2seq_model_with_optimized_model(self, model_cls, mo optimizer.optimize(optimization_config=optimization_config, save_dir=tmp_dir) optimized_model = model_cls.from_pretrained( tmp_dir, - encoder_file_name="encoder_model_optimized.onnx", - decoder_file_name="decoder_model_optimized.onnx", - decoder_with_past_file_name="decoder_with_past_model_optimized.onnx" if use_cache else None, from_transformers=False, use_cache=use_cache, ) diff --git a/tests/onnxruntime/test_quantization.py b/tests/onnxruntime/test_quantization.py index 8ebc0fdacc..b3c83acafa 100644 --- a/tests/onnxruntime/test_quantization.py +++ b/tests/onnxruntime/test_quantization.py @@ -54,7 +54,7 @@ def test_from_pretrained_method(self, *args): def test_fail_from_pretrained_method(self): with self.assertRaises(Exception) as context: ORTQuantizer.from_pretrained("bert-base-cased") - self.assertIn("Unable to load model from bert-base-cased", str(context.exception)) + self.assertIn("Could not find any ONNX model file in bert-base-cased", str(context.exception)) with self.assertRaises(Exception) as context: model = ORTModelForSeq2SeqLM.from_pretrained("optimum/t5-small")