Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DispatchCreation] Move the logic to transpose indexing maps into dispatch formation logic. #19412

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -402,6 +403,64 @@ static bool areOpsFusable(Operation *producer, Operation *consumer,
return true;
}

/// The logic to decide fusability (using the `hasCompatibleOuterParallelLoops`)
/// currently works when the indexing map corresponding to result of the
/// producer and indexing map corresponding to operand in the result are not
/// transposed with respect to each other. To find more fusion opportunities for
/// consumer elementwise operation, the indexing maps in the consumer can be
/// made to "align" with the indexing map of the producer to enhance fusion.
static bool areOpsFusableAfterInterchangeOfConsumer(
Copy link
Contributor

@IanWood1 IanWood1 Dec 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename this to something like makeConsumerFusableViaInterchange to signal that it modifies the operation? So it's more clear it has side-effects on the indexing maps of the op.

I'm not sure whats going on with the regression tests but otherwise this looks good. I'm glad ElementwiseOpInterchangePattern is getting removed, it was a bit weird and caused some problems

OpOperand &fusableOperand,
const llvm::SmallBitVector &rootOuterParallelLoops) {
Operation *producer = fusableOperand.get().getDefiningOp();
if (!producer) {
return false;
}

Operation *consumer = fusableOperand.getOwner();
auto genericOp = dyn_cast<linalg::GenericOp>(consumer);
if (!genericOp) {
return false;
}
assert(genericOp.getNumDpsInputs() > 0 &&
"expected consumer to have at least one input");

if (!linalg::isElementwise(genericOp) || genericOp.getNumResults() != 1) {
return false;
}

// The input map must be a permutation (i.e. this is not a broadcasting
// access), but not identity.
AffineMap inputMap = genericOp.getMatchingIndexingMap(&fusableOperand);
if (!inputMap.isPermutation() || inputMap.isIdentity()) {
return false;
}

// The output should also be a permutation. It should be for a parallel
// elementwise op, but checking here anyway.
OpResult result = cast<OpResult>(genericOp.getResult(0));
if (!genericOp.getIndexingMapMatchingResult(result).isPermutation()) {
return false;
}

// Make the input map identity.
auto perm =
llvm::map_to_vector(inputMap.getResults(), [](AffineExpr e) -> unsigned {
return cast<AffineDimExpr>(e).getPosition();
});
IRRewriter rewriter(consumer->getContext());
FailureOr<linalg::GenericOp> interchangedOp =
linalg::interchangeGenericOp(rewriter, genericOp, perm);
(void)interchangedOp;
assert(succeeded(interchangedOp) && "expected interchange to succeed");
assert(interchangedOp.value() == genericOp &&
"expected interchange to happen in place");
assert(
areOpsFusable(producer, interchangedOp.value(), rootOuterParallelLoops) &&
"expected the interchanged op to be fusable");
return true;
}

/// For the fusion of root op -> elementwise operation to be bufferized
/// in-place without use of extra memory, the result of the root operation
/// must be able to reuse the buffer for the result of the elementwise
Expand Down Expand Up @@ -561,7 +620,10 @@ isFusableWithConsumer(OpOperand &fusedOperand,
}

if (!areOpsFusable(producer, consumer, rootOuterParallelLoops)) {
return false;
if (!areOpsFusableAfterInterchangeOfConsumer(fusedOperand,
rootOuterParallelLoops)) {
return false;
}
}

// Check if the iteration spaces of the producer and consumer are same.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,44 +38,6 @@ namespace mlir::iree_compiler::DispatchCreation {

namespace {

//===----------------------------------------------------------------------===//
// ElementwiseOpInterchangePattern
//===----------------------------------------------------------------------===//

// If possible, interchange indexing maps to make input maps all identity.
struct ElementwiseOpInterchangePattern final
: public OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!linalg::isElementwise(genericOp) || genericOp.getNumResults() != 1 ||
genericOp.getNumDpsInputs() == 0)
return failure();

// All input maps must be equal and non-identity. All maps, including
// output, must be be permutations. Permutation maps are checked by
// isElementwise but may be removed.
AffineMap inputMap = genericOp.getIndexingMapsArray().front();
auto *initOperand = genericOp.getDpsInitOperand(0);
if (inputMap.isIdentity() || !inputMap.isPermutation() ||
!genericOp.getMatchingIndexingMap(initOperand).isPermutation()) {
return failure();
}
for (auto *operand : genericOp.getDpsInputOperands()) {
if (genericOp.getMatchingIndexingMap(operand) != inputMap) {
return failure();
}
}

// Make all inputs identity.
ArrayRef<AffineExpr> exprs = inputMap.getResults();
auto perm = llvm::map_to_vector(exprs, [](AffineExpr e) -> unsigned {
return cast<AffineDimExpr>(e).getPosition();
});
return linalg::interchangeGenericOp(rewriter, genericOp, perm);
}
};

//===----------------------------------------------------------------------===//
// FoldSuccessiveTensorInsertSliceOps
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -153,8 +115,7 @@ struct FusionPreprocessingPass final
: public impl::FusionPreprocessingPassBase<FusionPreprocessingPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<ElementwiseOpInterchangePattern,
FoldSuccessiveTensorInsertSliceOps>(&getContext());
patterns.add<FoldSuccessiveTensorInsertSliceOps>(&getContext());

