diff --git a/xla/service/hlo_dce.cc b/xla/service/hlo_dce.cc index 1c6050554ed35..5617190d36b85 100644 --- a/xla/service/hlo_dce.cc +++ b/xla/service/hlo_dce.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -79,11 +80,9 @@ bool IsRemovableWhile(HloInstruction* instruction, !computation->root_instruction()->has_sharding() && fusion_instruction->output_operand_aliasing().empty() && !fusion_instruction->HasControlDependencies() && - fusion_instruction->user_count() < - computation->root_instruction()->operand_count() && !fusion_instruction->IsCustomFusion()) { - std::vector used_tuple_elements; - used_tuple_elements.reserve(fusion_instruction->user_count()); + // The order of the used outputs is relevant for the algorithm below. + std::set used_tuple_elements; // We only support this cleanup if all users of the fusion instruction are // GetTupleElement ops, and there is at least one user of // 'fusion_instruction'. @@ -93,10 +92,16 @@ bool IsRemovableWhile(HloInstruction* instruction, supported = false; break; } - used_tuple_elements.push_back(gte->tuple_index()); + used_tuple_elements.insert(gte->tuple_index()); } + + // If all outputs are used, nothing to clean up. + if (used_tuple_elements.size() == + computation->root_instruction()->operand_count()) { + supported = false; + } + if (supported) { - std::sort(used_tuple_elements.begin(), used_tuple_elements.end()); std::vector tuple_shapes; tuple_shapes.reserve(used_tuple_elements.size()); for (int64_t tuple_index : used_tuple_elements) { @@ -119,18 +124,23 @@ bool IsRemovableWhile(HloInstruction* instruction, gte->set_tuple_index(new_tuple_index); } } else { - HloInstruction* gte = fusion_instruction->users()[0]; - // Replace and change control successors to be dependent on the fusion - // instruction itself. - TF_ASSIGN_OR_RETURN(bool replaced, - gte->parent()->ReplaceInstruction( - gte, fusion_instruction, - /*preserve_sharding=*/true, - /*relay_control_dependency=*/true)); - if (replaced) { - changed |= replaced; + // Since we iterate over users while removing them .. make a local copy + // first. + std::vector users(fusion_instruction->users()); + for (HloInstruction* gte : users) { + // Replace and change control successors to be dependent on the fusion + // instruction itself. + TF_ASSIGN_OR_RETURN(bool replaced, + gte->parent()->ReplaceInstruction( + gte, fusion_instruction, + /*preserve_sharding=*/true, + /*relay_control_dependency=*/true)); + if (replaced) { + changed |= replaced; + } } } + // Update the root of the fusion computation. if (tuple_shapes.size() > 1) { std::vector new_operands; @@ -147,7 +157,7 @@ bool IsRemovableWhile(HloInstruction* instruction, TF_RETURN_IF_ERROR( computation->root_instruction()->ReplaceAllUsesWithDifferentShape( computation->root_instruction()->mutable_operand( - used_tuple_elements[0]))); + *used_tuple_elements.begin()))); } } } diff --git a/xla/service/hlo_dce_test.cc b/xla/service/hlo_dce_test.cc index 38a170ae77160..765af9fecbeb8 100644 --- a/xla/service/hlo_dce_test.cc +++ b/xla/service/hlo_dce_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -650,6 +651,114 @@ TEST_F(HloDceTest, MultiOutputFusionRemoveUnusedTupleElementsRemoveTuple) { EXPECT_EQ(module->MakeComputationPostOrder().size(), 2); } +TEST_F( + HloDceTest, + MultiOutputFusionRemoveUnusedTupleElementsRemoveTupleMultiUsersPerOutput) { + constexpr char kHloString[] = R"( + HloModule test_module + fused_add { + p0 = f32[32,32]{1,0} parameter(0) + p1 = f32[32,32]{1,0} parameter(1) + p2 = f32[32,32]{1,0} parameter(2) // becomes dead + add = f32[32,32]{1,0} add(p0, p1) + ROOT res = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(p2, add, p2) + } + + ENTRY reduce { + param0 = f32[32,32]{1,0} parameter(0) + param1 = f32[32,32]{1,0} parameter(1) + param2 = f32[32,32]{1,0} parameter(2) + fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(param0, param1, param2), kind=kLoop, calls=fused_add + gte.1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1 + gte.1.again = f32[32,32]{1,0} get-tuple-element(fusion), index=1 + ROOT res = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(gte.1, gte.1.again) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString)); + HloDCE dce; + auto changed = dce.Run(module.get()); + ASSERT_TRUE(changed.ok()); + EXPECT_TRUE(*changed); + + HloInstruction* gte_0 = FindInstruction(module.get(), "gte.0"); + EXPECT_EQ(gte_0, nullptr); + HloInstruction* gte_1 = FindInstruction(module.get(), "gte.1"); + EXPECT_EQ(gte_1, nullptr); + HloInstruction* gte_1_again = FindInstruction(module.get(), "gte.1.again"); + EXPECT_EQ(gte_1_again, nullptr); + + HloInstruction* fusion = FindInstruction(module.get(), "fusion"); + ASSERT_NE(fusion, nullptr); + EXPECT_FALSE(fusion->shape().IsTuple()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kTuple); + EXPECT_EQ(root->operand_count(), 2); + EXPECT_EQ(root->operand(0), fusion); + EXPECT_EQ(root->operand(1), fusion); +} + +TEST_F( + HloDceTest, + MultiOutputFusionRemoveUnusedTupleElementsRemoveTupleNonContiguousRemoval) { + constexpr char kHloString[] = R"( + HloModule test_module + fused_add { + p0 = f32[32,32]{1,0} parameter(0) + p1 = f32[32,32]{1,0} parameter(1) + p2 = f32[32,32]{1,0} parameter(2) // becomes dead + add = f32[32,32]{1,0} add(p0, p1) + ROOT res = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(p2, add, p2, p2) + } + + ENTRY reduce { + param0 = f32[32,32]{1,0} parameter(0) + param1 = f32[32,32]{1,0} parameter(1) + param2 = f32[32,32]{1,0} parameter(2) + fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(param0, param1, param2), kind=kLoop, calls=fused_add + gte.0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0 // dead + gte.1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1 + gte.1.again = f32[32,32]{1,0} get-tuple-element(fusion), index=1 + gte.3 = f32[32,32]{1,0} get-tuple-element(fusion), index=3 + ROOT res = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(gte.1, gte.1.again, gte.3) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString)); + HloDCE dce; + auto changed = dce.Run(module.get()); + ASSERT_TRUE(changed.ok()); + EXPECT_TRUE(*changed); + + // We expect that the dead parameter and the dead tuple entry are removed. + HloInstruction* gte_0 = FindInstruction(module.get(), "gte.0"); + EXPECT_EQ(gte_0, nullptr); + HloInstruction* gte_1 = FindInstruction(module.get(), "gte.1"); + EXPECT_NE(gte_1, nullptr); + EXPECT_EQ(static_cast(gte_1)->tuple_index(), + 0); + HloInstruction* gte_1_again = FindInstruction(module.get(), "gte.1.again"); + EXPECT_EQ( + static_cast(gte_1_again)->tuple_index(), + 0); + EXPECT_NE(gte_1_again, nullptr); + HloInstruction* gte_3 = FindInstruction(module.get(), "gte.3"); + EXPECT_NE(gte_3, nullptr); + EXPECT_EQ(static_cast(gte_3)->tuple_index(), + 1); + + HloInstruction* fusion = FindInstruction(module.get(), "fusion"); + ASSERT_NE(fusion, nullptr); + EXPECT_TRUE(fusion->shape().IsTuple()); + EXPECT_EQ(fusion->shape().tuple_shapes_size(), 2); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kTuple); + EXPECT_EQ(root->operand_count(), 3); + EXPECT_EQ(root->operand(0), gte_1); + EXPECT_EQ(root->operand(1), gte_1_again); + EXPECT_EQ(root->operand(2), gte_3); +} + TEST_F(HloDceTest, MultiOutputFusionRemoveUnusedTupleElementAdjustTuple) { constexpr char kHloString[] = R"( HloModule test_module