From 982001f77105a0fb38469315fc03dbf65eebfc75 Mon Sep 17 00:00:00 2001 From: Jeff Fifield Date: Mon, 4 Mar 2024 16:08:34 -0700 Subject: [PATCH] Support removal of loop carried events in airrt-to-ipu (#470) Check for uses before deleting airrt.wait_all and try again after unroll. --- mlir/lib/Conversion/AIRRtToIpuPass.cpp | 25 +++++++++++++------ .../Conversion/AIRRtToIpu/airrt_to_ipu.mlir | 18 +++++++++++++ 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Conversion/AIRRtToIpuPass.cpp b/mlir/lib/Conversion/AIRRtToIpuPass.cpp index d03743f0f..7787db950 100644 --- a/mlir/lib/Conversion/AIRRtToIpuPass.cpp +++ b/mlir/lib/Conversion/AIRRtToIpuPass.cpp @@ -795,6 +795,10 @@ struct AIRRtToIpuPass : public impl::AIRRtToIpuBase { (void)applyPatternsAndFoldGreedily(module, std::move(canoPatterns_1)); unrollSCFFors(module); + // Purge all wait ops again after unroll, in case there were loop carried + // events which couldn't be purged before + purgeWaitAlls(module); + // Purge dma ops' async tokens purgeDmaAsyncTokens(module); @@ -901,13 +905,20 @@ struct AIRRtToIpuPass : public impl::AIRRtToIpuBase { } void purgeWaitAlls(ModuleOp module) { - SmallVector waits; - module.walk([&](WaitAllOp w) { waits.push_back(w); }); - for (auto w : waits) { - w->eraseOperands(0, w->getNumOperands()); - } - for (auto w : waits) { - w.erase(); + int size = 0; + int last_size = 1; + while (size < last_size) { + SmallVector waits; + module.walk([&](WaitAllOp w) { waits.push_back(w); }); + size = waits.size(); + last_size = size; + for (auto &w : waits) { + if (!w->use_empty()) + continue; + w->eraseOperands(0, w->getNumOperands()); + w.erase(); + size--; + } } } diff --git a/mlir/test/Conversion/AIRRtToIpu/airrt_to_ipu.mlir b/mlir/test/Conversion/AIRRtToIpu/airrt_to_ipu.mlir index bb6414e1b..55ffff845 100644 --- a/mlir/test/Conversion/AIRRtToIpu/airrt_to_ipu.mlir +++ b/mlir/test/Conversion/AIRRtToIpu/airrt_to_ipu.mlir @@ -602,4 +602,22 @@ module { } } +// ----- + +// Loop carried event +// CHECK-LABEL: func.func @func14 +// CHECK-NEXT: return +module { + func.func @func14() { + %c0 = arith.constant 0 : index + %c512 = arith.constant 512 : index + %c2048 = arith.constant 2048 : index + %9 = airrt.wait_all : !airrt.event + %11:1 = scf.for %arg6 = %c0 to %c2048 step %c512 iter_args(%arg7 = %9) -> (!airrt.event) { + %12 = airrt.wait_all : !airrt.event + scf.yield %12 : !airrt.event + } + return + } +}