From c45ebb01c6248fe3be1042c4270195adde1a9ee3 Mon Sep 17 00:00:00 2001 From: Vinayak Dev <104419489+vinayakdsci@users.noreply.github.com> Date: Fri, 20 Sep 2024 15:11:56 +0530 Subject: [PATCH] [IREE EP] Fix ElementsAttr iteration error from MLIR (#10) * [IREE-EP][Importer,JIT] Fix compile failures for empty inputs * [IREE-EP] Do not take initializers as function args in IR * [IREE EP][Importer] Fix ElementsAttr iteration error * Address comments * Fix lint errors --- .../core/providers/get_execution_providers.cc | 8 ++ .../providers/iree/compiler/jit_compiler.cc | 30 ++--- .../torch-mlir-import-onnx/OnnxImporter.cpp | 119 ++++++++++-------- .../torch-mlir-import-onnx/OnnxImporter.h | 17 ++- .../providers/iree/iree_execution_provider.cc | 4 - .../providers/shared_library/provider_api.h | 1 + .../python/onnxruntime_pybind_state.cc | 2 +- 7 files changed, 96 insertions(+), 85 deletions(-) diff --git a/onnxruntime/core/providers/get_execution_providers.cc b/onnxruntime/core/providers/get_execution_providers.cc index 61c035bc29ed5..0f2f26f5ef8c9 100644 --- a/onnxruntime/core/providers/get_execution_providers.cc +++ b/onnxruntime/core/providers/get_execution_providers.cc @@ -186,6 +186,14 @@ constexpr ProviderInfo kProvidersInPriorityOrder[] = true, #else false, +#endif + }, + { + kIreeExecutionProvider, +#ifdef USE_IREE + true, +#else + false, #endif }, {kCpuExecutionProvider, true}, // kCpuExecutionProvider is always last diff --git a/onnxruntime/core/providers/iree/compiler/jit_compiler.cc b/onnxruntime/core/providers/iree/compiler/jit_compiler.cc index b2627753887c4..382208cce528f 100644 --- a/onnxruntime/core/providers/iree/compiler/jit_compiler.cc +++ b/onnxruntime/core/providers/iree/compiler/jit_compiler.cc @@ -5,6 +5,7 @@ // (which require a pre-compilation step). #include "core/providers/iree/compiler/jit_compiler.h" +#include "core/graph/graph_proto_serializer.h" #include "core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h" #include "mlir-c/BuiltinAttributes.h" @@ -157,13 +158,12 @@ common::Status CompilerInvocation::ImportSubgraph(const onnxruntime::GraphViewer opset_import->set_version(it.second); } - // Unforgivably sharp edge: There is a ToGraphProto() that returns a value and another that returns a reference. - // And they differ by const-ness. We need to make sure we get the reference, obviously, so we assign it explicitly. - const ONNX_NAMESPACE::GraphProto& graph_proto = graph_view.GetGraph().ToGraphProto(); + ONNX_NAMESPACE::GraphProto graph_proto; + GraphViewerToProto(graph_view, graph_proto, false, false); // LOGS(session.logger, INFO) << " full graph: " << graph_proto.DebugString(); // Set up for subgraph import. - torch_mlir_onnx::GraphInfo subgraph_info(model_info, graph_proto); + torch_mlir_onnx::GraphInfo subgraph_info(graph_view, model_info, graph_proto); if (torch_mlir_onnx::failed(subgraph_info.Initialize())) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, model_info.error_message()); } @@ -193,24 +193,10 @@ common::Status CompilerInvocation::ImportSubgraph(const onnxruntime::GraphViewer model_info.error_message(), ConsumeDiagnostics()); } - // Import each node. Note that the importer uses references internally and expects nodes to be located at fixed - // memory locations for the life of iteration. So we materialize them into a fixed vector first. This is because - // the onnxruntime does not keep the serialized proto form sync'd on its own. - auto node_indices = graph_view.GetNodesInTopologicalOrder(); - std::vector nodes(node_indices.size()); - for (size_t i = 0; i < node_indices.size(); ++i) { - graph_view.GetNode(node_indices[i])->ToProto(nodes[i]); - } - for (const auto& node : nodes) { - if (torch_mlir_onnx::failed(imp.ImportNode(node))) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to import node '", node.name(), "': ", - model_info.error_message(), " (node:\n", node.DebugString(), "\n)", ConsumeDiagnostics()); - } - } - - // Finalize. - if (torch_mlir_onnx::failed(imp.FinalizeGraph())) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, model_info.error_message(), ConsumeDiagnostics()); + if (torch_mlir_onnx::failed(imp.ImportAll())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to import nodes", + ": ", model_info.error_message(), + ConsumeDiagnostics()); } // Verify the function at the point of import because we have better diagnostics. diff --git a/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp b/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp index 1f1901e06473b..0ba67cf33fd4c 100644 --- a/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp +++ b/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp @@ -154,11 +154,11 @@ void ModelInfo::DebugDumpProto() { fprintf(stderr, "%s\n", debug_string.c_str()); } -Status ModelInfo::Initialize() { +Status ModelInfo::Initialize(const onnxruntime::GraphViewer &gv) { if (!model_proto_.has_graph()) { return SetError("ONNX ModelProto has no main graph"); } - main_graph_ = std::make_unique(*this, model_proto_.graph()); + main_graph_ = std::make_unique(gv, *this, model_proto_.graph()); if (failed(main_graph_->Initialize())) { return failure(); } @@ -228,33 +228,25 @@ Status GraphInfo::Initialize() { } const onnx::TypeProto *GraphInfo::FindTypeProtoForName(std::string_view name) { - // Node outputs don't typically have type information, but shape inference - // will associate them in the value_info. If not there, it may be a - // graph output, which must have type information. - { - auto it = value_info_map_.find(name); - if (it != value_info_map_.end()) { - return &it->second.type(); - } - } - { - auto it = output_map_.find(name); - if (it != output_map_.end()) { - return &it->second.type(); - } - } - - std::string msg = "No type information associated with '"; - msg.append(name); - msg.append("'. Run shape inference?"); - model_info_.SetError(std::move(msg)); - return nullptr; + return graph_viewer_.GetNodeArg(std::string{name})->TypeAsProto(); } // ---------------------------------------------------------------------------// // ContextCache // ---------------------------------------------------------------------------// +// Parsing !torch.none to an MlirType (this is used as the result type for the +// GetNoneNode op). +MlirType ContextCache::GetNoneType() { + auto t = + mlirTypeParseGet(context_, mlirStringRefCreateFromCString("!torch.none")); + if (mlirTypeIsNull(t)) { + std::string message = "internal error: could not parse !torch.none type: "; + model_info_.SetError(std::move(message)); + } + return t; +} + MlirType ContextCache::ConvertTypeProto(const onnx::TypeProto &tp) { if (tp.has_tensor_type()) { // Convert Tensor TypeProto. @@ -392,8 +384,8 @@ ContextCache::ConvertTensorProtoToAttr(const onnx::TensorProto &tp) { int8_conversion.reserve(tp.int32_data_size()); for (int32_t v : tp.int32_data()) int8_conversion.push_back(v); - return mlirDenseElementsAttrInt8Get( - tensor_type, int8_conversion.size(), int8_conversion.data()); + return mlirDenseElementsAttrInt8Get(tensor_type, int8_conversion.size(), + int8_conversion.data()); } case onnx::TensorProto::DataType::TensorProto_DataType_INT32: return mlirDenseElementsAttrInt32Get(tensor_type, tp.int32_data_size(), @@ -511,6 +503,19 @@ NodeImporter::NodeImporter(GraphInfo &graph_info, ContextCache &cc, /*childLoc=*/{nullptr}); } +// For importing the !torch.none in place of +// '' -> that is the empty label. +void NodeImporter::ImportNoneNode() { + auto it = nv_map_.find(""); + if (it != nv_map_.end()) + return; + + MlirOperation new_op = createMlirOperationAtEnd( + body_block_, "torch.constant.none", default_loc_, cc_.GetNoneType()); + MlirValue nne = mlirOperationGetResult(new_op, 0); + nv_map_.emplace("", nne); +} + Status NodeImporter::DefineFunction(std::optional name, MlirOperation *out_function_op) { const onnx::GraphProto &p = graph_info_.graph_proto(); @@ -529,16 +534,16 @@ Status NodeImporter::DefineFunction(std::optional name, std::vector input_types; std::vector input_locs; std::vector output_types; - for (auto *input : graph_info_.inputs()) { - MlirType t = cc_.ConvertTypeProto(input->type()); + for (auto *input : graph_info_.graph_viewer().GetInputs()) { + MlirType t = cc_.ConvertTypeProto(*input->TypeAsProto()); if (mlirTypeIsNull(t)) { return failure(); } input_types.push_back(t); input_locs.push_back(default_loc_); } - for (auto *output : graph_info_.outputs()) { - MlirType t = cc_.ConvertTypeProto(output->type()); + for (auto output : graph_info_.graph_proto().output()) { + MlirType t = cc_.ConvertTypeProto(output.type()); if (mlirTypeIsNull(t)) { return failure(); } @@ -561,8 +566,9 @@ Status NodeImporter::DefineFunction(std::optional name, mlirRegionAppendOwnedBlock(bodyRegion, body_block_); // Map the block args to names and store for evaluation. - for (int i = 0, e = graph_info_.inputs().size(); i < e; ++i) { - std::string_view name = graph_info_.inputs()[i]->name(); + for (int i = 0, e = graph_info_.graph_viewer().GetInputs().size(); i < e; + ++i) { + std::string_view name = graph_info_.graph_viewer().GetInputs()[i]->Name(); MlirValue value = mlirBlockGetArgument(body_block_, i); nv_map_[name] = value; } @@ -622,15 +628,19 @@ void NodeImporter::PopulateGraphAttrs(MlirOperation container_op) { } Status NodeImporter::ImportAll() { - // TODO: Consider pulling in initializers on demand since there can be so - // much unused crap. - for (auto it : graph_info_.initializer_map()) { - if (failed(ImportInitializer(it.second))) - return failure(); + ImportNoneNode(); + + auto node_indices = graph_info_.graph_viewer().GetNodesInTopologicalOrder(); + std::vector nodes(node_indices.size()); + for (size_t i = 0; i < node_indices.size(); ++i) { + graph_info_.graph_viewer().GetNode(node_indices[i])->ToProto(nodes[i]); } - for (auto it : graph_info_.graph_proto().node()) { - if (failed(ImportNode(it))) - return failure(); + + for (const auto &node : nodes) { + if (torch_mlir_onnx::failed(ImportNode(node))) { + return SetError("Failed to import node '" + node.name() + + "': " + "(node:\n" + node.DebugString() + "\n)"); + } } return FinalizeGraph(); @@ -640,8 +650,8 @@ Status NodeImporter::FinalizeGraph() { // Lookup the outputs, which should all be in the nv_map if the graph was // properly formed. std::vector output_values; - for (const auto *output : graph_info_.outputs()) { - std::string_view name = output->name(); + for (const auto &output : graph_info_.graph_proto().output()) { + std::string_view name = output.name(); auto found_it = nv_map_.find(name); if (found_it == nv_map_.end()) { std::string msg = "Non topologically produced ONNX graph output '"; @@ -670,8 +680,11 @@ Status NodeImporter::ImportInitializer(const onnx::TensorProto &initializer) { return failure(); MlirOperation op = createMlirOperationAtEnd( - body_block_, "torch.vtensor.literal", loc, vtensor_type, - toMlirNamedAttribute("value", value_attr)); + body_block_, "torch.operator", loc, vtensor_type, + toMlirNamedAttribute( + "name", + mlirStringAttrGet(context_, toMlirStringRef("onnx.Constant"))), + toMlirNamedAttribute("torch.onnx.value", value_attr)); MlirValue result = mlirOperationGetResult(op, 0); auto inserted = nv_map_.insert(std::make_pair(name, result)); @@ -706,6 +719,11 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) { // Map inputs to values. std::vector input_values; for (auto &input_name : node.input()) { + if (auto inp = graph_info_.graph_viewer().GetConstantInitializer(input_name, + false)) { + ImportInitializer(*inp); + } + auto found_it = nv_map_.find(input_name); if (found_it == nv_map_.end()) { std::string msg = "Non topologically produced ONNX node input '"; @@ -720,7 +738,7 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) { std::vector output_types; for (auto &output_name : node.output()) { const onnx::TypeProto *type_proto = - graph_info_.FindTypeProtoForName(output_name); + graph_info_.graph_viewer().GetNodeArg(output_name)->TypeAsProto(); if (!type_proto) return failure(); @@ -955,15 +973,8 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) { Status NodeImporter::GetImmediateShapeTensor(const std::string &name, std::vector &shape) { - auto found_it = graph_info_.initializer_map().find(name); - if (found_it == graph_info_.initializer_map().end()) { - std::string message = "An immediate shape value for '"; - message.append(name); - message.append("' was required but it is dynamically produced"); - return SetError(std::move(message)); - } - - const onnx::TensorProto &tp = found_it->second; + const onnx::TensorProto &tp = + *graph_info_.graph_viewer().GetConstantInitializer(name, false); shape.clear(); // Since this is being interpreted as a shape, we only support some limited @@ -1028,7 +1039,7 @@ void NodeImporter::DebugDumpModule() { fwrite(sr.data, sizeof(char), sr.length, stderr); }; MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); - mlirOpPrintingFlagsEnableDebugInfo(flags, true, false); + mlirOpPrintingFlagsEnableDebugInfo(flags, false, true); mlirOperationPrintWithFlags(module_op_, flags, callback, nullptr); mlirOpPrintingFlagsDestroy(flags); } diff --git a/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h b/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h index 0fd24cc3ad004..733aefbda2583 100644 --- a/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h +++ b/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h @@ -17,9 +17,11 @@ // for class members/accessors because canonical protobuf coding presumes // this kind of style. +#include "core/graph/graph_viewer.h" #include "mlir-c/IR.h" #include "onnx/onnx_pb.h" +#include #include #include #include @@ -64,8 +66,9 @@ static inline bool failed(Status status) { return !status.is_success(); } // Accounting for a GraphProto. class GraphInfo { public: - GraphInfo(ModelInfo &model_info, const onnx::GraphProto &graph_proto) - : model_info_(model_info), graph_proto_(graph_proto) {} + GraphInfo(const onnxruntime::GraphViewer &gv, ModelInfo &model_info, + const onnx::GraphProto &graph_proto) + : graph_viewer_(gv), model_info_(model_info), graph_proto_(graph_proto) {} ModelInfo &model_info() { return model_info_; } const onnx::GraphProto &graph_proto() { return graph_proto_; } @@ -101,12 +104,15 @@ class GraphInfo { return output_map_; } + const onnxruntime::GraphViewer &graph_viewer() { return graph_viewer_; } + std::unordered_map & initializer_map() { return initializer_map_; } private: + const onnxruntime::GraphViewer &graph_viewer_; ModelInfo &model_info_; const onnx::GraphProto &graph_proto_; @@ -131,7 +137,7 @@ class ModelInfo { onnx::ModelProto &model_proto() { return model_proto_; } /// Post-construction, failable initialization. - Status Initialize(); + Status Initialize(const onnxruntime::GraphViewer &gv); GraphInfo &main_graph() { return *main_graph_; } const std::string &error_message() { return error_message_; } @@ -157,6 +163,7 @@ class ContextCache { : model_info_(model_info), context_(context) {} MlirContext context() { return context_; } + MlirType GetNoneType(); /// Converts the TypeProto to an MlirType, returning a null type and /// setting an error if not possible. @@ -208,15 +215,17 @@ class NodeImporter { /// Imports all nodes topologically. Internally calls FinalizeGraph. Status ImportAll(); + /// Substitutes !torch.none in place of `''` labelled inputs. + void ImportNoneNode(); /// Import nodes one at a time. Must complete with a call to FinalizeGraph. Status ImportNode(const onnx::NodeProto &node); + Status ImportInitializer(const onnx::TensorProto &initializer); Status FinalizeGraph(); void DebugDumpModule(); private: void PopulateGraphAttrs(MlirOperation container_op); - Status ImportInitializer(const onnx::TensorProto &initializer); MlirAttribute ImportGeneralAttribute(const onnx::AttributeProto &onnx_attr); // Special-form nodes. diff --git a/onnxruntime/core/providers/iree/iree_execution_provider.cc b/onnxruntime/core/providers/iree/iree_execution_provider.cc index 70094efa23788..772475224deaa 100644 --- a/onnxruntime/core/providers/iree/iree_execution_provider.cc +++ b/onnxruntime/core/providers/iree/iree_execution_provider.cc @@ -76,10 +76,6 @@ std::vector> IREEExecutionProvider::GetCapabi inputs.push_back(nodeArgPtr->Name()); } - for (auto& name : required_initializers) { - inputs.push_back(name); - } - for (auto& nodeArgPtr : graph_viewer.GetOutputs()) { outputs.push_back(nodeArgPtr->Name()); } diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index b84825236a453..54ad641983006 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -274,6 +274,7 @@ constexpr const char* kMIGraphXExecutionProvider = "MIGraphXExecutionProvider"; constexpr const char* kQnnExecutionProvider = "QNNExecutionProvider"; constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider"; constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider"; +constexpr const char* kIreeExecutionProvider = "IreeExecutionProvider"; template using IAllocatorUniquePtr = std::unique_ptr>; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 9a74b527ef6ee..357af596f55af 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1141,7 +1141,7 @@ std::unique_ptr CreateExecutionProviderInstance( #endif } else if (type == kIreeExecutionProvider) { #if USE_IREE - const auto &it = provider_options_map.find(type); + const auto& it = provider_options_map.find(type); ProviderOptions iree_option_map = ProviderOptions{}; if (it != provider_options_map.end()) { iree_option_map = it->second;