Skip to content

Commit

Permalink
Hotfix AIRSplitL2Memref: a number of bugfixes to get conv2d with stri…
Browse files Browse the repository at this point in the history
…de 2 to work (Xilinx#676)

* Fixup cascade in affine.map if overlapping; fixup wrap if overlapping

* Roll back split_dim_size for split_by_channel mode

* Offset dim to memref dim conversions

* Make test more flexible with CHECK-DAG
  • Loading branch information
erwei-xilinx authored Jul 22, 2024
1 parent 8727a53 commit 2754dcf
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 35 deletions.
107 changes: 78 additions & 29 deletions mlir/lib/Transform/AIRMiscPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ class AIRSplitL2MemrefForBufferConstraintPass
private:
void partitionMemref(SmallVector<air::ChannelPutOp> &puts,
SmallVector<air::ChannelGetOp> &gets, int dim,
std::string splitType);
Operation *allocOp);
SmallVector<memref::AllocOp>
getTargetMemrefAllocs(func::FuncOp func,
std::map<memref::AllocOp, SmallVector<int>>
Expand Down Expand Up @@ -888,25 +888,31 @@ Value tileChannelOpByFactor(air::ChannelInterface originalChanOp, int factor,
OpBuilder builder(originalChanOp);
SmallVector<Value> originalApplyOperands;
Operation *affineApplyOp = nullptr;
Value originalApplyOutput = nullptr;
Value zeroIdx = builder.create<arith::ConstantIndexOp>(loc, 0);
if (!originalChanOp.getOffsets().empty()) {
auto offsetDefOp = originalChanOp.getOffsets()[dim].getDefiningOp();
if ((offsetDefOp && isa<affine::AffineApplyOp>(offsetDefOp)) ||
isa<air::ExecuteOp>(offsetDefOp))
affineApplyOp = offsetDefOp;
}
if (affineApplyOp && isa<affine::AffineApplyOp>(affineApplyOp))
if (affineApplyOp && isa<affine::AffineApplyOp>(affineApplyOp)) {
originalApplyOperands = affineApplyOp->getOperands();
else if (affineApplyOp && isa<air::ExecuteOp>(affineApplyOp)) {
originalApplyOutput = affineApplyOp->getResult(0);
} else if (affineApplyOp && isa<air::ExecuteOp>(affineApplyOp)) {
auto execOp = dyn_cast<air::ExecuteOp>(affineApplyOp);
originalApplyOperands = execOp.getChildOp()->getOperands();
} else
originalApplyOperands.push_back(
builder.create<arith::ConstantIndexOp>(loc, 0));
originalApplyOutput = affineApplyOp->getResult(1);
} else {
originalApplyOperands.push_back(zeroIdx);
originalApplyOutput = originalChanOp.getOffsets().empty()
? zeroIdx
: originalChanOp.getOffsets()[dim];
}
SmallVector<Value> tokens;
for (int i = 0; i < factor; i++) {
SmallVector<Value> newIndices{
builder.create<arith::ConstantIndexOp>(loc, i),
builder.create<arith::ConstantIndexOp>(loc, 0)};
builder.create<arith::ConstantIndexOp>(loc, i), zeroIdx};
// Update y offset.
// Create affine.apply on induction variable.
auto checkpoint = builder.saveInsertionPoint();
Expand All @@ -920,15 +926,24 @@ Value tileChannelOpByFactor(air::ChannelInterface originalChanOp, int factor,
auto map = AffineMap::get(0, 1, add);
// If allocOp has "affine_map" attribute set, then use that map instead
// (potentially overlapping access pattern).
affine::AffineApplyOp newApplyOp = nullptr;
if (allocOp->hasAttr("affine_map")) {
auto original_map =
allocOp->getAttrOfType<AffineMapAttr>("affine_map").getAffineMap();
if (original_map.getNumInputs() == 2)
if (original_map.getNumInputs() == 2) {
// Overlapping data access.
map = original_map.replace(getAffineSymbolExpr(1, ctx),
getAffineConstantExpr(i, ctx), 0, 1);
}
auto newApplyOp =
builder.create<affine::AffineApplyOp>(loc, map, originalApplyOperands);
if (affineApplyOp)
builder.setInsertionPointAfter(affineApplyOp);
newApplyOp = builder.create<affine::AffineApplyOp>(
loc, map, SmallVector<Value>{originalApplyOutput});
} else // Non-overlapping data access.
newApplyOp = builder.create<affine::AffineApplyOp>(
loc, map, originalApplyOperands);
} else // Non-overlapping data access.
newApplyOp = builder.create<affine::AffineApplyOp>(loc, map,
originalApplyOperands);
if (affineApplyOp)
builder.restoreInsertionPoint(checkpoint);
SmallVector<Value> newOffsets = originalChanOp.getOffsets();
Expand All @@ -938,8 +953,13 @@ Value tileChannelOpByFactor(air::ChannelInterface originalChanOp, int factor,
air::populateDefaultWrapsAndStrides(builder, originalChanOp.getMemref(),
newOffsets, newWraps, newStrides);
newOffsets[dim] = newApplyOp.getResult();
newWraps[dim] = builder.create<arith::ConstantIndexOp>(
loc, llvm::divideCeilSigned(originalMemrefSize, factor));
// Get post-splitting size at split_dim from allocOp attributes.
if (allocOp->hasAttr("split_dim_size"))
newWraps[dim] = builder.create<arith::ConstantIndexOp>(
loc, allocOp->getAttrOfType<IntegerAttr>("split_dim_size").getInt());
else
newWraps[dim] = builder.create<arith::ConstantIndexOp>(
loc, llvm::divideCeilSigned(originalMemrefSize, factor));
auto deps = dyn_cast<air::AsyncOpInterface>(originalChanOp.getOperation())
.getAsyncDependencies();
SmallVector<Type, 4> tys = {air::AsyncTokenType::get(ctx)};
Expand Down Expand Up @@ -989,10 +1009,16 @@ scf::ForOp getScfForFromVal(Value val) {
// Partition L2 memref.
void AIRSplitL2MemrefForBufferConstraintPass::partitionMemref(
SmallVector<air::ChannelPutOp> &puts, SmallVector<air::ChannelGetOp> &gets,
int dim, std::string splitType = "") {
int dim, Operation *allocOp) {
auto memref = puts.front().getMemref();
MemRefType ty = llvm::cast<MemRefType>(memref.getType());
auto allocOp = memref.getDefiningOp();
// allocOp attributes:
int split_dim_size = -1;
if (allocOp->hasAttr("split_dim_size"))
split_dim_size =
allocOp->getAttrOfType<IntegerAttr>("split_dim_size").getInt();
if (isa<air::ExecuteOp>(allocOp->getParentOp()))
allocOp = allocOp->getParentOp();
auto loc = allocOp->getLoc();
auto ctx = allocOp->getContext();
Operation *deallocOp = nullptr;
Expand Down Expand Up @@ -1081,18 +1107,24 @@ void AIRSplitL2MemrefForBufferConstraintPass::partitionMemref(
continue;
if (op.getSizes().size() != newMemrefShape.size())
continue;
auto offset = getConstantIntValue(op.getOffsets()[*offsetDim]);
if (offset)
newMemrefShape[dim] = *getConstantIntValue(op.getSizes()[*offsetDim]);

// Get post-splitting size at split_dim from allocOp attributes.
if (split_dim_size >= 0)
newMemrefShape[dim] = split_dim_size;
else {
auto forOp = getScfForFromVal(op.getOffsets()[*offsetDim]);
if (!forOp)
continue;
auto trip_count = air::getStaticScfForTripCountAsInt(forOp);
if (!trip_count)
continue;
newMemrefShape[dim] =
*getConstantIntValue(op.getSizes()[*offsetDim]) * (*trip_count);
auto offset = getConstantIntValue(op.getOffsets()[*offsetDim]);
if (offset)
newMemrefShape[dim] = *getConstantIntValue(op.getSizes()[*offsetDim]);
else {
auto forOp = getScfForFromVal(op.getOffsets()[*offsetDim]);
if (!forOp)
continue;
auto trip_count = air::getStaticScfForTripCountAsInt(forOp);
if (!trip_count)
continue;
newMemrefShape[dim] =
*getConstantIntValue(op.getSizes()[*offsetDim]) * (*trip_count);
}
}
break;
}
Expand Down Expand Up @@ -1313,6 +1345,24 @@ AIRSplitL2MemrefForBufferConstraintPass::getTargetMemrefAllocs(
allocOp->setAttr("affine_map",
AffineMapAttr::get(apply.getAffineMap()));
}

// Infer the size at splitDim for both overlapping and non-overlapping
// access pattern.
auto offsetDimOpt = air::getOffsetDimFromMemrefDim(
splitDim, chanOp.getStrides(),
air::getTensorShape(memref.getType()));
auto constOffset =
getConstantIntValue(chanOp.getOffsets()[*offsetDimOpt]);
if (!constOffset)
if (auto forOp =
getScfForFromVal(chanOp.getOffsets()[*offsetDimOpt]))
if (auto trip_count = air::getStaticScfForTripCountAsInt(forOp))
allocOp->setAttr(
"split_dim_size",
IntegerAttr::get(
IntegerType::get(ctx, 32),
*getConstantIntValue(chanOp.getSizes()[*offsetDimOpt]) *
(*trip_count)));
}
// Tiling along the first (x) dimension of scf.parallel only, as one NPU
// memtile is located at the bottom of each column.
Expand Down Expand Up @@ -1534,8 +1584,7 @@ void AIRSplitL2MemrefForBufferConstraintPass::runOnOperation() {
if (allocOp->hasAttr("split_dim"))
dim = allocOp->getAttrOfType<IntegerAttr>("split_dim").getInt();

partitionMemref(puts, gets, dim,
allocOp->getAttrOfType<StringAttr>("split_type").str());
partitionMemref(puts, gets, dim, allocOp);
}
for (auto allocOp : allocOps) {
if (auto execOp = dyn_cast<air::ExecuteOp>(allocOp->getParentOp())) {
Expand Down
12 changes: 6 additions & 6 deletions mlir/test/Transform/AIRMiscPasses/air_split_l2_memref.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1179,15 +1179,15 @@ module {

// Conv2d 3x3, stride 2 (overlapping l2 access).

// CHECK: [[$MAP0:#map[0-9]*]] = affine_map<()[s0] -> (s0 + 2)>
// CHECK: [[$MAP1:#map[0-9]+]] = affine_map<()[s0] -> (s0 + 4)>
// CHECK: [[$MAP2:#map[0-9]+]] = affine_map<()[s0] -> (s0 + 6)>
// CHECK-DAG: [[$MAP0:#map[0-9]*]] = affine_map<()[s0] -> (s0 + 2)>
// CHECK-DAG: [[$MAP1:#map[0-9]+]] = affine_map<()[s0] -> (s0 + 4)>
// CHECK-DAG: [[$MAP2:#map[0-9]+]] = affine_map<()[s0] -> (s0 + 6)>

// CHECK-LABEL: func.func @test9
// CHECK: air.launch
// CHECK: %[[VAL0:.*]] = affine.apply [[$MAP0]]()
// CHECK: %[[VAL1:.*]] = affine.apply [[$MAP1]]()
// CHECK: %[[VAL2:.*]] = affine.apply [[$MAP2]]()
// CHECK-DAG: %[[VAL0:.*]] = affine.apply [[$MAP0]]()
// CHECK-DAG: %[[VAL1:.*]] = affine.apply [[$MAP1]]()
// CHECK-DAG: %[[VAL2:.*]] = affine.apply [[$MAP2]]()
// CHECK: air.channel.put {{.*}} @channel_0[%c0, %c0]
// CHECK: air.channel.put {{.*}} @channel_0[%c1, %c0] (%{{.*}}[%c0, %[[VAL0]]
// CHECK: air.channel.put {{.*}} @channel_0[%c2, %c0] (%{{.*}}[%c0, %[[VAL1]]
Expand Down

0 comments on commit 2754dcf

Please sign in to comment.