From 2298a502e9f50529e9aa2a62e4e31867978a68fb Mon Sep 17 00:00:00 2001 From: Taiqi Zheng <56971484+taiqzheng@users.noreply.github.com> Date: Tue, 26 Dec 2023 21:49:48 +0800 Subject: [PATCH] [DAP] Fix LowerDAPPass (#245) --- .../lib/Conversion/LowerDAP/LowerDAPPass.cpp | 167 +++++------------- 1 file changed, 42 insertions(+), 125 deletions(-) diff --git a/midend/lib/Conversion/LowerDAP/LowerDAPPass.cpp b/midend/lib/Conversion/LowerDAP/LowerDAPPass.cpp index 33148d5478..bf77f358bb 100644 --- a/midend/lib/Conversion/LowerDAP/LowerDAPPass.cpp +++ b/midend/lib/Conversion/LowerDAP/LowerDAPPass.cpp @@ -21,11 +21,11 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Pass/Pass.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "DAP/DAPDialect.h" #include "DAP/DAPOps.h" @@ -175,10 +175,7 @@ class DAPIirLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - explicit DAPIirLowering(MLIRContext *context, int64_t strideParam) - : OpRewritePattern(context) { - stride = strideParam; - } + explicit DAPIirLowering(MLIRContext *context) : OpRewritePattern(context) {} LogicalResult matchAndRewrite(dap::IirOp op, PatternRewriter &rewriter) const override { @@ -197,141 +194,60 @@ class DAPIirLowering : public OpRewritePattern { Value N = rewriter.create(loc, input, c0); Value filterSize = rewriter.create(loc, kernel, c0); - Value strideVal = rewriter.create(loc, stride); FloatType f32 = FloatType::getF32(ctx); - VectorType vectorTy32 = VectorType::get({stride}, f32); - - Value zr = rewriter.create(loc, APFloat(float(0)), f32); - // calculate the upper bound of the FIR part - Value strictN = rewriter.create(loc, N, c2); - Value strideRem = rewriter.create(loc, strictN, strideVal); - Value upperN = rewriter.create(loc, N, strideRem); - // loop over every row in SOS matrix rewriter.create( - loc, c0, filterSize, c1, ValueRange{std::nullopt}, - [&](OpBuilder &builder, Location loc, ValueRange ivs, - ValueRange iargs) { - Value b0 = builder.create(loc, kernel, - ValueRange{ivs[0], c0}); - Value b1 = builder.create(loc, kernel, - ValueRange{ivs[0], c1}); - Value b2 = builder.create(loc, kernel, - ValueRange{ivs[0], c2}); - // Value a0 of kernel is not used - Value a1 = builder.create(loc, kernel, - ValueRange{ivs[0], c4}); - Value a2 = builder.create(loc, kernel, - ValueRange{ivs[0], c5}); + loc, c0, filterSize, c1, ValueRange{input}, + [&](OpBuilder &builder, Location loc, Value iv, ValueRange iarg) { + Value b0 = + builder.create(loc, kernel, ValueRange{iv, c0}); + Value b1 = + builder.create(loc, kernel, ValueRange{iv, c1}); + Value b2 = + builder.create(loc, kernel, ValueRange{iv, c2}); + Value a1 = + builder.create(loc, kernel, ValueRange{iv, c4}); + Value a2 = + builder.create(loc, kernel, ValueRange{iv, c5}); Value z1 = builder.create(loc, APFloat(float(0)), f32); Value z2 = builder.create(loc, APFloat(float(0)), f32); - Value x0 = builder.create(loc, input, ValueRange{c0}); - Value temp = builder.create(loc, b0, x0); - builder.create(loc, temp, output, ValueRange{c0}); - - Value x1 = builder.create(loc, input, ValueRange{c1}); - Value temp0 = builder.create(loc, b0, x1); - Value temp1 = builder.create(loc, b1, x0); - Value temp2 = builder.create(loc, temp0, temp1); - builder.create(loc, temp2, output, ValueRange{c1}); - - Value Vecb0 = - builder.create(loc, vectorTy32, b0); - Value Vecb1 = - builder.create(loc, vectorTy32, b1); - Value Vecb2 = - builder.create(loc, vectorTy32, b2); - - // A biquad filter expression: - // y[n] = b0*x[n] + b1*x[n-1] + b2*x[n-2] - a1*y[n-1] - a2*y[n-2]; - // FIR part - builder.create( - loc, c2, upperN, strideVal, ValueRange{std::nullopt}, - [&](OpBuilder &builder, Location loc, Value iv, - ValueRange itrargs) { - Value idx0 = iv; - Value idx1 = builder.create(loc, idx0, c1); - Value idx2 = builder.create(loc, idx0, c2); - - Value inputVec0 = builder.create(loc, vectorTy32, input, - ValueRange{idx0}); - Value inputVec1 = builder.create(loc, vectorTy32, input, - ValueRange{idx1}); - Value inputVec2 = builder.create(loc, vectorTy32, input, - ValueRange{idx2}); - - Value outputVec = - rewriter.create(loc, vectorTy32, zr); - Value resVec0 = - builder.create(loc, inputVec0, Vecb0, outputVec); - Value resVec1 = - builder.create(loc, inputVec1, Vecb1, resVec0); - Value resVec2 = - builder.create(loc, inputVec2, Vecb2, resVec1); - builder.create(loc, resVec2, output, ValueRange{idx0}); - - builder.create(loc, std::nullopt); - }); - - // process the remain data of FIR part - Value idx1 = builder.create(loc, upperN, c1); - Value idx2 = builder.create(loc, upperN, c2); - Value in1 = - builder.create(loc, input, ValueRange{idx1}); - Value in2 = - builder.create(loc, input, ValueRange{idx2}); - - builder.create( - loc, upperN, N, c1, ValueRange{in1, in2}, - [&](OpBuilder &builder, Location loc, Value iv, - ValueRange itrargs) { - Value in0 = - builder.create(loc, input, ValueRange{iv}); - - Value temp0 = builder.create(loc, b0, in0); - Value temp1 = builder.create(loc, b1, in1); - Value temp2 = builder.create(loc, b2, in2); - Value sum0 = builder.create(loc, temp0, temp1); - Value sum1 = builder.create(loc, sum0, temp2); - - builder.create(loc, sum1, output, ValueRange{iv}); - - builder.create(loc, std::vector{in0, in1}); - }); - - // IIR part + // Loop reordering, compute z1 for next iteration, z2 for the second + // following iteration. builder.create( loc, c0, N, c1, ValueRange{z1, z2}, [&](OpBuilder &builder, Location loc, Value iv, - ValueRange itrargs) { - Value x = - builder.create(loc, output, ValueRange{iv}); - Value t1 = builder.create(loc, a1, itrargs[1]); - Value t2 = builder.create(loc, a2, itrargs[0]); - Value y = builder.create(loc, t1, t2); - Value opt = builder.create(loc, x, y); - - builder.create(loc, opt, output, - ValueRange{iv}); + ValueRange iargs) { + Value inElem = builder.create(loc, iarg[0], iv); + Value t0 = builder.create(loc, b0, inElem); + Value outElem = + builder.create(loc, t0, iargs[0]); + + Value t1 = builder.create(loc, b1, inElem); + Value t2 = builder.create(loc, a1, outElem); + Value t3 = builder.create(loc, t1, t2); + Value z1Next = builder.create(loc, t3, iargs[1]); + + Value t4 = builder.create(loc, b2, inElem); + Value t5 = builder.create(loc, a2, outElem); + Value z2Next = builder.create(loc, t4, t5); + + builder.create(loc, outElem, output, iv); builder.create( - loc, std::vector{itrargs[1], opt}); + loc, std::vector{z1Next, z2Next}); }); - builder.create(loc, output, input); - builder.create(loc, std::nullopt); + + builder.create(loc, output); }); rewriter.eraseOp(op); return success(); } - -private: - int64_t stride; }; } // end anonymous namespace @@ -340,7 +256,7 @@ void populateLowerDAPConversionPatterns(RewritePatternSet &patterns, int64_t stride) { patterns.add(patterns.getContext()); patterns.add(patterns.getContext(), stride); - patterns.add(patterns.getContext(), stride); + patterns.add(patterns.getContext()); } //===----------------------------------------------------------------------===// @@ -363,7 +279,8 @@ class LowerDAPPass : public PassWrapper> { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + affine::AffineDialect, arith::ArithDialect, + linalg::LinalgDialect>(); } Option stride{*this, "DAP-vector-splitting", llvm::cl::desc("Vector splitting size."), @@ -376,10 +293,10 @@ void LowerDAPPass::runOnOperation() { ModuleOp module = getOperation(); ConversionTarget target(*context); - target.addLegalDialect(); + target + .addLegalDialect(); target.addLegalOp(); RewritePatternSet patterns(context);