diff --git a/lib/Dialect/Transforms/FusionPatterns.cpp b/lib/Dialect/Transforms/FusionPatterns.cpp index 56c90ef..fcd3b03 100644 --- a/lib/Dialect/Transforms/FusionPatterns.cpp +++ b/lib/Dialect/Transforms/FusionPatterns.cpp @@ -19,6 +19,7 @@ LogicalResult GenericBottomUpFuser::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { Operation *use = op; + bool opIsInsideGroup = op->getParentOfType() != nullptr; bool isChanged = false; for (auto operand : op->getOperands()) { if (operand.getDefiningOp()) { @@ -36,22 +37,48 @@ GenericBottomUpFuser::matchAndRewrite(Operation *op, continue; } - // We only support fusing def ops that have exactly one use, for - // now. Special-case the uses of the def in - // tcp.bind_symbolic_shape - bool cannotFuse = false; SmallVector bindSymbolicUsersOfDef; + SmallVector otherUses; for (auto otherUserOfDef : def->getUsers()) { if (auto bindSymbolicShapeOp = dyn_cast(otherUserOfDef)) { bindSymbolicUsersOfDef.push_back(bindSymbolicShapeOp); - } else if (otherUserOfDef != use) { - cannotFuse = true; - break; + } else { + otherUses.push_back(otherUserOfDef); } } - if (cannotFuse) + // Check that all the uses of this def are still valid after we + // move the def op. If there's a single use, its always safe to + // fuse with the def. For the case when we have more than 1 use, + // see below for when it is safe to fuse with the def. + bool areUsesValidForFusion = false; + if (otherUses.size() > 1) { + // If we have more than one use, either + // 1. All those uses are used by the current op + if (llvm::all_of(otherUses, + [&](Operation *userOp) { return userOp == op; })) + areUsesValidForFusion = true; + + // 2. All those uses are in the same group as the current op. + // NOTE: We are checking here that the original op is already + // inside a group and that all the other uses of this def are in + // that group. That means that we can safely move this def to the + // beginning of the group. + // + // We cannot do this if the use is not inside a group because + // then we are creating a new group. + if (opIsInsideGroup && + llvm::all_of(otherUses, [&](Operation *userOp) { + return userOp->getParentRegion() == op->getParentRegion(); + })) + areUsesValidForFusion = true; + } else if (otherUses.size() == 1) { + // If we have exactly one use, then we can fuse. + areUsesValidForFusion = true; + } + + if (!areUsesValidForFusion) continue; // Fuse the def and use ops into a group. @@ -84,6 +111,10 @@ GenericBottomUpFuser::matchAndRewrite(Operation *op, def->moveBefore(use); } } else if (auto groupOp = dyn_cast(use->getParentOp())) { + // We already know that all other uses are in the same group + // and because we are doing this bottom up, this is the "first" + // use of this op in this group. So its OK to move it to just + // before this use. def->moveBefore(use); } else { llvm_unreachable("Unhandled case during fusion"); diff --git a/test/Dialect/tcp_fusion.mlir b/test/Dialect/tcp_fusion.mlir index de50ad5..fbdc433 100644 --- a/test/Dialect/tcp_fusion.mlir +++ b/test/Dialect/tcp_fusion.mlir @@ -57,16 +57,13 @@ func.func @test_multiple_fusions(%arg0 : tensor, // CHECK: func.func @test_multi_use_fusion(%[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor) -> tensor { // CHECK: %[[V0:.+]] = tcp.group { -// CHECK: %[[V2:.+]] = tcp.tanh %[[ARG0]] : tensor -> tensor -// CHECK: %[[V3:.+]] = tcp.add %[[V2]], %[[ARG1]] : tensor, tensor -> tensor -// CHECK: tcp.yield %[[V3]] : tensor -// CHECK: } : tensor -// CHECK: %[[V1:.+]] = tcp.group { -// CHECK: %[[V2]] = tcp.sub %[[V0]], %[[ARG1]] : tensor, tensor -> tensor -// CHECK: %[[V3]] = tcp.mul %[[V0]], %[[V2]] : tensor, tensor -> tensor -// CHECK: tcp.yield %[[V3]] : tensor +// CHECK: %[[V1:.+]] = tcp.tanh %[[ARG0]] : tensor -> tensor +// CHECK: %[[V2:.+]] = tcp.add %[[V1]], %[[ARG1]] : tensor, tensor -> tensor +// CHECK: %[[V3:.+]] = tcp.sub %[[V2]], %[[ARG1]] : tensor, tensor -> tensor +// CHECK: %[[V4:.+]] = tcp.mul %[[V2]], %[[V3]] : tensor, tensor -> tensor +// CHECK: tcp.yield %[[V4]] : tensor // CHECK: } : tensor -// CHECK: return %[[V1]] : tensor +// CHECK: return %[[V0]] : tensor // CHECK: } func.func @test_multi_use_fusion(%arg0 : tensor, %arg1 : tensor) -> tensor { %0 = tcp.tanh %arg0 : tensor -> tensor @@ -207,3 +204,22 @@ func.func @buggy_tcp_fusion(%arg0 : tensor, %arg1 : tensor) -> %6 = tcp.custom_op("test.op") %5 : tensor -> tensor return %2 : tensor } + +// ----- + +// Make sure that things do not break if a value is used twice by the same +// op. + +// CHECK: func.func @test_multi_use_fusion_same_op_uses(%[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor) -> tensor { +// CHECK: %[[V0:.+]] = tcp.group { +// CHECK: %[[V1:.+]] = tcp.tanh %[[ARG0]] : tensor -> tensor +// CHECK: %[[V2:.+]] = tcp.mul %[[V1]], %[[V1]] : tensor, tensor -> tensor +// CHECK: tcp.yield %[[V2]] : tensor +// CHECK: } : tensor +// CHECK: return %[[V0]] : tensor +// CHECK: } +func.func @test_multi_use_fusion_same_op_uses(%arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = tcp.tanh %arg0 : tensor -> tensor + %3 = tcp.mul %0, %0 : tensor, tensor -> tensor + return %3 : tensor +}