Skip to content

Commit

Permalink
Slight improvement to fusion (#86)
Browse files Browse the repository at this point in the history
As discussed, a small improvement to the fusion algorithm to account for
the case when we have multiple uses of a definition but all those uses
already belong to the group we are in. This allows us to fuse "diamond
patterns" of multiple uses into a single group. See the update to the
lit test for the improvement.

Note that this still is not a "maximal" fusion which can create groups
with multiple returns etc.

---------

Co-authored-by: Srinath Avadhanula <[email protected]>
  • Loading branch information
srinathava and Srinath Avadhanula authored Jul 29, 2024
1 parent 7a5f782 commit 7c50225
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 17 deletions.
47 changes: 39 additions & 8 deletions lib/Dialect/Transforms/FusionPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ LogicalResult
GenericBottomUpFuser::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
Operation *use = op;
bool opIsInsideGroup = op->getParentOfType<tcp::GroupOp>() != nullptr;
bool isChanged = false;
for (auto operand : op->getOperands()) {
if (operand.getDefiningOp()) {
Expand All @@ -36,22 +37,48 @@ GenericBottomUpFuser::matchAndRewrite(Operation *op,
continue;
}

// We only support fusing def ops that have exactly one use, for
// now. Special-case the uses of the def in
// tcp.bind_symbolic_shape
bool cannotFuse = false;
SmallVector<tcp::BindSymbolicShapeOp> bindSymbolicUsersOfDef;
SmallVector<Operation *> otherUses;
for (auto otherUserOfDef : def->getUsers()) {
if (auto bindSymbolicShapeOp =
dyn_cast<tcp::BindSymbolicShapeOp>(otherUserOfDef)) {
bindSymbolicUsersOfDef.push_back(bindSymbolicShapeOp);
} else if (otherUserOfDef != use) {
cannotFuse = true;
break;
} else {
otherUses.push_back(otherUserOfDef);
}
}

if (cannotFuse)
// Check that all the uses of this def are still valid after we
// move the def op. If there's a single use, its always safe to
// fuse with the def. For the case when we have more than 1 use,
// see below for when it is safe to fuse with the def.
bool areUsesValidForFusion = false;
if (otherUses.size() > 1) {
// If we have more than one use, either
// 1. All those uses are used by the current op
if (llvm::all_of(otherUses,
[&](Operation *userOp) { return userOp == op; }))
areUsesValidForFusion = true;

// 2. All those uses are in the same group as the current op.
// NOTE: We are checking here that the original op is already
// inside a group and that all the other uses of this def are in
// that group. That means that we can safely move this def to the
// beginning of the group.
//
// We cannot do this if the use is not inside a group because
// then we are creating a new group.
if (opIsInsideGroup &&
llvm::all_of(otherUses, [&](Operation *userOp) {
return userOp->getParentRegion() == op->getParentRegion();
}))
areUsesValidForFusion = true;
} else if (otherUses.size() == 1) {
// If we have exactly one use, then we can fuse.
areUsesValidForFusion = true;
}

if (!areUsesValidForFusion)
continue;

// Fuse the def and use ops into a group.
Expand Down Expand Up @@ -84,6 +111,10 @@ GenericBottomUpFuser::matchAndRewrite(Operation *op,
def->moveBefore(use);
}
} else if (auto groupOp = dyn_cast<tcp::GroupOp>(use->getParentOp())) {
// We already know that all other uses are in the same group
// and because we are doing this bottom up, this is the "first"
// use of this op in this group. So its OK to move it to just
// before this use.
def->moveBefore(use);
} else {
llvm_unreachable("Unhandled case during fusion");
Expand Down
34 changes: 25 additions & 9 deletions test/Dialect/tcp_fusion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,13 @@ func.func @test_multiple_fusions(%arg0 : tensor<?x?xf32>,

// CHECK: func.func @test_multi_use_fusion(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[V0:.+]] = tcp.group {
// CHECK: %[[V2:.+]] = tcp.tanh %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V3:.+]] = tcp.add %[[V2]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: tcp.yield %[[V3]] : tensor<?x?xf32>
// CHECK: } : tensor<?x?xf32>
// CHECK: %[[V1:.+]] = tcp.group {
// CHECK: %[[V2]] = tcp.sub %[[V0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V3]] = tcp.mul %[[V0]], %[[V2]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: tcp.yield %[[V3]] : tensor<?x?xf32>
// CHECK: %[[V1:.+]] = tcp.tanh %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V2:.+]] = tcp.add %[[V1]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V3:.+]] = tcp.sub %[[V2]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V4:.+]] = tcp.mul %[[V2]], %[[V3]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: tcp.yield %[[V4]] : tensor<?x?xf32>
// CHECK: } : tensor<?x?xf32>
// CHECK: return %[[V1]] : tensor<?x?xf32>
// CHECK: return %[[V0]] : tensor<?x?xf32>
// CHECK: }
func.func @test_multi_use_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = tcp.tanh %arg0 : tensor<?x?xf32> -> tensor<?x?xf32>
Expand Down Expand Up @@ -207,3 +204,22 @@ func.func @buggy_tcp_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) ->
%6 = tcp.custom_op("test.op") %5 : tensor<?x?xf32> -> tensor<?x?xf32>
return %2 : tensor<?x?xf32>
}

// -----

// Make sure that things do not break if a value is used twice by the same
// op.

// CHECK: func.func @test_multi_use_fusion_same_op_uses(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[V0:.+]] = tcp.group {
// CHECK: %[[V1:.+]] = tcp.tanh %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V2:.+]] = tcp.mul %[[V1]], %[[V1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: tcp.yield %[[V2]] : tensor<?x?xf32>
// CHECK: } : tensor<?x?xf32>
// CHECK: return %[[V0]] : tensor<?x?xf32>
// CHECK: }
func.func @test_multi_use_fusion_same_op_uses(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = tcp.tanh %arg0 : tensor<?x?xf32> -> tensor<?x?xf32>
%3 = tcp.mul %0, %0 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
return %3 : tensor<?x?xf32>
}

0 comments on commit 7c50225

Please sign in to comment.