diff --git a/lib/Dialect/Stencil/StencilUnrollingPass.cpp b/lib/Dialect/Stencil/StencilUnrollingPass.cpp index 8e86c89..34b621c 100644 --- a/lib/Dialect/Stencil/StencilUnrollingPass.cpp +++ b/lib/Dialect/Stencil/StencilUnrollingPass.cpp @@ -19,7 +19,6 @@ #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" #include "llvm/Support/raw_ostream.h" -#include #include #include @@ -37,8 +36,7 @@ struct StencilUnrollingPass void unrollStencilApply(stencil::ApplyOp applyOp); void addPeelIteration(stencil::ApplyOp applyOp); - stencil::ReturnOp makePeelIteration(stencil::ReturnOp returnOp, - unsigned tripCount); + void makePeelIteration(stencil::ReturnOp returnOp, unsigned tripCount); stencil::ReturnOp cloneBody(stencil::ApplyOp from, stencil::ApplyOp to, OpBuilder &builder); }; @@ -103,29 +101,15 @@ void StencilUnrollingPass::unrollStencilApply(stencil::ApplyOp applyOp) { b.getI64ArrayAttr(unroll)); } -stencil::ReturnOp -StencilUnrollingPass::makePeelIteration(stencil::ReturnOp returnOp, - unsigned tripCount) { - // Setup the builder and - OpBuilder b(returnOp); - +void StencilUnrollingPass::makePeelIteration(stencil::ReturnOp returnOp, + unsigned tripCount) { // Create empty store for all iterations that exceed the trip count SmallVector newOperands; for (auto en : llvm::enumerate(returnOp.getOperands())) { if (en.index() % unrollFactor >= tripCount) { - newOperands.push_back(b.create( - returnOp.getLoc(), returnOp.getOperand(0).getType(), ValueRange())); - en.value().getDefiningOp()->erase(); - } else { - newOperands.push_back(en.value()); + en.value().getDefiningOp()->setOperands({}); } } - - // Replace the return op - auto newOp = b.create(returnOp.getLoc(), newOperands, - returnOp.unrollAttr()); - returnOp.erase(); - return newOp; } void StencilUnrollingPass::addPeelIteration(stencil::ApplyOp applyOp) { @@ -145,7 +129,7 @@ void StencilUnrollingPass::addPeelIteration(stencil::ApplyOp applyOp) { auto peelOp = cast(b.clone(*applyOp.getOperation())); auto bodyOp = cast(b.clone(*applyOp.getOperation())); - // Adapt the shape of the two apply ops + // // Adapt the shape of the two apply ops auto lb = shapeOp.getLB(); auto ub = shapeOp.getUB(); int64_t split = ub[unrollIndex] - domainSize % unrollFactor; @@ -185,22 +169,26 @@ void StencilUnrollingPass::runOnFunction() { } // Check shape inference has been executed - bool hasShapeOpWithoutShape = false; - funcOp.walk([&](stencil::ShapeOp shapeOp) { + auto result = funcOp->walk([&](stencil::ShapeOp shapeOp) { if (!shapeOp.hasShape()) - hasShapeOpWithoutShape = true; + return WalkResult::interrupt(); + return WalkResult::advance(); }); - if (hasShapeOpWithoutShape) { + if (result.wasInterrupted()) { funcOp.emitOpError("execute shape inference before stencil unrolling"); signalPassFailure(); return; } - // Unroll all stencil apply ops - funcOp.walk([&](stencil::ApplyOp applyOp) { + // Collect the stencil apply operations + SmallVector workList; + funcOp.walk([&](stencil::ApplyOp applyOp) { workList.push_back(applyOp); }); + + // Unroll the stencil apply operations + for (auto applyOp : workList) { unrollStencilApply(applyOp); addPeelIteration(applyOp); - }); + } } } // namespace