Skip to content

Commit

Permalink
[midend] Move insertZeroConstantOp from dip-utils to utils.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghb97 committed Sep 22, 2023
1 parent 07a8970 commit da71b6b
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 30 deletions.
5 changes: 0 additions & 5 deletions midend/include/Utils/DIPUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,6 @@ enum class DIP_OP { CORRELATION_2D, EROSION_2D, DILATION_2D };
// from lowering passes with appropriate messages.
enum class DIP_ERROR { INCONSISTENT_TYPES, UNSUPPORTED_TYPE, NO_ERROR };

// Inserts a constant op with value 0 into a location `loc` based on type
// `type`. Supported types are : f32, f64, integer types.
Value insertZeroConstantOp(MLIRContext *ctx, OpBuilder &builder, Location loc,
Type elemTy);

// Inserts FMA operation into a given location `loc` based on type `type`.
// Note: FMA is done by Multiply and Add for integer types, because there is no
// dedicated FMA operation for them.
Expand Down
5 changes: 5 additions & 0 deletions midend/include/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ using namespace mlir;

namespace buddy {

// Inserts a constant op with value 0 into a location `loc` based on type
// `type`. Supported types are : f32, f64, integer types.
Value insertZeroConstantOp(MLIRContext *ctx, OpBuilder &builder, Location loc,
Type elemTy);

// Function to test whether a value is equivalent to zero or not.
Value zeroCond(OpBuilder &builder, Location loc, Type elemType, Value value,
Value zeroElem);
Expand Down
9 changes: 3 additions & 6 deletions midend/lib/Conversion/LowerDIP/LowerDIPPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -795,8 +795,7 @@ class DIPTopHat2DOpLowering : public OpRewritePattern<dip::TopHat2DOp> {
VectorType vectorTy32 = VectorType::get({stride}, inElemTy);
IntegerType i1 = IntegerType::get(ctx, 1);
VectorType vectorMaskTy = VectorType::get({stride}, i1);
Value zeroPaddingElem =
dip::insertZeroConstantOp(ctx, rewriter, loc, inElemTy);
Value zeroPaddingElem = insertZeroConstantOp(ctx, rewriter, loc, inElemTy);
Value zeroPaddingVec =
rewriter.create<vector::BroadcastOp>(loc, vectorTy32, zeroPaddingElem);

Expand Down Expand Up @@ -1001,8 +1000,7 @@ class DIPBottomHat2DOpLowering : public OpRewritePattern<dip::BottomHat2DOp> {
VectorType vectorTy32 = VectorType::get({stride}, inElemTy);
IntegerType i1 = IntegerType::get(ctx, 1);
VectorType vectorMaskTy = VectorType::get({stride}, i1);
Value zeroPaddingElem =
dip::insertZeroConstantOp(ctx, rewriter, loc, inElemTy);
Value zeroPaddingElem = insertZeroConstantOp(ctx, rewriter, loc, inElemTy);
Value zeroPaddingVec =
rewriter.create<vector::BroadcastOp>(loc, vectorTy32, zeroPaddingElem);

Expand Down Expand Up @@ -1203,8 +1201,7 @@ class DIPMorphGrad2DOpLowering : public OpRewritePattern<dip::MorphGrad2DOp> {
VectorType vectorTy32 = VectorType::get({stride}, inElemTy);
IntegerType i1 = IntegerType::get(ctx, 1);
VectorType vectorMaskTy = VectorType::get({stride}, i1);
Value zeroPaddingElem =
dip::insertZeroConstantOp(ctx, rewriter, loc, inElemTy);
Value zeroPaddingElem = insertZeroConstantOp(ctx, rewriter, loc, inElemTy);
Value zeroPaddingVec =
rewriter.create<vector::BroadcastOp>(loc, vectorTy32, zeroPaddingElem);

Expand Down
19 changes: 0 additions & 19 deletions midend/lib/Utils/DIPUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,25 +184,6 @@ DIP_ERROR checkDIPCommonTypes(DIPOP op, const std::vector<Value> &args) {
return DIP_ERROR::NO_ERROR;
}

// Inserts a constant op with value 0 into a location `loc` based on type
// `type`. Supported types are : f32, f64, integer types.
Value insertZeroConstantOp(MLIRContext *ctx, OpBuilder &builder, Location loc,
Type elemTy) {
Value op = {};
auto bitWidth = elemTy.getIntOrFloatBitWidth();
if (elemTy.isF32() || elemTy.isF64()) {
FloatType type =
elemTy.isF32() ? FloatType::getF32(ctx) : FloatType::getF64(ctx);
auto zero = APFloat::getZero(type.getFloatSemantics());
op = builder.create<arith::ConstantFloatOp>(loc, zero, type);
} else if (elemTy.isInteger(bitWidth)) {
IntegerType type = IntegerType::get(ctx, bitWidth);
op = builder.create<arith::ConstantIntOp>(loc, 0, type);
}

return op;
}

// Inserts FMA operation into a given location `loc` based on type `type`.
// Note: FMA is done by Multiply and Add for integer types, because there is no
// dedicated FMA operation for them.
Expand Down
19 changes: 19 additions & 0 deletions midend/lib/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,25 @@ using namespace mlir;

namespace buddy {

// Inserts a constant op with value 0 into a location `loc` based on type
// `type`. Supported types are : f32, f64, integer types.
Value insertZeroConstantOp(MLIRContext *ctx, OpBuilder &builder, Location loc,
Type elemTy) {
Value op = {};
auto bitWidth = elemTy.getIntOrFloatBitWidth();
if (elemTy.isF32() || elemTy.isF64()) {
FloatType type =
elemTy.isF32() ? FloatType::getF32(ctx) : FloatType::getF64(ctx);
auto zero = APFloat::getZero(type.getFloatSemantics());
op = builder.create<arith::ConstantFloatOp>(loc, zero, type);
} else if (elemTy.isInteger(bitWidth)) {
IntegerType type = IntegerType::get(ctx, bitWidth);
op = builder.create<arith::ConstantIntOp>(loc, 0, type);
}

return op;
}

// Function to test whether a value is equivalent to zero or not.
Value zeroCond(OpBuilder &builder, Location loc, Type elemType, Value value,
Value zeroElem) {
Expand Down

0 comments on commit da71b6b

Please sign in to comment.