Skip to content

Commit

Permalink
AIRDependency: fixup issues around op erase in `hoistTargetOpsToNewSC…
Browse files Browse the repository at this point in the history
…FFor` method (Xilinx#777)
  • Loading branch information
erwei-xilinx authored Nov 16, 2024
1 parent cb7758b commit 62f0de0
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 23 deletions.
3 changes: 2 additions & 1 deletion mlir/include/air/Util/Dependency.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ void addAsyncDependencyIfNew(Operation *op, Value token);
bool isAsyncOp(Operation *op);
bool areAsyncDependent(Operation *a, Operation *b);
bool isAsyncDependent(Operation *a, Operation *b);
scf::ForOp hoistTargetOpsToNewSCFFor(OpBuilder builder, scf::ForOp for_op,
scf::ForOp hoistTargetOpsToNewSCFFor(PatternRewriter &rewriter,
scf::ForOp for_op,
SmallVector<Operation *> target_ops);
LogicalResult unrollAIRChannelPutGetInScfParallel(OpBuilder builder,
scf::ParallelOp par,
Expand Down
57 changes: 35 additions & 22 deletions mlir/lib/Util/Dependency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -660,14 +660,27 @@ bool isAsyncDependent(Operation *a, Operation *b) {

// Splits an SCF for loop into two for loops, by hoisting target operations in
// for loop to a new for loop located at the same scope.
scf::ForOp hoistTargetOpsToNewSCFFor(OpBuilder builder, scf::ForOp for_op,
scf::ForOp hoistTargetOpsToNewSCFFor(PatternRewriter &rewriter,
scf::ForOp for_op,
SmallVector<Operation *> target_ops) {
auto loc = for_op->getLoc();
// If target ops are already perfectly nested, then skip
auto hasNChannelOps = [](Block *block, unsigned N) {
SmallVector<air::ChannelInterface> chanOps;
block->walk([&](air::ChannelInterface op) { chanOps.push_back(op); });
return chanOps.size() == N;
auto hasNChannelOps = [target_ops](Block *block, unsigned N) {
unsigned counter = 0;
block->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
[target_ops, &counter](Operation *op) {
if (op->hasTrait<OpTrait::IsIsolatedFromAbove>())
return WalkResult::skip();
if (llvm::is_contained(target_ops, op)) {
counter++;
return WalkResult::skip();
}
if (isa<air::ChannelInterface>(op))
counter++;
counter++;
return WalkResult::advance();
});
return counter == N;
};
if (hasNChannelOps(for_op.getBody(), 1))
return for_op;
Expand All @@ -686,20 +699,20 @@ scf::ForOp hoistTargetOpsToNewSCFFor(OpBuilder builder, scf::ForOp for_op,
}
}

builder.setInsertionPoint(for_op);
rewriter.setInsertionPoint(for_op);
IRMapping remap;
auto new_for_op = builder.create<scf::ForOp>(
auto new_for_op = rewriter.create<scf::ForOp>(
loc, for_op.getLowerBound(), for_op.getUpperBound(), for_op.getStep(),
SmallVector<Value>{builder
.create<air::WaitAllOp>(
loc,
air::AsyncTokenType::get(builder.getContext()),
SmallVector<Value>{})
.getAsyncToken()});
SmallVector<Value>{
rewriter
.create<air::WaitAllOp>(
loc, air::AsyncTokenType::get(rewriter.getContext()),
SmallVector<Value>{})
.getAsyncToken()});
remap.map(for_op.getInductionVar(), new_for_op.getInductionVar());
remap.map(getLoopCarriedTokenFromScfOp(for_op, "argument"),
getLoopCarriedTokenFromScfOp(new_for_op, "argument"));
builder.setInsertionPointToStart(new_for_op.getBody());
rewriter.setInsertionPointToStart(new_for_op.getBody());
SmallVector<Value> yield_operands;
// Build up a log of ops to be cloned; using SetVector to avoid repetition.
llvm::SetVector<Operation *> ops_to_be_cloned;
Expand All @@ -719,14 +732,14 @@ scf::ForOp hoistTargetOpsToNewSCFFor(OpBuilder builder, scf::ForOp for_op,
}
Operation *back_of_dep_chain;
for (auto o : ops_to_be_cloned)
back_of_dep_chain = builder.clone(*o, remap);
back_of_dep_chain = rewriter.clone(*o, remap);
yield_operands.push_back(getAsyncTokenFromOp(back_of_dep_chain));

builder.create<scf::YieldOp>(
rewriter.create<scf::YieldOp>(
loc, SmallVector<Value>{
builder
rewriter
.create<air::WaitAllOp>(
loc, air::AsyncTokenType::get(builder.getContext()),
loc, air::AsyncTokenType::get(rewriter.getContext()),
yield_operands)
->getResult(0)});

Expand All @@ -738,7 +751,7 @@ scf::ForOp hoistTargetOpsToNewSCFFor(OpBuilder builder, scf::ForOp for_op,
}
for (auto erase_op : target_ops) {
// Reconnect returned tokens.
builder.setInsertionPoint(erase_op);
rewriter.setInsertionPoint(erase_op);
for (auto res : erase_op->getResults()) {
if (!isa<air::AsyncTokenType>(res.getType()))
continue;
Expand All @@ -752,17 +765,17 @@ scf::ForOp hoistTargetOpsToNewSCFFor(OpBuilder builder, scf::ForOp for_op,
// User op doesn't have air::AsyncOpInterface. Replace uses with newly
// generated air.wait_all op.
u->replaceUsesOfWith(
res, builder
res, rewriter
.create<air::WaitAllOp>(
loc, air::AsyncTokenType::get(builder.getContext()),
loc, air::AsyncTokenType::get(rewriter.getContext()),
getAsyncDependenciesFromOp(erase_op))
.getAsyncToken());
}
}
}
}
for (auto erase_op : target_ops)
erase_op->erase();
rewriter.eraseOp(erase_op);
for (auto user : for_op.getResults().front().getUsers()) {
air::addAsyncDependencyIfNew(user, new_for_op.getResults().front());
}
Expand Down

0 comments on commit 62f0de0

Please sign in to comment.