Skip to content

Commit

Permalink
Drop the unit dims on scatter ops.
Browse files Browse the repository at this point in the history
Signed-off-by: Ian Wood <[email protected]>
  • Loading branch information
IanWood1 committed Dec 11, 2024
1 parent 7177c29 commit ce0ec3a
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 45 deletions.
33 changes: 5 additions & 28 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,20 +166,14 @@ LogicalResult ScatterOp::verify() {
"update value rank exceeds the rank of the original value");
}

// indexDepth + update dims should cover the original dims. The first dim of
// update is the number of updates.
if (originalType.getRank() > indexDepth + updateType.getRank() - 1) {
return op->emitOpError(
"index depth and update value does not cover rank of original value");
}

// Validate the non-indexed update dims cover the full slice size of the
// original tensor.
int64_t fullSliceDims = originalType.getRank() - indexDepth;
for (auto it :
llvm::zip_equal(llvm::seq<unsigned>(indexDepth, originalType.getRank()),
llvm::seq<unsigned>(updateType.getRank() - fullSliceDims,
updateType.getRank()))) {
for (auto it : llvm::zip(
llvm::reverse(
llvm::seq<unsigned>(indexDepth, originalType.getRank())),
llvm::reverse(llvm::seq<unsigned>(
updateType.getRank() - fullSliceDims, updateType.getRank())))) {
int64_t originalDim = std::get<0>(it);
int64_t updateDim = std::get<1>(it);
if (!originalType.isDynamicDim(originalDim) &&
Expand All @@ -190,23 +184,6 @@ LogicalResult ScatterOp::verify() {
}
}

// Check that the remaining update indices do not exceed the update length.
int64_t insertDims = originalType.getRank() - updateType.getRank() + 1;
for (auto it : llvm::zip_equal(
llvm::seq<unsigned>(insertDims, indexDepth),
llvm::seq<unsigned>(1, updateType.getRank() - fullSliceDims))) {
int64_t originalDim = std::get<0>(it);
int64_t updateDim = std::get<1>(it);
if (!originalType.isDynamicDim(originalDim) &&
updateType.getDimSize(updateDim) >
originalType.getDimSize(originalDim)) {
return op->emitOpError("indexed shape of update value dim#")
<< updateDim << " exceeds original value at dim#" << originalDim
<< " " << updateType.getDimSize(updateDim) << " "
<< originalType.getDimSize(originalDim);
}
}

Region &region = this->getRegion();
Block *body = &region.front();
if (body->getNumArguments() != 2) {
Expand Down
17 changes: 0 additions & 17 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -321,23 +321,6 @@ func.func @scatter_index_depth_dynamic(

// -----

func.func @scatter_original_rank_mismatch(
%update : tensor<?xi64>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
// expected-error @+1 {{op index depth and update value does not cover rank of original value}}
%0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
ins(%update, %indices : tensor<?xi64>, tensor<?x1xi32>)
outs(%original : tensor<?x?xi64>) {
^bb0(%arg1: i64, %arg2: i64):
%1 = arith.addi %arg1, %arg2 : i64
%2 = arith.trunci %1 : i64 to i32
iree_linalg_ext.yield %1, %2 : i64, i32
} -> tensor<?x?xi64>
return %0 : tensor<?x?xi64>
}

// -----

func.func @topk_invalid(%input_values: tensor<2x10xf32>, %input_indices: tensor<2x10xi32>, %out_values : tensor<2x3xf32>, %out_indices: tensor<2x3xi32>) -> (tensor<2x3xf32>, tensor<2x3xi32>) {
// expected-error@+1 {{expected one or two input operands}}
%0:2 = iree_linalg_ext.topk
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,73 @@ struct FoldAttentionWithProducerReshapeByExpansion final
linalg::ControlFusionFn controlFoldingReshapes;
};

/// Remove the unit dims from `iree_linalg_ext.scatter` 's `update` operand.
/// The `update` tensor is scanned from left to right, starting from the second
/// element. The number of unit dimensions are counted until reaching a non unit
/// dim.
struct FoldScatterUnitDims final : public OpRewritePattern<ScatterOp> {
FoldScatterUnitDims(MLIRContext *context, linalg::ControlDropUnitDims options,
PatternBenefit benefit = 1)
: OpRewritePattern<ScatterOp>(context, benefit),
options(std::move(options)) {}

LogicalResult matchAndRewrite(ScatterOp scatterOp,
PatternRewriter &rewriter) const override {
if (options.rankReductionStrategy !=
linalg::ControlDropUnitDims::RankReductionStrategy::
ReassociativeReshape) {
return rewriter.notifyMatchFailure(
scatterOp, "Only reassociative reshape strategy supported");
}
llvm::SmallVector<unsigned> canDrop = options.controlFn(scatterOp);

// TODO: use the actual rank once it has been added.
constexpr int64_t batchRank = 1;
const ArrayRef<int64_t> updateShape = scatterOp.getUpdateType().getShape();

// Find the first `numDimsToDrop` unit dimensions in the update tensor,
// these are the ones that can be dropped.
int64_t numDimsToDrop = 0;
for (auto val : updateShape.drop_front(batchRank)) {
if (val != 1) {
break;
}
++numDimsToDrop;
}
llvm::erase_if(canDrop, [&](unsigned dimPos) {
return dimPos < batchRank || dimPos >= batchRank + numDimsToDrop;
});
if (canDrop.empty()) {
return failure();
}

SmallVector<int64_t> droppedUpdateShape;
droppedUpdateShape.reserve(updateShape.size() - canDrop.size());
for (auto [idx, dimLen] : llvm::enumerate(updateShape)) {
if (!llvm::is_contained(canDrop, idx)) {
droppedUpdateShape.push_back(dimLen);
}
}

auto reassoc =
getReassociationIndicesForCollapse(updateShape, droppedUpdateShape);
assert(reassoc.has_value() && "expected reassociation to be valid");
auto collapseOp = rewriter.create<tensor::CollapseShapeOp>(
scatterOp.getLoc(),
RankedTensorType::get(droppedUpdateShape,
scatterOp.getUpdateType().getElementType()),
scatterOp.getUpdates(), reassoc.value());

constexpr int64_t kUpdateOpNum = 0;
rewriter.modifyOpInPlace(scatterOp, [&]() {
scatterOp.setOperand(kUpdateOpNum, collapseOp.getResult());
});
return success();
}

linalg::ControlDropUnitDims options;
};

} // namespace

/// Return the `reassociation` indices to use to collapse the operand when the
Expand Down Expand Up @@ -708,4 +775,14 @@ void populateFoldReshapeOpsByExpansionPatterns(
patterns.getContext(), controlFoldingReshapes);
}

SmallVector<unsigned> defaultControlDropUnitDims(Operation *op) {
auto fusionOp = cast<LinalgFusionOpInterface>(op);
return llvm::to_vector(llvm::seq<unsigned>(0, fusionOp.getNumLoops()));
}

void populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options) {
patterns.add<FoldScatterUnitDims>(patterns.getContext(), options);
}

} // namespace mlir::iree_compiler::IREE::LinalgExt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ void populateBubbleTransposeFromLinalgExtOps(
RewritePatternSet &patterns,
const linalg::ControlFusionFn &controlFusionFn);

