Skip to content

Commit

Permalink
[IREE EP] Fix ElementsAttr iteration error from MLIR (#10)
Browse files Browse the repository at this point in the history
* [IREE-EP][Importer,JIT] Fix compile failures for empty inputs

* [IREE-EP] Do not take initializers as function args in IR

* [IREE EP][Importer] Fix ElementsAttr iteration error

* Address comments

* Fix lint errors
  • Loading branch information
vinayakdsci authored Sep 20, 2024
1 parent 08acace commit c45ebb0
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 85 deletions.
8 changes: 8 additions & 0 deletions onnxruntime/core/providers/get_execution_providers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,14 @@ constexpr ProviderInfo kProvidersInPriorityOrder[] =
true,
#else
false,
#endif
},
{
kIreeExecutionProvider,
#ifdef USE_IREE
true,
#else
false,
#endif
},
{kCpuExecutionProvider, true}, // kCpuExecutionProvider is always last
Expand Down
30 changes: 8 additions & 22 deletions onnxruntime/core/providers/iree/compiler/jit_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// (which require a pre-compilation step).

#include "core/providers/iree/compiler/jit_compiler.h"
#include "core/graph/graph_proto_serializer.h"
#include "core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h"
#include "mlir-c/BuiltinAttributes.h"

Expand Down Expand Up @@ -157,13 +158,12 @@ common::Status CompilerInvocation::ImportSubgraph(const onnxruntime::GraphViewer
opset_import->set_version(it.second);
}

// Unforgivably sharp edge: There is a ToGraphProto() that returns a value and another that returns a reference.
// And they differ by const-ness. We need to make sure we get the reference, obviously, so we assign it explicitly.
const ONNX_NAMESPACE::GraphProto& graph_proto = graph_view.GetGraph().ToGraphProto();
ONNX_NAMESPACE::GraphProto graph_proto;
GraphViewerToProto(graph_view, graph_proto, false, false);
// LOGS(session.logger, INFO) << " full graph: " << graph_proto.DebugString();

// Set up for subgraph import.
torch_mlir_onnx::GraphInfo subgraph_info(model_info, graph_proto);
torch_mlir_onnx::GraphInfo subgraph_info(graph_view, model_info, graph_proto);
if (torch_mlir_onnx::failed(subgraph_info.Initialize())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, model_info.error_message());
}
Expand Down Expand Up @@ -193,24 +193,10 @@ common::Status CompilerInvocation::ImportSubgraph(const onnxruntime::GraphViewer
model_info.error_message(), ConsumeDiagnostics());
}

// Import each node. Note that the importer uses references internally and expects nodes to be located at fixed
// memory locations for the life of iteration. So we materialize them into a fixed vector first. This is because
// the onnxruntime does not keep the serialized proto form sync'd on its own.
auto node_indices = graph_view.GetNodesInTopologicalOrder();
std::vector<ONNX_NAMESPACE::NodeProto> nodes(node_indices.size());
for (size_t i = 0; i < node_indices.size(); ++i) {
graph_view.GetNode(node_indices[i])->ToProto(nodes[i]);
}
for (const auto& node : nodes) {
if (torch_mlir_onnx::failed(imp.ImportNode(node))) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to import node '", node.name(), "': ",
model_info.error_message(), " (node:\n", node.DebugString(), "\n)", ConsumeDiagnostics());
}
}

// Finalize.
if (torch_mlir_onnx::failed(imp.FinalizeGraph())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, model_info.error_message(), ConsumeDiagnostics());
if (torch_mlir_onnx::failed(imp.ImportAll())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to import nodes",
": ", model_info.error_message(),
ConsumeDiagnostics());
}

// Verify the function at the point of import because we have better diagnostics.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,11 @@ void ModelInfo::DebugDumpProto() {
fprintf(stderr, "%s\n", debug_string.c_str());
}

