Skip to content

Commit

Permalink
Move & rename onnx_export (#1685)
Browse files Browse the repository at this point in the history
* move & rename onnx_export

* fix test

* Update optimum/exporters/onnx/convert.py

Co-authored-by: Ella Charlaix <[email protected]>

---------

Co-authored-by: Ella Charlaix <[email protected]>
  • Loading branch information
fxmarty and echarlaix authored Feb 8, 2024
1 parent 3988bbd commit 32a51af
Show file tree
Hide file tree
Showing 9 changed files with 474 additions and 390 deletions.
2 changes: 2 additions & 0 deletions docs/source/exporters/onnx/package_reference/export.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ You can export models to ONNX from two frameworks in 🤗 Optimum: PyTorch and T

[[autodoc]] exporters.onnx.main_export

[[autodoc]] exporters.onnx.onnx_export_from_model

[[autodoc]] exporters.onnx.convert.export

[[autodoc]] exporters.onnx.convert.export_pytorch
Expand Down
2 changes: 1 addition & 1 deletion optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def parse_args(parser: "ArgumentParser"):
return parse_args_onnx(parser)

def run(self):
from ...exporters.onnx.__main__ import main_export
from ...exporters.onnx import main_export

# Get the shapes to be used to generate dummy inputs
input_shapes = {}
Expand Down
16 changes: 14 additions & 2 deletions optimum/exporters/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
_import_structure = {
"base": ["OnnxConfig", "OnnxConfigWithLoss", "OnnxConfigWithPast", "OnnxSeq2SeqConfigWithPast"],
"config": ["TextDecoderOnnxConfig", "TextEncoderOnnxConfig", "TextSeq2SeqOnnxConfig"],
"convert": ["export", "export_models", "validate_model_outputs", "validate_models_outputs"],
"convert": [
"export",
"export_models",
"validate_model_outputs",
"validate_models_outputs",
"onnx_export_from_model",
],
"utils": [
"get_decoder_models_for_export",
"get_encoder_decoder_models_for_export",
Expand All @@ -34,7 +40,13 @@
if TYPE_CHECKING:
from .base import OnnxConfig, OnnxConfigWithLoss, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast # noqa
from .config import TextDecoderOnnxConfig, TextEncoderOnnxConfig, TextSeq2SeqOnnxConfig # noqa
from .convert import export, export_models, validate_model_outputs, validate_models_outputs # noqa
from .convert import (
export,
export_models,
validate_model_outputs,
validate_models_outputs,
onnx_export_from_model,
) # noqa
from .utils import (
get_decoder_models_for_export,
get_encoder_decoder_models_for_export,
Expand Down
389 changes: 8 additions & 381 deletions optimum/exporters/onnx/__main__.py

Large diffs are not rendered by default.

345 changes: 343 additions & 2 deletions optimum/exporters/onnx/convert.py

Large diffs are not rendered by default.

101 changes: 100 additions & 1 deletion optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Utility functions."""

import copy
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from packaging import version
Expand Down Expand Up @@ -560,3 +560,102 @@ def __setstate__(self, values):

self.model_path = values["model_path"]
self.sess = ort.InferenceSession(self.model_path, sess_options=self.sess_options, providers=self.providers)


def _get_submodels_and_onnx_configs(
model: Union["PreTrainedModel", "TFPreTrainedModel"],
task: str,
monolith: bool,
custom_onnx_configs: Dict,
custom_architecture: bool,
_variant: str,
library_name: str,
int_dtype: str = "int64",
float_dtype: str = "fp32",
fn_get_submodels: Optional[Callable] = None,
preprocessors: Optional[List[Any]] = None,
legacy: bool = False,
model_kwargs: Optional[Dict] = None,
):
if not custom_architecture:
if library_name == "diffusers":
onnx_config = None
models_and_onnx_configs = get_stable_diffusion_models_for_export(
model, int_dtype=int_dtype, float_dtype=float_dtype
)
else:
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model=model, exporter="onnx", task=task, library_name=library_name
)
onnx_config = onnx_config_constructor(
model.config,
int_dtype=int_dtype,
float_dtype=float_dtype,
preprocessors=preprocessors,
legacy=legacy,
)

onnx_config.variant = _variant
all_variants = "\n".join(
[f" - {name}: {description}" for name, description in onnx_config.VARIANTS.items()]
)
logger.info(f"Using the export variant {onnx_config.variant}. Available variants are:\n{all_variants}")

# TODO: this succession of if/else strongly suggests a refactor is needed.
if (
model.config.is_encoder_decoder
and task.startswith(TasksManager._ENCODER_DECODER_TASKS)
and not monolith
):
models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config)
elif task.startswith("text-generation") and not monolith:
models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config, legacy=legacy)
elif model.config.model_type == "sam":
models_and_onnx_configs = get_sam_models_for_export(model, onnx_config)
elif model.config.model_type == "speecht5":
models_and_onnx_configs = get_speecht5_models_for_export(model, onnx_config, model_kwargs)
else:
models_and_onnx_configs = {"model": (model, onnx_config)}

# When specifying custom ONNX configs for supported transformers architectures, we do
# not force to specify a custom ONNX config for each submodel.
for key, custom_onnx_config in custom_onnx_configs.items():
models_and_onnx_configs[key] = (models_and_onnx_configs[key][0], custom_onnx_config)
else:
onnx_config = None
submodels_for_export = None
models_and_onnx_configs = {}

if fn_get_submodels is not None:
submodels_for_export = fn_get_submodels(model)
else:
if library_name == "diffusers":
submodels_for_export = _get_submodels_for_export_stable_diffusion(model)
elif (
model.config.is_encoder_decoder
and task.startswith(TasksManager._ENCODER_DECODER_TASKS)
and not monolith
):
submodels_for_export = _get_submodels_for_export_encoder_decoder(
model, use_past=task.endswith("-with-past")
)
elif task.startswith("text-generation") and not monolith:
submodels_for_export = _get_submodels_for_export_decoder(model, use_past=task.endswith("-with-past"))
else:
submodels_for_export = {"model": model}

if submodels_for_export.keys() != custom_onnx_configs.keys():
logger.error(f"ONNX custom configs for: {', '.join(custom_onnx_configs.keys())}")
logger.error(f"Submodels to export: {', '.join(submodels_for_export.keys())}")
raise ValueError(
"Trying to export a custom model, but could not find as many custom ONNX configs as the number of submodels to export. Please specifiy the fn_get_submodels argument, that should return a dictionary of submodules with as many items as the provided custom_onnx_configs dictionary."
)

for key, custom_onnx_config in custom_onnx_configs.items():
models_and_onnx_configs[key] = (submodels_for_export[key], custom_onnx_config)

# Default to the first ONNX config for stable-diffusion and custom architecture case.
if onnx_config is None:
onnx_config = next(iter(models_and_onnx_configs.values()))[1]

return onnx_config, models_and_onnx_configs
2 changes: 2 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1721,6 +1721,8 @@ def standardize_model_attributes(
library_name (`Optional[str]`, *optional*)::
The library name of the model. Can be any of "transformers", "timm", "diffusers", "sentence_transformers".
"""
# TODO: make model_name_or_path an optional argument here.

library_name = TasksManager.infer_library_from_model(
model_name_or_path, subfolder, revision, cache_dir, library_name
)
Expand Down
2 changes: 1 addition & 1 deletion tests/exporters/onnx/test_exporters_onnx_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow

from optimum.exporters.error_utils import MinimumVersionError
from optimum.exporters.onnx.__main__ import main_export
from optimum.exporters.onnx import main_export
from optimum.onnxruntime import (
ONNX_DECODER_MERGED_NAME,
ONNX_DECODER_NAME,
Expand Down
5 changes: 3 additions & 2 deletions tests/exporters/onnx/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@
get_decoder_models_for_export,
get_encoder_decoder_models_for_export,
get_stable_diffusion_models_for_export,
main_export,
onnx_export_from_model,
validate_models_outputs,
)
from optimum.exporters.onnx.__main__ import main_export, onnx_export
from optimum.exporters.onnx.base import ConfigBehavior
from optimum.exporters.onnx.config import TextDecoderOnnxConfig
from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
Expand Down Expand Up @@ -632,7 +633,7 @@ def _onnx_export(
preprocessors = None

with TemporaryDirectory() as tmpdirname:
onnx_export(
onnx_export_from_model(
model=model,
output=Path(tmpdirname),
monolith=monolith,
Expand Down

0 comments on commit 32a51af

Please sign in to comment.