Skip to content

Commit

Permalink
[gccjit] add basic arith conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
SchrodingerZhu committed Nov 4, 2024
1 parent 630b116 commit 51e5908
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 5 deletions.
3 changes: 3 additions & 0 deletions include/mlir-gccjit/Conversion/TypeConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class GCCJITTypeConverter : public TypeConverter {
~GCCJITTypeConverter();
// integral types
gccjit::IntType convertIndexType(mlir::IndexType type) const;
gccjit::IntType makeSigned(gccjit::IntType type) const;
gccjit::IntType makeUnsigned(gccjit::IntType type) const;
bool isSigned(gccjit::IntType type) const;
gccjit::IntType convertIntegerType(mlir::IntegerType type) const;
gccjit::IntAttr convertIntegerAttr(mlir::IntegerAttr attr) const;

Expand Down
139 changes: 138 additions & 1 deletion src/Conversion/ConvertArithToGCCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/Transforms/DialectConversion.h>

#include "mlir-gccjit/Conversion/Conversions.h"
#include "mlir-gccjit/Conversion/TypeConverter.h"
Expand All @@ -33,16 +35,151 @@ struct ConvertArithToGCCJITPass
void runOnOperation() override final;
};

template <typename T>
class GCCJITLoweringPattern : public mlir::OpConversionPattern<T> {
protected:
const GCCJITTypeConverter *getTypeConverter() const {
return static_cast<const GCCJITTypeConverter *>(this->typeConverter);
}

public:
using OpConversionPattern<T>::OpConversionPattern;
};

class ConstantOpLowering : public GCCJITLoweringPattern<arith::ConstantOp> {
public:
using GCCJITLoweringPattern::GCCJITLoweringPattern;
mlir::LogicalResult
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto attr = op.getValue();

if (auto value = dyn_cast<mlir::IntegerAttr>(attr)) {
rewriter.replaceOpWithNewOp<gccjit::ConstantOp>(
op, getTypeConverter()->convertIntegerAttr(value));
return mlir::success();
}

if (auto value = dyn_cast<mlir::FloatAttr>(attr)) {
rewriter.replaceOpWithNewOp<gccjit::ConstantOp>(
op, getTypeConverter()->convertFloatAttr(value));
return mlir::success();
}

return mlir::failure();
}
};

class CmpIOpLowering : public GCCJITLoweringPattern<arith::CmpIOp> {
void getComparison(gccjit::CmpOp &kind, bool &signedness,
arith::CmpIPredicate pred) const {
signedness = false;
switch (pred) {
case arith::CmpIPredicate::eq:
kind = gccjit::CmpOp::Eq;
break;
case arith::CmpIPredicate::ne:
kind = gccjit::CmpOp::Ne;
break;

case arith::CmpIPredicate::slt:
signedness = true;
[[fallthrough]];
case arith::CmpIPredicate::ult:
kind = gccjit::CmpOp::Lt;
break;

case arith::CmpIPredicate::sle:
signedness = true;
[[fallthrough]];
case arith::CmpIPredicate::ule:
kind = gccjit::CmpOp::Le;
break;

case arith::CmpIPredicate::sgt:
signedness = true;
[[fallthrough]];
case arith::CmpIPredicate::ugt:
kind = gccjit::CmpOp::Gt;
break;

case arith::CmpIPredicate::sge:
signedness = true;
[[fallthrough]];
case arith::CmpIPredicate::uge:
kind = gccjit::CmpOp::Ge;
break;
}
}

public:
using GCCJITLoweringPattern::GCCJITLoweringPattern;
mlir::LogicalResult
matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto lhs = adaptor.getLhs();
auto rhs = adaptor.getRhs();
auto inputTy = cast<IntType>(lhs.getType());
auto pred = adaptor.getPredicate();
gccjit::CmpOp kind;
bool signedness;
getComparison(kind, signedness, pred);
auto i1 = getTypeConverter()->convertType(op.getResult().getType());
if (signedness && !getTypeConverter()->isSigned(inputTy)) {
auto signedType = getTypeConverter()->makeSigned(inputTy);
lhs = rewriter.create<gccjit::BitCastOp>(op.getLoc(), signedType, lhs);
rhs = rewriter.create<gccjit::BitCastOp>(op.getLoc(), signedType, rhs);
}
auto cmpAttr = CmpOpAttr::get(op.getContext(), kind);
rewriter.replaceOpWithNewOp<gccjit::CompareOp>(op, i1, cmpAttr, lhs, rhs);
return mlir::success();
}
};
template <class Op, BOp Kind>
class TrivialBinOpConversion : public GCCJITLoweringPattern<Op> {
using GCCJITLoweringPattern<Op>::GCCJITLoweringPattern;
mlir::LogicalResult
matchAndRewrite(Op op, typename GCCJITLoweringPattern<Op>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto lhs = adaptor.getLhs();
auto rhs = adaptor.getRhs();
auto resultTy = lhs.getType();
auto kind = BOpAttr::get(op.getContext(), Kind);
rewriter.replaceOpWithNewOp<gccjit::BinaryOp>(op, resultTy, kind, lhs, rhs);
return mlir::success();
}
};

