From 2b1c60936e690bb7633a1427e97fe0d73d63d3ce Mon Sep 17 00:00:00 2001 From: xla authors Date: Tue, 17 Sep 2024 16:37:29 -0700 Subject: [PATCH] [XLA][HloDCE] Removal of unused outputs of fusions to consider multiple users of the same shape index (aka output) Changes: * Replaces check of users vs number of outputs with checking if number of unique outputs used is smaller than number of outputs * Before this change, if same shape index (aka same output) is used multiple times, we might not end up removing any of the unused. * Before this change, if an output has multiple users, and it is the only one used, we might not remove all the unused fusion outputs (aka will leave around outputs). PiperOrigin-RevId: 675754198 --- xla/service/hlo_dce.cc | 44 +++++++++------ xla/service/hlo_dce_test.cc | 109 ++++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 17 deletions(-) diff --git a/xla/service/hlo_dce.cc b/xla/service/hlo_dce.cc index 1c6050554ed35..5617190d36b85 100644 --- a/xla/service/hlo_dce.cc +++ b/xla/service/hlo_dce.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -79,11 +80,9 @@ bool IsRemovableWhile(HloInstruction* instruction, !computation->root_instruction()->has_sharding() && fusion_instruction->output_operand_aliasing().empty() && !fusion_instruction->HasControlDependencies() && - fusion_instruction->user_count() < - computation->root_instruction()->operand_count() && !fusion_instruction->IsCustomFusion()) { - std::vector used_tuple_elements; - used_tuple_elements.reserve(fusion_instruction->user_count()); + // The order of the used outputs is relevant for the algorithm below. + std::set used_tuple_elements; // We only support this cleanup if all users of the fusion instruction are // GetTupleElement ops, and there is at least one user of // 'fusion_instruction'. @@ -93,10 +92,16 @@ bool IsRemovableWhile(HloInstruction* instruction, supported = false; break; } - used_tuple_elements.push_back(gte->tuple_index()); + used_tuple_elements.insert(gte->tuple_index()); } + + // If all outputs are used, nothing to clean up. + if (used_tuple_elements.size() == + computation->root_instruction()->operand_count()) { + supported = false; + } + if (supported) { - std::sort(used_tuple_elements.begin(), used_tuple_elements.end()); std::vector tuple_shapes; tuple_shapes.reserve(used_tuple_elements.size()); for (int64_t tuple_index : used_tuple_elements) { @@ -119,18 +124,23 @@ bool IsRemovableWhile(HloInstruction* instruction, gte->set_tuple_index(new_tuple_index); } } else { - HloInstruction* gte = fusion_instruction->users()[0]; - // Replace and change control successors to be dependent on the fusion - // instruction itself. - TF_ASSIGN_OR_RETURN(bool replaced, - gte->parent()->ReplaceInstruction( - gte, fusion_instruction, - /*preserve_sharding=*/true, - /*relay_control_dependency=*/true)); - if (replaced) { - changed |= replaced; + // Since we iterate over users while removing them .. make a local copy + // first. + std::vector users(fusion_instruction->users()); + for (HloInstruction* gte : users) { + // Replace and change control successors to be dependent on the fusion + // instruction itself. + TF_ASSIGN_OR_RETURN(bool replaced, + gte->parent()->ReplaceInstruction( + gte, fusion_instruction, + /*preserve_sharding=*/true, + /*relay_control_dependency=*/true)); + if (replaced) { + changed |= replaced; + } } } + // Update the root of the fusion computation. if (tuple_shapes.size() > 1) { std::vector new_operands; @@ -147,7 +157,7 @@ bool IsRemovableWhile(HloInstruction* instruction, TF_RETURN_IF_ERROR( computation->root_instruction()->ReplaceAllUsesWithDifferentShape( computation->root_instruction()->mutable_operand( - used_tuple_elements[0]))); + *used_tuple_elements.begin()))); } } } diff --git a/xla/service/hlo_dce_test.cc b/xla/service/hlo_dce_test.cc index 38a170ae77160..765af9fecbeb8 100644 --- a/xla/service/hlo_dce_test.cc +++ b/xla/service/hlo_dce_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -650,6 +651,114 @@ TEST_F(HloDceTest, MultiOutputFusionRemoveUnusedTupleElementsRemoveTuple) { EXPECT_EQ(module->MakeComputationPostOrder().size(), 2); } +TEST_F( + HloDceTest, + MultiOutputFusionRemoveUnusedTupleElementsRemoveTupleMultiUsersPerOutput) { + constexpr char kHloString[] = R"( + HloModule test_module + fused_add { + p0 = f32[32,32]{1,0} parameter(0) + p1 = f32[32,32]{1,0} parameter(1) + p2 = f32[32,32]{1,0} parameter(2) // becomes dead + add = f32[32,32]{1,0} add(p0, p1) + ROOT res = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(p2, add, p2) + } + + ENTRY reduce { + param0 = f32[32,32]{1,0} parameter(0) + param1 = f32[32,32]{1,0} parameter(1) + param2 = f32[32,32]{1,0} parameter(2) + fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(param0, param1, param2), kind=kLoop, calls=fused_add + gte.1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1 + gte.1.again = f32[32,32]{1,0} get-tuple-element(fusion), index=1 + ROOT res = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(gte.1, gte.1.again) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString)); + HloDCE dce; + auto changed = dce.Run(module.get()); + ASSERT_TRUE(changed.ok()); + EXPECT_TRUE(*changed); + + HloInstruction* gte_0 = FindInstruction(module.get(), "gte.0"); + EXPECT_EQ(gte_0, nullptr); + HloInstruction* gte_1 = FindInstruction(module.get(), "gte.1"); + EXPECT_EQ(gte_1, nullptr); + HloInstruction* gte_1_again = FindInstruction(module.get(), "gte.1.again"); + EXPECT_EQ(gte_1_again, nullptr); + + HloInstruction* fusion = FindInstruction(module.get(), "fusion"); + ASSERT_NE(fusion, nullptr); + EXPECT_FALSE(fusion->shape().IsTuple()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kTuple); + EXPECT_EQ(root->operand_count(), 2); + EXPECT_EQ(root->operand(0), fusion); + EXPECT_EQ(root->operand(1), fusion); +} + +TEST_F( + HloDceTest, + MultiOutputFusionRemoveUnusedTupleElementsRemoveTupleNonContiguousRemoval) { + constexpr char kHloString[] = R"( + HloModule test_module + fused_add { + p0 = f32[32,32]{1,0} parameter(0) + p1 = f32[32,32]{1,0} parameter(1) + p2 = f32[32,32]{1,0} parameter(2) // becomes dead + add = f32[32,32]{1,0} add(p0, p1) + ROOT res = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(p2, add, p2, p2) + } + + ENTRY reduce { + param0 = f32[32,32]{1,0} parameter(0) + param1 = f32[32,32]{1,0} parameter(1) + param2 = f32[32,32]{1,0} parameter(2) + fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(param0, param1, param2), kind=kLoop, calls=fused_add + gte.0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0 // dead + gte.1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1 + gte.1.again = f32[32,32]{1,0} get-tuple-element(fusion), index=1 + gte.3 = f32[32,32]{1,0} get-tuple-element(fusion), index=3 + ROOT res = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(gte.1, gte.1.again, gte.3) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString)); + HloDCE dce; + auto changed = dce.Run(module.get()); + ASSERT_TRUE(changed.ok()); + EXPECT_TRUE(*changed); + + // We expect that the dead parameter and the dead tuple entry are removed. + HloInstruction* gte_0 = FindInstruction(module.get(), "gte.0"); + EXPECT_EQ(gte_0, nullptr); + HloInstruction* gte_1 = FindInstruction(module.get(), "gte.1"); + EXPECT_NE(gte_1, nullptr); + EXPECT_EQ(static_cast(gte_1)->tuple_index(), + 0); + HloInstruction* gte_1_again = FindInstruction(module.get(), "gte.1.again"); + EXPECT_EQ( + static_cast(gte_1_again)->tuple_index(), + 0); + EXPECT_NE(gte_1_again, nullptr); + HloInstruction* gte_3 = FindInstruction(module.get(), "gte.3"); + EXPECT_NE(gte_3, nullptr); + EXPECT_EQ(static_cast(gte_3)->tuple_index(), + 1); + + HloInstruction* fusion = FindInstruction(module.get(), "fusion"); + ASSERT_NE(fusion, nullptr); + EXPECT_TRUE(fusion->shape().IsTuple()); + EXPECT_EQ(fusion->shape().tuple_shapes_size(), 2); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kTuple); + EXPECT_EQ(root->operand_count(), 3); + EXPECT_EQ(root->operand(0), gte_1); + EXPECT_EQ(root->operand(1), gte_1_again); + EXPECT_EQ(root->operand(2), gte_3); +} + TEST_F(HloDceTest, MultiOutputFusionRemoveUnusedTupleElementAdjustTuple) { constexpr char kHloString[] = R"( HloModule test_module