Skip to content

Commit

Permalink
AIRSpecializeChannelWrapAndStride: enable canonicalizing affine map j…
Browse files Browse the repository at this point in the history
…oining multiple for loops in a loop nest (#671)

* Separate affine.apply folding and canonicalization from folding for loops; relax unnecessary failure conditions

* Enable affine apply map joining multiple for loops

* Test

* Add canonicalizer folding arith.addi and muli into the for loop

* Folding is only safe when iv only has one user (i.e. the arith op)
  • Loading branch information
erwei-xilinx authored Jul 19, 2024
1 parent 4905fcc commit debcde7
Show file tree
Hide file tree
Showing 3 changed files with 362 additions and 91 deletions.
317 changes: 243 additions & 74 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1621,6 +1621,53 @@ struct CanonicalizeAIRExecute : public OpRewritePattern<air::ExecuteOp> {
private:
};

affine::AffineForOp updateAffineForBounds(OpBuilder builder, IRMapping &remap,
affine::AffineForOp loop_op, int lb,
int ub, int step) {
affine::AffineForOp new_loop_op = builder.create<affine::AffineForOp>(
builder.getUnknownLoc(), lb, ub, step);
remap.map(loop_op.getInductionVar(), new_loop_op.getInductionVar());
// remap.map(old_apply.getResult(), new_loop_op.getInductionVar());
auto insertionCheckpoint = builder.saveInsertionPoint();
builder.setInsertionPointToStart(new_loop_op.getBody());
for (Operation &child_op : loop_op.getBody()->getOperations()) {
if (&child_op == loop_op.getBody()->getTerminator()) { /*Skip*/
} else
builder.clone(child_op, remap);
}
builder.restoreInsertionPoint(insertionCheckpoint);
return new_loop_op;
}

scf::ForOp updateScfForBounds(OpBuilder builder, IRMapping &remap,
scf::ForOp loop_op, int lb, int ub, int step) {
auto loc = loop_op->getLoc();
SmallVector<Value, 1> deps =
loop_op.getOperands().drop_front(loop_op.getNumControlOperands());
scf::ForOp new_loop_op = builder.create<scf::ForOp>(
builder.getUnknownLoc(), builder.create<arith::ConstantIndexOp>(loc, lb),
builder.create<arith::ConstantIndexOp>(loc, ub),
builder.create<arith::ConstantIndexOp>(loc, step), deps);
remap.map(loop_op.getInductionVar(), new_loop_op.getInductionVar());
for (unsigned i = 0; i < loop_op.getRegionIterArgs().size(); i++)
remap.map(loop_op.getRegionIterArgs()[i],
new_loop_op.getRegionIterArgs()[i]);
auto insertionCheckpoint = builder.saveInsertionPoint();
builder.setInsertionPointToStart(new_loop_op.getBody());
for (Operation &child_op : loop_op.getBody()->getOperations()) {
if (&child_op == loop_op.getBody()->getTerminator()) {
if (!new_loop_op.getBody()->mightHaveTerminator())
builder.clone(child_op, remap);
} else
builder.clone(child_op, remap);
}
for (unsigned i = 0; i < loop_op->getNumResults(); i++)
loop_op->getResult(i).replaceAllUsesWith(new_loop_op->getResult(i));
builder.restoreInsertionPoint(insertionCheckpoint);
return new_loop_op;
}

// Fold affine.apply op operating on loop induction variable into loop bounds.
struct CanonicalizeAffineApplyOnLoopInductionVar
: public OpRewritePattern<affine::AffineApplyOp> {
using OpRewritePattern<affine::AffineApplyOp>::OpRewritePattern;
Expand Down Expand Up @@ -1709,54 +1756,184 @@ struct CanonicalizeAffineApplyOnLoopInductionVar
output = simplifyAffineMap(newmap).getSingleConstantResult();
return output;
}
};

