diff --git a/xla/service/gpu/gpu_hlo_schedule_test.cc b/xla/service/gpu/gpu_hlo_schedule_test.cc index aa669290f9c9d..4edab874aaf51 100644 --- a/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -777,30 +777,18 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined1) { HloModule test while_cond { - param = (u32[], (f32[1,1024,1024], u32[], token[]), - (f32[1,1024,1024], u32[], token[])) parameter(0) + param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0) count = get-tuple-element(param), index=0 ub = u32[] constant(25) ROOT cond-result = pred[] compare(count, ub), direction=LT } while_body { - param = (u32[], (f32[1,1024,1024], u32[], token[]), - (f32[1,1024,1024], u32[], token[])) parameter(0) + param = (u32[], (f32[1,1024,1024], token[]), token[]) parameter(0) count = get-tuple-element(param), index=0 - recv.1.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(param), index=1 - recv-done.1 = (f32[1,1024,1024], token[]) recv-done(recv.1.q), channel_id=1, - frontend_attributes={ - _xla_send_recv_pipeline="0" - } - recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.1), index=0 - - send.1.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(param), index=2 - send-done.1 = token[] send-done(send.1.q), channel_id=1, - frontend_attributes={ - _xla_send_recv_pipeline="0" - } + recv-done.1.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1 + recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.1.q), index=0 c1 = u32[] constant(1) new-count = u32[] add(count, c1) @@ -820,17 +808,24 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined1) { after-all.1 = token[] after-all() send.1 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.1), channel_id=1, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", - _xla_send_recv_pipeline="0" - } + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", + _xla_send_recv_pipeline="0" + } recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=1, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", - _xla_send_recv_pipeline="0" - } - - ROOT body-result = (u32[], (f32[1,1024,1024], u32[], token[]), - (f32[1,1024,1024], u32[], token[])) tuple(new-count, recv.1, send.1) + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", + _xla_send_recv_pipeline="0" + } + recv-done.1 = (f32[1,1024,1024], token[]) recv-done(recv.1), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + send-done.1 = token[] send-done(send.1), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[]) + tuple(new-count, recv-done.1, send-done.1) } ENTRY main { @@ -841,35 +836,32 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined1) { after-all.2 = token[] after-all() recv.2 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.2), channel_id=1, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", - _xla_send_recv_pipeline="0" - } + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", + _xla_send_recv_pipeline="0" + } send.2 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.2), channel_id=1, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", - _xla_send_recv_pipeline="0" - } - - while-init = (u32[], (f32[1,1024,1024], u32[], token[]), - (f32[1,1024,1024], u32[], token[])) tuple(c0, recv.2, send.2) - while-result = (u32[], (f32[1,1024,1024], u32[], token[]), - (f32[1,1024,1024], u32[], token[])) while(while-init), - body=while_body, condition=while_cond, - backend_config={"known_trip_count":{"n":"25"}} - - recv.2.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(while-result), index=1 - recv-done.2 = (f32[1,1024,1024], token[]) recv-done(recv.2.q), channel_id=1, + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", + _xla_send_recv_pipeline="0" + } + recv-done.2 = (f32[1,1024,1024], token[]) recv-done(recv.2), channel_id=1, frontend_attributes={ - _xla_send_recv_pipeline="0" + _xla_send_recv_pipeline="0" } - - send.2.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(while-result), index=2 - send-done.2 = token[] send-done(send.2.q), channel_id=1, + send-done.2 = token[] send-done(send.2), channel_id=1, frontend_attributes={ - _xla_send_recv_pipeline="0" + _xla_send_recv_pipeline="0" } + while-init = (u32[], (f32[1,1024,1024], token[]), token[]) + tuple(c0, recv-done.2, send-done.2) + while-result = (u32[], (f32[1,1024,1024], token[]), token[]) + while(while-init), + body=while_body, condition=while_cond, + backend_config={"known_trip_count":{"n":"25"}} - ROOT entry-result = f32[1, 1024, 1024] get-tuple-element(recv-done.2), index=0 + recv-done.2.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result), index=1 + + ROOT entry-result = f32[1, 1024, 1024] get-tuple-element(recv-done.2.q), index=0 } )"; @@ -894,20 +886,23 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined1) { }) - instruction_sequence.begin(); }; - EXPECT_TRUE(HasValidFingerprint(module.get())); - // The pipelined Send-Recv in the main. - EXPECT_LT(get_index("recv.2", main), get_index("while-result", main)); - EXPECT_LT(get_index("send.2", main), get_index("while-result", main)); - EXPECT_LT(get_index("while-result", main), get_index("recv-done.2", main)); - EXPECT_LT(get_index("while-result", main), get_index("send-done.2", main)); - // The pipelined Send-Recv in the while-body. + // The pipelined Send-Recv in the main. A pipelined Recv is scheduled right + // after its corresponding Send due to kForceEarly. + EXPECT_EQ(get_index("recv.2", main) + 1, get_index("send.2", main)); + EXPECT_LT(get_index("send.2", main), get_index("recv-done.2", main)); + EXPECT_LT(get_index("recv-done.2", main), get_index("send-done.2", main)); + EXPECT_LT(get_index("send-done.2", main), get_index("while-result", main)); + + // The pipelined Send-Recv in the while-body. A pipelined Recv is scheduled + // right after its corresponding Send due to kForceEarly. + EXPECT_EQ(get_index("recv.1", while_body) + 1, + get_index("send.1", while_body)); + EXPECT_LT(get_index("send.1", while_body), + get_index("recv-done.1", while_body)); EXPECT_LT(get_index("recv-done.1", while_body), get_index("send-done.1", while_body)); - EXPECT_LT(get_index("send-done.1", while_body), - get_index("recv.1", while_body)); - EXPECT_LT(get_index("recv.1", while_body), get_index("send.1", while_body)); } // Checks that with the dependence added by the gpu-hlo-scheduler, the @@ -917,45 +912,22 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined2) { HloModule test while_cond { - param = (u32[], (f32[1,1024,1024], u32[], token[]), - (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), - (f32[1,1024,1024], u32[], token[])) parameter(0) + param = (u32[], (f32[1,1024,1024], token[]), token[], + (f32[1,1024,1024], token[]), token[]) parameter(0) count = get-tuple-element(param), index=0 ub = u32[] constant(25) ROOT cond-result = pred[] compare(count, ub), direction=LT } while_body { - param = (u32[], (f32[1,1024,1024], u32[], token[]), - (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), - (f32[1,1024,1024], u32[], token[])) parameter(0) + param = (u32[], (f32[1,1024,1024], token[]), token[], + (f32[1,1024,1024], token[]), token[]) parameter(0) count = get-tuple-element(param), index=0 - recv.0.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(param), index=1 - recv-done.0 = (f32[1,1024,1024], token[]) recv-done(recv.0.q), channel_id=1, - frontend_attributes={ - _xla_send_recv_pipeline="0" - } - recv-data.0 = f32[1, 1024, 1024] get-tuple-element(recv-done.0), index=0 - - send.0.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(param), index=2 - send-done.0 = token[] send-done(send.0.q), channel_id=1, - frontend_attributes={ - _xla_send_recv_pipeline="0" - } - - recv.1.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(param), index=3 - recv-done.1 = (f32[1,1024,1024], token[]) recv-done(recv.1.q), channel_id=2, - frontend_attributes={ - _xla_send_recv_pipeline="1" - } - recv-data.1 = f32[1, 1024, 1024] get-tuple-element(recv-done.1), index=0 - - send.1.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(param), index=4 - send-done.1 = token[] send-done(send.1.q), channel_id=2, - frontend_attributes={ - _xla_send_recv_pipeline="1" - } + recv-done.0.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1 + recv-data.0 = f32[1, 1024, 1024] get-tuple-element(recv-done.0.q), index=0 + recv-done.1.q = (f32[1,1024,1024], token[]) get-tuple-element(param), index=3 + recv-data.1 = f32[1, 1024, 1024] get-tuple-element(recv-done.1.q), index=0 replica = u32[] replica-id() constant0 = u32[] constant(0) @@ -980,30 +952,46 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined2) { after-all.0 = token[] after-all() send.0 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.0), channel_id=1, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{3,0}}", - _xla_send_recv_pipeline="0" - } + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_send_recv_pipeline="0" + } recv.0 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.0), channel_id=1, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{3,0}}", - _xla_send_recv_pipeline="0" - } + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_send_recv_pipeline="0" + } + recv-done.0 = (f32[1,1024,1024], token[]) recv-done(recv.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + send-done.0 = token[] send-done(send.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } after-all.1 = token[] after-all() send.1 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.1), channel_id=2, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}", - _xla_send_recv_pipeline="1" - } - recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=2, - frontend_attributes={ _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}", _xla_send_recv_pipeline="1" - } + } + recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=2, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}", + _xla_send_recv_pipeline="1" + } + recv-done.1 = (f32[1,1024,1024], token[]) recv-done(recv.1), channel_id=2, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } + send-done.1 = token[] send-done(send.1), channel_id=2, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } - ROOT body-result = (u32[], (f32[1,1024,1024], u32[], token[]), - (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), - (f32[1,1024,1024], u32[], token[])) tuple(new-count, recv.0, send.0, recv.1, send.1) + ROOT body-result = (u32[], (f32[1,1024,1024], token[]), token[], + (f32[1,1024,1024], token[]), token[]) + tuple(new-count, recv-done.0, send-done.0, recv-done.1, send-done.1) } ENTRY main { @@ -1022,6 +1010,14 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined2) { _xla_send_recv_source_target_pairs="{{3,0}}", _xla_send_recv_pipeline="0" } + recv-done.2 = (f32[1,1024,1024], token[]) recv-done(recv.2), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + send-done.2 = token[] send-done(send.2), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } after-all.3 = token[] after-all() recv.3 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.3), channel_id=2, @@ -1034,41 +1030,26 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined2) { _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}", _xla_send_recv_pipeline="1" } - - while-init = (u32[], (f32[1,1024,1024], u32[], token[]), - (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), - (f32[1,1024,1024], u32[], token[])) tuple(c0, recv.2, send.2, recv.3, send.3) - while-result = (u32[], (f32[1,1024,1024], u32[], token[]), - (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), - (f32[1,1024,1024], u32[], token[])) while(while-init), - body=while_body, condition=while_cond, - backend_config={"known_trip_count":{"n":"25"}} - - recv.2.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(while-result), index=1 - recv-done.2 = (f32[1,1024,1024], token[]) recv-done(recv.2.q), channel_id=1, + recv-done.3 = (f32[1,1024,1024], token[]) recv-done(recv.3), channel_id=2, frontend_attributes={ - _xla_send_recv_pipeline="0" + _xla_send_recv_pipeline="1" } - recv-data.2 = f32[1, 1024, 1024] get-tuple-element(recv-done.2), index=0 - - send.2.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(while-result), index=2 - send-done.2 = token[] send-done(send.2.q), channel_id=1, + send-done.3 = token[] send-done(send.3), channel_id=2, frontend_attributes={ - _xla_send_recv_pipeline="0" + _xla_send_recv_pipeline="1" } - recv.3.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(while-result), index=3 - recv-done.3 = (f32[1,1024,1024], token[]) recv-done(recv.3.q), channel_id=2, - frontend_attributes={ - _xla_send_recv_pipeline="1" - } - recv-data.3 = f32[1, 1024, 1024] get-tuple-element(recv-done.3), index=0 + while-init = (u32[], (f32[1,1024,1024], token[]), token[], + (f32[1,1024,1024], token[]), token[]) tuple(c0, recv-done.2, send-done.2, recv-done.3, send-done.3) + while-result = (u32[], (f32[1,1024,1024], token[]), token[], + (f32[1,1024,1024], token[]), token[]) while(while-init), + body=while_body, condition=while_cond, + backend_config={"known_trip_count":{"n":"25"}} - send.3.q = (f32[1,1024,1024], u32[], token[]) get-tuple-element(while-result), index=4 - send-done.3 = token[] send-done(send.3.q), channel_id=2, - frontend_attributes={ - _xla_send_recv_pipeline="1" - } + recv-done.2.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result), index=1 + recv-data.2 = f32[1, 1024, 1024] get-tuple-element(recv-done.2.q), index=0 + recv-done.3.q = (f32[1,1024,1024], token[]) get-tuple-element(while-result), index=3 + recv-data.3 = f32[1, 1024, 1024] get-tuple-element(recv-done.3.q), index=0 replica = u32[] replica-id() constant0 = u32[] constant(0) @@ -1101,18 +1082,32 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined2) { }; EXPECT_TRUE(HasValidFingerprint(module.get())); - // The pipelined Send-Recv in the main. - EXPECT_LT(get_index("recv.2", main), get_index("while-result", main)); - EXPECT_LT(get_index("send.2", main), get_index("while-result", main)); - EXPECT_LT(get_index("while-result", main), get_index("recv-done.2", main)); - EXPECT_LT(get_index("while-result", main), get_index("send-done.2", main)); - - // The pipelined Send-Recv in the while-body. + // The pipelined Send-Recv in the main. A pipelined Recv is scheduled right + // after its corresponding Send due to kForceEarly. + EXPECT_EQ(get_index("recv.2", main) + 1, get_index("send.2", main)); + EXPECT_LT(get_index("send.2", main), get_index("recv.3", main)); + EXPECT_EQ(get_index("recv.3", main) + 1, get_index("send.3", main)); + EXPECT_LT(get_index("send.3", main), get_index("recv-done.2", main)); + EXPECT_LT(get_index("recv-done.2", main), get_index("recv-done.3", main)); + EXPECT_LT(get_index("recv-done.3", main), get_index("send-done.2", main)); + EXPECT_LT(get_index("send-done.2", main), get_index("send-done.3", main)); + EXPECT_LT(get_index("send-done.3", main), get_index("while-result", main)); + + // The pipelined Send-Recv in the while-body. A pipelined Recv is scheduled + // right after its corresponding Send due to kForceEarly. + EXPECT_EQ(get_index("recv.0", while_body) + 1, + get_index("send.0", while_body)); + EXPECT_LT(get_index("send.0", while_body), get_index("recv.1", while_body)); + EXPECT_EQ(get_index("recv.1", while_body) + 1, + get_index("send.1", while_body)); + EXPECT_LT(get_index("send.1", while_body), + get_index("recv-done.0", while_body)); + EXPECT_LT(get_index("recv-done.0", while_body), + get_index("recv-done.1", while_body)); EXPECT_LT(get_index("recv-done.1", while_body), + get_index("send-done.0", while_body)); + EXPECT_LT(get_index("send-done.0", while_body), get_index("send-done.1", while_body)); - EXPECT_LT(get_index("send-done.1", while_body), - get_index("recv.1", while_body)); - EXPECT_LT(get_index("recv.1", while_body), get_index("send.1", while_body)); } TEST_F(GpuHloScheduleTest, SkipAlreadyScheduled) { diff --git a/xla/service/gpu/gpu_p2p_pipeliner.cc b/xla/service/gpu/gpu_p2p_pipeliner.cc index 5f5028a43442c..f8cba55030c9f 100644 --- a/xla/service/gpu/gpu_p2p_pipeliner.cc +++ b/xla/service/gpu/gpu_p2p_pipeliner.cc @@ -38,7 +38,7 @@ namespace gpu { namespace { bool ShouldPipeline(const HloInstruction* instr) { - if (!HloPredicateIsOp(instr)) { + if (!HloPredicateIsOp(instr)) { return false; } @@ -48,10 +48,12 @@ bool ShouldPipeline(const HloInstruction* instr) { return false; } - // Check that the Send or Recv is used for non-trivial computation. This - // avoids repeatedly pipelining a loop. - return (instr->user_count() == 1 && instr->parent() != nullptr && - instr->users()[0] != instr->parent()->root_instruction()); + // Checks that the SendDone or RecvDone is used for non-trivial computation. + // This avoids repeatedly pipelining a loop. + bool is_pipelined = + (instr->user_count() == 1 && instr->parent() != nullptr && + instr->users()[0] == instr->parent()->root_instruction()); + return !is_pipelined; } bool ShouldAllowLoopVariantParameterInChain(const HloInstruction* instr) { @@ -65,6 +67,14 @@ bool ShouldAllowLoopVariantParameterInChain(const HloInstruction* instr) { Status PostprocessP2PImpl( HloInstruction* instr, std::function&)> transformer) { + // The input instruction is a Done instruction. + if (!HloPredicateIsOp(instr)) { + return Internal("Expected SendDone/RecvDone as the pipelined collective"); + } + instr = instr->mutable_operand(0); + if (!HloPredicateIsOp(instr)) { + return Internal("Expected Send/Recv as the SendDone/RecvDone operand"); + } auto validation_it = instr->frontend_attributes().map().find(kSendRecvValidationAttr); if (validation_it == instr->frontend_attributes().map().end() || diff --git a/xla/service/p2p_schedule_preparation.cc b/xla/service/p2p_schedule_preparation.cc index 1771a99ea2227..782b807763f82 100644 --- a/xla/service/p2p_schedule_preparation.cc +++ b/xla/service/p2p_schedule_preparation.cc @@ -89,7 +89,7 @@ HloInstruction* GetStartOpForDoneOp(HloInstruction* op) { enum P2PGroupKind { kUnpipelined = 0, kPipelined = 1, kUnrecognized = 2 }; -enum P2PPipelineStream { kUnknown = 0, kPipeline0 = 1, kPipeline1 = 2 }; +enum P2PRuntimeStream { kUnknown = 0, kStream0 = 1, kStream1 = 2 }; // A P2P group node represents the P2P instructions that are in the same // computation and have the same channel ID. This includes one Send/SendDone @@ -164,14 +164,14 @@ struct P2PGroupNode { // Returns the pipeline stream used to execute the P2P instructions in the // group. - P2PPipelineStream GetPipelineStream(const HloInstruction* start) const { + P2PRuntimeStream GetRuntimeStream(const HloInstruction* start) const { auto it = start->frontend_attributes().map().find(kSendRecvPipelineAttr); if (it != start->frontend_attributes().map().end()) { if (it->second == "0") { - return kPipeline0; + return kStream0; } if (it->second == "1") { - return kPipeline1; + return kStream1; } } return kUnknown; @@ -180,9 +180,9 @@ struct P2PGroupNode { // Finds the pipeline stream from the frontend attribute of the Send/Recv in // the pipeline group node, verifies they both have the same value and returns // the stream. - P2PPipelineStream GetPipelineStream() const { - P2PPipelineStream send_stream = GetPipelineStream(send); - P2PPipelineStream recv_stream = GetPipelineStream(recv); + P2PRuntimeStream GetRuntimeStream() const { + P2PRuntimeStream send_stream = GetRuntimeStream(send); + P2PRuntimeStream recv_stream = GetRuntimeStream(recv); if (send_stream != recv_stream) { return kUnknown; } @@ -283,37 +283,45 @@ struct P2PGroup { // Finds the pipeline stream from the frontend attribute of the Send/Recv in // the pipeline group, verifies they all have the same value and records // the stream. - bool RecordPipelineStream() { - P2PPipelineStream child_stream = - nodes[kPipelinedChildNodeIdx].GetPipelineStream(); - P2PPipelineStream parent_stream = - nodes[kPipelinedParentNodeIdx].GetPipelineStream(); - if (child_stream != parent_stream || child_stream == kUnknown) { - return false; + bool RecordRuntimeStream() { + P2PRuntimeStream child_stream = + nodes[kPipelinedChildNodeIdx].GetRuntimeStream(); + if (kind == kPipelined) { + P2PRuntimeStream parent_stream = + nodes[kPipelinedParentNodeIdx].GetRuntimeStream(); + if (child_stream != parent_stream || child_stream == kUnknown) { + return false; + } } // Record the stream. - pipeline_stream = child_stream; + runtime_stream = child_stream; return true; } // Records the other group that forms a cycle with this group, assuming that - // we only pipepline at most two groups for a loop. + // we handle only two groups that form a cycle. Status RecordComplementGroup(P2PGroupMap& p2p_group_map) { + CHECK(complement_group == nullptr && runtime_stream == kStream1); for (auto& [channel, p2p_group] : p2p_group_map) { - if (&p2p_group == this || p2p_group.kind != kPipelined || - p2p_group.ChildComputation() != ChildComputation() || - p2p_group.ParentComputation() != ParentComputation()) { + if (&p2p_group == this || + p2p_group.ChildComputation() != ChildComputation()) { continue; } - // Found two pipeline group for the same while loop, verify that they have - // different valid pipeline stream. - if (pipeline_stream == p2p_group.pipeline_stream) { - return Internal( - "Expected different pipeline stream for complement group"); + if (p2p_group.kind == kPipelined && + p2p_group.ParentComputation() == ParentComputation()) { + // Found two pipelined group for the same while loop, verify that they + // have different valid pipeline stream. + if (p2p_group.runtime_stream != kStream0) { + return Internal( + "Expected different pipeline stream for complement group"); + } + complement_group = &p2p_group; + p2p_group.complement_group = this; + } else if (p2p_group.kind == kUnpipelined && + p2p_group.runtime_stream != kStream1) { + complement_group = &p2p_group; + p2p_group.complement_group = this; } - complement_group = &p2p_group; - p2p_group.complement_group = this; - break; } return OkStatus(); } @@ -332,33 +340,32 @@ struct P2PGroup { } // Returns the start and end of a region marked by a pipelined chain in the - // given computation. For most of the cases, this is the region with the - // pipelined P2P instructions. The only exception is for a pipelined chain - // in the child computation, in which case, the region is from the end of the - // Send/Recv-done instructions block to the beginning of the Send/Recv - // instruction block start instruction block which is the region where other - // collectives should be scheduled to. + // given computation, which is the region with the pipelined P2P instructions. ChainStartEnd GetChainStartEnd(HloComputation* computation) const { if (kind == kUnpipelined) { - return std::make_pair(GetChild().recv, GetChild().send_done); + if (!InCycle()) { + return std::make_pair(GetChild().recv, GetChild().send_done); + } + CHECK(runtime_stream == kStream1); + return std::make_pair(complement_group->GetChild().recv, + GetChild().send_done); } CHECK(kind == kPipelined); if (computation == ChildComputation()) { - // For the child computation of a pipelined group, we return the start - // and end of the instruction where we can put other collectives. - if (complement_group == nullptr) { - return std::make_pair(GetChild().send_done, GetChild().recv); + if (!InCycle()) { + return std::make_pair(GetChild().recv, GetChild().send_done); } - CHECK(pipeline_stream == kPipeline1); - return std::make_pair(GetChild().send_done, GetChild().recv); + CHECK(runtime_stream == kStream1); + return std::make_pair(complement_group->GetChild().recv, + GetChild().send_done); } CHECK(computation == ParentComputation()); - if (complement_group == nullptr) { + if (!InCycle()) { return std::make_pair(GetParent().recv, GetParent().send_done); } - CHECK(pipeline_stream == kPipeline1); + CHECK(runtime_stream == kStream1); return std::make_pair(complement_group->GetParent().recv, GetParent().send_done); } @@ -367,9 +374,11 @@ struct P2PGroup { return nodes[kPipelinedParentNodeIdx].while_loop; } + bool InCycle() const { return complement_group != nullptr; } + P2PGroupKind kind = kUnpipelined; P2PGroupNode nodes[2]; - P2PPipelineStream pipeline_stream = kUnknown; + P2PRuntimeStream runtime_stream = kUnknown; // Another P2PGroup that forms a cycle with this group. P2PGroup* complement_group = nullptr; }; @@ -411,32 +420,7 @@ Status MayAddWhileOpToPipelinedGroup(HloInstruction* while_op, int pipelined_group = 0; // Check whether the while-op init contains a token from a Send result. for (auto hlo : while_op->while_init()->operands()) { - if (hlo->opcode() == HloOpcode::kTuple) { - // A send has a tuple as its result, the tuple contains a token. - // If a send is pipelined, then, the while-init either contains - // a send-result, or contains a tuple with a token element from the - // send result. As such, if a tuple represent a pipelined send, it is - // either a direct send result, or a tuple with this code pattern: - /// - // send = (..., token) send(...) - // send.token = token[] get-tuple-element(send) index=... - // send.tuple.reconstruct = tuple(..., send.token) - // while-init = tuple(..., send.tuple.reconstruct) - // while-result = while(while-init), ... - // - // So if the tuple contains a token, we make `hlo` point-to the producer - // of the token so that we can check whether the producer is a send after. - for (auto ele : hlo->operands()) { - if (ele->shape().IsToken()) { - // Assure that the token is part of an instruction result and not - // generated by a copy as we currently don't copy token. - CHECK(ele->opcode() == HloOpcode::kGetTupleElement); - hlo = ele->mutable_operand(0); - break; - } - } - } - if (hlo->opcode() != HloOpcode::kSend) { + if (hlo->opcode() != HloOpcode::kSendDone) { continue; } int64_t channel_id = hlo->channel_id().value(); @@ -463,11 +447,9 @@ Status OrderBefore(HloInstruction* i1, HloInstruction* i2) { return OkStatus(); } -// For an unpipelined Send-Recv chain, we add control dependence to enforce this -// ordering: +// Adds control dependence to enforce this ordering: // recv => send => recv-done => send-done. -Status ConnectUnpipelinedP2P(const P2PGroup& p2p_group) { - const P2PGroupNode& node = p2p_group.GetChild(); +Status ConnectP2P1NodeChain(const P2PGroupNode& node) { HloRecvDoneInstruction* recv_done = node.recv_done; HloRecvInstruction* recv = node.recv; HloSendDoneInstruction* send_done = node.send_done; @@ -478,28 +460,26 @@ Status ConnectUnpipelinedP2P(const P2PGroup& p2p_group) { return OkStatus(); } -// For a single pipelined Send-Recv chain in a while-body, we enforce this +// For an unpipelined Send-Recv chain, adds control dependence to enforce this // ordering: -// recv-done => send-done => recv => send +// recv => send => recv-done => send-done. +Status ConnectUnpipelinedP2P(const P2PGroup& p2p_group) { + return ConnectP2P1NodeChain(p2p_group.GetChild()); +} + +// For a single pipelined Send-Recv chain in a while-body, adds control +// dependence toenforce this ordering: +// recv => send => recv-done => send-done Status ConnectPipelined1P2PChild(const P2PGroup& p2p_group) { - const P2PGroupNode& node = p2p_group.GetChild(); - HloSendRecvInstruction* recv_done = node.recv_done; - HloRecvInstruction* recv = node.recv; - HloSendRecvInstruction* send_done = node.send_done; - HloSendInstruction* send = node.send; - TF_RETURN_IF_ERROR(OrderBefore(recv_done, send_done)); - TF_RETURN_IF_ERROR(OrderBefore(send_done, recv)); - TF_RETURN_IF_ERROR(OrderBefore(recv, send)); - return OkStatus(); + return ConnectP2P1NodeChain(p2p_group.GetChild()); } -// For two pipelined Send-Recv chains forming a cycle in a while-body -// computation, we enforce this ordering: -// recv-done.0 => send-done.0 => recv-done.1 => send-done.1 => -// recv.0 => send.0 => recv.1 => send.1 -Status ConnectPipelined2P2PChild(const P2PGroup& p2p_group) { - const P2PGroupNode& node0 = p2p_group.complement_group->GetChild(); - const P2PGroupNode& node1 = p2p_group.GetChild(); +// For aSend-Recv chain involving two channels, adds control dependence to +// enforce this ordering: +// recv.0 => send.0 => recv.1 => send.1 => +// recv-done.0 => recv-done.1 => send-done.0 => send-done.1 +Status ConnectP2P2NodeChain(const P2PGroupNode& node0, + const P2PGroupNode& node1) { HloSendRecvInstruction* recv_done0 = node0.recv_done; HloRecvInstruction* recv0 = node0.recv; HloSendRecvInstruction* send_done0 = node0.send_done; @@ -509,54 +489,53 @@ Status ConnectPipelined2P2PChild(const P2PGroup& p2p_group) { HloSendRecvInstruction* send_done1 = node1.send_done; HloSendInstruction* send1 = node1.send; - TF_RETURN_IF_ERROR(OrderBefore(recv_done0, send_done0)); - TF_RETURN_IF_ERROR(OrderBefore(send_done0, recv_done1)); - TF_RETURN_IF_ERROR(OrderBefore(recv_done1, send_done1)); - TF_RETURN_IF_ERROR(OrderBefore(send_done1, recv0)); + TF_RETURN_IF_ERROR(OrderBefore(recv_done0, recv_done1)); + TF_RETURN_IF_ERROR(OrderBefore(recv_done1, send_done0)); + TF_RETURN_IF_ERROR(OrderBefore(send_done0, send_done1)); + TF_RETURN_IF_ERROR(OrderBefore(recv0, send0)); TF_RETURN_IF_ERROR(OrderBefore(send0, recv1)); TF_RETURN_IF_ERROR(OrderBefore(recv1, send1)); + TF_RETURN_IF_ERROR(OrderBefore(send1, recv_done0)); + return OkStatus(); } -// For a single pipelined Send-Recv chain in the while-body calling computation, -// we enforce this ordering: -// recv => send => (while_op) => recv-done => send-done +// For a pipelined Send-Recv chain with two channel groups forming a cycle in a +// while-body computation, we enforce this ordering: +// recv.0 => send.0 => recv.1 => send.1 => +// recv-done.0 => recv-done.1 => send-done.0 => send-done.1 +Status ConnectPipelined2P2PChild(const P2PGroup& p2p_group) { + return ConnectP2P2NodeChain(p2p_group.complement_group->GetChild(), + p2p_group.GetChild()); +} + +// For a pipelined Send-Recv chain with one group in the while-body calling +// computation, we enforce this ordering: +// recv => send => recv-done => send-done Status ConnectPipelined1P2PParent(const P2PGroup& p2p_group) { - const P2PGroupNode& node = p2p_group.GetParent(); - HloSendRecvInstruction* recv_done = node.recv_done; - HloRecvInstruction* recv = node.recv; - HloSendRecvInstruction* send_done = node.send_done; - HloSendInstruction* send = node.send; - TF_RETURN_IF_ERROR(OrderBefore(recv, send)); - TF_RETURN_IF_ERROR(OrderBefore(recv_done, send_done)); - return OkStatus(); + return ConnectP2P1NodeChain(p2p_group.GetParent()); } -// For two pipelined Send-Recv chains forming a cycle in the while-body -// calling computation, we enforce this ordering: -// recv.0 => send.0 => recv.1 => send.1 => (while_op) => -// recv-done.0 => send-done.0 => recv-done.1 => send-done.1 +// For a pipelined Send-Recv chain with two channel groups forming a cycle +// in the while-body calling computation, we enforce this ordering: +// recv.0 => send.0 => recv.1 => send.1 => => +// recv-done.0 => recv-done.1 => send-done.0 => send-done.1 Status ConnectPipelined2P2PParent(const P2PGroup& p2p_group) { - const P2PGroupNode& node0 = p2p_group.complement_group->GetParent(); - const P2PGroupNode& node1 = p2p_group.GetParent(); - HloSendRecvInstruction* recv_done0 = node0.recv_done; - HloRecvInstruction* recv0 = node0.recv; - HloSendRecvInstruction* send_done0 = node0.send_done; - HloSendInstruction* send0 = node0.send; - HloSendRecvInstruction* recv_done1 = node1.recv_done; - HloRecvInstruction* recv1 = node1.recv; - HloSendRecvInstruction* send_done1 = node1.send_done; - HloSendInstruction* send1 = node1.send; + return ConnectP2P2NodeChain(p2p_group.complement_group->GetParent(), + p2p_group.GetParent()); +} - TF_RETURN_IF_ERROR(OrderBefore(recv0, send0)); - TF_RETURN_IF_ERROR(OrderBefore(send0, recv1)); - TF_RETURN_IF_ERROR(OrderBefore(recv1, send1)); - TF_RETURN_IF_ERROR(OrderBefore(recv_done0, send_done0)); - TF_RETURN_IF_ERROR(OrderBefore(send_done0, recv_done1)); - TF_RETURN_IF_ERROR(OrderBefore(recv_done1, send_done1)); - return OkStatus(); +// For a Send-Recv chain with two channel groups forming a cycle in a while-body +// annotated for pipelining but not pipelined (due to skip pipelining pass), we +// enforece this ordering: +// recv.0 => send.0 => recv.1 => send.1 => +// recv-done.0 => recv-done.1 => send-done.0 => send-done.1 +Status ConnectUnpipelined2P2P(const P2PGroup& p2p_group) { + CHECK(p2p_group.runtime_stream == kStream1); + return ConnectP2P2NodeChain(p2p_group.complement_group->GetChild(), + p2p_group.GetChild()); } // Collects P2P send-done and recv-done instructions from the computation, @@ -571,16 +550,8 @@ Status GatherP2PGroupsAndCollectiveInfo( std::vector while_ops; for (auto hlo : computation->MakeInstructionPostOrder()) { // Record the use of collective operations. - if (IsCollectiveOp(hlo)) { + if (MayInvokeCollectiveOp(hlo, collective_in_computation)) { collective_in_computation[computation] = true; - } else { - // Propagate CollectiveInComputation from callees to callers. - for (auto callee : hlo->called_computations()) { - auto collective_in_comp = collective_in_computation.find(callee); - if (collective_in_comp != collective_in_computation.end()) { - collective_in_computation[computation] |= collective_in_comp->second; - } - } } if (hlo->opcode() == HloOpcode::kWhile) { @@ -637,14 +608,15 @@ Status GatherP2PGroupsAndCollectiveInfo( // kUnrecognized. for (auto& [channel, p2p_group] : p2p_group_map) { if (p2p_group.kind == kUnpipelined) { - if (p2p_group.nodes[kUnpipelinedNodeIdx].Incomplete()) { + if (p2p_group.nodes[kUnpipelinedNodeIdx].Incomplete() || + !p2p_group.RecordRuntimeStream()) { p2p_group.kind = kUnrecognized; } } else if (p2p_group.kind == kPipelined) { if (p2p_group.nodes[kPipelinedChildNodeIdx].Incomplete() || p2p_group.nodes[kPipelinedParentNodeIdx] .IncompletePipelinedParent() || - !p2p_group.RecordPipelineStream()) { + !p2p_group.RecordRuntimeStream()) { p2p_group.kind = kUnrecognized; } } @@ -655,16 +627,18 @@ Status GatherP2PGroupsAndCollectiveInfo( return p2p_group.second.kind == kUnrecognized; }); - // Connect kPipelined groups that form cycles if the current computation is - // the calling computation for the loop being pipelined. We only build such a - // connection when we are processing the group for kPipeline1 stream. + // Connect two groups that form a cycle, both for pipelined and unpipelined + // cases for the current computation. We only build such a connection when we + // are processing the group for kStream1 stream, and for parent computation + // for a pipelined group. for (auto& [channel, p2p_group] : p2p_group_map) { - if (p2p_group.kind != kPipelined || - p2p_group.ParentComputation() != computation || + if ((p2p_group.kind == kPipelined && + p2p_group.ParentComputation() != computation) || p2p_group.complement_group != nullptr || - p2p_group.pipeline_stream != kPipeline1) { + p2p_group.runtime_stream != kStream1) { continue; } + TF_RETURN_IF_ERROR(p2p_group.RecordComplementGroup(p2p_group_map)); } @@ -693,7 +667,11 @@ absl::StatusOr> ConnectP2PChain( const P2PGroup& p2p_group = it->second; P2PGroupKind kind = p2p_group.kind; if (kind == P2PGroupKind::kUnpipelined) { - TF_RETURN_IF_ERROR(ConnectUnpipelinedP2P(p2p_group)); + if (!p2p_group.InCycle()) { + TF_RETURN_IF_ERROR(ConnectUnpipelinedP2P(p2p_group)); + } else if (p2p_group.runtime_stream == kStream1) { + TF_RETURN_IF_ERROR(ConnectUnpipelined2P2P(p2p_group)); + } continue; } @@ -712,8 +690,8 @@ absl::StatusOr> ConnectP2PChain( } // A pipeline of two groups that form a cycle. We process the pipeline when - // we see the group with kPipeline1. - if (p2p_group.pipeline_stream != kPipeline1) { + // we see the group with kStream1. + if (p2p_group.runtime_stream != kStream1) { continue; } @@ -839,12 +817,7 @@ Status LinearizeCollectivesWithOtherP2P( // Adds control dependence to linearize other collective ops with respect to // the given pipelined P2P chain in the computation for the pipelined -// while-loop, which is ordered as follows: -// RecvDone => SendDone .... Recv => Send (1 pipelined chain) -// RecvDone.0 => SendDone.0 => RecvDone.1 => SendDone.1 .... Recv.0 => -// Send.0 => Recv.1 => Send.1 (2 pipelined chains) -// All collective ops should be scheduled between (SendDone, Recv) or -// (SendDone.1, Recv.0) +// while-loop. All Collective ops should be scheduled before the chain. Status LinearizeCollectivesWithPipelinedP2PChild( const P2PGroupMap& p2p_group_map, const P2PGroup& group, const CollectiveInComputation& collective_in_computation, @@ -852,9 +825,10 @@ Status LinearizeCollectivesWithPipelinedP2PChild( ChainStartEnd start_end = group.GetChainStartEnd(computation); // If an hlo may invoke collective operation, we add control dependence to - // make sure that the hlo is schedule between (start, end) marked by the - // pipelined P2P operation in a while-body. + // make sure that the hlo is scheduled before the pipelined chain starts. for (HloInstruction* hlo : computation->MakeInstructionPostOrder()) { + // For async collective ops, only the done version of the op passes this + // check, to avoid handling async ops twice. if (!MayInvokeCollectiveOp(hlo, collective_in_computation)) { continue; } @@ -864,7 +838,7 @@ Status LinearizeCollectivesWithPipelinedP2PChild( if (IsP2POp(hlo) && opcode != HloOpcode::kSendDone) { continue; } - if (opcode == HloOpcode::kSendDone) { + if (hlo->opcode() == HloOpcode::kSendDone) { auto group_it = p2p_group_map.find(hlo->channel_id().value()); if (group_it == p2p_group_map.end()) { continue; @@ -880,17 +854,13 @@ Status LinearizeCollectivesWithPipelinedP2PChild( ChainStartEnd cur_start_end = cur_group.GetChainStartEnd(computation); TF_RETURN_IF_ERROR( - OrderBefore(reachability, start_end.first, cur_start_end.first)); - TF_RETURN_IF_ERROR( - OrderBefore(reachability, cur_start_end.second, start_end.second)); + OrderBefore(reachability, cur_start_end.second, start_end.first)); continue; } // Async done, CustomCall, or other ops that indirectly invoke collectives. - TF_RETURN_IF_ERROR( - OrderBefore(reachability, start_end.first, GetStartOpForDoneOp(hlo))); - TF_RETURN_IF_ERROR(OrderBefore(reachability, hlo, start_end.second)); + TF_RETURN_IF_ERROR(OrderBefore(reachability, hlo, start_end.first)); } return OkStatus(); @@ -954,10 +924,9 @@ absl::StatusOr P2PSchedulePreparation::Run( std::unique_ptr reachability = HloReachabilityMap::Build(computation); if (result.second != nullptr) { - // The current compuation is a while-body with pipelined P2P chain. - // Order all other collectives in a pipelined while-body between the - // Send/Recv-done block and the Send/Recv block of the pipelined P2P - // chain. + // The current computation is a while-body with pipelined P2P chain. + // Order all other collectives in a pipelined while-body before the + // pipelined P2P chain. TF_RETURN_IF_ERROR(LinearizeCollectivesWithPipelinedP2PChild( p2p_group_map, *result.second, collective_in_computation, computation, reachability.get())); @@ -988,11 +957,9 @@ absl::StatusOr P2PSchedulePreparation::Run( // to other collectives. continue; } - if (kind == P2PGroupKind::kPipelined && - group.complement_group != nullptr && - group.pipeline_stream != kPipeline1) { + if (group.InCycle() && group.runtime_stream != kStream1) { // We process a chain with two groups when we see the group for - // kPipeline1. + // kStream1. continue; } ChainStartEnd start_end = group.GetChainStartEnd(computation); diff --git a/xla/service/p2p_schedule_preparation_test.cc b/xla/service/p2p_schedule_preparation_test.cc index f4d04a5655639..bcd2bedef7fd0 100644 --- a/xla/service/p2p_schedule_preparation_test.cc +++ b/xla/service/p2p_schedule_preparation_test.cc @@ -46,10 +46,9 @@ class P2PSchedulePreparationTest : public HloTestBase { EXPECT_EQ(send_done->control_predecessors().size(), 0); } - // Verifies that the control dependence enforces this ordering for an - // unpipelined Send-Recv chain: + // Verifies that the control dependence enforces this ordering: // recv => send => recv-done => send-done - void VerifyUnpipelinedP2P(HloModule* module, const std::string& suffix = "") { + void VerifyP2P1GroupChain(HloModule* module, const std::string& suffix) { HloInstruction* send = FindInstruction(module, "send" + suffix); HloInstruction* recv = FindInstruction(module, "recv" + suffix); HloInstruction* recv_done = FindInstruction(module, "recv-done" + suffix); @@ -59,23 +58,19 @@ class P2PSchedulePreparationTest : public HloTestBase { EXPECT_EQ(send_done->control_predecessors()[0], recv_done); } + // Verifies that the control dependence enforces this ordering for an + // unpipelined Send-Recv chain: + // recv => send => recv-done => send-done + void VerifyUnpipelinedP2P(HloModule* module, const std::string& suffix = "") { + VerifyP2P1GroupChain(module, suffix); + } + // Verifies that the control dependence enforces this ordering for a pipelined // Send-Recv chain in the while-body: - // recv-done => send-done => recv => send. + // recv => send => recv-done => send-done void VerifyPipelinedP2PChild(HloModule* module, const std::string& suffix = "") { - HloInstruction* send = FindInstruction(module, "send" + suffix); - HloInstruction* recv = FindInstruction(module, "recv" + suffix); - HloInstruction* recv_done = FindInstruction(module, "recv-done" + suffix); - HloInstruction* send_done = FindInstruction(module, "send-done" + suffix); - // If the while-body has other P2P, the pipelined Recv should also have the - // Send-done of the other P2P as control predecessors. - EXPECT_EQ(1, absl::c_count(recv->control_predecessors(), send_done)); - EXPECT_EQ(recv_done->control_predecessors().size(), 0); - EXPECT_EQ(send_done->control_predecessors().size(), 1); - EXPECT_EQ(send_done->control_predecessors()[0], recv_done); - EXPECT_EQ(send->control_predecessors().size(), 1); - EXPECT_EQ(send->control_predecessors()[0], recv); + VerifyP2P1GroupChain(module, suffix); } // Verifies that the control dependence enforces this ordering for a pipelined @@ -83,22 +78,14 @@ class P2PSchedulePreparationTest : public HloTestBase { // recv => send => while-loop => recv-done => send-done. void VerifyPipelinedP2PParent(HloModule* module, const std::string& suffix = "") { - HloInstruction* send = FindInstruction(module, "send" + suffix); - HloInstruction* recv = FindInstruction(module, "recv" + suffix); - HloInstruction* recv_done = FindInstruction(module, "recv-done" + suffix); - HloInstruction* send_done = FindInstruction(module, "send-done" + suffix); - EXPECT_EQ(send_done->control_predecessors().size(), 1); - EXPECT_EQ(send_done->control_predecessors()[0], recv_done); - EXPECT_EQ(send->control_predecessors().size(), 1); - EXPECT_EQ(send->control_predecessors()[0], recv); + VerifyP2P1GroupChain(module, suffix); } - // Verifies that the control dependence enforces this ordering for a pipelined - // chain with two Send-Recv groups in a while-body: - // recv-done.0 => send-done.0 => recv-done.1 => send-done.1 => - // recv.0 => send.0 => recv.1 => send.1 - void VerifyPipelined2P2PChild(HloModule* module, const std::string& suffix0, - const std::string& suffix1) { + // Verifies that the control dependence enforces this ordering: + // recv.0 => send.0 => recv.1 => send.1 => + // recv-done.0 => recv-done.1 => send-done.0 => send-done.1 + void VerifyP2P2GroupChain(HloModule* module, const std::string& suffix0, + const std::string& suffix1) { HloInstruction* send0 = FindInstruction(module, "send" + suffix0); HloInstruction* recv0 = FindInstruction(module, "recv" + suffix0); HloInstruction* recv_done0 = FindInstruction(module, "recv-done" + suffix0); @@ -108,37 +95,33 @@ class P2PSchedulePreparationTest : public HloTestBase { HloInstruction* recv_done1 = FindInstruction(module, "recv-done" + suffix1); HloInstruction* send_done1 = FindInstruction(module, "send-done" + suffix1); - EXPECT_EQ(send_done0->control_predecessors()[0], recv_done0); - EXPECT_EQ(recv_done1->control_predecessors()[0], send_done0); - EXPECT_EQ(send_done1->control_predecessors()[0], recv_done1); + EXPECT_EQ(recv_done1->control_predecessors()[0], recv_done0); + EXPECT_EQ(send_done0->control_predecessors()[0], recv_done1); + EXPECT_EQ(send_done1->control_predecessors()[0], send_done0); EXPECT_EQ(send0->control_predecessors()[0], recv0); EXPECT_EQ(recv1->control_predecessors()[0], send0); EXPECT_EQ(send1->control_predecessors()[0], recv1); + + EXPECT_EQ(recv_done0->control_predecessors()[0], send1); + } + + // Verifies that the control dependence enforces this ordering for a pipelined + // chain with two Send-Recv groups in a while-body: + // recv.0 => send.0 => recv.1 => send.1 => + // recv-done.0 => send-done.0 => recv-done.1 => send-done.1 + void VerifyPipelined2P2PChild(HloModule* module, const std::string& suffix0, + const std::string& suffix1) { + VerifyP2P2GroupChain(module, suffix0, suffix1); } // Verifies that the control dependence enforces this ordering for a pipelined // chain with two Send-Recv groups in the while-loop calling computation: - // recv.0 => send.0 => recv.1 => send.1 => while-loop + // recv.0 => send.0 => recv.1 => send.1 => // => recv-done.0 => send-done.0 => recv-done.1 => send-done.1 void VerifyPipelined2P2PParent(HloModule* module, const std::string& suffix0, const std::string& suffix1) { - HloInstruction* send0 = FindInstruction(module, "send" + suffix0); - HloInstruction* recv0 = FindInstruction(module, "recv" + suffix0); - HloInstruction* recv_done0 = FindInstruction(module, "recv-done" + suffix0); - HloInstruction* send_done0 = FindInstruction(module, "send-done" + suffix0); - HloInstruction* send1 = FindInstruction(module, "send" + suffix1); - HloInstruction* recv1 = FindInstruction(module, "recv" + suffix1); - HloInstruction* recv_done1 = FindInstruction(module, "recv-done" + suffix1); - HloInstruction* send_done1 = FindInstruction(module, "send-done" + suffix1); - - EXPECT_EQ(send0->control_predecessors()[0], recv0); - EXPECT_EQ(recv1->control_predecessors()[0], send0); - EXPECT_EQ(send1->control_predecessors()[0], recv1); - - EXPECT_EQ(send_done0->control_predecessors()[0], recv_done0); - EXPECT_EQ(recv_done1->control_predecessors()[0], send_done0); - EXPECT_EQ(send_done1->control_predecessors()[0], recv_done1); + VerifyP2P2GroupChain(module, suffix0, suffix1); } }; @@ -385,52 +368,38 @@ std::string GetPipelinedP2PModuleString(bool nested_p2p_in_main = false, // while-loop with nested P2P chains. constexpr char kUnnestedResult[] = R"( while-result-1 = f32[1, 1024, 1024] get-tuple-element(while-result), index=1 - ROOT collective-permute.2 = f32[1, 1024, 1024] collective-permute(while-result-1), + collective-permute.2 = f32[1, 1024, 1024] collective-permute(init), source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}} + ROOT entry-result = f32[1, 1024, 1024] add(while-result-1, collective-permute.2) )"; // Similar to the above, but for test_custom_call = true. constexpr char kUnnestedResultWithCustomCall[] = R"( while-result-1 = f32[1, 1024, 1024] get-tuple-element(while-result), index=1 - ROOT custom-call = f32[1, 1024, 1024] custom-call(while-result-1), + custom-call = f32[1, 1024, 1024] custom-call(init), custom_call_target="my_custom_call" + ROOT entry-result = f32[1, 1024, 1024] add(while-result-1, custom-call) )"; // This is the result for the main computation, if it has another while-loop // with nested P2P chains. constexpr char kNestedResult[] = R"( while-result-1 = f32[1, 1024, 1024] get-tuple-element(while-result), index=1 - while-init-2 = (u32[], f32[1, 1024, 1024]) tuple(c0, while-result-1) - while-result-2 = (u32[], f32[1, 1024, 1024]) while(while-init-2), + while-init-2 = (u32[], f32[1, 1024, 1024]) tuple(c0, init) + while-2 = (u32[], f32[1, 1024, 1024]) while(while-init-2), body=while-body-2, condition=while-cond-2, backend_config={"known_trip_count":{"n":"25"}} - ROOT entry-result = f32[1, 1024, 1024] get-tuple-element(while-result-2), index=1 + while-result-2 = f32[1, 1024, 1024] get-tuple-element(while-2), index=1 + ROOT entry-result = f32[1, 1024, 1024] add(while-result-1, while-result-2) )"; constexpr char kPipelinedWhileBodyWithoutOtherP2P[] = R"( while-body { - param = (u32[], (f32[1, 1024, 1024], token[]), - (f32[1, 1024, 1024], token[])) parameter(0) + param = (u32[], (f32[1, 1024, 1024], token[]), token[]) parameter(0) count = get-tuple-element(param), index=0 - // Mimic the code transformation done by copy-insertion to complicate - // the code pattern. - send.1.q.t = (f32[1,1024,1024], token[]) get-tuple-element(param), index=1 - send.1.q.data = f32[1,1024,1024] get-tuple-element(send.1.q.t), index=0 - send.1.q.data.copy = f32[1,1024,1024] copy(send.1.q.data) - send.1.q.token = token[] get-tuple-element(send.1.q.t), index=1 - send.1.q = (f32[1, 1024, 1024], token[]) tuple(send.1.q.data.copy, send.1.q.token) - - recv.1.q = (f32[1, 1024, 1024], token[])get-tuple-element(param), index=1 - send-done.1 = token[] send-done(send.1.q), channel_id=1, - frontend_attributes={ - _xla_send_recv_pipeline="0" - } - recv-done.1 = token[] recv-done(recv.1.q), channel_id=1, - frontend_attributes={ - _xla_send_recv_pipeline="0" - } - recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.1), index=0 + recv-done.1.q = (f32[1, 1024, 1024], token[]) get-tuple-element(param), index=1 + recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.1.q), index=0 c1 = u32[] constant(1) new-count = u32[] add(count, c1) @@ -455,39 +424,31 @@ std::string GetPipelinedP2PModuleString(bool nested_p2p_in_main = false, _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", _xla_send_recv_pipeline="0" } - - // Mimic the code transformation done by copy-insertion to complicate - // the code pattern. + send-done.1 = token[] send-done(send.1), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } recv.1 = (f32[1, 1024, 1024], token[]) recv(after-all.1), channel_id=1, frontend_attributes={ _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", _xla_send_recv_pipeline="0" } - recv.1.data = f32[1,1024,1024] get-tuple-element(recv.1), index=0 - recv.1.data.copy = f32[1,1024,1024] copy(recv.1.data) - recv.1.token = token[] get-tuple-element(recv.1), index=1 - recv.1.tuple = (f32[1,1024,1024], token[]) tuple(recv.1.data.copy, recv.1.token) + recv-done.1 = (f32[1, 1024, 1024], token[]) recv-done(recv.1), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } - ROOT body-result = (u32[], (f32[1, 1024, 1024], token[]), - (f32[1, 1024, 1024], token[])) tuple(new-count, recv.1, send.1) + ROOT body-result = (u32[], (f32[1, 1024, 1024], token[]), token[]) + tuple(new-count, recv-done.1, send-done.1) } )"; constexpr char kPipelinedWhileBodyWithOtherP2P[] = R"( while-body { - param = (u32[], (f32[1, 1024, 1024], token[]), (f32[1, 1024, 1024], token[])) parameter(0) + param = (u32[], (f32[1, 1024, 1024], token[]), token[]) parameter(0) count = get-tuple-element(param), index=0 - send.1.q = (f32[1, 1024, 1024], token[]) get-tuple-element(param), index=2 - recv.1.q = (f32[1, 1024, 1024], token[])get-tuple-element(param), index=1 - send-done.1 = token[] send-done(send.1.q), channel_id=1, - frontend_attributes={ - _xla_send_recv_pipeline="0" - } - recv-done.1 = token[] recv-done(recv.1.q), channel_id=1, - frontend_attributes={ - _xla_send_recv_pipeline="0" - } - recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.1), index=0 + recv-done.1.q = (f32[1, 1024, 1024], token[])get-tuple-element(param), index=1 + recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done.1.q), index=0 c1 = u32[] constant(1) new-count = u32[] add(count, c1) @@ -509,30 +470,37 @@ std::string GetPipelinedP2PModuleString(bool nested_p2p_in_main = false, after-all.4 = token[] after-all() send.4 = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all.4), channel_id=4, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" - } + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" + } send-done.4 = token[] send-done(send.4), channel_id=4 recv.4 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.4), channel_id=4, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" - } + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}" + } recv-done.4 = (f32[1, 1024, 1024], token[]) recv-done(recv.4), channel_id=4 new-data = f32[1, 1024, 1024] get-tuple-element(recv-done.4), index=0 after-all.1 = token[] after-all() send.1 = (f32[1, 1024, 1024], token[]) send(new-data, after-all.1), channel_id=1, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", - _xla_send_recv_pipeline="0" - } - recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=1, + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", + _xla_send_recv_pipeline="0" + } + send-done.1 = token[] send-done(send.1), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + recv.1 = (f32[1, 1024, 1024], token[]) recv(after-all.1), channel_id=1, frontend_attributes={ _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", _xla_send_recv_pipeline="0" } - - ROOT body-result = (u32[], (f32[1, 1024, 1024], token[]), - (f32[1, 1024, 1024], token[])) tuple(new-count, recv.1, send.1) + recv-done.1 = (f32[1, 1024, 1024], token[]) recv-done(recv.1), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + ROOT body-result = (u32[], (f32[1, 1024, 1024], token[]), token[]) + tuple(new-count, recv-done.1, send-done.1) } )"; @@ -540,7 +508,7 @@ std::string GetPipelinedP2PModuleString(bool nested_p2p_in_main = false, HloModule test while-cond { - param = (u32[], (f32[1, 1024, 1024], u32[], token[]), (f32[1, 1024, 1024], u32[], token[])) parameter(0) + param = (u32[], (f32[1, 1024, 1024], u32[], token[]), token[]) parameter(0) count = get-tuple-element(param), index=0 ub = u32[] constant(25) ROOT cond-result = pred[] compare(count, ub), direction=LT @@ -560,41 +528,33 @@ std::string GetPipelinedP2PModuleString(bool nested_p2p_in_main = false, after-all.2 = token[] after-all() recv.2 = (f32[1, 1024, 1024], token[]) recv(after-all.2), channel_id=1, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", - _xla_send_recv_pipeline="0" - } - - // Mimic the code transformation done by copy-insertion to complicate - // the code pattern. + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", + _xla_send_recv_pipeline="0" + } + recv-done.2 = (f32[1, 1024, 1024], token[]) recv-done(recv.2), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } send.2 = (f32[1, 1024, 1024], token[]) send(init, after-all.2), channel_id=1, frontend_attributes={ - _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", - _xla_send_recv_pipeline="0" - } - send.2.data = f32[1,1024,1024] get-tuple-element(send.2), index=0 - send.2.data.copy = f32[1,1024,1024] copy(send.2.data) - send.2.token = token[] get-tuple-element(send.2), index=1 - send.2.tuple = (f32[1,1024,1024], token[]) tuple(send.2.data.copy, send.2.token) - - while-init = (u32[], (f32[1, 1024, 1024], token[]), - (f32[1, 1024, 1024], token[])) tuple(c0, recv.2, send.2.tuple) - while-result = (u32[], (f32[1, 1024, 1024], token[]), - (f32[1, 1024, 1024], token[])) while(while-init), - body=while-body, condition=while-cond, - backend_config={"known_trip_count":{"n":"25"}} - - recv.2.q = (f32[1, 1024, 1024], token[]) get-tuple-element(while-result), index=1 - recv-done.2 = (f32[1, 1024, 1024], token[]) recv-done(recv.2.q), channel_id=1, - frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}", _xla_send_recv_pipeline="0" } - recv-data.2.q = f32[1, 1024, 1024] get-tuple-element(recv-done.2), index=0 - send.2.q = (f32[1, 1024, 1024], token[]) get-tuple-element(while-result), index=2 - send-done.2 = token[] send-done(send.2.q), channel_id=1, + send-done.2 = token[] send-done(send.2), channel_id=1, frontend_attributes={ _xla_send_recv_pipeline="0" } + while-init = (u32[], (f32[1, 1024, 1024], token[]), token[]) + tuple(c0, recv-done.2, send-done.2) + while-result = (u32[], (f32[1, 1024, 1024], token[]), token[]) + while(while-init), + body=while-body, condition=while-cond, + backend_config={"known_trip_count":{"n":"25"}} + + recv-done.2.q = (f32[1, 1024, 1024], token[]) get-tuple-element(while-result), index=1 + recv-data.2.q = f32[1, 1024, 1024] get-tuple-element(recv-done.2.q), index=0 + // The code for the computation result goes here. %s } @@ -626,18 +586,22 @@ TEST_F(P2PSchedulePreparationTest, UnnestedPipelinedP2PChainTransformed) { // Verify the pipelined P2P chain in the main computation. VerifyPipelinedP2PParent(module.get(), ".2"); - // Verify in the while-body collective-permute is scheduled after Send-done. - HloInstruction* send_done_1 = FindInstruction(module.get(), "send-done.1"); + // Verify in the while-body collective-permute is scheduled before recv. + HloInstruction* recv_1 = FindInstruction(module.get(), "recv.1"); HloInstruction* collective_1 = FindInstruction(module.get(), "collective-permute.1"); - EXPECT_EQ(collective_1->control_predecessors()[0], send_done_1); + EXPECT_EQ(recv_1->control_predecessors()[0], collective_1); - // Verify in the main computation collective-permute is scheduled after the - // Send-done for the pipelined while-loop. + // Verify in the main computation collective-permute is either scheduled + // after send-done or before recv of the pipelined P2P chain. HloInstruction* send_done_2 = FindInstruction(module.get(), "send-done.2"); + HloInstruction* recv_2 = FindInstruction(module.get(), "recv.2"); HloInstruction* collective_2 = FindInstruction(module.get(), "collective-permute.2"); - EXPECT_EQ(collective_2->control_predecessors()[0], send_done_2); + EXPECT_TRUE((!collective_2->control_predecessors().empty() && + collective_2->control_predecessors()[0] == send_done_2) || + (!recv_2->control_predecessors().empty() && + recv_2->control_predecessors()[0] == collective_2)); } TEST_F(P2PSchedulePreparationTest, NestedPipelinedP2PChainTransformed) { @@ -657,11 +621,15 @@ TEST_F(P2PSchedulePreparationTest, NestedPipelinedP2PChainTransformed) { // Verify the unpipelined P2P chain in the other while-body. VerifyUnpipelinedP2P(module.get(), ".3"); - // Verify that the while-loop with nested P2P is schedule after the last - // Send-done of the pipeline P2P chain. - HloInstruction* send_done = FindInstruction(module.get(), "send-done.2"); - HloInstruction* while_user = FindInstruction(module.get(), "while-result-2"); - EXPECT_EQ(while_user->control_predecessors()[0], send_done); + // Verify in the while-loop with nested P2P is either scheduled after + // end-done or before recv of the pipelined P2P chain. + HloInstruction* send_done_2 = FindInstruction(module.get(), "send-done.2"); + HloInstruction* recv_2 = FindInstruction(module.get(), "recv.2"); + HloInstruction* while_2 = FindInstruction(module.get(), "while-2"); + EXPECT_TRUE((!while_2->control_predecessors().empty() && + while_2->control_predecessors()[0] == send_done_2) || + (!recv_2->control_predecessors().empty() && + recv_2->control_predecessors()[0] == while_2)); } TEST_F(P2PSchedulePreparationTest, @@ -682,16 +650,11 @@ TEST_F(P2PSchedulePreparationTest, // Verify the other unpipelined P2P chain in the while-body. VerifyUnpipelinedP2P(module.get(), ".4"); - // Verify that in the pipelined while-body, the pipelined Send-done is ordered - // before other P2P while the pipelined Recv is ordered after other P2P. - HloInstruction* pipelined_send_done = - FindInstruction(module.get(), "send-done.1"); + // Verify that in the pipelined while-body, the pipelined recv is ordered + // after other P2P. HloInstruction* pipelined_recv = FindInstruction(module.get(), "recv.1"); - HloInstruction* other_recv = FindInstruction(module.get(), "recv.4"); HloInstruction* other_send_done = FindInstruction(module.get(), "send-done.4"); - EXPECT_EQ(1, absl::c_count(other_recv->control_predecessors(), - pipelined_send_done)); EXPECT_EQ(1, absl::c_count(pipelined_recv->control_predecessors(), other_send_done)); } @@ -707,11 +670,15 @@ TEST_F(P2PSchedulePreparationTest, TF_ASSERT_OK_AND_ASSIGN(bool changed, preparation.Run(module.get())); EXPECT_TRUE(changed); - // Verify in the main computation custom-call is scheduled after the - // Send-done for the pipelined while-loop. + // Verify in the main computation, custom-call is either scheduled after + // end-done or before recv of the pipelined P2P chain. HloInstruction* send_done_2 = FindInstruction(module.get(), "send-done.2"); + HloInstruction* recv_2 = FindInstruction(module.get(), "recv.2"); HloInstruction* custom_call = FindInstruction(module.get(), "custom-call"); - EXPECT_EQ(custom_call->control_predecessors()[0], send_done_2); + EXPECT_TRUE((!custom_call->control_predecessors().empty() && + custom_call->control_predecessors()[0] == send_done_2) || + (!recv_2->control_predecessors().empty() && + recv_2->control_predecessors()[0] == custom_call)); } TEST_F(P2PSchedulePreparationTest, PipelinedP2PChain2Transformed) { @@ -719,31 +686,22 @@ TEST_F(P2PSchedulePreparationTest, PipelinedP2PChain2Transformed) { HloModule test cond { - param = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[]), - (u32[2], u32[], token[]), (u32[2], u32[], token[])) parameter(0) + param = (u32[], (u32[2], token[]), (u32[2], token[]), + token[], token[]) parameter(0) count = get-tuple-element(%param), index=0 ub = u32[] constant(10) ROOT result = pred[] compare(count, ub), direction=LT } body { - param = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[]), - (u32[2], u32[], token[]), (u32[2], u32[], token[])) parameter(0) + param = (u32[], (u32[2], token[]), (u32[2], token[]), + token[], token[]) parameter(0) count = get-tuple-element(param), index=0 - recv.0.f = (u32[2], u32[], token[]) get-tuple-element(param), index=1 - recv-done.0 = (u32[2], token[]) recv-done(recv.0.f), channel_id=1, - frontend_attributes={ - _xla_send_recv_pipeline="0" - } - recv-data.0 = u32[2] get-tuple-element(recv-done.0), index=0 - - recv.1.f = (u32[2], u32[], token[]) get-tuple-element(param), index=2 - recv-done.1 = (u32[2], token[]) recv-done(recv.1.f), channel_id=2, - frontend_attributes={ - _xla_send_recv_pipeline="1" - } - recv-data.1 = u32[2] get-tuple-element(recv-done.1), index=0 + recv-done.0.f = (u32[2], token[]) get-tuple-element(param), index=1 + recv-data.0 = u32[2] get-tuple-element(recv-done.0.f), index=0 + recv-done.1.f = (u32[2], token[]) get-tuple-element(param), index=2 + recv-data.1 = u32[2] get-tuple-element(recv-done.1.f), index=0 replica = u32[] replica-id() constant0 = u32[] constant(0) @@ -757,17 +715,6 @@ body { r = u32[2] broadcast(c1), dimensions={} s = u32[2] add(r, recv-data) - send.0.f = (u32[2], u32[], token[]) get-tuple-element(param), index=3 - send-done.0 = token[] send-done(send.0.f), channel_id=1, - frontend_attributes={ - _xla_send_recv_pipeline="0" - } - send.1.f = (u32[2], u32[], token[]) get-tuple-element(param), index=4 - send-done.1 = token[] send-done(send.1.f), channel_id=2, - frontend_attributes={ - _xla_send_recv_pipeline="1" - } - // The Recv "rotated" from the beginning of the loop to the end of the loop. after-all.0.n = token[] after-all() recv.0 = (u32[2], u32[], token[]) recv(after-all.0.n), channel_id=1, @@ -781,6 +728,14 @@ body { _xla_send_recv_source_target_pairs="{{3,0}}", _xla_send_recv_pipeline="0" } + recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + send-done.0 = token[] send-done(send.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } after-all.1.n = token[] after-all() recv.1 = (u32[2], u32[], token[]) recv(after-all.1.n), channel_id=2, @@ -794,9 +749,16 @@ body { _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3}}", _xla_send_recv_pipeline="1" } - - ROOT result = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[]), - (u32[2], u32[], token[]), (u32[2], u32[], token[])) tuple(new_count, recv.0, recv.1, send.0, send.1) + recv-done.1 = (u32[2], token[]) recv-done(recv.1), channel_id=2, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } + send-done.1 = token[] send-done(send.1), channel_id=2, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } + ROOT result = (u32[], (u32[2], token[]), (u32[2], token[]), token[], token[]) + tuple(new_count, recv-done.0, recv-done.1, send-done.0, send-done.1) } ENTRY test_computation { @@ -819,7 +781,14 @@ body { _xla_send_recv_source_target_pairs="{{3,0}}", _xla_send_recv_pipeline="0" } - + recv-done.2 = (u32[2], token[]) recv-done(recv.2), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + send-done.2 = token[] send-done(send.2), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } after-all.1.p = token[] after-all() recv.3 = (u32[2], u32[], token[]) recv(after-all.1.p), channel_id=2, frontend_attributes={ @@ -832,30 +801,28 @@ body { _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3}}", _xla_send_recv_pipeline="1" } - + recv-done.3 = (u32[2], token[]) recv-done(recv.3), channel_id=2, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } + send-done.3 = token[] send-done(send.3), channel_id=2, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } // This is the pipelined loop. - while_init = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[]), - (u32[2], u32[], token[]), (u32[2], u32[], token[])) tuple(c0, recv.2, recv.3, send.2, send.3) + while_init = (u32[], (u32[2], token[]), (u32[2], token[]), + token[], token[]) tuple(c0, recv-done.2, recv-done.3, send-done.2, send-done.3) while_result = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[]), - (u32[2], u32[], token[]), (u32[2], u32[], token[])) while(while_init), body=body, condition=cond, + token[], token[]) while(while_init), body=body, condition=cond, backend_config={"known_trip_count":{"n":"10"}} // This is the remaining Send/Send-done/Recv-done for the pipeline. // Use .q as suffix for HLO name. + recv-done.0.q = (u32[2], u32[], token[]) get-tuple-element(while_result), index=1 + recv-data.0.q = u32[2] get-tuple-element(recv-done.0.q), index=0 - recv.0.q = (u32[2], u32[], token[]) get-tuple-element(while_result), index=1 - recv-done.2 = (u32[2], token[]) recv-done(recv.0.q), channel_id=1, - frontend_attributes={ - _xla_send_recv_pipeline="0" - } - recv-data.0.q = u32[2] get-tuple-element(recv-done.2), index=0 - - recv.1.q = (u32[2], u32[], token[]) get-tuple-element(while_result), index=2 - recv-done.3 = (u32[2], token[]) recv-done(recv.1.q), channel_id=2, - frontend_attributes={ - _xla_send_recv_pipeline="1" - } - recv-data.1.q = u32[2] get-tuple-element(recv-done.2), index=0 + recv-done.1.q = (u32[2], u32[], token[]) get-tuple-element(while_result), index=2 + recv-data.1.q = u32[2] get-tuple-element(recv-done.1.q), index=0 replica = u32[] replica-id() constant0 = u32[] constant(0) @@ -865,18 +832,109 @@ body { s = u32[2] add(c1, recv-data) - send.0.q = (u32[2], u32[], token[]) get-tuple-element(while_result), index=3 - send-done.2 = token[] send-done(send.0.q), channel_id=1, + ROOT result = u32[2] add(s, recv-data) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kModuleStr))); + P2PSchedulePreparation preparation; + TF_ASSERT_OK_AND_ASSIGN(bool changed, preparation.Run(module.get())); + VLOG(10) << module->ToString(); + EXPECT_TRUE(changed); + + // Verify the pipelined P2P chain in the while-body. + VerifyPipelined2P2PChild(module.get(), ".0", ".1"); + // Verify the pipelined P2P chain in the main computation. + VerifyPipelined2P2PParent(module.get(), ".2", ".3"); +} + +TEST_F(P2PSchedulePreparationTest, UnpipelinedP2PChain2Transformed) { + const char* const kModuleStr = R"( + HloModule test + +cond { + param = (u32[], u32[2]) parameter(0) + count = get-tuple-element(%param), index=0 + ub = u32[] constant(11) + ROOT result = pred[] compare(count, ub), direction=LT + } + +body { + param = (u32[], u32[2]) parameter(0) + count = get-tuple-element(param), index=0 + send-data = u32[2] get-tuple-element(param), index=1 + + after-all.0.n = token[] after-all() + recv.0 = (u32[2], u32[], token[]) recv(after-all.0.n), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_send_recv_pipeline="0" + } + send.0 = (u32[2], u32[], token[]) send(send-data, after-all.0.n), + channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{3,0}}", + _xla_send_recv_pipeline="0" + } + recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=1, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } + send-done.0 = token[] send-done(send.0), channel_id=1, frontend_attributes={ _xla_send_recv_pipeline="0" } - send.1.q = (u32[2], u32[], token[]) get-tuple-element(while_result), index=4 - send-done.3 = token[] send-done(send.1.q), channel_id=2, + + after-all.1 = token[] after-all() + recv.1 = (u32[2], u32[], token[]) recv(after-all.1), channel_id=2, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3}}", + _xla_send_recv_pipeline="1" + } + send.1 = (u32[2], u32[], token[]) send(send-data, after-all.1), + channel_id=2, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3}}", + _xla_send_recv_pipeline="1" + } + recv-done.1 = (u32[2], token[]) recv-done(recv.1), channel_id=2, + frontend_attributes={ + _xla_send_recv_pipeline="1" + } + send-done.1 = token[] send-done(send.1), channel_id=2, frontend_attributes={ _xla_send_recv_pipeline="1" } - ROOT result = u32[2] add(s, recv-data) + recv-data.0 = u32[2] get-tuple-element(recv-done.0), index=0 + recv-data.1 = u32[2] get-tuple-element(recv-done.1), index=0 + + replica = u32[] replica-id() + constant0 = u32[] constant(0) + compare0 = pred[] compare(replica, constant0), direction=EQ + compare = pred[2] broadcast(compare0), dimensions={} + recv-data = u32[2] select(compare, recv-data.0, recv-data.1) + + c1 = u32[] constant(1) + new_count = u32[] add(count, c1) + + r = u32[2] broadcast(c1), dimensions={} + s = u32[2] add(r, recv-data) + + ROOT result = (u32[], u32[2]) tuple(new_count, s) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + c1 = u32[] constant(1) + r = u32[] replica-id() + a = u32[] add(c1, r) + init = u32[2] broadcast(a), dimensions={} + while_init = (u32[], u32[2]) tuple(c0, init) + while_result = (u32[], u32[2]) while(while_init), body=body, condition=cond, + backend_config={"known_trip_count":{"n":"11"}} + ROOT recv-data = u32[2] get-tuple-element(while_result), index=1 } )"; @@ -884,13 +942,10 @@ body { ParseAndReturnUnverifiedModule((kModuleStr))); P2PSchedulePreparation preparation; TF_ASSERT_OK_AND_ASSIGN(bool changed, preparation.Run(module.get())); - VLOG(10) << module->ToString(); EXPECT_TRUE(changed); - // Verify the pipelined P2P chain in the while-body. - VerifyPipelined2P2PChild(module.get(), ".0", ".1"); - // Verify the pipelined P2P chain in the main computation. - VerifyPipelined2P2PParent(module.get(), ".2", ".3"); + // Verify the unpipelined P2P chain with two channels in the while-body. + VerifyP2P2GroupChain(module.get(), ".0", ".1"); } } // namespace