Skip to content

Commit

Permalink
fix shape inference during combine tree adaptation
Browse files Browse the repository at this point in the history
  • Loading branch information
gysit committed Feb 1, 2021
1 parent 7d032f3 commit c5d38d9
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 16 deletions.
54 changes: 38 additions & 16 deletions lib/Dialect/Stencil/PeelOddIterationsPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ struct PeelRewrite : public stencil::ApplyOpPattern {

// Remove stores that exceed the domain
auto peelOp = peelSize < 0 ? leftOp : rightOp;
makePeelIteration(peelOp, domainSize % unrollFac, rewriter);
makePeelIteration(peelOp, peelSize, rewriter);

// Introduce a stencil combine to replace the apply operation
auto combineOp = rewriter.create<stencil::CombineOp>(
Expand Down Expand Up @@ -212,25 +212,23 @@ struct FuseRewrite : public stencil::CombineOpPattern {
// Introduce a peel loop if the shape is not a multiple of the unroll factor
LogicalResult matchAndRewrite(stencil::CombineOp combineOp,
PatternRewriter &rewriter) const override {
// Store the left and right leaf combines
stencil::CombineOp leftCombineOp = combineOp;
stencil::CombineOp rightCombineOp = combineOp;

// Search the lower and upper defining ops and exit if none exists
Operation *currLeftCombineOp = getLowerDefiningOp(combineOp);
Operation *currRightCombineOp = getUpperDefiningOp(combineOp);
if (!currLeftCombineOp || !currRightCombineOp)
return failure();

// Walk up the combine tree
SmallVector<Operation *, 4> leftCombineOps;
SmallVector<Operation *, 4> rightCombineOps;
while (auto combineOp =
dyn_cast_or_null<stencil::CombineOp>(currLeftCombineOp)) {
leftCombineOp = combineOp;
leftCombineOps.push_back(combineOp);
currLeftCombineOp = getUpperDefiningOp(combineOp);
}
while (auto combineOp =
dyn_cast_or_null<stencil::CombineOp>(currRightCombineOp)) {
rightCombineOp = combineOp;
rightCombineOps.push_back(combineOp);
currRightCombineOp = getLowerDefiningOp(combineOp);
}

Expand All @@ -250,23 +248,41 @@ struct FuseRewrite : public stencil::CombineOpPattern {

// Merge the two apply operations in case they overlap
auto newOp = fusePeelIterations(leftOp, rightOp, rewriter);
auto newShape = cast<ShapeOp>(newOp.getOperation());

// Update the shape of the left and right combines
for (auto leftCombineOp : leftCombineOps) {
auto leftShape = cast<ShapeOp>(leftCombineOp);
auto ub = leftShape.getUB();
ub[returnOp.getUnrollDim()] = newShape.getLB()[returnOp.getUnrollDim()];
leftShape.updateShape(leftShape.getLB(), ub);
}
for (auto rightCombineOp : rightCombineOps) {
auto rightShape = cast<ShapeOp>(rightCombineOp);
auto lb = rightShape.getLB();
lb[returnOp.getUnrollDim()] = newShape.getUB()[returnOp.getUnrollDim()];
rightShape.updateShape(lb, rightShape.getUB());
}

// Disconnect the left and right apply operations from the combine tree
SmallVector<Value, 10> leftOperands;
SmallVector<Value, 10> rightOperands;
if (leftCombineOp != combineOp) {
rewriter.replaceOp(leftCombineOp, leftCombineOp.lower());
if (!leftCombineOps.empty()) {
rewriter.replaceOp(
leftCombineOps.back(),
cast<stencil::CombineOp>(leftCombineOps.back()).lower());
leftOperands = combineOp.lower();
}
if (rightCombineOp != combineOp) {
rewriter.replaceOp(rightCombineOp, rightCombineOp.upper());
if (!rightCombineOps.empty()) {
rewriter.replaceOp(
rightCombineOps.back(),
cast<stencil::CombineOp>(rightCombineOps.back()).upper());
rightOperands = combineOp.upper();
}

// Replace the combine op by the results computed by the fused apply
SmallVector<Value, 10> newResults = newOp.getResults();
auto currShape = cast<ShapeOp>(newOp.getOperation());
auto fullShape = cast<ShapeOp>(combineOp.getOperation());
auto unrollDim = returnOp.getUnrollDim();
if (!leftOperands.empty()) {
// Introduce a combine ob to connect to the left combine subtree
Expand All @@ -276,10 +292,9 @@ struct FuseRewrite : public stencil::CombineOpPattern {
ValueRange(), combineOp.lbAttr(), combineOp.ubAttr());
newResults = newCombineOp.getResults();

// Update the shape of the newly introduced combine
auto lb = fullShape.getLB();
auto ub = fullShape.getUB();
ub[unrollDim] = currShape.getUB()[unrollDim];
// Get the lower and upper bounds of the children
auto lb = cast<ShapeOp>(getLowerDefiningOp(newCombineOp)).getLB();
auto ub = cast<ShapeOp>(getUpperDefiningOp(newCombineOp)).getUB();
currShape = cast<ShapeOp>(newCombineOp.getOperation());
currShape.updateShape(lb, ub);
}
Expand All @@ -290,7 +305,14 @@ struct FuseRewrite : public stencil::CombineOpPattern {
currShape.getUB()[unrollDim], newResults, rightOperands, ValueRange(),
ValueRange(), combineOp.lbAttr(), combineOp.ubAttr());
newResults = newCombineOp.getResults();

// Get the lower and upper bounds of the children
auto lb = cast<ShapeOp>(getLowerDefiningOp(newCombineOp)).getLB();
auto ub = cast<ShapeOp>(getUpperDefiningOp(newCombineOp)).getUB();
currShape = cast<ShapeOp>(newCombineOp.getOperation());
currShape.updateShape(lb, ub);
}

rewriter.replaceOp(combineOp, newResults);
rewriter.eraseOp(leftOp);
rewriter.eraseOp(rightOp);
Expand Down
40 changes: 40 additions & 0 deletions test/Dialect/Stencil/peel-odd-iterations.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,44 @@ func @combine_tree(%arg0: !stencil.field<?x?x?xf64>, %arg1: !stencil.field<?x?x?
%5 = stencil.combine 1 at 62 lower = (%3 : !stencil.temp<64x62x60xf64>) upper = (%4 : !stencil.temp<64x1x60xf64>) ([0, 0, 0] : [64, 63, 60]) : !stencil.temp<64x63x60xf64>
stencil.store %5 to %1([0, 0, 0] : [64, 63, 60]) : !stencil.temp<64x63x60xf64> to !stencil.field<70x70x60xf64>
return
}

// -----

// CHECK-LABEL: func @test
func @test(%arg0: !stencil.field<?x?x?xf64>, %arg1: !stencil.field<?x?x?xf64>) attributes {stencil.program} {
%0 = stencil.cast %arg0([-3, -3, 0] : [67, 67, 60]) : (!stencil.field<?x?x?xf64>) -> !stencil.field<70x70x60xf64>
%1 = stencil.cast %arg1([-3, -3, 0] : [67, 67, 60]) : (!stencil.field<?x?x?xf64>) -> !stencil.field<70x70x60xf64>
%2 = stencil.load %0([0, 0, 0] : [64, 63, 60]) : (!stencil.field<70x70x60xf64>) -> !stencil.temp<64x63x60xf64>
%3 = stencil.apply (%arg2 = %2 : !stencil.temp<64x63x60xf64>) -> !stencil.temp<64x10x60xf64> {
%8 = stencil.access %arg2 [0, 0, 0] : (!stencil.temp<64x63x60xf64>) -> f64
%9 = stencil.store_result %8 : (f64) -> !stencil.result<f64>
%10 = stencil.access %arg2 [0, 1, 0] : (!stencil.temp<64x63x60xf64>) -> f64
%11 = stencil.store_result %10 : (f64) -> !stencil.result<f64>
%12 = stencil.access %arg2 [0, 2, 0] : (!stencil.temp<64x63x60xf64>) -> f64
%13 = stencil.store_result %12 : (f64) -> !stencil.result<f64>
%14 = stencil.access %arg2 [0, 3, 0] : (!stencil.temp<64x63x60xf64>) -> f64
%15 = stencil.store_result %14 : (f64) -> !stencil.result<f64>
stencil.return unroll [1, 4, 1] %9, %11, %13, %15 : !stencil.result<f64>, !stencil.result<f64>, !stencil.result<f64>, !stencil.result<f64>
} to ([0, 0, 0] : [64, 10, 60])
%4 = stencil.apply (%arg2 = %2 : !stencil.temp<64x63x60xf64>) -> !stencil.temp<64x52x60xf64> {
%cst = constant 2.000000e+00 : f64
%8 = stencil.store_result %cst : (f64) -> !stencil.result<f64>
%9 = stencil.store_result %cst : (f64) -> !stencil.result<f64>
%10 = stencil.store_result %cst : (f64) -> !stencil.result<f64>
%11 = stencil.store_result %cst : (f64) -> !stencil.result<f64>
stencil.return unroll [1, 4, 1] %8, %9, %10, %11 : !stencil.result<f64>, !stencil.result<f64>, !stencil.result<f64>, !stencil.result<f64>
} to ([0, 10, 0] : [64, 62, 60])
%5 = stencil.combine 1 at 10 lower = (%3 : !stencil.temp<64x10x60xf64>) upper = (%4 : !stencil.temp<64x52x60xf64>) ([0, 0, 0] : [64, 62, 60]) : !stencil.temp<64x62x60xf64>
%6 = stencil.apply (%arg2 = %2 : !stencil.temp<64x63x60xf64>) -> !stencil.temp<64x1x60xf64> {
%cst = constant 1.000000e+00 : f64
%8 = stencil.store_result %cst : (f64) -> !stencil.result<f64>
%9 = stencil.store_result %cst : (f64) -> !stencil.result<f64>
%10 = stencil.store_result %cst : (f64) -> !stencil.result<f64>
%11 = stencil.store_result %cst : (f64) -> !stencil.result<f64>
stencil.return unroll [1, 4, 1] %8, %9, %10, %11 : !stencil.result<f64>, !stencil.result<f64>, !stencil.result<f64>, !stencil.result<f64>
} to ([0, 62, 0] : [64, 63, 60])
%7 = stencil.combine 1 at 62 lower = (%5 : !stencil.temp<64x62x60xf64>) upper = (%6 : !stencil.temp<64x1x60xf64>) ([0, 0, 0] : [64, 63, 60]) : !stencil.temp<64x63x60xf64>
stencil.store %7 to %1([0, 0, 0] : [64, 63, 60]) : !stencil.temp<64x63x60xf64> to !stencil.field<70x70x60xf64>
return
}

0 comments on commit c5d38d9

Please sign in to comment.