affine::AffineForOp updateAffineForBounds(OpBuilder builder, IRMapping &remap,
affine::AffineForOp loop_op, int lb,
int ub, int step) const {
affine::AffineForOp new_loop_op = builder.create<affine::AffineForOp>(
builder.getUnknownLoc(), lb, ub, step);
remap.map(loop_op.getInductionVar(), new_loop_op.getInductionVar());
// remap.map(old_apply.getResult(), new_loop_op.getInductionVar());
auto insertionCheckpoint = builder.saveInsertionPoint();
builder.setInsertionPointToStart(new_loop_op.getBody());
for (Operation &child_op : loop_op.getBody()->getOperations()) {
if (&child_op == loop_op.getBody()->getTerminator()) { /*Skip*/
} else
builder.clone(child_op, remap);
// Fold arith.muli op operating on loop induction variable into loop bounds.
struct CanonicalizeArithMuliOpOnLoopInductionVar
: public OpRewritePattern<arith::MulIOp> {
using OpRewritePattern<arith::MulIOp>::OpRewritePattern;

LogicalResult matchAndRewrite(arith::MulIOp op,
PatternRewriter &rewriter) const override {
Operation *containingOp = nullptr;
Value const_val = nullptr;
Value var_val = nullptr;
for (auto val : SmallVector<Value>{op.getLhs(), op.getRhs()}) {
if (getConstantIntValue(val)) {
const_val = val;
continue;
}
auto ivArg = llvm::dyn_cast<BlockArgument>(val);
if (!ivArg)
continue;
if (!ivArg.getOwner())
continue;
if (!val.hasOneUse())
continue;
if (op.getResult().use_empty())
continue;
if (auto exec_muli = dyn_cast<air::ExecuteOp>(op->getParentOp()))
if (exec_muli->getResult(1).use_empty())
continue;
if (isa<scf::ForOp>(ivArg.getOwner()->getParentOp())) {
containingOp = ivArg.getOwner()->getParentOp();
var_val = val;
} else if (isa<affine::AffineForOp>(ivArg.getOwner()->getParentOp())) {
containingOp = ivArg.getOwner()->getParentOp();
var_val = val;
}
}
builder.restoreInsertionPoint(insertionCheckpoint);
return new_loop_op;
if (!containingOp)
return failure();
if (!const_val)
return failure();
if (!var_val)
return failure();

// Apply arith muli to loop step and bound
int muli_factor = *mlir::getConstantIntValue(const_val);
if (auto sfo = dyn_cast<scf::ForOp>(containingOp)) {
if (!getStaticScfForTripCountAsInt(sfo))
return failure();
int tripCount = *getStaticScfForTripCountAsInt(sfo);
int new_ub =
*mlir::getConstantIntValue(sfo.getUpperBound()) * muli_factor;
int new_lb =
*mlir::getConstantIntValue(sfo.getLowerBound()) * muli_factor;
int newStepInInt = llvm::divideCeilSigned(new_ub - new_lb, tripCount);
IRMapping remap;
if (auto exec = dyn_cast<air::ExecuteOp>(op->getParentOp())) {
rewriter.setInsertionPoint(exec);
exec.getResult(1).replaceAllUsesWith(sfo.getInductionVar());
exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]);
rewriter.eraseOp(exec);
} else {
rewriter.setInsertionPoint(op);
op.getResult().replaceAllUsesWith(sfo.getInductionVar());
rewriter.eraseOp(op);
}
rewriter.setInsertionPoint(sfo);
updateScfForBounds(rewriter, remap, sfo, new_lb, new_ub, newStepInInt);
rewriter.eraseOp(sfo);
} else if (auto afo = dyn_cast<affine::AffineForOp>(containingOp)) {
if (!afo.hasConstantBounds())
return failure();
int tripCount = *getStaticAffineForTripCountAsInt(afo);
int new_ub = afo.getConstantUpperBound() * muli_factor;
int new_lb = afo.getConstantLowerBound() * muli_factor;
int newStepInInt = llvm::divideCeilSigned(new_ub - new_lb, tripCount);
IRMapping remap;
rewriter.setInsertionPoint(afo);
op.getResult().replaceAllUsesWith(afo.getInductionVar());
rewriter.eraseOp(op);
updateAffineForBounds(rewriter, remap, afo, new_lb, new_ub, newStepInInt);
rewriter.eraseOp(afo);
} else
return failure();

return success();
}

scf::ForOp updateScfForBounds(OpBuilder builder, IRMapping &remap,
scf::ForOp loop_op, int lb, int ub,
int step) const {
auto loc = loop_op->getLoc();
SmallVector<Value, 1> deps =
loop_op.getOperands().drop_front(loop_op.getNumControlOperands());
scf::ForOp new_loop_op = builder.create<scf::ForOp>(
builder.getUnknownLoc(),
builder.create<arith::ConstantIndexOp>(loc, lb),
builder.create<arith::ConstantIndexOp>(loc, ub),
builder.create<arith::ConstantIndexOp>(loc, step), deps);
remap.map(loop_op.getInductionVar(), new_loop_op.getInductionVar());
for (unsigned i = 0; i < loop_op.getRegionIterArgs().size(); i++)
remap.map(loop_op.getRegionIterArgs()[i],
new_loop_op.getRegionIterArgs()[i]);
auto insertionCheckpoint = builder.saveInsertionPoint();
builder.setInsertionPointToStart(new_loop_op.getBody());
for (Operation &child_op : loop_op.getBody()->getOperations()) {
if (&child_op == loop_op.getBody()->getTerminator()) {
if (!new_loop_op.getBody()->mightHaveTerminator())
builder.clone(child_op, remap);
} else
builder.clone(child_op, remap);
private:
};

// Fold arith.addi op operating on loop induction variable into loop bounds.
struct CanonicalizeArithAddiOpOnLoopInductionVar
: public OpRewritePattern<arith::AddIOp> {
using OpRewritePattern<arith::AddIOp>::OpRewritePattern;

LogicalResult matchAndRewrite(arith::AddIOp op,
PatternRewriter &rewriter) const override {
Operation *containingOp = nullptr;
Value const_val = nullptr;
Value var_val = nullptr;
for (auto val : SmallVector<Value>{op.getLhs(), op.getRhs()}) {
if (getConstantIntValue(val)) {
const_val = val;
continue;
}
auto ivArg = llvm::dyn_cast<BlockArgument>(val);
if (!ivArg)
continue;
if (!ivArg.getOwner())
continue;
if (!val.hasOneUse())
continue;
if (op.getResult().use_empty())
continue;
if (auto exec_addi = dyn_cast<air::ExecuteOp>(op->getParentOp()))
if (exec_addi->getResult(1).use_empty())
continue;
if (isa<scf::ForOp>(ivArg.getOwner()->getParentOp())) {
containingOp = ivArg.getOwner()->getParentOp();
var_val = val;
} else if (isa<affine::AffineForOp>(ivArg.getOwner()->getParentOp())) {
containingOp = ivArg.getOwner()->getParentOp();
var_val = val;
}
}
for (unsigned i = 0; i < loop_op->getNumResults(); i++)
loop_op->getResult(i).replaceAllUsesWith(new_loop_op->getResult(i));
builder.restoreInsertionPoint(insertionCheckpoint);
return new_loop_op;
if (!containingOp)
return failure();
if (!const_val)
return failure();
if (!var_val)
return failure();

// Apply arith muli to loop step and bound
int addi_operand = *mlir::getConstantIntValue(const_val);
if (auto sfo = dyn_cast<scf::ForOp>(containingOp)) {
if (!getStaticScfForTripCountAsInt(sfo))
return failure();
int tripCount = *getStaticScfForTripCountAsInt(sfo);
int new_ub =
*mlir::getConstantIntValue(sfo.getUpperBound()) + addi_operand;
int new_lb =
*mlir::getConstantIntValue(sfo.getLowerBound()) + addi_operand;
int newStepInInt = llvm::divideCeilSigned(new_ub - new_lb, tripCount);
IRMapping remap;
if (auto exec = dyn_cast<air::ExecuteOp>(op->getParentOp())) {
rewriter.setInsertionPoint(exec);
exec.getResult(1).replaceAllUsesWith(sfo.getInductionVar());
exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]);
rewriter.eraseOp(exec);
} else {
rewriter.setInsertionPoint(op);
op.getResult().replaceAllUsesWith(sfo.getInductionVar());
rewriter.eraseOp(op);
}
rewriter.setInsertionPoint(sfo);
updateScfForBounds(rewriter, remap, sfo, new_lb, new_ub, newStepInInt);
rewriter.eraseOp(sfo);
} else if (auto afo = dyn_cast<affine::AffineForOp>(containingOp)) {
if (!afo.hasConstantBounds())
return failure();
int tripCount = *getStaticAffineForTripCountAsInt(afo);
int new_ub = afo.getConstantUpperBound() + addi_operand;
int new_lb = afo.getConstantLowerBound() + addi_operand;
int newStepInInt = llvm::divideCeilSigned(new_ub - new_lb, tripCount);
IRMapping remap;
rewriter.setInsertionPoint(afo);
op.getResult().replaceAllUsesWith(afo.getInductionVar());
rewriter.eraseOp(op);
updateAffineForBounds(rewriter, remap, afo, new_lb, new_ub, newStepInInt);
rewriter.eraseOp(afo);
} else
return failure();

return success();
}

private:
};

