Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[IREE EP][Importer] Fix IR import for onnx.ConstantOfShape" #12

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,8 @@ Status NodeImporter::ImportAll() {

for (const auto &node : nodes) {
if (torch_mlir_onnx::failed(ImportNode(node))) {
return failure();
return SetError("Failed to import node '" + node.name() +
"': " + "(node:\n" + node.DebugString() + "\n)");
}
}

Expand Down Expand Up @@ -727,8 +728,7 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
if (found_it == nv_map_.end()) {
std::string msg = "Non topologically produced ONNX node input '";
msg.append(input_name);
msg.append("': ");
msg.append(node.DebugString());
msg.append("'");
return SetError(std::move(msg));
}
input_values.push_back(found_it->second);
Expand All @@ -739,9 +739,8 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
for (auto &output_name : node.output()) {
const onnx::TypeProto *type_proto =
graph_info_.graph_viewer().GetNodeArg(output_name)->TypeAsProto();
if (!type_proto) {
return SetError("Failed to obtain TypeProto for tensor");
}
if (!type_proto)
return failure();

MlirType t = cc_.ConvertTypeProto(*type_proto);
if (mlirTypeIsNull(t))
Expand Down Expand Up @@ -907,83 +906,38 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
return mlirRankedTensorTypeGet(shape.size(), shape.data(), element_type,
/*encoding*/ {nullptr});
};
const bool has_raw_data = tensor_proto.has_raw_data();
MlirAttribute splat_attr = {nullptr};
size_t out_size;
switch (tensor_proto.data_type()) {
case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT: {
const float *data = nullptr;
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<float>(tensor_proto, out_size);
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
tensor_proto.DebugString());
}
case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirF32TypeGet(context_)),
has_raw_data ? data[0] : tensor_proto.float_data(0));
tensorTypeFor(mlirF32TypeGet(context_)), tensor_proto.float_data(0));
break;
}
case onnx::TensorProto::DataType::TensorProto_DataType_INT32: {
const int32_t *data = nullptr;
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<int32_t>(tensor_proto, out_size);
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
tensor_proto.DebugString());
}
splat_attr = mlirDenseElementsAttrInt32SplatGet(
case onnx::TensorProto::DataType::TensorProto_DataType_INT32:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirIntegerTypeSignedGet(context_, 32)),
has_raw_data ? data[0] : tensor_proto.int32_data(0));
tensor_proto.int32_data(0));
break;
}
case onnx::TensorProto::DataType::TensorProto_DataType_INT64: {
const int64_t *data = nullptr;
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<int64_t>(tensor_proto, out_size);
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
tensor_proto.DebugString());
}
splat_attr = mlirDenseElementsAttrInt64SplatGet(
case onnx::TensorProto::DataType::TensorProto_DataType_INT64:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirIntegerTypeSignedGet(context_, 64)),
has_raw_data ? data[0] : tensor_proto.int64_data(0));
tensor_proto.int64_data(0));
break;
}
case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE: {
const double *data = nullptr;
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<double>(tensor_proto, out_size);
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
tensor_proto.DebugString());
}
splat_attr = mlirDenseElementsAttrDoubleSplatGet(
tensorTypeFor(mlirF64TypeGet(context_)),
has_raw_data ? data[0] : tensor_proto.double_data(0));
case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirF64TypeGet(context_)), tensor_proto.double_data(0));
break;
}
case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: {
const uint64_t *data = nullptr;
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<uint64_t>(tensor_proto, out_size);
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
tensor_proto.DebugString());
}
splat_attr = mlirDenseElementsAttrUInt64SplatGet(
case onnx::TensorProto::DataType::TensorProto_DataType_UINT64:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 64)),
has_raw_data ? data[0] : tensor_proto.uint64_data(0));
tensor_proto.uint64_data(0));
break;
}
case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: {
const uint32_t *data = nullptr;
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<uint32_t>(tensor_proto, out_size);
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
tensor_proto.DebugString());
}
splat_attr = mlirDenseElementsAttrUInt32SplatGet(
case onnx::TensorProto::DataType::TensorProto_DataType_UINT32:
// Special case: inline data is stored in uint64.
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 32)),
has_raw_data ? data[0] : tensor_proto.float_data(0));
tensor_proto.uint64_data(0));
break;
}
}

if (mlirAttributeIsNull(splat_attr)) {
std::string message =
Expand All @@ -1004,7 +958,8 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
toMlirNamedAttribute("value", splat_attr));
MlirValue result = mlirOperationGetResult(op, 0);

auto inserted = nv_map_.emplace(node.output(0), result);
// Export to the nv_map.
auto inserted = nv_map_.insert(std::make_pair(name, result));
if (!inserted.second) {
std::string msg = "Multiple nodes produced a value for '";
msg.append(name);
Expand All @@ -1018,17 +973,8 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {

Status NodeImporter::GetImmediateShapeTensor(const std::string &name,
std::vector<int64_t> &shape) {
const onnx::TensorProto *tensor =
graph_info_.graph_viewer().GetConstantInitializer(name, false);
if (!tensor) {
std::string msg = "Could not find the immediate shape tensor ";
msg.append(name);
msg.append(" in constant graph initializers. It was possibly produced "
"dynamically.");
return SetError(msg);
}
const onnx::TensorProto &tp = *tensor;

const onnx::TensorProto &tp =
*graph_info_.graph_viewer().GetConstantInitializer(name, false);
shape.clear();

// Since this is being interpreted as a shape, we only support some limited
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,6 @@ class GraphInfo {
return nullptr;
}

std::unordered_map<std::string_view, const onnx::ValueInfoProto &> &
value_info_map() {
return value_info_map_;
}
std::vector<const onnx::ValueInfoProto *> &inputs() { return inputs_; }
std::unordered_map<std::string_view, const onnx::ValueInfoProto &> &
input_map() {
Expand Down
Loading