diff --git a/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp index 845485667d38..a78b6b83876b 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp @@ -7,7 +7,6 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" -#include "iree/compiler/DispatchCreation/FusionUtils.h" #include "iree/compiler/DispatchCreation/Passes.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/TopologicalSortUtils.h" @@ -108,6 +107,25 @@ static bool isEmptyFillContractionDAGRootOp( return true; } +/// Check that a given operation is "horizontal" to the group. The operation +/// is horizontal if the `slice` of the operation does not contain any op +/// from the group. +static bool isHorizontalToGroup(Operation *op, + const llvm::SetVector &currGroup, + const DominanceInfo &dominanceInfo, + Operation *seedOp) { + BackwardSliceOptions options; + // Limit the slice to the seed to make sure the slice is small. + options.filter = [&](Operation *op) { + return !dominanceInfo.properlyDominates(op, seedOp); + }; + llvm::SetVector slice; + getBackwardSlice(op, &slice, options); + return !llvm::any_of(currGroup, [&](Operation *groupedOp) { + return slice.contains(groupedOp); + }); +} + /// Get user of operation that is a truncate operation. static std::optional getTruncateOp(Operation *op, @@ -131,8 +149,8 @@ getTruncateOp(Operation *op, if (!checkOperationEquivalence(genericOp, seedTruncateOp.value())) { return std::nullopt; } - if (!isHorizontalToGroup(genericOp, groupedOperations.getArrayRef(), - dominanceInfo, seedTruncateOp.value())) { + if (!isHorizontalToGroup(genericOp, groupedOperations, dominanceInfo, + seedTruncateOp.value())) { return std::nullopt; } } @@ -208,8 +226,7 @@ static std::optional getHorizontalFusionGroupMembers( if (!dominanceInfo.properlyDominates(seedOp, linalgOp)) { return false; } - if (!isHorizontalToGroup(linalgOp, allOps.getArrayRef(), dominanceInfo, - seedOp)) { + if (!isHorizontalToGroup(linalgOp, allOps, dominanceInfo, seedOp)) { return false; } return true; @@ -329,6 +346,40 @@ static AffineMap getConcatenatedIndexingMap(RewriterBase &rewriter, return newIndexingMap.insertResult(rewriter.getAffineDimExpr(0), 0); } +/// During horizontal fusion, there might be operands of the fused operations +/// whose definitions are interspersed between the fused operations. For groups +/// chosen to fuse horizontally, such operations can be moved before the +/// seed contraction operation (where the fused operation is generated). +template +static LogicalResult +moveOperandDefs(RewriterBase &rewriter, ArrayRef operations, + Operation *insertionPoint, DominanceInfo &dominanceInfo, + ArrayRef ignoreOperations = {}) { + BackwardSliceOptions options; + llvm::DenseSet ignoreOperationsSet; + ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end()); + options.filter = [&](Operation *op) { + return !dominanceInfo.properlyDominates(op, insertionPoint) && + !ignoreOperationsSet.contains(op); + }; + // Set inclusive to true cause the slice is computed from the operand, and + // we want to include the defining op (which is the point here) + options.inclusive = true; + + llvm::SetVector slice; + for (auto op : operations) { + for (auto operand : op->getOperands()) { + getBackwardSlice(operand, &slice, options); + } + } + + mlir::topologicalSort(slice); + for (auto op : slice) { + rewriter.moveOpBefore(op, insertionPoint); + } + return success(); +} + /// On finding this pattern /// ``` /// %0 = linalg.matmul ins(%arg0, %arg1) diff --git a/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp index d79d5145e77d..9d9d477c9a57 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp @@ -16,13 +16,9 @@ #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" -#include "iree/compiler/DispatchCreation/FusionUtils.h" #include "iree/compiler/DispatchCreation/Passes.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -49,55 +45,25 @@ static llvm::cl::opt clLinalgMaxConstantFoldElements( llvm::cl::desc("Maximum number of elements to try to constant fold."), llvm::cl::init(0)); -static Operation *getMostDominantUse(Operation *op, - const DominanceInfo &dominanceInfo) { - auto uses = op->getUses(); - auto it = llvm::find_if(uses, [&](OpOperand &source) { - Operation *sourceOp = source.getOwner(); - - return llvm::all_of(uses, [&](OpOperand &target) { - Operation *targetOp = target.getOwner(); - return dominanceInfo.dominates(sourceOp, targetOp); - }); - }); - if (it != uses.end()) { - return it->getOwner(); - } - return nullptr; -} - /// Check if any of the use dominates all other uses of the operation. -static Operation *getFusableUse(Operation *op, - const DominanceInfo &dominanceInfo) { +static std::optional getFusableUse(Operation *op, + DominanceInfo &dominanceInfo) { auto uses = op->getUses(); - Operation *fusableUse = nullptr; for (OpOperand &source : uses) { Operation *sourceOp = source.getOwner(); - - bool dominatesAllFusableOps = llvm::all_of(uses, [&](OpOperand &target) { + bool dominatesAllUsers = true; + for (OpOperand &target : uses) { Operation *targetOp = target.getOwner(); - return !isa(targetOp) || - dominanceInfo.dominates(sourceOp, targetOp); - }); - if (dominatesAllFusableOps) { - fusableUse = sourceOp; - break; + if (!dominanceInfo.dominates(sourceOp, targetOp)) { + dominatesAllUsers = false; + break; + } + } + if (dominatesAllUsers) { + return &source; } } - Operation *mostDominantOp = getMostDominantUse(op, dominanceInfo); - if (!fusableUse || !mostDominantOp) { - return nullptr; - } - - // If `fusableUse` dominates all other users, there's nothing else to do. - if (fusableUse == mostDominantOp) { - return fusableUse; - } - - SmallVector users(op->getUsers().begin(), op->getUsers().end()); - return isHorizontalToGroup(fusableUse, users, dominanceInfo, mostDominantOp) - ? fusableUse - : nullptr; + return std::nullopt; } static OpOperand *getFirstUseInConsumer(Operation *producer, @@ -125,7 +91,6 @@ static SmallVector getAllUsesInConsumer(Operation *producer, /// using elementwise fusion. static LogicalResult doMultiUseFusion(Operation *rootOp, llvm::SetVector &fusableOps, - const DominanceInfo &dominanceInfo, RewriterBase &rewriter) { assert(rootOp && "root op cant be null"); @@ -147,20 +112,11 @@ static LogicalResult doMultiUseFusion(Operation *rootOp, Operation *consumerOp = rootOp; OpBuilder::InsertionGuard g(rewriter); for (Operation *producerOp : llvm::reverse(fusedOpsVec)) { - Operation *mostDominantUser = getMostDominantUse(producerOp, dominanceInfo); // Fuse all uses from producer -> consumer. It has been checked // before that all uses are fusable. while (OpOperand *fusedOperand = getFirstUseInConsumer(producerOp, consumerOp)) { rewriter.setInsertionPoint(consumerOp); - - if (consumerOp != mostDominantUser && - failed(moveOperandDefs(rewriter, ArrayRef{consumerOp}, - mostDominantUser, dominanceInfo))) { - return rewriter.notifyMatchFailure(consumerOp, - "failed to move operand defs"); - } - rewriter.moveOpBefore(consumerOp, mostDominantUser); FailureOr fusionResult = linalg::fuseElementwiseOps(rewriter, fusedOperand); if (failed(fusionResult)) { @@ -234,8 +190,9 @@ static FailureOr fuseMultiUseProducers(Operation *funcOp, } // 6. Check that the `genericOp` dominates all uses of `producer`. - Operation *fusableUse = getFusableUse(producer, dominanceInfo); - if (!fusableUse || fusableUse != genericOp) { + std::optional fusableUse = + getFusableUse(producer, dominanceInfo); + if (!fusableUse || fusableUse.value()->getOwner() != genericOp) { continue; } @@ -275,8 +232,7 @@ static FailureOr fuseMultiUseProducers(Operation *funcOp, IRRewriter rewriter(context); for (auto it = fusedOps.rbegin(), ie = fusedOps.rend(); it != ie; ++it) { - if (failed( - doMultiUseFusion(it->first, it->second, dominanceInfo, rewriter))) { + if (failed(doMultiUseFusion(it->first, it->second, rewriter))) { return funcOp->emitOpError("failed multi use fusion"); } } diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp index 3e3e6532af84..c428091f6cf8 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp @@ -10,11 +10,7 @@ #include "compiler/src/iree/compiler/DispatchCreation/FusionUtils.h" #include "compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" -#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/IR/Dominance.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Transforms/RegionUtils.h" namespace mlir::iree_compiler::DispatchCreation { @@ -101,22 +97,4 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand, return true; } -bool isHorizontalToGroup(Operation *op, ArrayRef currGroup, - const DominanceInfo &dominanceInfo, - Operation *seedOp) { - assert(dominanceInfo.properlyDominates(seedOp, op) && - op->getParentRegion() == seedOp->getParentRegion()); - BackwardSliceOptions options; - options.omitUsesFromAbove = false; - // Limit the slice to the seed to make sure the slice is small. - options.filter = [&](Operation *op) { - return !dominanceInfo.properlyDominates(op, seedOp); - }; - llvm::SetVector slice; - getBackwardSlice(op, &slice, options); - return !llvm::any_of(currGroup, [&](Operation *groupedOp) { - return slice.contains(groupedOp); - }); -} - } // namespace mlir::iree_compiler::DispatchCreation diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h index a264db94010e..1d9c9306f7ae 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h +++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h @@ -10,10 +10,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Analysis/TopologicalSortUtils.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/IR/Dominance.h" #include "mlir/IR/Operation.h" namespace mlir::iree_compiler::DispatchCreation { @@ -23,45 +19,4 @@ namespace mlir::iree_compiler::DispatchCreation { bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand, bool fuseMultiReduction); -/// Check that a given operation is "horizontal" to the group. The operation -/// is horizontal if the program slice of the operation (from op back to seedOp) -/// does not contain any op from the group. -bool isHorizontalToGroup(Operation *op, ArrayRef currGroup, - const DominanceInfo &dominanceInfo, Operation *seedOp); - -/// Moves the operands and transitive defs for each op in `operations` directly -/// after `insertionPoint`. Note: this does not check if it is legal to move the -/// operands. -template -static LogicalResult -moveOperandDefs(RewriterBase &rewriter, ArrayRef operations, - Operation *insertionPoint, const DominanceInfo &dominanceInfo, - ArrayRef ignoreOperations = {}) { - BackwardSliceOptions options; - options.omitUsesFromAbove = false; - llvm::DenseSet ignoreOperationsSet; - ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end()); - options.filter = [&](Operation *op) { - return !dominanceInfo.properlyDominates(op, insertionPoint) && - !ignoreOperationsSet.contains(op); - }; - // Set inclusive to true cause the slice is computed from the operand, and - // we want to include the defining op (which is the point here) - options.inclusive = true; - - llvm::SetVector slice; - for (auto op : operations) { - assert(insertionPoint->getBlock() == op->getBlock()); - for (auto operand : op->getOperands()) { - getBackwardSlice(operand, &slice, options); - } - } - - mlir::topologicalSort(slice); - for (auto op : slice) { - rewriter.moveOpBefore(op, insertionPoint); - } - return success(); -} - } // namespace mlir::iree_compiler::DispatchCreation diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir index c6af7b1e8ca5..cc3e159ca943 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir @@ -139,84 +139,3 @@ util.func public @math_sin() { // CHECK: %[[GENERIC:.+]]:2 = linalg.generic // CHECK-DAG: check.expect_almost_eq(%[[GENERIC]]#0, // CHECK-DAG: check.expect_almost_eq(%[[GENERIC]]#1, - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -util.func public @fuse_by_moving_consumer(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) { - %cst = arith.constant 1.000000e+00 : f32 - %cst_0 = arith.constant 2.000000e+00 : f32 - %cst_1 = arith.constant 3.000000e+00 : f32 - %4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { - ^bb0(%arg2: f32, %arg3: f32): - %8 = arith.addf %arg2, %cst : f32 - linalg.yield %8 : f32 - } -> tensor<5x5xf32> - // expected-note @below {{prior use here}} - %collapsed = tensor.collapse_shape %4 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32> - %5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { - ^bb0(%arg2: f32, %arg3: f32): - %8 = arith.subf %arg2, %cst_0 : f32 - linalg.yield %8 : f32 - } -> tensor<5x5xf32> - util.return %5, %collapsed: tensor<5x5xf32>, tensor<25xf32> -} -// CHECK-LABEL: util.func public @fuse_by_moving_consumer -// CHECK: linalg.generic -// CHECK-NOT: linalg.generic - - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -util.func public @dont_fuse_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) { - %cst = arith.constant 1.000000e+00 : f32 - %cst_0 = arith.constant 2.000000e+00 : f32 - %cst_1 = arith.constant 3.000000e+00 : f32 - %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { - ^bb0(%in: f32, %out: f32): - %2 = arith.addf %in, %cst : f32 - linalg.yield %2 : f32 - } -> tensor<5x5xf32> - %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32> - %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { - ^bb0(%in: f32, %out: f32): - %c2 = arith.constant 2 : index - %extracted = tensor.extract %collapsed[%c2] : tensor<25xf32> - %2 = arith.addf %extracted, %extracted : f32 - linalg.yield %2 : f32 - } -> tensor<5x5xf32> - util.return %1, %collapsed : tensor<5x5xf32>, tensor<25xf32> -} - -// CHECK-LABEL: util.func public @dont_fuse_use_from_above -// CHECK: linalg.generic -// CHECK: linalg.generic - - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -util.func public @do_fuse_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) { - %cst = arith.constant 1.000000e+00 : f32 - %cst_0 = arith.constant 2.000000e+00 : f32 - %cst_1 = arith.constant 3.000000e+00 : f32 - %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { - ^bb0(%in: f32, %out: f32): - %2 = arith.addf %in, %cst : f32 - linalg.yield %2 : f32 - } -> tensor<5x5xf32> - %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32> - %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { - ^bb0(%in: f32, %out: f32): - %c2 = arith.constant 2 : index - %extracted = tensor.extract %arg0[%c2, %c2] : tensor<5x5xf32> - %2 = arith.addf %extracted, %extracted : f32 - linalg.yield %2 : f32 - } -> tensor<5x5xf32> - util.return %1, %collapsed : tensor<5x5xf32>, tensor<25xf32> -} - -// CHECK-LABEL: util.func public @do_fuse_use_from_above -// CHECK: linalg.generic -// CHECK-NOT: linalg.generic