struct AIRSpecializeChannelWrapAndStrideInScfFor
Expand All @@ -1772,11 +1949,12 @@ struct AIRSpecializeChannelWrapAndStrideInScfFor
auto hasNElements = [](Block *block, unsigned N) {
unsigned counter = 0;
for (auto &o : block->getOperations()) {
if (o.mightHaveTrait<OpTrait::IsTerminator>())
continue;
if (isa<air::WaitAllOp>(o))
continue;
counter++;
if (isa<air::ChannelInterface>(o))
counter++;
else if (isa<LoopLikeOpInterface>(o))
counter++;
else if (isa<mlir::linalg::LinalgOp>(o))
counter++;
}
return counter == N;
};
Expand All @@ -1786,10 +1964,6 @@ struct AIRSpecializeChannelWrapAndStrideInScfFor

if (!hasNElements(for_op.getBody(), 1))
return failure();
if (isa<air::ChannelInterface>(for_op.getBody()->begin())) {
} else if (isa<scf::ForOp>(for_op.getBody()->begin())) {
} else
return failure();

// Check if the loop nest contains exactly one channel op
SmallVector<air::ChannelInterface> channel_ops;
Expand All @@ -1813,10 +1987,6 @@ struct AIRSpecializeChannelWrapAndStrideInScfFor
// Check for perfect loop nest containing only air.channel ops
if (!hasNElements(o.getBody(), 1))
return failure();
if (isa<air::ChannelInterface>(o.getBody()->begin())) {
} else if (isa<scf::ForOp>(o.getBody()->begin())) {
} else
return failure();
if (!getStaticScfForTripCountAsInt(o))
return failure();
}
Expand Down Expand Up @@ -1886,11 +2056,12 @@ struct AIRSpecializeChannelWrapAndStrideInAffineFor
auto hasNElements = [](Block *block, unsigned N) {
unsigned counter = 0;
for (auto &o : block->getOperations()) {
if (o.mightHaveTrait<OpTrait::IsTerminator>())
continue;
if (isa<air::WaitAllOp>(o))
continue;
counter++;
if (isa<air::ChannelInterface>(o))
counter++;
else if (isa<LoopLikeOpInterface>(o))
counter++;
else if (isa<mlir::linalg::LinalgOp>(o))
counter++;
}
return counter == N;
};
Expand All @@ -1900,10 +2071,6 @@ struct AIRSpecializeChannelWrapAndStrideInAffineFor

