Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert multi-use fusion algorithm #85

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 75 additions & 161 deletions lib/Dialect/Transforms/FusionPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,175 +12,89 @@
#include "mlir-tcp/Dialect/IR/TcpOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/OpDefinition.h"
#include "llvm/Support/Debug.h"

#ifndef NDEBUG
#define DEBUG_TYPE "tcp-fusion-patterns"
#endif

namespace mlir::tcp {

LogicalResult
GenericBottomUpFuser::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {

// Currently we are only fusing ops at the top-level.
// This is to avoid recursing inside a group and ending up with
// nested groups that contain the same ops.
// Since we are iterating bottom up in a block, we only need to check
// if the def op has a func parent.
//
// TODO: Remove this restriction to allow fusing in nested regions.
if (!isa<func::FuncOp>(op->getParentOp()))
return failure();

if (op->use_empty())
return failure();

// We can only fuse a def with multiple uses if all the uses belong to the
// same region and can be fused with the defining op
Region *usesParentRegion = nullptr;
SmallVector<Operation *> uses;
llvm::DenseSet<Operation *> usesSet;
llvm::DenseSet<tcp::BindSymbolicShapeOp> bindShapeUses;

LLVM_DEBUG(llvm::dbgs() << "Processing op: " << *op << "\n");
for (auto &use : op->getUses()) {
if (auto bindShapeOp = dyn_cast<tcp::BindSymbolicShapeOp>(use.getOwner())) {
bindShapeUses.insert(bindShapeOp);
continue;
}

auto parentRegion = use.getOwner()->getParentRegion();
if (usesParentRegion && usesParentRegion != parentRegion)
return failure();
usesParentRegion = parentRegion;

if (!canFuse(op, use.getOwner()))
return failure();

if (usesSet.insert(use.getOwner()).second)
uses.push_back(use.getOwner());
}

// All its uses are tcp.bind_symbolic_shape ops.
if (uses.empty())
return failure();

// Sorting by dominance ensures that the first element of this vector is
// the first use of the def. Used below when we want to move the op into
// an existing group.
LLVM_DEBUG(llvm::dbgs() << "Processing op: " << *op << " with " << uses.size()
<< " uses\n");
DominanceInfo domInfo;
llvm::stable_sort(uses, [&](Operation *a, Operation *b) {
return domInfo.dominates(a, b);
});

#ifndef NDEBUG
for (auto use : uses) {
LLVM_DEBUG(llvm::dbgs() << "Use: " << *use << "\n");
}
#endif

if (op->getParentRegion() == usesParentRegion) {
LLVM_DEBUG(llvm::dbgs() << "Creating new group\n");
// this case can only happen when all ops belong to the function.
SmallVector<Type> allResultTypes;
SmallVector<Value> allResults;
for (auto use : uses) {
allResultTypes.append(use->getResultTypes().begin(),
use->getResultTypes().end());
allResults.append(use->getResults().begin(), use->getResults().end());
}

auto groupOp = rewriter.create<tcp::GroupOp>(op->getLoc(), allResultTypes);
if (postFunc) {
postFunc(groupOp, rewriter);
}
Block *groupBlock = new Block();
groupOp.getBody().push_back(groupBlock);

// First move all uses into the group in the dominance order
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(groupBlock);
auto yieldOp = rewriter.create<tcp::YieldOp>(op->getLoc(), allResults);
// This is where we are using the sorted-ness of `uses`. We are
// guaranteed that if the users of the op themselves depend on each
// other, then we'll move them in the correct order.
for (auto use : uses) {
use->moveBefore(yieldOp);
}
op->moveBefore(*uses.begin());
for (auto bindShapeOp : bindShapeUses) {
bindShapeOp->moveAfter(op);
Operation *use = op;
bool isChanged = false;
for (auto operand : op->getOperands()) {
if (operand.getDefiningOp()) {
Operation *def = operand.getDefiningOp();
if (canFuse(def, use)) {
// Currently we are only fusing ops at the top-level.
// This is to avoid recursing inside a group and ending up with
// nested groups that contain the same ops.
// Since we are iterating bottom up in a block, we only need to
// check if the def op has a func parent.
//
// TODO: Remove this restriction to allow fusing in nested
// regions.
if (!isa<func::FuncOp>(def->getParentOp())) {
continue;
}

// We only support fusing def ops that have exactly one use, for
// now. Special-case the uses of the def in
// tcp.bind_symbolic_shape
bool cannotFuse = false;
SmallVector<tcp::BindSymbolicShapeOp> bindSymbolicUsersOfDef;
for (auto otherUserOfDef : def->getUsers()) {
if (auto bindSymbolicShapeOp =
dyn_cast<tcp::BindSymbolicShapeOp>(otherUserOfDef)) {
bindSymbolicUsersOfDef.push_back(bindSymbolicShapeOp);
} else if (otherUserOfDef != use) {
cannotFuse = true;
break;
}
}

if (cannotFuse)
continue;

// Fuse the def and use ops into a group.

// * If both the ops have the same parent region, they must be
// part
// of the top-level func. So, we need to create a new group.
// * The only other case is when the def op is part of the
// top-level
// func and the use is already inside a group.
isChanged = true;
if (def->getParentRegion() == use->getParentRegion()) {
auto groupOp = rewriter.create<tcp::GroupOp>(use->getLoc(),
use->getResultTypes());
if (postFunc) {
postFunc(groupOp, rewriter);
}
Block *groupBlock = new Block();
groupOp.getBody().push_back(groupBlock);
for (unsigned num = 0; num < use->getNumResults(); ++num) {
rewriter.replaceAllUsesWith(use->getResult(num),
groupOp->getResult(num));
}
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(groupBlock);
auto yieldOp =
rewriter.create<tcp::YieldOp>(use->getLoc(), use->getResults());
use->moveBefore(yieldOp);
def->moveBefore(use);
}
} else if (auto groupOp = dyn_cast<tcp::GroupOp>(use->getParentOp())) {
def->moveBefore(use);
} else {
llvm_unreachable("Unhandled case during fusion");
}

for (auto bindSymbolicShapeOp : bindSymbolicUsersOfDef) {
bindSymbolicShapeOp->moveAfter(def);
}
}
}

// We then replace all uses of the uses which lie outside the group
// with the group's results. We should not replace uses inside the
// group otherwise ops inside the group will end up depending on the
// group's results causing dominance issues.
size_t groupResultNum = 0;
for (auto use : uses) {
for (unsigned num = 0; num < use->getNumResults(); ++num) {
auto useIsOutsideGroup = [&](OpOperand &operand) {
return operand.getOwner()->getParentOp() != groupOp;
};
rewriter.replaceUsesWithIf(use->getResult(num),
groupOp->getResult(groupResultNum),
useIsOutsideGroup);
groupResultNum++;
}
}

} else if (auto groupOp =
dyn_cast<tcp::GroupOp>(usesParentRegion->getParentOp())) {
// Given that we iterate over the funcop in a bottom up manner, when moving
// into an existing group, we would be guaranteed that this op does not use
// any of the ops already in the group. So we can move it to the very
// beginning of the group. This ensures that the order of operands is
// preserved when creating a group. For example, if we start with
// something like:
//
// %0 = op1(%in1)
// %1 = op2(%in2)
// %2 = op3(%0, %1)
//
// we'll first create a %1 and %2
//
// %0 = op1(%in1)
// %3 = tcp.group {
// %1 = op2(%in2)
// %2 = op3(%0, %1)
// }
//
// if we try to move %0 to right before its use in the group, then we'd
// end up with:
//
// %3 = tcp.group {
// %1 = op2(%in2)
// %0 = op1(%in1)
// %2 = op3(%0, %1)
// }
//
// While this is not incorrect, it is a bit annoying that the MLIR gets
// reordered.
auto &firstOp = *usesParentRegion->getOps().begin();
op->moveBefore(&firstOp);
for (auto bindShapeOp : bindShapeUses) {
bindShapeOp->moveBefore(&firstOp);
}
} else {
op->emitError("Unhandled case during fusion");
llvm_unreachable("Unhandled case during fusion");
}
LLVM_DEBUG(llvm::dbgs() << "Function after transformation:\n"
<< op->getParentOfType<func::FuncOp>() << "\n");
return success();
return isChanged ? success() : failure();
}

} // namespace mlir::tcp
100 changes: 70 additions & 30 deletions test/Dialect/tcp_fusion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,19 @@ func.func @test_multiple_fusions(%arg0 : tensor<?x?xf32>,

// -----

// Fusion with multiple uses where the def with multiple uses moves into an
// already created group.

// CHECK: func.func @test_multi_use_fusion(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[V0:.+]] = tcp.group {
// CHECK: %[[V1:.+]] = tcp.tanh %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V2:.+]] = tcp.add %[[V1]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V3:.+]] = tcp.sub %[[V2]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V4:.+]] = tcp.mul %[[V2]], %[[V3]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: tcp.yield %[[V4]] : tensor<?x?xf32>
// CHECK: %[[V2:.+]] = tcp.tanh %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V3:.+]] = tcp.add %[[V2]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: tcp.yield %[[V3]] : tensor<?x?xf32>
// CHECK: } : tensor<?x?xf32>
// CHECK: %[[V1:.+]] = tcp.group {
// CHECK: %[[V2]] = tcp.sub %[[V0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V3]] = tcp.mul %[[V0]], %[[V2]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: tcp.yield %[[V3]] : tensor<?x?xf32>
// CHECK: } : tensor<?x?xf32>
// CHECK: return %[[V0]] : tensor<?x?xf32>
// CHECK: return %[[V1]] : tensor<?x?xf32>
// CHECK: }
func.func @test_multi_use_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = tcp.tanh %arg0 : tensor<?x?xf32> -> tensor<?x?xf32>
Expand All @@ -77,18 +78,20 @@ func.func @test_multi_use_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32

// -----

// Fusion with multiple uses where the def and the multiple uses create a
// new group. Here we test that the moves use the dominance correctly.
// This and the previous test used to create a single fused group in
// earlier versions of the fusion algorithm. However, that algorithm had a
// bug causing us to revert to a simpler algo which does not create a
// single group for this sequence.

// CHECK: func.func @test_multi_use_fusion(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
// CHECK: %[[V0:.+]]:2 = tcp.group {
// CHECK: %[[V1:.+]] = tcp.tanh %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V2:.+]] = tcp.add %[[V1]], %[[V1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V3:.+]] = tcp.sub %[[V2]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V4:.+]] = tcp.mul %[[V2]], %[[V3]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: tcp.yield %[[V3]], %[[V4]] : tensor<?x?xf32>, tensor<?x?xf32>
// CHECK: } : tensor<?x?xf32>, tensor<?x?xf32>
// CHECK: return %[[V0]]#0, %[[V0]]#1 : tensor<?x?xf32>, tensor<?x?xf32>
// CHECK: %[[V0:.+]] = tcp.group {
// CHECK: %[[V3:.+]] = tcp.tanh %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V4:.+]] = tcp.add %[[V3]], %[[V3]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: tcp.yield %[[V4]] : tensor<?x?xf32>
// CHECK: } : tensor<?x?xf32>
// CHECK: %[[V1:.+]] = tcp.sub %[[V0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V2:.+]] = tcp.mul %[[V0]], %[[V1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: return %[[V1]], %[[V2]] : tensor<?x?xf32>, tensor<?x?xf32>
// CHECK: }
func.func @test_multi_use_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
%0 = tcp.tanh %arg0 : tensor<?x?xf32> -> tensor<?x?xf32>
Expand Down Expand Up @@ -139,18 +142,18 @@ func.func @test_fusion_with_symbolic_shape(%arg0 : tensor<?x?xf32>, %arg1 : tens
// CHECK: %[[V0:.+]] = tcp.symbolic_int "s0" {min_val = 2, max_val = 9223372036854775806} : i64
// CHECK: %[[V1:.+]] = tcp.symbolic_int "s1" {min_val = 2, max_val = 9223372036854775806} : i64
// CHECK: tcp.bind_symbolic_shape %[[ARG0]], [%[[V0]], %[[V1]]], affine_map<()[s0, s1] -> (s0, s1)> : tensor<?x?xf32>
// CHECK: %[[V2:.+]]:2 = tcp.group {
// CHECK: %[[V3:.+]] = tcp.tanh %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: tcp.bind_symbolic_shape %[[V3]], [%[[V0]], %[[V1]]], affine_map<()[s0, s1] -> (s0, s1)> : tensor<?x?xf32>
// CHECK: %[[V4:.+]] = tcp.add %[[V3]], %[[V3]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: tcp.bind_symbolic_shape %[[V4]], [%[[V0]], %[[V1]]], affine_map<()[s0, s1] -> (s0, s1)> : tensor<?x?xf32>
// CHECK: %[[V5:.+]] = tcp.sub %[[V4]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V6:.+]] = tcp.mul %[[V4]], %[[V5]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: tcp.yield %[[V5]], %[[V6]] : tensor<?x?xf32>, tensor<?x?xf32>
// CHECK: } : tensor<?x?xf32>, tensor<?x?xf32>
// CHECK: tcp.bind_symbolic_shape %[[V2]]#0, [%[[V0]], %[[V1]]], affine_map<()[s0, s1] -> (s0, s1)> : tensor<?x?xf32>
// CHECK: tcp.bind_symbolic_shape %[[V2]]#1, [%[[V0]], %[[V1]]], affine_map<()[s0, s1] -> (s0, s1)> : tensor<?x?xf32>
// CHECK: return %[[V2]]#0, %[[V2]]#1 : tensor<?x?xf32>, tensor<?x?xf32>
// CHECK: %[[V2:.+]] = tcp.group {
// CHECK: %[[V5:.+]] = tcp.tanh %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: tcp.bind_symbolic_shape %[[V5]], [%[[V0]], %[[V1]]], affine_map<()[s0, s1] -> (s0, s1)> : tensor<?x?xf32>
// CHECK: %[[V6:.+]] = tcp.add %[[V5]], %[[V5]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: tcp.yield %[[V6]] : tensor<?x?xf32>
// CHECK: } : tensor<?x?xf32>
// CHECK: tcp.bind_symbolic_shape %[[V2]], [%[[V0]], %[[V1]]], affine_map<()[s0, s1] -> (s0, s1)> : tensor<?x?xf32>
// CHECK: %[[V3:.+]] = tcp.sub %[[V2]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: tcp.bind_symbolic_shape %[[V3]], [%[[V0]], %[[V1]]], affine_map<()[s0, s1] -> (s0, s1)> : tensor<?x?xf32>
// CHECK: %[[V4:.+]] = tcp.mul %[[V2]], %[[V3]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: tcp.bind_symbolic_shape %[[V4]], [%[[V0]], %[[V1]]], affine_map<()[s0, s1] -> (s0, s1)> : tensor<?x?xf32>
// CHECK: return %[[V3]], %[[V4]] : tensor<?x?xf32>, tensor<?x?xf32>
// CHECK: }
func.func @test_multi_use_fusion_with_sym_shapes(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
%s0 = tcp.symbolic_int "s0" {min_val = 2, max_val = 9223372036854775806} : i64
Expand All @@ -167,3 +170,40 @@ func.func @test_multi_use_fusion_with_sym_shapes(%arg0 : tensor<?x?xf32>, %arg1
tcp.bind_symbolic_shape %3, [%s0, %s1], affine_map<()[s0, s1] -> (s0, s1)> : tensor<?x?xf32>
"func.return" (%2, %3) : (tensor<?x?xf32>, tensor<?x?xf32>) -> ()
}


// -----

// This test shows why iterating over all the users of an op and then
// fusing them together might lead to bugs. In this case, %0 is only used
// by %2 and %5 and they are all element-wise ops. However, if we create a
// tcp.group for them, there's no correct place to put the newly created
// tcp.group without violating dominance for the other operands and uses of
// %2 and %5.
//
// This change shows the need to start from a op and only look at its
// operands to start a fusion operation.


// CHECK: func.func @buggy_tcp_fusion(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[V0:.+]] = tcp.tanh %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V1:.+]] = tcp.custom_op("test.op") %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V2:.+]] = tcp.add %[[V0]], %[[V1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V3:.+]] = tcp.custom_op("test.op") %[[V2]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V4:.+]] = tcp.custom_op("test.op") %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V5:.+]] = tcp.mul %[[V0]], %[[V4]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V6:.+]] = tcp.custom_op("test.op") %[[V5]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: return %[[V2]] : tensor<?x?xf32>
// CHECK: }
func.func @buggy_tcp_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> (tensor<?x?xf32>) {
%0 = tcp.tanh %arg0 : tensor<?x?xf32> -> tensor<?x?xf32>

%1 = tcp.custom_op("test.op") %arg0 : tensor<?x?xf32> -> tensor<?x?xf32>
%2 = tcp.add %0, %1 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
%3 = tcp.custom_op("test.op") %2 : tensor<?x?xf32> -> tensor<?x?xf32>

%4 = tcp.custom_op("test.op") %arg0 : tensor<?x?xf32> -> tensor<?x?xf32>
%5 = tcp.mul %0, %4 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
%6 = tcp.custom_op("test.op") %5 : tensor<?x?xf32> -> tensor<?x?xf32>
return %2 : tensor<?x?xf32>
}