Skip to content

Commit

Permalink
Add option to disable ONNX constant folding (#1682)
Browse files Browse the repository at this point in the history
* optionally disable onnx constant folding

* Update optimum/exporters/onnx/__main__.py

Co-authored-by: Michael Benayoun <[email protected]>

---------

Co-authored-by: Michael Benayoun <[email protected]>
  • Loading branch information
fxmarty and michaelbenayoun authored Feb 6, 2024
1 parent e0e12ed commit c05ab93
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
6 changes: 6 additions & 0 deletions optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ def parse_args_onnx(parser):
optional_group.add_argument(
"--no-dynamic-axes", action="store_true", help="Disable dynamic axes during ONNX export"
)
optional_group.add_argument(
"--no-constant-folding",
action="store_true",
help="PyTorch-only argument. Disables PyTorch ONNX export constant folding.",
)

input_group = parser.add_argument_group(
"Input shapes (if necessary, this allows to override the shapes of the input given to the ONNX exporter, that requires an example input)."
Expand Down Expand Up @@ -276,5 +281,6 @@ def run(self):
legacy=self.args.legacy,
no_dynamic_axes=self.args.no_dynamic_axes,
model_kwargs=self.args.model_kwargs,
do_constant_folding=not self.args.no_constant_folding,
**input_shapes,
)
7 changes: 7 additions & 0 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def main_export(
library_name: Optional[str] = None,
legacy: bool = False,
no_dynamic_axes: bool = False,
do_constant_folding: bool = True,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -275,6 +276,8 @@ def main_export(
Disable the use of position_ids for text-generation models that require it for batched generation. Also enable to export decoder only models in three files (without + with past and the merged model). This argument is introduced for backward compatibility and will be removed in a future release of Optimum.
no_dynamic_axes (bool, defaults to `False`):
If True, disables the use of dynamic axes during ONNX export.
do_constant_folding (bool, defaults to `True`):
PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
**kwargs_shapes (`Dict`):
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.
Expand Down Expand Up @@ -485,6 +488,7 @@ def main_export(
no_dynamic_axes=no_dynamic_axes,
task=task,
use_subprocess=use_subprocess,
do_constant_folding=do_constant_folding,
**kwargs_shapes,
)

Expand All @@ -508,6 +512,7 @@ def onnx_export(
no_dynamic_axes: bool = False,
task: Optional[str] = None,
use_subprocess: bool = False,
do_constant_folding: bool = True,
**kwargs_shapes,
):
library_name = TasksManager._infer_library_from_model(model)
Expand Down Expand Up @@ -676,6 +681,7 @@ def onnx_export(
device=device,
dtype=float_dtype,
no_dynamic_axes=no_dynamic_axes,
do_constant_folding=do_constant_folding,
model_kwargs=model_kwargs,
)

Expand Down Expand Up @@ -775,6 +781,7 @@ def main():
for_ort=args.for_ort,
library_name=args.library_name,
legacy=args.legacy,
do_constant_folding=not args.no_constant_folding,
**input_shapes,
)

Expand Down
13 changes: 12 additions & 1 deletion optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ def export_pytorch(
device: str = "cpu",
input_shapes: Optional[Dict] = None,
no_dynamic_axes: bool = False,
do_constant_folding: bool = True,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[List[str], List[str]]:
"""
Expand All @@ -498,6 +499,8 @@ def export_pytorch(
If specified, allows to use specific shapes for the example input provided to the ONNX exporter.
no_dynamic_axes (bool, defaults to `False`):
If True, disables the use of dynamic axes during ONNX export.
do_constant_folding (bool, defaults to `True`):
PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
Experimental usage: keyword arguments to pass to the model during
the export. This argument should be used along the `custom_onnx_config` argument
Expand Down Expand Up @@ -566,7 +569,7 @@ def remap(value):
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamix_axes,
do_constant_folding=True,
do_constant_folding=do_constant_folding,
opset_version=opset,
)

Expand Down Expand Up @@ -690,6 +693,7 @@ def export_models(
disable_dynamic_axes_fix: Optional[bool] = False,
dtype: Optional[str] = None,
no_dynamic_axes: bool = False,
do_constant_folding: bool = True,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Expand Down Expand Up @@ -718,6 +722,8 @@ def export_models(
Data type to remap the model inputs to. PyTorch-only. Only `fp16` is supported.
no_dynamic_axes (bool, defaults to `False`):
If True, disables the use of dynamic axes during ONNX export.
do_constant_folding (bool, defaults to `True`):
PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
Experimental usage: keyword arguments to pass to the model during
the export. This argument should be used along the `custom_onnx_config` argument
Expand Down Expand Up @@ -752,6 +758,7 @@ def export_models(
disable_dynamic_axes_fix=disable_dynamic_axes_fix,
dtype=dtype,
no_dynamic_axes=no_dynamic_axes,
do_constant_folding=do_constant_folding,
model_kwargs=model_kwargs,
)
)
Expand All @@ -770,6 +777,7 @@ def export(
disable_dynamic_axes_fix: Optional[bool] = False,
dtype: Optional[str] = None,
no_dynamic_axes: bool = False,
do_constant_folding: bool = True,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[List[str], List[str]]:
"""
Expand All @@ -795,6 +803,8 @@ def export(
Data type to remap the model inputs to. PyTorch-only. Only `fp16` is supported.
no_dynamic_axes (bool, defaults to `False`):
If True, disables the use of dynamic axes during ONNX export.
do_constant_folding (bool, defaults to `True`):
PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
Experimental usage: keyword arguments to pass to the model during
the export. This argument should be used along the `custom_onnx_config` argument
Expand Down Expand Up @@ -851,6 +861,7 @@ def export(
device=device,
input_shapes=input_shapes,
no_dynamic_axes=no_dynamic_axes,
do_constant_folding=do_constant_folding,
model_kwargs=model_kwargs,
)

Expand Down

0 comments on commit c05ab93

Please sign in to comment.