From 09dad3fc01422b02643e19227d90aecf28738f5a Mon Sep 17 00:00:00 2001 From: Vinayak Dev Date: Mon, 23 Sep 2024 14:09:01 +0530 Subject: [PATCH] Address comments-2 --- .../torch-mlir-import-onnx/OnnxImporter.cpp | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp b/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp index b81e90c171cae..3cbc435cac989 100644 --- a/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp +++ b/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp @@ -912,10 +912,11 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) { size_t out_size; switch (tensor_proto.data_type()) { case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT: { - const float *data = {0}; + const float *data = nullptr; if (has_raw_data) { data = graph_info_.GetOptionalRawData(tensor_proto, out_size); - ORT_ENFORCE(data); + ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ", + tensor_proto.DebugString()); } splat_attr = mlirDenseElementsAttrFloatSplatGet( tensorTypeFor(mlirF32TypeGet(context_)), @@ -923,10 +924,11 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) { break; } case onnx::TensorProto::DataType::TensorProto_DataType_INT32: { - const int32_t *data = {0}; + const int32_t *data = nullptr; if (has_raw_data) { data = graph_info_.GetOptionalRawData(tensor_proto, out_size); - ORT_ENFORCE(data); + ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ", + tensor_proto.DebugString()); } splat_attr = mlirDenseElementsAttrInt32SplatGet( tensorTypeFor(mlirIntegerTypeSignedGet(context_, 32)), @@ -934,10 +936,11 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) { break; } case onnx::TensorProto::DataType::TensorProto_DataType_INT64: { - const int64_t *data = {0}; + const int64_t *data = nullptr; if (has_raw_data) { data = graph_info_.GetOptionalRawData(tensor_proto, out_size); - ORT_ENFORCE(data); + ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ", + tensor_proto.DebugString()); } splat_attr = mlirDenseElementsAttrInt64SplatGet( tensorTypeFor(mlirIntegerTypeSignedGet(context_, 64)), @@ -945,10 +948,11 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) { break; } case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE: { - const double *data = {0}; + const double *data = nullptr; if (has_raw_data) { data = graph_info_.GetOptionalRawData(tensor_proto, out_size); - ORT_ENFORCE(data); + ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ", + tensor_proto.DebugString()); } splat_attr = mlirDenseElementsAttrDoubleSplatGet( tensorTypeFor(mlirF64TypeGet(context_)), @@ -956,10 +960,11 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) { break; } case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: { - const uint64_t *data = {0}; + const uint64_t *data = nullptr; if (has_raw_data) { data = graph_info_.GetOptionalRawData(tensor_proto, out_size); - ORT_ENFORCE(data); + ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ", + tensor_proto.DebugString()); } splat_attr = mlirDenseElementsAttrUInt64SplatGet( tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 64)), @@ -967,10 +972,11 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) { break; } case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: { - const uint32_t *data = {0}; + const uint32_t *data = nullptr; if (has_raw_data) { data = graph_info_.GetOptionalRawData(tensor_proto, out_size); - ORT_ENFORCE(data); + ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ", + tensor_proto.DebugString()); } splat_attr = mlirDenseElementsAttrUInt32SplatGet( tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 32)),