Skip to content

Commit

Permalink
Fix two problems from sharktank tests
Browse files Browse the repository at this point in the history
(1) It wasn't checking if the new reassoc was appended before calling
`.back()`. (2) Unit dims cannot be dropped if they weren't expanded.

Signed-off-by: Ian Wood <[email protected]>
  • Loading branch information
IanWood1 committed Dec 4, 2024
1 parent 80efa36 commit 740a981
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 9 deletions.
33 changes: 24 additions & 9 deletions compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand Down Expand Up @@ -70,6 +71,14 @@ struct DropUnitDimsFromCollapseOfExpand
}
}

const auto expandReassoc = expandOp.getReassociationIndices();
for (const auto &[inDim, indicies] : llvm::enumerate(expandReassoc)) {
// Can't drop unit dim if it isn't from an expansion.
if (indicies.size() == 1) {
toDrop.erase(indicies[0]);
}
}

if (toDrop.empty()) {
return rewriter.notifyMatchFailure(collapseOp,
"Didn't find any unit dims to drop");
Expand All @@ -83,6 +92,7 @@ struct DropUnitDimsFromCollapseOfExpand
}
}

/// Returns true if new `ReassociationIndices` were appended to `reassoc`.
auto appendDroppedReassocation =
[&toDrop](SmallVector<ReassociationIndices, 4> &reassoc, int64_t start,
int64_t count, int64_t origStart) {
Expand All @@ -99,7 +109,9 @@ struct DropUnitDimsFromCollapseOfExpand
// All indicies have been dropped.
if (indicies.empty()) {
reassoc.pop_back();
return false;
}
return true;
};

auto dropOutputOfr = [&toDrop](const SmallVector<OpFoldResult> &sizes) {
Expand All @@ -118,21 +130,24 @@ struct DropUnitDimsFromCollapseOfExpand
SmallVector<ReassociationIndices, 4> newCollapseReassoc;
int64_t collapsedDim = 0;
for (auto dim : llvm::seq<int64_t>(0, outShape.size())) {
appendDroppedReassocation(newCollapseReassoc, collapsedDim,
collapseReassoc[dim].size(),
collapseReassoc[dim].front());
collapsedDim += newCollapseReassoc.back().size();
bool changed = appendDroppedReassocation(newCollapseReassoc, collapsedDim,
collapseReassoc[dim].size(),
collapseReassoc[dim].front());
if (changed) {
collapsedDim += newCollapseReassoc.back().size();
}
}

const auto expandReassoc = expandOp.getReassociationIndices();
SmallVector<ReassociationIndices, 4> newExpandReassoc;
ArrayRef<int64_t> srcShape = expandOp.getSrcType().getShape();
int64_t expandedDim = 0;
for (auto dim : llvm::seq<int64_t>(0, srcShape.size())) {
appendDroppedReassocation(newExpandReassoc, expandedDim,
expandReassoc[dim].size(),
expandReassoc[dim].front());
expandedDim += newExpandReassoc.back().size();
bool changed = appendDroppedReassocation(newExpandReassoc, expandedDim,
expandReassoc[dim].size(),
expandReassoc[dim].front());
if (changed) {
expandedDim += newExpandReassoc.back().size();
}
}

auto outputSizes = getMixedValues(expandOp.getStaticOutputShape(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ util.func @collapse_of_expand_0(%arg0: tensor<?x128xf16>, %arg1: index) -> tenso
%collapsed = tensor.collapse_shape %expanded [[0], [1, 2, 3], [4]] : tensor<4x?x1x1x128xf16> into tensor<4x?x128xf16>
util.return %collapsed : tensor<4x?x128xf16>
}

// CHECK-LABEL: util.func public @collapse_of_expand_0
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x128xf16>, %[[ARG1:.+]]: index
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
Expand All @@ -128,6 +129,7 @@ util.func @collapse_of_expand_1(%arg0: tensor<?x128xf16>, %arg1: index) -> tenso
%collapsed = tensor.collapse_shape %expanded [[0], [1, 2, 3], [4]] : tensor<4x?x1x2x64xf16> into tensor<4x?x64xf16>
util.return %collapsed : tensor<4x?x64xf16>
}

// CHECK-LABEL: util.func public @collapse_of_expand_1
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x128xf16>, %[[ARG1:.+]]: index
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
Expand Down Expand Up @@ -175,3 +177,17 @@ util.func @collapse_of_expand_4(%arg0: tensor<1x1xf16>, %arg1: index, %arg2: ind
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]]
// CHECK-SAME: tensor<1x1xf16> into tensor<1xf16>
// CHECK: util.return %[[COLLAPSED]] : tensor<1xf16>

// -----

util.func @collapse_of_expand_5(%arg0: tensor<1x?x4x32xf16>, %arg1: index) -> tensor<?x4x32xf16> {
%expanded = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [1, %arg1, 4, 1, 32] : tensor<1x?x4x32xf16> into tensor<1x?x4x1x32xf16>
%collapsed = tensor.collapse_shape %expanded [[0, 1], [2, 3], [4]] : tensor<1x?x4x1x32xf16> into tensor<?x4x32xf16>
util.return %collapsed : tensor<?x4x32xf16>
}

// CHECK-LABEL: util.func public @collapse_of_expand_5
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x4x32xf16>
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]]
// CHECK-SAME: tensor<1x?x4x32xf16> into tensor<?x4x32xf16>
// CHECK: util.return %[[COLLAPSED]] : tensor<?x4x32xf16>

0 comments on commit 740a981

Please sign in to comment.