-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[IR] Implement pass to remove unused nodes in graph #1841
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1841 +/- ##
==========================================
+ Coverage 75.20% 75.28% +0.07%
==========================================
Files 251 253 +2
Lines 27429 27537 +108
Branches 5032 5047 +15
==========================================
+ Hits 20629 20730 +101
- Misses 5828 5830 +2
- Partials 972 977 +5 ☔ View full report in Codecov by Sentry. |
# Remove | ||
for node in all_nodes: | ||
if node not in visited_nodes: | ||
node.graph.remove(node) |
Check failure
Code scanning / lintrunner
MYPY/union-attr Error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Acknowledged.
Just updated with upstream. No need for rerun. |
visited_nodes: set[Node] = set() | ||
|
||
# BFS Traversal | ||
value_queue: deque[Value] = deque( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be better if a node's subgraphs are processed only after the node is itself determined to be useful (that is, added to visited_nodes
. This will handle examples such as the one below better:
x = ...
y = If ( cond, ... x ..., ...)
Here, if y
is not used, then we may not need x
either. But the current logic will, I believe, mark x as visited since it is used to compute the output of the If's then subgraph's output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. Do we really need all subgraphs' outputs? @justinchuby
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I adopted @gramalingam 's idea, modified the code and added a testcase.
|
||
|
||
class RemoveUnused: | ||
def __init__(self, graph_like: Graph): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ir.Graph?
if not isinstance(attr, Attr): | ||
continue | ||
if attr.type == _enums.AttributeType.GRAPH: | ||
add_graph_output_values_to_queue(attr.value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some subgraphs use intermediate results declared in the main graph. You need to loop over nodes inside subgraphs as well. You'll have to handle inputs/outputs with the same name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The RecursiveGraphIterator will loop over all nodes in subgraphs. So all_nodes
includes nodes from the subgraph already
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the code handles Xavier's case, but not because of the recursive graph iterator. (That seems to be used only in the later loop below to remove nodes). The code above goes from value to the producer of the value: this should go from a use inside a subgraph to a producer outside the subgraph (as long as the IR is constructed correctly.)
|
||
add_graph_output_values_to_queue(self._graph) | ||
|
||
while queue: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be possible to avoid the queue by looping over all nodes in the backward order (assuming the ir preserve a consistent order on the nodes +@justinchuby).
from collections import deque | ||
|
||
import onnxscript.ir as ir | ||
from onnxscript.ir import Attr, Graph, Node, Value, _enums |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import modules only
I will come back to this later this week. Thanks for your patience! |
@@ -0,0 +1,82 @@ | |||
# Copyright (c) Microsoft Corporation. |
Check warning
Code scanning / lintrunner
RUFF-FORMAT/format Warning
|
||
# Remove | ||
for node in all_nodes: | ||
if node not in visited_nodes: # type: ignore[union-attr]` |
Check failure
Code scanning / lintrunner
MYPY/syntax Error
# Remove all nodes that have not been marked as visited during the BFS traversal. | ||
|
||
# Initialize | ||
all_nodes: list[Node] = list(ir.traversal.RecursiveGraphIterator(self._graph)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: now this can be moved down to line 79, which is where it is used, I think ...
@justinchuby
Fixes #1474