From cff1fb1ef6ef68d9dbbaea49fa2d2c8ce9c4e8c2 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Wed, 11 Dec 2024 07:19:33 -0800 Subject: [PATCH] Drop the unit dims on scatter ops. Signed-off-by: Ian Wood --- .../Dialect/LinalgExt/IR/LinalgExtOps.cpp | 33 ++------ .../Dialect/LinalgExt/IR/test/invalid.mlir | 17 ---- .../LinalgExt/Transforms/ReshapeFusion.cpp | 77 +++++++++++++++++++ .../Dialect/LinalgExt/Transforms/Transforms.h | 7 ++ .../DispatchCreation/FoldUnitExtentDims.cpp | 7 ++ .../DispatchCreation/test/fold_unit_dims.mlir | 30 ++++++++ 6 files changed, 126 insertions(+), 45 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp index d66c5f2162f6..e6eff1b3d49d 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -166,20 +166,14 @@ LogicalResult ScatterOp::verify() { "update value rank exceeds the rank of the original value"); } - // indexDepth + update dims should cover the original dims. The first dim of - // update is the number of updates. - if (originalType.getRank() > indexDepth + updateType.getRank() - 1) { - return op->emitOpError( - "index depth and update value does not cover rank of original value"); - } - // Validate the non-indexed update dims cover the full slice size of the // original tensor. int64_t fullSliceDims = originalType.getRank() - indexDepth; - for (auto it : - llvm::zip_equal(llvm::seq(indexDepth, originalType.getRank()), - llvm::seq(updateType.getRank() - fullSliceDims, - updateType.getRank()))) { + for (auto it : llvm::zip( + llvm::reverse( + llvm::seq(indexDepth, originalType.getRank())), + llvm::reverse(llvm::seq( + updateType.getRank() - fullSliceDims, updateType.getRank())))) { int64_t originalDim = std::get<0>(it); int64_t updateDim = std::get<1>(it); if (!originalType.isDynamicDim(originalDim) && @@ -190,23 +184,6 @@ LogicalResult ScatterOp::verify() { } } - // Check that the remaining update indices do not exceed the update length. - int64_t insertDims = originalType.getRank() - updateType.getRank() + 1; - for (auto it : llvm::zip_equal( - llvm::seq(insertDims, indexDepth), - llvm::seq(1, updateType.getRank() - fullSliceDims))) { - int64_t originalDim = std::get<0>(it); - int64_t updateDim = std::get<1>(it); - if (!originalType.isDynamicDim(originalDim) && - updateType.getDimSize(updateDim) > - originalType.getDimSize(originalDim)) { - return op->emitOpError("indexed shape of update value dim#") - << updateDim << " exceeds original value at dim#" << originalDim - << " " << updateType.getDimSize(updateDim) << " " - << originalType.getDimSize(originalDim); - } - } - Region ®ion = this->getRegion(); Block *body = ®ion.front(); if (body->getNumArguments() != 2) { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir index 32ed87e1311d..c8617d4adbef 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir @@ -321,23 +321,6 @@ func.func @scatter_index_depth_dynamic( // ----- -func.func @scatter_original_rank_mismatch( - %update : tensor, %indices : tensor, - %original : tensor) -> tensor { - // expected-error @+1 {{op index depth and update value does not cover rank of original value}} - %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: i64, %arg2: i64): - %1 = arith.addi %arg1, %arg2 : i64 - %2 = arith.trunci %1 : i64 to i32 - iree_linalg_ext.yield %1, %2 : i64, i32 - } -> tensor - return %0 : tensor -} - -// ----- - func.func @topk_invalid(%input_values: tensor<2x10xf32>, %input_indices: tensor<2x10xi32>, %out_values : tensor<2x3xf32>, %out_indices: tensor<2x3xi32>) -> (tensor<2x3xf32>, tensor<2x3xi32>) { // expected-error@+1 {{expected one or two input operands}} %0:2 = iree_linalg_ext.topk diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp index e87efd9f2099..0d0fb1ec1afb 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp @@ -486,6 +486,73 @@ struct FoldAttentionWithProducerReshapeByExpansion final linalg::ControlFusionFn controlFoldingReshapes; }; +/// Remove the unit dims from `iree_linalg_ext.scatter` 's `update` operand. +/// The `update` tensor is scanned from left to right, starting from the second +/// element. The number of unit dimensions are counted until reaching a non unit +/// dim. +struct FoldScatterUnitDims final : public OpRewritePattern { + FoldScatterUnitDims(MLIRContext *context, linalg::ControlDropUnitDims options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + options(std::move(options)) {} + + LogicalResult matchAndRewrite(ScatterOp scatterOp, + PatternRewriter &rewriter) const override { + if (options.rankReductionStrategy != + linalg::ControlDropUnitDims::RankReductionStrategy:: + ReassociativeReshape) { + return rewriter.notifyMatchFailure( + scatterOp, "Only reassociative reshape strategy supported"); + } + llvm::SmallVector canDrop = options.controlFn(scatterOp); + + // TODO: use the actual rank once it has been added. + constexpr int64_t batchRank = 1; + const ArrayRef updateShape = scatterOp.getUpdateType().getShape(); + + // Find the first `numDimsToDrop` unit dimensions in the update tensor, + // these are the ones that can be dropped. + int64_t numDimsToDrop = 0; + for (auto val : updateShape.drop_front(batchRank)) { + if (val != 1) { + break; + } + ++numDimsToDrop; + } + llvm::erase_if(canDrop, [&](unsigned dimPos) { + return dimPos < batchRank || dimPos >= batchRank + numDimsToDrop; + }); + if (canDrop.empty()) { + return failure(); + } + + SmallVector droppedUpdateShape; + droppedUpdateShape.reserve(updateShape.size() - canDrop.size()); + for (auto [idx, dimLen] : llvm::enumerate(updateShape)) { + if (!llvm::is_contained(canDrop, idx)) { + droppedUpdateShape.push_back(dimLen); + } + } + + auto reassoc = + getReassociationIndicesForCollapse(updateShape, droppedUpdateShape); + assert(reassoc.has_value() && "expected reassociation to be valid"); + auto collapseOp = rewriter.create( + scatterOp.getLoc(), + RankedTensorType::get(droppedUpdateShape, + scatterOp.getUpdateType().getElementType()), + scatterOp.getUpdates(), reassoc.value()); + + constexpr int64_t kUpdateOpNum = 0; + rewriter.modifyOpInPlace(scatterOp, [&]() { + scatterOp.setOperand(kUpdateOpNum, collapseOp.getResult()); + }); + return success(); + } + + linalg::ControlDropUnitDims options; +}; + } // namespace /// Return the `reassociation` indices to use to collapse the operand when the @@ -708,4 +775,14 @@ void populateFoldReshapeOpsByExpansionPatterns( patterns.getContext(), controlFoldingReshapes); } +SmallVector defaultControlDropUnitDims(Operation *op) { + auto fusionOp = cast(op); + return llvm::to_vector(llvm::seq(0, fusionOp.getNumLoops())); +} + +void populateFoldUnitExtentDimsPatterns( + RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options) { + patterns.add(patterns.getContext(), options); +} + } // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h index 8bf84cab2574..8da0225e27ef 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h @@ -25,6 +25,13 @@ void populateBubbleTransposeFromLinalgExtOps( RewritePatternSet &patterns, const linalg::ControlFusionFn &controlFusionFn); +/// Default function to drop unit dims for for linalgext ops. +SmallVector defaultControlDropUnitDims(Operation *op); + +/// Drop unit extent dims from linalg ext ops +void populateFoldUnitExtentDimsPatterns( + RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options); + /// Helper struct to hold the results of collapsing an operation. struct CollapseResult { SmallVector results; diff --git a/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp b/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp index f802f0b9742b..40fabc56bcf7 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp @@ -12,6 +12,8 @@ //===----------------------------------------------------------------------===// #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" +#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" +#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" #include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" #include "iree/compiler/Dialect/Util/Analysis/Explorer.h" #include "iree/compiler/DispatchCreation/Passes.h" @@ -151,9 +153,14 @@ void FoldUnitExtentDimsPass::runOnOperation() { if (!IREE::Flow::isNonNullAndOutsideDispatch(op)) { return SmallVector{}; } + if (isa(op)) { + return IREE::LinalgExt::defaultControlDropUnitDims(op); + } return defaultFn(op); }; linalg::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, options); + IREE::LinalgExt::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, + options); linalg::populateMoveInitOperandsToInputPattern(foldUnitDimsPatterns); if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(foldUnitDimsPatterns)))) { diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir index 249a8b1cba4b..513221c1f9c5 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir @@ -106,3 +106,33 @@ module @fold_stream_parameter { // CHECK: util.global private mutable @[[GLOBAL:.+]] = #stream.parameter.named<"module"::"global"> : tensor<10xf32> // CHECK: util.func public @fold_stream_parameter // CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<10xf32> + +// ----- + +util.func public @scatter0(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%arg3: f16, %arg4: f16): + iree_linalg_ext.yield %arg3 : f16 + } -> tensor + util.return %0 : tensor +} +// CHECK-LABEL: func public @scatter0 +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape +// CHECK-SAME: to tensor +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[COLLAPSE]] + +// ----- + +util.func public @scatter1(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%arg3: f16, %arg4: f16): + iree_linalg_ext.yield %arg3 : f16 + } -> tensor + util.return %0 : tensor +} +// CHECK-LABEL: func public @scatter1 +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape +// CHECK-SAME: to tensor +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[COLLAPSE]]