Skip to content

Commit

Permalink
[XLA:GPU] Return instruction from FindInstruction in HLO query helpers.
Browse files Browse the repository at this point in the history
The function previously returned the index of encounter, which is
non-deterministic and misleading. Also, remove `IsBeforeInComputation`, which
is based on this index.

PiperOrigin-RevId: 681524796
  • Loading branch information
frgossen authored and Google-ML-Automation committed Oct 2, 2024
1 parent cd6e808 commit c88b612
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 115 deletions.
33 changes: 9 additions & 24 deletions xla/hlo/utils/hlo_query.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,36 +280,21 @@ HloComputation* FindComputation(HloModule* module, absl::string_view name) {
return *it;
}

std::pair<HloInstruction*, int> FindFirstInstruction(
const HloComputation* computation, absl::string_view name) {
int current_index = 0;
for (auto* instruction : computation->instructions()) {
if (instruction->name() == name) {
return {instruction, current_index};
break;
}
current_index++;
HloInstruction* FindInstruction(const HloComputation* computation,
absl::string_view name) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->name() == name) return instruction;
}
return {nullptr, -1};
return nullptr;
}

std::pair<HloInstruction*, int> FindFirstInstruction(
const HloComputation* computation, HloOpcode opcode) {
int current_index = 0;
HloInstruction* FindInstruction(const HloComputation* computation,
HloOpcode opcode) {
for (auto* instruction : computation->instructions()) {
if (instruction->opcode() == opcode) {
return {instruction, current_index};
break;
}
current_index++;
if (instruction->opcode() == opcode) return instruction;
}
return {nullptr, -1};
return nullptr;
}

bool IsBeforeInComputation(const HloComputation* computation,
absl::string_view inst1, absl::string_view inst2) {
return FindFirstInstruction(computation, inst1).second <
FindFirstInstruction(computation, inst2).second;
}
} // namespace hlo_query
} // namespace xla
28 changes: 11 additions & 17 deletions xla/hlo/utils/hlo_query.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,23 +156,17 @@ HloInstruction* GetUniqueGteInstruction(const HloInstruction* operand,

// Gets the computation from the given module with the given name.
HloComputation* FindComputation(HloModule* module, absl::string_view name);
// Gets the first instruction and its index from the given computation with the
// given instruction name. The function returns {nullptr, -1} if the instruction
// cannot be found.
std::pair<HloInstruction*, int> FindFirstInstruction(
const HloComputation* computation, absl::string_view name);
// Gets the first instruction and its index from the given computation with the
// given instruction opcode. The function returns {nullptr, -1} if the
// instruction cannot be found.
std::pair<HloInstruction*, int> FindFirstInstruction(
const HloComputation* computation, HloOpcode opcode);

// Check that one instruction comes before another one for a given computation.
// The function returns true if the first instruction comes before the second
// one, and false otherwise. This is useful for partial checks on the
// transformed IR without going through a full file check.
bool IsBeforeInComputation(const HloComputation* computation,
absl::string_view inst1, absl::string_view inst2);

// Gets the instruction from the given computation with the given instruction
// name. Returns nullptr if no such instruction can be found.
HloInstruction* FindInstruction(const HloComputation* computation,
absl::string_view name);

// Gets any instruction from the given computation with the given opcode.
// Returns nullptr if no such instruction can be found.
HloInstruction* FindInstruction(const HloComputation* computation,
HloOpcode opcode);

} // namespace hlo_query
} // namespace xla