/// Default function to drop unit dims for for linalgext ops.
SmallVector<unsigned> defaultControlDropUnitDims(Operation *op);

/// Drop unit extent dims from linalg ext ops
void populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options);

/// Helper struct to hold the results of collapsing an operation.
struct CollapseResult {
SmallVector<Value> results;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
//===----------------------------------------------------------------------===//

#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Util/Analysis/Explorer.h"
#include "iree/compiler/DispatchCreation/Passes.h"
Expand Down Expand Up @@ -151,9 +153,14 @@ void FoldUnitExtentDimsPass::runOnOperation() {
if (!IREE::Flow::isNonNullAndOutsideDispatch(op)) {
return SmallVector<unsigned>{};
}
if (isa<IREE::LinalgExt::LinalgExtOp>(op)) {
return IREE::LinalgExt::defaultControlDropUnitDims(op);
}
return defaultFn(op);
};
linalg::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, options);
IREE::LinalgExt::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns,
options);
linalg::populateMoveInitOperandsToInputPattern(foldUnitDimsPatterns);
if (failed(applyPatternsAndFoldGreedily(moduleOp,
std::move(foldUnitDimsPatterns)))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,33 @@ module @fold_stream_parameter {
// CHECK: util.global private mutable @[[GLOBAL:.+]] = #stream.parameter.named<"module"::"global"> : tensor<10xf32>
// CHECK: util.func public @fold_stream_parameter
// CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<10xf32>

// -----

util.func public @scatter0(%arg0: tensor<?x1x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
%0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
^bb0(%arg3: f16, %arg4: f16):
iree_linalg_ext.yield %arg3 : f16
} -> tensor<?x2x16x4x128xf16>
util.return %0 : tensor<?x2x16x4x128xf16>
}
// CHECK-LABEL: func public @scatter0
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
// CHECK-SAME: to tensor<?x16x4x128xf16>
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
// CHECK-SAME: ins(%[[COLLAPSE]]

// -----

util.func public @scatter1(%arg0: tensor<?x1x1x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
%0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x1x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
^bb0(%arg3: f16, %arg4: f16):
iree_linalg_ext.yield %arg3 : f16
} -> tensor<?x2x16x4x128xf16>
util.return %0 : tensor<?x2x16x4x128xf16>
}
// CHECK-LABEL: func public @scatter1
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
// CHECK-SAME: to tensor<?x16x4x128xf16>
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
// CHECK-SAME: ins(%[[COLLAPSE]]

0 comments on commit ce0ec3a

Please sign in to comment.