Skip to content

Commit

Permalink
Added support for f32 data #69
Browse files Browse the repository at this point in the history
  • Loading branch information
pthomadakis authored and johnpzh committed Dec 3, 2024
1 parent 84a0a97 commit 597fc1e
Show file tree
Hide file tree
Showing 10 changed files with 255 additions and 128 deletions.
15 changes: 10 additions & 5 deletions frontends/comet_dsl/mlir/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
Expand Down Expand Up @@ -1174,9 +1175,10 @@ namespace
auto *rhsLT = llvm::cast<LabeledTensorExprAST>(expr);
auto name = rhsLT->getTensorName();
mlir::Value tensorValue = symbolTable.lookup(name);
mlir::ShapedType shapedT = mlir::cast<mlir::ShapedType>(tensorValue.getType());
comet_debug() << " generate ta.sum op\n";
/// TODO(gkestor): look at reduceOp in linalg
sumVal = builder.create<mlir::tensorAlgebra::ReduceOp>(location, builder.getF64Type(), tensorValue);
sumVal = builder.create<mlir::tensorAlgebra::ReduceOp>(location, shapedT.getElementType(), tensorValue);
}

/// Case 2: SUM(A[i,j]*B[j,k])
Expand Down Expand Up @@ -1797,7 +1799,9 @@ namespace
/// for DenseTensorDeclOp create
mlir::StringRef format_strref = dyn_cast<DenseTensorDeclOp>(rhs_tensor.getDefiningOp()).getFormat();
mlir::StringAttr formatAttr = builder.getStringAttr(format_strref);
lhs_tensor = builder.create<DenseTensorDeclOp>(loc(transpose.loc()), mlir::RankedTensorType::get(shape, builder.getF64Type()), indices, formatAttr);
mlir::ShapedType shapedT = mlir::cast<mlir::ShapedType>(rhs_tensor.getType());

lhs_tensor = builder.create<DenseTensorDeclOp>(loc(transpose.loc()), mlir::RankedTensorType::get(shape, shapedT.getElementType()), indices, formatAttr);

/// populate formats
/// assumes lhs and rhs formats are same
Expand All @@ -1811,9 +1815,9 @@ namespace
mlir::StringAttr formatAttr = builder.getStringAttr(format_strref);

std::vector<int32_t> format = mlir::tensorAlgebra::getFormats(format_strref, shape.size(), builder.getContext());
mlir::Type element_type = builder.getF64Type();
mlir::ShapedType shapedT = mlir::cast<mlir::ShapedType>(rhs_tensor.getType());
mlir::Type element_type = shapedT.getElementType();
return_type = SparseTensorType::get(builder.getContext(), element_type, shape, format);

/// no lhs_LabeledTensor has been created. The output tensor of tranpose doesn't have explicit declaration,
/// BoolAttr is true to speficy SparseTensorDeclOp is for temporaries
auto sp_tensor_type = SparseTensorType::get(builder.getContext(), element_type, shape, format);
Expand All @@ -1832,7 +1836,8 @@ namespace
auto strAttr = builder.getStrArrayAttr(formats);