Expand Down
69 changes: 25 additions & 44 deletions xla/hlo/utils/hlo_query_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,52 +157,42 @@ TEST_F(HloQueryTest, FindInstructionUsingNameTest) {
std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(kConstantAdditionHloString));
const HloComputation* main = hlo_query::FindComputation(module.get(), "main");
EXPECT_NE(hlo_query::FindFirstInstruction(main, "zero").first, nullptr);
EXPECT_NE(hlo_query::FindFirstInstruction(main, "five").first, nullptr);
EXPECT_NE(hlo_query::FindFirstInstruction(main, "out").first, nullptr);
EXPECT_EQ(hlo_query::FindFirstInstruction(main, "foo").first, nullptr);
EXPECT_NE(hlo_query::FindInstruction(main, "zero"), nullptr);
EXPECT_NE(hlo_query::FindInstruction(main, "five"), nullptr);
EXPECT_NE(hlo_query::FindInstruction(main, "out"), nullptr);
EXPECT_EQ(hlo_query::FindInstruction(main, "foo"), nullptr);
}

std::pair<HloInstruction*, int> FindFirst(const HloComputation* main,
absl::string_view opcode) {
return hlo_query::FindFirstInstruction(main,
StringToHloOpcode(opcode).value());
}

// Assures that the string and opcode versions of FindFirstInstruction return
// Assures that the string and opcode versions of FindInstruction return
// the same result
void FindFirstInstructionsAndExpectEqual(const HloComputation* main,
absl::string_view name,
absl::string_view opcode_str) {
void FindInstructionsAndExpectEqual(const HloComputation* main,
absl::string_view name, HloOpcode opcode) {
SCOPED_TRACE(absl::StrCat("Comparing finding by name: ", name,
" and opcode: ", opcode_str));
auto withString = hlo_query::FindFirstInstruction(main, name);
auto withOpCode = FindFirst(main, opcode_str);
EXPECT_EQ(withString.first, withOpCode.first);
EXPECT_EQ(withString.second, withOpCode.second);
if (withString.first != nullptr)
EXPECT_EQ(withString.first->ToString(), withOpCode.first->ToString());
" and opcode: ", opcode));
HloInstruction* by_name = hlo_query::FindInstruction(main, name);
HloInstruction* by_opcode = hlo_query::FindInstruction(main, opcode);
EXPECT_EQ(by_name, by_opcode);
}

TEST_F(HloQueryTest, FindInstructionUsingOpcodeTest) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(kConstantAdditionHloString));
const HloComputation* main = hlo_query::FindComputation(module.get(), "main");
EXPECT_NE(FindFirst(main, "add").first, nullptr);
EXPECT_NE(FindFirst(main, "constant").first, nullptr);
EXPECT_EQ(FindFirst(main, "select").first, nullptr);
EXPECT_NE(hlo_query::FindInstruction(main, HloOpcode::kConstant), nullptr);
EXPECT_NE(hlo_query::FindInstruction(main, HloOpcode::kAdd), nullptr);
EXPECT_EQ(hlo_query::FindInstruction(main, HloOpcode::kSelect), nullptr);
}

TEST_F(HloQueryTest, FindInstructionUsingOpcodeAndNameEqualTest) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(kConstantAdditionHloString));
const HloComputation* main = hlo_query::FindComputation(module.get(), "main");
FindFirstInstructionsAndExpectEqual(main, "zero", "constant");
FindFirstInstructionsAndExpectEqual(main, "out", "add");
FindInstructionsAndExpectEqual(main, "zero", HloOpcode::kConstant);
FindInstructionsAndExpectEqual(main, "out", HloOpcode::kAdd);
// both are not found
FindFirstInstructionsAndExpectEqual(main, "dummy", "select");
FindInstructionsAndExpectEqual(main, "dummy", HloOpcode::kSelect);
}

TEST_F(HloQueryTest, FindInstructionDoesNotExistTest) {
Expand All @@ -211,21 +201,10 @@ TEST_F(HloQueryTest, FindInstructionDoesNotExistTest) {
ParseAndReturnUnverifiedModule(kConstantAdditionHloString));
const HloComputation* main = hlo_query::FindComputation(module.get(), "main");
EXPECT_NE(main, nullptr);
auto find_beef = hlo_query::FindFirstInstruction(main, "deadbeef");
auto find_nothing = hlo_query::FindFirstInstruction(main, "");
EXPECT_EQ(find_beef.first, nullptr);
EXPECT_EQ(find_beef.second, -1);
EXPECT_EQ(find_nothing.first, nullptr);
EXPECT_EQ(find_nothing.second, -1);
}

