From 908b3782e25cfb8b1787099c44dac2ed07d75686 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Mon, 1 Apr 2024 09:36:57 -0700 Subject: [PATCH] Fix graph building to exclude input, output and initializer from value_info ghstack-source-id: 4c40be31d6af9e243296bfa3dca08ae4a4a5b3b5 Pull Request resolved: https://github.com/microsoft/onnx-script/pull/1321 --- .../function_libs/torch_lib/graph_building.py | 14 ++++++--- .../torch_lib/graph_building_test.py | 30 +++++++++++++++++-- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/graph_building.py b/onnxscript/function_libs/torch_lib/graph_building.py index 6e15b3a04..a35a61605 100644 --- a/onnxscript/function_libs/torch_lib/graph_building.py +++ b/onnxscript/function_libs/torch_lib/graph_building.py @@ -822,9 +822,15 @@ def _override_with_symbolic_value_info_proto(self, onnx_model: onnx.ModelProto): # nn.Modules exported by dynamo exporter have unique call sites, their function # op_type name can serve to form the unique identifier for value info. # Store inside top level GraphProto. - existing_value_info.update(self.generate_subgraphs_value_info_proto()) - # Insert value info for nodes in top level graph. - existing_value_info.update(self.generate_maingraph_value_info_proto()) + new_value_info = self.generate_maingraph_value_info_proto() + # Do not store input, output or initializer into value_info + for input in onnx_model.graph.input: + new_value_info.pop(input.name, None) + for output in onnx_model.graph.output: + new_value_info.pop(output.name, None) + for tensor in onnx_model.graph.initializer: + new_value_info.pop(tensor.name, None) + existing_value_info.update(new_value_info) onnx_model.graph.value_info.extend(existing_value_info.values()) return onnx_model @@ -918,7 +924,7 @@ def generate_subgraphs_value_info_proto(self) -> Mapping[str, onnx.ValueInfoProt return named_value_info @runtime_typing.checked - def generate_maingraph_value_info_proto(self) -> Mapping[str, onnx.ValueInfoProto]: + def generate_maingraph_value_info_proto(self) -> Dict[str, onnx.ValueInfoProto]: """Returns value info proto for values in the main graph.""" named_value_info: Dict[str, onnx.ValueInfoProto] = {} for torch_value, tensor in self._value_to_tensor.items(): diff --git a/onnxscript/function_libs/torch_lib/graph_building_test.py b/onnxscript/function_libs/torch_lib/graph_building_test.py index 3ff366d24..9787a16b4 100644 --- a/onnxscript/function_libs/torch_lib/graph_building_test.py +++ b/onnxscript/function_libs/torch_lib/graph_building_test.py @@ -3,7 +3,6 @@ # mypy: disable-error-code="arg-type,type-arg,valid-type" from __future__ import annotations -import os import unittest import torch @@ -140,7 +139,6 @@ def test_add_initializer_allows_adding_the_same_tensor_twice_using_same_name(sel class TestModelSaving(unittest.TestCase): - @unittest.skipIf(os.getenv("CI") == "true", "CI is not ready to run dyanmo_export.") def test_save_initializer_to_files_for_large_model(self): class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): @@ -167,6 +165,34 @@ def forward(self, x): # Assert model is larger than 2GB (~=3GB) self.assertGreater(model_proto.ByteSize(), 2**31) + def test_input_output_and_initializer_are_not_stored_in_value_info(self): + class MLP(torch.nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super().__init__() + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.fc2 = torch.nn.Linear(hidden_size, output_size) + self.relu = torch.nn.ReLU() + + def forward(self, x): + out = self.fc1(x) + out = self.relu(out) + out = self.fc2(out) + return out + + batch_size, input_size, hidden_size, output_size = 1, 4, 5, 10 + model = MLP(input_size, hidden_size, output_size) + x = torch.randn(batch_size, input_size) + + model_proto = torch.onnx.dynamo_export(model, x).model_proto + v_names = {v.name for v in model_proto.graph.value_info} + + for i in model_proto.graph.input: + self.assertNotIn(i.name, v_names) + for o in model_proto.graph.output: + self.assertNotIn(o.name, v_names) + for i in model_proto.graph.initializer: + self.assertNotIn(i.name, v_names) + if __name__ == "__main__": unittest.main()