diff --git a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp index b31eb1f00..b0ed81701 100644 --- a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp +++ b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp @@ -4772,14 +4772,23 @@ scf::ForOp simpleScfForLoopTiling(scf::ForOp forOp, int original_step, return new_for_op; } -// Erase async op and replace with air wait all -void replaceAsyncOpWithWaitAll(OpBuilder builder, Operation *op) { - builder.setInsertionPoint(op); - auto waitAllReplaceAlloc = builder.create( - builder.getUnknownLoc(), air::AsyncTokenType::get(builder.getContext()), - SmallVector{}); - for (unsigned i = 0; i < op->getNumResults(); i++) - op->getResult(i).replaceAllUsesWith(waitAllReplaceAlloc.getAsyncToken()); +// Replace async op with wait_all op +static air::WaitAllOp replaceAsyncOpWithWaitAll(OpBuilder builder, + IRMapping &remap, Operation *op, + bool cloneDepList = true) { + assert(air::isAsyncOp(op)); + SmallVector dep_list_remap; + if (cloneDepList) { + for (auto dep : air::getAsyncDependenciesFromOp(op)) { + dep_list_remap.push_back(remap.lookupOrDefault(dep)); + } + } + auto wa_op = builder.create( + builder.getUnknownLoc(), air::AsyncTokenType::get(op->getContext()), + dep_list_remap); + wa_op->setAttr("hoist", StringAttr::get(op->getContext(), "dep")); + remap.map(air::getAsyncTokenFromOp(op), wa_op.getAsyncToken()); + return wa_op; } // Fuse scf.for loops in region. "fuseWithAllocDeallocs" is a bool argument @@ -4937,7 +4946,7 @@ LogicalResult fuseLoopsInRegion(Region *region, PatternRewriter &rewriter, if (fusableForOps.size() <= 1) return failure(); - rewriter.setInsertionPointAfter(equalIterationForOps.back().front()); + rewriter.setInsertionPoint(equalIterationForOps.front().front()); auto new_loop_op_init_arg = rewriter .create(loc, air::AsyncTokenType::get(ctx), @@ -4959,34 +4968,6 @@ LogicalResult fuseLoopsInRegion(Region *region, PatternRewriter &rewriter, air::getLoopCarriedTokenFromScfOp(new_loop_op, "argument"); } IRMapping remap; - if (fuseWithAllocDeallocs) { - SmallVector erase_keys; - for (auto execOpPair : alloc_dealloc_execs) { - bool canMove = false; - air::ExecuteOp alloc_exec = execOpPair.first; - auto token_users = getTokenUsersOfType(alloc_exec); - for (auto token_user : token_users) - if (llvm::any_of(equalIterationForOps, - [&](SmallVector fusableForOpNest) { - return llvm::is_contained(fusableForOpNest, - token_user); - })) - canMove = true; - if (canMove) { - rewriter.setInsertionPointToEnd(new_loop_op.getBody()); - auto new_alloc_exec = rewriter.clone(*alloc_exec, remap); - clearAsyncDependenciesOfAsyncOp( - dyn_cast(new_alloc_exec)); - for (unsigned i = 0; i < new_alloc_exec->getNumResults(); i++) - remap.map(alloc_exec->getResult(i), new_alloc_exec->getResult(i)); - } else - erase_keys.push_back(alloc_exec); - } - for (auto e : erase_keys) - for (unsigned i = 0; i < alloc_dealloc_execs.size(); i++) - if (e == alloc_dealloc_execs[i].first) - alloc_dealloc_execs.erase(alloc_dealloc_execs.begin() + i); - } // Loop fusion. auto getParentScfForNest = [](Operation *op) { @@ -5010,19 +4991,96 @@ LogicalResult fuseLoopsInRegion(Region *region, PatternRewriter &rewriter, remap.map(std::get<0>(pair).getRegionIterArgs()[i], std::get<1>(pair).getRegionIterArgs()[i]); } + // Preserve the original outermost scf.for's iter_arg. + for (unsigned i = 0; i < forOp.getRegionIterArgs().size(); i++) + remap.map(forOp.getRegionIterArgs()[i], forOp.getInitArgs()[i]); rewriter.setInsertionPointToEnd(new_loop_op.getBody()); for (auto &child_op : forOp.getBody()->without_terminator()) rewriter.clone(child_op, remap); } - // Fuse dealloc ops. - if (fuseWithAllocDeallocs) - for (auto execOpPair : alloc_dealloc_execs) { - air::ExecuteOp dealloc_exec = execOpPair.second; - clearAsyncDependenciesOfAsyncOp(dealloc_exec); - rewriter.setInsertionPointToEnd(new_loop_op.getBody()); - rewriter.clone(*dealloc_exec, remap); + // Erase original scf.for ops. + for (auto forOp : fusableForOps) { + auto fusableBandHead = getParentScfForNest(forOp).back(); + for (unsigned i = 0; i < fusableBandHead.getNumResults(); i++) { + fusableBandHead.getResult(i).replaceAllUsesWith(new_loop_op.getResult(i)); + } + rewriter.eraseOp(fusableBandHead); + } + + // Fuse allocs and deallocs into the created scf.for loop. + if (fuseWithAllocDeallocs) { + SmallVector erase_keys; + for (auto &[alloc_exec, dealloc_exec] : alloc_dealloc_execs) { + Value alloc_token = alloc_exec.getAsyncToken(); + Value dealloc_token = dealloc_exec.getAsyncToken(); + Value memref = alloc_exec->getResult(1); + bool canMove = llvm::all_of(memref.getUsers(), [&new_loop_op]( + Operation *user) { + return new_loop_op.getRegion().isAncestor(user->getParentRegion()) || + isa(user); + }); + + if (canMove) { + rewriter.setInsertionPointToStart(new_loop_op.getBody()); + auto new_alloc_exec = rewriter.clone(*alloc_exec, remap); + /// TODO: Do we still need below? + clearAsyncDependenciesOfAsyncOp( + dyn_cast(new_alloc_exec)); + rewriter.setInsertionPointToEnd(new_loop_op.getBody()); + auto new_dealloc_exec = rewriter.clone(*dealloc_exec, remap); + + if (air::isAsyncOp(new_loop_op)) + air::addAsyncDependencyIfNew( + new_alloc_exec, + air::getLoopCarriedTokenFromScfOp(new_loop_op, "argument")); + // Replace all uses of tokens + alloc_token.replaceUsesWithIf( + air::getAsyncTokenFromOp(new_loop_op), + [&new_loop_op](OpOperand &u) { + return !new_loop_op.getRegion().isAncestor( + u.getOwner()->getParentRegion()); + }); + replaceAllUsesInRegionWith(alloc_token, + air::getAsyncTokenFromOp(new_alloc_exec), + new_loop_op.getRegion()); + // Replace all uses of values + for (unsigned i = 1; i < new_alloc_exec->getNumResults(); i++) + alloc_exec->getResult(i).replaceAllUsesWith( + new_alloc_exec->getResult(i)); + + // Replace all uses of tokens + dealloc_token.replaceUsesWithIf( + air::getAsyncTokenFromOp(new_loop_op), + [&new_loop_op](OpOperand &u) { + return !new_loop_op.getRegion().isAncestor( + u.getOwner()->getParentRegion()); + }); + replaceAllUsesInRegionWith(dealloc_token, + air::getAsyncTokenFromOp(new_dealloc_exec), + new_loop_op.getRegion()); + } else + erase_keys.push_back(alloc_exec); + } + llvm::SetVector outstandingTokens; + getUsedValuesDefinedAbove(new_loop_op.getRegion(), outstandingTokens); + for (auto token : outstandingTokens) + if (isa(token.getType())) + replaceAllUsesInRegionWith( + token, air::getLoopCarriedTokenFromScfOp(new_loop_op, "argument"), + new_loop_op.getRegion()); + for (auto e : erase_keys) + for (unsigned i = 0; i < alloc_dealloc_execs.size(); i++) + if (e == alloc_dealloc_execs[i].first) + alloc_dealloc_execs.erase(alloc_dealloc_execs.begin() + i); + // Erase allocs/deallocs upon fusion. + for (auto &[alloc, dealloc] : alloc_dealloc_execs) { + assert(alloc->use_empty()); + rewriter.eraseOp(alloc); + assert(dealloc->use_empty()); + rewriter.eraseOp(dealloc); } + } // Scf.yield op. rewriter.setInsertionPointToEnd(new_loop_op.getBody()); @@ -5072,10 +5130,7 @@ LogicalResult fuseLoopsInRegion(Region *region, PatternRewriter &rewriter, continue; if (put_parent->isBeforeInBlock(get_parent)) put_parent->moveAfter(get_parent); - Value get_parent_token = nullptr; - for (auto res : get_parent->getResults()) - if (isa(res.getType())) - get_parent_token = res; + Value get_parent_token = air::getAsyncTokenFromOp(get_parent); for (unsigned i = 0; i < put_parent->getNumOperands(); i++) if (get_parent_token && isa(put_parent->getOperand(i).getType())) { @@ -5083,26 +5138,6 @@ LogicalResult fuseLoopsInRegion(Region *region, PatternRewriter &rewriter, } } - // Erase allocs/deallocs - if (fuseWithAllocDeallocs) - for (auto execOpPair : alloc_dealloc_execs) { - auto alloc = execOpPair.first; - replaceAsyncOpWithWaitAll(rewriter, alloc); - rewriter.eraseOp(alloc); - auto dealloc = execOpPair.second; - replaceAsyncOpWithWaitAll(rewriter, dealloc); - rewriter.eraseOp(dealloc); - } - - // Erase original scf.for ops. - for (auto forOp : fusableForOps) { - auto fusableBandHead = getParentScfForNest(forOp).back(); - for (unsigned i = 0; i < fusableBandHead.getNumResults(); i++) { - fusableBandHead.getResult(i).replaceAllUsesWith(new_loop_op.getResult(i)); - } - rewriter.eraseOp(fusableBandHead); - } - return success(); }