Skip to content

Commit

Permalink
initial shape overlap pass
Browse files Browse the repository at this point in the history
  • Loading branch information
gysit committed Jan 17, 2021
1 parent a1672b9 commit 6234ede
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 65 deletions.
113 changes: 78 additions & 35 deletions lib/Dialect/Stencil/ShapeOverlapPass.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include "Dialect/Stencil/Passes.h"
#include "Dialect/Stencil/StencilDialect.h"
#include "Dialect/Stencil/StencilOps.h"
#include "Dialect/Stencil/StencilUtils.h"
#include "PassDetail.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include <algorithm>
#include <bits/stdint-intn.h>
#include "llvm/ADT/SmallVector.h"

using namespace mlir;
using namespace stencil;
Expand All @@ -22,11 +22,13 @@ class AccessSets {
while (!queue.empty()) {
auto curr = queue.back();
queue.pop_back();

// Add field to the access set
if (auto loadOp = dyn_cast<LoadOp>(curr)) {
accessSets[storeOp].insert(loadOp.field());
continue;
}

// Search possible access sets
for (auto operand : curr->getOperands()) {
if (auto definingOp = operand.getDefiningOp()) {
Expand Down Expand Up @@ -55,9 +57,14 @@ struct ShapeOverlapPass : public ShapeOverlapPassBase<ShapeOverlapPass> {

protected:
bool areRangesOverlapping(ShapeOp shapeOp1, ShapeOp shapeOp2);
SmallVector<Value, 10> computeGroupValues(int dim, int64_t lb, int64_t ub,
SmallVector<Value, 10> computeGroupValues(OpBuilder b, int dim, int64_t from,
int64_t to, ArrayRef<int64_t> lb,
ArrayRef<int64_t> ub,
ArrayRef<StoreOp> group);
void splitGroupDimension(int dim, ArrayRef<StoreOp> group);
SmallVector<Value, 10> splitGroupDimension(OpBuilder b, int dim,
ArrayRef<int64_t> lb,
ArrayRef<int64_t> ub,
ArrayRef<StoreOp> group);
};

// Two shapes overlap if at least half of the bounds are overlapping
Expand All @@ -74,27 +81,47 @@ bool ShapeOverlapPass::areRangesOverlapping(ShapeOp shapeOp1,
return count >= kIndexSize;
}

SmallVector<Value, 10>
ShapeOverlapPass::computeGroupValues(int dim, int64_t lb, int64_t ub,
ArrayRef<StoreOp> group) {
SmallVector<Value, 10> ShapeOverlapPass::computeGroupValues(
OpBuilder b, int dim, int64_t from, int64_t to, ArrayRef<int64_t> lb,
ArrayRef<int64_t> ub, ArrayRef<StoreOp> group) {
// Iterate all store operations of the group
SmallVector<Value, 10> tempValues;
for (auto storeOp : group) {
SmallVector<StoreOp, 10> subGroup;
SmallVector<size_t, 10> subGroupIndexes;
for (auto en : llvm::enumerate(group)) {
auto storeOp = en.value();
auto shapeOp = cast<ShapeOp>(storeOp.getOperation());
if (shapeOp.getLB()[dim] <= lb && shapeOp.getUB()[dim] >= ub) {
tempValues.push_back(storeOp.temp());
} else {
tempValues.push_back(nullptr);
if (shapeOp.getLB()[dim] <= from && shapeOp.getUB()[dim] >= to) {
subGroup.push_back(storeOp);
subGroupIndexes.push_back(en.index());
}
}

// Split the subgroup recursively
auto subTempValues = splitGroupDimension(b, dim - 1, lb, ub, subGroup);

// Return the resulting temp values or null otherwise
SmallVector<Value, 10> tempValues(group.size(), nullptr);
for (auto en : llvm::enumerate(subGroupIndexes)) {
tempValues[en.value()] = subTempValues[en.index()];
}
return tempValues;
}

void ShapeOverlapPass::splitGroupDimension(int dim, ArrayRef<StoreOp> group) {
SmallVector<Value, 10> ShapeOverlapPass::splitGroupDimension(
OpBuilder b, int dim, ArrayRef<int64_t> lb, ArrayRef<int64_t> ub,
ArrayRef<StoreOp> group) {
assert(dim < kIndexSize &&
"expected dimension to be lower than the index size");
// Return the group temporary values if the dimension is smaller than zero
if (dim < 0) {
SmallVector<Value, 10> tempValues;
llvm::transform(group, std::back_inserter(tempValues),
[](StoreOp storeOp) { return storeOp.temp(); });
return tempValues;
}

// Compute the bounds of all subdomains
SmallVector<int64_t, 10> limits;
SmallVector<int64_t, 10> limits = {lb[dim], ub[dim]};
for (auto storeOp : group) {
auto shapeOp = cast<ShapeOp>(storeOp.getOperation());
limits.push_back(shapeOp.getLB()[dim]);
Expand All @@ -103,13 +130,6 @@ void ShapeOverlapPass::splitGroupDimension(int dim, ArrayRef<StoreOp> group) {
std::sort(limits.begin(), limits.end());
limits.erase(std::unique(limits.begin(), limits.end()), limits.end());

// Setup the op builder
auto storeOp = *std::min_element(group.begin(), group.end(),
[](StoreOp storeOp1, StoreOp storeOp2) {
return storeOp1->isBeforeInBlock(storeOp2);
});
OpBuilder b(storeOp);

// Compute the lower and upper bounds of all intervals except for the last
SmallVector<int64_t, 10> lowerBounds;
SmallVector<int64_t, 10> upperBounds;
Expand All @@ -119,15 +139,15 @@ void ShapeOverlapPass::splitGroupDimension(int dim, ArrayRef<StoreOp> group) {
}

// Initialize the temporary values to the values stored in the last intervall
auto tempValues =
computeGroupValues(dim, lowerBounds.back(), upperBounds.back(), group);
auto tempValues = computeGroupValues(b, dim, lowerBounds.back(),
upperBounds.back(), lb, ub, group);
lowerBounds.pop_back();
upperBounds.pop_back();

// Introduce combine operations in backward order
while (!lowerBounds.empty()) {
auto currValues =
computeGroupValues(dim, lowerBounds.back(), upperBounds.back(), group);
auto currValues = computeGroupValues(b, dim, lowerBounds.back(),
upperBounds.back(), lb, ub, group);

// Compute the indexes of the lower upper and extra values
SmallVector<int, 10> lower, lowerext, upperext;
Expand Down Expand Up @@ -162,10 +182,16 @@ void ShapeOverlapPass::splitGroupDimension(int dim, ArrayRef<StoreOp> group) {
llvm::transform(upperext, std::back_inserter(upperextOperands),
[&](int64_t x) { return tempValues[x]; });

// Compute the fused location
SmallVector<Location, 10> locs;
llvm::transform(group, std::back_inserter(locs),
[](StoreOp storeOp) { return storeOp->getLoc(); });

// Create a combine operation
auto combineOp = b.create<stencil::CombineOp>(
storeOp.getLoc(), resultTypes, dim, upperBounds.back(), lowerOperands,
upperOperands, lowerextOperands, upperextOperands, nullptr, nullptr);
locs.empty() ? b.getUnknownLoc() : b.getFusedLoc(locs), resultTypes,
dim, upperBounds.back(), lowerOperands, upperOperands, lowerextOperands,
upperextOperands, nullptr, nullptr);

// Update the temporary values
unsigned resultIdx = 0;
Expand All @@ -180,10 +206,7 @@ void ShapeOverlapPass::splitGroupDimension(int dim, ArrayRef<StoreOp> group) {
upperBounds.pop_back();
}

// Update the store operands
for(auto en : llvm::enumerate(group)) {
en.value()->setOperand(0, tempValues[en.index()]);
}
return tempValues;
}

} // namespace
Expand All @@ -198,8 +221,6 @@ void ShapeOverlapPass::runOnFunction() {
// Compute the extent analysis
AccessSets &sets = getAnalysis<AccessSets>();

// TODO check before shape inference

// Walk all store operations
SmallVector<SmallVector<StoreOp, 10>, 4> groupList;
funcOp->walk([&](StoreOp storeOp1) {
Expand All @@ -219,7 +240,29 @@ void ShapeOverlapPass::runOnFunction() {

// Split all groups
for (auto &group : groupList) {
splitGroupDimension(0, group);
// Search the first store operation
auto storeOp = *std::min_element(
group.begin(), group.end(), [](StoreOp storeOp1, StoreOp storeOp2) {
return storeOp1->isBeforeInBlock(storeOp2);
});
OpBuilder b(storeOp);

// Compute the bounding box of all store operation shapes
auto shapeOp = cast<ShapeOp>(storeOp.getOperation());
auto lb = shapeOp.getLB();
auto ub = shapeOp.getUB();
for (ShapeOp shapeOp : group) {
lb = applyFunElementWise(lb, shapeOp.getLB(), min);
ub = applyFunElementWise(ub, shapeOp.getUB(), max);
}

// Split the group dimensions recursively
auto tempValues = splitGroupDimension(b, kIndexSize - 1, lb, ub, group);

// Update the store operands
for (auto en : llvm::enumerate(group)) {
en.value()->setOperand(0, tempValues[en.index()]);
}
}
}

Expand Down
36 changes: 6 additions & 30 deletions test/Dialect/Stencil/shape-overlap.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ func @ioverlap(%arg0: !stencil.field<?x?x?xf64>, %arg1: !stencil.field<?x?x?xf64
%1 = stencil.cast %arg1([-3, -3, 0] : [67, 67, 60]) : (!stencil.field<?x?x?xf64>) -> !stencil.field<70x70x60xf64>
%2 = stencil.cast %arg2([-3, -3, 0] : [67, 67, 60]) : (!stencil.field<?x?x?xf64>) -> !stencil.field<70x70x60xf64>
%3 = stencil.load %0 : (!stencil.field<70x70x60xf64>) -> !stencil.temp<?x?x?xf64>
// CHECK: [[TEMP1:%.*]] = stencil.apply (%{{.*}} = %{{.*}} : !stencil.temp<?x?x?xf64>) -> !stencil.temp<?x?x?xf64> {
// CHECK: [[TEMP2:%.*]] = stencil.apply (%{{.*}} = %{{.*}} : !stencil.temp<?x?x?xf64>) -> !stencil.temp<?x?x?xf64> {
%4 = stencil.apply (%arg3 = %3 : !stencil.temp<?x?x?xf64>) -> !stencil.temp<?x?x?xf64> {
%6 = stencil.access %arg3 [0, 0, 0] : (!stencil.temp<?x?x?xf64>) -> f64
%7 = stencil.store_result %6 : (f64) -> !stencil.result<f64>
Expand All @@ -16,39 +18,13 @@ func @ioverlap(%arg0: !stencil.field<?x?x?xf64>, %arg1: !stencil.field<?x?x?xf64
%7 = stencil.store_result %6 : (f64) -> !stencil.result<f64>
stencil.return %7 : !stencil.result<f64>
}
// CHECK: [[TEMP3:%.*]]:2 = stencil.combine 0 at 64 lower = ([[TEMP2]] : !stencil.temp<?x?x?xf64>) upper = ([[TEMP2]] : !stencil.temp<?x?x?xf64>) lowerext = ([[TEMP1]] : !stencil.temp<?x?x?xf64>) : !stencil.temp<?x?x?xf64>, !stencil.temp<?x?x?xf64>
// CHECK-NEXT: [[TEMP4:%.*]]:2 = stencil.combine 0 at 0 lower = ([[TEMP1]] : !stencil.temp<?x?x?xf64>) upper = ([[TEMP3]]#1 : !stencil.temp<?x?x?xf64>) upperext = ([[TEMP3]]#0 : !stencil.temp<?x?x?xf64>) : !stencil.temp<?x?x?xf64>, !stencil.temp<?x?x?xf64>
// CHECK-NEXT: stencil.store [[TEMP4:%.*]]#0 to %{{.*}}([-1, 0, 0] : [64, 64, 60]) : !stencil.temp<?x?x?xf64> to !stencil.field<70x70x60xf64>
// CHECK-NEXT: stencil.store [[TEMP4:%.*]]#1 to %{{.*}}([0, 0, 0] : [65, 64, 60]) : !stencil.temp<?x?x?xf64> to !stencil.field<70x70x60xf64>
stencil.store %4 to %1([-1, 0, 0] : [64, 64, 60]) : !stencil.temp<?x?x?xf64> to !stencil.field<70x70x60xf64>
stencil.store %5 to %2([0, 0, 0] : [65, 64, 60]) : !stencil.temp<?x?x?xf64> to !stencil.field<70x70x60xf64>
return
}

// -----

// func @test(%arg0: !stencil.field<?x?x?xf64>, %arg1: !stencil.field<?x?x?xf64>, %arg2: !stencil.field<?x?x?xf64>) attributes {stencil.program} {
// %0 = stencil.cast %arg0([-3, -3, 0] : [67, 67, 60]) : (!stencil.field<?x?x?xf64>) -> !stencil.field<70x70x60xf64>
// %1 = stencil.cast %arg1([-3, -3, 0] : [67, 67, 60]) : (!stencil.field<?x?x?xf64>) -> !stencil.field<70x70x60xf64>
// %2 = stencil.cast %arg2([-3, -3, 0] : [67, 67, 60]) : (!stencil.field<?x?x?xf64>) -> !stencil.field<70x70x60xf64>
// %3 = stencil.load %0 : (!stencil.field<70x70x60xf64>) -> !stencil.temp<?x?x?xf64>
// %4 = stencil.apply (%arg3 = %3 : !stencil.temp<?x?x?xf64>) -> !stencil.temp<?x?x?xf64> {
// %6 = stencil.access %arg3 [0, 0, 0] : (!stencil.temp<?x?x?xf64>) -> f64
// %7 = stencil.store_result %6 : (f64) -> !stencil.result<f64>
// stencil.return %7 : !stencil.result<f64>
// }
// %5 = stencil.apply (%arg3 = %3 : !stencil.temp<?x?x?xf64>) -> !stencil.temp<?x?x?xf64> {
// %6 = stencil.access %arg3 [0, 0, 0] : (!stencil.temp<?x?x?xf64>) -> f64
// %7 = stencil.store_result %6 : (f64) -> !stencil.result<f64>
// stencil.return %7 : !stencil.result<f64>
// }


// %6,%7 = stencil.combine 0 at 0 lower = (%4 : !stencil.temp<?x?x?xf64>)
// upper = (%4 : !stencil.temp<?x?x?xf64>)
// upperext = (%5 : !stencil.temp<?x?x?xf64>) : !stencil.temp<?x?x?xf64>, !stencil.temp<?x?x?xf64>

// %8,%9 = stencil.combine 0 at 64 lower = (%7 : !stencil.temp<?x?x?xf64>)
// upper = (%5 : !stencil.temp<?x?x?xf64>)
// lowerext = (%6 : !stencil.temp<?x?x?xf64>) : !stencil.temp<?x?x?xf64>, !stencil.temp<?x?x?xf64>

// stencil.store %9 to %1([-1, 0, 0] : [64, 64, 60]) : !stencil.temp<?x?x?xf64> to !stencil.field<70x70x60xf64>
// stencil.store %8 to %2([0, 0, 0] : [65, 64, 60]) : !stencil.temp<?x?x?xf64> to !stencil.field<70x70x60xf64>
// return
// }

0 comments on commit 6234ede

Please sign in to comment.