Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DAP] Debug LowerDAPPass #245

Merged
merged 2 commits into from
Dec 26, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 42 additions & 125 deletions midend/lib/Conversion/LowerDAP/LowerDAPPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"

#include "DAP/DAPDialect.h"
#include "DAP/DAPOps.h"
Expand Down Expand Up @@ -175,10 +175,7 @@ class DAPIirLowering : public OpRewritePattern<dap::IirOp> {
public:
using OpRewritePattern<dap::IirOp>::OpRewritePattern;

explicit DAPIirLowering(MLIRContext *context, int64_t strideParam)
: OpRewritePattern(context) {
stride = strideParam;
}
explicit DAPIirLowering(MLIRContext *context) : OpRewritePattern(context) {}

LogicalResult matchAndRewrite(dap::IirOp op,
PatternRewriter &rewriter) const override {
Expand All @@ -197,141 +194,60 @@ class DAPIirLowering : public OpRewritePattern<dap::IirOp> {

Value N = rewriter.create<memref::DimOp>(loc, input, c0);
Value filterSize = rewriter.create<memref::DimOp>(loc, kernel, c0);
Value strideVal = rewriter.create<ConstantIndexOp>(loc, stride);

FloatType f32 = FloatType::getF32(ctx);

VectorType vectorTy32 = VectorType::get({stride}, f32);

Value zr = rewriter.create<ConstantFloatOp>(loc, APFloat(float(0)), f32);
// calculate the upper bound of the FIR part <scf::ForOp>
Value strictN = rewriter.create<SubIOp>(loc, N, c2);
Value strideRem = rewriter.create<RemSIOp>(loc, strictN, strideVal);
Value upperN = rewriter.create<SubIOp>(loc, N, strideRem);

// loop over every row in SOS matrix
rewriter.create<scf::ForOp>(
loc, c0, filterSize, c1, ValueRange{std::nullopt},
[&](OpBuilder &builder, Location loc, ValueRange ivs,
ValueRange iargs) {
Value b0 = builder.create<memref::LoadOp>(loc, kernel,
ValueRange{ivs[0], c0});
Value b1 = builder.create<memref::LoadOp>(loc, kernel,
ValueRange{ivs[0], c1});
Value b2 = builder.create<memref::LoadOp>(loc, kernel,
ValueRange{ivs[0], c2});
// Value a0 of kernel is not used
Value a1 = builder.create<memref::LoadOp>(loc, kernel,
ValueRange{ivs[0], c4});
Value a2 = builder.create<memref::LoadOp>(loc, kernel,
ValueRange{ivs[0], c5});
loc, c0, filterSize, c1, ValueRange{input},
[&](OpBuilder &builder, Location loc, Value iv, ValueRange iarg) {
Value b0 =
builder.create<memref::LoadOp>(loc, kernel, ValueRange{iv, c0});
Value b1 =
builder.create<memref::LoadOp>(loc, kernel, ValueRange{iv, c1});
Value b2 =
builder.create<memref::LoadOp>(loc, kernel, ValueRange{iv, c2});
Value a1 =
builder.create<memref::LoadOp>(loc, kernel, ValueRange{iv, c4});
Value a2 =
builder.create<memref::LoadOp>(loc, kernel, ValueRange{iv, c5});

Value z1 =
builder.create<ConstantFloatOp>(loc, APFloat(float(0)), f32);
Value z2 =
builder.create<ConstantFloatOp>(loc, APFloat(float(0)), f32);

Value x0 = builder.create<memref::LoadOp>(loc, input, ValueRange{c0});
Value temp = builder.create<MulFOp>(loc, b0, x0);
builder.create<memref::StoreOp>(loc, temp, output, ValueRange{c0});

Value x1 = builder.create<memref::LoadOp>(loc, input, ValueRange{c1});
Value temp0 = builder.create<MulFOp>(loc, b0, x1);
Value temp1 = builder.create<MulFOp>(loc, b1, x0);
Value temp2 = builder.create<AddFOp>(loc, temp0, temp1);
builder.create<memref::StoreOp>(loc, temp2, output, ValueRange{c1});

Value Vecb0 =
builder.create<vector::BroadcastOp>(loc, vectorTy32, b0);
Value Vecb1 =
builder.create<vector::BroadcastOp>(loc, vectorTy32, b1);
Value Vecb2 =
builder.create<vector::BroadcastOp>(loc, vectorTy32, b2);

// A biquad filter expression:
// y[n] = b0*x[n] + b1*x[n-1] + b2*x[n-2] - a1*y[n-1] - a2*y[n-2];
// FIR part
builder.create<scf::ForOp>(
loc, c2, upperN, strideVal, ValueRange{std::nullopt},
[&](OpBuilder &builder, Location loc, Value iv,
ValueRange itrargs) {
Value idx0 = iv;
Value idx1 = builder.create<SubIOp>(loc, idx0, c1);
Value idx2 = builder.create<SubIOp>(loc, idx0, c2);

Value inputVec0 = builder.create<LoadOp>(loc, vectorTy32, input,
ValueRange{idx0});
Value inputVec1 = builder.create<LoadOp>(loc, vectorTy32, input,
ValueRange{idx1});
Value inputVec2 = builder.create<LoadOp>(loc, vectorTy32, input,
ValueRange{idx2});

Value outputVec =
rewriter.create<vector::BroadcastOp>(loc, vectorTy32, zr);
Value resVec0 =
builder.create<FMAOp>(loc, inputVec0, Vecb0, outputVec);
Value resVec1 =
builder.create<FMAOp>(loc, inputVec1, Vecb1, resVec0);
Value resVec2 =
builder.create<FMAOp>(loc, inputVec2, Vecb2, resVec1);
builder.create<StoreOp>(loc, resVec2, output, ValueRange{idx0});

builder.create<scf::YieldOp>(loc, std::nullopt);
});

// process the remain data of FIR part
Value idx1 = builder.create<SubIOp>(loc, upperN, c1);
Value idx2 = builder.create<SubIOp>(loc, upperN, c2);
Value in1 =
builder.create<memref::LoadOp>(loc, input, ValueRange{idx1});
Value in2 =
builder.create<memref::LoadOp>(loc, input, ValueRange{idx2});

builder.create<scf::ForOp>(
loc, upperN, N, c1, ValueRange{in1, in2},
[&](OpBuilder &builder, Location loc, Value iv,
ValueRange itrargs) {
Value in0 =
builder.create<memref::LoadOp>(loc, input, ValueRange{iv});

Value temp0 = builder.create<MulFOp>(loc, b0, in0);
Value temp1 = builder.create<MulFOp>(loc, b1, in1);
Value temp2 = builder.create<MulFOp>(loc, b2, in2);
Value sum0 = builder.create<AddFOp>(loc, temp0, temp1);
Value sum1 = builder.create<AddFOp>(loc, sum0, temp2);

builder.create<memref::StoreOp>(loc, sum1, output, ValueRange{iv});

builder.create<scf::YieldOp>(loc, std::vector<Value>{in0, in1});
});

// IIR part
// Loop reordering, compute z1 for next iteration, z2 for the second
// following iteration.
builder.create<scf::ForOp>(
loc, c0, N, c1, ValueRange{z1, z2},
[&](OpBuilder &builder, Location loc, Value iv,
ValueRange itrargs) {
Value x =
builder.create<memref::LoadOp>(loc, output, ValueRange{iv});
Value t1 = builder.create<MulFOp>(loc, a1, itrargs[1]);
Value t2 = builder.create<MulFOp>(loc, a2, itrargs[0]);
Value y = builder.create<AddFOp>(loc, t1, t2);
Value opt = builder.create<SubFOp>(loc, x, y);

builder.create<memref::StoreOp>(loc, opt, output,
ValueRange{iv});
ValueRange iargs) {
Value inElem = builder.create<memref::LoadOp>(loc, iarg[0], iv);
Value t0 = builder.create<arith::MulFOp>(loc, b0, inElem);
Value outElem =
builder.create<arith::AddFOp>(loc, t0, iargs[0]);

Value t1 = builder.create<arith::MulFOp>(loc, b1, inElem);
Value t2 = builder.create<arith::MulFOp>(loc, a1, outElem);
Value t3 = builder.create<arith::SubFOp>(loc, t1, t2);
Value z1Next = builder.create<arith::AddFOp>(loc, t3, iargs[1]);

Value t4 = builder.create<arith::MulFOp>(loc, b2, inElem);
Value t5 = builder.create<arith::MulFOp>(loc, a2, outElem);
Value z2Next = builder.create<arith::SubFOp>(loc, t4, t5);

builder.create<memref::StoreOp>(loc, outElem, output, iv);
builder.create<scf::YieldOp>(
loc, std::vector<Value>{itrargs[1], opt});
loc, std::vector<Value>{z1Next, z2Next});
});
builder.create<memref::CopyOp>(loc, output, input);
builder.create<scf::YieldOp>(loc, std::nullopt);

builder.create<scf::YieldOp>(loc, output);
});

rewriter.eraseOp(op);
return success();
}

private:
int64_t stride;
};

} // end anonymous namespace
Expand All @@ -340,7 +256,7 @@ void populateLowerDAPConversionPatterns(RewritePatternSet &patterns,
int64_t stride) {
patterns.add<DAPFirLowering>(patterns.getContext());
patterns.add<DAPBiquadLowering>(patterns.getContext(), stride);
patterns.add<DAPIirLowering>(patterns.getContext(), stride);
patterns.add<DAPIirLowering>(patterns.getContext());
}

//===----------------------------------------------------------------------===//
Expand All @@ -363,7 +279,8 @@ class LowerDAPPass : public PassWrapper<LowerDAPPass, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<buddy::dap::DAPDialect, func::FuncDialect,
memref::MemRefDialect, scf::SCFDialect, VectorDialect,
affine::AffineDialect, arith::ArithDialect,linalg::LinalgDialect>();
affine::AffineDialect, arith::ArithDialect,
linalg::LinalgDialect>();
}
Option<int64_t> stride{*this, "DAP-vector-splitting",
llvm::cl::desc("Vector splitting size."),
Expand All @@ -376,10 +293,10 @@ void LowerDAPPass::runOnOperation() {
ModuleOp module = getOperation();

ConversionTarget target(*context);
target.addLegalDialect<affine::AffineDialect, scf::SCFDialect,
func::FuncDialect, memref::MemRefDialect,
VectorDialect, arith::ArithDialect,
linalg::LinalgDialect>();
target
.addLegalDialect<affine::AffineDialect, scf::SCFDialect,
func::FuncDialect, memref::MemRefDialect, VectorDialect,
arith::ArithDialect, linalg::LinalgDialect>();
target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();

RewritePatternSet patterns(context);
Expand Down