diff --git a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp index 5dd952f16..6414f8b02 100644 --- a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp +++ b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp @@ -1621,6 +1621,53 @@ struct CanonicalizeAIRExecute : public OpRewritePattern { private: }; +affine::AffineForOp updateAffineForBounds(OpBuilder builder, IRMapping &remap, + affine::AffineForOp loop_op, int lb, + int ub, int step) { + affine::AffineForOp new_loop_op = builder.create( + builder.getUnknownLoc(), lb, ub, step); + remap.map(loop_op.getInductionVar(), new_loop_op.getInductionVar()); + // remap.map(old_apply.getResult(), new_loop_op.getInductionVar()); + auto insertionCheckpoint = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(new_loop_op.getBody()); + for (Operation &child_op : loop_op.getBody()->getOperations()) { + if (&child_op == loop_op.getBody()->getTerminator()) { /*Skip*/ + } else + builder.clone(child_op, remap); + } + builder.restoreInsertionPoint(insertionCheckpoint); + return new_loop_op; +} + +scf::ForOp updateScfForBounds(OpBuilder builder, IRMapping &remap, + scf::ForOp loop_op, int lb, int ub, int step) { + auto loc = loop_op->getLoc(); + SmallVector deps = + loop_op.getOperands().drop_front(loop_op.getNumControlOperands()); + scf::ForOp new_loop_op = builder.create( + builder.getUnknownLoc(), builder.create(loc, lb), + builder.create(loc, ub), + builder.create(loc, step), deps); + remap.map(loop_op.getInductionVar(), new_loop_op.getInductionVar()); + for (unsigned i = 0; i < loop_op.getRegionIterArgs().size(); i++) + remap.map(loop_op.getRegionIterArgs()[i], + new_loop_op.getRegionIterArgs()[i]); + auto insertionCheckpoint = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(new_loop_op.getBody()); + for (Operation &child_op : loop_op.getBody()->getOperations()) { + if (&child_op == loop_op.getBody()->getTerminator()) { + if (!new_loop_op.getBody()->mightHaveTerminator()) + builder.clone(child_op, remap); + } else + builder.clone(child_op, remap); + } + for (unsigned i = 0; i < loop_op->getNumResults(); i++) + loop_op->getResult(i).replaceAllUsesWith(new_loop_op->getResult(i)); + builder.restoreInsertionPoint(insertionCheckpoint); + return new_loop_op; +} + +// Fold affine.apply op operating on loop induction variable into loop bounds. struct CanonicalizeAffineApplyOnLoopInductionVar : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1709,54 +1756,184 @@ struct CanonicalizeAffineApplyOnLoopInductionVar output = simplifyAffineMap(newmap).getSingleConstantResult(); return output; } +}; - affine::AffineForOp updateAffineForBounds(OpBuilder builder, IRMapping &remap, - affine::AffineForOp loop_op, int lb, - int ub, int step) const { - affine::AffineForOp new_loop_op = builder.create( - builder.getUnknownLoc(), lb, ub, step); - remap.map(loop_op.getInductionVar(), new_loop_op.getInductionVar()); - // remap.map(old_apply.getResult(), new_loop_op.getInductionVar()); - auto insertionCheckpoint = builder.saveInsertionPoint(); - builder.setInsertionPointToStart(new_loop_op.getBody()); - for (Operation &child_op : loop_op.getBody()->getOperations()) { - if (&child_op == loop_op.getBody()->getTerminator()) { /*Skip*/ - } else - builder.clone(child_op, remap); +// Fold arith.muli op operating on loop induction variable into loop bounds. +struct CanonicalizeArithMuliOpOnLoopInductionVar + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::MulIOp op, + PatternRewriter &rewriter) const override { + Operation *containingOp = nullptr; + Value const_val = nullptr; + Value var_val = nullptr; + for (auto val : SmallVector{op.getLhs(), op.getRhs()}) { + if (getConstantIntValue(val)) { + const_val = val; + continue; + } + auto ivArg = llvm::dyn_cast(val); + if (!ivArg) + continue; + if (!ivArg.getOwner()) + continue; + if (!val.hasOneUse()) + continue; + if (op.getResult().use_empty()) + continue; + if (auto exec_muli = dyn_cast(op->getParentOp())) + if (exec_muli->getResult(1).use_empty()) + continue; + if (isa(ivArg.getOwner()->getParentOp())) { + containingOp = ivArg.getOwner()->getParentOp(); + var_val = val; + } else if (isa(ivArg.getOwner()->getParentOp())) { + containingOp = ivArg.getOwner()->getParentOp(); + var_val = val; + } } - builder.restoreInsertionPoint(insertionCheckpoint); - return new_loop_op; + if (!containingOp) + return failure(); + if (!const_val) + return failure(); + if (!var_val) + return failure(); + + // Apply arith muli to loop step and bound + int muli_factor = *mlir::getConstantIntValue(const_val); + if (auto sfo = dyn_cast(containingOp)) { + if (!getStaticScfForTripCountAsInt(sfo)) + return failure(); + int tripCount = *getStaticScfForTripCountAsInt(sfo); + int new_ub = + *mlir::getConstantIntValue(sfo.getUpperBound()) * muli_factor; + int new_lb = + *mlir::getConstantIntValue(sfo.getLowerBound()) * muli_factor; + int newStepInInt = llvm::divideCeilSigned(new_ub - new_lb, tripCount); + IRMapping remap; + if (auto exec = dyn_cast(op->getParentOp())) { + rewriter.setInsertionPoint(exec); + exec.getResult(1).replaceAllUsesWith(sfo.getInductionVar()); + exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]); + rewriter.eraseOp(exec); + } else { + rewriter.setInsertionPoint(op); + op.getResult().replaceAllUsesWith(sfo.getInductionVar()); + rewriter.eraseOp(op); + } + rewriter.setInsertionPoint(sfo); + updateScfForBounds(rewriter, remap, sfo, new_lb, new_ub, newStepInInt); + rewriter.eraseOp(sfo); + } else if (auto afo = dyn_cast(containingOp)) { + if (!afo.hasConstantBounds()) + return failure(); + int tripCount = *getStaticAffineForTripCountAsInt(afo); + int new_ub = afo.getConstantUpperBound() * muli_factor; + int new_lb = afo.getConstantLowerBound() * muli_factor; + int newStepInInt = llvm::divideCeilSigned(new_ub - new_lb, tripCount); + IRMapping remap; + rewriter.setInsertionPoint(afo); + op.getResult().replaceAllUsesWith(afo.getInductionVar()); + rewriter.eraseOp(op); + updateAffineForBounds(rewriter, remap, afo, new_lb, new_ub, newStepInInt); + rewriter.eraseOp(afo); + } else + return failure(); + + return success(); } - scf::ForOp updateScfForBounds(OpBuilder builder, IRMapping &remap, - scf::ForOp loop_op, int lb, int ub, - int step) const { - auto loc = loop_op->getLoc(); - SmallVector deps = - loop_op.getOperands().drop_front(loop_op.getNumControlOperands()); - scf::ForOp new_loop_op = builder.create( - builder.getUnknownLoc(), - builder.create(loc, lb), - builder.create(loc, ub), - builder.create(loc, step), deps); - remap.map(loop_op.getInductionVar(), new_loop_op.getInductionVar()); - for (unsigned i = 0; i < loop_op.getRegionIterArgs().size(); i++) - remap.map(loop_op.getRegionIterArgs()[i], - new_loop_op.getRegionIterArgs()[i]); - auto insertionCheckpoint = builder.saveInsertionPoint(); - builder.setInsertionPointToStart(new_loop_op.getBody()); - for (Operation &child_op : loop_op.getBody()->getOperations()) { - if (&child_op == loop_op.getBody()->getTerminator()) { - if (!new_loop_op.getBody()->mightHaveTerminator()) - builder.clone(child_op, remap); - } else - builder.clone(child_op, remap); +private: +}; + +// Fold arith.addi op operating on loop induction variable into loop bounds. +struct CanonicalizeArithAddiOpOnLoopInductionVar + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::AddIOp op, + PatternRewriter &rewriter) const override { + Operation *containingOp = nullptr; + Value const_val = nullptr; + Value var_val = nullptr; + for (auto val : SmallVector{op.getLhs(), op.getRhs()}) { + if (getConstantIntValue(val)) { + const_val = val; + continue; + } + auto ivArg = llvm::dyn_cast(val); + if (!ivArg) + continue; + if (!ivArg.getOwner()) + continue; + if (!val.hasOneUse()) + continue; + if (op.getResult().use_empty()) + continue; + if (auto exec_addi = dyn_cast(op->getParentOp())) + if (exec_addi->getResult(1).use_empty()) + continue; + if (isa(ivArg.getOwner()->getParentOp())) { + containingOp = ivArg.getOwner()->getParentOp(); + var_val = val; + } else if (isa(ivArg.getOwner()->getParentOp())) { + containingOp = ivArg.getOwner()->getParentOp(); + var_val = val; + } } - for (unsigned i = 0; i < loop_op->getNumResults(); i++) - loop_op->getResult(i).replaceAllUsesWith(new_loop_op->getResult(i)); - builder.restoreInsertionPoint(insertionCheckpoint); - return new_loop_op; + if (!containingOp) + return failure(); + if (!const_val) + return failure(); + if (!var_val) + return failure(); + + // Apply arith muli to loop step and bound + int addi_operand = *mlir::getConstantIntValue(const_val); + if (auto sfo = dyn_cast(containingOp)) { + if (!getStaticScfForTripCountAsInt(sfo)) + return failure(); + int tripCount = *getStaticScfForTripCountAsInt(sfo); + int new_ub = + *mlir::getConstantIntValue(sfo.getUpperBound()) + addi_operand; + int new_lb = + *mlir::getConstantIntValue(sfo.getLowerBound()) + addi_operand; + int newStepInInt = llvm::divideCeilSigned(new_ub - new_lb, tripCount); + IRMapping remap; + if (auto exec = dyn_cast(op->getParentOp())) { + rewriter.setInsertionPoint(exec); + exec.getResult(1).replaceAllUsesWith(sfo.getInductionVar()); + exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]); + rewriter.eraseOp(exec); + } else { + rewriter.setInsertionPoint(op); + op.getResult().replaceAllUsesWith(sfo.getInductionVar()); + rewriter.eraseOp(op); + } + rewriter.setInsertionPoint(sfo); + updateScfForBounds(rewriter, remap, sfo, new_lb, new_ub, newStepInInt); + rewriter.eraseOp(sfo); + } else if (auto afo = dyn_cast(containingOp)) { + if (!afo.hasConstantBounds()) + return failure(); + int tripCount = *getStaticAffineForTripCountAsInt(afo); + int new_ub = afo.getConstantUpperBound() + addi_operand; + int new_lb = afo.getConstantLowerBound() + addi_operand; + int newStepInInt = llvm::divideCeilSigned(new_ub - new_lb, tripCount); + IRMapping remap; + rewriter.setInsertionPoint(afo); + op.getResult().replaceAllUsesWith(afo.getInductionVar()); + rewriter.eraseOp(op); + updateAffineForBounds(rewriter, remap, afo, new_lb, new_ub, newStepInInt); + rewriter.eraseOp(afo); + } else + return failure(); + + return success(); } + +private: }; struct AIRSpecializeChannelWrapAndStrideInScfFor @@ -1772,11 +1949,12 @@ struct AIRSpecializeChannelWrapAndStrideInScfFor auto hasNElements = [](Block *block, unsigned N) { unsigned counter = 0; for (auto &o : block->getOperations()) { - if (o.mightHaveTrait()) - continue; - if (isa(o)) - continue; - counter++; + if (isa(o)) + counter++; + else if (isa(o)) + counter++; + else if (isa(o)) + counter++; } return counter == N; }; @@ -1786,10 +1964,6 @@ struct AIRSpecializeChannelWrapAndStrideInScfFor if (!hasNElements(for_op.getBody(), 1)) return failure(); - if (isa(for_op.getBody()->begin())) { - } else if (isa(for_op.getBody()->begin())) { - } else - return failure(); // Check if the loop nest contains exactly one channel op SmallVector channel_ops; @@ -1813,10 +1987,6 @@ struct AIRSpecializeChannelWrapAndStrideInScfFor // Check for perfect loop nest containing only air.channel ops if (!hasNElements(o.getBody(), 1)) return failure(); - if (isa(o.getBody()->begin())) { - } else if (isa(o.getBody()->begin())) { - } else - return failure(); if (!getStaticScfForTripCountAsInt(o)) return failure(); } @@ -1886,11 +2056,12 @@ struct AIRSpecializeChannelWrapAndStrideInAffineFor auto hasNElements = [](Block *block, unsigned N) { unsigned counter = 0; for (auto &o : block->getOperations()) { - if (o.mightHaveTrait()) - continue; - if (isa(o)) - continue; - counter++; + if (isa(o)) + counter++; + else if (isa(o)) + counter++; + else if (isa(o)) + counter++; } return counter == N; }; @@ -1900,10 +2071,6 @@ struct AIRSpecializeChannelWrapAndStrideInAffineFor if (!hasNElements(for_op.getBody(), 1)) return failure(); - if (isa(for_op.getBody()->begin())) { - } else if (isa(for_op.getBody()->begin())) { - } else - return failure(); // Check if the loop nest contains exactly one channel op SmallVector channel_ops; @@ -1927,10 +2094,6 @@ struct AIRSpecializeChannelWrapAndStrideInAffineFor // Check for perfect loop nest containing only air.channel ops if (!hasNElements(o.getBody(), 1)) return failure(); - if (isa(o.getBody()->begin())) { - } else if (isa(o.getBody()->begin())) { - } else - return failure(); if (!getStaticAffineForTripCountAsInt(o)) return failure(); } @@ -2832,14 +2995,20 @@ class AIRSpecializeChannelWrapAndStridePattern void runOptPatterns(func::FuncOp funcOp) { MLIRContext *ctx = funcOp.getContext(); + RewritePatternSet preproc_patterns(&getContext()); + preproc_patterns.insert(ctx); + // Canonicalize constant operands in affine.apply. + mlir::affine::AffineApplyOp::getCanonicalizationPatterns(preproc_patterns, + ctx); + air::WaitAllOp::getCanonicalizationPatterns(preproc_patterns, ctx); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(preproc_patterns)); + RewritePatternSet patterns(&getContext()); - patterns.insert(ctx); - // Canonicalize constant operands in affine.apply. - mlir::affine::AffineApplyOp::getCanonicalizationPatterns(patterns, ctx); - air::WaitAllOp::getCanonicalizationPatterns(patterns, ctx); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); // Canonicalize wrap and stride list to remove redundant dimensions diff --git a/mlir/lib/Util/Util.cpp b/mlir/lib/Util/Util.cpp index 6cf57d22a..245a7404b 100644 --- a/mlir/lib/Util/Util.cpp +++ b/mlir/lib/Util/Util.cpp @@ -860,7 +860,6 @@ LogicalResult eraseWrapNStrideDim(OpBuilder builder, continue; auto const_stride_next = getConstantIntValue(strides[i + 1]); assert(const_stride_next && "non-static stride, NYI."); - erased |= multiplyAdjWraps(builder, i, sizes); if (const_offset) { offsets[i + 1] = builder.create( builder.getUnknownLoc(), @@ -890,6 +889,8 @@ LogicalResult eraseWrapNStrideDim(OpBuilder builder, offset_producer = exec.getChildOp(); auto affine_apply = dyn_cast(offset_producer); assert(affine_apply && "ssa offset not produced by affine.apply, NYI."); + if (affine_apply->getNumOperands() > 1) + continue; // Compose affine map auto offset_expr = getAffineSymbolExpr(0, builder.getContext()); auto stride_expr = @@ -908,6 +909,7 @@ LogicalResult eraseWrapNStrideDim(OpBuilder builder, affine_apply.setMap(next_offset_map); offsets[i + 1] = offsets[i]; } + erased |= multiplyAdjWraps(builder, i, sizes); offsets.erase(offsets.begin() + i); sizes.erase(sizes.begin() + i); strides.erase(strides.begin() + i); @@ -1010,34 +1012,27 @@ LogicalResult air::foldForLoopNestAsExtendedSizesAndStrides( else if (isa(parent)) for_loops.push_back(parent); } + + // First traversal inserting new dimensions from loops for (auto o : for_loops) { uint64_t ind_var_factor = 0; for (int i = offsets.size() - 1; i >= 0; i--) { Value iv = nullptr; - int loop_lower_bound = 0; - if (auto afo = dyn_cast(o)) { + if (auto afo = dyn_cast(o)) iv = afo.getInductionVar(); - loop_lower_bound = afo.getConstantLowerBound(); - } else if (auto sfo = dyn_cast(o)) { + else if (auto sfo = dyn_cast(o)) iv = sfo.getInductionVar(); - if (auto cst_lower_bound = - mlir::getConstantIntValue(sfo.getLowerBound())) - loop_lower_bound = *cst_lower_bound; - } if (iv && offsets[i] == iv) { - // Replace for loop induction vars in offsets with zero - offsets[i] = builder.template create( - loc, loop_lower_bound); ind_var_factor = *getConstantIntValue(strides[i]); break; } else if (iv && offsets[i].getDefiningOp()) { - if (isa(offsets[i].getDefiningOp()) && - offsets[i].getDefiningOp()->getOperand(0) == iv) { - offsets[i] = builder.template create( - loc, loop_lower_bound); + Operation *iv_consumer = offsets[i].getDefiningOp(); + if (auto exec = dyn_cast(iv_consumer)) + iv_consumer = exec.getChildOp(); + if (llvm::is_contained(iv_consumer->getOperands(), iv)) { ind_var_factor = *getConstantIntValue(strides[i]); break; - }; + } } } int trip_count = -1; @@ -1074,6 +1069,38 @@ LogicalResult air::foldForLoopNestAsExtendedSizesAndStrides( wraps.insert(wraps.begin(), new_wrap); strides.insert(strides.begin(), new_stride); } + + // Second traversal updating existing offsets + for (auto o : for_loops) { + for (int i = offsets.size() - 1; i >= 0; i--) { + Value iv = nullptr; + int loop_lower_bound = 0; + if (auto afo = dyn_cast(o)) { + iv = afo.getInductionVar(); + loop_lower_bound = afo.getConstantLowerBound(); + } else if (auto sfo = dyn_cast(o)) { + iv = sfo.getInductionVar(); + if (auto cst_lower_bound = + mlir::getConstantIntValue(sfo.getLowerBound())) + loop_lower_bound = *cst_lower_bound; + } + if (iv && offsets[i] == iv) { + // Replace offset with for loop lower bound + offsets[i] = builder.template create( + loc, loop_lower_bound); + break; + } else if (iv && offsets[i].getDefiningOp()) { + Operation *iv_consumer = offsets[i].getDefiningOp(); + if (auto exec = dyn_cast(iv_consumer)) + iv_consumer = exec.getChildOp(); + if (llvm::is_contained(iv_consumer->getOperands(), iv)) { + offsets[i] = builder.template create( + loc, loop_lower_bound); + break; + } + } + } + } return success(); } diff --git a/mlir/test/Transform/AIRDependencyScheduleOpt/specialize-channel-wrap-and-stride.mlir b/mlir/test/Transform/AIRDependencyScheduleOpt/specialize-channel-wrap-and-stride.mlir index c34106d84..7ff64a7ca 100644 --- a/mlir/test/Transform/AIRDependencyScheduleOpt/specialize-channel-wrap-and-stride.mlir +++ b/mlir/test/Transform/AIRDependencyScheduleOpt/specialize-channel-wrap-and-stride.mlir @@ -11,6 +11,7 @@ #map = affine_map<()[s0] -> (s0 * 32)> #map1 = affine_map<()[s0, s1] -> (s0 + s1)> +#map2 = affine_map<(d0, d1) -> (d0 + d1)> module { // CHECK-LABEL: test0 @@ -375,4 +376,78 @@ module { } return } + + // Affine.apply with map joining two for loops in a loop nest. + // CHECK-LABEL: test11 + + // CHECK: air.channel.put async [%{{.*}}] @channel_26[%c0, %c0] (%{{.*}}[%c0, %c0, %c0] [%c4_0, %c18, %c4_0] [%c96, %c16, %c1]) : (memref<1x6x6x16xbf16, 1>) + + func.func @test11() { + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %0 = air.launch async (%arg3, %arg4, %arg5) in (%arg6=%c3, %arg7=%c3, %arg8=%c4) { + %1 = air.segment @segment_0 async { + %c576 = arith.constant 576 : index + %c96 = arith.constant 96 : index + %c3_0 = arith.constant 3 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c6 = arith.constant 6 : index + %c0 = arith.constant 0 : index + %c4_1 = arith.constant 4 : index + %async_token, %results = air.execute -> (memref<1x6x6x16xbf16, 1>) { + %alloc = memref.alloc() : memref<1x6x6x16xbf16, 1> + air.execute_terminator %alloc : memref<1x6x6x16xbf16, 1> + } + %4 = scf.for %arg9 = %c0 to %c4_1 step %c1 iter_args(%arg13 = %async_token) -> (!air.async.token) { + %2 = scf.for %arg10 = %c0 to %c3_0 step %c1 iter_args(%arg11 = %arg13) -> (!air.async.token) { + %async_token_2, %results_3 = air.execute [%arg11] -> (index) { + %4 = affine.apply #map2(%arg9, %arg10) + air.execute_terminator %4 : index + } + %3 = air.channel.put async [%async_token_2] @channel_26[%c0, %c0] (%results[%c0, %results_3, %c0, %c0] [%c1, %c1, %c6, %c4_1] [%c576, %c96, %c16, %c1]) : (memref<1x6x6x16xbf16, 1>) + scf.yield %3 : !air.async.token + } + scf.yield %2 : !air.async.token + } + } + } + return + } + + // Arith.muli and addi folding into loops. + // CHECK-LABEL: test12 + + // CHECK: air.channel.put async [%{{.*}}] @channel_27[%c0, %c0] (%{{.*}}[%c0, %c0, %c8] [%c2, %c2, %c128] [%c8, %c16, %c1]) : (memref<32x16xi32>) + + func.func @test12(%arg0: memref<32x16xi32>) { + %0 = air.launch async () in () args(%arg2=%arg0) : memref<32x16xi32> { + %1 = air.segment @seg async args(%arg3=%arg2) : memref<32x16xi32> { + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %2 = air.wait_all async + %3 = scf.for %arg4 = %c0 to %c2 step %c1 iter_args(%arg5 = %2) -> (!air.async.token) { + %4 = scf.for %arg6 = %c0 to %c2 step %c1 iter_args(%arg7 = %arg5) -> (!air.async.token) { + %async_token, %results = air.execute [%arg7] -> (index) { + %6 = arith.addi %arg4, %c1 : index + air.execute_terminator %6 : index + } + %async_token_0, %results_1 = air.execute [%arg7] -> (index) { + %6 = arith.muli %arg6, %c16 : index + air.execute_terminator %6 : index + } + %5 = air.channel.put async [%async_token_0, %async_token, %arg7] @channel_27[%c0, %c0] (%arg3[%results, %results_1] [%c16, %c8] [%c8, %c1]) : (memref<32x16xi32>) + scf.yield %5 : !air.async.token + } + scf.yield %4 : !air.async.token + } + } + } + return + } + }