if (!hasNElements(for_op.getBody(), 1))
return failure();
if (isa<air::ChannelInterface>(for_op.getBody()->begin())) {
} else if (isa<affine::AffineForOp>(for_op.getBody()->begin())) {
} else
return failure();

// Check if the loop nest contains exactly one channel op
SmallVector<air::ChannelInterface> channel_ops;
Expand All @@ -1927,10 +2094,6 @@ struct AIRSpecializeChannelWrapAndStrideInAffineFor
// Check for perfect loop nest containing only air.channel ops
if (!hasNElements(o.getBody(), 1))
return failure();
if (isa<air::ChannelInterface>(o.getBody()->begin())) {
} else if (isa<affine::AffineForOp>(o.getBody()->begin())) {
} else
return failure();
if (!getStaticAffineForTripCountAsInt(o))
return failure();
}
Expand Down Expand Up @@ -2832,14 +2995,20 @@ class AIRSpecializeChannelWrapAndStridePattern

void runOptPatterns(func::FuncOp funcOp) {
MLIRContext *ctx = funcOp.getContext();
RewritePatternSet preproc_patterns(&getContext());
preproc_patterns.insert<UnrollScfParallel, CanonicalizeAIRExecute,
CanonicalizeAffineApplyOnLoopInductionVar,
CanonicalizeArithMuliOpOnLoopInductionVar,
CanonicalizeArithAddiOpOnLoopInductionVar>(ctx);
// Canonicalize constant operands in affine.apply.
mlir::affine::AffineApplyOp::getCanonicalizationPatterns(preproc_patterns,
ctx);
air::WaitAllOp::getCanonicalizationPatterns(preproc_patterns, ctx);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(preproc_patterns));

RewritePatternSet patterns(&getContext());
patterns.insert<UnrollScfParallel, CanonicalizeAIRExecute,
CanonicalizeAffineApplyOnLoopInductionVar,
AIRSpecializeChannelWrapAndStrideInScfFor,
patterns.insert<AIRSpecializeChannelWrapAndStrideInScfFor,
AIRSpecializeChannelWrapAndStrideInAffineFor>(ctx);
// Canonicalize constant operands in affine.apply.
mlir::affine::AffineApplyOp::getCanonicalizationPatterns(patterns, ctx);
air::WaitAllOp::getCanonicalizationPatterns(patterns, ctx);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));

// Canonicalize wrap and stride list to remove redundant dimensions
Expand Down
Loading

0 comments on commit debcde7

Please sign in to comment.