From 2754dcf53e9dce76df4e1a8d5f34bcb2acd1871e Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Mon, 22 Jul 2024 08:44:53 -0700 Subject: [PATCH] Hotfix AIRSplitL2Memref: a number of bugfixes to get conv2d with stride 2 to work (#676) * Fixup cascade in affine.map if overlapping; fixup wrap if overlapping * Roll back split_dim_size for split_by_channel mode * Offset dim to memref dim conversions * Make test more flexible with CHECK-DAG --- mlir/lib/Transform/AIRMiscPasses.cpp | 107 +++++++++++++----- .../AIRMiscPasses/air_split_l2_memref.mlir | 12 +- 2 files changed, 84 insertions(+), 35 deletions(-) diff --git a/mlir/lib/Transform/AIRMiscPasses.cpp b/mlir/lib/Transform/AIRMiscPasses.cpp index 18cf6ff61..76ba6bf8f 100644 --- a/mlir/lib/Transform/AIRMiscPasses.cpp +++ b/mlir/lib/Transform/AIRMiscPasses.cpp @@ -851,7 +851,7 @@ class AIRSplitL2MemrefForBufferConstraintPass private: void partitionMemref(SmallVector &puts, SmallVector &gets, int dim, - std::string splitType); + Operation *allocOp); SmallVector getTargetMemrefAllocs(func::FuncOp func, std::map> @@ -888,25 +888,31 @@ Value tileChannelOpByFactor(air::ChannelInterface originalChanOp, int factor, OpBuilder builder(originalChanOp); SmallVector originalApplyOperands; Operation *affineApplyOp = nullptr; + Value originalApplyOutput = nullptr; + Value zeroIdx = builder.create(loc, 0); if (!originalChanOp.getOffsets().empty()) { auto offsetDefOp = originalChanOp.getOffsets()[dim].getDefiningOp(); if ((offsetDefOp && isa(offsetDefOp)) || isa(offsetDefOp)) affineApplyOp = offsetDefOp; } - if (affineApplyOp && isa(affineApplyOp)) + if (affineApplyOp && isa(affineApplyOp)) { originalApplyOperands = affineApplyOp->getOperands(); - else if (affineApplyOp && isa(affineApplyOp)) { + originalApplyOutput = affineApplyOp->getResult(0); + } else if (affineApplyOp && isa(affineApplyOp)) { auto execOp = dyn_cast(affineApplyOp); originalApplyOperands = execOp.getChildOp()->getOperands(); - } else - originalApplyOperands.push_back( - builder.create(loc, 0)); + originalApplyOutput = affineApplyOp->getResult(1); + } else { + originalApplyOperands.push_back(zeroIdx); + originalApplyOutput = originalChanOp.getOffsets().empty() + ? zeroIdx + : originalChanOp.getOffsets()[dim]; + } SmallVector tokens; for (int i = 0; i < factor; i++) { SmallVector newIndices{ - builder.create(loc, i), - builder.create(loc, 0)}; + builder.create(loc, i), zeroIdx}; // Update y offset. // Create affine.apply on induction variable. auto checkpoint = builder.saveInsertionPoint(); @@ -920,15 +926,24 @@ Value tileChannelOpByFactor(air::ChannelInterface originalChanOp, int factor, auto map = AffineMap::get(0, 1, add); // If allocOp has "affine_map" attribute set, then use that map instead // (potentially overlapping access pattern). + affine::AffineApplyOp newApplyOp = nullptr; if (allocOp->hasAttr("affine_map")) { auto original_map = allocOp->getAttrOfType("affine_map").getAffineMap(); - if (original_map.getNumInputs() == 2) + if (original_map.getNumInputs() == 2) { + // Overlapping data access. map = original_map.replace(getAffineSymbolExpr(1, ctx), getAffineConstantExpr(i, ctx), 0, 1); - } - auto newApplyOp = - builder.create(loc, map, originalApplyOperands); + if (affineApplyOp) + builder.setInsertionPointAfter(affineApplyOp); + newApplyOp = builder.create( + loc, map, SmallVector{originalApplyOutput}); + } else // Non-overlapping data access. + newApplyOp = builder.create( + loc, map, originalApplyOperands); + } else // Non-overlapping data access. + newApplyOp = builder.create(loc, map, + originalApplyOperands); if (affineApplyOp) builder.restoreInsertionPoint(checkpoint); SmallVector newOffsets = originalChanOp.getOffsets(); @@ -938,8 +953,13 @@ Value tileChannelOpByFactor(air::ChannelInterface originalChanOp, int factor, air::populateDefaultWrapsAndStrides(builder, originalChanOp.getMemref(), newOffsets, newWraps, newStrides); newOffsets[dim] = newApplyOp.getResult(); - newWraps[dim] = builder.create( - loc, llvm::divideCeilSigned(originalMemrefSize, factor)); + // Get post-splitting size at split_dim from allocOp attributes. + if (allocOp->hasAttr("split_dim_size")) + newWraps[dim] = builder.create( + loc, allocOp->getAttrOfType("split_dim_size").getInt()); + else + newWraps[dim] = builder.create( + loc, llvm::divideCeilSigned(originalMemrefSize, factor)); auto deps = dyn_cast(originalChanOp.getOperation()) .getAsyncDependencies(); SmallVector tys = {air::AsyncTokenType::get(ctx)}; @@ -989,10 +1009,16 @@ scf::ForOp getScfForFromVal(Value val) { // Partition L2 memref. void AIRSplitL2MemrefForBufferConstraintPass::partitionMemref( SmallVector &puts, SmallVector &gets, - int dim, std::string splitType = "") { + int dim, Operation *allocOp) { auto memref = puts.front().getMemref(); MemRefType ty = llvm::cast(memref.getType()); - auto allocOp = memref.getDefiningOp(); + // allocOp attributes: + int split_dim_size = -1; + if (allocOp->hasAttr("split_dim_size")) + split_dim_size = + allocOp->getAttrOfType("split_dim_size").getInt(); + if (isa(allocOp->getParentOp())) + allocOp = allocOp->getParentOp(); auto loc = allocOp->getLoc(); auto ctx = allocOp->getContext(); Operation *deallocOp = nullptr; @@ -1081,18 +1107,24 @@ void AIRSplitL2MemrefForBufferConstraintPass::partitionMemref( continue; if (op.getSizes().size() != newMemrefShape.size()) continue; - auto offset = getConstantIntValue(op.getOffsets()[*offsetDim]); - if (offset) - newMemrefShape[dim] = *getConstantIntValue(op.getSizes()[*offsetDim]); + + // Get post-splitting size at split_dim from allocOp attributes. + if (split_dim_size >= 0) + newMemrefShape[dim] = split_dim_size; else { - auto forOp = getScfForFromVal(op.getOffsets()[*offsetDim]); - if (!forOp) - continue; - auto trip_count = air::getStaticScfForTripCountAsInt(forOp); - if (!trip_count) - continue; - newMemrefShape[dim] = - *getConstantIntValue(op.getSizes()[*offsetDim]) * (*trip_count); + auto offset = getConstantIntValue(op.getOffsets()[*offsetDim]); + if (offset) + newMemrefShape[dim] = *getConstantIntValue(op.getSizes()[*offsetDim]); + else { + auto forOp = getScfForFromVal(op.getOffsets()[*offsetDim]); + if (!forOp) + continue; + auto trip_count = air::getStaticScfForTripCountAsInt(forOp); + if (!trip_count) + continue; + newMemrefShape[dim] = + *getConstantIntValue(op.getSizes()[*offsetDim]) * (*trip_count); + } } break; } @@ -1313,6 +1345,24 @@ AIRSplitL2MemrefForBufferConstraintPass::getTargetMemrefAllocs( allocOp->setAttr("affine_map", AffineMapAttr::get(apply.getAffineMap())); } + + // Infer the size at splitDim for both overlapping and non-overlapping + // access pattern. + auto offsetDimOpt = air::getOffsetDimFromMemrefDim( + splitDim, chanOp.getStrides(), + air::getTensorShape(memref.getType())); + auto constOffset = + getConstantIntValue(chanOp.getOffsets()[*offsetDimOpt]); + if (!constOffset) + if (auto forOp = + getScfForFromVal(chanOp.getOffsets()[*offsetDimOpt])) + if (auto trip_count = air::getStaticScfForTripCountAsInt(forOp)) + allocOp->setAttr( + "split_dim_size", + IntegerAttr::get( + IntegerType::get(ctx, 32), + *getConstantIntValue(chanOp.getSizes()[*offsetDimOpt]) * + (*trip_count))); } // Tiling along the first (x) dimension of scf.parallel only, as one NPU // memtile is located at the bottom of each column. @@ -1534,8 +1584,7 @@ void AIRSplitL2MemrefForBufferConstraintPass::runOnOperation() { if (allocOp->hasAttr("split_dim")) dim = allocOp->getAttrOfType("split_dim").getInt(); - partitionMemref(puts, gets, dim, - allocOp->getAttrOfType("split_type").str()); + partitionMemref(puts, gets, dim, allocOp); } for (auto allocOp : allocOps) { if (auto execOp = dyn_cast(allocOp->getParentOp())) { diff --git a/mlir/test/Transform/AIRMiscPasses/air_split_l2_memref.mlir b/mlir/test/Transform/AIRMiscPasses/air_split_l2_memref.mlir index 57749050e..40e564136 100644 --- a/mlir/test/Transform/AIRMiscPasses/air_split_l2_memref.mlir +++ b/mlir/test/Transform/AIRMiscPasses/air_split_l2_memref.mlir @@ -1179,15 +1179,15 @@ module { // Conv2d 3x3, stride 2 (overlapping l2 access). -// CHECK: [[$MAP0:#map[0-9]*]] = affine_map<()[s0] -> (s0 + 2)> -// CHECK: [[$MAP1:#map[0-9]+]] = affine_map<()[s0] -> (s0 + 4)> -// CHECK: [[$MAP2:#map[0-9]+]] = affine_map<()[s0] -> (s0 + 6)> +// CHECK-DAG: [[$MAP0:#map[0-9]*]] = affine_map<()[s0] -> (s0 + 2)> +// CHECK-DAG: [[$MAP1:#map[0-9]+]] = affine_map<()[s0] -> (s0 + 4)> +// CHECK-DAG: [[$MAP2:#map[0-9]+]] = affine_map<()[s0] -> (s0 + 6)> // CHECK-LABEL: func.func @test9 // CHECK: air.launch -// CHECK: %[[VAL0:.*]] = affine.apply [[$MAP0]]() -// CHECK: %[[VAL1:.*]] = affine.apply [[$MAP1]]() -// CHECK: %[[VAL2:.*]] = affine.apply [[$MAP2]]() +// CHECK-DAG: %[[VAL0:.*]] = affine.apply [[$MAP0]]() +// CHECK-DAG: %[[VAL1:.*]] = affine.apply [[$MAP1]]() +// CHECK-DAG: %[[VAL2:.*]] = affine.apply [[$MAP2]]() // CHECK: air.channel.put {{.*}} @channel_0[%c0, %c0] // CHECK: air.channel.put {{.*}} @channel_0[%c1, %c0] (%{{.*}}[%c0, %[[VAL0]] // CHECK: air.channel.put {{.*}} @channel_0[%c2, %c0] (%{{.*}}[%c0, %[[VAL1]]