Skip to content

Commit

Permalink
Add air-loop-fusion pass to fuse scf.for loops in air.segment (Xili…
Browse files Browse the repository at this point in the history
…nx#426)

* Fixup a missing condition when erasing async events

* Perform memalloc hoisting before converting air.dma to air.channel ops; simplify logic for air.dma memory space demotion

* Silence warnings

* Update unit test to reflect a real gemm scenario

* Add a pass which fuses scf.for loops in air.segment, and generates the candidate for loop structure for pingpong buffering
  • Loading branch information
erwei-xilinx authored Feb 15, 2024
1 parent 59d80b2 commit edb8e02
Show file tree
Hide file tree
Showing 8 changed files with 395 additions and 212 deletions.
2 changes: 2 additions & 0 deletions mlir/include/air/Transform/AIRDependencyScheduleOpt.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ std::unique_ptr<mlir::Pass> createAIRFuseChannels();

std::unique_ptr<mlir::Pass> createAIRIsolateAsyncDmaLoopNests();

std::unique_ptr<mlir::Pass> createAIRSegmentLoopFusion();

// Populate patterns for canonicalizing index operations on loop index
// variables. At the moment, only affine.apply computations on induction
// variables are canonicalized
Expand Down
1 change: 1 addition & 0 deletions mlir/include/air/Transform/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ namespace air {
#define GEN_PASS_DEF_AIRUNROLLCHANNELBYFACTORPATTERN
#define GEN_PASS_DEF_AIRUNROLLLOOPFORPIPELININGPATTERN
#define GEN_PASS_DEF_AFFINELOOPOPTPASS
#define GEN_PASS_DEF_AIRSEGMENTLOOPFUSION
#include "air/Transform/Passes.h.inc"

} // namespace air
Expand Down
8 changes: 8 additions & 0 deletions mlir/include/air/Transform/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,14 @@ def AIRIsolateAsyncDmaLoopNests: Pass<"air-isolate-async-dma-loop-nests", "Modul
}];
}

def AIRSegmentLoopFusion: Pass<"air-loop-fusion", "func::FuncOp"> {
let summary = "Hoist dma ops into perfectly nested loop";
let constructor = "xilinx::air::createAIRSegmentLoopFusion()";
let description = [{
This pass performs loop fusion within air.segment op's region.
}];
}