Status ModelInfo::Initialize() {
Status ModelInfo::Initialize(const onnxruntime::GraphViewer &gv) {
if (!model_proto_.has_graph()) {
return SetError("ONNX ModelProto has no main graph");
}
main_graph_ = std::make_unique<GraphInfo>(*this, model_proto_.graph());
main_graph_ = std::make_unique<GraphInfo>(gv, *this, model_proto_.graph());
if (failed(main_graph_->Initialize())) {
return failure();
}
Expand Down Expand Up @@ -228,33 +228,25 @@ Status GraphInfo::Initialize() {
}

const onnx::TypeProto *GraphInfo::FindTypeProtoForName(std::string_view name) {
// Node outputs don't typically have type information, but shape inference
// will associate them in the value_info. If not there, it may be a
// graph output, which must have type information.
{
auto it = value_info_map_.find(name);
if (it != value_info_map_.end()) {
return &it->second.type();
}
}
{
auto it = output_map_.find(name);
if (it != output_map_.end()) {
return &it->second.type();
}
}

std::string msg = "No type information associated with '";
msg.append(name);
msg.append("'. Run shape inference?");
model_info_.SetError(std::move(msg));
return nullptr;
return graph_viewer_.GetNodeArg(std::string{name})->TypeAsProto();
}

// ---------------------------------------------------------------------------//
// ContextCache
// ---------------------------------------------------------------------------//

// Parsing !torch.none to an MlirType (this is used as the result type for the
// GetNoneNode op).
MlirType ContextCache::GetNoneType() {
auto t =
mlirTypeParseGet(context_, mlirStringRefCreateFromCString("!torch.none"));
if (mlirTypeIsNull(t)) {
std::string message = "internal error: could not parse !torch.none type: ";
model_info_.SetError(std::move(message));
}
return t;
}

MlirType ContextCache::ConvertTypeProto(const onnx::TypeProto &tp) {
if (tp.has_tensor_type()) {
// Convert Tensor TypeProto.
Expand Down Expand Up @@ -392,8 +384,8 @@ ContextCache::ConvertTensorProtoToAttr(const onnx::TensorProto &tp) {
int8_conversion.reserve(tp.int32_data_size());
for (int32_t v : tp.int32_data())
int8_conversion.push_back(v);
return mlirDenseElementsAttrInt8Get(
tensor_type, int8_conversion.size(), int8_conversion.data());
return mlirDenseElementsAttrInt8Get(tensor_type, int8_conversion.size(),
int8_conversion.data());
}
case onnx::TensorProto::DataType::TensorProto_DataType_INT32:
return mlirDenseElementsAttrInt32Get(tensor_type, tp.int32_data_size(),
Expand Down Expand Up @@ -511,6 +503,19 @@ NodeImporter::NodeImporter(GraphInfo &graph_info, ContextCache &cc,
/*childLoc=*/{nullptr});
}

// For importing the !torch.none in place of
// '' -> that is the empty label.
void NodeImporter::ImportNoneNode() {
auto it = nv_map_.find("");
if (it != nv_map_.end())
return;

MlirOperation new_op = createMlirOperationAtEnd(
body_block_, "torch.constant.none", default_loc_, cc_.GetNoneType());
MlirValue nne = mlirOperationGetResult(new_op, 0);
nv_map_.emplace("", nne);
}

Status NodeImporter::DefineFunction(std::optional<std::string> name,
MlirOperation *out_function_op) {
const onnx::GraphProto &p = graph_info_.graph_proto();
Expand All @@ -529,16 +534,16 @@ Status NodeImporter::DefineFunction(std::optional<std::string> name,
std::vector<MlirType> input_types;
std::vector<MlirLocation> input_locs;
std::vector<MlirType> output_types;
for (auto *input : graph_info_.inputs()) {
MlirType t = cc_.ConvertTypeProto(input->type());
for (auto *input : graph_info_.graph_viewer().GetInputs()) {
MlirType t = cc_.ConvertTypeProto(*input->TypeAsProto());
if (mlirTypeIsNull(t)) {
return failure();
}
input_types.push_back(t);
input_locs.push_back(default_loc_);
}
for (auto *output : graph_info_.outputs()) {
MlirType t = cc_.ConvertTypeProto(output->type());
for (auto output : graph_info_.graph_proto().output()) {
MlirType t = cc_.ConvertTypeProto(output.type());
if (mlirTypeIsNull(t)) {
return failure();
}
Expand All @@ -561,8 +566,9 @@ Status NodeImporter::DefineFunction(std::optional<std::string> name,
mlirRegionAppendOwnedBlock(bodyRegion, body_block_);

// Map the block args to names and store for evaluation.
for (int i = 0, e = graph_info_.inputs().size(); i < e; ++i) {
std::string_view name = graph_info_.inputs()[i]->name();
for (int i = 0, e = graph_info_.graph_viewer().GetInputs().size(); i < e;
++i) {
std::string_view name = graph_info_.graph_viewer().GetInputs()[i]->Name();
MlirValue value = mlirBlockGetArgument(body_block_, i);
nv_map_[name] = value;
}
Expand Down Expand Up @@ -622,15 +628,19 @@ void NodeImporter::PopulateGraphAttrs(MlirOperation container_op) {
}

Status NodeImporter::ImportAll() {
// TODO: Consider pulling in initializers on demand since there can be so
// much unused crap.
for (auto it : graph_info_.initializer_map()) {
if (failed(ImportInitializer(it.second)))
return failure();
ImportNoneNode();

auto node_indices = graph_info_.graph_viewer().GetNodesInTopologicalOrder();
std::vector<ONNX_NAMESPACE::NodeProto> nodes(node_indices.size());
for (size_t i = 0; i < node_indices.size(); ++i) {
graph_info_.graph_viewer().GetNode(node_indices[i])->ToProto(nodes[i]);
}
for (auto it : graph_info_.graph_proto().node()) {
if (failed(ImportNode(it)))
return failure();

for (const auto &node : nodes) {
if (torch_mlir_onnx::failed(ImportNode(node))) {
return SetError("Failed to import node '" + node.name() +
"': " + "(node:\n" + node.DebugString() + "\n)");
}
}

return FinalizeGraph();
Expand All @@ -640,8 +650,8 @@ Status NodeImporter::FinalizeGraph() {
// Lookup the outputs, which should all be in the nv_map if the graph was
// properly formed.
std::vector<MlirValue> output_values;
for (const auto *output : graph_info_.outputs()) {
std::string_view name = output->name();
for (const auto &output : graph_info_.graph_proto().output()) {
std::string_view name = output.name();
auto found_it = nv_map_.find(name);
if (found_it == nv_map_.end()) {
std::string msg = "Non topologically produced ONNX graph output '";
Expand Down Expand Up @@ -670,8 +680,11 @@ Status NodeImporter::ImportInitializer(const onnx::TensorProto &initializer) {
return failure();

MlirOperation op = createMlirOperationAtEnd(
body_block_, "torch.vtensor.literal", loc, vtensor_type,
toMlirNamedAttribute("value", value_attr));
body_block_, "torch.operator", loc, vtensor_type,
toMlirNamedAttribute(
"name",
mlirStringAttrGet(context_, toMlirStringRef("onnx.Constant"))),
toMlirNamedAttribute("torch.onnx.value", value_attr));
MlirValue result = mlirOperationGetResult(op, 0);

auto inserted = nv_map_.insert(std::make_pair(name, result));
Expand Down Expand Up @@ -706,6 +719,11 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
// Map inputs to values.
std::vector<MlirValue> input_values;
for (auto &input_name : node.input()) {
if (auto inp = graph_info_.graph_viewer().GetConstantInitializer(input_name,
false)) {
ImportInitializer(*inp);
}

auto found_it = nv_map_.find(input_name);
if (found_it == nv_map_.end()) {
std::string msg = "Non topologically produced ONNX node input '";
Expand All @@ -720,7 +738,7 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
std::vector<MlirType> output_types;
for (auto &output_name : node.output()) {
const onnx::TypeProto *type_proto =
graph_info_.FindTypeProtoForName(output_name);
graph_info_.graph_viewer().GetNodeArg(output_name)->TypeAsProto();
if (!type_proto)
return failure();

Expand Down Expand Up @@ -955,15 +973,8 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {

Status NodeImporter::GetImmediateShapeTensor(const std::string &name,
std::vector<int64_t> &shape) {
auto found_it = graph_info_.initializer_map().find(name);
if (found_it == graph_info_.initializer_map().end()) {
std::string message = "An immediate shape value for '";
message.append(name);
message.append("' was required but it is dynamically produced");
return SetError(std::move(message));
}

const onnx::TensorProto &tp = found_it->second;
const onnx::TensorProto &tp =
*graph_info_.graph_viewer().GetConstantInitializer(name, false);
shape.clear();

// Since this is being interpreted as a shape, we only support some limited
Expand Down Expand Up @@ -1028,7 +1039,7 @@ void NodeImporter::DebugDumpModule() {
fwrite(sr.data, sizeof(char), sr.length, stderr);
};
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
mlirOpPrintingFlagsEnableDebugInfo(flags, true, false);
mlirOpPrintingFlagsEnableDebugInfo(flags, false, true);
mlirOperationPrintWithFlags(module_op_, flags, callback, nullptr);
mlirOpPrintingFlagsDestroy(flags);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
// for class members/accessors because canonical protobuf coding presumes
// this kind of style.

#include "core/graph/graph_viewer.h"
#include "mlir-c/IR.h"
#include "onnx/onnx_pb.h"

#include <memory>
#include <optional>
#include <string_view>
#include <unordered_map>
Expand Down Expand Up @@ -64,8 +66,9 @@ static inline bool failed(Status status) { return !status.is_success(); }
// Accounting for a GraphProto.
class GraphInfo {
public:
GraphInfo(ModelInfo &model_info, const onnx::GraphProto &graph_proto)
: model_info_(model_info), graph_proto_(graph_proto) {}
GraphInfo(const onnxruntime::GraphViewer &gv, ModelInfo &model_info,
const onnx::GraphProto &graph_proto)
: graph_viewer_(gv), model_info_(model_info), graph_proto_(graph_proto) {}
ModelInfo &model_info() { return model_info_; }
const onnx::GraphProto &graph_proto() { return graph_proto_; }

Expand Down Expand Up @@ -101,12 +104,15 @@ class GraphInfo {
return output_map_;
}

const onnxruntime::GraphViewer &graph_viewer() { return graph_viewer_; }

std::unordered_map<std::string_view, const onnx::TensorProto &> &
initializer_map() {
return initializer_map_;
}

private:
const onnxruntime::GraphViewer &graph_viewer_;
ModelInfo &model_info_;
const onnx::GraphProto &graph_proto_;

Expand All @@ -131,7 +137,7 @@ class ModelInfo {
onnx::ModelProto &model_proto() { return model_proto_; }

/// Post-construction, failable initialization.
Status Initialize();
Status Initialize(const onnxruntime::GraphViewer &gv);

GraphInfo &main_graph() { return *main_graph_; }
const std::string &error_message() { return error_message_; }
Expand All @@ -157,6 +163,7 @@ class ContextCache {
: model_info_(model_info), context_(context) {}

MlirContext context() { return context_; }
MlirType GetNoneType();

/// Converts the TypeProto to an MlirType, returning a null type and
/// setting an error if not possible.
Expand Down Expand Up @@ -208,15 +215,17 @@ class NodeImporter {
/// Imports all nodes topologically. Internally calls FinalizeGraph.
Status ImportAll();

/// Substitutes !torch.none in place of `''` labelled inputs.
void ImportNoneNode();
/// Import nodes one at a time. Must complete with a call to FinalizeGraph.
Status ImportNode(const onnx::NodeProto &node);
Status ImportInitializer(const onnx::TensorProto &initializer);
Status FinalizeGraph();

void DebugDumpModule();

private:
void PopulateGraphAttrs(MlirOperation container_op);
Status ImportInitializer(const onnx::TensorProto &initializer);
MlirAttribute ImportGeneralAttribute(const onnx::AttributeProto &onnx_attr);

// Special-form nodes.
Expand Down
4 changes: 0 additions & 4 deletions onnxruntime/core/providers/iree/iree_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,6 @@ std::vector<std::unique_ptr<ComputeCapability>> IREEExecutionProvider::GetCapabi
inputs.push_back(nodeArgPtr->Name());
}

for (auto& name : required_initializers) {
inputs.push_back(name);
}

for (auto& nodeArgPtr : graph_viewer.GetOutputs()) {
outputs.push_back(nodeArgPtr->Name());
}
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/shared_library/provider_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ constexpr const char* kMIGraphXExecutionProvider = "MIGraphXExecutionProvider";
constexpr const char* kQnnExecutionProvider = "QNNExecutionProvider";
constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider";
constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider";
constexpr const char* kIreeExecutionProvider = "IreeExecutionProvider";

template <typename T>
using IAllocatorUniquePtr = std::unique_ptr<T, std::function<void(T*)>>;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
#endif
} else if (type == kIreeExecutionProvider) {
#if USE_IREE
const auto &it = provider_options_map.find(type);
const auto& it = provider_options_map.find(type);
ProviderOptions iree_option_map = ProviderOptions{};
if (it != provider_options_map.end()) {
iree_option_map = it->second;
Expand Down

0 comments on commit c45ebb0

Please sign in to comment.