diff --git a/onnxscript/ir/passes/_remove_unused.py b/onnxscript/ir/passes/_remove_unused.py new file mode 100644 index 0000000000..47f34444ce --- /dev/null +++ b/onnxscript/ir/passes/_remove_unused.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Utilities for removing unused nodes the IR graph.""" + +from __future__ import annotations + +from collections import deque + +import onnxscript.ir as ir +from onnxscript.ir import Attr, Graph, Node, Value, _enums + + +class RemoveUnused: + def __init__(self, graph_like: Graph): + self._graph = graph_like + + def purge(self) -> None: + """Remove unused nodes in this graph (and all subgraphs) that do not contribute to main graph outputs.""" + # 1. Initialize: + # Gather all nodes from the graph and its subgraphs. + # Initialize sets to keep track of visited graphs, values, and nodes. + # 2. BFS traversal: + # Create a queue initialized with all output values of the main graph. + # While there are values in the queue: + # - Dequeue a value and retrieve its producer node. + # - Mark the producer node as visited, if it hasn't been visited. + # - Enqueue all output values of the attribute subgraphs of the producer node, + # if they haven't been visited. + # - Enqueue all input values of the producer node, if they haven't been visited. + # 3. Remove: + # Remove all nodes that have not been marked as visited during the BFS traversal. + + # Initialize + all_nodes: list[Node] = list(ir.traversal.RecursiveGraphIterator(self._graph)) + visited_graphs: set[Graph] = set() + visited_values: set[Value] = set() + visited_nodes: set[Node] = set() + + # BFS Traversal + queue: deque[Value] = deque() + + def add_graph_output_values_to_queue(graph: Graph | None) -> None: + """Helper function to add all output values of a graph to the queue.""" + if not graph or graph in visited_graphs: + return + visited_graphs.add(graph) + for output in graph.outputs: + if not output: + continue + queue.append(output) + visited_values.add(output) + + add_graph_output_values_to_queue(self._graph) + + while queue: + # Dequeue a value and retrieve its producer_node + # Add producer_node to visited_nodes + current_value = queue.popleft() + producer_node = current_value.producer() + if not producer_node or producer_node in visited_nodes: + continue + visited_nodes.add(producer_node) + # Add producer_node's subgraphs to visited_graphs + # Add subgraphs' output values to queue + for attr in producer_node.attributes.values(): + if not isinstance(attr, Attr): + continue + if attr.type == _enums.AttributeType.GRAPH: + add_graph_output_values_to_queue(attr.value) + elif attr.type == _enums.AttributeType.GRAPHS: + for subgraph in attr.value: + add_graph_output_values_to_queue(subgraph) + # Add producer_node's input values to queue + for input_value in producer_node.inputs: + if input_value and input_value not in visited_values: + visited_values.add(input_value) + queue.append(input_value) + + # Remove + for node in all_nodes: + if node not in visited_nodes: # type: ignore[union-attr]` + node.graph.remove(node) diff --git a/onnxscript/ir/passes/_remove_unused_test.py b/onnxscript/ir/passes/_remove_unused_test.py new file mode 100644 index 0000000000..32b4ee8104 --- /dev/null +++ b/onnxscript/ir/passes/_remove_unused_test.py @@ -0,0 +1,133 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +from onnxscript import ir +from onnxscript.ir.passes._remove_unused import RemoveUnused + + +class RemoveUnusedTest(unittest.TestCase): + def test_purge_empty(self): + graph = ir.Graph( + inputs=(), + outputs=(), + nodes=(), + opset_imports={"": 1}, + ) + remove_unused = RemoveUnused(graph) + remove_unused.purge() + self.assertEqual(tuple(graph), ()) + + def test_purge_a_single_node(self): + v0 = ir.Value(name="v0") + node0 = ir.Node("", "Node0", inputs=(v0,), num_outputs=1) + node1 = ir.Node("", "Node1", inputs=(v0,), num_outputs=1) + node2 = ir.Node("", "Node2", inputs=(v0,), num_outputs=0) + node3 = ir.Node("", "Node3", inputs=(), num_outputs=1) + node4 = ir.Node("", "Node4", inputs=(None,), num_outputs=1) + graph = ir.Graph( + (v0,), + (node0.outputs[0], node3.outputs[0], node4.outputs[0]), + nodes=(node0, node1, node2, node3, node4), + opset_imports={"": 1}, + ) + remove_unused = RemoveUnused(graph) + remove_unused.purge() + self.assertEqual(tuple(graph), (node0, node3, node4)) + + def test_purge_a_tree(self): + v0 = ir.Value(name="v0") + node0 = ir.Node("", "Node0", inputs=(v0,), num_outputs=1) + node1 = ir.Node("", "Node1", inputs=(node0.outputs[0],), num_outputs=1) + node2 = ir.Node("", "Node2", inputs=(node0.outputs[0],), num_outputs=1) + graph = ir.Graph( + (v0,), + (), + nodes=(node0, node1, node2), + opset_imports={"": 1}, + ) + remove_unused = RemoveUnused(graph) + remove_unused.purge() + self.assertEqual(tuple(graph), ()) + + def test_purge_subgraph_partial(self): + v0 = ir.Value(name="va") + v1 = ir.Value(name="vb") + v2 = ir.Value(name="vc") + v3 = ir.Value(name="vd") + node0 = ir.Node("", "a", inputs=(v0,), num_outputs=1) + node1 = ir.Node("", "b", inputs=(v1,), num_outputs=1) + node2 = ir.Node("", "c", inputs=(v2,), num_outputs=1) + node3 = ir.Node("", "d", inputs=(v3,), num_outputs=1) + node4 = ir.Node("", "sub", inputs=(node2.outputs[0], node3.outputs[0]), num_outputs=1) + node5 = ir.Node("", "add", inputs=(node2.outputs[0], node3.outputs[0]), num_outputs=1) + node6 = ir.Node("", ">", inputs=(node0.outputs[0], node1.outputs[0]), num_outputs=1) + then_graph = ir.Graph( + inputs=(node2.outputs[0], node3.outputs[0]), + outputs=(node4.outputs[0],), + nodes=(node4,), + name="then_graph", + ) + else_graph = ir.Graph( + inputs=(node2.outputs[0], node3.outputs[0]), + outputs=(), + nodes=(node5,), + name="else_graph", + ) + + node7 = ir.Node( + "", + "if", + inputs=(node6.outputs[0],), + num_outputs=1, + attributes=[ + ir.AttrGraphs("subgraphs", [then_graph, else_graph]), + ], + ) + main_graph = ir.Graph( + inputs=(v0, v1, v2, v3), + outputs=(node7.outputs[0],), + nodes=(node0, node1, node2, node3, node6, node7), + name="main_graph", + opset_imports={"": 1}, + ) + remove_unused = RemoveUnused(main_graph) + remove_unused.purge() + self.assertEqual(tuple(main_graph), (node0, node1, node2, node3, node6, node7)) + self.assertEqual(tuple(then_graph), (node4,)) + self.assertEqual(tuple(else_graph), ()) + + def test_purge_subgraph_all(self): + v0 = ir.Value(name="v0") + node0 = ir.Node("", "c", inputs=(v0,), num_outputs=1) + node1 = ir.Node("", "sub", inputs=(node0.outputs[0],), num_outputs=1) + node2 = ir.Node("", ">", inputs=(v0,), num_outputs=1) + then_graph = ir.Graph( + inputs=(node0.outputs[0],), + outputs=(node1.outputs[0],), + nodes=(node1,), + name="then_graph", + ) + node4 = ir.Node( + "", + "if", + inputs=(node2.outputs[0],), + num_outputs=1, + attributes=[ + ir.AttrGraph("then_graph", then_graph), + ], + ) + main_graph = ir.Graph( + inputs=(v0,), + outputs=(), + nodes=(node0, node2, node4), + name="main_graph", + ) + remove_unused = RemoveUnused(main_graph) + remove_unused.purge() + self.assertEqual(tuple(main_graph), ()) + self.assertEqual(tuple(then_graph), ()) + + +if __name__ == "__main__": + unittest.main()