From 6b0f78cbcf03bcae6d1bc43b236e0334b8d7bb3d Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 16 Sep 2024 17:21:11 -0700 Subject: [PATCH] Enable optimization via environment variable Signed-off-by: Ganesan Ramalingam --- onnxscript/_framework_apis/torch_2_5.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/onnxscript/_framework_apis/torch_2_5.py b/onnxscript/_framework_apis/torch_2_5.py index d011e0a17..980d376ab 100644 --- a/onnxscript/_framework_apis/torch_2_5.py +++ b/onnxscript/_framework_apis/torch_2_5.py @@ -19,7 +19,7 @@ import onnx -from onnxscript import ir +from onnxscript import ir, optimizer from onnxscript.function_libs.torch_lib import registration from onnxscript.ir import _external_data @@ -28,6 +28,9 @@ os.getenv("TORCH_ONNX_OFFLOAD_EXTERNAL_DATA_WITH_IR") != "0" ) +# Internal flag. Will go away. +_TORCH_ONNX_ENABLE_OPTIMIZATION = os.getenv("TORCH_ONNX_ENABLE_OPTIMIZATION") == "1'" + @dataclasses.dataclass(frozen=True) class _OnnxFunctionMeta: @@ -50,7 +53,8 @@ class _OnnxFunctionMeta: def optimize(model: ir.Model) -> ir.Model: """Optimize the model.""" - # TODO(justinchuby): Use the optimizer + if _TORCH_ONNX_ENABLE_OPTIMIZATION: + optimizer.optimize_ir(model) return model