def AIRDependencyScheduleOpt: Pass<"air-dependency-schedule-opt", "ModuleOp"> {
let summary = "Optimize scheduling based on air async dependency";
let constructor = "xilinx::air::createAIRDependencyScheduleOptPass()";
Expand Down
130 changes: 47 additions & 83 deletions mlir/lib/Conversion/ConvertToAIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,11 @@ T cloneScfLoopUsingRemap(OpBuilder builder, IRMapping &remap, T loop_op,
} else if (externalGetPut && dyn_cast<affine::AffineIfOp>(child_op)) {
// If externalGetPut is not nullptr, then broadcast lowering mode is on
replaceAffineIfOpWithChannelOpAndClone(builder, remap, externalGetPut);
} else if (auto dma_op = dyn_cast<air::DmaMemcpyNdOp>(child_op)) {
if (child_op.hasAttr("loop-carried-dep"))
builder.clone(child_op, remap);
else
replaceAsyncOpWithWaitAllAndClone(builder, remap, &child_op, false);
} else if (getLinalgOpFromExecuteOp(&child_op)) {
replaceAsyncOpWithWaitAllAndClone(builder, remap, &child_op, false);
} else {
Expand Down Expand Up @@ -1066,16 +1071,16 @@ class AIRDmaToAIRChannelConversion
SmallVector<Value, 4> dst_strides = op.getDstStrides();

if (src_offsets.size()) {
if (src_sizes.size() != src_rank)
if (src_sizes.size() != (unsigned)src_rank)
return failure();
if (src_strides.size() != src_rank)
if (src_strides.size() != (unsigned)src_rank)
return failure();
}

if (dst_offsets.size()) {
if (dst_sizes.size() != dst_rank)
if (dst_sizes.size() != (unsigned)dst_rank)
return failure();
if (dst_strides.size() != dst_rank)
if (dst_strides.size() != (unsigned)dst_rank)
return failure();
}

Expand Down Expand Up @@ -1478,16 +1483,16 @@ class AIRDemoteDmaToAIRHierarchyConversion
SmallVector<Value, 4> dst_strides = op.getDstStrides();

if (src_offsets.size()) {
if (src_sizes.size() != src_rank)
if (src_sizes.size() != (unsigned)src_rank)
return failure();
if (src_strides.size() != src_rank)
if (src_strides.size() != (unsigned)src_rank)
return failure();
}

if (dst_offsets.size()) {
if (dst_sizes.size() != dst_rank)
if (dst_sizes.size() != (unsigned)dst_rank)
return failure();
if (dst_strides.size() != dst_rank)
if (dst_strides.size() != (unsigned)dst_rank)
return failure();
}

Expand All @@ -1496,71 +1501,40 @@ class AIRDemoteDmaToAIRHierarchyConversion
{
OpBuilder::InsertionGuard guard(rewriter);

SetVector<Operation *> backwardSlice;
BackwardSliceOptions bsOptions{
[&](Operation *o) { return o != hier_op; }};
getBackwardSlice(op.getOperation(), &backwardSlice, bsOptions);

for (auto parent = op->getParentOp();
!isa<air::HierarchyInterface>(parent);
parent = parent->getParentOp()) {
getBackwardSlice(parent, &backwardSlice, bsOptions);
backwardSlice.insert(parent);
}

for (auto b : backwardSlice) {
if (dyn_cast<air::ExecuteOp>(b)) {
for (auto &exec_child_op : b->getRegions().front().getOps()) {
getBackwardSlice(&exec_child_op, &backwardSlice, bsOptions);
backwardSlice.insert(&exec_child_op);
}
}
}
SmallVector<Operation *> backwardSlice;
backwardSlice.push_back(op);
if (isa<scf::ForOp>(op->getParentOp()))
backwardSlice.push_back(op->getParentOp());
for (auto o : backwardSlice)
for (auto oper : o->getOperands())
if (getConstantIntValue(oper))
backwardSlice.push_back(oper.getDefiningOp());

for (auto b : backwardSlice) {
b->setAttr("hoist", StringAttr::get(ctx, "dep"));
}
op->setAttr("hoist", StringAttr::get(op->getContext(), "dep"));
op->setAttr("loop-carried-dep",
StringAttr::get(op->getContext(), "external"));

// Hoist hierarchy op into scf op
Operation *scf_loop = nullptr;
mlir::OpBuilder::InsertPoint
insertionPointAtHierOp; // To keep a record of the insertion point as
// destination for hoisting
rewriter.setInsertionPoint(hier_op);
if (herd) {
SmallVector<int, 2> lbs;
SmallVector<int, 2> ubs;
auto size = herd.getSizeOperands();
for (auto s : size) {
lbs.push_back(0);
ubs.push_back(*mlir::getConstantIntValue(s));
}
scf::ParallelOp scf_par =
hoistHerdToAsyncParallel(rewriter, loc, ctx, herd, lbs, ubs);
scf_loop = scf_par.getOperation();
} else if (segment) {
// Since segment doesn't have iteration space, it doesn't hoist a loop
insertionPointAtHierOp = rewriter.saveInsertionPoint();
}

if (herd) {
auto scf_par = dyn_cast<scf::ParallelOp>(scf_loop);
// Get mapping for remapped ssa values entering the hoisted scf.parallel
IRMapping remap;
auto herd_size = herd.getSizeOperands();
remap.map(herd.getSize()[0], herd_size[0]);
remap.map(herd.getSize()[1], herd_size[1]);
remap.map(herd.getIds()[0], scf_par.getInductionVars()[0]);
remap.map(herd.getIds()[1], scf_par.getInductionVars()[1]);
if (auto for_op = dyn_cast<scf::ForOp>(op->getParentOp()))
for (auto init_arg : for_op.getInitArgs())
remap.map(init_arg,
rewriter
.create<air::WaitAllOp>(
loc, air::AsyncTokenType::get(op->getContext()),
SmallVector<Value>{})
.getAsyncToken());
int arg_idx = 0;
for (auto arg : herd.getKernelArguments())
remap.map(arg, herd.getKernelOperand(arg_idx++));

// Clone ops into hoisted scf.parallel
rewriter.setInsertionPointToStart(scf_par.getBody());
for (Operation &o :
herd->getRegions().front().getBlocks().front().getOperations()) {
if (isa<air::HerdTerminatorOp>(o))
Expand All @@ -1584,17 +1558,6 @@ class AIRDemoteDmaToAIRHierarchyConversion
} else
return failure();

if (scf_loop) {
scf_loop->walk([&](mlir::Operation *o) {
if (o == o->getBlock()->getTerminator()) {
return;
}
if (!o->hasAttr("hoist"))
erased.insert(o);
else
o->removeAttr("hoist");
});
}
hier_op.walk([&](mlir::Operation *o) {
if (o->hasAttr("hoist"))
o->removeAttr("hoist");
Expand Down Expand Up @@ -2396,13 +2359,13 @@ static LogicalResult canonicalizeAIRDmaOperands(OpBuilder builder,
// default order.
int max_dim_size =
std::max(std::max(offsets.size(), sizes.size()), strides.size());
if (max_dim_size && offsets.size() < memref.getRank()) {
if (max_dim_size && offsets.size() < (unsigned)memref.getRank()) {
for (unsigned i = offsets.size(); i < memref.getRank(); i++) {
offsets.insert(offsets.begin(), builder.create<arith::ConstantIndexOp>(
builder.getUnknownLoc(), 0));
}
}
if (max_dim_size && sizes.size() < memref.getRank()) {
if (max_dim_size && sizes.size() < (unsigned)memref.getRank()) {
for (unsigned i = sizes.size(); i < memref.getRank(); i++) {
sizes.insert(sizes.begin(), builder.create<arith::ConstantIndexOp>(
builder.getUnknownLoc(), 1));
Expand All @@ -2411,7 +2374,7 @@ static LogicalResult canonicalizeAIRDmaOperands(OpBuilder builder,
int memref_size = 1;
for (auto size : memref.getShape())
memref_size *= size;
if (max_dim_size && strides.size() < memref.getRank()) {
if (max_dim_size && strides.size() < (unsigned)memref.getRank()) {
for (unsigned i = strides.size(); i < memref.getRank(); i++) {
strides.insert(strides.begin(),
builder.create<arith::ConstantIndexOp>(
Expand All @@ -2420,12 +2383,13 @@ static LogicalResult canonicalizeAIRDmaOperands(OpBuilder builder,
}

// Reduce highest dimensions if more than memref size
while (strides.size() > memref.getRank() && getConstantIntValue(strides[0]) &&
while (strides.size() > (unsigned)memref.getRank() &&
getConstantIntValue(strides[0]) &&
*getConstantIntValue(strides[0]) == memref_size) {
strides.erase(strides.begin());
}
while (sizes.size() > memref.getRank() && getConstantIntValue(sizes[0]) &&
*getConstantIntValue(sizes[0]) == 1) {
while (sizes.size() > (unsigned)memref.getRank() &&
getConstantIntValue(sizes[0]) && *getConstantIntValue(sizes[0]) == 1) {
sizes.erase(sizes.begin());
}
while (offsets.size() > std::min(sizes.size(), strides.size()) &&
Expand Down Expand Up @@ -2755,16 +2719,6 @@ struct DmaToChannelPass : public air::impl::DmaToChannelBase<DmaToChannelPass> {
SmallVector<func::FuncOp, 4> funcOps;
module.walk([&](func::FuncOp op) { funcOps.push_back(op); });

// Hoist broadcast pattern
for (auto f : funcOps) {
f.walk([&](affine::AffineIfOp op) {
if (!op->getParentOfType<affine::AffineIfOp>()) {
// Only hoist top-level affine if op with a nest of if ops
HoistingAffineIf(op);
}
});
}

// Demote memref alloc pattern
std::map<air::HierarchyInterface, std::vector<Operation *>> hier_to_allocs;
for (auto f : funcOps) {
Expand All @@ -2784,7 +2738,7 @@ struct DmaToChannelPass : public air::impl::DmaToChannelBase<DmaToChannelPass> {
alloc->getParentOfType<air::ExecuteOp>()
? alloc->getParentOfType<air::ExecuteOp>().getOperation()
: alloc.getOperation();
if (memref_type.getMemorySpaceAsInt() < hierMemorySpace) {
if (memref_type.getMemorySpaceAsInt() < (unsigned)hierMemorySpace) {
hier_to_allocs[hier_op].push_back(alloc_op);
}
});
Expand All @@ -2794,6 +2748,16 @@ struct DmaToChannelPass : public air::impl::DmaToChannelBase<DmaToChannelPass> {
(void)AIRDemoteMemrefToAIRHierarchy(pair, builder);
}

// Hoist broadcast pattern
for (auto f : funcOps) {
f.walk([&](affine::AffineIfOp op) {
if (!op->getParentOfType<affine::AffineIfOp>()) {
// Only hoist top-level affine if op with a nest of if ops
HoistingAffineIf(op);
}
});
}

// First pattern to demote dma ops to corresponding air hierarchy
ConversionTarget target_0(*context);

Expand Down
Loading

0 comments on commit edb8e02

Please sign in to comment.