diff --git a/xla/hlo/utils/hlo_query.cc b/xla/hlo/utils/hlo_query.cc index 147f54822aef97..220fd3da9a2951 100644 --- a/xla/hlo/utils/hlo_query.cc +++ b/xla/hlo/utils/hlo_query.cc @@ -280,36 +280,21 @@ HloComputation* FindComputation(HloModule* module, absl::string_view name) { return *it; } -std::pair FindFirstInstruction( - const HloComputation* computation, absl::string_view name) { - int current_index = 0; - for (auto* instruction : computation->instructions()) { - if (instruction->name() == name) { - return {instruction, current_index}; - break; - } - current_index++; +HloInstruction* FindInstruction(const HloComputation* computation, + absl::string_view name) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->name() == name) return instruction; } - return {nullptr, -1}; + return nullptr; } -std::pair FindFirstInstruction( - const HloComputation* computation, HloOpcode opcode) { - int current_index = 0; +HloInstruction* FindInstruction(const HloComputation* computation, + HloOpcode opcode) { for (auto* instruction : computation->instructions()) { - if (instruction->opcode() == opcode) { - return {instruction, current_index}; - break; - } - current_index++; + if (instruction->opcode() == opcode) return instruction; } - return {nullptr, -1}; + return nullptr; } -bool IsBeforeInComputation(const HloComputation* computation, - absl::string_view inst1, absl::string_view inst2) { - return FindFirstInstruction(computation, inst1).second < - FindFirstInstruction(computation, inst2).second; -} } // namespace hlo_query } // namespace xla diff --git a/xla/hlo/utils/hlo_query.h b/xla/hlo/utils/hlo_query.h index ec5c0b25804d10..0882e26cf6fe8c 100644 --- a/xla/hlo/utils/hlo_query.h +++ b/xla/hlo/utils/hlo_query.h @@ -156,23 +156,17 @@ HloInstruction* GetUniqueGteInstruction(const HloInstruction* operand, // Gets the computation from the given module with the given name. HloComputation* FindComputation(HloModule* module, absl::string_view name); -// Gets the first instruction and its index from the given computation with the -// given instruction name. The function returns {nullptr, -1} if the instruction -// cannot be found. -std::pair FindFirstInstruction( - const HloComputation* computation, absl::string_view name); -// Gets the first instruction and its index from the given computation with the -// given instruction opcode. The function returns {nullptr, -1} if the -// instruction cannot be found. -std::pair FindFirstInstruction( - const HloComputation* computation, HloOpcode opcode); - -// Check that one instruction comes before another one for a given computation. -// The function returns true if the first instruction comes before the second -// one, and false otherwise. This is useful for partial checks on the -// transformed IR without going through a full file check. -bool IsBeforeInComputation(const HloComputation* computation, - absl::string_view inst1, absl::string_view inst2); + +// Gets the instruction from the given computation with the given instruction +// name. Returns nullptr if no such instruction can be found. +HloInstruction* FindInstruction(const HloComputation* computation, + absl::string_view name); + +// Gets any instruction from the given computation with the given opcode. +// Returns nullptr if no such instruction can be found. +HloInstruction* FindInstruction(const HloComputation* computation, + HloOpcode opcode); + } // namespace hlo_query } // namespace xla diff --git a/xla/hlo/utils/hlo_query_test.cc b/xla/hlo/utils/hlo_query_test.cc index 1f715ad6815284..f17e6bbd27832c 100644 --- a/xla/hlo/utils/hlo_query_test.cc +++ b/xla/hlo/utils/hlo_query_test.cc @@ -157,31 +157,21 @@ TEST_F(HloQueryTest, FindInstructionUsingNameTest) { std::unique_ptr module, ParseAndReturnUnverifiedModule(kConstantAdditionHloString)); const HloComputation* main = hlo_query::FindComputation(module.get(), "main"); - EXPECT_NE(hlo_query::FindFirstInstruction(main, "zero").first, nullptr); - EXPECT_NE(hlo_query::FindFirstInstruction(main, "five").first, nullptr); - EXPECT_NE(hlo_query::FindFirstInstruction(main, "out").first, nullptr); - EXPECT_EQ(hlo_query::FindFirstInstruction(main, "foo").first, nullptr); + EXPECT_NE(hlo_query::FindInstruction(main, "zero"), nullptr); + EXPECT_NE(hlo_query::FindInstruction(main, "five"), nullptr); + EXPECT_NE(hlo_query::FindInstruction(main, "out"), nullptr); + EXPECT_EQ(hlo_query::FindInstruction(main, "foo"), nullptr); } -std::pair FindFirst(const HloComputation* main, - absl::string_view opcode) { - return hlo_query::FindFirstInstruction(main, - StringToHloOpcode(opcode).value()); -} - -// Assures that the string and opcode versions of FindFirstInstruction return +// Assures that the string and opcode versions of FindInstruction return // the same result -void FindFirstInstructionsAndExpectEqual(const HloComputation* main, - absl::string_view name, - absl::string_view opcode_str) { +void FindInstructionsAndExpectEqual(const HloComputation* main, + absl::string_view name, HloOpcode opcode) { SCOPED_TRACE(absl::StrCat("Comparing finding by name: ", name, - " and opcode: ", opcode_str)); - auto withString = hlo_query::FindFirstInstruction(main, name); - auto withOpCode = FindFirst(main, opcode_str); - EXPECT_EQ(withString.first, withOpCode.first); - EXPECT_EQ(withString.second, withOpCode.second); - if (withString.first != nullptr) - EXPECT_EQ(withString.first->ToString(), withOpCode.first->ToString()); + " and opcode: ", opcode)); + HloInstruction* by_name = hlo_query::FindInstruction(main, name); + HloInstruction* by_opcode = hlo_query::FindInstruction(main, opcode); + EXPECT_EQ(by_name, by_opcode); } TEST_F(HloQueryTest, FindInstructionUsingOpcodeTest) { @@ -189,9 +179,9 @@ TEST_F(HloQueryTest, FindInstructionUsingOpcodeTest) { std::unique_ptr module, ParseAndReturnUnverifiedModule(kConstantAdditionHloString)); const HloComputation* main = hlo_query::FindComputation(module.get(), "main"); - EXPECT_NE(FindFirst(main, "add").first, nullptr); - EXPECT_NE(FindFirst(main, "constant").first, nullptr); - EXPECT_EQ(FindFirst(main, "select").first, nullptr); + EXPECT_NE(hlo_query::FindInstruction(main, HloOpcode::kConstant), nullptr); + EXPECT_NE(hlo_query::FindInstruction(main, HloOpcode::kAdd), nullptr); + EXPECT_EQ(hlo_query::FindInstruction(main, HloOpcode::kSelect), nullptr); } TEST_F(HloQueryTest, FindInstructionUsingOpcodeAndNameEqualTest) { @@ -199,10 +189,10 @@ TEST_F(HloQueryTest, FindInstructionUsingOpcodeAndNameEqualTest) { std::unique_ptr module, ParseAndReturnUnverifiedModule(kConstantAdditionHloString)); const HloComputation* main = hlo_query::FindComputation(module.get(), "main"); - FindFirstInstructionsAndExpectEqual(main, "zero", "constant"); - FindFirstInstructionsAndExpectEqual(main, "out", "add"); + FindInstructionsAndExpectEqual(main, "zero", HloOpcode::kConstant); + FindInstructionsAndExpectEqual(main, "out", HloOpcode::kAdd); // both are not found - FindFirstInstructionsAndExpectEqual(main, "dummy", "select"); + FindInstructionsAndExpectEqual(main, "dummy", HloOpcode::kSelect); } TEST_F(HloQueryTest, FindInstructionDoesNotExistTest) { @@ -211,21 +201,10 @@ TEST_F(HloQueryTest, FindInstructionDoesNotExistTest) { ParseAndReturnUnverifiedModule(kConstantAdditionHloString)); const HloComputation* main = hlo_query::FindComputation(module.get(), "main"); EXPECT_NE(main, nullptr); - auto find_beef = hlo_query::FindFirstInstruction(main, "deadbeef"); - auto find_nothing = hlo_query::FindFirstInstruction(main, ""); - EXPECT_EQ(find_beef.first, nullptr); - EXPECT_EQ(find_beef.second, -1); - EXPECT_EQ(find_nothing.first, nullptr); - EXPECT_EQ(find_nothing.second, -1); -} - -TEST_F(HloQueryTest, IsBeforeInComputationTest) { - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, - ParseAndReturnUnverifiedModule(kConstantAdditionHloString)); - const HloComputation* main = hlo_query::FindComputation(module.get(), "main"); - EXPECT_TRUE(hlo_query::IsBeforeInComputation(main, "zero", "five")); - EXPECT_TRUE(hlo_query::IsBeforeInComputation(main, "five", "out")); + auto find_beef = hlo_query::FindInstruction(main, "deadbeef"); + auto find_nothing = hlo_query::FindInstruction(main, ""); + EXPECT_EQ(find_beef, nullptr); + EXPECT_EQ(find_nothing, nullptr); } TEST_F(HloQueryTest, NextChannelIdForModuleWithoutChannelIdTest) { @@ -253,8 +232,10 @@ TEST_F(HloQueryTest, NextChannelIdTwoIdsTest) { HloModule test ENTRY test_computation { p = u32[] partition-id() - l = u32[] collective-permute(p), channel_id=8, source_target_pairs={{0,1},{1,2}} - r = u32[] collective-permute(p), channel_id=9, source_target_pairs={{2,3},{3,0}} + l = u32[] collective-permute(p), channel_id=8, + source_target_pairs={{0,1},{1,2}} + r = u32[] collective-permute(p), channel_id=9, + source_target_pairs={{2,3},{3,0}} ROOT res = u32[] add(l,r) } )"; diff --git a/xla/service/collective_permute_decomposer_test.cc b/xla/service/collective_permute_decomposer_test.cc index eac5ab0707418a..1fb86f4eb59948 100644 --- a/xla/service/collective_permute_decomposer_test.cc +++ b/xla/service/collective_permute_decomposer_test.cc @@ -345,7 +345,8 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) { select = f32[2,2] select(broadcast, cp_back, cp_forward) - matmul = f32[2,2] dot(weights, select), lhs_contracting_dims={1}, rhs_contracting_dims={0} + matmul = f32[2,2] dot(weights, select), lhs_contracting_dims={1}, + rhs_contracting_dims={0} ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, matmul, weights) } @@ -361,8 +362,10 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) { start_iter = u32[] constant(0) input_data = f32[2,2] parameter(0) input_weights = f32[2,2] parameter(1) - input = (u32[], f32[2,2], f32[2,2]) tuple(start_iter, input_data, input_weights) - while_result = (u32[], f32[2,2], f32[2,2]) while(input), condition=while_cond, body=while_body + input = (u32[], f32[2,2], f32[2,2]) tuple(start_iter, input_data, + input_weights) + while_result = (u32[], f32[2,2], f32[2,2]) while(input), + condition=while_cond, body=while_body ROOT data_out = f32[2,2] get-tuple-element(while_result), index=1 } )"; @@ -378,7 +381,9 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) { // an XLA invariant that shouldn't be broken (see // https://openxla.org/xla/operation_semantics#send for details of the // semantics). - HloInstruction* recv_bwd = FindInstruction(transformed_module, "recv"); + HloComputation* while_body = + FindComputation(transformed_module, "while_body"); + HloInstruction* recv_bwd = hlo_query::FindInstruction(while_body, "recv"); EXPECT_EQ(recv_bwd->channel_id().value(), 1); auto recv_bwd_frontend_attributes = recv_bwd->frontend_attributes().map(); EXPECT_EQ(recv_bwd_frontend_attributes.size(), 3); @@ -388,12 +393,12 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) { EXPECT_EQ(recv_bwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr), "{{3,0}}"); - HloInstruction* send_bwd = FindInstruction(transformed_module, "send"); + HloInstruction* send_bwd = hlo_query::FindInstruction(while_body, "send"); auto send_bwd_frontend_attributes = send_bwd->frontend_attributes().map(); EXPECT_THAT(send_bwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr), "{{3,0}}"); - HloInstruction* recv_fwd = FindInstruction(transformed_module, "recv.1"); + HloInstruction* recv_fwd = hlo_query::FindInstruction(while_body, "recv.1"); EXPECT_EQ(recv_fwd->channel_id().value(), 2); auto recv_fwd_frontend_attributes = recv_fwd->frontend_attributes().map(); EXPECT_EQ(recv_fwd_frontend_attributes.size(), 3); @@ -401,31 +406,18 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) { EXPECT_EQ(recv_fwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr), "{{0,1},{1,2},{2,3}}"); - HloInstruction* send_fwd = FindInstruction(transformed_module, "send.1"); + HloInstruction* send_fwd = hlo_query::FindInstruction(while_body, "send.1"); auto send_fwd_frontend_attributes = send_fwd->frontend_attributes().map(); EXPECT_EQ(send_fwd_frontend_attributes.size(), 3); EXPECT_EQ(send_fwd_frontend_attributes.at(kSendRecvPipelineAttr), "1"); EXPECT_EQ(send_fwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr), "{{0,1},{1,2},{2,3}}"); - HloComputation* while_body = - FindComputation(transformed_module, "while_body"); EXPECT_NE(while_body, nullptr); - EXPECT_TRUE(hlo_query::IsBeforeInComputation(while_body, "recv", "send")); - EXPECT_TRUE( - hlo_query::IsBeforeInComputation(while_body, "recv", "recv-done")); - EXPECT_TRUE( - hlo_query::IsBeforeInComputation(while_body, "send", "recv-done")); - EXPECT_TRUE( - hlo_query::IsBeforeInComputation(while_body, "send", "send-done")); - EXPECT_TRUE( - hlo_query::IsBeforeInComputation(while_body, "send-done", "send-done.1")); - EXPECT_TRUE( - hlo_query::IsBeforeInComputation(while_body, "recv-done", "send-done.1")); - EXPECT_TRUE(hlo_query::IsBeforeInComputation(while_body, "recv-done.1", - "send-done.1")); - auto recv_done_fwd = FindInstruction(transformed_module, "recv-done"); - auto recv_done_bwd = FindInstruction(transformed_module, "recv-done.1"); + HloInstruction* recv_done_fwd = + hlo_query::FindInstruction(while_body, "recv-done"); + HloInstruction* recv_done_bwd = + hlo_query::FindInstruction(while_body, "recv-done.1"); // TODO: b/356201477 - Investigate potential NCCL deadlock in // collective_permute_decomposer diff --git a/xla/tests/hlo_test_base.cc b/xla/tests/hlo_test_base.cc index ee8d4653b97089..3c3c529243a498 100644 --- a/xla/tests/hlo_test_base.cc +++ b/xla/tests/hlo_test_base.cc @@ -1097,9 +1097,9 @@ HloComputation* HloTestBase::FindComputation(HloModule* module, HloInstruction* HloTestBase::FindInstruction(HloModule* module, absl::string_view name) { for (const HloComputation* computation : module->computations()) { - if (auto instruction = hlo_query::FindFirstInstruction(computation, name); - instruction.first != nullptr) { - return instruction.first; + if (HloInstruction* instruction = + hlo_query::FindInstruction(computation, name)) { + return instruction; } } return nullptr; @@ -1108,9 +1108,9 @@ HloInstruction* HloTestBase::FindInstruction(HloModule* module, HloInstruction* HloTestBase::FindInstruction(HloModule* module, HloOpcode opcode) { for (const HloComputation* computation : module->computations()) { - if (auto instruction = hlo_query::FindFirstInstruction(computation, opcode); - instruction.first != nullptr) { - return instruction.first; + if (HloInstruction* instruction = + hlo_query::FindInstruction(computation, opcode)) { + return instruction; } } return nullptr;