Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move check_dummy_inputs_allowed to common export utils #2114

Merged
merged 9 commits into from
Dec 19, 2024
12 changes: 8 additions & 4 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,12 @@
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import onnx
from transformers.utils import is_accelerate_available, is_torch_available

from ...onnx import remove_duplicate_weights_from_tied_info


if is_torch_available():
import torch.nn as nn

from ...onnx import merge_decoders
Comment on lines -30 to -39
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would prefer to keep import there, why moving ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge_decoders imported from graph_transformations submodule that contains onnx import
https://github.com/huggingface/optimum/blob/main/optimum/onnx/graph_transformations.py#L19

as this file heavily depends on onnx functionality, while merge decoders used only for specific configs postprocessing triggered only if specific config behaviour is selected, I think it may be better to allow use other functional from this module without necessarily to use onnx import.

As I said there is case, when import onnx crashed on windows. This file contains basic functionality for configs that we also reuse for enabling new models in optimum-intel moving this import allow us continue usage of base config classes even if onnx broken

from ...utils import (
DEFAULT_DUMMY_SHAPES,
DummyInputGenerator,
Expand All @@ -54,6 +50,8 @@
from .model_patcher import ModelPatcher, Seq2SeqModelPatcher


# TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization

if is_accelerate_available():
from accelerate.utils import find_tied_parameters

Expand Down Expand Up @@ -542,6 +540,10 @@ def post_process_exported_models(
first_key = next(iter(models_and_onnx_configs))
if is_torch_available() and isinstance(models_and_onnx_configs[first_key][0], nn.Module):
if is_accelerate_available():
import onnx

from ...onnx import remove_duplicate_weights_from_tied_info

logger.info("Deduplicating shared (tied) weights...")
for subpath, key in zip(onnx_files_subpaths, models_and_onnx_configs):
torch_model = models_and_onnx_configs[key][0]
Expand Down Expand Up @@ -934,6 +936,8 @@ def post_process_exported_models(
decoder_with_past_path = Path(path, onnx_files_subpaths[2])
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
try:
from ...onnx import merge_decoders

# The decoder with past does not output the cross attention past key values as they are constant,
# hence the need for strict=False
merge_decoders(
Expand Down
6 changes: 5 additions & 1 deletion optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from transformers.utils import is_tf_available

from ...onnx import merge_decoders
from ...utils import (
DummyAudioInputGenerator,
DummyBboxInputGenerator,
Expand All @@ -38,6 +37,9 @@
from .model_patcher import DecoderModelPatcher


# TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization


if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel

Expand Down Expand Up @@ -129,6 +131,8 @@ def post_process_exported_models(

# Attempt to merge only if the decoder-only was exported separately without/with past
if self.use_past is True and len(models_and_onnx_configs) == 2:
from ...onnx import merge_decoders

decoder_path = Path(path, onnx_files_subpaths[0])
decoder_with_past_path = Path(path, onnx_files_subpaths[1])
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
Expand Down
29 changes: 4 additions & 25 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from inspect import signature
from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import onnx
Expand All @@ -45,6 +45,7 @@
from ...utils.save_utils import maybe_save_preprocessors
from ..error_utils import AtolError, MinimumVersionError, OutputMatchError, ShapeError
from ..tasks import TasksManager
from ..utils import check_dummy_inputs_are_allowed
from .base import OnnxConfig
from .constants import UNPICKABLE_ARCHS
from .model_configs import SpeechT5OnnxConfig
Expand All @@ -56,6 +57,8 @@
)


# TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization

if is_torch_available():
import torch
import torch.nn as nn
Expand All @@ -75,30 +78,6 @@ class DynamicAxisNameError(ValueError):
pass


def check_dummy_inputs_are_allowed(
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], dummy_input_names: Iterable[str]
):
"""
Checks that the dummy inputs from the ONNX config is a subset of the allowed inputs for `model`.
Args:
model (`Union[transformers.PreTrainedModel, transformers.TFPreTrainedModel`]):
The model instance.
model_inputs (`Iterable[str]`):
The model input names.
"""

forward = model.forward if is_torch_available() and isinstance(model, nn.Module) else model.call
forward_parameters = signature(forward).parameters
forward_inputs_set = set(forward_parameters.keys())
dummy_input_names = set(dummy_input_names)

# We are fine if config_inputs has more keys than model_inputs
if not dummy_input_names.issubset(forward_inputs_set):
raise ValueError(
f"Config dummy inputs are not a subset of the model inputs: {dummy_input_names} vs {forward_inputs_set}"
)


def validate_models_outputs(
models_and_onnx_configs: Dict[
str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]
Expand Down
6 changes: 5 additions & 1 deletion optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from packaging import version
from transformers.utils import is_tf_available

from ...onnx import merge_decoders
from ...utils import (
DEFAULT_DUMMY_SHAPES,
BloomDummyPastKeyValuesGenerator,
Expand Down Expand Up @@ -93,6 +92,9 @@
)


# TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization


if TYPE_CHECKING:
from transformers import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
Expand Down Expand Up @@ -1875,6 +1877,8 @@ def post_process_exported_models(
decoder_with_past_path = Path(path, onnx_files_subpaths[3])
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
try:
from ...onnx import merge_decoders

# The decoder with past does not output the cross attention past key values as they are constant,
# hence the need for strict=False
merge_decoders(
Expand Down
27 changes: 26 additions & 1 deletion optimum/exporters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
"""Utilities for model preparation to export."""

import copy
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from inspect import signature
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
from packaging import version
Expand Down Expand Up @@ -675,3 +676,27 @@ def _get_submodels_and_export_configs(
export_config = next(iter(models_and_export_configs.values()))[1]

return export_config, models_and_export_configs


def check_dummy_inputs_are_allowed(
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], dummy_input_names: Iterable[str]
):
"""
Checks that the dummy inputs from the ONNX config is a subset of the allowed inputs for `model`.
Args:
model (`Union[transformers.PreTrainedModel, transformers.TFPreTrainedModel`]):
The model instance.
model_inputs (`Iterable[str]`):
The model input names.
"""

forward = model.forward if is_torch_available() and isinstance(model, torch.nn.Module) else model.call
forward_parameters = signature(forward).parameters
forward_inputs_set = set(forward_parameters.keys())
dummy_input_names = set(dummy_input_names)

# We are fine if config_inputs has more keys than model_inputs
if not dummy_input_names.issubset(forward_inputs_set):
raise ValueError(
f"Config dummy inputs are not a subset of the model inputs: {dummy_input_names} vs {forward_inputs_set}"
)
Loading