Skip to content

Commit

Permalink
[DAP] Fix LowerDAPPass (buddy-compiler#245)
Browse files Browse the repository at this point in the history
  • Loading branch information
taiqzheng authored Dec 26, 2023
1 parent c5fc1cd commit 04ac92e
Showing 1 changed file with 42 additions and 125 deletions.
167 changes: 42 additions & 125 deletions midend/lib/Conversion/LowerDAP/LowerDAPPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"

#include "DAP/DAPDialect.h"
#include "DAP/DAPOps.h"
Expand Down Expand Up @@ -175,10 +175,7 @@ class DAPIirLowering : public OpRewritePattern<dap::IirOp> {
public:
using OpRewritePattern<dap::IirOp>::OpRewritePattern;

explicit DAPIirLowering(MLIRContext *context, int64_t strideParam)
: OpRewritePattern(context) {
stride = strideParam;
}
explicit DAPIirLowering(MLIRContext *context) : OpRewritePattern(context) {}

LogicalResult matchAndRewrite(dap::IirOp op,
PatternRewriter &rewriter) const override {
Expand All @@ -197,141 +194,60 @@ class DAPIirLowering : public OpRewritePattern<dap::IirOp> {

Value N = rewriter.create<memref::DimOp>(loc, input, c0);
Value filterSize = rewriter.create<memref::DimOp>(loc, kernel, c0);
Value strideVal = rewriter.create<ConstantIndexOp>(loc, stride);

FloatType f32 = FloatType::getF32(ctx);

VectorType vectorTy32 = VectorType::get({stride}, f32);

Value zr = rewriter.create<ConstantFloatOp>(loc, APFloat(float(0)), f32);
// calculate the upper bound of the FIR part <scf::ForOp>
Value strictN = rewriter.create<SubIOp>(loc, N, c2);
Value strideRem = rewriter.create<RemSIOp>(loc, strictN, strideVal);
Value upperN = rewriter.create<SubIOp>(loc, N, strideRem);

// loop over every row in SOS matrix
rewriter.create<scf::ForOp>(
loc, c0, filterSize, c1, ValueRange{std::nullopt},
[&](OpBuilder &builder, Location loc, ValueRange ivs,
ValueRange iargs) {
Value b0 = builder.create<memref::LoadOp>(loc, kernel,
ValueRange{ivs[0], c0});
Value b1 = builder.create<memref::LoadOp>(loc, kernel,
ValueRange{ivs[0], c1});
Value b2 = builder.create<memref::LoadOp>(loc, kernel,
ValueRange{ivs[0], c2});
// Value a0 of kernel is not used
Value a1 = builder.create<memref::LoadOp>(loc, kernel,
ValueRange{ivs[0], c4});
Value a2 = builder.create<memref::LoadOp>(loc, kernel,
ValueRange{ivs[0], c5});
loc, c0, filterSize, c1, ValueRange{input},
[&](OpBuilder &builder, Location loc, Value iv, ValueRange iarg) {
Value b0 =
builder.create<memref::LoadOp>(loc, kernel, ValueRange{iv, c0});
Value b1 =
builder.create<memref::LoadOp>(loc, kernel, ValueRange{iv, c1});
Value b2 =
builder.create<memref::LoadOp>(loc, kernel, ValueRange{iv, c2});
Value a1 =
builder.create<memref::LoadOp>(loc, kernel, ValueRange{iv, c4});
Value a2 =
builder.create<memref::LoadOp>(loc, kernel, ValueRange{iv, c5});

Value z1 =
builder.create<ConstantFloatOp>(loc, APFloat(float(0)), f32);
Value z2 =
builder.create<ConstantFloatOp>(loc, APFloat(float(0)), f32);

Value x0 = builder.create<memref::LoadOp>(loc, input, ValueRange{c0});
Value temp = builder.create<MulFOp>(loc, b0, x0);
builder.create<memref::StoreOp>(loc, temp, output, ValueRange{c0});

Value x1 = builder.create<memref::LoadOp>(loc, input, ValueRange{c1});
Value temp0 = builder.create<MulFOp>(loc, b0, x1);
Value temp1 = builder.create<MulFOp>(loc, b1, x0);
Value temp2 = builder.create<AddFOp>(loc, temp0, temp1);
builder.create<memref::StoreOp>(loc, temp2, output, ValueRange{c1});

Value Vecb0 =
builder.create<vector::BroadcastOp>(loc, vectorTy32, b0);
Value Vecb1 =
builder.create<vector::BroadcastOp>(loc, vectorTy32, b1);
Value Vecb2 =
builder.create<vector::BroadcastOp>(loc, vectorTy32, b2);

// A biquad filter expression:
// y[n] = b0*x[n] + b1*x[n-1] + b2*x[n-2] - a1*y[n-1] - a2*y[n-2];
// FIR part
builder.create<scf::ForOp>(
loc, c2, upperN, strideVal, ValueRange{std::nullopt},
[&](OpBuilder &builder, Location loc, Value iv,
ValueRange itrargs) {
Value idx0 = iv;
Value idx1 = builder.create<SubIOp>(loc, idx0, c1);
Value idx2 = builder.create<SubIOp>(loc, idx0, c2);

Value inputVec0 = builder.create<LoadOp>(loc, vectorTy32, input,
ValueRange{idx0});
Value inputVec1 = builder.create<LoadOp>(loc, vectorTy32, input,
ValueRange{idx1});
Value inputVec2 = builder.create<LoadOp>(loc, vectorTy32, input,
ValueRange{idx2});

Value outputVec =
rewriter.create<vector::BroadcastOp>(loc, vectorTy32, zr);
Value resVec0 =
builder.create<FMAOp>(loc, inputVec0, Vecb0, outputVec);
Value resVec1 =
builder.create<FMAOp>(loc, inputVec1, Vecb1, resVec0);
Value resVec2 =
builder.create<FMAOp>(loc, inputVec2, Vecb2, resVec1);
builder.create<StoreOp>(loc, resVec2, output, ValueRange{idx0});

builder.create<scf::YieldOp>(loc, std::nullopt);
});

// process the remain data of FIR part
Value idx1 = builder.create<SubIOp>(loc, upperN, c1);
Value idx2 = builder.create<SubIOp>(loc, upperN, c2);
Value in1 =
builder.create<memref::LoadOp>(loc, input, ValueRange{idx1});
Value in2 =
builder.create<memref::LoadOp>(loc, input, ValueRange{idx2});

builder.create<scf::ForOp>(
loc, upperN, N, c1, ValueRange{in1, in2},
[&](OpBuilder &builder, Location loc, Value iv,
ValueRange itrargs) {
Value in0 =
builder.create<memref::LoadOp>(loc, input, ValueRange{iv});

Value temp0 = builder.create<MulFOp>(loc, b0, in0);
Value temp1 = builder.create<MulFOp>(loc, b1, in1);
Value temp2 = builder.create<MulFOp>(loc, b2, in2);
Value sum0 = builder.create<AddFOp>(loc, temp0, temp1);
Value sum1 = builder.create<AddFOp>(loc, sum0, temp2);

builder.create<memref::StoreOp>(loc, sum1, output, ValueRange{iv});

builder.create<scf::YieldOp>(loc, std::vector<Value>{in0, in1});
});

// IIR part
// Loop reordering, compute z1 for next iteration, z2 for the second
// following iteration.
builder.create<scf::ForOp>(
loc, c0, N, c1, ValueRange{z1, z2},
[&](OpBuilder &builder, Location loc, Value iv,
ValueRange itrargs) {
Value x =
builder.create<memref::LoadOp>(loc, output, ValueRange{iv});
Value t1 = builder.create<MulFOp>(loc, a1, itrargs[1]);
Value t2 = builder.create<MulFOp>(loc, a2, itrargs[0]);
Value y = builder.create<AddFOp>(loc, t1, t2);
Value opt = builder.create<SubFOp>(loc, x, y);

builder.create<memref::StoreOp>(loc, opt, output,
ValueRange{iv});
ValueRange iargs) {
Value inElem = builder.create<memref::LoadOp>(loc, iarg[0], iv);
Value t0 = builder.create<arith::MulFOp>(loc, b0, inElem);
Value outElem =
builder.create<arith::AddFOp>(loc, t0, iargs[0]);

Value t1 = builder.create<arith::MulFOp>(loc, b1, inElem);
Value t2 = builder.create<arith::MulFOp>(loc, a1, outElem);
Value t3 = builder.create<arith::SubFOp>(loc, t1, t2);
Value z1Next = builder.create<arith::AddFOp>(loc, t3, iargs[1]);

Value t4 = builder.create<arith::MulFOp>(loc, b2, inElem);
Value t5 = builder.create<arith::MulFOp>(loc, a2, outElem);
Value z2Next = builder.create<arith::SubFOp>(loc, t4, t5);

builder.create<memref::StoreOp>(loc, outElem, output, iv);
builder.create<scf::YieldOp>(
loc, std::vector<Value>{itrargs[1], opt});
loc, std::vector<Value>{z1Next, z2Next});
});
builder.create<memref::CopyOp>(loc, output, input);
builder.create<scf::YieldOp>(loc, std::nullopt);

builder.create<scf::YieldOp>(loc, output);
});

rewriter.eraseOp(op);
return success();
}

private:
int64_t stride;
};

} // end anonymous namespace
Expand All @@ -340,7 +256,7 @@ void populateLowerDAPConversionPatterns(RewritePatternSet &patterns,
int64_t stride) {
patterns.add<DAPFirLowering>(patterns.getContext());
patterns.add<DAPBiquadLowering>(patterns.getContext(), stride);
patterns.add<DAPIirLowering>(patterns.getContext(), stride);
patterns.add<DAPIirLowering>(patterns.getContext());
}

//===----------------------------------------------------------------------===//
Expand All @@ -363,7 +279,8 @@ class LowerDAPPass : public PassWrapper<LowerDAPPass, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<buddy::dap::DAPDialect, func::FuncDialect,
memref::MemRefDialect, scf::SCFDialect, VectorDialect,
affine::AffineDialect, arith::ArithDialect,linalg::LinalgDialect>();
affine::AffineDialect, arith::ArithDialect,
linalg::LinalgDialect>();
}
Option<int64_t> stride{*this, "DAP-vector-splitting",
llvm::cl::desc("Vector splitting size."),
Expand All @@ -376,10 +293,10 @@ void LowerDAPPass::runOnOperation() {
ModuleOp module = getOperation();

ConversionTarget target(*context);
target.addLegalDialect<affine::AffineDialect, scf::SCFDialect,
func::FuncDialect, memref::MemRefDialect,
VectorDialect, arith::ArithDialect,
linalg::LinalgDialect>();
target
.addLegalDialect<affine::AffineDialect, scf::SCFDialect,
func::FuncDialect, memref::MemRefDialect, VectorDialect,
arith::ArithDialect, linalg::LinalgDialect>();
target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();

RewritePatternSet patterns(context);
Expand Down

0 comments on commit 04ac92e

Please sign in to comment.