Skip to content

Commit

Permalink
Address comments-2
Browse files Browse the repository at this point in the history
  • Loading branch information
vinayakdsci committed Sep 23, 2024
1 parent 62e0867 commit 09dad3f
Showing 1 changed file with 18 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -912,65 +912,71 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
size_t out_size;
switch (tensor_proto.data_type()) {
case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT: {
const float *data = {0};
const float *data = nullptr;
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<float>(tensor_proto, out_size);
ORT_ENFORCE(data);
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
tensor_proto.DebugString());
}
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirF32TypeGet(context_)),
has_raw_data ? data[0] : tensor_proto.float_data(0));
break;
}
case onnx::TensorProto::DataType::TensorProto_DataType_INT32: {
const int32_t *data = {0};
const int32_t *data = nullptr;
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<int32_t>(tensor_proto, out_size);
ORT_ENFORCE(data);
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
tensor_proto.DebugString());
}
splat_attr = mlirDenseElementsAttrInt32SplatGet(
tensorTypeFor(mlirIntegerTypeSignedGet(context_, 32)),
has_raw_data ? data[0] : tensor_proto.int32_data(0));
break;
}
case onnx::TensorProto::DataType::TensorProto_DataType_INT64: {
const int64_t *data = {0};
const int64_t *data = nullptr;
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<int64_t>(tensor_proto, out_size);
ORT_ENFORCE(data);
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
tensor_proto.DebugString());
}
splat_attr = mlirDenseElementsAttrInt64SplatGet(
tensorTypeFor(mlirIntegerTypeSignedGet(context_, 64)),
has_raw_data ? data[0] : tensor_proto.int64_data(0));
break;
}
case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE: {
const double *data = {0};
const double *data = nullptr;
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<double>(tensor_proto, out_size);
ORT_ENFORCE(data);
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
tensor_proto.DebugString());
}
splat_attr = mlirDenseElementsAttrDoubleSplatGet(
tensorTypeFor(mlirF64TypeGet(context_)),
has_raw_data ? data[0] : tensor_proto.double_data(0));
break;
}
case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: {
const uint64_t *data = {0};
const uint64_t *data = nullptr;
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<uint64_t>(tensor_proto, out_size);
ORT_ENFORCE(data);
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
tensor_proto.DebugString());
}
splat_attr = mlirDenseElementsAttrUInt64SplatGet(
tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 64)),
has_raw_data ? data[0] : tensor_proto.uint64_data(0));
break;
}
case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: {
const uint32_t *data = {0};
const uint32_t *data = nullptr;
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<uint32_t>(tensor_proto, out_size);
ORT_ENFORCE(data);
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
tensor_proto.DebugString());
}
splat_attr = mlirDenseElementsAttrUInt32SplatGet(
tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 32)),
Expand Down

0 comments on commit 09dad3f

Please sign in to comment.