diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 95cfa7135483f..4d0b4c834e8c8 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -744,7 +744,21 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st // be also added to the subgraph's output list if (node->GetOutputEdgesCount() > node->OutputDefs().size()) { for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) { - const auto& node_idx = it->GetNode().Index(); + + const auto& target_node = it->GetNode(); + const auto& target_op_type = target_node.OpType(); + + if (target_op_type == "If" || target_op_type == "Loop" || target_op_type == "Scan") { + const auto& src_output_idx = it->GetSrcArgIndex(); + if (src_output_idx < node->OutputDefs().size()) { + const auto* output_def = node->OutputDefs()[src_output_idx]; + if (output_def && fused_outputs.find(output_def) == fused_outputs.end() && erased.find(output_def) == erased.end()) { + fused_outputs_to_add[output_def] = output_order++; + } + } + continue; + } + const auto& node_idx = target_node.Index(); const auto& output = (it->GetNode()).InputDefs()[it->GetDstArgIndex()]; if (node_set.find(node_idx) != node_set.end()) { const auto& iter = fused_inputs.find(output); @@ -989,6 +1003,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "SimplifiedLayerNormalization", "Sin", "Sinh", + "Size", "SkipLayerNormalization", "SkipSimplifiedLayerNormalization", "Slice", @@ -1068,6 +1083,18 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector> result; + + if (graph_viewer.IsSubgraph()) + { + const auto* parent_node = graph_viewer.ParentNode(); + if (parent_node) { + const auto& parent_op_type = parent_node->OpType(); + if (parent_op_type == "If" || parent_op_type == "Loop" || parent_op_type == "Scan") { + return result; + } + } + } + auto model = graph_viewer.CreateModel(*GetLogger()); auto model_proto = model->ToProto(); graph_viewer.ToProto(*model_proto->mutable_graph(), true, true);