diff --git a/xla/hlo/utils/hlo_query.cc b/xla/hlo/utils/hlo_query.cc index 147f54822aef9..5e1b182f531e4 100644 --- a/xla/hlo/utils/hlo_query.cc +++ b/xla/hlo/utils/hlo_query.cc @@ -280,6 +280,7 @@ HloComputation* FindComputation(HloModule* module, absl::string_view name) { return *it; } +// TODO: Make this return only the instruction. std::pair FindFirstInstruction( const HloComputation* computation, absl::string_view name) { int current_index = 0; @@ -293,6 +294,7 @@ std::pair FindFirstInstruction( return {nullptr, -1}; } +// TODO: Make this return only the instruction. std::pair FindFirstInstruction( const HloComputation* computation, HloOpcode opcode) { int current_index = 0; @@ -306,6 +308,7 @@ std::pair FindFirstInstruction( return {nullptr, -1}; } +// TODO: Remove this. It could be misleading as there is no linear order. bool IsBeforeInComputation(const HloComputation* computation, absl::string_view inst1, absl::string_view inst2) { return FindFirstInstruction(computation, inst1).second < diff --git a/xla/service/BUILD b/xla/service/BUILD index 5f1c0407d99fd..643427c445cc6 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -5234,6 +5234,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_reachability", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -5295,6 +5296,7 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", + "//xla/hlo/utils:hlo_query", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/log", diff --git a/xla/service/copy_insertion.cc b/xla/service/copy_insertion.cc index 6e2fc858d0958..ddc53dbac9ba4 100644 --- a/xla/service/copy_insertion.cc +++ b/xla/service/copy_insertion.cc @@ -16,10 +16,12 @@ limitations under the License. #include "xla/service/copy_insertion.h" #include +#include #include #include #include #include +#include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" @@ -35,12 +37,15 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/frontend_attributes.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_reachability.h" +#include "xla/hlo/utils/hlo_query.h" #include "xla/map_util.h" #include "xla/service/call_graph.h" #include "xla/service/compile_time_cap.h" @@ -186,6 +191,22 @@ DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to, return std::make_pair(from_deep_copy, to_deep_copy); } +bool IsSendRecv(const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kSend || + instruction->opcode() == HloOpcode::kRecv; +} + +bool IsSendRecvDone(const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kSendDone || + instruction->opcode() == HloOpcode::kRecvDone; +} + +bool IsSendRecvInInit(const HloInstruction* init, const ShapeIndex& index) { + if (index.empty()) return false; + int64_t i = index.front(); + return i < init->operand_count() && IsSendRecv(init->operand(i)); +} + // Compute the indices of the loop state which need copies in order to avoid // live range interference. Generally, an element in the loop state does not // need to be copied if the element is passed through transparently through the @@ -202,9 +223,14 @@ bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow, for (auto& pair : *indices_to_copy) { const ShapeIndex& index = pair.first; bool& should_copy = pair.second; - // If there is any ambiguity, then loop state must be copied. - if (dataflow.GetValueSet(init, index).values().size() > 1 || - dataflow.GetValueSet(xla_while, index).values().size() > 1) { + if (IsSendRecvInInit(init, index)) { + // Do not copy partially pipelined send/recv ops. The required copies will + // be inserted specifically for the send/recv ops. + should_copy = false; + continue; + } else if (dataflow.GetValueSet(init, index).values().size() > 1 || + dataflow.GetValueSet(xla_while, index).values().size() > 1) { + // If there is any ambiguity, then loop state must be copied. should_copy = true; } else { // If the output of the while instruction is not the same as the init @@ -1307,42 +1333,6 @@ class CopyRemover { if (buffer.values().at(0)->defining_instruction()->IsFused()) { continue; } - if (check_live_range_ordering) { - // Skip checking if execution thread is not included. - auto should_skip_value = [&execution_threads](const HloValue* value) { - return value->defining_instruction()->parent() != nullptr && - !HloInstruction::IsThreadIncluded(value->defining_instruction() - ->parent() - ->execution_thread(), - execution_threads); - }; - // Verify values contained in the buffer are strictly ordered. This - // should always be the case after adding copies to eliminate - // interference. Specifically, the addition of the control flow edges - // between copies added around aliased operations (kWhile) guarantees - // this strict order. - for (const HloValue* value_a : buffer.values()) { - if (value_a->shape().IsToken()) { - // Token values have no representation and cannot interfere. - continue; - } - if (should_skip_value(value_a)) { - continue; - } - for (const HloValue* value_b : buffer.values()) { - if (!should_skip_value(value_b) && value_a != value_b) { - DCHECK(ordering_->LiveRangeStrictlyBefore( - *value_a, *value_b, dataflow_, - /*use_is_always_before_def_in_same_instr=*/true) || - ordering_->LiveRangeStrictlyBefore( - *value_b, *value_a, dataflow_, - /*use_is_always_before_def_in_same_instr=*/true)) - << value_a->ToString() << " and " << value_b->ToString() - << " are not ordered"; - } - } - } - } std::vector values = buffer.values(); absl::c_sort(values, [this, &instruction_ids](const HloValue* a, @@ -2014,6 +2004,122 @@ absl::Status CopyInsertion::AddCopiesForConditional( return absl::OkStatus(); } +HloInstruction* FindAsyncSendRecvDoneInWhileBody( + const HloComputation* while_body, const HloInstruction* start_op) { + // Partially pipelined send/recv must have a single user. + if (start_op->user_count() != 1) return nullptr; + HloInstruction* unique_user = start_op->users().front(); + // Send/recv must be consumed by send/recv-done op or be passed through the + // loop. + if (IsSendRecvDone(unique_user)) return unique_user; + if (unique_user->opcode() != HloOpcode::kTuple || !unique_user->IsRoot()) + return nullptr; + int64_t index = unique_user->operand_index(start_op); + for (const HloInstruction* it : + while_body->parameter_instruction(0)->users()) { + const auto* gte = DynCast(it); + if (gte->tuple_index() == index) { + CHECK_EQ(gte->user_count(), 1) << "send/recv in next loop iteration must " + "be consumed by unique send/recv-done."; + HloInstruction* next_unique_user = gte->users().front(); + if (IsSendRecvDone(next_unique_user)) return next_unique_user; + } + } + return nullptr; +} + +// Add copies for partially pipelined async send/recv. +// Copies are added before before startint to send and after finishing to recv. +// This is to prevent overlapping live times of the buffers. The control edges +// from the added copy to the recv or send-done operation guarantee disjoint +// live times. +// +// +// Before: +// +// kParameter kParameter +// | | +// kSendDone kRecvDone +// | +// ... consumer +// +// producer ... +// | +// kSend kRecv +// | | +// (body root) (body root) +// +// +// After: +// +// kParameter kParameter +// | | +// kSendDone ----+ kRecvDone +// | | +// ctrl kCopy ----+ +// producer edge | | +// | | consumer ctrl +// kCopy <-----+ edge +// | | +// kSend kRecv <---+ +// | | +// (body root) (body root) +// +absl::Status CopyInsertion::AddCopiesForAsyncSendRecv( + const HloAliasAnalysis& alias_analysis, HloInstruction* start_op) { + // If start op has multiple users, this must be the synchronous use of + // send/recv. + // TODO: Disambiguate sync and async use of send/recv b/369589022 + if (start_op->users().size() != 1) return absl::OkStatus(); + + // If start feeds directly into done, the live time is contained and we don't + // need to add any copies. + HloInstruction* unique_user = start_op->users().front(); + const HloOpcode done_opcode = start_op->opcode() == HloOpcode::kSend + ? HloOpcode::kSendDone + : HloOpcode::kRecvDone; + if (unique_user->opcode() == done_opcode) { + return absl::OkStatus(); + } + + // For send/recv outside of the while loop, live times are disjoint. No copies + // needed. + HloComputation* while_body = start_op->parent(); + if (!while_body->IsWhileBodyComputation()) return absl::OkStatus(); + + // Handle send case. + HloInstruction* done_op = + FindAsyncSendRecvDoneInWhileBody(while_body, start_op); + // TODO: Disambiguate sync and async use of send/recv b/369589022 + if (done_op == nullptr) return absl::OkStatus(); + if (start_op->opcode() == HloOpcode::kSend) { + HloInstruction* operand = start_op->mutable_operand(0); + HloInstruction* copied_operand = + while_body->AddInstruction(HloInstruction::CreateUnary( + operand->shape(), HloOpcode::kCopy, operand)); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(start_op, copied_operand)); + TF_RETURN_IF_ERROR(done_op->AddControlDependencyTo(copied_operand)); + return absl::OkStatus(); + } + + // Handle recv case. + CHECK_EQ(start_op->opcode(), HloOpcode::kRecv); + PtrVec done_op_users = done_op->users(); + ShapeTree copies_added(done_op->shape()); + TF_ASSIGN_OR_RETURN(HloInstruction * done_op_copy, + while_body->DeepCopyInstruction( + done_op, /*indices_to_copy=*/nullptr, &copies_added)); + for (auto [shape_index, instr] : copies_added) { + if (instr != nullptr) + TF_RETURN_IF_ERROR(instr->AddControlDependencyTo(start_op)); + } + TF_RETURN_IF_ERROR(done_op->AddControlDependencyTo(start_op)); + for (HloInstruction* it : done_op_users) { + TF_RETURN_IF_ERROR(done_op->ReplaceUseWith(it, done_op_copy)); + } + return absl::OkStatus(); +} + // Add kCopy instructions to the given module to guarantee there is no // live-range interference. Generally interference can only occur around kWhile // instructions which have update-in-place semantics. @@ -2034,6 +2140,10 @@ absl::Status CopyInsertion::AddCopiesToResolveInterference( } else if (instruction->opcode() == HloOpcode::kConditional) { TF_RETURN_IF_ERROR( AddCopiesForConditional(*alias_analysis, instruction)); + } else if (IsSendRecv(instruction)) { + // TODO: Generalize this to all async collectives. + TF_RETURN_IF_ERROR( + AddCopiesForAsyncSendRecv(*alias_analysis, instruction)); } else { // When an operand is a tuple, we avoid copying the operand multiple // times by recording and checking the operand number of operands that diff --git a/xla/service/copy_insertion.h b/xla/service/copy_insertion.h index b76d47cd1a871..be0874f64cf9a 100644 --- a/xla/service/copy_insertion.h +++ b/xla/service/copy_insertion.h @@ -107,6 +107,10 @@ class CopyInsertion : public HloModulePass { virtual absl::Status AddCopiesForConditional( const HloAliasAnalysis& alias_analysis, HloInstruction* conditional); + // Add copies for async send/recv instructions. + absl::Status AddCopiesForAsyncSendRecv(const HloAliasAnalysis& alias_analysis, + HloInstruction* async); + // Backend specific function that decides whether an instruction can share // buffer with its operand. HloDataflowAnalysis::CanShareBuffer can_share_buffer_; diff --git a/xla/service/copy_insertion_test.cc b/xla/service/copy_insertion_test.cc index 5250e9842895d..f9c7efc99fbe6 100644 --- a/xla/service/copy_insertion_test.cc +++ b/xla/service/copy_insertion_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" +#include "xla/hlo/utils/hlo_query.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal_util.h" @@ -3869,5 +3870,298 @@ ENTRY main { EXPECT_EQ(CountCopies(*module), 0); } +TEST_F(CopyInsertionTest, PartiallyPipelinedAsyncRecv) { + constexpr absl::string_view kModuleString = R"( + HloModule test, entry_computation_layout={()->f32[16]{0}}, num_partitions=4 + + while_body { + param = ((f32[16]{0}, u32[], token[])) parameter(0) + prev_recv = (f32[16]{0}, u32[], token[]) get-tuple-element(param), index=0 + recv_done = (f32[16]{0}, token[]) recv-done(prev_recv), channel_id=1 + after_all = token[] after-all() + recv = (f32[16]{0}, u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + ROOT tuple = ((f32[16]{0}, u32[], token[])) tuple(recv) + } + + // Infinite loop to keep IR small. + while_condition { + param = ((f32[16]{0}, u32[], token[])) parameter(0) + ROOT infinite_loop = pred[] constant(true) + } + + ENTRY main_spmd { + after_all = token[] after-all() + recv = (f32[16]{0}, u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + init = ((f32[16]{0}, u32[], token[])) tuple(recv) + while = ((f32[16]{0}, u32[], token[])) while(init), + condition=while_condition, body=while_body + recv_ctx = (f32[16]{0}, u32[], token[]) get-tuple-element(while), index=0 + recv_done = (f32[16]{0}, token[]) recv-done(recv_ctx), channel_id=1 + ROOT result = f32[16]{0} get-tuple-element(recv_done), index=0 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleString)); + CopyInsertion copy_insertion(nullptr, + /*use_region_based_live_range_analysis=*/-1); + + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + VLOG(2) << module->ToString(); + + // All async start/end will be ordered so that all copies are removable. + EXPECT_EQ(CountCopies(*module), 0); + + // Expect control dependency from recv-done to recv. + HloComputation* while_body = + hlo_query::FindComputation(module.get(), "while_body"); + HloInstruction* recv_done = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kRecvDone).first; + HloInstruction* recv = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kRecv).first; + EXPECT_THAT(recv->control_predecessors(), UnorderedElementsAre(recv_done)); +} + +TEST_F(CopyInsertionTest, PartiallyPipelinedAsyncRecvMultipleUses) { + constexpr absl::string_view kModuleString = R"( + HloModule test, entry_computation_layout={(f32[16]{0})->f32[16]{0}}, + num_partitions=4 + + while_body { + param = ((f32[16]{0}, u32[], token[]), f32[16]{0}) parameter(0) + prev_recv = (f32[16]{0}, u32[], token[]) get-tuple-element(param), index=0 + recv_done = (f32[16]{0}, token[]) recv-done(prev_recv), channel_id=1 + recv_data = f32[16]{0} get-tuple-element(recv_done), index=0 + after_all = token[] after-all() + recv = (f32[16]{0}, u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + + // `recv_data` is again here, which extends it's live range. + ROOT tuple = ((f32[16]{0}, u32[], token[]), f32[16]{0}) tuple(recv, + recv_data) + } + + // Infinite loop to keep IR small. + while_condition { + param = ((f32[16]{0}, u32[], token[]), f32[16]{0}) parameter(0) + ROOT infinite_loop = pred[] constant(true) + } + + ENTRY main_spmd { + data = f32[16]{0} parameter(0) + after_all = token[] after-all() + recv = (f32[16]{0}, u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + init = ((f32[16]{0}, u32[], token[]), f32[16]{0}) tuple(recv, data) + while = ((f32[16]{0}, u32[], token[]), f32[16]{0}) while(init), + condition=while_condition, body=while_body + recv_ctx = (f32[16]{0}, u32[], token[]) get-tuple-element(while), index=0 + recv_done = (f32[16]{0}, token[]) recv-done(recv_ctx), channel_id=1 + ROOT result = f32[16]{0} get-tuple-element(recv_done), index=0 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleString)); + CopyInsertion copy_insertion(nullptr, + /*use_region_based_live_range_analysis=*/-1); + + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + VLOG(2) << module->ToString(); + + // All async start/end will be ordered so that all copies, except for an extra + // use of the recv result, are removable. Additionally, there will be one copy + // leading into the loop. + HloComputation* while_body = + hlo_query::FindComputation(module.get(), "while_body"); + EXPECT_EQ(CountCopies(*module), 2); + EXPECT_EQ(CountCopies(*while_body), 1); + + // Expect control dependency from recv-done to recv. + HloInstruction* recv_done = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kRecvDone).first; + HloInstruction* recv = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kRecv).first; + HloInstruction* recv_done_copy = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kCopy).first; + EXPECT_THAT(recv_done_copy, op::Copy(op::GetTupleElement(recv_done))); + EXPECT_THAT(recv->control_predecessors(), + UnorderedElementsAre(recv_done, recv_done_copy)); +} + +TEST_F(CopyInsertionTest, PartiallyPipelinedAsyncSendMultipleUses) { + constexpr absl::string_view kModuleString = R"( + HloModule test, entry_computation_layout={(f32[16]{0})->f32[16]{0}}, + num_partitions=4 + + while_body { + param = ((f32[16]{0}, u32[], token[]), f32[16]{0}) parameter(0) + prev_send = (f32[16]{0}, u32[], token[]) get-tuple-element(param), index=0 + data = f32[16]{0} get-tuple-element(param), index=1 + send_done = (f32[16]{0}, token[]) send-done(prev_send), channel_id=1 + after_all = token[] after-all() + send = (f32[16]{0}, u32[], token[]) send(data, after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + + // `data` is used again here, which extends it's live range beyond `send`. + ROOT tuple = ((f32[16]{0}, u32[], token[]), f32[16]{0}) tuple(send, data) + } + + // Infinite loop to keep IR small. + while_condition { + param = ((f32[16]{0}, u32[], token[]), f32[16]{0}) parameter(0) + ROOT infinite_loop = pred[] constant(true) + } + + ENTRY main_spmd { + data = f32[16]{0} parameter(0) + after_all = token[] after-all() + send = (f32[16]{0}, u32[], token[]) send(data, after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + init = ((f32[16]{0}, u32[], token[]), f32[16]{0}) tuple(send, data) + while = ((f32[16]{0}, u32[], token[]), f32[16]{0}) while(init), + condition=while_condition, body=while_body + send_ctx = (f32[16]{0}, u32[], token[]) get-tuple-element(while), index=0 + send_done = (f32[16]{0}, token[]) send-done(send_ctx), channel_id=1 + ROOT data_ = f32[16]{0} get-tuple-element(while), index=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleString)); + CopyInsertion copy_insertion(nullptr, + /*use_region_based_live_range_analysis=*/-1); + + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + VLOG(2) << module->ToString(); + + // All async start/end will be ordered so that all copies, except for an extra + // use of the send operand, are removable. Additionally, there will be 2 + // copies leading into the loop and returning copying the result. + HloComputation* while_body = + hlo_query::FindComputation(module.get(), "while_body"); + EXPECT_EQ(CountCopies(*module), 3); + EXPECT_EQ(CountCopies(*while_body), 1); + + // Expect control dependency from send-done to send. + HloInstruction* send_done = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kSendDone).first; + HloInstruction* send = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kSend).first; + HloInstruction* send_operand_copy = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kCopy).first; + EXPECT_THAT(send, op::Send(send_operand_copy, op::AfterAll())); + EXPECT_THAT(send_operand_copy->control_predecessors(), + UnorderedElementsAre(send_done)); +} + +TEST_F(CopyInsertionTest, PartiallyPipelinedAsyncSendRecvPipelineParallelism) { + constexpr absl::string_view kModuleString = R"( + HloModule test, entry_computation_layout={(f32[16]{0})->f32[16]{0}}, + num_partitions=4 + + while_body { + param = ((f32[16]{0}, u32[], token[]), (f32[16]{0}, u32[], token[]), + f32[16]{0}, f32[16]{0}) parameter(0) + + prev_fwd = f32[16]{0} get-tuple-element(param), index=3 + + prev_send = (f32[16]{0}, u32[], token[]) get-tuple-element(param), index=0 + send_done = (f32[16]{0}, token[]) send-done(prev_send), channel_id=1 + prev_recv = (f32[16]{0}, u32[], token[]) get-tuple-element(param), index=1 + recv_done = (f32[16]{0}, token[]) recv-done(prev_recv), channel_id=2 + + fwd = f32[16]{0} get-tuple-element(recv_done), index=0 + + after_all = token[] after-all() + send = (f32[16]{0}, u32[], token[]) send(prev_fwd, after_all), + channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + recv = (f32[16]{0}, u32[], token[]) recv(after_all), channel_id=2, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + + // Both, the data that was sent and the data that was received are live + // until the end of the while loop. + ROOT tuple = ((f32[16]{0}, u32[], token[]), (f32[16]{0}, u32[], token[]), + f32[16]{0}, f32[16]{0}) tuple(send, recv, prev_fwd, fwd) + } + + // Infinite loop to keep IR small. + while_condition { + param = ((f32[16]{0}, u32[], token[]), (f32[16]{0}, u32[], token[]), + f32[16]{0}, f32[16]{0}) parameter(0) + ROOT infinite_loop = pred[] constant(true) + } + + ENTRY main_spmd { + data = f32[16]{0} parameter(0) + after_all = token[] after-all() + recv = (f32[16]{0}, u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + send = (f32[16]{0}, u32[], token[]) send(data, after_all), channel_id=2, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + init = ((f32[16]{0}, u32[], token[]), (f32[16]{0}, u32[], token[]), + f32[16]{0}, f32[16]{0}) tuple(send, recv, data, data) + while = ((f32[16]{0}, u32[], token[]), (f32[16]{0}, u32[], token[]), + f32[16]{0}, f32[16]{0}) while(init), condition=while_condition, + body=while_body + recv_ctx = (f32[16]{0}, u32[], token[]) get-tuple-element(while), index=0 + recv_done = (f32[16]{0}, token[]) recv-done(recv_ctx), channel_id=1 + send_ctx = (f32[16]{0}, u32[], token[]) get-tuple-element(while), index=0 + send_done = (f32[16]{0}, token[]) send-done(send_ctx), channel_id=2 + ROOT data_ = f32[16]{0} get-tuple-element(recv_done), index=0 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleString)); + CopyInsertion copy_insertion(nullptr, + /*use_region_based_live_range_analysis=*/-1); + + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + VLOG(2) << module->ToString(); + + // All async start/end will be ordered so that all copies but two are + // removable: + // - The copy for the extra use of the send operand. + // - The copy for the extra use of the recv result. + // The copy removal heuristic fails on removing one data copy, so the total + // number of copies in the while loop body is 3. + HloComputation* while_body = + hlo_query::FindComputation(module.get(), "while_body"); + EXPECT_EQ(CountCopies(*module), 6); + EXPECT_EQ(CountCopies(*while_body), 3); + + // Expect control dependency from send-done to send. + HloInstruction* send_done = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kSendDone).first; + HloInstruction* send = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kSend).first; + HloInstruction* send_operand_copy = send->mutable_operand(0); + EXPECT_THAT(send_operand_copy, op::Copy()); + EXPECT_THAT(send, op::Send(send_operand_copy, op::AfterAll())); + EXPECT_THAT(send_operand_copy->control_predecessors(), + UnorderedElementsAre(send_done)); + + // Expect control dependency from recv-done to recv. + HloInstruction* recv_done = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kRecvDone).first; + HloInstruction* recv = + hlo_query::FindFirstInstruction(while_body, HloOpcode::kRecv).first; + HloInstruction* recv_done_copy = *absl::c_find_if( + recv->control_predecessors(), HloPredicateIsOp); + EXPECT_THAT(recv_done_copy, op::Copy(op::GetTupleElement(recv_done))); + EXPECT_THAT(recv->control_predecessors(), + UnorderedElementsAre(recv_done, recv_done_copy)); +} + } // namespace } // namespace xla