Skip to content

Commit

Permalink
Enable export of model with fixed shape (#1643)
Browse files Browse the repository at this point in the history
* add no dynamic axes arg

* add tests

* update doc

* update doc

* update test

* add test for shape check

* fix style
  • Loading branch information
mht-sharma authored Jan 17, 2024
1 parent 6984c10 commit 130197f
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ Supported architectures from [🤗 Timm](https://huggingface.co/docs/timm/index)
- EfficientNet
- EfficientNet (Knapsack Pruned)
- Ensemble Adversarial Inception ResNet v2
- ESE-VoVNet (Partial support with static shapes)
- FBNet
- (Gluon) Inception v3
- (Gluon) ResNet
Expand Down
4 changes: 4 additions & 0 deletions optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ def parse_args_onnx(parser):
"Also disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum."
),
)
optional_group.add_argument(
"--no-dynamic-axes", action="store_true", help="Disable dynamic axes during ONNX export"
)

input_group = parser.add_argument_group(
"Input shapes (if necessary, this allows to override the shapes of the input given to the ONNX exporter, that requires an example input)."
Expand Down Expand Up @@ -263,6 +266,7 @@ def run(self):
_variant=self.args.variant,
library_name=self.args.library_name,
legacy=self.args.legacy,
no_dynamic_axes=self.args.no_dynamic_axes,
model_kwargs=self.args.model_kwargs,
**input_shapes,
)
4 changes: 4 additions & 0 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def main_export(
_variant: str = "default",
library_name: Optional[str] = None,
legacy: bool = False,
no_dynamic_axes: bool = False,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -270,6 +271,8 @@ def main_export(
The library of the model (`"transformers"` or `"diffusers"` or `"timm"` or `"sentence_transformers"`). If not provided, will attempt to automatically detect the library name for the checkpoint.
legacy (`bool`, defaults to `False`):
Disable the use of position_ids for text-generation models that require it for batched generation. Also enable to export decoder only models in three files (without + with past and the merged model). This argument is introduced for backward compatibility and will be removed in a future release of Optimum.
no_dynamic_axes (bool, defaults to `False`):
If True, disables the use of dynamic axes during ONNX export.
**kwargs_shapes (`Dict`):
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.
Expand Down Expand Up @@ -556,6 +559,7 @@ def main_export(
input_shapes=input_shapes,
device=device,
dtype="fp16" if fp16 is True else None,
no_dynamic_axes=no_dynamic_axes,
model_kwargs=model_kwargs,
)

Expand Down
18 changes: 17 additions & 1 deletion optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ def export_pytorch(
device: str = "cpu",
dtype: Optional["torch.dtype"] = None,
input_shapes: Optional[Dict] = None,
no_dynamic_axes: bool = False,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[List[str], List[str]]:
"""
Expand All @@ -508,6 +509,8 @@ def export_pytorch(
Data type to remap the model inputs to. PyTorch-only. Only `torch.float16` is supported.
input_shapes (`Optional[Dict]`, defaults to `None`):
If specified, allows to use specific shapes for the example input provided to the ONNX exporter.
no_dynamic_axes (bool, defaults to `False`):
If True, disables the use of dynamic axes during ONNX export.
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
Experimental usage: keyword arguments to pass to the model during
the export. This argument should be used along the `custom_onnx_config` argument
Expand Down Expand Up @@ -562,6 +565,11 @@ def remap(value):
input_names = list(inputs.keys())
output_names = list(config.outputs.keys())

if no_dynamic_axes:
dynamix_axes = None
else:
dynamix_axes = dict(chain(inputs.items(), config.outputs.items()))

# Export can work with named args but the dict containing named args has to be the last element of the args
# tuple.
onnx_export(
Expand All @@ -570,7 +578,7 @@ def remap(value):
f=output.as_posix(),
input_names=input_names,
output_names=output_names,
dynamic_axes=dict(chain(inputs.items(), config.outputs.items())),
dynamic_axes=dynamix_axes,
do_constant_folding=True,
opset_version=opset,
)
Expand Down Expand Up @@ -694,6 +702,7 @@ def export_models(
input_shapes: Optional[Dict] = None,
disable_dynamic_axes_fix: Optional[bool] = False,
dtype: Optional[str] = None,
no_dynamic_axes: bool = False,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Expand All @@ -720,6 +729,8 @@ def export_models(
Whether to disable the default dynamic axes fixing.
dtype (`Optional[str]`, defaults to `None`):
Data type to remap the model inputs to. PyTorch-only. Only `fp16` is supported.
no_dynamic_axes (bool, defaults to `False`):
If True, disables the use of dynamic axes during ONNX export.
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
Experimental usage: keyword arguments to pass to the model during
the export. This argument should be used along the `custom_onnx_config` argument
Expand Down Expand Up @@ -753,6 +764,7 @@ def export_models(
input_shapes=input_shapes,
disable_dynamic_axes_fix=disable_dynamic_axes_fix,
dtype=dtype,
no_dynamic_axes=no_dynamic_axes,
model_kwargs=model_kwargs,
)
)
Expand All @@ -770,6 +782,7 @@ def export(
input_shapes: Optional[Dict] = None,
disable_dynamic_axes_fix: Optional[bool] = False,
dtype: Optional[str] = None,
no_dynamic_axes: bool = False,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[List[str], List[str]]:
"""
Expand All @@ -793,6 +806,8 @@ def export(
Whether to disable the default dynamic axes fixing.
dtype (`Optional[str]`, defaults to `None`):
Data type to remap the model inputs to. PyTorch-only. Only `fp16` is supported.
no_dynamic_axes (bool, defaults to `False`):
If True, disables the use of dynamic axes during ONNX export.
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
Experimental usage: keyword arguments to pass to the model during
the export. This argument should be used along the `custom_onnx_config` argument
Expand Down Expand Up @@ -855,6 +870,7 @@ def export(
device=device,
input_shapes=input_shapes,
dtype=torch_dtype,
no_dynamic_axes=no_dynamic_axes,
model_kwargs=model_kwargs,
)

Expand Down
21 changes: 21 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
"num_choices": [4],
}

NO_DYNAMIC_AXES_EXPORT_SHAPES_TRANSFORMERS = {
"batch_size": [1, 3, 5],
"num_choices": [2, 4],
"sequence_length": [8, 33, 96],
}

PYTORCH_EXPORT_MODELS_TINY = {
"albert": "hf-internal-testing/tiny-random-AlbertModel",
Expand Down Expand Up @@ -325,3 +330,19 @@
"sentence-transformers-clip": "sentence-transformers/all-MiniLM-L6-v2",
"sentence-transformers-transformer": "sentence-transformers/clip-ViT-B-32-multilingual-v1",
}


PYTORCH_TRANSFORMERS_MODEL_NO_DYNAMIC_AXES = {
"albert": "hf-internal-testing/tiny-random-AlbertModel",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"roberta": "hf-internal-testing/tiny-random-RobertaModel",
"roformer": "hf-internal-testing/tiny-random-RoFormerModel",
}


PYTORCH_TIMM_MODEL_NO_DYNAMIC_AXES = {
"default-timm-config": {
"timm/ese_vovnet39b.ra_in1k": ["image-classification"],
"timm/ese_vovnet19b_dw.ra_in1k": ["image-classification"],
}
}
105 changes: 104 additions & 1 deletion tests/exporters/onnx/test_exporters_onnx_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,20 @@
ONNX_DECODER_WITH_PAST_NAME,
ONNX_ENCODER_NAME,
)
from optimum.utils.testing_utils import require_diffusers, require_sentence_transformers, require_timm
from optimum.utils.testing_utils import grid_parameters, require_diffusers, require_sentence_transformers, require_timm


if is_torch_available():
from optimum.exporters.tasks import TasksManager

from ..exporters_utils import (
NO_DYNAMIC_AXES_EXPORT_SHAPES_TRANSFORMERS,
PYTORCH_EXPORT_MODELS_TINY,
PYTORCH_SENTENCE_TRANSFORMERS_MODEL,
PYTORCH_STABLE_DIFFUSION_MODEL,
PYTORCH_TIMM_MODEL,
PYTORCH_TIMM_MODEL_NO_DYNAMIC_AXES,
PYTORCH_TRANSFORMERS_MODEL_NO_DYNAMIC_AXES,
)


Expand Down Expand Up @@ -179,6 +182,7 @@ def _onnx_export(
device: str = "cpu",
fp16: bool = False,
variant: str = "default",
no_dynamic_axes: bool = False,
model_kwargs: Optional[Dict] = None,
):
with TemporaryDirectory() as tmpdir:
Expand All @@ -193,11 +197,54 @@ def _onnx_export(
monolith=monolith,
no_post_process=no_post_process,
_variant=variant,
no_dynamic_axes=no_dynamic_axes,
model_kwargs=model_kwargs,
)
except MinimumVersionError as e:
pytest.skip(f"Skipping due to minimum version requirements not met. Full error: {e}")

def _onnx_export_no_dynamic_axes(
self,
model_name: str,
task: str,
input_shape: dict,
input_shape_for_validation: tuple,
monolith: bool = False,
no_post_process: bool = False,
optimization_level: Optional[str] = None,
device: str = "cpu",
fp16: bool = False,
variant: str = "default",
model_kwargs: Optional[Dict] = None,
):
with TemporaryDirectory() as tmpdir:
try:
main_export(
model_name_or_path=model_name,
output=tmpdir,
task=task,
device=device,
fp16=fp16,
optimize=optimization_level,
monolith=monolith,
no_post_process=no_post_process,
_variant=variant,
no_dynamic_axes=True,
model_kwargs=model_kwargs,
**input_shape,
)

model = onnx.load(Path(tmpdir) / "model.onnx")

is_dynamic = any(dim.dim_param for dim in model.graph.input[0].type.tensor_type.shape.dim)
self.assertFalse(is_dynamic)

model_input_shape = [dim.dim_value for dim in model.graph.input[0].type.tensor_type.shape.dim]
self.assertEqual(model_input_shape, input_shape_for_validation)

except MinimumVersionError as e:
pytest.skip(f"Skipping due to minimum version requirements not met. Full error: {e}")

@parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL.items())
@require_torch
@require_vision
Expand Down Expand Up @@ -258,6 +305,32 @@ def test_exporters_cli_pytorch_cpu_timm(
):
self._onnx_export(model_name, task, monolith, no_post_process, variant=variant)

@parameterized.expand(_get_models_to_test(PYTORCH_TIMM_MODEL_NO_DYNAMIC_AXES, library_name="timm"))
@require_torch
@require_vision
@require_timm
@slow
@pytest.mark.timm_test
@pytest.mark.run_slow
def test_exporters_cli_pytorch_cpu_timm_no_dynamic_axes(
self,
test_name: str,
model_type: str,
model_name: str,
task: str,
variant: str,
monolith: bool,
no_post_process: bool,
):
input_shapes_iterator = grid_parameters({"batch_size": [1, 3, 5]}, yield_dict=True, add_test_name=False)
for input_shape in input_shapes_iterator:
# NOTE: The timm models use input shapes from the model config, so we need to fix the other shapes of the model.
input_shape_for_validation = [input_shape["batch_size"], 3, 224, 224]

self._onnx_export_no_dynamic_axes(
model_name, task, input_shape, input_shape_for_validation, monolith, no_post_process, variant=variant
)

@parameterized.expand(_get_models_to_test(PYTORCH_TIMM_MODEL, library_name="timm"))
@require_torch_gpu
@require_vision
Expand Down Expand Up @@ -322,6 +395,36 @@ def test_exporters_cli_pytorch_cpu(

self._onnx_export(model_name, task, monolith, no_post_process, variant=variant, model_kwargs=model_kwargs)

@parameterized.expand(_get_models_to_test(PYTORCH_TRANSFORMERS_MODEL_NO_DYNAMIC_AXES))
@require_torch
@require_vision
def test_exporters_cli_pytorch_cpu_no_dynamic_axes(
self,
test_name: str,
model_type: str,
model_name: str,
task: str,
variant: str,
monolith: bool,
no_post_process: bool,
):
input_shapes_iterator = grid_parameters(
NO_DYNAMIC_AXES_EXPORT_SHAPES_TRANSFORMERS, yield_dict=True, add_test_name=False
)
for input_shape in input_shapes_iterator:
if task == "multiple-choice":
input_shape_for_validation = [
input_shape["batch_size"],
input_shape["num_choices"],
input_shape["sequence_length"],
]
else:
input_shape_for_validation = [input_shape["batch_size"], input_shape["sequence_length"]]

self._onnx_export_no_dynamic_axes(
model_name, task, input_shape, input_shape_for_validation, monolith, no_post_process, variant=variant
)

@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY))
@require_vision
@require_torch_gpu
Expand Down

0 comments on commit 130197f

Please sign in to comment.