comet_debug() << " create TransposeOp\n";
mlir::Value t = builder.create<mlir::tensorAlgebra::TransposeOp>(loc(transpose.loc()), mlir::RankedTensorType::get(shape, builder.getF64Type()),
mlir::ShapedType shapedT = mlir::cast<mlir::ShapedType>(rhs_tensor.getType());
mlir::Value t = builder.create<mlir::tensorAlgebra::TransposeOp>(loc(transpose.loc()), mlir::RankedTensorType::get(shape, shapedT.getElementType()),
rhs_tensor, all_labels_val, affineMapArrayAttr, strAttr);
builder.create<TensorSetOp>(loc(transpose.loc()), t.getDefiningOp()->getResult(0), lhs_tensor);
comet_vdump(t);
Expand Down
10 changes: 6 additions & 4 deletions include/comet/Dialect/TensorAlgebra/IR/TAOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def ReduceOp : TA_Op<"reduce",[Pure]> {

let arguments = (ins TA_AnyTensor:$rhs);

let results = (outs F64:$lhs);
let results = (outs AnyTypeOf<[F32,F64,Index]>:$lhs);

let builders = [OpBuilder<
(ins "Value":$input)>];
Expand Down Expand Up @@ -903,7 +903,9 @@ def PrintOp : TA_Op<"print"> {

/// The print operation takes an input tensor to print.
/// We can extend the list of supported datatype for print with F64Tensor, I8MemRef, I64MemRef, F32MemRef, etc.
let arguments = (ins AnyTypeOf<[F64,
let arguments = (ins AnyTypeOf<[F32,
F64,
F32MemRef,
F64MemRef,
TA_AnyTensor]>:$input);
}
Expand Down Expand Up @@ -1048,7 +1050,7 @@ def TensorInsertOp : TA_Op<"TAInsertOp", [Pure, SameVariadicOperandSize]>{
ins TA_AnyTensor:$tensor,
Variadic<Index>:$pos,
Variadic<Index>:$crds,
F64:$value
AnyTypeOf<[F32,F64]>:$value
);

let results = (outs TA_AnyTensor);
Expand All @@ -1061,7 +1063,7 @@ def TensorExtractOp : TA_Op<"TAExtractOp", [Pure, SameVariadicOperandSize]>{
let arguments = (
ins TA_AnyTensor:$tensor,
Index:$pos,
F64Attr:$zero
AnyAttrOf<[F32Attr, F64Attr]>:$zero
);

let results = (outs AnyFloat);
Expand Down
1 change: 0 additions & 1 deletion include/comet/Dialect/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
#include <typeinfo>

/// TODO(gkestor): supports only f64 - need generalization
extern std::string VALUETYPE;

using namespace mlir::linalg;

Expand Down
36 changes: 34 additions & 2 deletions lib/Conversion/IndexTreeToSCF/IndexTreeToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Pass/Pass.h"
#include "mlir/IR/Dominance.h"
Expand Down Expand Up @@ -446,10 +448,24 @@ LowerIndexTreeToSCFPass::convertOperand(IndexTreeOperandOp op, IRRewriter &rewri
}
Value pos = positions[positions.size() - 1];
double zero = 0;
FloatAttr zero_attr;

if(semiring == "minxy"){
zero = INFINITY;
}
return rewriter.create<TensorExtractOp>(loc, element_type, tensor, pos, rewriter.getF64FloatAttr(zero));
if(element_type.isF32())
{
zero_attr = rewriter.getF32FloatAttr(zero);
}
else if(element_type.isF64())
{
zero_attr = rewriter.getF64FloatAttr(zero);
}
else
{
assert(false && "Unsupported type");
}
return rewriter.create<TensorExtractOp>(loc, element_type, tensor, pos, zero_attr);
}
}

Expand All @@ -465,14 +481,30 @@ LowerIndexTreeToSCFPass::convertOperand(IndexTreeLHSOperandOp op, IRRewriter &re
if((tensor_type = llvm::dyn_cast<mlir::TensorType>(tensor.getType()))){
return rewriter.create<tensor::ExtractOp>(loc, tensor_type.getElementType(), tensor, crds);
} else {
ShapedType sparseT = mlir::cast<ShapedType>(tensor.getType());
// LHS may not be constant (i.e. if we are inserting into a tensor that we need to resize),
// so cannot directly lower like we can the RHS
Value pos = positions[positions.size() - 1];
double zero = 0;
FloatAttr zero_attr;

if(semiring == "minxy"){
zero = INFINITY;
}
return rewriter.create<tensorAlgebra::TensorExtractOp>(loc, rewriter.getF64Type(), tensor, pos, rewriter.getF64FloatAttr(zero));
if(sparseT.getElementType().isF32())
{
zero_attr = rewriter.getF32FloatAttr(zero);
}
else if(sparseT.getElementType().isF64())
{
zero_attr = rewriter.getF64FloatAttr(zero);
}
else
{
assert(false && "Unsupported type");
}

return rewriter.create<tensorAlgebra::TensorExtractOp>(loc, sparseT.getElementType(), tensor, pos, zero_attr);
}
}

Expand Down
122 changes: 73 additions & 49 deletions lib/Conversion/TensorAlgebraToSCF/LateLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include <memory>

using namespace mlir;
Expand Down Expand Up @@ -74,54 +76,37 @@ namespace
Location loc = op->getLoc();
auto module = op->getParentOfType<ModuleOp>();
auto *ctx = op->getContext();
FloatType f64Type = FloatType::getF64(ctx);
IndexType indexType = IndexType::get(ctx);
Type unrankedMemrefType_f64 = UnrankedMemRefType::get(f64Type, 0);
Type unrankedMemref_index = mlir::UnrankedMemRefType::get(indexType, 0);

auto printTensorF64Func = FunctionType::get(ctx, {mlir::UnrankedMemRefType::get(f64Type, 0)}, {});
auto printTensorIndexFunc = FunctionType::get(ctx, {mlir::UnrankedMemRefType::get(indexType, 0)}, {});
auto printScalarFunc = FunctionType::get(ctx, {FloatType::getF64(ctx)}, {});

func::FuncOp print_func;

auto inputType = op->getOperand(0).getType();

/// If the Input type is scalar (F64)
if (inputType.isa<FloatType>())
if(ShapedType shaped_type = mlir::dyn_cast<ShapedType>(inputType))
{
std::string print_scalar_f64Str = "printF64";
std::string print_newline_Str = "printNewline";
if (!hasFuncDeclaration(module, "printF64"))
{
print_func = func::FuncOp::create(loc, print_scalar_f64Str, printScalarFunc, ArrayRef<NamedAttribute>{});
print_func.setPrivate();
module.push_back(print_func);
auto unrankedMemrefType = mlir::UnrankedMemRefType::get(shaped_type.getElementType(), 0);
auto printTensor = FunctionType::get(ctx, {unrankedMemrefType}, {});

if (!hasFuncDeclaration(module, "printNewline"))
{
auto printNewLineFunc = FunctionType::get(ctx, {}, {});
func::FuncOp print_newline = func::FuncOp::create(loc, print_newline_Str, printNewLineFunc, ArrayRef<NamedAttribute>{});
print_newline.setPrivate();
module.push_back(print_newline);
}
std::string comet_print; //_f64Str = "comet_print_memref_f64";
if(shaped_type.getElementType().isF32())
{
comet_print = "comet_print_memref_f32";
}
rewriter.create<func::CallOp>(loc, print_scalar_f64Str, SmallVector<Type, 2>{}, ValueRange{op->getOperand(0)});
rewriter.create<func::CallOp>(loc, print_newline_Str, SmallVector<Type, 2>{}, ValueRange{});
}
else
{
std::string comet_print_f64Str = "comet_print_memref_f64";
if (!hasFuncDeclaration(module, comet_print_f64Str))
else if(shaped_type.getElementType().isF64())
{
print_func = func::FuncOp::create(loc, comet_print_f64Str, printTensorF64Func, ArrayRef<NamedAttribute>{});
print_func.setPrivate();
module.push_back(print_func);
comet_print = "comet_print_memref_f64";
}
else if(shaped_type.getElementType().isIndex())
{
comet_print = "comet_print_memref_i64";
}
else
{
assert(false && "Unexpected type to print");
}

std::string comet_print_i64Str = "comet_print_memref_i64";
if (!hasFuncDeclaration(module, comet_print_i64Str))

if (!hasFuncDeclaration(module, comet_print))
{
print_func = func::FuncOp::create(loc, comet_print_i64Str, printTensorIndexFunc, ArrayRef<NamedAttribute>{});
func::FuncOp print_func = func::FuncOp::create(loc, comet_print, printTensor, ArrayRef<NamedAttribute>{});
print_func.setPrivate();
module.push_back(print_func);
}
Expand All @@ -130,28 +115,67 @@ namespace
{
auto alloc_op = cast<memref::AllocOp>(op->getOperand(0).getDefiningOp());
comet_vdump(alloc_op);
auto u = rewriter.create<memref::CastOp>(loc, unrankedMemrefType_f64, alloc_op);
rewriter.create<func::CallOp>(loc, comet_print_f64Str, SmallVector<Type, 2>{}, ValueRange{u});
}else if (inputType.isa<TensorType>())
auto u = rewriter.create<memref::CastOp>(loc, unrankedMemrefType, alloc_op);
rewriter.create<func::CallOp>(loc, comet_print, SmallVector<Type, 2>{}, ValueRange{u});
}
else if (inputType.isa<TensorType>())
{
auto rhs = op->getOperand(0);
auto tensor_type = llvm::cast<TensorType>(inputType);
auto memref_type = MemRefType::get(tensor_type.getShape(), tensor_type.getElementType());
auto buffer = rewriter.create<bufferization::ToMemrefOp>(loc, memref_type, rhs);

if(llvm::isa<IndexType>(tensor_type.getElementType())){
auto u = rewriter.create<memref::CastOp>(loc, unrankedMemref_index, buffer);
rewriter.create<func::CallOp>(loc, comet_print_i64Str, SmallVector<Type, 2>{}, ValueRange{u});
} else {
auto u = rewriter.create<memref::CastOp>(loc, unrankedMemrefType_f64, buffer);
rewriter.create<func::CallOp>(loc, comet_print_f64Str, SmallVector<Type, 2>{}, ValueRange{u});
}
auto u = rewriter.create<memref::CastOp>(loc, unrankedMemrefType, buffer);
rewriter.create<func::CallOp>(loc, comet_print, SmallVector<Type, 2>{}, ValueRange{u});
}
else
{
llvm::errs() << __FILE__ << " " << __LINE__ << "Unknown Data type\n";
}
}
/// If the Input type is scalar (F64)
else if (inputType.isa<FloatType,IndexType>())
{
std::string print_scalar;
if(inputType.isF64())
{
print_scalar = "printF64";
}
else if (inputType.isF32())
{
print_scalar = "printF32";
}
else if (inputType.isIndex())
{
print_scalar = "printI64";
}
else
{
assert(false && "Unsupported float type");
}
FunctionType printScalarFunc = FunctionType::get(ctx, {inputType}, {});

std::string print_newline_Str = "printNewline";
if (!hasFuncDeclaration(module, print_scalar))
{
func::FuncOp print_func = func::FuncOp::create(loc, print_scalar, printScalarFunc, ArrayRef<NamedAttribute>{});
print_func.setPrivate();
module.push_back(print_func);

if (!hasFuncDeclaration(module, "printNewline"))
{
auto printNewLineFunc = FunctionType::get(ctx, {}, {});
func::FuncOp print_newline = func::FuncOp::create(loc, print_newline_Str, printNewLineFunc, ArrayRef<NamedAttribute>{});
print_newline.setPrivate();
module.push_back(print_newline);
}
}
rewriter.create<func::CallOp>(loc, print_scalar, SmallVector<Type, 2>{}, ValueRange{op->getOperand(0)});
rewriter.create<func::CallOp>(loc, print_newline_Str, SmallVector<Type, 2>{}, ValueRange{});
}
else
{
assert(false && "Unexpected type to print");
}

/// Notify the rewriter that this operation has been removed.
comet_pdump(op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
Expand Down Expand Up @@ -926,7 +927,8 @@ class ConvertWorkspaceTensorExtractOp
[&] (OpBuilder& builder, Location loc) {
// TODO: Does the zero value depend on the semi-ring?
Type result_type = op->getResult(0).getType();
Value zero = builder.create<arith::ConstantOp>(loc, result_type, op.getZeroAttr());
FloatAttr zero_attr = op.getZeroAttr().cast<FloatAttr>();
Value zero = builder.create<arith::ConstantOp>(loc, result_type, zero_attr);
builder.create<scf::YieldOp>(loc, ArrayRef<Value>({zero}));
}
);
Expand Down
Loading

0 comments on commit 597fc1e

Please sign in to comment.