using AddIOpLowering = TrivialBinOpConversion<arith::AddIOp, BOp::Plus>;
using AddFOpLowering = TrivialBinOpConversion<arith::AddFOp, BOp::Plus>;
using MulFOpLowering = TrivialBinOpConversion<arith::MulFOp, BOp::Mult>;

void ConvertArithToGCCJITPass::runOnOperation() {
auto moduleOp = getOperation();
auto typeConverter = GCCJITTypeConverter();
// unrealized conversions
auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1)
return Value();

return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
};
typeConverter.addTargetMaterialization(materializeAsUnrealizedCast);
typeConverter.addSourceMaterialization(materializeAsUnrealizedCast);
mlir::RewritePatternSet patterns(&getContext());
patterns.add<ConstantOpLowering, CmpIOpLowering, AddIOpLowering,
AddFOpLowering, MulFOpLowering>(typeConverter, &getContext());
mlir::ConversionTarget target(getContext());
target.addLegalDialect<gccjit::GCCJITDialect>();
target.addIllegalDialect<mlir::arith::ArithDialect>();
target.addIllegalDialect<arith::ArithDialect>();
llvm::SmallVector<Operation *> ops;
for (auto func : moduleOp.getOps<func::FuncOp>())
ops.push_back(func);
for (auto func : moduleOp.getOps<gccjit::FuncOp>())
ops.push_back(func);
if (failed(applyPartialConversion(ops, target, std::move(patterns))))
signalPassFailure();
}
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ConvertMemrefToGCCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
// limitations under the License.

#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>

#include "mlir-gccjit/Conversion/Conversions.h"
#include "mlir-gccjit/Conversion/TypeConverter.h"
#include "mlir-gccjit/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"

using namespace mlir;
using namespace mlir::gccjit;
Expand Down
77 changes: 75 additions & 2 deletions src/Conversion/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "mlir-gccjit/IR/GCCJITAttrs.h"
#include "mlir-gccjit/IR/GCCJITTypes.h"
#include <mlir/IR/BuiltinTypes.h>

