Skip to content

Commit

Permalink
[API] Create stable APIs for PyTorch 2.6 (#1896)
Browse files Browse the repository at this point in the history
- optimize is turned on. It will be controlled by an option in PyTorch
- Remove the `_TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR` flag

Co-authored-by: Ti-Tai Wang <[email protected]>
  • Loading branch information
justinchuby and titaiwangms authored Oct 8, 2024
1 parent 1426e9f commit 37b11fc
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 46 deletions.
72 changes: 26 additions & 46 deletions onnxscript/_framework_apis/torch_2_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,10 @@
import pathlib
from typing import Callable

import onnx

from onnxscript import ir, optimizer
from onnxscript.function_libs.torch_lib import registration
from onnxscript.ir import _external_data

# Internal flag. Will go away.
_TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR = (
os.getenv("TORCH_ONNX_OFFLOAD_EXTERNAL_DATA_WITH_IR") != "0"
)


@dataclasses.dataclass(frozen=True)
class _OnnxFunctionMeta:
Expand Down Expand Up @@ -83,45 +76,32 @@ def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike
"""Save the model with external data. The model is unchanged after saving."""

# TODO(#1835): Decide if we want to externalize large attributes as well
if _TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR:
initializer_values = tuple(model.graph.initializers.values())
tensors = [v.const_value for v in initializer_values]
for tensor in tensors:
if tensor is None:
raise ValueError(
"The model contains uninitialized initializer values. "
"Please make sure all initializer values are initialized."
)
destination_path = pathlib.Path(model_path)
base_dir = destination_path.parent
data_path = f"{destination_path.name}.data"

external_tensors = _external_data.convert_tensors_to_external(
tensors, # type: ignore[arg-type]
base_dir,
data_path,
)

# Replace the initializer values with external tensors and save the model
for initializer, external_tensor in zip(initializer_values, external_tensors):
initializer.const_value = external_tensor
ir.save(model, model_path)

# Restore the original initializer values so the model is unchanged
for initializer, tensor in zip(initializer_values, tensors):
initializer.const_value = tensor

else:
destination_path = pathlib.Path(model_path)
# Create the directory if it does not exist
data_path = f"{destination_path.name}.data"
proto = ir.serde.serialize_model(model)
onnx.save_model(
proto,
model_path,
save_as_external_data=True,
location=data_path,
)
initializer_values = tuple(model.graph.initializers.values())
tensors = [v.const_value for v in initializer_values]
for tensor in tensors:
if tensor is None:
raise ValueError(
"The model contains uninitialized initializer values. "
"Please make sure all initializer values are initialized."
)
destination_path = pathlib.Path(model_path)
base_dir = destination_path.parent
data_path = f"{destination_path.name}.data"

external_tensors = _external_data.convert_tensors_to_external(
tensors, # type: ignore[arg-type]
base_dir,
data_path,
)

# Replace the initializer values with external tensors and save the model
for initializer, external_tensor in zip(initializer_values, external_tensors):
initializer.const_value = external_tensor
ir.save(model, model_path)

# Restore the original initializer values so the model is unchanged
for initializer, tensor in zip(initializer_values, tensors):
initializer.const_value = tensor


def get_torchlib_ops() -> list[_OnnxFunctionMeta]:
Expand Down
26 changes: 26 additions & 0 deletions onnxscript/_framework_apis/torch_2_6.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Stable APIs for PyTorch 2.6."""

from __future__ import annotations

__all__ = [
"check_model",
"convert_version",
"get_torchlib_ops",
"optimize",
"save_model_with_external_data",
]
from onnxscript import ir, optimizer
from onnxscript._framework_apis.torch_2_5 import (
check_model,
convert_version,
get_torchlib_ops,
save_model_with_external_data,
)


def optimize(model: ir.Model) -> ir.Model:
"""Optimize the model."""
optimizer.optimize_ir(model)
return model

0 comments on commit 37b11fc

Please sign in to comment.