From 6c1eb7513bacc00b218e6cf4622804b2507414ad Mon Sep 17 00:00:00 2001 From: BowenBao Date: Fri, 29 Mar 2024 11:18:34 -0700 Subject: [PATCH 1/2] Fix graph building to exclude input, output and initializer from value_info [ghstack-poisoned] --- .../function_libs/torch_lib/graph_building.py | 12 ++++++-- .../torch_lib/graph_building_test.py | 30 +++++++++++++++++-- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/graph_building.py b/onnxscript/function_libs/torch_lib/graph_building.py index 6e15b3a04..2f4209805 100644 --- a/onnxscript/function_libs/torch_lib/graph_building.py +++ b/onnxscript/function_libs/torch_lib/graph_building.py @@ -822,9 +822,15 @@ def _override_with_symbolic_value_info_proto(self, onnx_model: onnx.ModelProto): # nn.Modules exported by dynamo exporter have unique call sites, their function # op_type name can serve to form the unique identifier for value info. # Store inside top level GraphProto. - existing_value_info.update(self.generate_subgraphs_value_info_proto()) - # Insert value info for nodes in top level graph. - existing_value_info.update(self.generate_maingraph_value_info_proto()) + new_value_info = self.generate_maingraph_value_info_proto() + # Do not store input, output or initializer into value_info + for name in onnx_model.graph.input: + new_value_info.pop(name.name, None) + for name in onnx_model.graph.output: + new_value_info.pop(name.name, None) + for name in self.initializers: + new_value_info.pop(name, None) + existing_value_info.update(new_value_info) onnx_model.graph.value_info.extend(existing_value_info.values()) return onnx_model diff --git a/onnxscript/function_libs/torch_lib/graph_building_test.py b/onnxscript/function_libs/torch_lib/graph_building_test.py index 3ff366d24..d36478127 100644 --- a/onnxscript/function_libs/torch_lib/graph_building_test.py +++ b/onnxscript/function_libs/torch_lib/graph_building_test.py @@ -3,7 +3,6 @@ # mypy: disable-error-code="arg-type,type-arg,valid-type" from __future__ import annotations -import os import unittest import torch @@ -140,7 +139,6 @@ def test_add_initializer_allows_adding_the_same_tensor_twice_using_same_name(sel class TestModelSaving(unittest.TestCase): - @unittest.skipIf(os.getenv("CI") == "true", "CI is not ready to run dyanmo_export.") def test_save_initializer_to_files_for_large_model(self): class MLP(torch.nn.Module): def __init__(self, input_size, hidden_size, output_size): @@ -167,6 +165,34 @@ def forward(self, x): # Assert model is larger than 2GB (~=3GB) self.assertGreater(model_proto.ByteSize(), 2**31) + def test_input_output_and_initializer_are_not_stored_in_value_info(self): + class MLP(torch.nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super().__init__() + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.fc2 = torch.nn.Linear(hidden_size, output_size) + self.relu = torch.nn.ReLU() + + def forward(self, x): + out = self.fc1(x) + out = self.relu(out) + out = self.fc2(out) + return out + + batch_size, input_size, hidden_size, output_size = 1, 4, 5, 10 + model = MLP(input_size, hidden_size, output_size) + x = torch.randn(batch_size, input_size) + + model_proto = torch.onnx.dynamo_export(model, x).model_proto + v_names = set(v.name for v in model_proto.graph.value_info) + print(v_names) + for i in model_proto.graph.input: + self.assertNotIn(i.name, v_names) + for o in model_proto.graph.output: + self.assertNotIn(o.name, v_names) + for i in model_proto.graph.initializer: + self.assertNotIn(i.name, v_names) + if __name__ == "__main__": unittest.main() From 17b53c148f5b5d886b31e45d33359ab2c6a6dd96 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Fri, 29 Mar 2024 14:14:14 -0700 Subject: [PATCH 2/2] Update on "Fix graph building to exclude input, output and initializer from value_info" Otherwise the emitted model proto instance violates the spec definition for value_info. [ghstack-poisoned] --- .../function_libs/torch_lib/graph_building.py | 14 +++++++------- .../function_libs/torch_lib/graph_building_test.py | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/graph_building.py b/onnxscript/function_libs/torch_lib/graph_building.py index 2f4209805..a35a61605 100644 --- a/onnxscript/function_libs/torch_lib/graph_building.py +++ b/onnxscript/function_libs/torch_lib/graph_building.py @@ -824,12 +824,12 @@ def _override_with_symbolic_value_info_proto(self, onnx_model: onnx.ModelProto): # Store inside top level GraphProto. new_value_info = self.generate_maingraph_value_info_proto() # Do not store input, output or initializer into value_info - for name in onnx_model.graph.input: - new_value_info.pop(name.name, None) - for name in onnx_model.graph.output: - new_value_info.pop(name.name, None) - for name in self.initializers: - new_value_info.pop(name, None) + for input in onnx_model.graph.input: + new_value_info.pop(input.name, None) + for output in onnx_model.graph.output: + new_value_info.pop(output.name, None) + for tensor in onnx_model.graph.initializer: + new_value_info.pop(tensor.name, None) existing_value_info.update(new_value_info) onnx_model.graph.value_info.extend(existing_value_info.values()) @@ -924,7 +924,7 @@ def generate_subgraphs_value_info_proto(self) -> Mapping[str, onnx.ValueInfoProt return named_value_info @runtime_typing.checked - def generate_maingraph_value_info_proto(self) -> Mapping[str, onnx.ValueInfoProto]: + def generate_maingraph_value_info_proto(self) -> Dict[str, onnx.ValueInfoProto]: """Returns value info proto for values in the main graph.""" named_value_info: Dict[str, onnx.ValueInfoProto] = {} for torch_value, tensor in self._value_to_tensor.items(): diff --git a/onnxscript/function_libs/torch_lib/graph_building_test.py b/onnxscript/function_libs/torch_lib/graph_building_test.py index d36478127..9787a16b4 100644 --- a/onnxscript/function_libs/torch_lib/graph_building_test.py +++ b/onnxscript/function_libs/torch_lib/graph_building_test.py @@ -184,8 +184,8 @@ def forward(self, x): x = torch.randn(batch_size, input_size) model_proto = torch.onnx.dynamo_export(model, x).model_proto - v_names = set(v.name for v in model_proto.graph.value_info) - print(v_names) + v_names = {v.name for v in model_proto.graph.value_info} + for i in model_proto.graph.input: self.assertNotIn(i.name, v_names) for o in model_proto.graph.output: