From 6234ede974105c1c36cad21898da71e5f8153a55 Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Sun, 17 Jan 2021 19:56:06 +0100 Subject: [PATCH] initial shape overlap pass --- lib/Dialect/Stencil/ShapeOverlapPass.cpp | 113 ++++++++++++++++------- test/Dialect/Stencil/shape-overlap.mlir | 36 ++------ 2 files changed, 84 insertions(+), 65 deletions(-) diff --git a/lib/Dialect/Stencil/ShapeOverlapPass.cpp b/lib/Dialect/Stencil/ShapeOverlapPass.cpp index a3851ac..3b0f434 100644 --- a/lib/Dialect/Stencil/ShapeOverlapPass.cpp +++ b/lib/Dialect/Stencil/ShapeOverlapPass.cpp @@ -1,11 +1,11 @@ #include "Dialect/Stencil/Passes.h" #include "Dialect/Stencil/StencilDialect.h" #include "Dialect/Stencil/StencilOps.h" +#include "Dialect/Stencil/StencilUtils.h" #include "PassDetail.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" -#include -#include +#include "llvm/ADT/SmallVector.h" using namespace mlir; using namespace stencil; @@ -22,11 +22,13 @@ class AccessSets { while (!queue.empty()) { auto curr = queue.back(); queue.pop_back(); + // Add field to the access set if (auto loadOp = dyn_cast(curr)) { accessSets[storeOp].insert(loadOp.field()); continue; } + // Search possible access sets for (auto operand : curr->getOperands()) { if (auto definingOp = operand.getDefiningOp()) { @@ -55,9 +57,14 @@ struct ShapeOverlapPass : public ShapeOverlapPassBase { protected: bool areRangesOverlapping(ShapeOp shapeOp1, ShapeOp shapeOp2); - SmallVector computeGroupValues(int dim, int64_t lb, int64_t ub, + SmallVector computeGroupValues(OpBuilder b, int dim, int64_t from, + int64_t to, ArrayRef lb, + ArrayRef ub, ArrayRef group); - void splitGroupDimension(int dim, ArrayRef group); + SmallVector splitGroupDimension(OpBuilder b, int dim, + ArrayRef lb, + ArrayRef ub, + ArrayRef group); }; // Two shapes overlap if at least half of the bounds are overlapping @@ -74,27 +81,47 @@ bool ShapeOverlapPass::areRangesOverlapping(ShapeOp shapeOp1, return count >= kIndexSize; } -SmallVector -ShapeOverlapPass::computeGroupValues(int dim, int64_t lb, int64_t ub, - ArrayRef group) { +SmallVector ShapeOverlapPass::computeGroupValues( + OpBuilder b, int dim, int64_t from, int64_t to, ArrayRef lb, + ArrayRef ub, ArrayRef group) { // Iterate all store operations of the group - SmallVector tempValues; - for (auto storeOp : group) { + SmallVector subGroup; + SmallVector subGroupIndexes; + for (auto en : llvm::enumerate(group)) { + auto storeOp = en.value(); auto shapeOp = cast(storeOp.getOperation()); - if (shapeOp.getLB()[dim] <= lb && shapeOp.getUB()[dim] >= ub) { - tempValues.push_back(storeOp.temp()); - } else { - tempValues.push_back(nullptr); + if (shapeOp.getLB()[dim] <= from && shapeOp.getUB()[dim] >= to) { + subGroup.push_back(storeOp); + subGroupIndexes.push_back(en.index()); } } + + // Split the subgroup recursively + auto subTempValues = splitGroupDimension(b, dim - 1, lb, ub, subGroup); + + // Return the resulting temp values or null otherwise + SmallVector tempValues(group.size(), nullptr); + for (auto en : llvm::enumerate(subGroupIndexes)) { + tempValues[en.value()] = subTempValues[en.index()]; + } return tempValues; } -void ShapeOverlapPass::splitGroupDimension(int dim, ArrayRef group) { +SmallVector ShapeOverlapPass::splitGroupDimension( + OpBuilder b, int dim, ArrayRef lb, ArrayRef ub, + ArrayRef group) { assert(dim < kIndexSize && "expected dimension to be lower than the index size"); + // Return the group temporary values if the dimension is smaller than zero + if (dim < 0) { + SmallVector tempValues; + llvm::transform(group, std::back_inserter(tempValues), + [](StoreOp storeOp) { return storeOp.temp(); }); + return tempValues; + } + // Compute the bounds of all subdomains - SmallVector limits; + SmallVector limits = {lb[dim], ub[dim]}; for (auto storeOp : group) { auto shapeOp = cast(storeOp.getOperation()); limits.push_back(shapeOp.getLB()[dim]); @@ -103,13 +130,6 @@ void ShapeOverlapPass::splitGroupDimension(int dim, ArrayRef group) { std::sort(limits.begin(), limits.end()); limits.erase(std::unique(limits.begin(), limits.end()), limits.end()); - // Setup the op builder - auto storeOp = *std::min_element(group.begin(), group.end(), - [](StoreOp storeOp1, StoreOp storeOp2) { - return storeOp1->isBeforeInBlock(storeOp2); - }); - OpBuilder b(storeOp); - // Compute the lower and upper bounds of all intervals except for the last SmallVector lowerBounds; SmallVector upperBounds; @@ -119,15 +139,15 @@ void ShapeOverlapPass::splitGroupDimension(int dim, ArrayRef group) { } // Initialize the temporary values to the values stored in the last intervall - auto tempValues = - computeGroupValues(dim, lowerBounds.back(), upperBounds.back(), group); + auto tempValues = computeGroupValues(b, dim, lowerBounds.back(), + upperBounds.back(), lb, ub, group); lowerBounds.pop_back(); upperBounds.pop_back(); // Introduce combine operations in backward order while (!lowerBounds.empty()) { - auto currValues = - computeGroupValues(dim, lowerBounds.back(), upperBounds.back(), group); + auto currValues = computeGroupValues(b, dim, lowerBounds.back(), + upperBounds.back(), lb, ub, group); // Compute the indexes of the lower upper and extra values SmallVector lower, lowerext, upperext; @@ -162,10 +182,16 @@ void ShapeOverlapPass::splitGroupDimension(int dim, ArrayRef group) { llvm::transform(upperext, std::back_inserter(upperextOperands), [&](int64_t x) { return tempValues[x]; }); + // Compute the fused location + SmallVector locs; + llvm::transform(group, std::back_inserter(locs), + [](StoreOp storeOp) { return storeOp->getLoc(); }); + // Create a combine operation auto combineOp = b.create( - storeOp.getLoc(), resultTypes, dim, upperBounds.back(), lowerOperands, - upperOperands, lowerextOperands, upperextOperands, nullptr, nullptr); + locs.empty() ? b.getUnknownLoc() : b.getFusedLoc(locs), resultTypes, + dim, upperBounds.back(), lowerOperands, upperOperands, lowerextOperands, + upperextOperands, nullptr, nullptr); // Update the temporary values unsigned resultIdx = 0; @@ -180,10 +206,7 @@ void ShapeOverlapPass::splitGroupDimension(int dim, ArrayRef group) { upperBounds.pop_back(); } - // Update the store operands - for(auto en : llvm::enumerate(group)) { - en.value()->setOperand(0, tempValues[en.index()]); - } + return tempValues; } } // namespace @@ -198,8 +221,6 @@ void ShapeOverlapPass::runOnFunction() { // Compute the extent analysis AccessSets &sets = getAnalysis(); - // TODO check before shape inference - // Walk all store operations SmallVector, 4> groupList; funcOp->walk([&](StoreOp storeOp1) { @@ -219,7 +240,29 @@ void ShapeOverlapPass::runOnFunction() { // Split all groups for (auto &group : groupList) { - splitGroupDimension(0, group); + // Search the first store operation + auto storeOp = *std::min_element( + group.begin(), group.end(), [](StoreOp storeOp1, StoreOp storeOp2) { + return storeOp1->isBeforeInBlock(storeOp2); + }); + OpBuilder b(storeOp); + + // Compute the bounding box of all store operation shapes + auto shapeOp = cast(storeOp.getOperation()); + auto lb = shapeOp.getLB(); + auto ub = shapeOp.getUB(); + for (ShapeOp shapeOp : group) { + lb = applyFunElementWise(lb, shapeOp.getLB(), min); + ub = applyFunElementWise(ub, shapeOp.getUB(), max); + } + + // Split the group dimensions recursively + auto tempValues = splitGroupDimension(b, kIndexSize - 1, lb, ub, group); + + // Update the store operands + for (auto en : llvm::enumerate(group)) { + en.value()->setOperand(0, tempValues[en.index()]); + } } } diff --git a/test/Dialect/Stencil/shape-overlap.mlir b/test/Dialect/Stencil/shape-overlap.mlir index 3c9a834..fd55c1a 100644 --- a/test/Dialect/Stencil/shape-overlap.mlir +++ b/test/Dialect/Stencil/shape-overlap.mlir @@ -6,6 +6,8 @@ func @ioverlap(%arg0: !stencil.field, %arg1: !stencil.field) -> !stencil.field<70x70x60xf64> %2 = stencil.cast %arg2([-3, -3, 0] : [67, 67, 60]) : (!stencil.field) -> !stencil.field<70x70x60xf64> %3 = stencil.load %0 : (!stencil.field<70x70x60xf64>) -> !stencil.temp + // CHECK: [[TEMP1:%.*]] = stencil.apply (%{{.*}} = %{{.*}} : !stencil.temp) -> !stencil.temp { + // CHECK: [[TEMP2:%.*]] = stencil.apply (%{{.*}} = %{{.*}} : !stencil.temp) -> !stencil.temp { %4 = stencil.apply (%arg3 = %3 : !stencil.temp) -> !stencil.temp { %6 = stencil.access %arg3 [0, 0, 0] : (!stencil.temp) -> f64 %7 = stencil.store_result %6 : (f64) -> !stencil.result @@ -16,39 +18,13 @@ func @ioverlap(%arg0: !stencil.field, %arg1: !stencil.field !stencil.result stencil.return %7 : !stencil.result } + // CHECK: [[TEMP3:%.*]]:2 = stencil.combine 0 at 64 lower = ([[TEMP2]] : !stencil.temp) upper = ([[TEMP2]] : !stencil.temp) lowerext = ([[TEMP1]] : !stencil.temp) : !stencil.temp, !stencil.temp + // CHECK-NEXT: [[TEMP4:%.*]]:2 = stencil.combine 0 at 0 lower = ([[TEMP1]] : !stencil.temp) upper = ([[TEMP3]]#1 : !stencil.temp) upperext = ([[TEMP3]]#0 : !stencil.temp) : !stencil.temp, !stencil.temp + // CHECK-NEXT: stencil.store [[TEMP4:%.*]]#0 to %{{.*}}([-1, 0, 0] : [64, 64, 60]) : !stencil.temp to !stencil.field<70x70x60xf64> + // CHECK-NEXT: stencil.store [[TEMP4:%.*]]#1 to %{{.*}}([0, 0, 0] : [65, 64, 60]) : !stencil.temp to !stencil.field<70x70x60xf64> stencil.store %4 to %1([-1, 0, 0] : [64, 64, 60]) : !stencil.temp to !stencil.field<70x70x60xf64> stencil.store %5 to %2([0, 0, 0] : [65, 64, 60]) : !stencil.temp to !stencil.field<70x70x60xf64> return } // ----- - -// func @test(%arg0: !stencil.field, %arg1: !stencil.field, %arg2: !stencil.field) attributes {stencil.program} { -// %0 = stencil.cast %arg0([-3, -3, 0] : [67, 67, 60]) : (!stencil.field) -> !stencil.field<70x70x60xf64> -// %1 = stencil.cast %arg1([-3, -3, 0] : [67, 67, 60]) : (!stencil.field) -> !stencil.field<70x70x60xf64> -// %2 = stencil.cast %arg2([-3, -3, 0] : [67, 67, 60]) : (!stencil.field) -> !stencil.field<70x70x60xf64> -// %3 = stencil.load %0 : (!stencil.field<70x70x60xf64>) -> !stencil.temp -// %4 = stencil.apply (%arg3 = %3 : !stencil.temp) -> !stencil.temp { -// %6 = stencil.access %arg3 [0, 0, 0] : (!stencil.temp) -> f64 -// %7 = stencil.store_result %6 : (f64) -> !stencil.result -// stencil.return %7 : !stencil.result -// } -// %5 = stencil.apply (%arg3 = %3 : !stencil.temp) -> !stencil.temp { -// %6 = stencil.access %arg3 [0, 0, 0] : (!stencil.temp) -> f64 -// %7 = stencil.store_result %6 : (f64) -> !stencil.result -// stencil.return %7 : !stencil.result -// } - - -// %6,%7 = stencil.combine 0 at 0 lower = (%4 : !stencil.temp) -// upper = (%4 : !stencil.temp) -// upperext = (%5 : !stencil.temp) : !stencil.temp, !stencil.temp - -// %8,%9 = stencil.combine 0 at 64 lower = (%7 : !stencil.temp) -// upper = (%5 : !stencil.temp) -// lowerext = (%6 : !stencil.temp) : !stencil.temp, !stencil.temp - -// stencil.store %9 to %1([-1, 0, 0] : [64, 64, 60]) : !stencil.temp to !stencil.field<70x70x60xf64> -// stencil.store %8 to %2([0, 0, 0] : [65, 64, 60]) : !stencil.temp to !stencil.field<70x70x60xf64> -// return -// } \ No newline at end of file