Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 102 additions & 61 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -691,10 +691,20 @@ static bool IsNodeSupported(const std::set<std::string>& op_set,
}

std::unique_ptr<IndexedSubGraph> MIGraphXExecutionProvider::GetSubGraph(const std::vector<std::size_t>& graph_nodes_index, const GraphViewer& graph, bool is_graph_split) const {
bool is_skip_fuse = false;
std::unordered_set<size_t> node_set;
node_set.reserve(graph_nodes_index.size());
for (const auto& index : graph_nodes_index) {
node_set.insert(index);
// If the subgraph contains an If node, skip fusion
if (graph.GetNode(index)->OpType() == "If") {
is_skip_fuse = true;
}
}

// If the parent graph is an If node, skip fusion
if ((graph.ParentNode() != nullptr && graph.ParentNode()->OpType() == "If")) {
is_skip_fuse = true;
}

// Get parent graph output names
Expand All @@ -706,76 +716,107 @@ std::unique_ptr<IndexedSubGraph> MIGraphXExecutionProvider::GetSubGraph(const st
// Find inputs and outputs of the subgraph
std::unique_ptr<IndexedSubGraph> sub_graph = onnxruntime::IndexedSubGraph::Create();
std::unordered_map<const NodeArg*, int> fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add;
std::unordered_set<const NodeArg*> erased;
int input_order = 0;
int output_order = 0;

for (const auto& index : graph_nodes_index) {
sub_graph->Nodes().push_back(index);
const auto& node = graph.GetNode(index);
for (const auto& input : node->InputDefs()) {
const auto& it = fused_outputs.find(input);
if (it != fused_outputs.end()) {
fused_outputs.erase(it);
erased.insert(input);
} else if (erased.find(input) == erased.end()) {
// Only when input is neither in output list nor erased list, add the input to input list
fused_inputs[input] = input_order++;
if (!is_skip_fuse) {
std::unordered_set<const NodeArg*> erased;
int input_order = 0;
int output_order = 0;

for (const auto& index : graph_nodes_index) {
sub_graph->Nodes().push_back(index);
const auto& node = graph.GetNode(index);
for (const auto& input : node->InputDefs()) {
const auto& it = fused_outputs.find(input);
if (it != fused_outputs.end()) {
fused_outputs.erase(it);
erased.insert(input);
} else if (erased.find(input) == erased.end()) {
// Only when input is neither in output list nor erased list, add the input to input list
fused_inputs[input] = input_order++;
}
}
}

for (const auto& input : node->ImplicitInputDefs()) {
const auto& it = fused_outputs.find(input);
if (it != fused_outputs.end()) {
fused_outputs.erase(it);
erased.insert(input);
} else if (erased.find(input) == erased.end()) {
// Only when input is neither in output list nor erased list, add the input to input list
fused_inputs[input] = input_order++;
for (const auto& input : node->ImplicitInputDefs()) {
const auto& it = fused_outputs.find(input);
if (it != fused_outputs.end()) {
fused_outputs.erase(it);
erased.insert(input);
} else if (erased.find(input) == erased.end()) {
// Only when input is neither in output list nor erased list, add the input to input list
fused_inputs[input] = input_order++;
}
}
}

// For output searching, there are two special cases,
// One is, if node's OutputEdges are more than its outputs, meaning certain output is used more than once,
// if the output is connected to nodes that don't belong to the subgraph, the output need to be added
// to the output list
// The other one is, if subgraph's node output is parent graph's output. the node output should
// be also added to the subgraph's output list
if (node->GetOutputEdgesCount() > node->OutputDefs().size()) {
for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) {
const auto& node_idx = it->GetNode().Index();
const auto& output = (it->GetNode()).InputDefs()[it->GetDstArgIndex()];
if (node_set.find(node_idx) != node_set.end()) {
const auto& iter = fused_inputs.find(output);
if (iter != fused_inputs.end()) {
fused_inputs.erase(iter);
erased.insert(output);
} else if (erased.find(output) == erased.end()) {
if (std::find(graph_output_names.begin(),
graph_output_names.end(), output->Name()) != graph_output_names.end()) {
graph_outputs_to_add[output] = output_order;
// For output searching, there are two special cases,
// One is, if node's OutputEdges are more than its outputs, meaning certain output is used more than once,
// if the output is connected to nodes that don't belong to the subgraph, the output need to be added
// to the output list
// The other one is, if subgraph's node output is parent graph's output. the node output should
// be also added to the subgraph's output list
if (node->GetOutputEdgesCount() > node->OutputDefs().size()) {
for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) {
const auto& node_idx = it->GetNode().Index();
const auto input_defs = it->GetNode().InputDefs();
const size_t arg_index = it->GetDstArgIndex();
const NodeArg* output = nullptr;

// Account for implicit inputs for "If" operator destination indexing
// Addresses an index out of bounds issue but "If" operator still has functional issues.
if (it->GetNode().OpType() == "If") {
size_t num_of_explicit_inputs = it->GetNode().InputDefs().size();
if (num_of_explicit_inputs > arg_index) {
output = input_defs.at(arg_index);
} else {
if (num_of_explicit_inputs + it->GetNode().ImplicitInputDefs().size() > arg_index) {
output = it->GetNode().ImplicitInputDefs()[arg_index - num_of_explicit_inputs];
}
}
if (output == nullptr) {
ORT_THROW("Invalid destination node arg.");
}
} else {
if (arg_index >= input_defs.size()) {
ORT_THROW("the index of destination argument (" + std::to_string(arg_index) +
") is outside the "
"range of model input definitions ({" +
std::to_string(input_defs.size()) + "}");
continue;
}
fused_outputs[output] = output_order++;
output = input_defs.at(arg_index);
}

if (node_set.find(node_idx) != node_set.end()) {
const auto& iter = fused_inputs.find(output);
if (iter != fused_inputs.end()) {
fused_inputs.erase(iter);
erased.insert(output);
} else if (erased.find(output) == erased.end()) {
if (std::find(graph_output_names.begin(),
graph_output_names.end(), output->Name()) != graph_output_names.end()) {
graph_outputs_to_add[output] = output_order;
}
fused_outputs[output] = output_order++;
}
} else {
fused_outputs_to_add[output] = output_order++;
}
} else {
fused_outputs_to_add[output] = output_order++;
}
}
} else {
for (const auto& output : node->OutputDefs()) {
const auto& it = fused_inputs.find(output);
if (it != fused_inputs.end()) {
fused_inputs.erase(it);
erased.insert(output);
}
// Only when output is neither in input list nor erased list, add the output to output list
else {
if (erased.find(output) == erased.end()) {
if (std::find(graph_output_names.begin(),
graph_output_names.end(), output->Name()) != graph_output_names.end()) {
graph_outputs_to_add[output] = output_order;
} else {
for (const auto& output : node->OutputDefs()) {
const auto& it = fused_inputs.find(output);
if (it != fused_inputs.end()) {
fused_inputs.erase(it);
erased.insert(output);
}
// Only when output is neither in input list nor erased list, add the output to output list
else {
if (erased.find(output) == erased.end()) {
if (std::find(graph_output_names.begin(),
graph_output_names.end(), output->Name()) != graph_output_names.end()) {
graph_outputs_to_add[output] = output_order;
}
fused_outputs[output] = output_order++;
}
fused_outputs[output] = output_order++;
}
}
}
Expand Down
Loading