diff --git a/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp b/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp index f802f0b9742b..fa5fb6bf13af 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp @@ -22,6 +22,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -32,6 +33,146 @@ namespace mlir::iree_compiler::DispatchCreation { #define GEN_PASS_DEF_FOLDUNITEXTENTDIMSPASS #include "iree/compiler/DispatchCreation/Passes.h.inc" +namespace { + +/// Simplify collapse_shape(expand_shape) by removing unneeded unit dimensions +/// that get expanded and subsequently collapsed. +struct DropUnitDimsFromCollapseOfExpand + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseOp, + PatternRewriter &rewriter) const override { + auto expandOp = collapseOp.getSrc().getDefiningOp(); + if (!expandOp) { + return failure(); + } + + const auto collapseReassoc = collapseOp.getReassociationIndices(); + ArrayRef interShape = expandOp.getType().getShape(); + ArrayRef outShape = collapseOp.getType().getShape(); + SmallVector interToOutMap(expandOp.getType().getRank()); + llvm::SmallDenseSet toDrop; + for (const auto &[outDim, indicies] : llvm::enumerate(collapseReassoc)) { + for (auto [innerIdx, inDim] : llvm::enumerate(indicies)) { + // Can't drop this dim if it isnt statically 1 or if it isn't being + // combined with any other dimensions. + if (indicies.size() == 1 || interShape[inDim] != 1) { + continue; + } + + // If we are collapsing multiple unit dims together, at least 1 must be + // kept (prefer the first). + if (outShape[outDim] == 1 && innerIdx != 0) { + continue; + } + toDrop.insert(inDim); + } + } + + const auto expandReassoc = expandOp.getReassociationIndices(); + for (const auto &[inDim, indicies] : llvm::enumerate(expandReassoc)) { + // Can't drop unit dim if it isn't from an expansion. + if (indicies.size() == 1) { + toDrop.erase(indicies[0]); + } + } + + if (toDrop.empty()) { + return rewriter.notifyMatchFailure(collapseOp, + "Didn't find any unit dims to drop"); + } + + SmallVector newInterShape; + newInterShape.reserve(interShape.size() - toDrop.size()); + for (auto [idx, length] : llvm::enumerate(interShape)) { + if (!toDrop.contains(idx)) { + newInterShape.push_back(length); + } + } + + /// Returns true if new `ReassociationIndices` were appended to `reassoc`. + auto appendDroppedReassocation = + [&toDrop](SmallVector &reassoc, int64_t start, + int64_t count, int64_t origStart) { + reassoc.emplace_back(); + auto &indicies = reassoc.back(); + indicies.reserve(count); + int64_t dim = start; + for (int64_t idx : llvm::seq(origStart, origStart + count)) { + if (!toDrop.contains(idx)) { + indicies.push_back(dim++); + } + } + + // All indicies have been dropped. + if (indicies.empty()) { + reassoc.pop_back(); + return false; + } + return true; + }; + + auto dropOutputOfr = [&toDrop](const SmallVector &sizes) { + return llvm::map_to_vector( + llvm::make_filter_range( + llvm::enumerate(sizes), + [&toDrop](auto pair) { return !toDrop.contains(pair.index()); }), + [](auto pair) -> OpFoldResult { return pair.value(); }); + }; + + auto isIdentityReassociation = [](ArrayRef reassoc) { + return llvm::all_of(reassoc, + [](auto &indices) { return indices.size() == 1; }); + }; + + SmallVector newCollapseReassoc; + int64_t collapsedDim = 0; + for (auto dim : llvm::seq(0, outShape.size())) { + bool changed = appendDroppedReassocation(newCollapseReassoc, collapsedDim, + collapseReassoc[dim].size(), + collapseReassoc[dim].front()); + if (changed) { + collapsedDim += newCollapseReassoc.back().size(); + } + } + + SmallVector newExpandReassoc; + ArrayRef srcShape = expandOp.getSrcType().getShape(); + int64_t expandedDim = 0; + for (auto dim : llvm::seq(0, srcShape.size())) { + bool changed = appendDroppedReassocation(newExpandReassoc, expandedDim, + expandReassoc[dim].size(), + expandReassoc[dim].front()); + if (changed) { + expandedDim += newExpandReassoc.back().size(); + } + } + + auto outputSizes = getMixedValues(expandOp.getStaticOutputShape(), + expandOp.getOutputShape(), rewriter); + Value newExpanded = expandOp.getSrc(); + if (!isIdentityReassociation(newExpandReassoc)) { + newExpanded = rewriter.create( + expandOp.getLoc(), + RankedTensorType::get(newInterShape, + expandOp.getType().getElementType()), + expandOp.getSrc(), newExpandReassoc, dropOutputOfr(outputSizes)); + } + + Value newCollapsed = newExpanded; + if (!isIdentityReassociation(newCollapseReassoc)) { + newCollapsed = rewriter.create( + collapseOp.getLoc(), collapseOp.getType(), newExpanded, + newCollapseReassoc); + } + rewriter.replaceOp(collapseOp, newCollapsed); + return success(); + } +}; + +} // namespace + //===----------------------------------------------------------------------===// // Pass helpers //===----------------------------------------------------------------------===// @@ -155,6 +296,7 @@ void FoldUnitExtentDimsPass::runOnOperation() { }; linalg::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, options); linalg::populateMoveInitOperandsToInputPattern(foldUnitDimsPatterns); + foldUnitDimsPatterns.insert(context); if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(foldUnitDimsPatterns)))) { return signalPassFailure(); diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir index 249a8b1cba4b..bd8ce1f11453 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir @@ -15,12 +15,13 @@ util.func public @no_fold_unit_dims_in_dispatches(%arg0 : tensor<1x1x10xf32>) -> } util.return %1 : tensor<1x1x10xf32> } -// CHECK: util.func public @no_fold_unit_dims_in_dispatches(%[[ARG0:.+]]: tensor<1x1x10xf32>) -// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region -// CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK-SAME: ins(%[[ARG0]] : tensor<1x1x10xf32>) -// CHECK: flow.return %[[GENERIC]] -// CHECK: util.return %[[DISPATCH]] +// CHECK-LABEL: util.func public @no_fold_unit_dims_in_dispatches +// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x1x10xf32>) +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : tensor<1x1x10xf32>) +// CHECK: flow.return %[[GENERIC]] +// CHECK: util.return %[[DISPATCH]] // ----- @@ -46,15 +47,15 @@ module @fold_unit_dims { } } -// CHECK: module @fold_unit_dims -// CHECK: util.global private mutable @[[GLOBAL:.+]] {inlining_policy = #util.inline.never} = #util.uninitialized : tensor<32x64xf32> -// CHECK: util.global private mutable @[[UNIT_GLOBAL:.+]] = #util.uninitialized : tensor -// CHECK: util.func public @fold_global_unit_dims -// CHECK: %[[LOAD0:.+]] = util.global.load @[[GLOBAL]] : tensor<32x64xf32> -// CHECK: %[[LOAD1:.+]] = util.global.load @[[UNIT_GLOBAL]] : tensor -// CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK-SAME: ins(%[[LOAD0]], %[[LOAD1]] -// CHECK: util.global.store %[[GENERIC]], @[[GLOBAL]] : tensor<32x64xf32> +// CHECK-LABEL: module @fold_unit_dims +// CHECK: util.global private mutable @[[GLOBAL:.+]] {inlining_policy = #util.inline.never} = #util.uninitialized : tensor<32x64xf32> +// CHECK: util.global private mutable @[[UNIT_GLOBAL:.+]] = #util.uninitialized : tensor +// CHECK: util.func public @fold_global_unit_dims +// CHECK: %[[LOAD0:.+]] = util.global.load @[[GLOBAL]] : tensor<32x64xf32> +// CHECK: %[[LOAD1:.+]] = util.global.load @[[UNIT_GLOBAL]] : tensor +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[LOAD0]], %[[LOAD1]] +// CHECK: util.global.store %[[GENERIC]], @[[GLOBAL]] : tensor<32x64xf32> // CHECK: util.return %[[GENERIC]] // ----- @@ -68,12 +69,12 @@ module @no_fold_immutable { } } -// CHECK: module @no_fold_immutable -// CHECK: util.global private @[[GLOBAL:.+]] : tensor<1x32x1x1x64xf32> -// CHECK: util.func public @no_fold_global_unit_dims -// CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<1x32x1x1x64xf32> -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[LOAD]] -// CHECK: util.return %[[COLLAPSE]] +// CHECK-LABEL: module @no_fold_immutable +// CHECK: util.global private @[[GLOBAL:.+]] : tensor<1x32x1x1x64xf32> +// CHECK: util.func public @no_fold_global_unit_dims +// CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<1x32x1x1x64xf32> +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[LOAD]] +// CHECK: util.return %[[COLLAPSE]] // ----- @@ -86,11 +87,11 @@ module @no_fold_public { } } -// CHECK: module @no_fold_public -// CHECK: util.global public mutable @[[GLOBAL:.+]] : tensor<1x32x1x1x64xf32> -// CHECK: util.func public @no_fold_global_unit_dims -// CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<1x32x1x1x64xf32> -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[LOAD]] +// CHECK-LABEL: module @no_fold_public +// CHECK: util.global public mutable @[[GLOBAL:.+]] : tensor<1x32x1x1x64xf32> +// CHECK: util.func public @no_fold_global_unit_dims +// CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<1x32x1x1x64xf32> +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[LOAD]] // ----- @@ -102,7 +103,91 @@ module @fold_stream_parameter { } } -// CHECK: module @fold_stream_parameter -// CHECK: util.global private mutable @[[GLOBAL:.+]] = #stream.parameter.named<"module"::"global"> : tensor<10xf32> -// CHECK: util.func public @fold_stream_parameter +// CHECK-LABEL: module @fold_stream_parameter +// CHECK: util.global private mutable @[[GLOBAL:.+]] = #stream.parameter.named<"module"::"global"> : tensor<10xf32> +// CHECK: util.func public @fold_stream_parameter // CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<10xf32> + +// ----- + +util.func @collapse_of_expand_0(%arg0: tensor, %arg1: index) -> tensor<4x?x128xf16> { + %expanded = tensor.expand_shape %arg0 [[0, 1, 2], [3, 4]] output_shape [4, %arg1, 1, 1, 128] : tensor into tensor<4x?x1x1x128xf16> + %collapsed = tensor.collapse_shape %expanded [[0], [1, 2, 3], [4]] : tensor<4x?x1x1x128xf16> into tensor<4x?x128xf16> + util.return %collapsed : tensor<4x?x128xf16> +} + +// CHECK-LABEL: util.func public @collapse_of_expand_0 +// CHECK-SAME: %[[ARG0:.+]]: tensor, %[[ARG1:.+]]: index +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK-SAME: tensor into tensor<4x?x128xf16> +// CHECK: util.return %[[EXPAND]] : tensor<4x?x128xf16> + +// ----- + +util.func @collapse_of_expand_1(%arg0: tensor, %arg1: index) -> tensor<4x?x64xf16> { + %expanded = tensor.expand_shape %arg0 [[0, 1, 2], [3, 4]] output_shape [4, %arg1, 1, 2, 64] : tensor into tensor<4x?x1x2x64xf16> + %collapsed = tensor.collapse_shape %expanded [[0], [1, 2, 3], [4]] : tensor<4x?x1x2x64xf16> into tensor<4x?x64xf16> + util.return %collapsed : tensor<4x?x64xf16> +} + +// CHECK-LABEL: util.func public @collapse_of_expand_1 +// CHECK-SAME: %[[ARG0:.+]]: tensor, %[[ARG1:.+]]: index +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK-SAME: tensor into tensor<4x?x2x64xf16> +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] +// CHECK-SAME: tensor<4x?x2x64xf16> into tensor<4x?x64xf16> +// CHECK: util.return %[[COLLAPSE]] : tensor<4x?x64xf16> + +// ----- + +util.func @collapse_of_expand_2(%arg0: tensor, %arg1: index) -> tensor<4x?x1xf16> { + %expanded = tensor.expand_shape %arg0 [[0, 1, 2], [3, 4]] output_shape [4, %arg1, 1, 1, 1] : tensor into tensor<4x?x1x1x1xf16> + %collapsed = tensor.collapse_shape %expanded [[0], [1, 2, 3], [4]] : tensor<4x?x1x1x1xf16> into tensor<4x?x1xf16> + util.return %collapsed : tensor<4x?x1xf16> +} + +// CHECK-LABEL: util.func public @collapse_of_expand_2 +// CHECK-SAME: %[[ARG0:.+]]: tensor, %[[ARG1:.+]]: index +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK-SAME: tensor into tensor<4x?x1xf16> +// CHECK: util.return %[[EXPAND]] : tensor<4x?x1xf16> + +// ----- + +util.func @collapse_of_expand_3(%arg0: tensor, %arg1: index, %arg2: index) -> tensor { + %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [%arg1, 1, 1, %arg2] : tensor into tensor + %collapsed = tensor.collapse_shape %expanded [[0], [1, 2, 3]] : tensor into tensor + util.return %collapsed : tensor +} + +// CHECK-LABEL: util.func public @collapse_of_expand_3 +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: util.return %[[ARG0]] : tensor + +// ----- + +util.func @collapse_of_expand_4(%arg0: tensor<1x1xf16>, %arg1: index, %arg2: index) -> tensor<1xf16> { + %expanded = tensor.expand_shape %arg0 [[0, 1, 2], [3]] output_shape [%arg1, 1, 1, %arg2] : tensor<1x1xf16> into tensor<1x1x1x1xf16> + %collapsed = tensor.collapse_shape %expanded [[0, 1, 2, 3]] : tensor<1x1x1x1xf16> into tensor<1xf16> + util.return %collapsed : tensor<1xf16> +} + +// CHECK-LABEL: util.func public @collapse_of_expand_4 +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1xf16> +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] +// CHECK-SAME: tensor<1x1xf16> into tensor<1xf16> +// CHECK: util.return %[[COLLAPSED]] : tensor<1xf16> + +// ----- + +util.func @collapse_of_expand_5(%arg0: tensor<1x?x4x32xf16>, %arg1: index) -> tensor { + %expanded = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [1, %arg1, 4, 1, 32] : tensor<1x?x4x32xf16> into tensor<1x?x4x1x32xf16> + %collapsed = tensor.collapse_shape %expanded [[0, 1], [2, 3], [4]] : tensor<1x?x4x1x32xf16> into tensor + util.return %collapsed : tensor +} + +// CHECK-LABEL: util.func public @collapse_of_expand_5 +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x4x32xf16> +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] +// CHECK-SAME: tensor<1x?x4x32xf16> into tensor +// CHECK: util.return %[[COLLAPSED]] : tensor