Skip to content

Commit

Permalink
[XLA][HloDCE] Removal of unused outputs of fusions to consider multip…
Browse files Browse the repository at this point in the history
…le users of the same shape index (aka output)

Changes:
* Replaces check of users vs number of outputs with checking if number of unique outputs used is smaller than number of outputs
* Before this change, if same shape index (aka same output) is used multiple times, we might not end up removing any of the unused.
* Before this change, if an output has multiple users, and it is the only one used, we might not remove all the unused fusion outputs (aka will leave around <number of users> outputs).
PiperOrigin-RevId: 675754198
  • Loading branch information
Google-ML-Automation committed Sep 17, 2024
1 parent 3406c60 commit 2b1c609
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 17 deletions.
44 changes: 27 additions & 17 deletions xla/service/hlo_dce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include <cstdint>
#include <iterator>
#include <set>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -79,11 +80,9 @@ bool IsRemovableWhile(HloInstruction* instruction,
!computation->root_instruction()->has_sharding() &&
fusion_instruction->output_operand_aliasing().empty() &&
!fusion_instruction->HasControlDependencies() &&
fusion_instruction->user_count() <
computation->root_instruction()->operand_count() &&
!fusion_instruction->IsCustomFusion()) {
std::vector<int64_t> used_tuple_elements;
used_tuple_elements.reserve(fusion_instruction->user_count());
// The order of the used outputs is relevant for the algorithm below.
std::set<int64_t> used_tuple_elements;
// We only support this cleanup if all users of the fusion instruction are
// GetTupleElement ops, and there is at least one user of
// 'fusion_instruction'.
Expand All @@ -93,10 +92,16 @@ bool IsRemovableWhile(HloInstruction* instruction,
supported = false;
break;
}
used_tuple_elements.push_back(gte->tuple_index());
used_tuple_elements.insert(gte->tuple_index());
}

// If all outputs are used, nothing to clean up.
if (used_tuple_elements.size() ==
computation->root_instruction()->operand_count()) {
supported = false;
}

if (supported) {
std::sort(used_tuple_elements.begin(), used_tuple_elements.end());
std::vector<Shape> tuple_shapes;
tuple_shapes.reserve(used_tuple_elements.size());
for (int64_t tuple_index : used_tuple_elements) {
Expand All @@ -119,18 +124,23 @@ bool IsRemovableWhile(HloInstruction* instruction,
gte->set_tuple_index(new_tuple_index);
}
} else {
HloInstruction* gte = fusion_instruction->users()[0];
// Replace and change control successors to be dependent on the fusion
// instruction itself.
TF_ASSIGN_OR_RETURN(bool replaced,
gte->parent()->ReplaceInstruction(
gte, fusion_instruction,
/*preserve_sharding=*/true,
/*relay_control_dependency=*/true));
if (replaced) {
changed |= replaced;
// Since we iterate over users while removing them .. make a local copy
// first.
std::vector<HloInstruction*> users(fusion_instruction->users());
for (HloInstruction* gte : users) {
// Replace and change control successors to be dependent on the fusion
// instruction itself.
TF_ASSIGN_OR_RETURN(bool replaced,
gte->parent()->ReplaceInstruction(
gte, fusion_instruction,
/*preserve_sharding=*/true,
/*relay_control_dependency=*/true));
if (replaced) {
changed |= replaced;
}
}
}

// Update the root of the fusion computation.
if (tuple_shapes.size() > 1) {
std::vector<HloInstruction*> new_operands;
Expand All @@ -147,7 +157,7 @@ bool IsRemovableWhile(HloInstruction* instruction,
TF_RETURN_IF_ERROR(
computation->root_instruction()->ReplaceAllUsesWithDifferentShape(
computation->root_instruction()->mutable_operand(
used_tuple_elements[0])));
*used_tuple_elements.begin())));
}
}
}
Expand Down
109 changes: 109 additions & 0 deletions xla/service/hlo_dce_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <memory>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/types/span.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
Expand Down Expand Up @@ -650,6 +651,114 @@ TEST_F(HloDceTest, MultiOutputFusionRemoveUnusedTupleElementsRemoveTuple) {
EXPECT_EQ(module->MakeComputationPostOrder().size(), 2);
}

