Skip to content

Commit

Permalink
SpeechT5 ONNX support (#1404)
Browse files Browse the repository at this point in the history
* wip

* wip bis

* nit

* nit^2

* working export

* working with-past version

* add test

* add doc

* working merged onnx

* fix dropout with training=True export

* test fix

* fix custom models

* some cleaning

* merge mess

* address review comments

* fix tests
  • Loading branch information
fxmarty authored Oct 18, 2023
1 parent 1ae95a7 commit 554a83a
Show file tree
Hide file tree
Showing 16 changed files with 612 additions and 96 deletions.
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ Supported architectures:
- SEW
- SEW-D
- Speech2Text
- SpeechT5
- Splinter
- SqueezeBert
- Stable Diffusion
Expand Down
24 changes: 16 additions & 8 deletions optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Defines the command line for the export with ONNX."""

import argparse
import json
from pathlib import Path
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -136,6 +137,20 @@ def parse_args_onnx(parser):
default=None,
help=("The library on the model." " If not provided, will attempt to infer the local checkpoint's library"),
)
optional_group.add_argument(
"--model-kwargs",
type=json.loads,
help=("Any kwargs passed to the model forward, or used to customize the export for a given model."),
)
optional_group.add_argument(
"--legacy",
action="store_true",
help=(
"Export decoder only models in three files (without + with past and the resulting merged model)."
"Also disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum."
),
)

input_group = parser.add_argument_group(
"Input shapes (if necessary, this allows to override the shapes of the input given to the ONNX exporter, that requires an example input)."
)
Expand Down Expand Up @@ -209,14 +224,6 @@ def parse_args_onnx(parser):
default=DEFAULT_DUMMY_SHAPES["nb_points_per_image"],
help="For Segment Anything. It corresponds to the number of points per segmentation masks.",
)
optional_group.add_argument(
"--legacy",
action="store_true",
help=(
"Export decoder only models in three files (without + with past and the resulting merged model)."
"Also disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum."
),
)

# deprecated argument
parser.add_argument("--for-ort", action="store_true", help=argparse.SUPPRESS)
Expand Down Expand Up @@ -256,5 +263,6 @@ def run(self):
_variant=self.args.variant,
library_name=self.args.library_name,
legacy=self.args.legacy,
model_kwargs=self.args.model_kwargs,
**input_shapes,
)
42 changes: 35 additions & 7 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pathlib import Path

from requests.exceptions import ConnectionError as RequestsConnectionError
from transformers import AutoTokenizer
from transformers import AutoConfig, AutoTokenizer
from transformers.utils import is_torch_available

from ...commands.export.onnx import parse_args_onnx
Expand All @@ -38,6 +38,7 @@
get_decoder_models_for_export,
get_encoder_decoder_models_for_export,
get_sam_models_for_export,
get_speecht5_models_for_export,
get_stable_diffusion_models_for_export,
)

Expand Down Expand Up @@ -69,6 +70,7 @@ def _get_submodels_and_onnx_configs(
fn_get_submodels: Optional[Callable] = None,
preprocessors: Optional[List[Any]] = None,
legacy: bool = False,
model_kwargs: Optional[Dict] = None,
):
is_stable_diffusion = "stable-diffusion" in task
if not custom_architecture:
Expand All @@ -95,10 +97,11 @@ def _get_submodels_and_onnx_configs(

onnx_config.variant = _variant
all_variants = "\n".join(
[f"\t- {name}: {description}" for name, description in onnx_config.VARIANTS.items()]
[f" - {name}: {description}" for name, description in onnx_config.VARIANTS.items()]
)
logger.info(f"Using the export variant {onnx_config.variant}. Available variants are:\n{all_variants}")

# TODO: this succession of if/else strongly suggests a refactor is needed.
if (
model.config.is_encoder_decoder
and task.startswith(TasksManager._ENCODER_DECODER_TASKS)
Expand All @@ -109,6 +112,8 @@ def _get_submodels_and_onnx_configs(
models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config, legacy=legacy)
elif model.config.model_type == "sam":
models_and_onnx_configs = get_sam_models_for_export(model, onnx_config)
elif model.config.model_type == "speecht5":
models_and_onnx_configs = get_speecht5_models_for_export(model, onnx_config, model_kwargs)
else:
models_and_onnx_configs = {"model": (model, onnx_config)}

Expand Down Expand Up @@ -333,6 +338,30 @@ def main_export(
f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)

if library_name == "transformers":
config = AutoConfig.from_pretrained(
model_name_or_path,
subfolder=subfolder,
revision=revision,
cache_dir=cache_dir,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
)
model_type = config.model_type.replace("_", "-")
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
custom_architecture = True
elif task not in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx"):
if original_task == "auto":
autodetected_message = " (auto-detected)"
else:
autodetected_message = ""
model_tasks = TasksManager.get_supported_tasks_for_model_type(model_type, exporter="onnx")
raise ValueError(
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum ONNX exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
)

model = TasksManager.get_model_from_task(
task,
model_name_or_path,
Expand Down Expand Up @@ -361,18 +390,16 @@ def main_export(
if not is_stable_diffusion:
if model_type in TasksManager._UNSUPPORTED_CLI_MODEL_TYPE:
raise ValueError(
f"{model_type} is not supported yet. Only {TasksManager._SUPPORTED_CLI_MODEL_TYPE} are supported. "
f"{model_type} is not supported yet. Only {list(TasksManager._SUPPORTED_CLI_MODEL_TYPE.keys())} are supported. "
f"If you want to support {model_type} please propose a PR or open up an issue."
)
if model.config.model_type.replace("-", "_") not in TasksManager.get_supported_model_type_for_task(
task, exporter="onnx"
):
if model.config.model_type.replace("_", "-") not in TasksManager._SUPPORTED_MODEL_TYPE:
custom_architecture = True

# TODO: support onnx_config.py in the model repo
if custom_architecture and custom_onnx_configs is None:
raise ValueError(
f"Trying to export a {model.config.model_type.replace('-', '_')} model, that is a custom or unsupported architecture for the task {task}, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. For the task {task}, the Optimum ONNX exporter supports natively the architectures: {TasksManager.get_supported_model_type_for_task(task, exporter='onnx')}."
f"Trying to export a {model.config.model_type} model, that is a custom or unsupported architecture for the task {task}, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type {model.config.model_type} to be supported natively in the ONNX export."
)

if custom_architecture and original_task == "auto":
Expand Down Expand Up @@ -425,6 +452,7 @@ def main_export(
preprocessors=preprocessors,
_variant=_variant,
legacy=legacy,
model_kwargs=model_kwargs,
)

if not is_stable_diffusion:
Expand Down
23 changes: 14 additions & 9 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class OnnxConfig(ExportConfig, ABC):
MIN_TRANSFORMERS_VERSION = GLOBAL_MIN_TRANSFORMERS_VERSION
PATCHING_SPECS: Optional[List["PatchingSpec"]] = None
VARIANTS = {"default": "The default ONNX variant."}
DEFAULT_VARIANT = "default"
_TASK_TO_COMMON_OUTPUTS = {
"audio-classification": OrderedDict({"logits": {0: "batch_size"}}),
"audio-frame-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
Expand Down Expand Up @@ -200,17 +201,14 @@ def __init__(
int_dtype: str = "int64",
float_dtype: str = "fp32",
):
if task not in self._TASK_TO_COMMON_OUTPUTS:
raise ValueError(
f"{task} is not a supported task, supported tasks: {', '.join(self._TASK_TO_COMMON_OUTPUTS.keys())}"
)
self.task = task
self.int_dtype = int_dtype
self.float_dtype = float_dtype

self._config = config
self._preprocessors = preprocessors
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)
self.variant = "default"

def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGenerator]:
"""
Expand Down Expand Up @@ -808,7 +806,8 @@ def with_behavior(
"""
if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior):
behavior = ConfigBehavior(behavior)
return self.__class__(

onnx_config = self.__class__(
self._config,
task=self.task,
int_dtype=self.int_dtype,
Expand All @@ -818,6 +817,8 @@ def with_behavior(
behavior=behavior,
preprocessors=self._preprocessors,
)
onnx_config.variant = self.variant
return onnx_config

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
Expand Down Expand Up @@ -902,8 +903,8 @@ def post_process_exported_models(
path, models_and_onnx_configs, onnx_files_subpaths
)

# Attempt to merge only if the decoder was exported without/with past
if self.use_past is True and len(models_and_onnx_configs) == 3:
# Attempt to merge only if the decoder was exported without/with past, and ignore seq2seq models exported with text-generation task
if len(onnx_files_subpaths) >= 3 and self.use_past is True:
decoder_path = Path(path, onnx_files_subpaths[1])
decoder_with_past_path = Path(path, onnx_files_subpaths[2])
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
Expand All @@ -922,7 +923,8 @@ def post_process_exported_models(
# In order to do the validation of the two branches on the same file
encoder_path = onnx_files_subpaths[0]

onnx_files_subpaths = [encoder_path, decoder_merged_path.name, decoder_merged_path.name]
onnx_files_subpaths_new = [encoder_path, decoder_merged_path.name, decoder_merged_path.name]
onnx_files_subpaths_new.extend(onnx_files_subpaths[3:])

# We validate the two branches of the decoder model then
models_and_onnx_configs[ONNX_DECODER_NAME][1].is_merged = True
Expand All @@ -933,8 +935,10 @@ def post_process_exported_models(

models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True
models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].is_merged = True
else:
onnx_files_subpaths_new = onnx_files_subpaths

return models_and_onnx_configs, onnx_files_subpaths
return models_and_onnx_configs, onnx_files_subpaths_new

def generate_dummy_inputs_for_validation(
self, reference_model_inputs: Dict[str, Any], onnx_input_names: Optional[List[str]] = None
Expand Down Expand Up @@ -1006,6 +1010,7 @@ def __init__(self, config: OnnxConfig, int_dtype: str = "int64", float_dtype: st
self.float_dtype = float_dtype
self._normalized_config = self._onnx_config._normalized_config
self.PATCHING_SPECS = self._onnx_config.PATCHING_SPECS
self.variant = "default"

@classmethod
def from_onnx_config(cls, config: OnnxConfig) -> "OnnxConfigWithLoss":
Expand Down
19 changes: 10 additions & 9 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from ..error_utils import AtolError, MinimumVersionError, OutputMatchError, ShapeError
from .base import OnnxConfig
from .model_configs import SpeechT5OnnxConfig
from .utils import PickableInferenceSession, recursive_to_device


Expand Down Expand Up @@ -142,15 +143,13 @@ def validate_models_outputs(
if use_subprocess:
logger.info("Validating models in subprocesses...")
exceptions = [] # run all validations before raising
onnx_paths = []
for i, model_name in enumerate(models_and_onnx_configs.keys()):
submodel, sub_onnx_config = models_and_onnx_configs[model_name]
onnx_model_path = (
output_dir.joinpath(onnx_files_subpaths[i])
if onnx_files_subpaths is not None
else output_dir.joinpath(model_name + ".onnx")
)
onnx_paths.append(onnx_model_path)
try:
# Model validation is done in subprocesses, as ONNX Runtime has the bad habit of
# not releasing memory once an InferenceSession is initialized.
Expand All @@ -168,12 +167,12 @@ def validate_models_outputs(
model_kwargs=model_kwargs,
)
except Exception as e:
exceptions.append(e)
exceptions.append((onnx_model_path, e))

if len(exceptions) != 0:
for i, exception in enumerate(exceptions[:-1]):
logger.error(f"Validation {i} for the model {onnx_paths[i].as_posix()} raised: {exception}")
raise exceptions[-1]
logger.error(f"Validation for the model {exception[0].as_posix()} raised: {exception[1]}")
raise exceptions[-1][1]


def validate_model_outputs(
Expand Down Expand Up @@ -423,9 +422,11 @@ def _run_validation(

if value_failures:
msg = "\n".join(f"- {t[0]}: max diff = {t[1]}" for t in value_failures)
raise AtolError(
f"The maximum absolute difference between the output of the reference model and the ONNX exported model is not within the set tolerance {atol}:\n{msg}"
)
atol_msg = f"The maximum absolute difference between the output of the reference model and the ONNX exported model is not within the set tolerance {atol}:\n{msg}"

if isinstance(config, SpeechT5OnnxConfig):
atol_msg += "\nIMPORTANT NOTE: SpeechT5 uses a dropout at inference and the output validation of ONNX Runtime inference vs PyTorch is expected to fail. Reference: https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L727"
raise AtolError(atol_msg)


class ValidationProcess(mp.Process):
Expand Down Expand Up @@ -526,7 +527,7 @@ def export_pytorch(

with torch.no_grad():
model.config.return_dict = True
model.eval()
model = model.eval()

# Check if we need to override certain configuration item
if config.values_override is not None:
Expand Down
Loading

0 comments on commit 554a83a

Please sign in to comment.