Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix graph building to exclude input, output and initializer from value_info #1321

Merged
merged 1 commit into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,9 +822,15 @@ def _override_with_symbolic_value_info_proto(self, onnx_model: onnx.ModelProto):
# nn.Modules exported by dynamo exporter have unique call sites, their function
# op_type name can serve to form the unique identifier for value info.
# Store inside top level GraphProto.
existing_value_info.update(self.generate_subgraphs_value_info_proto())
# Insert value info for nodes in top level graph.
existing_value_info.update(self.generate_maingraph_value_info_proto())
new_value_info = self.generate_maingraph_value_info_proto()
# Do not store input, output or initializer into value_info
for input in onnx_model.graph.input:
new_value_info.pop(input.name, None)
for output in onnx_model.graph.output:
new_value_info.pop(output.name, None)
for tensor in onnx_model.graph.initializer:
new_value_info.pop(tensor.name, None)
existing_value_info.update(new_value_info)
onnx_model.graph.value_info.extend(existing_value_info.values())

return onnx_model
Expand Down Expand Up @@ -918,7 +924,7 @@ def generate_subgraphs_value_info_proto(self) -> Mapping[str, onnx.ValueInfoProt
return named_value_info

@runtime_typing.checked
def generate_maingraph_value_info_proto(self) -> Mapping[str, onnx.ValueInfoProto]:
def generate_maingraph_value_info_proto(self) -> Dict[str, onnx.ValueInfoProto]:
"""Returns value info proto for values in the main graph."""
named_value_info: Dict[str, onnx.ValueInfoProto] = {}
for torch_value, tensor in self._value_to_tensor.items():
Expand Down
30 changes: 28 additions & 2 deletions onnxscript/function_libs/torch_lib/graph_building_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# mypy: disable-error-code="arg-type,type-arg,valid-type"
from __future__ import annotations

import os
import unittest

import torch
Expand Down Expand Up @@ -140,7 +139,6 @@ def test_add_initializer_allows_adding_the_same_tensor_twice_using_same_name(sel


class TestModelSaving(unittest.TestCase):
@unittest.skipIf(os.getenv("CI") == "true", "CI is not ready to run dyanmo_export.")
def test_save_initializer_to_files_for_large_model(self):
class MLP(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
Expand All @@ -167,6 +165,34 @@ def forward(self, x):
# Assert model is larger than 2GB (~=3GB)
self.assertGreater(model_proto.ByteSize(), 2**31)

def test_input_output_and_initializer_are_not_stored_in_value_info(self):
class MLP(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.fc2 = torch.nn.Linear(hidden_size, output_size)
self.relu = torch.nn.ReLU()

def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out

batch_size, input_size, hidden_size, output_size = 1, 4, 5, 10
model = MLP(input_size, hidden_size, output_size)
x = torch.randn(batch_size, input_size)

model_proto = torch.onnx.dynamo_export(model, x).model_proto
v_names = {v.name for v in model_proto.graph.value_info}

for i in model_proto.graph.input:
self.assertNotIn(i.name, v_names)
for o in model_proto.graph.output:
self.assertNotIn(o.name, v_names)
for i in model_proto.graph.initializer:
self.assertNotIn(i.name, v_names)


if __name__ == "__main__":
unittest.main()
Loading