Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA][HloDCE] Removal of unused outputs of fusions to consider multiple users of the same shape index (aka output) #17217

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions xla/service/hlo_dce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include <cstdint>
#include <iterator>
#include <set>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -79,11 +80,9 @@ bool IsRemovableWhile(HloInstruction* instruction,
!computation->root_instruction()->has_sharding() &&
fusion_instruction->output_operand_aliasing().empty() &&
!fusion_instruction->HasControlDependencies() &&
fusion_instruction->user_count() <
computation->root_instruction()->operand_count() &&
!fusion_instruction->IsCustomFusion()) {
std::vector<int64_t> used_tuple_elements;
used_tuple_elements.reserve(fusion_instruction->user_count());
// The order of the used outputs is relevant for the algorithm below.
std::set<int64_t> used_tuple_elements;
// We only support this cleanup if all users of the fusion instruction are
// GetTupleElement ops, and there is at least one user of
// 'fusion_instruction'.
Expand All @@ -93,10 +92,16 @@ bool IsRemovableWhile(HloInstruction* instruction,
supported = false;
break;
}
used_tuple_elements.push_back(gte->tuple_index());
used_tuple_elements.insert(gte->tuple_index());
}

// If all outputs are used, nothing to clean up.
if (used_tuple_elements.size() ==
computation->root_instruction()->operand_count()) {
supported = false;
}

if (supported) {
std::sort(used_tuple_elements.begin(), used_tuple_elements.end());
std::vector<Shape> tuple_shapes;
tuple_shapes.reserve(used_tuple_elements.size());
for (int64_t tuple_index : used_tuple_elements) {
Expand All @@ -119,18 +124,23 @@ bool IsRemovableWhile(HloInstruction* instruction,
gte->set_tuple_index(new_tuple_index);
}
} else {
HloInstruction* gte = fusion_instruction->users()[0];
// Replace and change control successors to be dependent on the fusion
// instruction itself.
TF_ASSIGN_OR_RETURN(bool replaced,
gte->parent()->ReplaceInstruction(
gte, fusion_instruction,
/*preserve_sharding=*/true,
/*relay_control_dependency=*/true));
if (replaced) {
changed |= replaced;
// Since we iterate over users while removing them .. make a local copy
// first.
std::vector<HloInstruction*> users(fusion_instruction->users());
for (HloInstruction* gte : users) {
// Replace and change control successors to be dependent on the fusion
// instruction itself.
TF_ASSIGN_OR_RETURN(bool replaced,
gte->parent()->ReplaceInstruction(
gte, fusion_instruction,
/*preserve_sharding=*/true,
/*relay_control_dependency=*/true));
if (replaced) {
changed |= replaced;
}
}
}

// Update the root of the fusion computation.
if (tuple_shapes.size() > 1) {
std::vector<HloInstruction*> new_operands;
Expand All @@ -147,7 +157,7 @@ bool IsRemovableWhile(HloInstruction* instruction,
TF_RETURN_IF_ERROR(
computation->root_instruction()->ReplaceAllUsesWithDifferentShape(
computation->root_instruction()->mutable_operand(
used_tuple_elements[0])));
*used_tuple_elements.begin())));
}
}
}
Expand Down
109 changes: 109 additions & 0 deletions xla/service/hlo_dce_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <memory>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/types/span.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
Expand Down Expand Up @@ -650,6 +651,114 @@ TEST_F(HloDceTest, MultiOutputFusionRemoveUnusedTupleElementsRemoveTuple) {
EXPECT_EQ(module->MakeComputationPostOrder().size(), 2);
}

TEST_F(
HloDceTest,
MultiOutputFusionRemoveUnusedTupleElementsRemoveTupleMultiUsersPerOutput) {
constexpr char kHloString[] = R"(
HloModule test_module
fused_add {
p0 = f32[32,32]{1,0} parameter(0)
p1 = f32[32,32]{1,0} parameter(1)
p2 = f32[32,32]{1,0} parameter(2) // becomes dead
add = f32[32,32]{1,0} add(p0, p1)
ROOT res = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(p2, add, p2)
}

ENTRY reduce {
param0 = f32[32,32]{1,0} parameter(0)
param1 = f32[32,32]{1,0} parameter(1)
param2 = f32[32,32]{1,0} parameter(2)
fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(param0, param1, param2), kind=kLoop, calls=fused_add
gte.1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1
gte.1.again = f32[32,32]{1,0} get-tuple-element(fusion), index=1
ROOT res = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(gte.1, gte.1.again)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kHloString));
HloDCE dce;
auto changed = dce.Run(module.get());
ASSERT_TRUE(changed.ok());
EXPECT_TRUE(*changed);

