Skip to content

Commit

Permalink
[DAP] Debug LowerDAPPass
Browse files Browse the repository at this point in the history
  • Loading branch information
taiqzheng committed Nov 27, 2023
1 parent ce466ba commit 74913bb
Showing 1 changed file with 42 additions and 126 deletions.
168 changes: 42 additions & 126 deletions midend/lib/Conversion/LowerDAP/LowerDAPPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,7 @@ class DAPIirLowering : public OpRewritePattern<dap::IirOp> {
public:
using OpRewritePattern<dap::IirOp>::OpRewritePattern;

explicit DAPIirLowering(MLIRContext *context, int64_t strideParam)
: OpRewritePattern(context) {
stride = strideParam;
}
explicit DAPIirLowering(MLIRContext *context) : OpRewritePattern(context) {}

LogicalResult matchAndRewrite(dap::IirOp op,
PatternRewriter &rewriter) const override {
Expand All @@ -197,141 +194,60 @@ class DAPIirLowering : public OpRewritePattern<dap::IirOp> {

Value N = rewriter.create<memref::DimOp>(loc, input, c0);
Value filterSize = rewriter.create<memref::DimOp>(loc, kernel, c0);
Value strideVal = rewriter.create<ConstantIndexOp>(loc, stride);

FloatType f32 = FloatType::getF32(ctx);

VectorType vectorTy32 = VectorType::get({stride}, f32);

Value zr = rewriter.create<ConstantFloatOp>(loc, APFloat(float(0)), f32);
// calculate the upper bound of the FIR part <scf::ForOp>
Value strictN = rewriter.create<SubIOp>(loc, N, c2);
Value strideRem = rewriter.create<RemSIOp>(loc, strictN, strideVal);
Value upperN = rewriter.create<SubIOp>(loc, N, strideRem);

// loop over every row in SOS matrix
rewriter.create<scf::ForOp>(
loc, c0, filterSize, c1, ValueRange{std::nullopt},
[&](OpBuilder &builder, Location loc, ValueRange ivs,
ValueRange iargs) {
Value b0 = builder.create<memref::LoadOp>(loc, kernel,
ValueRange{ivs[0], c0});
Value b1 = builder.create<memref::LoadOp>(loc, kernel,
ValueRange{ivs[0], c1});
Value b2 = builder.create<memref::LoadOp>(loc, kernel,
ValueRange{ivs[0], c2});
// Value a0 of kernel is not used
Value a1 = builder.create<memref::LoadOp>(loc, kernel,
ValueRange{ivs[0], c4});
Value a2 = builder.create<memref::LoadOp>(loc, kernel,
ValueRange{ivs[0], c5});

Value z1 =
builder.create<ConstantFloatOp>(loc, APFloat(float(0)), f32);
Value z2 =
builder.create<ConstantFloatOp>(loc, APFloat(float(0)), f32);

Value x0 = builder.create<memref::LoadOp>(loc, input, ValueRange{c0});
Value temp = builder.create<MulFOp>(loc, b0, x0);
builder.create<memref::StoreOp>(loc, temp, output, ValueRange{c0});

Value x1 = builder.create<memref::LoadOp>(loc, input, ValueRange{c1});
Value temp0 = builder.create<MulFOp>(loc, b0, x1);
Value temp1 = builder.create<MulFOp>(loc, b1, x0);
Value temp2 = builder.create<AddFOp>(loc, temp0, temp1);
builder.create<memref::StoreOp>(loc, temp2, output, ValueRange{c1});

Value Vecb0 =
builder.create<vector::BroadcastOp>(loc, vectorTy32, b0);
Value Vecb1 =
builder.create<vector::BroadcastOp>(loc, vectorTy32, b1);
Value Vecb2 =
builder.create<vector::BroadcastOp>(loc, vectorTy32, b2);

// A biquad filter expression:
// y[n] = b0*x[n] + b1*x[n-1] + b2*x[n-2] - a1*y[n-1] - a2*y[n-2];
// FIR part
builder.create<scf::ForOp>(
loc, c2, upperN, strideVal, ValueRange{std::nullopt},
[&](OpBuilder &builder, Location loc, Value iv,
ValueRange itrargs) {
Value idx0 = iv;
Value idx1 = builder.create<SubIOp>(loc, idx0, c1);
Value idx2 = builder.create<SubIOp>(loc, idx0, c2);

Value inputVec0 = builder.create<LoadOp>(loc, vectorTy32, input,
ValueRange{idx0});
Value inputVec1 = builder.create<LoadOp>(loc, vectorTy32, input,
ValueRange{idx1});
Value inputVec2 = builder.create<LoadOp>(loc, vectorTy32, input,
ValueRange{idx2});

Value outputVec =
rewriter.create<vector::BroadcastOp>(loc, vectorTy32, zr);
Value resVec0 =
builder.create<FMAOp>(loc, inputVec0, Vecb0, outputVec);
Value resVec1 =
builder.create<FMAOp>(loc, inputVec1, Vecb1, resVec0);
Value resVec2 =
builder.create<FMAOp>(loc, inputVec2, Vecb2, resVec1);
builder.create<StoreOp>(loc, resVec2, output, ValueRange{idx0});

builder.create<scf::YieldOp>(loc, std::nullopt);
});

// process the remain data of FIR part
Value idx1 = builder.create<SubIOp>(loc, upperN, c1);
Value idx2 = builder.create<SubIOp>(loc, upperN, c2);
Value in1 =
builder.create<memref::LoadOp>(loc, input, ValueRange{idx1});
Value in2 =
builder.create<memref::LoadOp>(loc, input, ValueRange{idx2});

builder.create<scf::ForOp>(
loc, upperN, N, c1, ValueRange{in1, in2},
[&](OpBuilder &builder, Location loc, Value iv,
ValueRange itrargs) {
Value in0 =
builder.create<memref::LoadOp>(loc, input, ValueRange{iv});

Value temp0 = builder.create<MulFOp>(loc, b0, in0);
Value temp1 = builder.create<MulFOp>(loc, b1, in1);
Value temp2 = builder.create<MulFOp>(loc, b2, in2);
Value sum0 = builder.create<AddFOp>(loc, temp0, temp1);
Value sum1 = builder.create<AddFOp>(loc, sum0, temp2);

builder.create<memref::StoreOp>(loc, sum1, output, ValueRange{iv});

builder.create<scf::YieldOp>(loc, std::vector<Value>{in0, in1});
});

// IIR part
loc, c0, filterSize, c1, ValueRange{input},
[&](OpBuilder &builder, Location loc, Value iv, ValueRange iarg) {
Value b0 =
builder.create<memref::LoadOp>(loc, kernel, ValueRange{iv, c0});
Value b1 =
builder.create<memref::LoadOp>(loc, kernel, ValueRange{iv, c1});
Value b2 =
builder.create<memref::LoadOp>(loc, kernel, ValueRange{iv, c2});
Value a1 =
builder.create<memref::LoadOp>(loc, kernel, ValueRange{iv, c4});
Value a2 =
builder.create<memref::LoadOp>(loc, kernel, ValueRange{iv, c5});

Value z1 = builder.create<ConstantFloatOp>(loc, APFloat(float(0)), f32);
Value z2 = builder.create<ConstantFloatOp>(loc, APFloat(float(0)), f32);

// Loop reordering, compute z1 for next iteration, z2 for the second
// following iteration.
builder.create<scf::ForOp>(
loc, c0, N, c1, ValueRange{z1, z2},
[&](OpBuilder &builder, Location loc, Value iv,
ValueRange itrargs) {
Value x =
builder.create<memref::LoadOp>(loc, output, ValueRange{iv});
Value t1 = builder.create<MulFOp>(loc, a1, itrargs[1]);
Value t2 = builder.create<MulFOp>(loc, a2, itrargs[0]);
Value y = builder.create<AddFOp>(loc, t1, t2);
Value opt = builder.create<SubFOp>(loc, x, y);

builder.create<memref::StoreOp>(loc, opt, output,
ValueRange{iv});
[&](OpBuilder &builder, Location loc, Value iv_i,
ValueRange iargs) {
Value in_elem =
builder.create<memref::LoadOp>(loc, iarg[0], iv_i);
Value t0 = builder.create<arith::MulFOp>(loc, b0, in_elem);
Value out_elem =
builder.create<arith::AddFOp>(loc, t0, iargs[0]);

Value t1 = builder.create<arith::MulFOp>(loc, b1, in_elem);
Value t2 = builder.create<arith::MulFOp>(loc, a1, out_elem);
Value t3 = builder.create<arith::SubFOp>(loc, t1, t2);
Value z1_next =
builder.create<arith::AddFOp>(loc, t3, iargs[1]);

Value t4 = builder.create<arith::MulFOp>(loc, b2, in_elem);
Value t5 = builder.create<arith::MulFOp>(loc, a2, out_elem);
Value z2_next = builder.create<arith::SubFOp>(loc, t4, t5);

builder.create<memref::StoreOp>(loc, out_elem, output, iv_i);
builder.create<scf::YieldOp>(
loc, std::vector<Value>{itrargs[1], opt});
loc, std::vector<Value>{z1_next, z2_next});
});
builder.create<memref::CopyOp>(loc, output, input);
builder.create<scf::YieldOp>(loc, std::nullopt);

builder.create<scf::YieldOp>(loc, output);
});

rewriter.eraseOp(op);
return success();
}

private:
int64_t stride;
};

} // end anonymous namespace
Expand All @@ -340,7 +256,7 @@ void populateLowerDAPConversionPatterns(RewritePatternSet &patterns,
int64_t stride) {
patterns.add<DAPFirLowering>(patterns.getContext());
patterns.add<DAPBiquadLowering>(patterns.getContext(), stride);
patterns.add<DAPIirLowering>(patterns.getContext(), stride);
patterns.add<DAPIirLowering>(patterns.getContext());
}

//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 74913bb

Please sign in to comment.