using namespace mlir;
using namespace mlir::gccjit;
Expand Down Expand Up @@ -82,8 +83,17 @@ GCCJITTypeConverter::convertIntegerType(mlir::IntegerType type) const {
gccjit::IntAttr
GCCJITTypeConverter::convertIntegerAttr(mlir::IntegerAttr attr) const {
auto value = attr.getValue();
auto type = convertIntegerType(cast<IntegerType>(attr.getType()));
return IntAttr::get(attr.getContext(), type, value);
if (auto intType = dyn_cast<IntegerType>(attr.getType())) {
auto type = convertIntegerType(intType);
return IntAttr::get(attr.getContext(), type, value);
}

if (auto indexType = dyn_cast<IndexType>(attr.getType())) {
auto type = convertIndexType(indexType);
return IntAttr::get(attr.getContext(), type, value);
}

return {};
}

gccjit::FloatType
Expand Down Expand Up @@ -216,3 +226,66 @@ Type GCCJITTypeConverter::convertAndPackTypesIfNonSingleton(
auto fieldsAttr = ArrayAttr::get(func.getContext(), fields);
return StructType::get(func.getContext(), nameAttr, fieldsAttr);
}

bool GCCJITTypeConverter::isSigned(gccjit::IntType type) const {
switch (type.getKind()) {
case GCC_JIT_TYPE_UNSIGNED_INT:
case GCC_JIT_TYPE_UNSIGNED_LONG:
case GCC_JIT_TYPE_UNSIGNED_LONG_LONG:
case GCC_JIT_TYPE_UINT8_T:
case GCC_JIT_TYPE_UINT16_T:
case GCC_JIT_TYPE_UINT32_T:
case GCC_JIT_TYPE_UINT64_T:
case GCC_JIT_TYPE_UINT128_T:
return false;
default:
return true;
}
}

gccjit::IntType GCCJITTypeConverter::makeSigned(gccjit::IntType type) const {
switch (type.getKind()) {
case GCC_JIT_TYPE_UNSIGNED_INT:
return IntType::get(type.getContext(), GCC_JIT_TYPE_INT);
case GCC_JIT_TYPE_UNSIGNED_LONG:
return IntType::get(type.getContext(), GCC_JIT_TYPE_LONG);
case GCC_JIT_TYPE_UNSIGNED_LONG_LONG:
return IntType::get(type.getContext(), GCC_JIT_TYPE_LONG_LONG);
case GCC_JIT_TYPE_UINT8_T:
return IntType::get(type.getContext(), GCC_JIT_TYPE_INT8_T);
case GCC_JIT_TYPE_UINT16_T:
return IntType::get(type.getContext(), GCC_JIT_TYPE_INT16_T);
case GCC_JIT_TYPE_UINT32_T:
return IntType::get(type.getContext(), GCC_JIT_TYPE_INT32_T);
case GCC_JIT_TYPE_UINT64_T:
return IntType::get(type.getContext(), GCC_JIT_TYPE_INT64_T);
case GCC_JIT_TYPE_UINT128_T:
return IntType::get(type.getContext(), GCC_JIT_TYPE_INT128_T);
default:
return type;
}
}

// the counterpart of makeSigned
gccjit::IntType GCCJITTypeConverter::makeUnsigned(gccjit::IntType type) const {
switch (type.getKind()) {
case GCC_JIT_TYPE_INT:
return IntType::get(type.getContext(), GCC_JIT_TYPE_UNSIGNED_INT);
case GCC_JIT_TYPE_LONG:
return IntType::get(type.getContext(), GCC_JIT_TYPE_UNSIGNED_LONG);
case GCC_JIT_TYPE_LONG_LONG:
return IntType::get(type.getContext(), GCC_JIT_TYPE_UNSIGNED_LONG_LONG);
case GCC_JIT_TYPE_INT8_T:
return IntType::get(type.getContext(), GCC_JIT_TYPE_UINT8_T);
case GCC_JIT_TYPE_INT16_T:
return IntType::get(type.getContext(), GCC_JIT_TYPE_UINT16_T);
case GCC_JIT_TYPE_INT32_T:
return IntType::get(type.getContext(), GCC_JIT_TYPE_UINT32_T);
case GCC_JIT_TYPE_INT64_T:
return IntType::get(type.getContext(), GCC_JIT_TYPE_UINT64_T);
case GCC_JIT_TYPE_INT128_T:
return IntType::get(type.getContext(), GCC_JIT_TYPE_UINT128_T);
default:
return type;
}
}
2 changes: 1 addition & 1 deletion test/lowering/gemm.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %gccjit-opt %s -lower-affine -convert-scf-to-cf -convert-func-to-gccjit | %filecheck %s
// RUN: %gccjit-opt %s -lower-affine -convert-scf-to-cf -convert-arith-to-gccjit -convert-func-to-gccjit -reconcile-unrealized-casts | %filecheck %s
module {
// CHECK-NOT: func.func
// CHECK-NOT: func.return
Expand Down

0 comments on commit 51e5908

Please sign in to comment.