Skip to content

Commit

Permalink
Refactor the alloc-dealloc-fusion logic to preserve the async depende…
Browse files Browse the repository at this point in the history
…ncies after loop fusion
  • Loading branch information
erwei-xilinx committed Dec 31, 2024
1 parent a2e4acc commit f4108d7
Showing 1 changed file with 103 additions and 68 deletions.
171 changes: 103 additions & 68 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4772,14 +4772,23 @@ scf::ForOp simpleScfForLoopTiling(scf::ForOp forOp, int original_step,
return new_for_op;
}

// Erase async op and replace with air wait all
void replaceAsyncOpWithWaitAll(OpBuilder builder, Operation *op) {
builder.setInsertionPoint(op);
auto waitAllReplaceAlloc = builder.create<air::WaitAllOp>(
builder.getUnknownLoc(), air::AsyncTokenType::get(builder.getContext()),
SmallVector<Value>{});
for (unsigned i = 0; i < op->getNumResults(); i++)
op->getResult(i).replaceAllUsesWith(waitAllReplaceAlloc.getAsyncToken());
// Replace async op with wait_all op
static air::WaitAllOp replaceAsyncOpWithWaitAll(OpBuilder builder,
IRMapping &remap, Operation *op,
bool cloneDepList = true) {
assert(air::isAsyncOp(op));
SmallVector<Value> dep_list_remap;
if (cloneDepList) {
for (auto dep : air::getAsyncDependenciesFromOp(op)) {
dep_list_remap.push_back(remap.lookupOrDefault(dep));
}
}
auto wa_op = builder.create<air::WaitAllOp>(
builder.getUnknownLoc(), air::AsyncTokenType::get(op->getContext()),
dep_list_remap);
wa_op->setAttr("hoist", StringAttr::get(op->getContext(), "dep"));
remap.map(air::getAsyncTokenFromOp(op), wa_op.getAsyncToken());
return wa_op;
}

// Fuse scf.for loops in region. "fuseWithAllocDeallocs" is a bool argument
Expand Down Expand Up @@ -4937,7 +4946,7 @@ LogicalResult fuseLoopsInRegion(Region *region, PatternRewriter &rewriter,
if (fusableForOps.size() <= 1)
return failure();

rewriter.setInsertionPointAfter(equalIterationForOps.back().front());
rewriter.setInsertionPoint(equalIterationForOps.front().front());
auto new_loop_op_init_arg =
rewriter
.create<air::WaitAllOp>(loc, air::AsyncTokenType::get(ctx),
Expand All @@ -4959,34 +4968,6 @@ LogicalResult fuseLoopsInRegion(Region *region, PatternRewriter &rewriter,
air::getLoopCarriedTokenFromScfOp(new_loop_op, "argument");
}
IRMapping remap;
if (fuseWithAllocDeallocs) {
SmallVector<air::ExecuteOp> erase_keys;
for (auto execOpPair : alloc_dealloc_execs) {
bool canMove = false;
air::ExecuteOp alloc_exec = execOpPair.first;
auto token_users = getTokenUsersOfType<scf::ForOp>(alloc_exec);
for (auto token_user : token_users)
if (llvm::any_of(equalIterationForOps,
[&](SmallVector<scf::ForOp> fusableForOpNest) {
return llvm::is_contained(fusableForOpNest,
token_user);
}))
canMove = true;
if (canMove) {
rewriter.setInsertionPointToEnd(new_loop_op.getBody());
auto new_alloc_exec = rewriter.clone(*alloc_exec, remap);
clearAsyncDependenciesOfAsyncOp(
dyn_cast<air::AsyncOpInterface>(new_alloc_exec));
for (unsigned i = 0; i < new_alloc_exec->getNumResults(); i++)
remap.map(alloc_exec->getResult(i), new_alloc_exec->getResult(i));
} else
erase_keys.push_back(alloc_exec);
}
for (auto e : erase_keys)
for (unsigned i = 0; i < alloc_dealloc_execs.size(); i++)
if (e == alloc_dealloc_execs[i].first)
alloc_dealloc_execs.erase(alloc_dealloc_execs.begin() + i);
}

// Loop fusion.
auto getParentScfForNest = [](Operation *op) {
Expand All @@ -5010,19 +4991,96 @@ LogicalResult fuseLoopsInRegion(Region *region, PatternRewriter &rewriter,
remap.map(std::get<0>(pair).getRegionIterArgs()[i],
std::get<1>(pair).getRegionIterArgs()[i]);
}
// Preserve the original outermost scf.for's iter_arg.
for (unsigned i = 0; i < forOp.getRegionIterArgs().size(); i++)
remap.map(forOp.getRegionIterArgs()[i], forOp.getInitArgs()[i]);
rewriter.setInsertionPointToEnd(new_loop_op.getBody());
for (auto &child_op : forOp.getBody()->without_terminator())
rewriter.clone(child_op, remap);
}

// Fuse dealloc ops.
if (fuseWithAllocDeallocs)
for (auto execOpPair : alloc_dealloc_execs) {
air::ExecuteOp dealloc_exec = execOpPair.second;
clearAsyncDependenciesOfAsyncOp(dealloc_exec);
rewriter.setInsertionPointToEnd(new_loop_op.getBody());
rewriter.clone(*dealloc_exec, remap);
// Erase original scf.for ops.
for (auto forOp : fusableForOps) {
auto fusableBandHead = getParentScfForNest(forOp).back();
for (unsigned i = 0; i < fusableBandHead.getNumResults(); i++) {
fusableBandHead.getResult(i).replaceAllUsesWith(new_loop_op.getResult(i));
}
rewriter.eraseOp(fusableBandHead);
}

// Fuse allocs and deallocs into the created scf.for loop.
if (fuseWithAllocDeallocs) {
SmallVector<air::ExecuteOp> erase_keys;
for (auto &[alloc_exec, dealloc_exec] : alloc_dealloc_execs) {
Value alloc_token = alloc_exec.getAsyncToken();
Value dealloc_token = dealloc_exec.getAsyncToken();
Value memref = alloc_exec->getResult(1);
bool canMove = llvm::all_of(memref.getUsers(), [&new_loop_op](
Operation *user) {
return new_loop_op.getRegion().isAncestor(user->getParentRegion()) ||
isa<memref::DeallocOp>(user);
});

if (canMove) {
rewriter.setInsertionPointToStart(new_loop_op.getBody());
auto new_alloc_exec = rewriter.clone(*alloc_exec, remap);
/// TODO: Do we still need below?
clearAsyncDependenciesOfAsyncOp(
dyn_cast<air::AsyncOpInterface>(new_alloc_exec));
rewriter.setInsertionPointToEnd(new_loop_op.getBody());
auto new_dealloc_exec = rewriter.clone(*dealloc_exec, remap);

if (air::isAsyncOp(new_loop_op))
air::addAsyncDependencyIfNew(
new_alloc_exec,
air::getLoopCarriedTokenFromScfOp(new_loop_op, "argument"));
// Replace all uses of tokens
alloc_token.replaceUsesWithIf(
air::getAsyncTokenFromOp(new_loop_op),
[&new_loop_op](OpOperand &u) {
return !new_loop_op.getRegion().isAncestor(
u.getOwner()->getParentRegion());
});
replaceAllUsesInRegionWith(alloc_token,
air::getAsyncTokenFromOp(new_alloc_exec),
new_loop_op.getRegion());
// Replace all uses of values
for (unsigned i = 1; i < new_alloc_exec->getNumResults(); i++)
alloc_exec->getResult(i).replaceAllUsesWith(
new_alloc_exec->getResult(i));

// Replace all uses of tokens
dealloc_token.replaceUsesWithIf(
air::getAsyncTokenFromOp(new_loop_op),
[&new_loop_op](OpOperand &u) {
return !new_loop_op.getRegion().isAncestor(
u.getOwner()->getParentRegion());
});
replaceAllUsesInRegionWith(dealloc_token,
air::getAsyncTokenFromOp(new_dealloc_exec),
new_loop_op.getRegion());
} else
erase_keys.push_back(alloc_exec);
}
llvm::SetVector<Value> outstandingTokens;
getUsedValuesDefinedAbove(new_loop_op.getRegion(), outstandingTokens);
for (auto token : outstandingTokens)
if (isa<air::AsyncTokenType>(token.getType()))
replaceAllUsesInRegionWith(
token, air::getLoopCarriedTokenFromScfOp(new_loop_op, "argument"),
new_loop_op.getRegion());
for (auto e : erase_keys)
for (unsigned i = 0; i < alloc_dealloc_execs.size(); i++)
if (e == alloc_dealloc_execs[i].first)
alloc_dealloc_execs.erase(alloc_dealloc_execs.begin() + i);
// Erase allocs/deallocs upon fusion.
for (auto &[alloc, dealloc] : alloc_dealloc_execs) {
assert(alloc->use_empty());
rewriter.eraseOp(alloc);
assert(dealloc->use_empty());
rewriter.eraseOp(dealloc);
}
}

// Scf.yield op.
rewriter.setInsertionPointToEnd(new_loop_op.getBody());
Expand Down Expand Up @@ -5072,37 +5130,14 @@ LogicalResult fuseLoopsInRegion(Region *region, PatternRewriter &rewriter,
continue;
if (put_parent->isBeforeInBlock(get_parent))
put_parent->moveAfter(get_parent);
Value get_parent_token = nullptr;
for (auto res : get_parent->getResults())
if (isa<AsyncTokenType>(res.getType()))
get_parent_token = res;
Value get_parent_token = air::getAsyncTokenFromOp(get_parent);
for (unsigned i = 0; i < put_parent->getNumOperands(); i++)
if (get_parent_token &&
isa<AsyncTokenType>(put_parent->getOperand(i).getType())) {
put_parent->getOpOperand(i).assign(get_parent_token);
}
}

// Erase allocs/deallocs
if (fuseWithAllocDeallocs)
for (auto execOpPair : alloc_dealloc_execs) {
auto alloc = execOpPair.first;
replaceAsyncOpWithWaitAll(rewriter, alloc);
rewriter.eraseOp(alloc);
auto dealloc = execOpPair.second;
replaceAsyncOpWithWaitAll(rewriter, dealloc);
rewriter.eraseOp(dealloc);
}

// Erase original scf.for ops.
for (auto forOp : fusableForOps) {
auto fusableBandHead = getParentScfForNest(forOp).back();
for (unsigned i = 0; i < fusableBandHead.getNumResults(); i++) {
fusableBandHead.getResult(i).replaceAllUsesWith(new_loop_op.getResult(i));
}
rewriter.eraseOp(fusableBandHead);
}

return success();
}

Expand Down

0 comments on commit f4108d7

Please sign in to comment.