From b886ebec3171e3cf07f9130738e766da577305e2 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Wed, 18 Sep 2024 11:51:35 -0700 Subject: [PATCH] [XLA:GPU] Support partially pipelined async send recv ops This is needed for pipeline parallelism on GPU where the send/recv operations are issued in one loop iteration and completed in the next. The same buffer must be alive throughout the process and no copies can be inserted. Avoid copies for these partially pipelined async send/recv ops. Insert the required copies and controlflow constraints on the send/recv ops separately. This is to ensure that the live times of the buffers do not overlap. Send: For send, a copy is inserted on the operand, starting a new live range. By enforcing this copy after the corresponding send/done, buffer live times are disjoint. Recv: For recv, a copy is inserted after recv-done, ending the live time of the buffer. Bt enforcing the copy to be before the corresponding recv. buffer live times are disjoint. PiperOrigin-RevId: 676075106 --- xla/hlo/utils/hlo_query.cc | 3 + xla/service/BUILD | 2 + xla/service/copy_insertion.cc | 188 ++++++++++++++---- xla/service/copy_insertion.h | 4 + xla/service/copy_insertion_test.cc | 294 +++++++++++++++++++++++++++++ 5 files changed, 452 insertions(+), 39 deletions(-) diff --git a/xla/hlo/utils/hlo_query.cc b/xla/hlo/utils/hlo_query.cc index 147f54822aef9..5e1b182f531e4 100644 --- a/xla/hlo/utils/hlo_query.cc +++ b/xla/hlo/utils/hlo_query.cc @@ -280,6 +280,7 @@ HloComputation* FindComputation(HloModule* module, absl::string_view name) { return *it; } +// TODO: Make this return only the instruction. std::pair FindFirstInstruction( const HloComputation* computation, absl::string_view name) { int current_index = 0; @@ -293,6 +294,7 @@ std::pair FindFirstInstruction( return {nullptr, -1}; } +// TODO: Make this return only the instruction. std::pair FindFirstInstruction( const HloComputation* computation, HloOpcode opcode) { int current_index = 0; @@ -306,6 +308,7 @@ std::pair FindFirstInstruction( return {nullptr, -1}; } +// TODO: Remove this. It could be misleading as there is no linear order. bool IsBeforeInComputation(const HloComputation* computation, absl::string_view inst1, absl::string_view inst2) { return FindFirstInstruction(computation, inst1).second < diff --git a/xla/service/BUILD b/xla/service/BUILD index 5f1c0407d99fd..643427c445cc6 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -5234,6 +5234,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_reachability", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -5295,6 +5296,7 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", + "//xla/hlo/utils:hlo_query", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/log", diff --git a/xla/service/copy_insertion.cc b/xla/service/copy_insertion.cc index 6e2fc858d0958..ddc53dbac9ba4 100644 --- a/xla/service/copy_insertion.cc +++ b/xla/service/copy_insertion.cc @@ -16,10 +16,12 @@ limitations under the License. #include "xla/service/copy_insertion.h" #include +#include #include #include #include #include +#include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" @@ -35,12 +37,15 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/frontend_attributes.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_reachability.h" +#include "xla/hlo/utils/hlo_query.h" #include "xla/map_util.h" #include "xla/service/call_graph.h" #include "xla/service/compile_time_cap.h" @@ -186,6 +191,22 @@ DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to, return std::make_pair(from_deep_copy, to_deep_copy); } +bool IsSendRecv(const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kSend || + instruction->opcode() == HloOpcode::kRecv; +} + +bool IsSendRecvDone(const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kSendDone || + instruction->opcode() == HloOpcode::kRecvDone; +} + +bool IsSendRecvInInit(const HloInstruction* init, const ShapeIndex& index) { + if (index.empty()) return false; + int64_t i = index.front(); + return i < init->operand_count() && IsSendRecv(init->operand(i)); +} + // Compute the indices of the loop state which need copies in order to avoid // live range interference. Generally, an element in the loop state does not // need to be copied if the element is passed through transparently through the @@ -202,9 +223,14 @@ bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow, for (auto& pair : *indices_to_copy) { const ShapeIndex& index = pair.first; bool& should_copy = pair.second; - // If there is any ambiguity, then loop state must be copied. - if (dataflow.GetValueSet(init, index).values().size() > 1 || - dataflow.GetValueSet(xla_while, index).values().size() > 1) { + if (IsSendRecvInInit(init, index)) { + // Do not copy partially pipelined send/recv ops. The required copies will + // be inserted specifically for the send/recv ops. + should_copy = false; + continue; + } else if (dataflow.GetValueSet(init, index).values().size() > 1 || + dataflow.GetValueSet(xla_while, index).values().size() > 1) { + // If there is any ambiguity, then loop state must be copied. should_copy = true; } else { // If the output of the while instruction is not the same as the init @@ -1307,42 +1333,6 @@ class CopyRemover { if (buffer.values().at(0)->defining_instruction()->IsFused()) { continue; } - if (check_live_range_ordering) { - // Skip checking if execution thread is not included. - auto should_skip_value = [&execution_threads](const HloValue* value) { - return value->defining_instruction()->parent() != nullptr && - !HloInstruction::IsThreadIncluded(value->defining_instruction() - ->parent() - ->execution_thread(), - execution_threads); - }; - // Verify values contained in the buffer are strictly ordered. This - // should always be the case after adding copies to eliminate - // interference. Specifically, the addition of the control flow edges - // between copies added around aliased operations (kWhile) guarantees - // this strict order. - for (const HloValue* value_a : buffer.values()) { - if (value_a->shape().IsToken()) { - // Token values have no representation and cannot interfere. - continue; - } - if (should_skip_value(value_a)) { - continue; - } - for (const HloValue* value_b : buffer.values()) { - if (!should_skip_value(value_b) && value_a != value_b) { - DCHECK(ordering_->LiveRangeStrictlyBefore( - *value_a, *value_b, dataflow_, - /*use_is_always_before_def_in_same_instr=*/true) || - ordering_->LiveRangeStrictlyBefore( - *value_b, *value_a, dataflow_, - /*use_is_always_before_def_in_same_instr=*/true)) - << value_a->ToString() << " and " << value_b->ToString() - << " are not ordered"; - } - } - } - } std::vector values = buffer.values(); absl::c_sort(values, [this, &instruction_ids](const HloValue* a, @@ -2014,6 +2004,122 @@ absl::Status CopyInsertion::AddCopiesForConditional( return absl::OkStatus(); } +HloInstruction* FindAsyncSendRecvDoneInWhileBody( + const HloComputation* while_body, const HloInstruction* start_op) { + // Partially pipelined send/recv must have a single user. + if (start_op->user_count() != 1) return nullptr; + HloInstruction* unique_user = start_op->users().front(); + // Send/recv must be consumed by send/recv-done op or be passed through the + // loop. + if (IsSendRecvDone(unique_user)) return unique_user; + if (unique_user->opcode() != HloOpcode::kTuple || !unique_user->IsRoot()) + return nullptr; + int64_t index = unique_user->operand_index(start_op); + for (const HloInstruction* it : + while_body->parameter_instruction(0)->users()) { + const auto* gte = DynCast(it); + if (gte->tuple_index() == index) { + CHECK_EQ(gte->user_count(), 1) << "send/recv in next loop iteration must " + "be consumed by unique send/recv-done."; + HloInstruction* next_unique_user = gte->users().front(); + if (IsSendRecvDone(next_unique_user)) return next_unique_user; + } + } + return nullptr; +} + +// Add copies for partially pipelined async send/recv. +// Copies are added before before startint to send and after finishing to recv. +// This is to prevent overlapping live times of the buffers. The control edges +// from the added copy to the recv or send-done operation guarantee disjoint +// live times. +// +// +// Before: +// +// kParameter kParameter +// | | +// kSendDone kRecvDone +// | +// ... consumer +// +// producer ... +// | +// kSend kRecv +// | | +// (body root) (body root) +// +// +// After: +// +// kParameter kParameter +// | | +// kSendDone ----+ kRecvDone +// | | +// ctrl kCopy ----+ +// producer edge | | +// | | consumer ctrl +// kCopy <-----+ edge +// | | +// kSend kRecv <---+ +// | | +// (body root) (body root) +// +absl::Status CopyInsertion::AddCopiesForAsyncSendRecv( + const HloAliasAnalysis& alias_analysis, HloInstruction* start_op) { + // If start op has multiple users, this must be the synchronous use of + // send/recv. + // TODO: Disambiguate sync and async use of send/recv b/369589022 + if (start_op->users().size() != 1) return absl::OkStatus(); + + // If start feeds directly into done, the live time is contained and we don't + // need to add any copies. + HloInstruction* unique_user = start_op->users().front(); + const HloOpcode done_opcode = start_op->opcode() == HloOpcode::kSend + ? HloOpcode::kSendDone + : HloOpcode::kRecvDone; + if (unique_user->opcode() == done_opcode) { + return absl::OkStatus(); + } + + // For send/recv outside of the while loop, live times are disjoint. No copies + // needed. + HloComputation* while_body = start_op->parent(); + if (!while_body->IsWhileBodyComputation()) return absl::OkStatus(); + + // Handle send case. + HloInstruction* done_op = + FindAsyncSendRecvDoneInWhileBody(while_body, start_op); + // TODO: Disambiguate sync and async use of send/recv b/369589022 + if (done_op == nullptr) return absl::OkStatus(); + if (start_op->opcode() == HloOpcode::kSend) { + HloInstruction* operand = start_op->mutable_operand(0); + HloInstruction* copied_operand = + while_body->AddInstruction(HloInstruction::CreateUnary( + operand->shape(), HloOpcode::kCopy, operand)); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(start_op, copied_operand)); + TF_RETURN_IF_ERROR(done_op->AddControlDependencyTo(copied_operand)); + return absl::OkStatus(); + } + + // Handle recv case. + CHECK_EQ(start_op->opcode(), HloOpcode::kRecv); + PtrVec done_op_users = done_op->users(); + ShapeTree copies_added(done_op->shape()); + TF_ASSIGN_OR_RETURN(HloInstruction * done_op_copy, + while_body->DeepCopyInstruction( + done_op, /*indices_to_copy=*/nullptr, &copies_added)); + for (auto [shape_index, instr] : copies_added) { + if (instr != nullptr) + TF_RETURN_IF_ERROR(instr->AddControlDependencyTo(start_op)); + } + TF_RETURN_IF_ERROR(done_op->AddControlDependencyTo(start_op)); + for (HloInstruction* it : done_op_users) { + TF_RETURN_IF_ERROR(done_op->ReplaceUseWith(it, done_op_copy)); + } + return absl::OkStatus(); +} + // Add kCopy instructions to the given module to guarantee there is no // live-range interference. Generally interference can only occur around kWhile // instructions which have update-in-place semantics. @@ -2034,6 +2140,10 @@ absl::Status CopyInsertion::AddCopiesToResolveInterference( } else if (instruction->opcode() == HloOpcode::kConditional) { TF_RETURN_IF_ERROR( AddCopiesForConditional(*alias_analysis, instruction)); + } else if (IsSendRecv(instruction)) { + // TODO: Generalize this to all async collectives. + TF_RETURN_IF_ERROR( + AddCopiesForAsyncSendRecv(*alias_analysis, instruction)); } else { // When an operand is a tuple, we avoid copying the operand multiple // times by recording and checking the operand number of operands that diff --git a/xla/service/copy_insertion.h b/xla/service/copy_insertion.h index b76d47cd1a871..be0874f64cf9a 100644 --- a/xla/service/copy_insertion.h +++ b/xla/service/copy_insertion.h @@ -107,6 +107,10 @@ class CopyInsertion : public HloModulePass { virtual absl::Status AddCopiesForConditional( const HloAliasAnalysis& alias_analysis, HloInstruction* conditional); + // Add copies for async send/recv instructions. + absl::Status AddCopiesForAsyncSendRecv(const HloAliasAnalysis& alias_analysis, + HloInstruction* async); + // Backend specific function that decides whether an instruction can share // buffer with its operand. HloDataflowAnalysis::CanShareBuffer can_share_buffer_; diff --git a/xla/service/copy_insertion_test.cc b/xla/service/copy_insertion_test.cc index 5250e9842895d..f9c7efc99fbe6 100644 --- a/xla/service/copy_insertion_test.cc +++ b/xla/service/copy_insertion_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" +#include "xla/hlo/utils/hlo_query.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal_util.h" @@ -3869,5 +3870,298 @@ ENTRY main { EXPECT_EQ(CountCopies(*module), 0); } +TEST_F(CopyInsertionTest, PartiallyPipelinedAsyncRecv) { + constexpr absl::string_view kModuleString = R"( + HloModule test, entry_computation_layout={()->f32[16]{0}}, num_partitions=4 + + while_body { + param = ((f32[16]{0}, u32[], token[])) parameter(0) + prev_recv = (f32[16]{0}, u32[], token[]) get-tuple-element(param), index=0 + recv_done = (f32[16]{0}, token[]) recv-done(prev_recv), channel_id=1 + after_all = token[] after-all() + recv = (f32[16]{0}, u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + ROOT tuple = ((f32[16]{0}, u32[], token[])) tuple(recv) + } + + // Infinite loop to keep IR small. + while_condition { + param = ((f32[16]{0}, u32[], token[])) parameter(0) + ROOT infinite_loop = pred[] constant(true) + } + + ENTRY main_spmd { + after_all = token[] after-all() + recv = (f32[16]{0}, u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + init = ((f32[16]{0}, u32[], token[])) tuple(recv) + while = ((f32[16]{0}, u32[], token[])) while(init), + condition=while_condition, body=while_body + recv_ctx = (f32[16]{0}, u32[], token[]) get-tuple-element(while), index=0 + recv_done = (f32[16]{0}, token[]) recv-done(recv_ctx), channel_id=1 + ROOT result = f32[16]{0} get-tuple-element(recv_done), index=0 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleString)); + CopyInsertion copy_insertion(nullptr, + /*use_region_based_live_range_analysis=*/-1); + + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + VLOG(2) << module->ToString(); + + // All async start/end will be ordered so that all copies are removable. + EXPECT_EQ(CountCopies(*module), 0); + + // Expect control dependency from recv-done to recv. + HloComputation* while_body = + hlo_query::FindComputation(module.get(), "while_body"); + HloInstruction* recv_done = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kRecvDone).first; + HloInstruction* recv = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kRecv).first; + EXPECT_THAT(recv->control_predecessors(), UnorderedElementsAre(recv_done)); +} + +TEST_F(CopyInsertionTest, PartiallyPipelinedAsyncRecvMultipleUses) { + constexpr absl::string_view kModuleString = R"( + HloModule test, entry_computation_layout={(f32[16]{0})->f32[16]{0}}, + num_partitions=4 + + while_body { + param = ((f32[16]{0}, u32[], token[]), f32[16]{0}) parameter(0) + prev_recv = (f32[16]{0}, u32[], token[]) get-tuple-element(param), index=0 + recv_done = (f32[16]{0}, token[]) recv-done(prev_recv), channel_id=1 + recv_data = f32[16]{0} get-tuple-element(recv_done), index=0 + after_all = token[] after-all() + recv = (f32[16]{0}, u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + + // `recv_data` is again here, which extends it's live range. + ROOT tuple = ((f32[16]{0}, u32[], token[]), f32[16]{0}) tuple(recv, + recv_data) + } + + // Infinite loop to keep IR small. + while_condition { + param = ((f32[16]{0}, u32[], token[]), f32[16]{0}) parameter(0) + ROOT infinite_loop = pred[] constant(true) + } + + ENTRY main_spmd { + data = f32[16]{0} parameter(0) + after_all = token[] after-all() + recv = (f32[16]{0}, u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + init = ((f32[16]{0}, u32[], token[]), f32[16]{0}) tuple(recv, data) + while = ((f32[16]{0}, u32[], token[]), f32[16]{0}) while(init), + condition=while_condition, body=while_body + recv_ctx = (f32[16]{0}, u32[], token[]) get-tuple-element(while), index=0 + recv_done = (f32[16]{0}, token[]) recv-done(recv_ctx), channel_id=1 + ROOT result = f32[16]{0} get-tuple-element(recv_done), index=0 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleString)); + CopyInsertion copy_insertion(nullptr, + /*use_region_based_live_range_analysis=*/-1); + + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + VLOG(2) << module->ToString(); + + // All async start/end will be ordered so that all copies, except for an extra + // use of the recv result, are removable. Additionally, there will be one copy + // leading into the loop. + HloComputation* while_body = + hlo_query::FindComputation(module.get(), "while_body"); + EXPECT_EQ(CountCopies(*module), 2); + EXPECT_EQ(CountCopies(*while_body), 1); + + // Expect control dependency from recv-done to recv. + HloInstruction* recv_done = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kRecvDone).first; + HloInstruction* recv = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kRecv).first; + HloInstruction* recv_done_copy = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kCopy).first; + EXPECT_THAT(recv_done_copy, op::Copy(op::GetTupleElement(recv_done))); + EXPECT_THAT(recv->control_predecessors(), + UnorderedElementsAre(recv_done, recv_done_copy)); +} + +TEST_F(CopyInsertionTest, PartiallyPipelinedAsyncSendMultipleUses) { + constexpr absl::string_view kModuleString = R"( + HloModule test, entry_computation_layout={(f32[16]{0})->f32[16]{0}}, + num_partitions=4 + + while_body { + param = ((f32[16]{0}, u32[], token[]), f32[16]{0}) parameter(0) + prev_send = (f32[16]{0}, u32[], token[]) get-tuple-element(param), index=0 + data = f32[16]{0} get-tuple-element(param), index=1 + send_done = (f32[16]{0}, token[]) send-done(prev_send), channel_id=1 + after_all = token[] after-all() + send = (f32[16]{0}, u32[], token[]) send(data, after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + + // `data` is used again here, which extends it's live range beyond `send`. + ROOT tuple = ((f32[16]{0}, u32[], token[]), f32[16]{0}) tuple(send, data) + } + + // Infinite loop to keep IR small. + while_condition { + param = ((f32[16]{0}, u32[], token[]), f32[16]{0}) parameter(0) + ROOT infinite_loop = pred[] constant(true) + } + + ENTRY main_spmd { + data = f32[16]{0} parameter(0) + after_all = token[] after-all() + send = (f32[16]{0}, u32[], token[]) send(data, after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + init = ((f32[16]{0}, u32[], token[]), f32[16]{0}) tuple(send, data) + while = ((f32[16]{0}, u32[], token[]), f32[16]{0}) while(init), + condition=while_condition, body=while_body + send_ctx = (f32[16]{0}, u32[], token[]) get-tuple-element(while), index=0 + send_done = (f32[16]{0}, token[]) send-done(send_ctx), channel_id=1 + ROOT data_ = f32[16]{0} get-tuple-element(while), index=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleString)); + CopyInsertion copy_insertion(nullptr, + /*use_region_based_live_range_analysis=*/-1); + + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + VLOG(2) << module->ToString(); + + // All async start/end will be ordered so that all copies, except for an extra + // use of the send operand, are removable. Additionally, there will be 2 + // copies leading into the loop and returning copying the result. + HloComputation* while_body = + hlo_query::FindComputation(module.get(), "while_body"); + EXPECT_EQ(CountCopies(*module), 3); + EXPECT_EQ(CountCopies(*while_body), 1); + + // Expect control dependency from send-done to send. + HloInstruction* send_done = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kSendDone).first; + HloInstruction* send = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kSend).first; + HloInstruction* send_operand_copy = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kCopy).first; + EXPECT_THAT(send, op::Send(send_operand_copy, op::AfterAll())); + EXPECT_THAT(send_operand_copy->control_predecessors(), + UnorderedElementsAre(send_done)); +} + +TEST_F(CopyInsertionTest, PartiallyPipelinedAsyncSendRecvPipelineParallelism) { + constexpr absl::string_view kModuleString = R"( + HloModule test, entry_computation_layout={(f32[16]{0})->f32[16]{0}}, + num_partitions=4 + + while_body { + param = ((f32[16]{0}, u32[], token[]), (f32[16]{0}, u32[], token[]), + f32[16]{0}, f32[16]{0}) parameter(0) + + prev_fwd = f32[16]{0} get-tuple-element(param), index=3 + + prev_send = (f32[16]{0}, u32[], token[]) get-tuple-element(param), index=0 + send_done = (f32[16]{0}, token[]) send-done(prev_send), channel_id=1 + prev_recv = (f32[16]{0}, u32[], token[]) get-tuple-element(param), index=1 + recv_done = (f32[16]{0}, token[]) recv-done(prev_recv), channel_id=2 + + fwd = f32[16]{0} get-tuple-element(recv_done), index=0 + + after_all = token[] after-all() + send = (f32[16]{0}, u32[], token[]) send(prev_fwd, after_all), + channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + recv = (f32[16]{0}, u32[], token[]) recv(after_all), channel_id=2, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + + // Both, the data that was sent and the data that was received are live + // until the end of the while loop. + ROOT tuple = ((f32[16]{0}, u32[], token[]), (f32[16]{0}, u32[], token[]), + f32[16]{0}, f32[16]{0}) tuple(send, recv, prev_fwd, fwd) + } + + // Infinite loop to keep IR small. + while_condition { + param = ((f32[16]{0}, u32[], token[]), (f32[16]{0}, u32[], token[]), + f32[16]{0}, f32[16]{0}) parameter(0) + ROOT infinite_loop = pred[] constant(true) + } + + ENTRY main_spmd { + data = f32[16]{0} parameter(0) + after_all = token[] after-all() + recv = (f32[16]{0}, u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + send = (f32[16]{0}, u32[], token[]) send(data, after_all), channel_id=2, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + init = ((f32[16]{0}, u32[], token[]), (f32[16]{0}, u32[], token[]), + f32[16]{0}, f32[16]{0}) tuple(send, recv, data, data) + while = ((f32[16]{0}, u32[], token[]), (f32[16]{0}, u32[], token[]), + f32[16]{0}, f32[16]{0}) while(init), condition=while_condition, + body=while_body + recv_ctx = (f32[16]{0}, u32[], token[]) get-tuple-element(while), index=0 + recv_done = (f32[16]{0}, token[]) recv-done(recv_ctx), channel_id=1 + send_ctx = (f32[16]{0}, u32[], token[]) get-tuple-element(while), index=0 + send_done = (f32[16]{0}, token[]) send-done(send_ctx), channel_id=2 + ROOT data_ = f32[16]{0} get-tuple-element(recv_done), index=0 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleString)); + CopyInsertion copy_insertion(nullptr, + /*use_region_based_live_range_analysis=*/-1); + + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + VLOG(2) << module->ToString(); + + // All async start/end will be ordered so that all copies but two are + // removable: + // - The copy for the extra use of the send operand. + // - The copy for the extra use of the recv result. + // The copy removal heuristic fails on removing one data copy, so the total + // number of copies in the while loop body is 3. + HloComputation* while_body = + hlo_query::FindComputation(module.get(), "while_body"); + EXPECT_EQ(CountCopies(*module), 6); + EXPECT_EQ(CountCopies(*while_body), 3); + + // Expect control dependency from send-done to send. + HloInstruction* send_done = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kSendDone).first; + HloInstruction* send = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kSend).first; + HloInstruction* send_operand_copy = send->mutable_operand(0); + EXPECT_THAT(send_operand_copy, op::Copy()); + EXPECT_THAT(send, op::Send(send_operand_copy, op::AfterAll())); + EXPECT_THAT(send_operand_copy->control_predecessors(), + UnorderedElementsAre(send_done)); + + // Expect control dependency from recv-done to recv. + HloInstruction* recv_done = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kRecvDone).first; + HloInstruction* recv = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kRecv).first; + HloInstruction* recv_done_copy = *absl::c_find_if( + recv->control_predecessors(), HloPredicateIsOp); + EXPECT_THAT(recv_done_copy, op::Copy(op::GetTupleElement(recv_done))); + EXPECT_THAT(recv->control_predecessors(), + UnorderedElementsAre(recv_done, recv_done_copy)); +} + } // namespace } // namespace xla