// Fold away `tensor.dim` operations that can be resolved in terms of its
// operand shapes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -910,3 +910,35 @@ util.func @custom_op_no_producer_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<
// CHECK-SAME: ins(%[[DISPATCH1]],
// CHECK: flow.return %[[CUSTOM_OP]]
// CHECK: util.return %[[DISPATCH2]]

// -----

util.func @fuse_transposed_op(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.0: f32
%m = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%n = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%empty = tensor.empty(%m, %n) : tensor<?x?xf32>
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<?x?xf32>) -> tensor<?x?xf32>
%matmul = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
%empty2 = tensor.empty(%n, %m) : tensor<?x?xf32>
%generic = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%matmul, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%empty2 : tensor<?x?xf32>) {
^bb0(%b0: f32, %b1 : f32, %b2 : f32):
%0 = arith.addf %b0, %b1 : f32
linalg.yield %0 : f32
} -> tensor<?x?xf32>
util.return %generic : tensor<?x?xf32>
}
// CHECK-LABEL: func public @fuse_transposed_op
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
// CHECK: %[[MATMUL:.+]] = linalg.matmul
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[MATMUL]],
// CHECK: flow.return %[[GENERIC]]
// CHECK: return %[[DISPATCH]]
Original file line number Diff line number Diff line change
Expand Up @@ -30,77 +30,3 @@ util.func public @fold_insert_slices(%source : tensor<?x?xf32>,
// CHECK: %[[RETURN:.+]] = tensor.insert_slice %[[SOURCE]] into %[[FILL]]
// CHECK-SAME: [%[[NEW_OFFSET0]], %[[NEW_OFFSET1]]] [%[[SIZE0]], %[[SIZE1]]]
// CHECK: util.return %[[RETURN]]

// -----

#ident = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#perm = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d1)>
util.func @single_input_interchange(%arg0: tensor<2x128x128x320xf32>) -> tensor<2x320x128x128xf16> {
%0 = tensor.empty() : tensor<2x320x128x128xf16>
%1 = linalg.generic {indexing_maps = [#perm, #ident], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x128x128x320xf32>) outs(%0 : tensor<2x320x128x128xf16>) {
^bb0(%in: f32, %out: f16):
%2 = arith.truncf %in : f32 to f16
linalg.yield %2 : f16
} -> tensor<2x320x128x128xf16>
util.return %1 : tensor<2x320x128x128xf16>
}

// CHECK-DAG: #[[$IDENT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-DAG: #[[$PERM_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d1)>
// CHECK-LABEL: util.func public @single_input_interchange
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x128x128x320xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x320x128x128xf16>
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$IDENT_MAP]], #[[$PERM_MAP]]]
// CHECK-SAME: ins(%[[ARG0]] : tensor<2x128x128x320xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x320x128x128xf16>)

// -----

#ident = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#perm = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d1)>
util.func @multi_input_interchange(%arg0: tensor<2x128x128x320xf32>) -> tensor<2x320x128x128xf16> {
%0 = tensor.empty() : tensor<2x320x128x128xf16>
%1 = linalg.generic {indexing_maps = [#perm, #perm, #ident], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg0 : tensor<2x128x128x320xf32>, tensor<2x128x128x320xf32>) outs(%0 : tensor<2x320x128x128xf16>) {
^bb0(%in: f32, %in_1: f32, %out: f16):
%2 = arith.addf %in, %in_1 : f32
%3 = arith.truncf %2 : f32 to f16
linalg.yield %3 : f16
} -> tensor<2x320x128x128xf16>
util.return %1 : tensor<2x320x128x128xf16>
}

// CHECK-DAG: #[[$IDENT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-DAG: #[[$PERM_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d1)>
// CHECK-LABEL: util.func public @multi_input_interchange
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x128x128x320xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x320x128x128xf16>
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$IDENT_MAP]], #[[$IDENT_MAP]], #[[$PERM_MAP]]]
// CHECK-SAME: ins(%[[ARG0]], %[[ARG0]] : tensor<2x128x128x320xf32>, tensor<2x128x128x320xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x320x128x128xf16>)

// -----

#ident = affine_map<(d0, d1) -> (d0, d1)>
#perm0 = affine_map<(d0, d1) -> (d1, d0)>
util.func @multi_input_no_interchange(%arg0: tensor<10x10xf32>) -> tensor<10x10xf16> {
%0 = tensor.empty() : tensor<10x10xf16>
%1 = linalg.generic {indexing_maps = [#ident, #perm0, #perm0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg0 : tensor<10x10xf32>, tensor<10x10xf32>) outs(%0 : tensor<10x10xf16>) {
^bb0(%in: f32, %in_1: f32, %out: f16):
%2 = arith.addf %in, %in_1 : f32
%3 = arith.truncf %2 : f32 to f16
linalg.yield %3 : f16
} -> tensor<10x10xf16>
util.return %1 : tensor<10x10xf16>
}

// CHECK-DAG: #[[$IDENT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[$PERM_MAP0:.+]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK-LABEL: util.func public @multi_input_no_interchange
// CHECK-SAME: %[[ARG0:.*]]: tensor<10x10xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<10x10xf16>
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$IDENT_MAP]], #[[$PERM_MAP0]], #[[$PERM_MAP0]]]
// CHECK-SAME: ins(%[[ARG0]], %[[ARG0]] : tensor<10x10xf32>, tensor<10x10xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<10x10xf16>)
Loading