TEST_F(HloQueryTest, IsBeforeInComputationTest) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(kConstantAdditionHloString));
const HloComputation* main = hlo_query::FindComputation(module.get(), "main");
EXPECT_TRUE(hlo_query::IsBeforeInComputation(main, "zero", "five"));
EXPECT_TRUE(hlo_query::IsBeforeInComputation(main, "five", "out"));
auto find_beef = hlo_query::FindInstruction(main, "deadbeef");
auto find_nothing = hlo_query::FindInstruction(main, "");
EXPECT_EQ(find_beef, nullptr);
EXPECT_EQ(find_nothing, nullptr);
}

TEST_F(HloQueryTest, NextChannelIdForModuleWithoutChannelIdTest) {
Expand Down Expand Up @@ -253,8 +232,10 @@ TEST_F(HloQueryTest, NextChannelIdTwoIdsTest) {
HloModule test
ENTRY test_computation {
p = u32[] partition-id()
l = u32[] collective-permute(p), channel_id=8, source_target_pairs={{0,1},{1,2}}
r = u32[] collective-permute(p), channel_id=9, source_target_pairs={{2,3},{3,0}}
l = u32[] collective-permute(p), channel_id=8,
source_target_pairs={{0,1},{1,2}}
r = u32[] collective-permute(p), channel_id=9,
source_target_pairs={{2,3},{3,0}}
ROOT res = u32[] add(l,r)
}
)";
Expand Down
40 changes: 16 additions & 24 deletions xla/service/collective_permute_decomposer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,8 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) {
select = f32[2,2] select(broadcast, cp_back, cp_forward)
matmul = f32[2,2] dot(weights, select), lhs_contracting_dims={1}, rhs_contracting_dims={0}
matmul = f32[2,2] dot(weights, select), lhs_contracting_dims={1},
rhs_contracting_dims={0}
ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, matmul, weights)
}
Expand All @@ -361,8 +362,10 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) {
start_iter = u32[] constant(0)
input_data = f32[2,2] parameter(0)
input_weights = f32[2,2] parameter(1)
input = (u32[], f32[2,2], f32[2,2]) tuple(start_iter, input_data, input_weights)
while_result = (u32[], f32[2,2], f32[2,2]) while(input), condition=while_cond, body=while_body
input = (u32[], f32[2,2], f32[2,2]) tuple(start_iter, input_data,
input_weights)
while_result = (u32[], f32[2,2], f32[2,2]) while(input),
condition=while_cond, body=while_body
ROOT data_out = f32[2,2] get-tuple-element(while_result), index=1
}
)";
Expand All @@ -378,7 +381,9 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) {
// an XLA invariant that shouldn't be broken (see
// https://openxla.org/xla/operation_semantics#send for details of the
// semantics).
HloInstruction* recv_bwd = FindInstruction(transformed_module, "recv");
HloComputation* while_body =
FindComputation(transformed_module, "while_body");
HloInstruction* recv_bwd = hlo_query::FindInstruction(while_body, "recv");
EXPECT_EQ(recv_bwd->channel_id().value(), 1);
auto recv_bwd_frontend_attributes = recv_bwd->frontend_attributes().map();
EXPECT_EQ(recv_bwd_frontend_attributes.size(), 3);
Expand All @@ -388,44 +393,31 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) {
EXPECT_EQ(recv_bwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr),
"{{3,0}}");

HloInstruction* send_bwd = FindInstruction(transformed_module, "send");
HloInstruction* send_bwd = hlo_query::FindInstruction(while_body, "send");
auto send_bwd_frontend_attributes = send_bwd->frontend_attributes().map();
EXPECT_THAT(send_bwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr),
"{{3,0}}");

