Skip to content

Commit

Permalink
Fix graph building to exclude input, output and initializer from valu…
Browse files Browse the repository at this point in the history
…e_info

ghstack-source-id: 4c40be31d6af9e243296bfa3dca08ae4a4a5b3b5
Pull Request resolved: #1321
  • Loading branch information
BowenBao committed Apr 1, 2024
1 parent 2612107 commit 908b378
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
14 changes: 10 additions & 4 deletions onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,9 +822,15 @@ def _override_with_symbolic_value_info_proto(self, onnx_model: onnx.ModelProto):
# nn.Modules exported by dynamo exporter have unique call sites, their function
# op_type name can serve to form the unique identifier for value info.
# Store inside top level GraphProto.
existing_value_info.update(self.generate_subgraphs_value_info_proto())
# Insert value info for nodes in top level graph.
existing_value_info.update(self.generate_maingraph_value_info_proto())
new_value_info = self.generate_maingraph_value_info_proto()
# Do not store input, output or initializer into value_info
for input in onnx_model.graph.input:
new_value_info.pop(input.name, None)
for output in onnx_model.graph.output:
new_value_info.pop(output.name, None)
for tensor in onnx_model.graph.initializer:
new_value_info.pop(tensor.name, None)
existing_value_info.update(new_value_info)
onnx_model.graph.value_info.extend(existing_value_info.values())

return onnx_model
Expand Down Expand Up @@ -918,7 +924,7 @@ def generate_subgraphs_value_info_proto(self) -> Mapping[str, onnx.ValueInfoProt
return named_value_info

@runtime_typing.checked
def generate_maingraph_value_info_proto(self) -> Mapping[str, onnx.ValueInfoProto]:
def generate_maingraph_value_info_proto(self) -> Dict[str, onnx.ValueInfoProto]:
"""Returns value info proto for values in the main graph."""
named_value_info: Dict[str, onnx.ValueInfoProto] = {}
for torch_value, tensor in self._value_to_tensor.items():
Expand Down
30 changes: 28 additions & 2 deletions onnxscript/function_libs/torch_lib/graph_building_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# mypy: disable-error-code="arg-type,type-arg,valid-type"
from __future__ import annotations

import os
import unittest

import torch
Expand Down Expand Up @@ -140,7 +139,6 @@ def test_add_initializer_allows_adding_the_same_tensor_twice_using_same_name(sel


class TestModelSaving(unittest.TestCase):
@unittest.skipIf(os.getenv("CI") == "true", "CI is not ready to run dyanmo_export.")
def test_save_initializer_to_files_for_large_model(self):
class MLP(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
Expand All @@ -167,6 +165,34 @@ def forward(self, x):
# Assert model is larger than 2GB (~=3GB)
self.assertGreater(model_proto.ByteSize(), 2**31)

def test_input_output_and_initializer_are_not_stored_in_value_info(self):
class MLP(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.fc2 = torch.nn.Linear(hidden_size, output_size)
self.relu = torch.nn.ReLU()

def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out

batch_size, input_size, hidden_size, output_size = 1, 4, 5, 10
model = MLP(input_size, hidden_size, output_size)
x = torch.randn(batch_size, input_size)

model_proto = torch.onnx.dynamo_export(model, x).model_proto
v_names = {v.name for v in model_proto.graph.value_info}

for i in model_proto.graph.input:
self.assertNotIn(i.name, v_names)
for o in model_proto.graph.output:
self.assertNotIn(o.name, v_names)
for i in model_proto.graph.initializer:
self.assertNotIn(i.name, v_names)


if __name__ == "__main__":
unittest.main()

0 comments on commit 908b378

Please sign in to comment.