Skip to content

Commit

Permalink
Support removal of loop carried events in airrt-to-ipu (Xilinx#470)
Browse files Browse the repository at this point in the history
Check for uses before deleting airrt.wait_all and try again after unroll.
  • Loading branch information
fifield authored Mar 4, 2024
1 parent cb1fde0 commit 982001f
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
25 changes: 18 additions & 7 deletions mlir/lib/Conversion/AIRRtToIpuPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,10 @@ struct AIRRtToIpuPass : public impl::AIRRtToIpuBase<AIRRtToIpuPass> {
(void)applyPatternsAndFoldGreedily(module, std::move(canoPatterns_1));
unrollSCFFors(module);

// Purge all wait ops again after unroll, in case there were loop carried
// events which couldn't be purged before
purgeWaitAlls(module);

// Purge dma ops' async tokens
purgeDmaAsyncTokens(module);

Expand Down Expand Up @@ -901,13 +905,20 @@ struct AIRRtToIpuPass : public impl::AIRRtToIpuBase<AIRRtToIpuPass> {
}

void purgeWaitAlls(ModuleOp module) {
SmallVector<WaitAllOp> waits;
module.walk([&](WaitAllOp w) { waits.push_back(w); });
for (auto w : waits) {
w->eraseOperands(0, w->getNumOperands());
}
for (auto w : waits) {
w.erase();
int size = 0;
int last_size = 1;
while (size < last_size) {
SmallVector<WaitAllOp> waits;
module.walk([&](WaitAllOp w) { waits.push_back(w); });
size = waits.size();
last_size = size;
for (auto &w : waits) {
if (!w->use_empty())
continue;
w->eraseOperands(0, w->getNumOperands());
w.erase();
size--;
}
}
}

Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Conversion/AIRRtToIpu/airrt_to_ipu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -602,4 +602,22 @@ module {
}
}

// -----

// Loop carried event

// CHECK-LABEL: func.func @func14
// CHECK-NEXT: return
module {
func.func @func14() {
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c2048 = arith.constant 2048 : index
%9 = airrt.wait_all : !airrt.event
%11:1 = scf.for %arg6 = %c0 to %c2048 step %c512 iter_args(%arg7 = %9) -> (!airrt.event) {
%12 = airrt.wait_all : !airrt.event
scf.yield %12 : !airrt.event
}
return
}
}

0 comments on commit 982001f

Please sign in to comment.