TEST_F(
HloDceTest,
MultiOutputFusionRemoveUnusedTupleElementsRemoveTupleMultiUsersPerOutput) {
constexpr char kHloString[] = R"(
HloModule test_module
fused_add {
p0 = f32[32,32]{1,0} parameter(0)
p1 = f32[32,32]{1,0} parameter(1)
p2 = f32[32,32]{1,0} parameter(2) // becomes dead
add = f32[32,32]{1,0} add(p0, p1)
ROOT res = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(p2, add, p2)
}
ENTRY reduce {
param0 = f32[32,32]{1,0} parameter(0)
param1 = f32[32,32]{1,0} parameter(1)
param2 = f32[32,32]{1,0} parameter(2)
fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(param0, param1, param2), kind=kLoop, calls=fused_add
gte.1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1
gte.1.again = f32[32,32]{1,0} get-tuple-element(fusion), index=1
ROOT res = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(gte.1, gte.1.again)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kHloString));
HloDCE dce;
auto changed = dce.Run(module.get());
ASSERT_TRUE(changed.ok());
EXPECT_TRUE(*changed);

HloInstruction* gte_0 = FindInstruction(module.get(), "gte.0");
EXPECT_EQ(gte_0, nullptr);
HloInstruction* gte_1 = FindInstruction(module.get(), "gte.1");
EXPECT_EQ(gte_1, nullptr);
HloInstruction* gte_1_again = FindInstruction(module.get(), "gte.1.again");
EXPECT_EQ(gte_1_again, nullptr);

HloInstruction* fusion = FindInstruction(module.get(), "fusion");
ASSERT_NE(fusion, nullptr);
EXPECT_FALSE(fusion->shape().IsTuple());

HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
EXPECT_EQ(root->operand_count(), 2);
EXPECT_EQ(root->operand(0), fusion);
EXPECT_EQ(root->operand(1), fusion);
}

TEST_F(
HloDceTest,
MultiOutputFusionRemoveUnusedTupleElementsRemoveTupleNonContiguousRemoval) {
constexpr char kHloString[] = R"(
HloModule test_module
fused_add {
p0 = f32[32,32]{1,0} parameter(0)
p1 = f32[32,32]{1,0} parameter(1)
p2 = f32[32,32]{1,0} parameter(2) // becomes dead
add = f32[32,32]{1,0} add(p0, p1)
ROOT res = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(p2, add, p2, p2)
}
ENTRY reduce {
param0 = f32[32,32]{1,0} parameter(0)
param1 = f32[32,32]{1,0} parameter(1)
param2 = f32[32,32]{1,0} parameter(2)
fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(param0, param1, param2), kind=kLoop, calls=fused_add
gte.0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0 // dead
gte.1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1
gte.1.again = f32[32,32]{1,0} get-tuple-element(fusion), index=1
gte.3 = f32[32,32]{1,0} get-tuple-element(fusion), index=3
ROOT res = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(gte.1, gte.1.again, gte.3)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kHloString));
HloDCE dce;
auto changed = dce.Run(module.get());
ASSERT_TRUE(changed.ok());
EXPECT_TRUE(*changed);

// We expect that the dead parameter and the dead tuple entry are removed.
HloInstruction* gte_0 = FindInstruction(module.get(), "gte.0");
EXPECT_EQ(gte_0, nullptr);
HloInstruction* gte_1 = FindInstruction(module.get(), "gte.1");
EXPECT_NE(gte_1, nullptr);
EXPECT_EQ(static_cast<HloGetTupleElementInstruction*>(gte_1)->tuple_index(),
0);
HloInstruction* gte_1_again = FindInstruction(module.get(), "gte.1.again");
EXPECT_EQ(
static_cast<HloGetTupleElementInstruction*>(gte_1_again)->tuple_index(),
0);
EXPECT_NE(gte_1_again, nullptr);
HloInstruction* gte_3 = FindInstruction(module.get(), "gte.3");
EXPECT_NE(gte_3, nullptr);
EXPECT_EQ(static_cast<HloGetTupleElementInstruction*>(gte_3)->tuple_index(),
1);

HloInstruction* fusion = FindInstruction(module.get(), "fusion");
ASSERT_NE(fusion, nullptr);
EXPECT_TRUE(fusion->shape().IsTuple());
EXPECT_EQ(fusion->shape().tuple_shapes_size(), 2);

HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
EXPECT_EQ(root->operand_count(), 3);
EXPECT_EQ(root->operand(0), gte_1);
EXPECT_EQ(root->operand(1), gte_1_again);
EXPECT_EQ(root->operand(2), gte_3);
}

TEST_F(HloDceTest, MultiOutputFusionRemoveUnusedTupleElementAdjustTuple) {
constexpr char kHloString[] = R"(
HloModule test_module
Expand Down

0 comments on commit 2b1c609

Please sign in to comment.