From 37b11fcff94681dea368ebfe0c4768844f2d3149 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 8 Oct 2024 14:47:36 -0700 Subject: [PATCH] [API] Create stable APIs for PyTorch 2.6 (#1896) - optimize is turned on. It will be controlled by an option in PyTorch - Remove the `_TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR` flag Co-authored-by: Ti-Tai Wang --- onnxscript/_framework_apis/torch_2_5.py | 72 +++++++++---------------- onnxscript/_framework_apis/torch_2_6.py | 26 +++++++++ 2 files changed, 52 insertions(+), 46 deletions(-) create mode 100644 onnxscript/_framework_apis/torch_2_6.py diff --git a/onnxscript/_framework_apis/torch_2_5.py b/onnxscript/_framework_apis/torch_2_5.py index 642660a43..eeebbb63d 100644 --- a/onnxscript/_framework_apis/torch_2_5.py +++ b/onnxscript/_framework_apis/torch_2_5.py @@ -17,17 +17,10 @@ import pathlib from typing import Callable -import onnx - from onnxscript import ir, optimizer from onnxscript.function_libs.torch_lib import registration from onnxscript.ir import _external_data -# Internal flag. Will go away. -_TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR = ( - os.getenv("TORCH_ONNX_OFFLOAD_EXTERNAL_DATA_WITH_IR") != "0" -) - @dataclasses.dataclass(frozen=True) class _OnnxFunctionMeta: @@ -83,45 +76,32 @@ def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike """Save the model with external data. The model is unchanged after saving.""" # TODO(#1835): Decide if we want to externalize large attributes as well - if _TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR: - initializer_values = tuple(model.graph.initializers.values()) - tensors = [v.const_value for v in initializer_values] - for tensor in tensors: - if tensor is None: - raise ValueError( - "The model contains uninitialized initializer values. " - "Please make sure all initializer values are initialized." - ) - destination_path = pathlib.Path(model_path) - base_dir = destination_path.parent - data_path = f"{destination_path.name}.data" - - external_tensors = _external_data.convert_tensors_to_external( - tensors, # type: ignore[arg-type] - base_dir, - data_path, - ) - - # Replace the initializer values with external tensors and save the model - for initializer, external_tensor in zip(initializer_values, external_tensors): - initializer.const_value = external_tensor - ir.save(model, model_path) - - # Restore the original initializer values so the model is unchanged - for initializer, tensor in zip(initializer_values, tensors): - initializer.const_value = tensor - - else: - destination_path = pathlib.Path(model_path) - # Create the directory if it does not exist - data_path = f"{destination_path.name}.data" - proto = ir.serde.serialize_model(model) - onnx.save_model( - proto, - model_path, - save_as_external_data=True, - location=data_path, - ) + initializer_values = tuple(model.graph.initializers.values()) + tensors = [v.const_value for v in initializer_values] + for tensor in tensors: + if tensor is None: + raise ValueError( + "The model contains uninitialized initializer values. " + "Please make sure all initializer values are initialized." + ) + destination_path = pathlib.Path(model_path) + base_dir = destination_path.parent + data_path = f"{destination_path.name}.data" + + external_tensors = _external_data.convert_tensors_to_external( + tensors, # type: ignore[arg-type] + base_dir, + data_path, + ) + + # Replace the initializer values with external tensors and save the model + for initializer, external_tensor in zip(initializer_values, external_tensors): + initializer.const_value = external_tensor + ir.save(model, model_path) + + # Restore the original initializer values so the model is unchanged + for initializer, tensor in zip(initializer_values, tensors): + initializer.const_value = tensor def get_torchlib_ops() -> list[_OnnxFunctionMeta]: diff --git a/onnxscript/_framework_apis/torch_2_6.py b/onnxscript/_framework_apis/torch_2_6.py new file mode 100644 index 000000000..ec929a1d8 --- /dev/null +++ b/onnxscript/_framework_apis/torch_2_6.py @@ -0,0 +1,26 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Stable APIs for PyTorch 2.6.""" + +from __future__ import annotations + +__all__ = [ + "check_model", + "convert_version", + "get_torchlib_ops", + "optimize", + "save_model_with_external_data", +] +from onnxscript import ir, optimizer +from onnxscript._framework_apis.torch_2_5 import ( + check_model, + convert_version, + get_torchlib_ops, + save_model_with_external_data, +) + + +def optimize(model: ir.Model) -> ir.Model: + """Optimize the model.""" + optimizer.optimize_ir(model) + return model