From 7aa2bb6dcf038d93b75af6edef1098b10906a9b1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 11 Jul 2025 21:58:14 +0000 Subject: [PATCH 1/4] Initial plan From 1f758586eeda5ed81a4a326d19e7d0ad65c2f856 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 11 Jul 2025 22:13:04 +0000 Subject: [PATCH 2/4] Implement NameFixPass for ensuring unique names Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- src/onnx_ir/passes/common/__init__.py | 2 + src/onnx_ir/passes/common/naming.py | 255 ++++++++++++++ src/onnx_ir/passes/common/naming_test.py | 410 +++++++++++++++++++++++ 3 files changed, 667 insertions(+) create mode 100644 src/onnx_ir/passes/common/naming.py create mode 100644 src/onnx_ir/passes/common/naming_test.py diff --git a/src/onnx_ir/passes/common/__init__.py b/src/onnx_ir/passes/common/__init__.py index 2aee4df..a8d7317 100644 --- a/src/onnx_ir/passes/common/__init__.py +++ b/src/onnx_ir/passes/common/__init__.py @@ -11,6 +11,7 @@ "InlinePass", "LiftConstantsToInitializersPass", "LiftSubgraphInitializersToMainGraphPass", + "NameFixPass", "RemoveInitializersFromInputsPass", "RemoveUnusedFunctionsPass", "RemoveUnusedNodesPass", @@ -38,6 +39,7 @@ DeduplicateInitializersPass, ) from onnx_ir.passes.common.inliner import InlinePass +from onnx_ir.passes.common.naming import NameFixPass from onnx_ir.passes.common.onnx_checker import CheckerPass from onnx_ir.passes.common.shape_inference import ShapeInferencePass from onnx_ir.passes.common.topological_sort import TopologicalSortPass diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py new file mode 100644 index 0000000..09ccfe9 --- /dev/null +++ b/src/onnx_ir/passes/common/naming.py @@ -0,0 +1,255 @@ +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 +"""Name fix pass for ensuring unique names for all values and nodes.""" + +from __future__ import annotations + +__all__ = [ + "NameFixPass", +] + +import logging +from collections.abc import Set as AbstractSet + +import onnx_ir as ir + +logger = logging.getLogger(__name__) + + +class NameFixPass(ir.passes.InPlacePass): + """Pass for fixing names to ensure all values and nodes have unique names. + + This pass ensures that: + 1. Graph inputs and outputs have unique names (take precedence) + 2. All intermediate values have unique names (assign names to unnamed values) + 3. All values in subgraphs have unique names + 4. All nodes have unique names (assign names to unnamed nodes) + + The pass maintains global uniqueness across the entire model. + """ + + def call(self, model: ir.Model) -> ir.passes.PassResult: + """Main entry point for the name fix pass.""" + modified = False + + # Use sets to track seen names globally + seen_value_names: set[str] = set() + seen_node_names: set[str] = set() + + # Counters for generating unique names (using list to pass by reference) + value_counter = [0] + node_counter = [0] + + # Process the main graph + if self._fix_graph_names( + model.graph, seen_value_names, seen_node_names, value_counter, node_counter + ): + modified = True + + # Process functions + for function in model.functions.values(): + if self._fix_function_names( + function, seen_value_names, seen_node_names, value_counter, node_counter + ): + modified = True + + if modified: + logger.info("Name fix pass modified the model") + + return ir.passes.PassResult(model, modified=modified) + + + def _fix_graph_names( + self, + graph: ir.Graph, + seen_value_names: set[str], + seen_node_names: set[str], + value_counter: list[int], + node_counter: list[int], + ) -> bool: + """Fix names in a graph and return whether modifications were made.""" + modified = False + + # Keep track of values we've already processed to avoid double-processing + processed_values: set[ir.Value] = set() + + # Step 1: Fix graph input names first (they have precedence) + for input_value in graph.inputs: + if self._process_value(input_value, seen_value_names, value_counter, processed_values): + modified = True + + # Step 2: Fix graph output names (they have precedence) + for output_value in graph.outputs: + if self._process_value(output_value, seen_value_names, value_counter, processed_values): + modified = True + + # Step 3: Fix initializer names + for initializer in graph.initializers.values(): + if self._process_value(initializer, seen_value_names, value_counter, processed_values): + modified = True + + # Step 4: Process all nodes and their values + for node in ir.traversal.RecursiveGraphIterator(graph): + # Fix node name + if node.name is None or node.name == "": + if self._assign_node_name(node, seen_node_names, node_counter): + modified = True + else: + if self._fix_duplicate_node_name(node, seen_node_names): + modified = True + + # Fix input value names (only if not already processed) + for input_value in node.inputs: + if input_value is not None: + if self._process_value(input_value, seen_value_names, value_counter, processed_values): + modified = True + + # Fix output value names (only if not already processed) + for output_value in node.outputs: + if self._process_value(output_value, seen_value_names, value_counter, processed_values): + modified = True + + return modified + + def _fix_function_names( + self, + function: ir.Function, + seen_value_names: set[str], + seen_node_names: set[str], + value_counter: list[int], + node_counter: list[int], + ) -> bool: + """Fix names in a function and return whether modifications were made.""" + modified = False + + # Keep track of values we've already processed to avoid double-processing + processed_values: set[ir.Value] = set() + + # Process function inputs first (they have precedence) + for input_value in function.inputs: + if self._process_value(input_value, seen_value_names, value_counter, processed_values): + modified = True + + # Process function outputs (they have precedence) + for output_value in function.outputs: + if self._process_value(output_value, seen_value_names, value_counter, processed_values): + modified = True + + # Process all nodes and their values + for node in ir.traversal.RecursiveGraphIterator(function): + # Fix node name + if node.name is None or node.name == "": + if self._assign_node_name(node, seen_node_names, node_counter): + modified = True + else: + if self._fix_duplicate_node_name(node, seen_node_names): + modified = True + + # Fix input value names (only if not already processed) + for input_value in node.inputs: + if input_value is not None: + if self._process_value(input_value, seen_value_names, value_counter, processed_values): + modified = True + + # Fix output value names (only if not already processed) + for output_value in node.outputs: + if self._process_value(output_value, seen_value_names, value_counter, processed_values): + modified = True + + return modified + + def _process_value( + self, + value: ir.Value, + seen_value_names: set[str], + value_counter: list[int], + processed_values: set[ir.Value] + ) -> bool: + """Process a value only if it hasn't been processed before.""" + if value in processed_values: + return False + + processed_values.add(value) + + if value.name is None or value.name == "": + return self._assign_value_name(value, seen_value_names, value_counter) + else: + return self._fix_duplicate_value_name(value, seen_value_names) + + def _assign_value_name( + self, value: ir.Value, seen_names: set[str], counter: list[int] + ) -> bool: + """Assign a name to an unnamed value. Returns True if modified.""" + while True: + new_name = f"val_{counter[0]}" + counter[0] += 1 + if new_name not in seen_names: + value.name = new_name + seen_names.add(new_name) + logger.debug("Assigned name %s to unnamed value", new_name) + return True + + def _assign_node_name( + self, node: ir.Node, seen_names: set[str], counter: list[int] + ) -> bool: + """Assign a name to an unnamed node. Returns True if modified.""" + while True: + new_name = f"node_{counter[0]}" + counter[0] += 1 + if new_name not in seen_names: + node.name = new_name + seen_names.add(new_name) + logger.debug("Assigned name %s to unnamed node", new_name) + return True + + def _fix_duplicate_value_name( + self, value: ir.Value, seen_names: set[str] + ) -> bool: + """Fix a value's name if it conflicts with existing names. Returns True if modified.""" + original_name = value.name + + if original_name is None or original_name == "": + return False # Should not happen if called correctly + + # If name is already seen, make it unique + if original_name in seen_names: + base_name = original_name + suffix = 1 + while True: + new_name = f"{base_name}_{suffix}" + if new_name not in seen_names: + value.name = new_name + seen_names.add(new_name) + logger.debug("Renamed value from %s to %s for uniqueness", original_name, new_name) + return True + suffix += 1 + else: + # Name is unique, just record it + seen_names.add(original_name) + return False + + def _fix_duplicate_node_name( + self, node: ir.Node, seen_names: set[str] + ) -> bool: + """Fix a node's name if it conflicts with existing names. Returns True if modified.""" + original_name = node.name + + if original_name is None or original_name == "": + return False # Should not happen if called correctly + + # If name is already seen, make it unique + if original_name in seen_names: + base_name = original_name + suffix = 1 + while True: + new_name = f"{base_name}_{suffix}" + if new_name not in seen_names: + node.name = new_name + seen_names.add(new_name) + logger.debug("Renamed node from %s to %s for uniqueness", original_name, new_name) + return True + suffix += 1 + else: + # Name is unique, just record it + seen_names.add(original_name) + return False \ No newline at end of file diff --git a/src/onnx_ir/passes/common/naming_test.py b/src/onnx_ir/passes/common/naming_test.py new file mode 100644 index 0000000..5eb7863 --- /dev/null +++ b/src/onnx_ir/passes/common/naming_test.py @@ -0,0 +1,410 @@ +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the name fix pass.""" + +from __future__ import annotations + +import unittest + +import onnx_ir as ir +from onnx_ir.passes.common import naming + + +class TestNameFixPass(unittest.TestCase): + """Test cases for NameFixPass.""" + + def test_assign_names_to_unnamed_values(self): + """Test ensuring all values have names even if IR auto-assigned them.""" + # Create a simple model with auto-assigned names + input_value = ir.Input( + None, shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) # Will get auto-assigned name when added to graph + + # Create Add node + add_node = ir.Node("", "Add", inputs=[input_value, input_value]) + + graph = ir.Graph( + inputs=[input_value], + outputs=[add_node.outputs[0]], + nodes=[add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Verify IR has auto-assigned names + self.assertIsNotNone(input_value.name) + self.assertIsNotNone(add_node.outputs[0].name) + + # Store original names + original_input_name = input_value.name + original_output_name = add_node.outputs[0].name + + # Run the pass + pass_instance = naming.NameFixPass() + result = pass_instance(model) + + # Verify the pass didn't modify anything (names were already assigned and unique) + self.assertFalse(result.modified) + + # Verify names remain the same + self.assertEqual(input_value.name, original_input_name) + self.assertEqual(add_node.outputs[0].name, original_output_name) + + def test_assign_names_to_unnamed_nodes(self): + """Test ensuring all nodes have names even if IR auto-assigned them.""" + # Create a simple model + input_value = ir.Input( + "input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Create Add node - IR will auto-assign name when added to graph + add_node = ir.Node("", "Add", inputs=[input_value, input_value]) + add_node.outputs[0].name = "output" + add_node.outputs[0].shape = input_value.shape + add_node.outputs[0].type = input_value.type + + graph = ir.Graph( + inputs=[input_value], + outputs=[add_node.outputs[0]], + nodes=[add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Verify IR has auto-assigned node name + self.assertIsNotNone(add_node.name) + original_node_name = add_node.name + + # Run the pass + pass_instance = naming.NameFixPass() + result = pass_instance(model) + + # Verify the pass didn't modify anything (node already had unique name) + self.assertFalse(result.modified) + + # Verify node name remains the same + self.assertEqual(add_node.name, original_node_name) + + def test_assigns_names_when_truly_unnamed(self): + """Test that the pass assigns names when values/nodes are created without names and manually cleared.""" + # Create a model and manually clear names to test assignment + input_value = ir.Input( + "input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + add_node = ir.Node("", "Add", inputs=[input_value, input_value]) + add_node.outputs[0].name = "output" + add_node.outputs[0].shape = input_value.shape + add_node.outputs[0].type = input_value.type + + graph = ir.Graph( + inputs=[input_value], + outputs=[add_node.outputs[0]], + nodes=[add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Manually clear some names to test assignment + add_node.name = None + add_node.outputs[0].name = "" + + # Run the pass + pass_instance = naming.NameFixPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify names were assigned + self.assertIsNotNone(add_node.name) + self.assertIsNotNone(add_node.outputs[0].name) + self.assertNotEqual(add_node.outputs[0].name, "") + + def test_handles_global_uniqueness_across_subgraphs(self): + """Test that names are unique globally, including across subgraphs.""" + # Create main graph input + main_input = ir.Input( + "main_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Create a simple subgraph for an If node + # Subgraph input and output (with potential name conflicts) + sub_input = ir.Input( + "main_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) # Same name as main input - should cause conflict + + sub_add_node = ir.Node("", "Add", inputs=[sub_input, sub_input]) + sub_add_node.outputs[0].name = "main_input" # Another conflict + sub_add_node.outputs[0].shape = sub_input.shape + sub_add_node.outputs[0].type = sub_input.type + + subgraph = ir.Graph( + inputs=[sub_input], + outputs=[sub_add_node.outputs[0]], + nodes=[sub_add_node], + name="subgraph", + ) + + # Create condition input for If node + condition_input = ir.Input( + "condition", shape=ir.Shape([]), type=ir.TensorType(ir.DataType.BOOL) + ) + + # Create If node with subgraph + if_node = ir.Node( + "", "If", + inputs=[condition_input], + attributes={"then_branch": ir.Attr("then_branch", ir.AttributeType.GRAPH, subgraph)} + ) + if_node.outputs[0].name = "if_output" + if_node.outputs[0].shape = main_input.shape + if_node.outputs[0].type = main_input.type + + # Create main graph + main_graph = ir.Graph( + inputs=[main_input, condition_input], + outputs=[if_node.outputs[0]], + nodes=[if_node], + name="main_graph", + ) + + model = ir.Model(main_graph, ir_version=10) + + # Run the pass + pass_instance = naming.NameFixPass() + result = pass_instance(model) + + # Verify the pass was applied (should fix duplicates) + self.assertTrue(result.modified) + + # Collect all value names to verify uniqueness + all_value_names = set() + + # Main graph values + for input_val in main_graph.inputs: + self.assertIsNotNone(input_val.name) + self.assertNotIn(input_val.name, all_value_names, f"Duplicate value name: {input_val.name}") + all_value_names.add(input_val.name) + + for output_val in main_graph.outputs: + self.assertIsNotNone(output_val.name) + if output_val.name not in all_value_names: # Could be same as input + all_value_names.add(output_val.name) + + # Node values in main graph + for node in main_graph: + for input_val in node.inputs: + if input_val is not None: + if input_val.name not in all_value_names: + all_value_names.add(input_val.name) + for output_val in node.outputs: + if output_val.name not in all_value_names: + all_value_names.add(output_val.name) + + # Subgraph values + for input_val in subgraph.inputs: + self.assertIsNotNone(input_val.name) + self.assertNotIn(input_val.name, all_value_names, f"Duplicate value name in subgraph: {input_val.name}") + all_value_names.add(input_val.name) + + for output_val in subgraph.outputs: + if output_val.name not in all_value_names: # Could be same as input + all_value_names.add(output_val.name) + + # Node values in subgraph + for node in subgraph: + for input_val in node.inputs: + if input_val is not None: + if input_val.name not in all_value_names: + all_value_names.add(input_val.name) + for output_val in node.outputs: + if output_val.name not in all_value_names: + all_value_names.add(output_val.name) + + # Verify main_input keeps its name (has precedence as graph input) + self.assertEqual(main_input.name, "main_input") + + def test_handle_duplicate_value_names(self): + """Test handling duplicate value names by making them unique.""" + # Create values with duplicate names + input1 = ir.Input( + "duplicate_name", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + input2 = ir.Input( + "duplicate_name", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + add_node = ir.Node("", "Add", inputs=[input1, input2]) + add_node.outputs[0].name = "output" + add_node.outputs[0].shape = input1.shape + add_node.outputs[0].type = input1.type + + graph = ir.Graph( + inputs=[input1, input2], + outputs=[add_node.outputs[0]], + nodes=[add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Verify both inputs have the same name initially + self.assertEqual(input1.name, "duplicate_name") + self.assertEqual(input2.name, "duplicate_name") + + # Run the pass + pass_instance = naming.NameFixPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify names are now unique + self.assertNotEqual(input1.name, input2.name) + # One should keep the original name, the other should have a suffix + names = {input1.name, input2.name} + self.assertIn("duplicate_name", names) + self.assertTrue( + "duplicate_name_1" in names, + f"Expected 'duplicate_name_1' in {names}" + ) + + def test_handle_duplicate_node_names(self): + """Test handling duplicate node names by making them unique.""" + input_value = ir.Input( + "input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Create nodes with duplicate names + add_node1 = ir.Node("", "Add", inputs=[input_value, input_value]) + add_node1.name = "duplicate_node" + add_node1.outputs[0].name = "output1" + add_node1.outputs[0].shape = input_value.shape + add_node1.outputs[0].type = input_value.type + + add_node2 = ir.Node("", "Add", inputs=[input_value, add_node1.outputs[0]]) + add_node2.name = "duplicate_node" # Same name as first node + add_node2.outputs[0].name = "output2" + add_node2.outputs[0].shape = input_value.shape + add_node2.outputs[0].type = input_value.type + + graph = ir.Graph( + inputs=[input_value], + outputs=[add_node2.outputs[0]], + nodes=[add_node1, add_node2], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Verify both nodes have the same name initially + self.assertEqual(add_node1.name, "duplicate_node") + self.assertEqual(add_node2.name, "duplicate_node") + + # Run the pass + pass_instance = naming.NameFixPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify names are now unique + self.assertNotEqual(add_node1.name, add_node2.name) + # One should keep the original name, the other should have a suffix + names = {add_node1.name, add_node2.name} + self.assertIn("duplicate_node", names) + self.assertTrue( + "duplicate_node_1" in names, + f"Expected 'duplicate_node_1' in {names}" + ) + + def test_no_modification_when_all_names_unique(self): + """Test that the pass doesn't modify anything when all names are already unique.""" + input_value = ir.Input( + "unique_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + add_node = ir.Node("", "Add", inputs=[input_value, input_value]) + add_node.name = "unique_node" + add_node.outputs[0].name = "unique_output" + add_node.outputs[0].shape = input_value.shape + add_node.outputs[0].type = input_value.type + + graph = ir.Graph( + inputs=[input_value], + outputs=[add_node.outputs[0]], + nodes=[add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Store original names + original_input_name = input_value.name + original_node_name = add_node.name + original_output_name = add_node.outputs[0].name + + # Run the pass + pass_instance = naming.NameFixPass() + result = pass_instance(model) + + # Verify the pass didn't modify anything + self.assertFalse(result.modified) + + # Verify names remain unchanged + self.assertEqual(input_value.name, original_input_name) + self.assertEqual(add_node.name, original_node_name) + self.assertEqual(add_node.outputs[0].name, original_output_name) + + def test_graph_inputs_outputs_have_precedence(self): + """Test that graph inputs and outputs keep their names when there are conflicts.""" + # Create an input with a specific name + input_value = ir.Input( + "important_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Create a node that produces an intermediate value with the same name + add_node = ir.Node("", "Add", inputs=[input_value, input_value]) + add_node.outputs[0].name = "important_input" # Conflicts with input name + add_node.outputs[0].shape = input_value.shape + add_node.outputs[0].type = input_value.type + + # Create another node that uses the intermediate value and produces the final output + mul_node = ir.Node("", "Mul", inputs=[add_node.outputs[0], input_value]) + mul_node.outputs[0].name = "important_output" + mul_node.outputs[0].shape = input_value.shape + mul_node.outputs[0].type = input_value.type + + graph = ir.Graph( + inputs=[input_value], + outputs=[mul_node.outputs[0]], + nodes=[add_node, mul_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = naming.NameFixPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify input keeps its original name (has precedence) + self.assertEqual(input_value.name, "important_input") + + # Verify output keeps its original name (has precedence) + self.assertEqual(mul_node.outputs[0].name, "important_output") + + # Verify intermediate value got renamed to avoid conflict + self.assertNotEqual(add_node.outputs[0].name, "important_input") + self.assertTrue(add_node.outputs[0].name.startswith("important_input_")) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From b979a3d444b274c3d078b82d655e7b52350ad5dd Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 11 Jul 2025 22:22:57 +0000 Subject: [PATCH 3/4] Replace processed values set with value-to-name mapping for clearer tracking Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- src/onnx_ir/passes/common/naming.py | 50 +++++++++++++++-------------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py index 09ccfe9..c2f86d5 100644 --- a/src/onnx_ir/passes/common/naming.py +++ b/src/onnx_ir/passes/common/naming.py @@ -36,20 +36,23 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: seen_value_names: set[str] = set() seen_node_names: set[str] = set() + # Dictionary to track which values have been assigned names + value_to_name: dict[ir.Value, str] = {} + # Counters for generating unique names (using list to pass by reference) value_counter = [0] node_counter = [0] # Process the main graph if self._fix_graph_names( - model.graph, seen_value_names, seen_node_names, value_counter, node_counter + model.graph, seen_value_names, seen_node_names, value_to_name, value_counter, node_counter ): modified = True # Process functions for function in model.functions.values(): if self._fix_function_names( - function, seen_value_names, seen_node_names, value_counter, node_counter + function, seen_value_names, seen_node_names, value_to_name, value_counter, node_counter ): modified = True @@ -64,28 +67,26 @@ def _fix_graph_names( graph: ir.Graph, seen_value_names: set[str], seen_node_names: set[str], + value_to_name: dict[ir.Value, str], value_counter: list[int], node_counter: list[int], ) -> bool: """Fix names in a graph and return whether modifications were made.""" modified = False - - # Keep track of values we've already processed to avoid double-processing - processed_values: set[ir.Value] = set() # Step 1: Fix graph input names first (they have precedence) for input_value in graph.inputs: - if self._process_value(input_value, seen_value_names, value_counter, processed_values): + if self._process_value(input_value, seen_value_names, value_to_name, value_counter): modified = True # Step 2: Fix graph output names (they have precedence) for output_value in graph.outputs: - if self._process_value(output_value, seen_value_names, value_counter, processed_values): + if self._process_value(output_value, seen_value_names, value_to_name, value_counter): modified = True # Step 3: Fix initializer names for initializer in graph.initializers.values(): - if self._process_value(initializer, seen_value_names, value_counter, processed_values): + if self._process_value(initializer, seen_value_names, value_to_name, value_counter): modified = True # Step 4: Process all nodes and their values @@ -101,12 +102,12 @@ def _fix_graph_names( # Fix input value names (only if not already processed) for input_value in node.inputs: if input_value is not None: - if self._process_value(input_value, seen_value_names, value_counter, processed_values): + if self._process_value(input_value, seen_value_names, value_to_name, value_counter): modified = True # Fix output value names (only if not already processed) for output_value in node.outputs: - if self._process_value(output_value, seen_value_names, value_counter, processed_values): + if self._process_value(output_value, seen_value_names, value_to_name, value_counter): modified = True return modified @@ -116,23 +117,21 @@ def _fix_function_names( function: ir.Function, seen_value_names: set[str], seen_node_names: set[str], + value_to_name: dict[ir.Value, str], value_counter: list[int], node_counter: list[int], ) -> bool: """Fix names in a function and return whether modifications were made.""" modified = False - - # Keep track of values we've already processed to avoid double-processing - processed_values: set[ir.Value] = set() # Process function inputs first (they have precedence) for input_value in function.inputs: - if self._process_value(input_value, seen_value_names, value_counter, processed_values): + if self._process_value(input_value, seen_value_names, value_to_name, value_counter): modified = True # Process function outputs (they have precedence) for output_value in function.outputs: - if self._process_value(output_value, seen_value_names, value_counter, processed_values): + if self._process_value(output_value, seen_value_names, value_to_name, value_counter): modified = True # Process all nodes and their values @@ -148,12 +147,12 @@ def _fix_function_names( # Fix input value names (only if not already processed) for input_value in node.inputs: if input_value is not None: - if self._process_value(input_value, seen_value_names, value_counter, processed_values): + if self._process_value(input_value, seen_value_names, value_to_name, value_counter): modified = True # Fix output value names (only if not already processed) for output_value in node.outputs: - if self._process_value(output_value, seen_value_names, value_counter, processed_values): + if self._process_value(output_value, seen_value_names, value_to_name, value_counter): modified = True return modified @@ -162,19 +161,22 @@ def _process_value( self, value: ir.Value, seen_value_names: set[str], - value_counter: list[int], - processed_values: set[ir.Value] + value_to_name: dict[ir.Value, str], + value_counter: list[int] ) -> bool: """Process a value only if it hasn't been processed before.""" - if value in processed_values: + if value in value_to_name: return False - processed_values.add(value) - + modified = False if value.name is None or value.name == "": - return self._assign_value_name(value, seen_value_names, value_counter) + modified = self._assign_value_name(value, seen_value_names, value_counter) else: - return self._fix_duplicate_value_name(value, seen_value_names) + modified = self._fix_duplicate_value_name(value, seen_value_names) + + # Record the final name for this value + value_to_name[value] = value.name + return modified def _assign_value_name( self, value: ir.Value, seen_names: set[str], counter: list[int] From 3cbe736d4b0aa268ef8cfb54066b1a5d0ce8c47b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 14 Jul 2025 09:05:17 -0700 Subject: [PATCH 4/4] wip Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/naming.py | 355 ++++++++++++++-------------- 1 file changed, 181 insertions(+), 174 deletions(-) diff --git a/src/onnx_ir/passes/common/naming.py b/src/onnx_ir/passes/common/naming.py index c2f86d5..4e0a303 100644 --- a/src/onnx_ir/passes/common/naming.py +++ b/src/onnx_ir/passes/common/naming.py @@ -9,7 +9,6 @@ ] import logging -from collections.abc import Set as AbstractSet import onnx_ir as ir @@ -29,30 +28,39 @@ class NameFixPass(ir.passes.InPlacePass): """ def call(self, model: ir.Model) -> ir.passes.PassResult: - """Main entry point for the name fix pass.""" modified = False # Use sets to track seen names globally seen_value_names: set[str] = set() seen_node_names: set[str] = set() - + # Dictionary to track which values have been assigned names value_to_name: dict[ir.Value, str] = {} - + # Counters for generating unique names (using list to pass by reference) value_counter = [0] node_counter = [0] # Process the main graph - if self._fix_graph_names( - model.graph, seen_value_names, seen_node_names, value_to_name, value_counter, node_counter + if _fix_graph_names( + model.graph, + seen_value_names, + seen_node_names, + value_to_name, + value_counter, + node_counter, ): modified = True # Process functions for function in model.functions.values(): - if self._fix_function_names( - function, seen_value_names, seen_node_names, value_to_name, value_counter, node_counter + if _fix_function_names( + function, + seen_value_names, + seen_node_names, + value_to_name, + value_counter, + node_counter, ): modified = True @@ -62,196 +70,195 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: return ir.passes.PassResult(model, modified=modified) - def _fix_graph_names( - self, - graph: ir.Graph, - seen_value_names: set[str], - seen_node_names: set[str], - value_to_name: dict[ir.Value, str], - value_counter: list[int], - node_counter: list[int], - ) -> bool: - """Fix names in a graph and return whether modifications were made.""" - modified = False +def _fix_graph_names( + graph: ir.Graph, + seen_value_names: set[str], + seen_node_names: set[str], + value_to_name: dict[ir.Value, str], + value_counter: list[int], + node_counter: list[int], +) -> bool: + """Fix names in a graph and return whether modifications were made.""" + modified = False + + # Step 1: Fix graph input names first (they have precedence) + for input_value in graph.inputs: + if _process_value(input_value, seen_value_names, value_to_name, value_counter): + modified = True - # Step 1: Fix graph input names first (they have precedence) - for input_value in graph.inputs: - if self._process_value(input_value, seen_value_names, value_to_name, value_counter): - modified = True + # Step 2: Fix graph output names (they have precedence) + for output_value in graph.outputs: + if _process_value(output_value, seen_value_names, value_to_name, value_counter): + modified = True - # Step 2: Fix graph output names (they have precedence) - for output_value in graph.outputs: - if self._process_value(output_value, seen_value_names, value_to_name, value_counter): - modified = True + # Step 3: Fix initializer names + for initializer in graph.initializers.values(): + if _process_value(initializer, seen_value_names, value_to_name, value_counter): + modified = True - # Step 3: Fix initializer names - for initializer in graph.initializers.values(): - if self._process_value(initializer, seen_value_names, value_to_name, value_counter): + # Step 4: Process all nodes and their values + for node in ir.traversal.RecursiveGraphIterator(graph): + # Fix node name + if node.name is None or node.name == "": + if _assign_node_name(node, seen_node_names, node_counter): + modified = True + else: + if _fix_duplicate_node_name(node, seen_node_names): modified = True - # Step 4: Process all nodes and their values - for node in ir.traversal.RecursiveGraphIterator(graph): - # Fix node name - if node.name is None or node.name == "": - if self._assign_node_name(node, seen_node_names, node_counter): - modified = True - else: - if self._fix_duplicate_node_name(node, seen_node_names): + # Fix input value names (only if not already processed) + for input_value in node.inputs: + if input_value is not None: + if _process_value(input_value, seen_value_names, value_to_name, value_counter): modified = True - # Fix input value names (only if not already processed) - for input_value in node.inputs: - if input_value is not None: - if self._process_value(input_value, seen_value_names, value_to_name, value_counter): - modified = True + # Fix output value names (only if not already processed) + for output_value in node.outputs: + if _process_value(output_value, seen_value_names, value_to_name, value_counter): + modified = True - # Fix output value names (only if not already processed) - for output_value in node.outputs: - if self._process_value(output_value, seen_value_names, value_to_name, value_counter): - modified = True + return modified - return modified - - def _fix_function_names( - self, - function: ir.Function, - seen_value_names: set[str], - seen_node_names: set[str], - value_to_name: dict[ir.Value, str], - value_counter: list[int], - node_counter: list[int], - ) -> bool: - """Fix names in a function and return whether modifications were made.""" - modified = False - # Process function inputs first (they have precedence) - for input_value in function.inputs: - if self._process_value(input_value, seen_value_names, value_to_name, value_counter): - modified = True +def _fix_function_names( + function: ir.Function, + seen_value_names: set[str], + seen_node_names: set[str], + value_to_name: dict[ir.Value, str], + value_counter: list[int], + node_counter: list[int], +) -> bool: + """Fix names in a function and return whether modifications were made.""" + modified = False - # Process function outputs (they have precedence) - for output_value in function.outputs: - if self._process_value(output_value, seen_value_names, value_to_name, value_counter): - modified = True + # Process function inputs first (they have precedence) + for input_value in function.inputs: + if _process_value(input_value, seen_value_names, value_to_name, value_counter): + modified = True - # Process all nodes and their values - for node in ir.traversal.RecursiveGraphIterator(function): - # Fix node name - if node.name is None or node.name == "": - if self._assign_node_name(node, seen_node_names, node_counter): - modified = True - else: - if self._fix_duplicate_node_name(node, seen_node_names): - modified = True + # Process function outputs (they have precedence) + for output_value in function.outputs: + if _process_value(output_value, seen_value_names, value_to_name, value_counter): + modified = True - # Fix input value names (only if not already processed) - for input_value in node.inputs: - if input_value is not None: - if self._process_value(input_value, seen_value_names, value_to_name, value_counter): - modified = True + # Process all nodes and their values + for node in ir.traversal.RecursiveGraphIterator(function): + # Fix node name + if node.name is None or node.name == "": + if _assign_node_name(node, seen_node_names, node_counter): + modified = True + else: + if _fix_duplicate_node_name(node, seen_node_names): + modified = True - # Fix output value names (only if not already processed) - for output_value in node.outputs: - if self._process_value(output_value, seen_value_names, value_to_name, value_counter): + # Fix input value names (only if not already processed) + for input_value in node.inputs: + if input_value is not None: + if _process_value(input_value, seen_value_names, value_to_name, value_counter): modified = True - return modified - - def _process_value( - self, - value: ir.Value, - seen_value_names: set[str], - value_to_name: dict[ir.Value, str], - value_counter: list[int] - ) -> bool: - """Process a value only if it hasn't been processed before.""" - if value in value_to_name: - return False - - modified = False - if value.name is None or value.name == "": - modified = self._assign_value_name(value, seen_value_names, value_counter) - else: - modified = self._fix_duplicate_value_name(value, seen_value_names) - - # Record the final name for this value - value_to_name[value] = value.name - return modified - - def _assign_value_name( - self, value: ir.Value, seen_names: set[str], counter: list[int] - ) -> bool: - """Assign a name to an unnamed value. Returns True if modified.""" + # Fix output value names (only if not already processed) + for output_value in node.outputs: + if _process_value(output_value, seen_value_names, value_to_name, value_counter): + modified = True + + return modified + + +def _process_value( + value: ir.Value, + seen_value_names: set[str], + value_to_name: dict[ir.Value, str], + value_counter: list[int], +) -> bool: + """Process a value only if it hasn't been processed before.""" + if value in value_to_name: + return False + + modified = False + if value.name is None or value.name == "": + modified = _assign_value_name(value, seen_value_names, value_counter) + else: + modified = _fix_duplicate_value_name(value, seen_value_names) + + # Record the final name for this value + value_to_name[value] = value.name + return modified + + +def _assign_value_name(value: ir.Value, seen_names: set[str], counter: list[int]) -> bool: + """Assign a name to an unnamed value. Returns True if modified.""" + while True: + new_name = f"val_{counter[0]}" + counter[0] += 1 + if new_name not in seen_names: + value.name = new_name + seen_names.add(new_name) + logger.debug("Assigned name %s to unnamed value", new_name) + return True + + +def _assign_node_name(node: ir.Node, seen_names: set[str], counter: list[int]) -> bool: + """Assign a name to an unnamed node. Returns True if modified.""" + while True: + new_name = f"node_{counter[0]}" + counter[0] += 1 + if new_name not in seen_names: + node.name = new_name + seen_names.add(new_name) + logger.debug("Assigned name %s to unnamed node", new_name) + return True + + +def _fix_duplicate_value_name(value: ir.Value, seen_names: set[str]) -> bool: + """Fix a value's name if it conflicts with existing names. Returns True if modified.""" + original_name = value.name + + if original_name is None or original_name == "": + return False # Should not happen if called correctly + + # If name is already seen, make it unique + if original_name in seen_names: + base_name = original_name + suffix = 1 while True: - new_name = f"val_{counter[0]}" - counter[0] += 1 + new_name = f"{base_name}_{suffix}" if new_name not in seen_names: value.name = new_name seen_names.add(new_name) - logger.debug("Assigned name %s to unnamed value", new_name) + logger.debug( + "Renamed value from %s to %s for uniqueness", original_name, new_name + ) return True + suffix += 1 + else: + # Name is unique, just record it + seen_names.add(original_name) + return False + + +def _fix_duplicate_node_name(node: ir.Node, seen_names: set[str]) -> bool: + """Fix a node's name if it conflicts with existing names. Returns True if modified.""" + original_name = node.name - def _assign_node_name( - self, node: ir.Node, seen_names: set[str], counter: list[int] - ) -> bool: - """Assign a name to an unnamed node. Returns True if modified.""" + if original_name is None or original_name == "": + return False # Should not happen if called correctly + + # If name is already seen, make it unique + if original_name in seen_names: + base_name = original_name + suffix = 1 while True: - new_name = f"node_{counter[0]}" - counter[0] += 1 + new_name = f"{base_name}_{suffix}" if new_name not in seen_names: node.name = new_name seen_names.add(new_name) - logger.debug("Assigned name %s to unnamed node", new_name) + logger.debug( + "Renamed node from %s to %s for uniqueness", original_name, new_name + ) return True - - def _fix_duplicate_value_name( - self, value: ir.Value, seen_names: set[str] - ) -> bool: - """Fix a value's name if it conflicts with existing names. Returns True if modified.""" - original_name = value.name - - if original_name is None or original_name == "": - return False # Should not happen if called correctly - - # If name is already seen, make it unique - if original_name in seen_names: - base_name = original_name - suffix = 1 - while True: - new_name = f"{base_name}_{suffix}" - if new_name not in seen_names: - value.name = new_name - seen_names.add(new_name) - logger.debug("Renamed value from %s to %s for uniqueness", original_name, new_name) - return True - suffix += 1 - else: - # Name is unique, just record it - seen_names.add(original_name) - return False - - def _fix_duplicate_node_name( - self, node: ir.Node, seen_names: set[str] - ) -> bool: - """Fix a node's name if it conflicts with existing names. Returns True if modified.""" - original_name = node.name - - if original_name is None or original_name == "": - return False # Should not happen if called correctly - - # If name is already seen, make it unique - if original_name in seen_names: - base_name = original_name - suffix = 1 - while True: - new_name = f"{base_name}_{suffix}" - if new_name not in seen_names: - node.name = new_name - seen_names.add(new_name) - logger.debug("Renamed node from %s to %s for uniqueness", original_name, new_name) - return True - suffix += 1 - else: - # Name is unique, just record it - seen_names.add(original_name) - return False \ No newline at end of file + suffix += 1 + else: + # Name is unique, just record it + seen_names.add(original_name) + return False