HloInstruction* recv_fwd = FindInstruction(transformed_module, "recv.1");
HloInstruction* recv_fwd = hlo_query::FindInstruction(while_body, "recv.1");
EXPECT_EQ(recv_fwd->channel_id().value(), 2);
auto recv_fwd_frontend_attributes = recv_fwd->frontend_attributes().map();
EXPECT_EQ(recv_fwd_frontend_attributes.size(), 3);
EXPECT_EQ(recv_fwd_frontend_attributes.at(kSendRecvPipelineAttr), "1");
EXPECT_EQ(recv_fwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr),
"{{0,1},{1,2},{2,3}}");

HloInstruction* send_fwd = FindInstruction(transformed_module, "send.1");
HloInstruction* send_fwd = hlo_query::FindInstruction(while_body, "send.1");
auto send_fwd_frontend_attributes = send_fwd->frontend_attributes().map();
EXPECT_EQ(send_fwd_frontend_attributes.size(), 3);
EXPECT_EQ(send_fwd_frontend_attributes.at(kSendRecvPipelineAttr), "1");
EXPECT_EQ(send_fwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr),
"{{0,1},{1,2},{2,3}}");

HloComputation* while_body =
FindComputation(transformed_module, "while_body");
EXPECT_NE(while_body, nullptr);
EXPECT_TRUE(hlo_query::IsBeforeInComputation(while_body, "recv", "send"));
EXPECT_TRUE(
hlo_query::IsBeforeInComputation(while_body, "recv", "recv-done"));
EXPECT_TRUE(
hlo_query::IsBeforeInComputation(while_body, "send", "recv-done"));
EXPECT_TRUE(
hlo_query::IsBeforeInComputation(while_body, "send", "send-done"));
EXPECT_TRUE(
hlo_query::IsBeforeInComputation(while_body, "send-done", "send-done.1"));
EXPECT_TRUE(
hlo_query::IsBeforeInComputation(while_body, "recv-done", "send-done.1"));
EXPECT_TRUE(hlo_query::IsBeforeInComputation(while_body, "recv-done.1",
"send-done.1"));
auto recv_done_fwd = FindInstruction(transformed_module, "recv-done");
auto recv_done_bwd = FindInstruction(transformed_module, "recv-done.1");
HloInstruction* recv_done_fwd =
hlo_query::FindInstruction(while_body, "recv-done");
HloInstruction* recv_done_bwd =
hlo_query::FindInstruction(while_body, "recv-done.1");

// TODO: b/356201477 - Investigate potential NCCL deadlock in
// collective_permute_decomposer
Expand Down
12 changes: 6 additions & 6 deletions xla/tests/hlo_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1097,9 +1097,9 @@ HloComputation* HloTestBase::FindComputation(HloModule* module,
HloInstruction* HloTestBase::FindInstruction(HloModule* module,
absl::string_view name) {
for (const HloComputation* computation : module->computations()) {
if (auto instruction = hlo_query::FindFirstInstruction(computation, name);
instruction.first != nullptr) {
return instruction.first;
if (HloInstruction* instruction =
hlo_query::FindInstruction(computation, name)) {
return instruction;
}
}
return nullptr;
Expand All @@ -1108,9 +1108,9 @@ HloInstruction* HloTestBase::FindInstruction(HloModule* module,
HloInstruction* HloTestBase::FindInstruction(HloModule* module,
HloOpcode opcode) {
for (const HloComputation* computation : module->computations()) {
if (auto instruction = hlo_query::FindFirstInstruction(computation, opcode);
instruction.first != nullptr) {
return instruction.first;
if (HloInstruction* instruction =
hlo_query::FindInstruction(computation, opcode)) {
return instruction;
}
}
return nullptr;
Expand Down

0 comments on commit c88b612

Please sign in to comment.