forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsubgraph_matcher.cpp
222 lines (193 loc) · 6.18 KB
/
subgraph_matcher.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
#include <torch/csrc/jit/subgraph_matcher.h>
#include <stack>
namespace torch {
namespace jit {
namespace {
/**
* \brief A class implementing an API for comparing subgraphs.
*/
class SubgraphMatcher {
public:
explicit SubgraphMatcher(const Graph& pattern) : pattern_(pattern) {}
/**
* \brief Compare matchGraph with the part of the graph denoted by a node \p
* ANCHOR.
*
* The anchor node would be compared against the deepest node in the
* match-graph. A node is considered matching if its number of inputs/outputs
* is the same as in the corresponding matchGraph node, its type is the same,
* and all nodes producing input-values also match.
*/
bool matchesSubgraphFromAnchorNode(Node* anchor);
/** \brief Return match map for nodes. */
std::unordered_map<const Node*, Node*> nodes_map() const {
return nodes_map_;
}
/** \brief Return match map for values. */
std::unordered_map<const Value*, Value*> values_map() const {
return values_map_;
}
private:
bool matchValues(const Value* v1, Value* v2);
bool matchNodes(const Node* n1, Node* n2);
std::unordered_map<const Node*, Node*> nodes_map_;
std::unordered_map<const Value*, Value*> values_map_;
const Graph& pattern_;
const Node* anchor_ = nullptr;
};
/**
* \brief A function to verify that \p PATTERN is valid. Concrete requirements
* for validity can be found in subgraph_matcher.h.
*/
bool patternGraphIsValid(const Graph& pattern) {
// Verify that pattern graph has a single block.
for (const Node* n : pattern.nodes()) {
if (!n->blocks().empty()) {
return false;
}
}
// Verify that pattern graph returns only one value.
const Node* bottom_node = *(pattern.nodes().end());
if (bottom_node->inputs().size() != 1) {
return false;
}
// TODO: Verify that nodes in the pattern don't alias.
return true;
}
/**
* Compare two Values. V1 is from pattern, V2 is from the actual graph.
*
* The values are considered matching if:
* 1) the nodes defining them match
* 2) they have the same number of uses, except they are entry or exit nodes.
*/
bool SubgraphMatcher::matchValues(const Value* v1, Value* v2) {
// Check if we've already visited these values.
if (values_map_.count(v1)) {
return values_map_.at(v1) == v2;
}
// When V2 is ANCHOR, we're comparing exiting values, and when V1->node is
// PARAM, we're comparing entering values - in these two cases the number of
// uses don't need to be the same.
if (v1->uses().size() != v2->uses().size() && v2->node() != anchor_ &&
v1->node()->kind() != prim::Param) {
return false;
}
// Add the values to the map before calling matchNodes to avoid infinite
// recursion.
values_map_[v1] = v2;
return matchNodes(v1->node(), v2->node());
}
/**
* Compare two Nodes. N1 is from pattern, N2 is from the actual graph.
*
* The nodes are considered matching if:
* 1) N1 and N2 are of the same kind.
* 2) Number of inputs and outputs is the same.
* 3) All input and output values match.
*
* A special case is when N1 is PARAM - this is considered outside the pattern,
* so it matches everything.
*/
bool SubgraphMatcher::matchNodes(const Node* n1, Node* n2) {
// Check if we've already visited these nodes.
if (nodes_map_.count(n1)) {
return nodes_map_.at(n1) == n2;
}
// Param node in pattern graph matches everything.
if (n1->kind() == prim::Param) {
return true;
}
// We don't allow matches to span across blocks, so check if N2 is in the same
// block as the first (anchor) node.
if (n2->owningBlock() != anchor_->owningBlock()) {
return false;
}
if (n1->kind() != n2->kind() ||
n1->outputs().size() != n2->outputs().size() ||
n1->inputs().size() != n2->inputs().size() ||
n1->numAttributes() != n2->numAttributes()) {
return false;
}
// Add nodes to the map before calling matchValues to avoid infinite
// recursion.
nodes_map_[n1] = n2;
for (size_t i = 0; i < n1->outputs().size(); i++) {
if (!matchValues(n1->outputs()[i], n2->outputs()[i])) {
return false;
}
}
for (size_t i = 0; i < n1->inputs().size(); i++) {
if (!matchValues(n1->inputs()[i], n2->inputs()[i])) {
return false;
}
}
for (const Symbol& attr_name : n1->attributeNames()) {
if (n1->kindOf(attr_name) != n2->kindOf(attr_name)) {
return false;
}
switch (n1->kindOf(attr_name)) {
case AttributeKind::s:
if (n1->s(attr_name) != n2->s(attr_name)) {
return false;
}
break;
case AttributeKind::f:
if (n1->f(attr_name) != n2->f(attr_name)) {
return false;
}
break;
case AttributeKind::i:
if (n1->i(attr_name) != n2->i(attr_name)) {
return false;
}
break;
default:
// Other attributes types not supported yet
return false;
}
}
return true;
}
/**
* Recursively try to match pattern with the actual graph starting from the
* exiting node in the pattern and anchor node in the actual graph.
*/
bool SubgraphMatcher::matchesSubgraphFromAnchorNode(Node* anchor) {
nodes_map_.clear();
values_map_.clear();
anchor_ = anchor;
const Node* bottom_node = *(pattern_.nodes().end());
AT_ASSERT(bottom_node->inputs().size() == 1);
bottom_node = bottom_node->input()->node();
if (!matchNodes(bottom_node, anchor)) {
return false;
}
return true;
}
} // unnamed namespace
// Main entry point for the subgraph matching.
std::vector<Match> findPatternMatches(const Graph& pattern, Graph& graph) {
AT_ASSERT(patternGraphIsValid(pattern));
SubgraphMatcher m(pattern);
std::vector<Match> matches;
std::stack<Block*> blocks_to_visit;
// Iterate over all nodes in the graph (including nodes in subblocks) trying
// to match the pattern each node.
blocks_to_visit.push(graph.block());
while (!blocks_to_visit.empty()) {
Block* block = blocks_to_visit.top();
blocks_to_visit.pop();
for (Node* n : block->nodes()) {
if (m.matchesSubgraphFromAnchorNode(n)) {
matches.push_back({n, m.nodes_map(), m.values_map()});
}
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
}
}
}
return matches;
}
} // namespace jit
} // namespace torch