HloInstruction* gte_0 = FindInstruction(module.get(), "gte.0");
EXPECT_EQ(gte_0, nullptr);
HloInstruction* gte_1 = FindInstruction(module.get(), "gte.1");
EXPECT_EQ(gte_1, nullptr);
HloInstruction* gte_1_again = FindInstruction(module.get(), "gte.1.again");
EXPECT_EQ(gte_1_again, nullptr);

HloInstruction* fusion = FindInstruction(module.get(), "fusion");
ASSERT_NE(fusion, nullptr);
EXPECT_FALSE(fusion->shape().IsTuple());

HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
EXPECT_EQ(root->operand_count(), 2);
EXPECT_EQ(root->operand(0), fusion);
EXPECT_EQ(root->operand(1), fusion);
}

TEST_F(
HloDceTest,
MultiOutputFusionRemoveUnusedTupleElementsRemoveTupleNonContiguousRemoval) {
constexpr char kHloString[] = R"(
HloModule test_module
fused_add {
p0 = f32[32,32]{1,0} parameter(0)
p1 = f32[32,32]{1,0} parameter(1)
p2 = f32[32,32]{1,0} parameter(2) // becomes dead
add = f32[32,32]{1,0} add(p0, p1)
ROOT res = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(p2, add, p2, p2)
}

ENTRY reduce {
param0 = f32[32,32]{1,0} parameter(0)
param1 = f32[32,32]{1,0} parameter(1)
param2 = f32[32,32]{1,0} parameter(2)
fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(param0, param1, param2), kind=kLoop, calls=fused_add
gte.0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0 // dead
gte.1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1
gte.1.again = f32[32,32]{1,0} get-tuple-element(fusion), index=1
gte.3 = f32[32,32]{1,0} get-tuple-element(fusion), index=3
ROOT res = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(gte.1, gte.1.again, gte.3)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kHloString));
HloDCE dce;
auto changed = dce.Run(module.get());
ASSERT_TRUE(changed.ok());
EXPECT_TRUE(*changed);

// We expect that the dead parameter and the dead tuple entry are removed.
HloInstruction* gte_0 = FindInstruction(module.get(), "gte.0");
EXPECT_EQ(gte_0, nullptr);
HloInstruction* gte_1 = FindInstruction(module.get(), "gte.1");
EXPECT_NE(gte_1, nullptr);
EXPECT_EQ(static_cast<HloGetTupleElementInstruction*>(gte_1)->tuple_index(),
0);
HloInstruction* gte_1_again = FindInstruction(module.get(), "gte.1.again");
EXPECT_EQ(
static_cast<HloGetTupleElementInstruction*>(gte_1_again)->tuple_index(),
0);
EXPECT_NE(gte_1_again, nullptr);
HloInstruction* gte_3 = FindInstruction(module.get(), "gte.3");
EXPECT_NE(gte_3, nullptr);
EXPECT_EQ(static_cast<HloGetTupleElementInstruction*>(gte_3)->tuple_index(),
1);

HloInstruction* fusion = FindInstruction(module.get(), "fusion");
ASSERT_NE(fusion, nullptr);
EXPECT_TRUE(fusion->shape().IsTuple());
EXPECT_EQ(fusion->shape().tuple_shapes_size(), 2);

HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
EXPECT_EQ(root->operand_count(), 3);
EXPECT_EQ(root->operand(0), gte_1);
EXPECT_EQ(root->operand(1), gte_1_again);
EXPECT_EQ(root->operand(2), gte_3);
}

TEST_F(HloDceTest, MultiOutputFusionRemoveUnusedTupleElementAdjustTuple) {
constexpr char kHloString[] = R"(
HloModule test_module
Expand Down
Loading