Skip to content

Commit

Permalink
Refactor air.execute's getBody() and getChildOp() methods (Xili…
Browse files Browse the repository at this point in the history
…nx#773)

* Rename the single region in air.execute to region instead of body; add getBody method to get single block

* Refactor getChildOp method to getChildOps method

* Remove assertions in air.execute verifier

* Relax uses of getChildOp method to iterate through a range of getChildOps instead

* Fixup iteration through getChildOps
  • Loading branch information
erwei-xilinx authored Nov 14, 2024
1 parent ddbb630 commit c60d7bd
Show file tree
Hide file tree
Showing 12 changed files with 170 additions and 142 deletions.
12 changes: 10 additions & 2 deletions mlir/include/air/Dialect/AIR/AIR.td
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def air_ExecuteOp : air_Op<"execute", [SingleBlockImplicitTerminator<"ExecuteTer
Variadic<AnyType>:$results
);
let summary = "Asynchronous code region";
let regions = (region SizedRegion<1>:$body);
let regions = (region SizedRegion<1>:$region);
let description = [{
Defines a code region to be dispatched asynchronously at runtime. All operations in
the region must be executed sequentially.
Expand All @@ -477,7 +477,15 @@ def air_ExecuteOp : air_Op<"execute", [SingleBlockImplicitTerminator<"ExecuteTer
}];

let extraClassDeclaration = [{
Operation * getChildOp();
Block &getBody() { return getRegion().front(); }
llvm::iplist<Operation> &getChildOps() { return getBody().getOperations(); }
SmallVector<Operation *> getYieldedChildOps() {
SmallVector<Operation *> ops;
for (auto oper : getBody().getTerminator()->getOperands())
if (oper.getDefiningOp() && getRegion().isAncestor(oper.getDefiningOp()->getParentRegion()))
ops.push_back(oper.getDefiningOp());
return ops;
}
int32_t getId() {
if (auto id_attr = (*this)->getAttrOfType<IntegerAttr>("id")) {
return id_attr.getInt();
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/AIRLoweringPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ LogicalResult lowerAirExecute(Operation *op) {

llvm::SmallSet<Operation *, 8> erased;
module->walk([&](air::ExecuteOp exe) {
auto &bb = exe.getBody().front();
auto &bb = exe.getRegion().front();
unsigned idx = 0;

OpBuilder builder(exe);
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/AIRToAIEPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ struct LowerAIRExecutePattern : public OpRewritePattern<air::ExecuteOp> {

LogicalResult matchAndRewrite(air::ExecuteOp op,
PatternRewriter &rewriter) const override {
auto &bb = op.getBody().front();
auto &bb = op.getRegion().front();
unsigned idx = 0;
for (auto &arg : bb.getArguments()) {
arg.replaceAllUsesWith(op.getOperand(idx));
Expand Down
14 changes: 6 additions & 8 deletions mlir/lib/Dialect/AIR/IR/AIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1026,8 +1026,12 @@ uint64_t HerdOp::getNumRows() {
//

LogicalResult ExecuteOp::verify() {
assert(getOperation()->getNumRegions() == 1 && "ExecuteOp has zero region!");
assert(!getBody().empty() && "ExecuteOp should have non-empty body");
if (getOperation()->getNumRegions() != 1)
return emitOpError("ExecuteOp has zero region.");
if (getRegion().empty())
return emitOpError("ExecuteOp should have non-empty region.");
if (getBody().empty())
return emitOpError("ExecuteOp should have non-empty body.");

return success();
}
Expand Down Expand Up @@ -1086,12 +1090,6 @@ void ExecuteOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add(CanonicalizeAsyncOpDeps<ExecuteOp>);
}

Operation *ExecuteOp::getChildOp() {
auto child_op =
&getOperation()->getRegion(0).getBlocks().front().getOperations().front();
return child_op;
}

//
// WaitAllOp
//
Expand Down
17 changes: 7 additions & 10 deletions mlir/lib/Transform/AIRDependency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,9 @@ class AIRDependency
f.walk([&](Operation *op) {
Operation *sink_op = nullptr;
if (auto async_execute_op = dyn_cast<air::ExecuteOp>(op)) {
for (auto &bb : async_execute_op.getBody()) {
for (auto &child_op : bb.getOperations()) {
if (!dyn_cast<air::ExecuteTerminatorOp>(child_op))
sink_op = &child_op;
}
}
for (auto &child_op : async_execute_op.getChildOps())
if (!dyn_cast<air::ExecuteTerminatorOp>(child_op))
sink_op = &child_op;
} else if (isa<xilinx::air::DmaMemcpyNdOp>(op)) {
sink_op = op;
} else if (isa<xilinx::air::ChannelInterface>(op)) {
Expand Down Expand Up @@ -694,7 +691,7 @@ class AIRDependency
mlir::IntegerType::get(op->getContext(), 32), ++ExecuteOpID));

// Insert op to the new async execute region's body.
Block *async_region_bb = builder.createBlock(&async_region.getBody());
Block *async_region_bb = builder.createBlock(&async_region.getRegion());
builder.setInsertionPointToStart(async_region_bb);

// Handle cases when the operand(s) of the given op that is
Expand Down Expand Up @@ -757,7 +754,7 @@ class AIRDependency
mlir::IntegerType::get(op->getContext(), 32), ++ExecuteOpID));

// Insert op to the new async execute region's body.
Block *async_region_bb = builder.createBlock(&async_region.getBody());
Block *async_region_bb = builder.createBlock(&async_region.getRegion());
builder.setInsertionPointToStart(async_region_bb);
auto op_cloned = builder.clone(*op);
builder.create<xilinx::air::ExecuteTerminatorOp>(builder.getUnknownLoc(),
Expand Down Expand Up @@ -1989,7 +1986,7 @@ class AIRDependency
}
bool isNotLoopCarriedOp(air::AsyncOpInterface op) {
if (auto exec_op = dyn_cast<air::ExecuteOp>(op.getOperation())) {
auto &bb = exec_op.getBody().front();
auto &bb = exec_op.getRegion().front();
Operation &child_op = bb.getOperations().front();
return isNotLoopCarriedOp(&child_op);
} else
Expand All @@ -2005,7 +2002,7 @@ class AIRDependency
for (auto user : token.getUsers()) {
if (user->getBlock() == block) {
if (auto async_user = dyn_cast<air::ExecuteOp>(user)) {
auto &bb = async_user.getBody().front();
auto &bb = async_user.getRegion().front();
Operation &child_op = bb.getOperations().front();
if (!isNotLoopCarriedOp(&child_op))
isOnlyUsedByNoCarryOps = false;
Expand Down
159 changes: 85 additions & 74 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,13 @@ air::ExecuteOp getRegionOfAllocOpForOp(Operation *op) {
auto dependency_list = current_async_op.getAsyncDependencies();
if (dependency_list.size()) {
for (auto dep_op : dependency_list) {
if (dep_op.getDefiningOp() &&
dyn_cast<air::ExecuteOp>(dep_op.getDefiningOp())) {
// Found air.ExecuteOp in upstream dependency
auto exec_op = dyn_cast<air::ExecuteOp>(dep_op.getDefiningOp());
auto child_op = exec_op.getChildOp();
if (auto alloc_op = dyn_cast<memref::AllocOp>(child_op)) {
// Found memref.allocOp inside air.ExecuteOp
return exec_op;
}
}
auto exec_op = dep_op.getDefiningOp<air::ExecuteOp>();
if (!exec_op)
continue;
if (llvm::any_of(exec_op.getChildOps(), [](Operation &child_op) {
return isa<memref::AllocOp>(child_op);
}))
return exec_op;
}
}
return nullptr;
Expand All @@ -87,14 +84,13 @@ air::ExecuteOp getRegionOfDeallocOpForOp(Operation *op) {
air::AsyncOpInterface current_async_op = dyn_cast<air::AsyncOpInterface>(op);
auto dependency_token = current_async_op.getAsyncToken();
for (auto user : dependency_token.getUsers()) {
if (auto exec_op = dyn_cast<air::ExecuteOp>(user)) {
// Found air.ExecuteOp in downstream dependency
auto child_op = exec_op.getChildOp();
if (auto dealloc_op = dyn_cast<memref::DeallocOp>(child_op)) {
// Found memref.deallocOp inside air.ExecuteOp
return exec_op;
}
}
auto exec_op = dyn_cast<air::ExecuteOp>(user);
if (!exec_op)
continue;
if (llvm::any_of(exec_op.getChildOps(), [](Operation &child_op) {
return isa<memref::DeallocOp>(child_op);
}))
return exec_op;
}
return nullptr;
}
Expand Down Expand Up @@ -273,11 +269,11 @@ struct HoistDmaInAccumPattern : public OpRewritePattern<scf::ForOp> {
} else if (auto exec_op =
dyn_cast<air::ExecuteOp>(dep_op.getDefiningOp())) {
// Found air.ExecuteOp in upstream dependency
auto child_op = exec_op.getChildOp();
if (auto alloc_op = dyn_cast<memref::AllocOp>(child_op)) {
// Found memref.allocOp inside air.ExecuteOp
foundMemrefAllocDep = true;
}
if (llvm::any_of(exec_op.getChildOps(), [](Operation &child_op) {
return isa<memref::AllocOp>(child_op);
}))
foundMemrefAllocDep =
true; // Found memref.allocOp inside air.ExecuteOp
}
}
}
Expand All @@ -294,18 +290,18 @@ struct HoistDmaInAccumPattern : public OpRewritePattern<scf::ForOp> {
dyn_cast<air::AsyncOpInterface>(current_op);
auto dependency_token = current_async_op.getAsyncToken();
for (auto user : dependency_token.getUsers()) {
if (auto exec_op = dyn_cast<air::ExecuteOp>(user)) {
// Found air.ExecuteOp in downstream dependency
auto child_op = exec_op.getChildOp();
if (auto dealloc_op = dyn_cast<memref::DeallocOp>(child_op)) {
// Found memref.deallocOp inside air.ExecuteOp
foundDepToMemrefDealloc = true;
}
}
if (dyn_cast<air::WaitAllOp>(user)) {
foundDepToWaitall = true;
}
}
auto exec_op = dyn_cast<air::ExecuteOp>(user);
if (!exec_op)
continue;
// Found air.ExecuteOp in downstream dependency
if (llvm::any_of(exec_op.getChildOps(), [](Operation &child_op) {
return isa<memref::DeallocOp>(child_op);
}))
foundDepToMemrefDealloc = true;
}
if (llvm::any_of(dependency_token.getUsers(),
[](Operation *user) { return isa<air::WaitAllOp>(user); }))
foundDepToWaitall = true;
return foundDepToWaitall & foundDepToMemrefDealloc;
}

Expand Down Expand Up @@ -364,7 +360,7 @@ struct HoistAIRChannelInAccumPattern : public OpRewritePattern<scf::ForOp> {
for (auto get : for_op.getOps<air::ChannelGetOp>())
dataProducers.push_back(get);
for (auto exec : for_op.getOps<air::ExecuteOp>()) {
auto child_op = exec.getChildOp();
auto child_op = &exec.getChildOps().front();
if (isa<linalg::FillOp>(child_op))
dataProducers.push_back(exec);
}
Expand Down Expand Up @@ -530,7 +526,7 @@ struct HoistAIRChannelInAccumPattern : public OpRewritePattern<scf::ForOp> {
}
Operation *actual_op_1 = op_1;
if (auto exec = dyn_cast<air::ExecuteOp>(op_1)) {
actual_op_1 = exec.getChildOp();
actual_op_1 = &exec.getChildOps().front();
}
Value op_1_memref = nullptr;
Value op_2_memref = nullptr;
Expand Down Expand Up @@ -585,7 +581,7 @@ struct AnnotateFrontAndBackOpsInForPattern
}
}
}
auto child_op = exec_op.getChildOp();
auto child_op = &exec_op.getChildOps().front();
if (isa<memref::AllocOp>(child_op) && isFrontCandidate) {
iterTokens.push_back(op.getResult(0));
// Note: skipping over alloc ops, since they will get hoisted out of
Expand Down Expand Up @@ -641,7 +637,7 @@ struct AnnotateFrontAndBackOpsInForPattern
for (auto token : yielded_tokens) {
auto back_candidate = token.getDefiningOp();
if (auto exec_op = dyn_cast<air::ExecuteOp>(back_candidate)) {
auto child_op = exec_op.getChildOp();
auto child_op = &exec_op.getChildOps().front();
if (isa<memref::DeallocOp>(child_op)) {
for (auto d : exec_op.getAsyncDependencies()) {
back_candidates.push_back(
Expand Down Expand Up @@ -1607,8 +1603,7 @@ struct CanonicalizeAIRExecute : public OpRewritePattern<air::ExecuteOp> {
LogicalResult matchAndRewrite(air::ExecuteOp exec,
PatternRewriter &rewriter) const override {

auto childOp = exec.getChildOp();
assert(childOp && "air.execute op has no child op");
auto childOp = &exec.getChildOps().front();
// Canonicalize air.execute with empty region.
if (!childOp->mightHaveTrait<OpTrait::IsTerminator>())
return failure();
Expand Down Expand Up @@ -3317,7 +3312,8 @@ class AIRDeAliasMemref
auto async_exec = builder.create<xilinx::air::ExecuteOp>(
user->getLoc(), air::AsyncTokenType::get(alloc->getContext()),
SmallVector<Value>{});
Block *async_exec_bb = builder.createBlock(&async_exec.getBody());
Block *async_exec_bb =
builder.createBlock(&async_exec.getRegion());
builder.setInsertionPointToStart(async_exec_bb);
builder.create<memref::DeallocOp>(user->getLoc(), new_memref);
builder.create<air::ExecuteTerminatorOp>(user->getLoc());
Expand Down Expand Up @@ -4303,7 +4299,7 @@ void identifyTargetOpsInSCFFor(
continue;
if (for_op->isProperAncestor(memrefDefOp)) {
if (auto exec = dyn_cast<air::ExecuteOp>(memrefDefOp))
memrefDefOp = exec.getChildOp();
memrefDefOp = &exec.getChildOps().front();
memrefDefOp->setAttr(
"hoist_alloc",
mlir::BoolAttr::get(memrefDefOp->getContext(), true));
Expand Down Expand Up @@ -4492,7 +4488,7 @@ struct ShrinkMemrefSizesByAccessPattern
auto newExecOp = rewriter.create<air::ExecuteOp>(
execOp->getLoc(), air::AsyncTokenType::get(rewriter.getContext()),
newMemrefType, execOp.getAsyncDependencies());
Block *async_exec_bb = rewriter.createBlock(&newExecOp.getBody());
Block *async_exec_bb = rewriter.createBlock(&newExecOp.getRegion());
rewriter.setInsertionPointToStart(async_exec_bb);
auto newAlloc =
rewriter.create<memref::AllocOp>(alloc->getLoc(), newMemrefType);
Expand Down Expand Up @@ -4722,9 +4718,11 @@ struct ShrinkMemrefSizesByAccessPattern
continue;
if (auto exec_to_herd_iv =
dyn_cast<air::ExecuteOp>((*subview_offsets).getDefiningOp())) {
for (auto oper : exec_to_herd_iv.getChildOp()->getOperands())
if (getHerdArgOwner(oper))
offsetIsHerdIndVar = true;
SetVector<Value> opers;
getUsedValuesDefinedAbove(exec_to_herd_iv.getRegion(), opers);
if (llvm::any_of(opers,
[](Value oper) { return getHerdArgOwner(oper); }))
offsetIsHerdIndVar = true;
}
if (offsetIsHerdIndVar)
if (auto updatedOffset =
Expand All @@ -4747,12 +4745,16 @@ struct ShrinkMemrefSizesByAccessPattern
return false;
if (auto exec = dyn_cast<air::ExecuteOp>(
(*subview_offsets).getDefiningOp())) {
for (auto oper : exec.getChildOp()->getOperands()) {
if (!getConstantIntValue(oper))
return false;
if (*getConstantIntValue(oper) != 0)
return false;
}
SetVector<Value> opers;
getUsedValuesDefinedAbove(exec.getRegion(), opers);
if (llvm::any_of(opers, [](Value oper) {
return !getConstantIntValue(oper);
}))
return false;
if (llvm::any_of(opers, [](Value oper) {
return *getConstantIntValue(oper) != 0;
}))
return false;
} else
return false;
} else if (*getConstantIntValue(*subview_offsets) != 0)
Expand Down Expand Up @@ -4820,13 +4822,16 @@ struct ShrinkMemrefSizesByAccessPattern
return nullptr;
if (index.getDefiningOp()) {
if (auto execOp = dyn_cast<air::ExecuteOp>(index.getDefiningOp())) {
for (auto oper : execOp.getChildOp()->getOperands()) {
if (auto herdOp = air::getHerdArgOwner(oper)) {
rewriter.setInsertionPointToStart(&herdOp.getBody().front());
execOp.getChildOp()->replaceUsesOfWith(
oper, rewriter.create<arith::ConstantIndexOp>(
rewriter.getUnknownLoc(), 0));
}
SetVector<Value> opers;
getUsedValuesDefinedAbove(execOp.getRegion(), opers);
for (auto oper : opers) {
auto herdOp = air::getHerdArgOwner(oper);
if (!herdOp)
continue;
rewriter.setInsertionPointToStart(&herdOp.getBody().front());
Value constZero = rewriter.create<arith::ConstantIndexOp>(
rewriter.getUnknownLoc(), 0);
replaceAllUsesInRegionWith(oper, constZero, execOp.getRegion());
}
}
} else if (auto herdOp = air::getHerdArgOwner(index)) {
Expand Down Expand Up @@ -4869,25 +4874,31 @@ struct AIRSegmentLoopFusionPattern : public OpRewritePattern<air::SegmentOp> {
// dealloc.
std::vector<std::pair<air::ExecuteOp, air::ExecuteOp>> alloc_dealloc_execs;
for (auto execOp : op.getOps<air::ExecuteOp>()) {
if (!execOp.getChildOp())
if (llvm::none_of(execOp.getChildOps(), [](Operation &child_op) {
return isa<memref::AllocOp>(child_op);
}))
continue;
if (!isa<memref::AllocOp>(execOp.getChildOp()))
SmallVector<Value> memrefs;
for (auto res : execOp->getResults())
if (isa<MemRefType>(res.getType()))
memrefs.push_back(res);
// Skip over any memref results used by other than air.channel.put/get ops
// in loops.
if (llvm::any_of(memrefs, [](Value v) {
return llvm::any_of(v.getUsers(), [](Operation *user) {
return isa<air::ChannelInterface>(user) &&
!isa<scf::ForOp>(user->getParentOp());
});
}))
continue;
auto memref = execOp->getResult(1);
bool allChannelUsersAreInScfFor = true;
for (auto user : memref.getUsers())
if (isa<air::ChannelInterface>(user))
if (!isa<scf::ForOp>(user->getParentOp()))
allChannelUsersAreInScfFor = false;
if (allChannelUsersAreInScfFor)
alloc_dealloc_execs.push_back(std::make_pair(execOp, nullptr));
alloc_dealloc_execs.push_back(std::make_pair(execOp, nullptr));
}
for (auto execOp : op.getOps<air::ExecuteOp>()) {
if (!execOp.getChildOp())
continue;
if (!isa<memref::DeallocOp>(execOp.getChildOp()))
if (llvm::none_of(execOp.getChildOps(), [](Operation &child_op) {
return isa<memref::DeallocOp>(child_op);
}))
continue;
auto dealloc = dyn_cast<memref::DeallocOp>(execOp.getChildOp());
auto dealloc = dyn_cast<memref::DeallocOp>(execOp.getChildOps().front());
for (auto &pair : alloc_dealloc_execs) {
if (dealloc.getMemref() == pair.first.getResult(1)) {
pair.second = execOp;
Expand Down
Loading

0 comments on commit c60d7bd

Please sign in to comment.