diff --git a/onnxscript/_framework_apis/torch_2_5.py b/onnxscript/_framework_apis/torch_2_5.py index d011e0a17..980d376ab 100644 --- a/onnxscript/_framework_apis/torch_2_5.py +++ b/onnxscript/_framework_apis/torch_2_5.py @@ -19,7 +19,7 @@ import onnx -from onnxscript import ir +from onnxscript import ir, optimizer from onnxscript.function_libs.torch_lib import registration from onnxscript.ir import _external_data @@ -28,6 +28,9 @@ os.getenv("TORCH_ONNX_OFFLOAD_EXTERNAL_DATA_WITH_IR") != "0" ) +# Internal flag. Will go away. +_TORCH_ONNX_ENABLE_OPTIMIZATION = os.getenv("TORCH_ONNX_ENABLE_OPTIMIZATION") == "1'" + @dataclasses.dataclass(frozen=True) class _OnnxFunctionMeta: @@ -50,7 +53,8 @@ class _OnnxFunctionMeta: def optimize(model: ir.Model) -> ir.Model: """Optimize the model.""" - # TODO(justinchuby): Use the optimizer + if _TORCH_ONNX_ENABLE_OPTIMIZATION: + optimizer.optimize_ir(model) return model