diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp index e866022eb9a9..e4037823f37a 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -402,6 +403,64 @@ static bool areOpsFusable(Operation *producer, Operation *consumer, return true; } +/// The logic to decide fusability (using the `hasCompatibleOuterParallelLoops`) +/// currently works when the indexing map corresponding to result of the +/// producer and indexing map corresponding to operand in the result are not +/// transposed with respect to each other. To find more fusion opportunities for +/// consumer elementwise operation, the indexing maps in the consumer can be +/// made to "align" with the indexing map of the producer to enhance fusion. +static bool areOpsFusableAfterInterchangeOfConsumer( + OpOperand &fusableOperand, + const llvm::SmallBitVector &rootOuterParallelLoops) { + Operation *producer = fusableOperand.get().getDefiningOp(); + if (!producer) { + return false; + } + + Operation *consumer = fusableOperand.getOwner(); + auto genericOp = dyn_cast(consumer); + if (!genericOp) { + return false; + } + assert(genericOp.getNumDpsInputs() > 0 && + "expected consumer to have at least one input"); + + if (!linalg::isElementwise(genericOp) || genericOp.getNumResults() != 1) { + return false; + } + + // The input map must be a permutation (i.e. this is not a broadcasting + // access), but not identity. + AffineMap inputMap = genericOp.getMatchingIndexingMap(&fusableOperand); + if (!inputMap.isPermutation() || inputMap.isIdentity()) { + return false; + } + + // The output should also be a permutation. It should be for a parallel + // elementwise op, but checking here anyway. + OpResult result = cast(genericOp.getResult(0)); + if (!genericOp.getIndexingMapMatchingResult(result).isPermutation()) { + return false; + } + + // Make the input map identity. + auto perm = + llvm::map_to_vector(inputMap.getResults(), [](AffineExpr e) -> unsigned { + return cast(e).getPosition(); + }); + IRRewriter rewriter(consumer->getContext()); + FailureOr interchangedOp = + linalg::interchangeGenericOp(rewriter, genericOp, perm); + (void)interchangedOp; + assert(succeeded(interchangedOp) && "expected interchange to succeed"); + assert(interchangedOp.value() == genericOp && + "expected interchange to happen in place"); + assert( + areOpsFusable(producer, interchangedOp.value(), rootOuterParallelLoops) && + "expected the interchanged op to be fusable"); + return true; +} + /// For the fusion of root op -> elementwise operation to be bufferized /// in-place without use of extra memory, the result of the root operation /// must be able to reuse the buffer for the result of the elementwise @@ -561,7 +620,10 @@ isFusableWithConsumer(OpOperand &fusedOperand, } if (!areOpsFusable(producer, consumer, rootOuterParallelLoops)) { - return false; + if (!areOpsFusableAfterInterchangeOfConsumer(fusedOperand, + rootOuterParallelLoops)) { + return false; + } } // Check if the iteration spaces of the producer and consumer are same. diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionPreprocessing.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionPreprocessing.cpp index 158775571e30..ba51c7a207ab 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionPreprocessing.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FusionPreprocessing.cpp @@ -38,44 +38,6 @@ namespace mlir::iree_compiler::DispatchCreation { namespace { -//===----------------------------------------------------------------------===// -// ElementwiseOpInterchangePattern -//===----------------------------------------------------------------------===// - -// If possible, interchange indexing maps to make input maps all identity. -struct ElementwiseOpInterchangePattern final - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(linalg::GenericOp genericOp, - PatternRewriter &rewriter) const override { - if (!linalg::isElementwise(genericOp) || genericOp.getNumResults() != 1 || - genericOp.getNumDpsInputs() == 0) - return failure(); - - // All input maps must be equal and non-identity. All maps, including - // output, must be be permutations. Permutation maps are checked by - // isElementwise but may be removed. - AffineMap inputMap = genericOp.getIndexingMapsArray().front(); - auto *initOperand = genericOp.getDpsInitOperand(0); - if (inputMap.isIdentity() || !inputMap.isPermutation() || - !genericOp.getMatchingIndexingMap(initOperand).isPermutation()) { - return failure(); - } - for (auto *operand : genericOp.getDpsInputOperands()) { - if (genericOp.getMatchingIndexingMap(operand) != inputMap) { - return failure(); - } - } - - // Make all inputs identity. - ArrayRef exprs = inputMap.getResults(); - auto perm = llvm::map_to_vector(exprs, [](AffineExpr e) -> unsigned { - return cast(e).getPosition(); - }); - return linalg::interchangeGenericOp(rewriter, genericOp, perm); - } -}; - //===----------------------------------------------------------------------===// // FoldSuccessiveTensorInsertSliceOps //===----------------------------------------------------------------------===// @@ -153,8 +115,7 @@ struct FusionPreprocessingPass final : public impl::FusionPreprocessingPassBase { void runOnOperation() override { RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); + patterns.add(&getContext()); // Fold away `tensor.dim` operations that can be resolved in terms of its // operand shapes. diff --git a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir index b29f43ed47fa..a2a720922e0d 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir @@ -910,3 +910,35 @@ util.func @custom_op_no_producer_fusion(%arg0 : tensor, %arg1 : tensor< // CHECK-SAME: ins(%[[DISPATCH1]], // CHECK: flow.return %[[CUSTOM_OP]] // CHECK: util.return %[[DISPATCH2]] + +// ----- + +util.func @fuse_transposed_op(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.0: f32 + %m = tensor.dim %arg0, %c0 : tensor + %n = tensor.dim %arg1, %c1 : tensor + %empty = tensor.empty(%m, %n) : tensor + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor) -> tensor + %matmul = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) outs(%fill : tensor) -> tensor + %empty2 = tensor.empty(%n, %m) : tensor + %generic = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%matmul, %arg2 : tensor, tensor) + outs(%empty2 : tensor) { + ^bb0(%b0: f32, %b1 : f32, %b2 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor + util.return %generic : tensor +} +// CHECK-LABEL: func public @fuse_transposed_op +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region +// CHECK: %[[MATMUL:.+]] = linalg.matmul +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[MATMUL]], +// CHECK: flow.return %[[GENERIC]] +// CHECK: return %[[DISPATCH]] diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fusion_preprocessing.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fusion_preprocessing.mlir index 14e7df57c127..4563dcc2ed5e 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/fusion_preprocessing.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/fusion_preprocessing.mlir @@ -30,77 +30,3 @@ util.func public @fold_insert_slices(%source : tensor, // CHECK: %[[RETURN:.+]] = tensor.insert_slice %[[SOURCE]] into %[[FILL]] // CHECK-SAME: [%[[NEW_OFFSET0]], %[[NEW_OFFSET1]]] [%[[SIZE0]], %[[SIZE1]]] // CHECK: util.return %[[RETURN]] - -// ----- - -#ident = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -#perm = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d1)> -util.func @single_input_interchange(%arg0: tensor<2x128x128x320xf32>) -> tensor<2x320x128x128xf16> { - %0 = tensor.empty() : tensor<2x320x128x128xf16> - %1 = linalg.generic {indexing_maps = [#perm, #ident], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x128x128x320xf32>) outs(%0 : tensor<2x320x128x128xf16>) { - ^bb0(%in: f32, %out: f16): - %2 = arith.truncf %in : f32 to f16 - linalg.yield %2 : f16 - } -> tensor<2x320x128x128xf16> - util.return %1 : tensor<2x320x128x128xf16> -} - -// CHECK-DAG: #[[$IDENT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-DAG: #[[$PERM_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d1)> -// CHECK-LABEL: util.func public @single_input_interchange -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x128x128x320xf32> -// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x320x128x128xf16> -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$IDENT_MAP]], #[[$PERM_MAP]]] -// CHECK-SAME: ins(%[[ARG0]] : tensor<2x128x128x320xf32>) -// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x320x128x128xf16>) - -// ----- - -#ident = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -#perm = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d1)> -util.func @multi_input_interchange(%arg0: tensor<2x128x128x320xf32>) -> tensor<2x320x128x128xf16> { - %0 = tensor.empty() : tensor<2x320x128x128xf16> - %1 = linalg.generic {indexing_maps = [#perm, #perm, #ident], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg0 : tensor<2x128x128x320xf32>, tensor<2x128x128x320xf32>) outs(%0 : tensor<2x320x128x128xf16>) { - ^bb0(%in: f32, %in_1: f32, %out: f16): - %2 = arith.addf %in, %in_1 : f32 - %3 = arith.truncf %2 : f32 to f16 - linalg.yield %3 : f16 - } -> tensor<2x320x128x128xf16> - util.return %1 : tensor<2x320x128x128xf16> -} - -// CHECK-DAG: #[[$IDENT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-DAG: #[[$PERM_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d1)> -// CHECK-LABEL: util.func public @multi_input_interchange -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x128x128x320xf32> -// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x320x128x128xf16> -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$IDENT_MAP]], #[[$IDENT_MAP]], #[[$PERM_MAP]]] -// CHECK-SAME: ins(%[[ARG0]], %[[ARG0]] : tensor<2x128x128x320xf32>, tensor<2x128x128x320xf32>) -// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x320x128x128xf16>) - -// ----- - -#ident = affine_map<(d0, d1) -> (d0, d1)> -#perm0 = affine_map<(d0, d1) -> (d1, d0)> -util.func @multi_input_no_interchange(%arg0: tensor<10x10xf32>) -> tensor<10x10xf16> { - %0 = tensor.empty() : tensor<10x10xf16> - %1 = linalg.generic {indexing_maps = [#ident, #perm0, #perm0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg0 : tensor<10x10xf32>, tensor<10x10xf32>) outs(%0 : tensor<10x10xf16>) { - ^bb0(%in: f32, %in_1: f32, %out: f16): - %2 = arith.addf %in, %in_1 : f32 - %3 = arith.truncf %2 : f32 to f16 - linalg.yield %3 : f16 - } -> tensor<10x10xf16> - util.return %1 : tensor<10x10xf16> -} - -// CHECK-DAG: #[[$IDENT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[$PERM_MAP0:.+]] = affine_map<(d0, d1) -> (d1, d0)> -// CHECK-LABEL: util.func public @multi_input_no_interchange -// CHECK-SAME: %[[ARG0:.*]]: tensor<10x10xf32> -// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<10x10xf16> -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$IDENT_MAP]], #[[$PERM_MAP0]], #[[$PERM_MAP0]]] -// CHECK-SAME: ins(%[[ARG0]], %[[ARG0]] : tensor<10x10xf32>, tensor<10x10xf32>) -// CHECK-SAME: outs(%[[EMPTY]] : tensor<10x10xf16>)