Skip to content

Commit

Permalink
fix unroll plus combine
Browse files Browse the repository at this point in the history
  • Loading branch information
gysit committed Dec 23, 2020
1 parent e09e396 commit 4f2dc2d
Showing 1 changed file with 16 additions and 28 deletions.
44 changes: 16 additions & 28 deletions lib/Dialect/Stencil/StencilUnrollingPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"
#include "llvm/Support/raw_ostream.h"
#include <bits/stdint-intn.h>
#include <cstddef>
#include <cstdint>

Expand All @@ -37,8 +36,7 @@ struct StencilUnrollingPass
void unrollStencilApply(stencil::ApplyOp applyOp);
void addPeelIteration(stencil::ApplyOp applyOp);

stencil::ReturnOp makePeelIteration(stencil::ReturnOp returnOp,
unsigned tripCount);
void makePeelIteration(stencil::ReturnOp returnOp, unsigned tripCount);
stencil::ReturnOp cloneBody(stencil::ApplyOp from, stencil::ApplyOp to,
OpBuilder &builder);
};
Expand Down Expand Up @@ -103,29 +101,15 @@ void StencilUnrollingPass::unrollStencilApply(stencil::ApplyOp applyOp) {
b.getI64ArrayAttr(unroll));
}

stencil::ReturnOp
StencilUnrollingPass::makePeelIteration(stencil::ReturnOp returnOp,
unsigned tripCount) {
// Setup the builder and
OpBuilder b(returnOp);

void StencilUnrollingPass::makePeelIteration(stencil::ReturnOp returnOp,
unsigned tripCount) {
// Create empty store for all iterations that exceed the trip count
SmallVector<Value, 16> newOperands;
for (auto en : llvm::enumerate(returnOp.getOperands())) {
if (en.index() % unrollFactor >= tripCount) {
newOperands.push_back(b.create<stencil::StoreResultOp>(
returnOp.getLoc(), returnOp.getOperand(0).getType(), ValueRange()));
en.value().getDefiningOp()->erase();
} else {
newOperands.push_back(en.value());
en.value().getDefiningOp()->setOperands({});
}
}

// Replace the return op
auto newOp = b.create<stencil::ReturnOp>(returnOp.getLoc(), newOperands,
returnOp.unrollAttr());
returnOp.erase();
return newOp;
}

void StencilUnrollingPass::addPeelIteration(stencil::ApplyOp applyOp) {
Expand All @@ -145,7 +129,7 @@ void StencilUnrollingPass::addPeelIteration(stencil::ApplyOp applyOp) {
auto peelOp = cast<stencil::ApplyOp>(b.clone(*applyOp.getOperation()));
auto bodyOp = cast<stencil::ApplyOp>(b.clone(*applyOp.getOperation()));

// Adapt the shape of the two apply ops
// // Adapt the shape of the two apply ops
auto lb = shapeOp.getLB();
auto ub = shapeOp.getUB();
int64_t split = ub[unrollIndex] - domainSize % unrollFactor;
Expand Down Expand Up @@ -185,22 +169,26 @@ void StencilUnrollingPass::runOnFunction() {
}

// Check shape inference has been executed
bool hasShapeOpWithoutShape = false;
funcOp.walk([&](stencil::ShapeOp shapeOp) {
auto result = funcOp->walk([&](stencil::ShapeOp shapeOp) {
if (!shapeOp.hasShape())
hasShapeOpWithoutShape = true;
return WalkResult::interrupt();
return WalkResult::advance();
});
if (hasShapeOpWithoutShape) {
if (result.wasInterrupted()) {
funcOp.emitOpError("execute shape inference before stencil unrolling");
signalPassFailure();
return;
}

// Unroll all stencil apply ops
funcOp.walk([&](stencil::ApplyOp applyOp) {
// Collect the stencil apply operations
SmallVector<stencil::ApplyOp, 16> workList;
funcOp.walk([&](stencil::ApplyOp applyOp) { workList.push_back(applyOp); });

// Unroll the stencil apply operations
for (auto applyOp : workList) {
unrollStencilApply(applyOp);
addPeelIteration(applyOp);
});
}
}

} // namespace
Expand Down

0 comments on commit 4f2dc2d

Please sign in to comment.