diff --git a/frontends/comet_dsl/mlir/MLIRGen.cpp b/frontends/comet_dsl/mlir/MLIRGen.cpp index af298dab..84c9d482 100644 --- a/frontends/comet_dsl/mlir/MLIRGen.cpp +++ b/frontends/comet_dsl/mlir/MLIRGen.cpp @@ -36,6 +36,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Verifier.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -1174,9 +1175,10 @@ namespace auto *rhsLT = llvm::cast(expr); auto name = rhsLT->getTensorName(); mlir::Value tensorValue = symbolTable.lookup(name); + mlir::ShapedType shapedT = mlir::cast(tensorValue.getType()); comet_debug() << " generate ta.sum op\n"; /// TODO(gkestor): look at reduceOp in linalg - sumVal = builder.create(location, builder.getF64Type(), tensorValue); + sumVal = builder.create(location, shapedT.getElementType(), tensorValue); } /// Case 2: SUM(A[i,j]*B[j,k]) @@ -1797,7 +1799,9 @@ namespace /// for DenseTensorDeclOp create mlir::StringRef format_strref = dyn_cast(rhs_tensor.getDefiningOp()).getFormat(); mlir::StringAttr formatAttr = builder.getStringAttr(format_strref); - lhs_tensor = builder.create(loc(transpose.loc()), mlir::RankedTensorType::get(shape, builder.getF64Type()), indices, formatAttr); + mlir::ShapedType shapedT = mlir::cast(rhs_tensor.getType()); + + lhs_tensor = builder.create(loc(transpose.loc()), mlir::RankedTensorType::get(shape, shapedT.getElementType()), indices, formatAttr); /// populate formats /// assumes lhs and rhs formats are same @@ -1811,9 +1815,9 @@ namespace mlir::StringAttr formatAttr = builder.getStringAttr(format_strref); std::vector format = mlir::tensorAlgebra::getFormats(format_strref, shape.size(), builder.getContext()); - mlir::Type element_type = builder.getF64Type(); + mlir::ShapedType shapedT = mlir::cast(rhs_tensor.getType()); + mlir::Type element_type = shapedT.getElementType(); return_type = SparseTensorType::get(builder.getContext(), element_type, shape, format); - /// no lhs_LabeledTensor has been created. The output tensor of tranpose doesn't have explicit declaration, /// BoolAttr is true to speficy SparseTensorDeclOp is for temporaries auto sp_tensor_type = SparseTensorType::get(builder.getContext(), element_type, shape, format); @@ -1832,7 +1836,8 @@ namespace auto strAttr = builder.getStrArrayAttr(formats); comet_debug() << " create TransposeOp\n"; - mlir::Value t = builder.create(loc(transpose.loc()), mlir::RankedTensorType::get(shape, builder.getF64Type()), + mlir::ShapedType shapedT = mlir::cast(rhs_tensor.getType()); + mlir::Value t = builder.create(loc(transpose.loc()), mlir::RankedTensorType::get(shape, shapedT.getElementType()), rhs_tensor, all_labels_val, affineMapArrayAttr, strAttr); builder.create(loc(transpose.loc()), t.getDefiningOp()->getResult(0), lhs_tensor); comet_vdump(t); diff --git a/include/comet/Dialect/TensorAlgebra/IR/TAOps.td b/include/comet/Dialect/TensorAlgebra/IR/TAOps.td index 19696f2e..d0653959 100644 --- a/include/comet/Dialect/TensorAlgebra/IR/TAOps.td +++ b/include/comet/Dialect/TensorAlgebra/IR/TAOps.td @@ -619,7 +619,7 @@ def ReduceOp : TA_Op<"reduce",[Pure]> { let arguments = (ins TA_AnyTensor:$rhs); - let results = (outs F64:$lhs); + let results = (outs AnyTypeOf<[F32,F64,Index]>:$lhs); let builders = [OpBuilder< (ins "Value":$input)>]; @@ -903,7 +903,9 @@ def PrintOp : TA_Op<"print"> { /// The print operation takes an input tensor to print. /// We can extend the list of supported datatype for print with F64Tensor, I8MemRef, I64MemRef, F32MemRef, etc. - let arguments = (ins AnyTypeOf<[F64, + let arguments = (ins AnyTypeOf<[F32, + F64, + F32MemRef, F64MemRef, TA_AnyTensor]>:$input); } @@ -1048,7 +1050,7 @@ def TensorInsertOp : TA_Op<"TAInsertOp", [Pure, SameVariadicOperandSize]>{ ins TA_AnyTensor:$tensor, Variadic:$pos, Variadic:$crds, - F64:$value + AnyTypeOf<[F32,F64]>:$value ); let results = (outs TA_AnyTensor); @@ -1061,7 +1063,7 @@ def TensorExtractOp : TA_Op<"TAExtractOp", [Pure, SameVariadicOperandSize]>{ let arguments = ( ins TA_AnyTensor:$tensor, Index:$pos, - F64Attr:$zero + AnyAttrOf<[F32Attr, F64Attr]>:$zero ); let results = (outs AnyFloat); diff --git a/include/comet/Dialect/Utils/Utils.h b/include/comet/Dialect/Utils/Utils.h index 555bdf2f..826f9cc6 100644 --- a/include/comet/Dialect/Utils/Utils.h +++ b/include/comet/Dialect/Utils/Utils.h @@ -39,7 +39,6 @@ #include /// TODO(gkestor): supports only f64 - need generalization -extern std::string VALUETYPE; using namespace mlir::linalg; diff --git a/lib/Conversion/IndexTreeToSCF/IndexTreeToSCF.cpp b/lib/Conversion/IndexTreeToSCF/IndexTreeToSCF.cpp index 1ce40ee0..9d4ceb24 100644 --- a/lib/Conversion/IndexTreeToSCF/IndexTreeToSCF.cpp +++ b/lib/Conversion/IndexTreeToSCF/IndexTreeToSCF.cpp @@ -38,6 +38,8 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Pass/Pass.h" #include "mlir/IR/Dominance.h" @@ -446,10 +448,24 @@ LowerIndexTreeToSCFPass::convertOperand(IndexTreeOperandOp op, IRRewriter &rewri } Value pos = positions[positions.size() - 1]; double zero = 0; + FloatAttr zero_attr; + if(semiring == "minxy"){ zero = INFINITY; } - return rewriter.create(loc, element_type, tensor, pos, rewriter.getF64FloatAttr(zero)); + if(element_type.isF32()) + { + zero_attr = rewriter.getF32FloatAttr(zero); + } + else if(element_type.isF64()) + { + zero_attr = rewriter.getF64FloatAttr(zero); + } + else + { + assert(false && "Unsupported type"); + } + return rewriter.create(loc, element_type, tensor, pos, zero_attr); } } @@ -465,14 +481,30 @@ LowerIndexTreeToSCFPass::convertOperand(IndexTreeLHSOperandOp op, IRRewriter &re if((tensor_type = llvm::dyn_cast(tensor.getType()))){ return rewriter.create(loc, tensor_type.getElementType(), tensor, crds); } else { + ShapedType sparseT = mlir::cast(tensor.getType()); // LHS may not be constant (i.e. if we are inserting into a tensor that we need to resize), // so cannot directly lower like we can the RHS Value pos = positions[positions.size() - 1]; double zero = 0; + FloatAttr zero_attr; + if(semiring == "minxy"){ zero = INFINITY; } - return rewriter.create(loc, rewriter.getF64Type(), tensor, pos, rewriter.getF64FloatAttr(zero)); + if(sparseT.getElementType().isF32()) + { + zero_attr = rewriter.getF32FloatAttr(zero); + } + else if(sparseT.getElementType().isF64()) + { + zero_attr = rewriter.getF64FloatAttr(zero); + } + else + { + assert(false && "Unsupported type"); + } + + return rewriter.create(loc, sparseT.getElementType(), tensor, pos, zero_attr); } } diff --git a/lib/Conversion/TensorAlgebraToSCF/LateLowering.cpp b/lib/Conversion/TensorAlgebraToSCF/LateLowering.cpp index 0630f919..925924e8 100644 --- a/lib/Conversion/TensorAlgebraToSCF/LateLowering.cpp +++ b/lib/Conversion/TensorAlgebraToSCF/LateLowering.cpp @@ -33,11 +33,13 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Pass/Pass.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SmallVector.h" #include using namespace mlir; @@ -74,54 +76,37 @@ namespace Location loc = op->getLoc(); auto module = op->getParentOfType(); auto *ctx = op->getContext(); - FloatType f64Type = FloatType::getF64(ctx); - IndexType indexType = IndexType::get(ctx); - Type unrankedMemrefType_f64 = UnrankedMemRefType::get(f64Type, 0); - Type unrankedMemref_index = mlir::UnrankedMemRefType::get(indexType, 0); - auto printTensorF64Func = FunctionType::get(ctx, {mlir::UnrankedMemRefType::get(f64Type, 0)}, {}); - auto printTensorIndexFunc = FunctionType::get(ctx, {mlir::UnrankedMemRefType::get(indexType, 0)}, {}); - auto printScalarFunc = FunctionType::get(ctx, {FloatType::getF64(ctx)}, {}); - - func::FuncOp print_func; + auto inputType = op->getOperand(0).getType(); - /// If the Input type is scalar (F64) - if (inputType.isa()) + if(ShapedType shaped_type = mlir::dyn_cast(inputType)) { - std::string print_scalar_f64Str = "printF64"; - std::string print_newline_Str = "printNewline"; - if (!hasFuncDeclaration(module, "printF64")) - { - print_func = func::FuncOp::create(loc, print_scalar_f64Str, printScalarFunc, ArrayRef{}); - print_func.setPrivate(); - module.push_back(print_func); + auto unrankedMemrefType = mlir::UnrankedMemRefType::get(shaped_type.getElementType(), 0); + auto printTensor = FunctionType::get(ctx, {unrankedMemrefType}, {}); - if (!hasFuncDeclaration(module, "printNewline")) - { - auto printNewLineFunc = FunctionType::get(ctx, {}, {}); - func::FuncOp print_newline = func::FuncOp::create(loc, print_newline_Str, printNewLineFunc, ArrayRef{}); - print_newline.setPrivate(); - module.push_back(print_newline); - } + std::string comet_print; //_f64Str = "comet_print_memref_f64"; + if(shaped_type.getElementType().isF32()) + { + comet_print = "comet_print_memref_f32"; } - rewriter.create(loc, print_scalar_f64Str, SmallVector{}, ValueRange{op->getOperand(0)}); - rewriter.create(loc, print_newline_Str, SmallVector{}, ValueRange{}); - } - else - { - std::string comet_print_f64Str = "comet_print_memref_f64"; - if (!hasFuncDeclaration(module, comet_print_f64Str)) + else if(shaped_type.getElementType().isF64()) { - print_func = func::FuncOp::create(loc, comet_print_f64Str, printTensorF64Func, ArrayRef{}); - print_func.setPrivate(); - module.push_back(print_func); + comet_print = "comet_print_memref_f64"; + } + else if(shaped_type.getElementType().isIndex()) + { + comet_print = "comet_print_memref_i64"; + } + else + { + assert(false && "Unexpected type to print"); } - std::string comet_print_i64Str = "comet_print_memref_i64"; - if (!hasFuncDeclaration(module, comet_print_i64Str)) + + if (!hasFuncDeclaration(module, comet_print)) { - print_func = func::FuncOp::create(loc, comet_print_i64Str, printTensorIndexFunc, ArrayRef{}); + func::FuncOp print_func = func::FuncOp::create(loc, comet_print, printTensor, ArrayRef{}); print_func.setPrivate(); module.push_back(print_func); } @@ -130,28 +115,67 @@ namespace { auto alloc_op = cast(op->getOperand(0).getDefiningOp()); comet_vdump(alloc_op); - auto u = rewriter.create(loc, unrankedMemrefType_f64, alloc_op); - rewriter.create(loc, comet_print_f64Str, SmallVector{}, ValueRange{u}); - }else if (inputType.isa()) + auto u = rewriter.create(loc, unrankedMemrefType, alloc_op); + rewriter.create(loc, comet_print, SmallVector{}, ValueRange{u}); + } + else if (inputType.isa()) { auto rhs = op->getOperand(0); auto tensor_type = llvm::cast(inputType); auto memref_type = MemRefType::get(tensor_type.getShape(), tensor_type.getElementType()); auto buffer = rewriter.create(loc, memref_type, rhs); - - if(llvm::isa(tensor_type.getElementType())){ - auto u = rewriter.create(loc, unrankedMemref_index, buffer); - rewriter.create(loc, comet_print_i64Str, SmallVector{}, ValueRange{u}); - } else { - auto u = rewriter.create(loc, unrankedMemrefType_f64, buffer); - rewriter.create(loc, comet_print_f64Str, SmallVector{}, ValueRange{u}); - } + auto u = rewriter.create(loc, unrankedMemrefType, buffer); + rewriter.create(loc, comet_print, SmallVector{}, ValueRange{u}); } else { llvm::errs() << __FILE__ << " " << __LINE__ << "Unknown Data type\n"; } } + /// If the Input type is scalar (F64) + else if (inputType.isa()) + { + std::string print_scalar; + if(inputType.isF64()) + { + print_scalar = "printF64"; + } + else if (inputType.isF32()) + { + print_scalar = "printF32"; + } + else if (inputType.isIndex()) + { + print_scalar = "printI64"; + } + else + { + assert(false && "Unsupported float type"); + } + FunctionType printScalarFunc = FunctionType::get(ctx, {inputType}, {}); + + std::string print_newline_Str = "printNewline"; + if (!hasFuncDeclaration(module, print_scalar)) + { + func::FuncOp print_func = func::FuncOp::create(loc, print_scalar, printScalarFunc, ArrayRef{}); + print_func.setPrivate(); + module.push_back(print_func); + + if (!hasFuncDeclaration(module, "printNewline")) + { + auto printNewLineFunc = FunctionType::get(ctx, {}, {}); + func::FuncOp print_newline = func::FuncOp::create(loc, print_newline_Str, printNewLineFunc, ArrayRef{}); + print_newline.setPrivate(); + module.push_back(print_newline); + } + } + rewriter.create(loc, print_scalar, SmallVector{}, ValueRange{op->getOperand(0)}); + rewriter.create(loc, print_newline_Str, SmallVector{}, ValueRange{}); + } + else + { + assert(false && "Unexpected type to print"); + } /// Notify the rewriter that this operation has been removed. comet_pdump(op); diff --git a/lib/Conversion/TensorAlgebraToSCF/SparseTensorConversionPass.cpp b/lib/Conversion/TensorAlgebraToSCF/SparseTensorConversionPass.cpp index 9722bd37..b8786918 100644 --- a/lib/Conversion/TensorAlgebraToSCF/SparseTensorConversionPass.cpp +++ b/lib/Conversion/TensorAlgebraToSCF/SparseTensorConversionPass.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypeInterfaces.h" @@ -926,7 +927,8 @@ class ConvertWorkspaceTensorExtractOp [&] (OpBuilder& builder, Location loc) { // TODO: Does the zero value depend on the semi-ring? Type result_type = op->getResult(0).getType(); - Value zero = builder.create(loc, result_type, op.getZeroAttr()); + FloatAttr zero_attr = op.getZeroAttr().cast(); + Value zero = builder.create(loc, result_type, zero_attr); builder.create(loc, ArrayRef({zero})); } ); diff --git a/lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp b/lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp index 8984651d..e32bd69d 100644 --- a/lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp +++ b/lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp @@ -36,6 +36,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" @@ -261,12 +262,11 @@ namespace std::string formats_strOut(opFormatsArrayAttr[1].cast().getValue()); IntegerType i32Type = IntegerType::get(ctx, 32); IndexType indexType = IndexType::get(ctx); - FloatType f64Type = FloatType::getF64(ctx); Value input_perm_num = rewriter.create(loc, i32Type, rewriter.getI32IntegerAttr(pnum[0])); Value output_perm_num = rewriter.create(loc, i32Type, rewriter.getI32IntegerAttr(pnum[1])); - Type unrankedMemrefType_f64 = UnrankedMemRefType::get(f64Type, 0); + UnrankedMemRefType unrankedMemrefType_float = UnrankedMemRefType::get(spType.getElementType(), 0); Type unrankedMemrefType_index = UnrankedMemRefType::get(indexType, 0); mlir::func::FuncOp transpose_func; /// runtime call @@ -322,7 +322,7 @@ namespace } Value vals = rewriter.create(loc, RankedTensorType::get({ShapedType::kDynamic}, tensor_type.getElementType()), tensor); Value vals_memref = rewriter.create(loc, MemRefType::get({ShapedType::kDynamic}, tensor_type.getElementType()), vals); - Value vals_v = rewriter.create(loc, unrankedMemrefType_f64, vals_memref); + Value vals_v = rewriter.create(loc, unrankedMemrefType_float, vals_memref); alloc_sizes_cast_vecs[n].push_back(vals_v); auto dims_tensor = mlir::cast(tensors[n].getDefiningOp()).getDims(); @@ -356,26 +356,35 @@ namespace if (rank_size == 2) { /// 2D - auto transpose2DF64Func = FunctionType::get(ctx, + auto transpose2DFunc = FunctionType::get(ctx, {i32Type, i32Type, i32Type, i32Type, unrankedMemrefType_index, unrankedMemrefType_index, unrankedMemrefType_index, unrankedMemrefType_index, unrankedMemrefType_index, unrankedMemrefType_index, unrankedMemrefType_index, unrankedMemrefType_index, - unrankedMemrefType_f64, + unrankedMemrefType_float, i32Type, i32Type, i32Type, i32Type, unrankedMemrefType_index, unrankedMemrefType_index, unrankedMemrefType_index, unrankedMemrefType_index, unrankedMemrefType_index, unrankedMemrefType_index, unrankedMemrefType_index, unrankedMemrefType_index, - unrankedMemrefType_f64, + unrankedMemrefType_float, unrankedMemrefType_index}, {}); - std::string func_name = "transpose_2D_f64"; + std::string func_name; + if(unrankedMemrefType_float.getElementType().isF32()) + { + func_name = "transpose_2D_f32"; + } + else if(unrankedMemrefType_float.getElementType().isF64()) + { + func_name = "transpose_2D_f64"; + } + if (!hasFuncDeclaration(module, func_name)) { - transpose_func = mlir::func::FuncOp::create(loc, func_name, transpose2DF64Func, ArrayRef{}); + transpose_func = mlir::func::FuncOp::create(loc, func_name, transpose2DFunc, ArrayRef{}); transpose_func.setPrivate(); module.push_back(transpose_func); } @@ -402,7 +411,7 @@ namespace } else if (rank_size == 3) { /// 3D - auto transpose3DF64Func = FunctionType::get(ctx, + auto transpose3DFunc = FunctionType::get(ctx, {i32Type, i32Type, i32Type, i32Type, i32Type, i32Type, @@ -413,7 +422,7 @@ namespace unrankedMemrefType_index, unrankedMemrefType_index, unrankedMemrefType_index, unrankedMemrefType_index, unrankedMemrefType_index, unrankedMemrefType_index, - unrankedMemrefType_f64, + unrankedMemrefType_float, i32Type, i32Type, i32Type, i32Type, i32Type, i32Type, @@ -423,14 +432,22 @@ namespace unrankedMemrefType_index, unrankedMemrefType_index, unrankedMemrefType_index, unrankedMemrefType_index, unrankedMemrefType_index, unrankedMemrefType_index, - unrankedMemrefType_f64, + unrankedMemrefType_float, unrankedMemrefType_index}, {}); - std::string func_name = "transpose_3D_f64"; + std::string func_name; + if(unrankedMemrefType_float.getElementType().isF32()) + { + func_name = "transpose_3D_f32"; + } + else + { + func_name = "transpose_3D_f64"; + } if (!hasFuncDeclaration(module, func_name)) { - transpose_func = mlir::func::FuncOp::create(loc, func_name, transpose3DF64Func, ArrayRef{}); + transpose_func = mlir::func::FuncOp::create(loc, func_name, transpose3DFunc, ArrayRef{}); transpose_func.setPrivate(); module.push_back(transpose_func); } @@ -484,16 +501,31 @@ namespace comet_debug() << "Lowering Reduce operation to SCF\n"; Location loc = op.getLoc(); - auto f64Type = rewriter.getF64Type(); + // auto f64Type = rewriter.getF64Type(); auto inputType = op->getOperand(0).getType(); /// Allocate memory for the result and initialized it auto cst_zero = rewriter.create(loc, 0); /// need to access res alloc - MemRefType memTy_alloc_res = MemRefType::get({1}, f64Type); + ShapedType shapeT = mlir::cast(inputType); + MemRefType memTy_alloc_res = MemRefType::get({1}, shapeT.getElementType()); + Value res = rewriter.create(loc, memTy_alloc_res); - Value const_f64_0 = rewriter.create(loc, f64Type, rewriter.getF64FloatAttr(0)); + FloatAttr zero; + if(shapeT.getElementType().isF32()) + { + zero = rewriter.getF32FloatAttr(0); + } + else if(shapeT.getElementType().isF64()) + { + zero = rewriter.getF64FloatAttr(0); + } + else + { + assert(false && "Unexpected type"); + } + Value const_float_0 = rewriter.create(loc, shapeT.getElementType(), zero); std::vector alloc_zero_loc = {cst_zero}; - rewriter.create(loc, const_f64_0, + rewriter.create(loc, const_float_0, res, alloc_zero_loc); comet_vdump(res); @@ -535,7 +567,7 @@ namespace comet_debug() << " tensorRank: " << tensorRanks << " \n"; comet_debug() << "Tensor to reduce:\n"; comet_pdump(op->getOperand(0).getDefiningOp()); - Value sp_tensor_values = rewriter.create(loc, RankedTensorType::get({ShapedType::kDynamic,}, rewriter.getF64Type()), op->getOperand(0)); + Value sp_tensor_values = rewriter.create(loc, RankedTensorType::get({ShapedType::kDynamic,}, sp_tensor_type.getElementType()), op->getOperand(0)); Value upperBound = rewriter.create(loc, sp_tensor_values, 0); comet_debug() << "Upper Bound:\n"; diff --git a/lib/Dialect/TensorAlgebra/Transforms/TensorDeclLowering.cpp b/lib/Dialect/TensorAlgebra/Transforms/TensorDeclLowering.cpp index f801efcb..f5e51a83 100644 --- a/lib/Dialect/TensorAlgebra/Transforms/TensorDeclLowering.cpp +++ b/lib/Dialect/TensorAlgebra/Transforms/TensorDeclLowering.cpp @@ -66,24 +66,16 @@ using namespace mlir::indexTree; //===----------------------------------------------------------------------===// namespace { - void insertReadFileLibCall(int rank_size, MLIRContext *ctx, ModuleOp &module, func::FuncOp function) + void insertReadFileLibCall(int rank_size, Type floatEleType,MLIRContext *ctx, ModuleOp &module, func::FuncOp function) { comet_debug() << "Inserting insertReadFileLibCall\n"; - FloatType f32Type, f64Type; - if (VALUETYPE.compare("f32") == 0) - { - f32Type = FloatType::getF32(ctx); - } - else - { - f64Type = FloatType::getF64(ctx); - } + IndexType indexType = IndexType::get(function.getContext()); IntegerType i32Type = IntegerType::get(ctx, 32); - auto unrankedMemref_f64 = mlir::UnrankedMemRefType::get(f64Type, 0); + auto unrankedMemref_f64 = mlir::UnrankedMemRefType::get(Float64Type::get(ctx), 0); /// TODO(gkestor): there is an issue with F32 UnrankedMemRefType - auto unrankedMemref_f32 = mlir::UnrankedMemRefType::get(f64Type, 0); + auto unrankedMemref_f32 = mlir::UnrankedMemRefType::get(Float32Type::get(ctx), 0); auto unrankedMemref_index = mlir::UnrankedMemRefType::get(indexType, 0); if (rank_size == 2) @@ -106,7 +98,7 @@ namespace unrankedMemref_f64, i32Type}, {}); - if (VALUETYPE.compare("f32") == 0) + if (floatEleType.isF32()) { std::string func_name = "read_input_2D_f32"; if (!hasFuncDeclaration(module, func_name)) @@ -118,7 +110,7 @@ namespace module.push_back(func1); } } - else /// f64 + else if (floatEleType.isF64()) { std::string func_name = "read_input_2D_f64"; if (!hasFuncDeclaration(module, func_name)) @@ -130,10 +122,14 @@ namespace module.push_back(func1); } } + else + { + assert(false && "Unexpected type"); + } auto readInputSizes2DF64Func = FunctionType::get(ctx, {i32Type, indexType, indexType, indexType, indexType, unrankedMemref_index, i32Type}, {}); /// last arg (i32Type): readMode - if (VALUETYPE.compare("f32") == 0) + if (floatEleType.isF32()) { std::string func_name = "read_input_sizes_2D_f32"; if (!hasFuncDeclaration(module, func_name)) @@ -145,7 +141,7 @@ namespace module.push_back(func1); } } - else + else if (floatEleType.isF64()) { std::string func_name = "read_input_sizes_2D_f64"; if (!hasFuncDeclaration(module, func_name)) @@ -157,6 +153,10 @@ namespace module.push_back(func1); } } + else + { + assert(false && "Unsupported type"); + } } /// 3D tensor else if (rank_size == 3) @@ -180,7 +180,7 @@ namespace unrankedMemref_f64, i32Type}, {}); - if (VALUETYPE.compare("f32") == 0) + if (floatEleType.isF32()) { std::string func_name = "read_input_3D_f32"; if (!hasFuncDeclaration(module, func_name)) @@ -191,24 +191,27 @@ namespace module.push_back(func1); } } - else + else if (floatEleType.isF64()) { std::string func_name = "read_input_3D_f64"; if (!hasFuncDeclaration(module, func_name)) { - comet_debug() << " Insert read_input_3D_f64 decl\n"; func::FuncOp func1 = func::FuncOp::create(function.getLoc(), func_name, readInput3DF64Func, ArrayRef{}); func1.setPrivate(); module.push_back(func1); } } + else + { + assert(false && "Unexpected type"); + } + auto readInputSizes3DF64Func = FunctionType::get(ctx, {i32Type, indexType, indexType, indexType, indexType, indexType, indexType, unrankedMemref_index, i32Type}, {}); /// last arg (i32Type): readMode - if (VALUETYPE.compare("f32") == 0) + if (floatEleType.isF32()) { - std::string func_name = "read_input_sizes_3D_f32"; if (!hasFuncDeclaration(module, func_name)) { @@ -218,7 +221,7 @@ namespace module.push_back(func1); } } - else + else if (floatEleType.isF64()) { std::string func_name = "read_input_sizes_3D_f64"; if (!hasFuncDeclaration(module, func_name)) @@ -230,6 +233,10 @@ namespace module.push_back(func1); } } + else + { + assert(false && "Unsupported type"); + } } else { @@ -332,9 +339,13 @@ Value insertSparseTensorDeclOp(PatternRewriter & rewriter, { comet_debug() << "lowerSparseOutputTensorDec::TempSparseOutputTensorDeclOp lowering\n"; } + else + { + assert(false && "Op should be either SparseOutputTensorDeclOp or TempSparseOutputTensorDeclOp"); + } + + SparseTensorType spType = mlir::cast(op->getResultTypes()[0]); - assert(isa(op) || (isa(op) && - "Op should be either SparseOutputTensorDeclOp or TempSparseOutputTensorDeclOp")); comet_vdump(op); auto loc = op.getLoc(); @@ -346,13 +357,11 @@ Value insertSparseTensorDeclOp(PatternRewriter & rewriter, auto rank_size = mlir::cast(op.getResult().getType()).getRank(); IndexType indexType = IndexType::get(op.getContext()); - FloatType f64Type = FloatType::getF64(op.getContext()); - if (VALUETYPE.compare(0, 3, "f32") == 0) - f64Type = FloatType::getF32(op.getContext()); + Type valsType = spType.getElementType(); /// A1_pos ... A_value auto dynamicmemTy_1d_index = MemRefType::get({ShapedType::kDynamic}, indexType); /// memref - auto dynamicmemTy_1d_f64 = MemRefType::get({ShapedType::kDynamic}, f64Type); /// memref + auto dynamicmemTy_1d_f64 = MemRefType::get({ShapedType::kDynamic}, valsType); /// memref comet_debug() << " " << formats_str << " isDense: " << isDense(formats_str, ", ") << "\n"; @@ -681,7 +690,7 @@ Value insertSparseTensorDeclOp(PatternRewriter & rewriter, } llvm::ArrayRef cur_memref_arrayref = llvm::ArrayRef(cur_memref); - MemRefType memrefType2 = MemRefType::get(cur_memref_arrayref, f64Type); + MemRefType memrefType2 = MemRefType::get(cur_memref_arrayref, valsType); Value alloc_sizes1 = insertAllocAndInitialize(loc, memrefType2, ValueRange(cur_indices), rewriter); comet_debug() << " AllocOp: "; comet_vdump(alloc_sizes1); @@ -815,9 +824,7 @@ Value insertSparseTensorDeclOp(PatternRewriter & rewriter, // auto rank_size = op.getNumOperands(); IndexType indexType = IndexType::get(op.getContext()); - FloatType f64Type = FloatType::getF64(op.getContext()); - if (VALUETYPE.compare(0, 3, "f32") == 0) - f64Type = FloatType::getF32(op.getContext()); + Type floatEleType = type.getElementType(); for (auto u1 : op.getOperation()->getUsers()) { @@ -914,10 +921,10 @@ Value insertSparseTensorDeclOp(PatternRewriter & rewriter, /// A1_pos ... A_value auto dynamicmemTy_1d_index = MemRefType::get({ShapedType::kDynamic}, indexType); /// memref - auto dynamicmemTy_1d_f64 = MemRefType::get({ShapedType::kDynamic}, f64Type); /// memref + auto dynamicmemTy_1d_float = MemRefType::get({ShapedType::kDynamic}, floatEleType); /// memref Type unrankedMemTy_index = UnrankedMemRefType::get(indexType, 0); - Type unrankedMemTy_f64 = UnrankedMemRefType::get(f64Type, 0); + Type unrankedMemTy_float = UnrankedMemRefType::get(floatEleType, 0); comet_debug() << " " << formats_str << " isDense: " << isDense(formats_str, ", ") << "\n"; @@ -1003,17 +1010,21 @@ Value insertSparseTensorDeclOp(PatternRewriter & rewriter, { /// 2D comet_debug() << " 2D\n"; /// Add function definition to the module - insertReadFileLibCall(rank_size, ctx, module, function); + insertReadFileLibCall(rank_size, floatEleType, ctx, module, function); std::string read_input_sizes_str; - if (VALUETYPE.compare(0, 3, "f32") == 0) + if (floatEleType.isF32()) { read_input_sizes_str = "read_input_sizes_2D_f32"; } - else + else if(floatEleType.isF64()) { read_input_sizes_str = "read_input_sizes_2D_f64"; } + else + { + assert(false && "Unexpected data type"); + } auto read_input_sizes_Call = rewriter.create(loc, read_input_sizes_str, SmallVector{}, ValueRange{sparseFileID, dim_format[0], dim_format[1], dim_format[2], dim_format[3], @@ -1024,17 +1035,22 @@ Value insertSparseTensorDeclOp(PatternRewriter & rewriter, { /// 3D comet_debug() << " 3D\n"; /// Add function definition to the module - insertReadFileLibCall(rank_size, ctx, module, function); + insertReadFileLibCall(rank_size, floatEleType, ctx, module, function); + std::string read_input_sizes_str; - if (VALUETYPE.compare(0, 3, "f32") == 0) + if (floatEleType.isF32()) { read_input_sizes_str = "read_input_sizes_3D_f32"; } - else - { /// default f64 + else if(floatEleType.isF64()) + { read_input_sizes_str = "read_input_sizes_3D_f64"; } + else + { + assert(false && "Unexpected data type"); + } auto read_input_sizes_3D_Call = rewriter.create(loc, read_input_sizes_str, SmallVector{}, ValueRange{sparseFileID, dim_format[0], dim_format[1], /// A1, A1_tile @@ -1079,11 +1095,11 @@ Value insertSparseTensorDeclOp(PatternRewriter & rewriter, { std::vector idxes; idxes.push_back(array_sizes[i]); - Value alloc_size = insertAllocAndInitialize(loc, dynamicmemTy_1d_f64, ValueRange{idxes}, rewriter); + Value alloc_size = insertAllocAndInitialize(loc, dynamicmemTy_1d_float, ValueRange{idxes}, rewriter); comet_debug() << " "; comet_vdump(alloc_size); alloc_sizes_vec.push_back(alloc_size); - Value alloc_size_cast = rewriter.create(loc, unrankedMemTy_f64, alloc_size); + Value alloc_size_cast = rewriter.create(loc, unrankedMemTy_float, alloc_size); alloc_sizes_cast_vec.push_back(alloc_size_cast); } @@ -1091,14 +1107,18 @@ Value insertSparseTensorDeclOp(PatternRewriter & rewriter, if (rank_size == 2) { /// 2D std::string read_input_str; - if (VALUETYPE.compare(0, 3, "f32") == 0) + if (floatEleType.isF32()) { read_input_str = "read_input_2D_f32"; } - else + else if (floatEleType.isF64()) { read_input_str = "read_input_2D_f64"; } + else + { + assert(false && "Unexpected type"); + } auto read_input_f64Call = rewriter.create(loc, read_input_str, SmallVector{}, ValueRange{sparseFileID, dim_format[0], dim_format[1], /// A1_format, A1_tile_format @@ -1117,11 +1137,11 @@ Value insertSparseTensorDeclOp(PatternRewriter & rewriter, else if (rank_size == 3) { /// 3D std::string read_input_str; - if (VALUETYPE.compare(0, 3, "f32") == 0) + if (floatEleType.isF32()) { read_input_str = "read_input_3D_f32"; } - else + else if (floatEleType.isF64()) { read_input_str = "read_input_3D_f64"; } diff --git a/lib/Dialect/Utils/Utils.cpp b/lib/Dialect/Utils/Utils.cpp index dcb212e7..e10cb041 100644 --- a/lib/Dialect/Utils/Utils.cpp +++ b/lib/Dialect/Utils/Utils.cpp @@ -46,7 +46,7 @@ // *********** For debug purpose *********// /// TODO(gkestor): supports only f64 - need generalization -std::string VALUETYPE = "f64"; +// std::string VALUETYPE = "f64"; using namespace mlir::arith; using namespace mlir::affine; diff --git a/lib/ExecutionEngine/StatUtils.cpp b/lib/ExecutionEngine/StatUtils.cpp index 33d5bc3f..5acbb5e8 100644 --- a/lib/ExecutionEngine/StatUtils.cpp +++ b/lib/ExecutionEngine/StatUtils.cpp @@ -91,6 +91,11 @@ extern "C" void _mlir_ciface_comet_print_memref_f64(UnrankedMemRefType * cometPrintMemRef(*M); } +extern "C" void _mlir_ciface_comet_print_memref_f32(UnrankedMemRefType *M) +{ + cometPrintMemRef(*M); +} + extern "C" void _mlir_ciface_comet_print_memref_i64(UnrankedMemRefType *M) { cometPrintMemRef(*M); @@ -102,6 +107,12 @@ extern "C" void comet_print_memref_f64(int64_t rank, void *ptr) _mlir_ciface_comet_print_memref_f64(&descriptor); } +extern "C" void comet_print_memref_f32(int64_t rank, void *ptr) +{ + UnrankedMemRefType descriptor = {rank, ptr}; + _mlir_ciface_comet_print_memref_f32(&descriptor); +} + extern "C" void comet_print_memref_i64(int64_t rank, void *ptr) { UnrankedMemRefType descriptor = {rank, ptr};