diff --git a/xla/hlo/ir/hlo_instruction.cc b/xla/hlo/ir/hlo_instruction.cc index dee81c0b08419..5db71de5162f0 100644 --- a/xla/hlo/ir/hlo_instruction.cc +++ b/xla/hlo/ir/hlo_instruction.cc @@ -465,7 +465,8 @@ absl::StatusOr> HloInstruction::CreateFromProto( case HloOpcode::kSendDone: TF_RET_CHECK(DynCast(operands(0)) != nullptr) << "SendDone must take the context operand from Send"; - instruction = CreateSendDone(operands(0), proto.is_host_transfer()); + instruction = CreateSendDone(operands(0), proto.channel_id(), + proto.is_host_transfer()); break; case HloOpcode::kRecv: instruction = CreateRecv(shape.tuple_shapes(0), operands(0), @@ -474,7 +475,8 @@ absl::StatusOr> HloInstruction::CreateFromProto( case HloOpcode::kRecvDone: TF_RET_CHECK(DynCast(operands(0)) != nullptr) << "RecvDone must take the context operand from Recv"; - instruction = CreateRecvDone(operands(0), proto.is_host_transfer()); + instruction = CreateRecvDone(operands(0), proto.channel_id(), + proto.is_host_transfer()); break; case HloOpcode::kReverse: instruction = @@ -1814,45 +1816,31 @@ HloInstruction::CreateCollectivePermuteStart( } /* static */ std::unique_ptr HloInstruction::CreateSend( - HloInstruction* operand, HloInstruction* token, int64_t channel_id, - bool is_host_transfer) { + HloInstruction* operand, HloInstruction* token, + std::optional channel_id, bool is_host_transfer) { return std::make_unique(operand, token, channel_id, is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateSendDone( - HloInstruction* operand, bool is_host_transfer) { - auto send_operand = DynCast(operand); - CHECK(send_operand != nullptr) - << "SendDone must take the context operand from Send"; - return std::make_unique(send_operand, - is_host_transfer); -} - -/* static */ std::unique_ptr HloInstruction::CreateSendDone( - HloInstruction* operand, int64_t channel_id, bool is_host_transfer) { + HloInstruction* operand, std::optional channel_id, + bool is_host_transfer) { + CHECK(operand->channel_id() == channel_id); return std::make_unique(operand, channel_id, is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateRecv( - const Shape& shape, HloInstruction* token, int64_t channel_id, - bool is_host_transfer) { + const Shape& shape, HloInstruction* token, + std::optional channel_id, bool is_host_transfer) { return std::make_unique(shape, token, channel_id, is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateRecvDone( - HloInstruction* operand, bool is_host_transfer) { - auto recv_operand = DynCast(operand); - CHECK(recv_operand != nullptr) - << "RecvDone must take the context operand from Recv"; - return std::make_unique(recv_operand, - is_host_transfer); -} - -/* static */ std::unique_ptr HloInstruction::CreateRecvDone( - HloInstruction* operand, int64_t channel_id, bool is_host_transfer) { + HloInstruction* operand, std::optional channel_id, + bool is_host_transfer) { + CHECK(operand->channel_id() == channel_id); return std::make_unique(operand, channel_id, is_host_transfer); } diff --git a/xla/hlo/ir/hlo_instruction.h b/xla/hlo/ir/hlo_instruction.h index 9ef72fd89591b..23a8fd9b96161 100644 --- a/xla/hlo/ir/hlo_instruction.h +++ b/xla/hlo/ir/hlo_instruction.h @@ -1191,16 +1191,13 @@ class HloInstruction { // another computation that has the same channel id. If is_host_transfer is // true, then this Send operation transfers data to the host. static std::unique_ptr CreateSend( - HloInstruction* operand, HloInstruction* token, int64_t channel_id, - bool is_host_transfer = false); + HloInstruction* operand, HloInstruction* token, + std::optional channel_id, bool is_host_transfer = false); // Blocks until data transfer for the Send instruction (operand) is complete. // The operand must be kSend. static std::unique_ptr CreateSendDone( - HloInstruction* operand, bool is_host_transfer = false); - // Similar to the above, but the operand doesn't have to be a kSend. - static std::unique_ptr CreateSendDone( - HloInstruction* operand, int64_t channel_id, + HloInstruction* operand, std::optional channel_id, bool is_host_transfer = false); // Creates an asynchronous receive instruction with the given channel id, @@ -1209,16 +1206,13 @@ class HloInstruction { // is_host_transfer is true, then this Recv operation transfers data from the // host. static std::unique_ptr CreateRecv( - const Shape& shape, HloInstruction* token, int64_t channel_id, - bool is_host_transfer = false); + const Shape& shape, HloInstruction* token, + std::optional channel_id, bool is_host_transfer = false); // Blocks until data transfer for the Recv instruction (operand) is complete // and returns the receive buffer. The operand must be kRecv. static std::unique_ptr CreateRecvDone( - HloInstruction* operand, bool is_host_transfer = false); - // Similar to the above, but the operand doesn't have to be a kRecv. - static std::unique_ptr CreateRecvDone( - HloInstruction* operand, int64_t channel_id, + HloInstruction* operand, std::optional channel_id, bool is_host_transfer = false); // Creates a slice instruction, where the operand is sliced by the given diff --git a/xla/hlo/ir/hlo_instructions.cc b/xla/hlo/ir/hlo_instructions.cc index cc3d81a846a94..89ec9286d730e 100644 --- a/xla/hlo/ir/hlo_instructions.cc +++ b/xla/hlo/ir/hlo_instructions.cc @@ -769,10 +769,9 @@ bool HloTopKInstruction::IdenticalSlowPath( return k() == casted_other.k() && largest() == casted_other.largest(); } -HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, - const Shape& shape, - int64_t channel_id, - bool is_host_transfer) +HloSendRecvInstruction::HloSendRecvInstruction( + HloOpcode opcode, const Shape& shape, std::optional channel_id, + bool is_host_transfer) : HloChannelInstruction(opcode, shape, channel_id), is_host_transfer_(is_host_transfer) {} @@ -802,7 +801,7 @@ bool HloSendRecvInstruction::IdenticalSlowPathIgnoringChannelIdValues( // Send instruction produces a tuple of {aliased operand, U32 context}. HloSendInstruction::HloSendInstruction(HloInstruction* operand, HloInstruction* token, - int64_t channel_id, + std::optional channel_id, bool is_host_transfer) : HloSendRecvInstruction( HloOpcode::kSend, @@ -818,21 +817,20 @@ std::unique_ptr HloSendInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return std::make_unique( - new_operands[0], new_operands[1], *channel_id(), is_host_transfer()); + return std::make_unique(new_operands[0], new_operands[1], + channel_id(), is_host_transfer()); } HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand, bool is_host_transfer) : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(), - CHECK_NOTNULL(operand)->channel_id().value(), - is_host_transfer) { + operand->channel_id(), is_host_transfer) { AppendOperand(operand); } -HloSendDoneInstruction::HloSendDoneInstruction(HloInstruction* operand, - int64_t channel_id, - bool is_host_transfer) +HloSendDoneInstruction::HloSendDoneInstruction( + HloInstruction* operand, std::optional channel_id, + bool is_host_transfer) : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(), channel_id, is_host_transfer) { AppendOperand(operand); @@ -848,14 +846,14 @@ HloSendDoneInstruction::CloneWithNewOperandsImpl( return std::make_unique(send, is_host_transfer()); } - return std::make_unique( - new_operands[0], channel_id().value(), is_host_transfer()); + return std::make_unique(new_operands[0], channel_id(), + is_host_transfer()); } // Recv instruction produces a tuple of {receive buffer, U32 context}. HloRecvInstruction::HloRecvInstruction(const Shape& shape, HloInstruction* token, - int64_t channel_id, + std::optional channel_id, bool is_host_transfer) : HloSendRecvInstruction( HloOpcode::kRecv, @@ -870,7 +868,7 @@ std::unique_ptr HloRecvInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return std::make_unique( - ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], *channel_id(), + ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(), is_host_transfer()); } @@ -881,13 +879,13 @@ HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand, ShapeUtil::MakeTupleShape( {ShapeUtil::GetTupleElementShape(operand->shape(), 0), ShapeUtil::MakeTokenShape()}), - CHECK_NOTNULL(operand)->channel_id().value(), is_host_transfer) { + operand->channel_id(), is_host_transfer) { AppendOperand(operand); } -HloRecvDoneInstruction::HloRecvDoneInstruction(HloInstruction* operand, - int64_t channel_id, - bool is_host_transfer) +HloRecvDoneInstruction::HloRecvDoneInstruction( + HloInstruction* operand, std::optional channel_id, + bool is_host_transfer) : HloSendRecvInstruction( HloOpcode::kRecvDone, ShapeUtil::MakeTupleShape( @@ -907,8 +905,8 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl( return std::make_unique(recv, is_host_transfer()); } - return std::make_unique( - new_operands[0], channel_id().value(), is_host_transfer()); + return std::make_unique(new_operands[0], channel_id(), + is_host_transfer()); } HloCollectiveInstruction::HloCollectiveInstruction( diff --git a/xla/hlo/ir/hlo_instructions.h b/xla/hlo/ir/hlo_instructions.h index 96d0e93bf4402..6830061d85036 100644 --- a/xla/hlo/ir/hlo_instructions.h +++ b/xla/hlo/ir/hlo_instructions.h @@ -555,7 +555,8 @@ class HloSendRecvInstruction : public HloChannelInstruction { protected: explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, - int64_t channel_id, bool is_host_transfer); + std::optional channel_id, + bool is_host_transfer); private: void PrintExtraAttributesImpl(AttributePrinter& printer, @@ -571,7 +572,8 @@ class HloSendRecvInstruction : public HloChannelInstruction { class HloSendInstruction : public HloSendRecvInstruction { public: explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token, - int64_t channel_id, bool is_host_transfer); + std::optional channel_id, + bool is_host_transfer); static bool ClassOf(const HloInstruction* hlo) { return hlo->opcode() == HloOpcode::kSend; @@ -588,7 +590,8 @@ class HloSendDoneInstruction : public HloSendRecvInstruction { public: explicit HloSendDoneInstruction(HloSendInstruction* operand, bool is_host_transfer); - explicit HloSendDoneInstruction(HloInstruction* operand, int64_t channel_id, + explicit HloSendDoneInstruction(HloInstruction* operand, + std::optional channel_id, bool is_host_transfer); static bool ClassOf(const HloInstruction* hlo) { return hlo->opcode() == HloOpcode::kSendDone; @@ -604,7 +607,8 @@ class HloSendDoneInstruction : public HloSendRecvInstruction { class HloRecvInstruction : public HloSendRecvInstruction { public: explicit HloRecvInstruction(const Shape& shape, HloInstruction* token, - int64_t channel_id, bool is_host_transfer); + std::optional channel_id, + bool is_host_transfer); static bool ClassOf(const HloInstruction* hlo) { return hlo->opcode() == HloOpcode::kRecv; @@ -621,7 +625,8 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction { public: explicit HloRecvDoneInstruction(HloRecvInstruction* operand, bool is_host_transfer); - explicit HloRecvDoneInstruction(HloInstruction* operand, int64_t channel_id, + explicit HloRecvDoneInstruction(HloInstruction* operand, + std::optional channel_id, bool is_host_transfer); static bool ClassOf(const HloInstruction* hlo) { diff --git a/xla/hlo/parser/hlo_parser.cc b/xla/hlo/parser/hlo_parser.cc index 90cb625c4745f..ffee1d9e87938 100644 --- a/xla/hlo/parser/hlo_parser.cc +++ b/xla/hlo/parser/hlo_parser.cc @@ -2188,7 +2188,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT optional channel_id; // If the is_host_transfer attribute is not present then default to false. optional is_host_transfer = false; - attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; + attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id}; attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool, &is_host_transfer}; if ((!preset_operands && @@ -2198,13 +2198,13 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT } // If the is_host_transfer attribute is not present then default to false. return builder->AddInstruction(HloInstruction::CreateRecv( - shape->tuple_shapes(0), operands[0], *channel_id, *is_host_transfer)); + shape->tuple_shapes(0), operands[0], channel_id, *is_host_transfer)); } case HloOpcode::kRecvDone: { optional channel_id; // If the is_host_transfer attribute is not present then default to false. optional is_host_transfer = false; - attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; + attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id}; attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool, &is_host_transfer}; if ((!preset_operands && @@ -2220,13 +2220,13 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT } return builder->AddInstruction(HloInstruction::CreateRecvDone( - operands[0], channel_id.value(), *is_host_transfer)); + operands[0], channel_id, *is_host_transfer)); } case HloOpcode::kSend: { optional channel_id; // If the is_host_transfer attribute is not present then default to false. optional is_host_transfer = false; - attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; + attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id}; attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool, &is_host_transfer}; if ((!preset_operands && @@ -2235,13 +2235,13 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT return nullptr; } return builder->AddInstruction(HloInstruction::CreateSend( - operands[0], operands[1], *channel_id, *is_host_transfer)); + operands[0], operands[1], channel_id, *is_host_transfer)); } case HloOpcode::kSendDone: { optional channel_id; // If the is_host_transfer attribute is not present then default to false. optional is_host_transfer = false; - attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; + attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id}; attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool, &is_host_transfer}; if ((!preset_operands && @@ -2257,7 +2257,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT } return builder->AddInstruction(HloInstruction::CreateSendDone( - operands[0], channel_id.value(), *is_host_transfer)); + operands[0], channel_id, *is_host_transfer)); } case HloOpcode::kGetTupleElement: { optional index; diff --git a/xla/hlo/parser/hlo_parser_test.cc b/xla/hlo/parser/hlo_parser_test.cc index ade7da2c25191..73e9d1c2550d3 100644 --- a/xla/hlo/parser/hlo_parser_test.cc +++ b/xla/hlo/parser/hlo_parser_test.cc @@ -428,6 +428,21 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0} } +)" +}, +{ +"SendRecvWoChannelID", +R"(HloModule SendRecvWoChannelID_module, entry_computation_layout={()->(f32[], token[])} + +ENTRY %computation () -> (f32[], token[]) { + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0) + ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv) + %constant = f32[] constant(2.1) + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0) + %send-done = token[] send-done((f32[], u32[], token[]) %send) +} + )" }, { diff --git a/xla/service/collective_permute_decomposer.cc b/xla/service/collective_permute_decomposer.cc index daa6feb484010..a9e4708521dc2 100644 --- a/xla/service/collective_permute_decomposer.cc +++ b/xla/service/collective_permute_decomposer.cc @@ -152,10 +152,10 @@ absl::Status DecomposeCollectivePermute( send->add_frontend_attributes(attributes); send->set_metadata(metadata); - HloInstruction* recv_done = - computation->AddInstruction(HloInstruction::CreateRecvDone(recv)); - HloInstruction* send_done = - computation->AddInstruction(HloInstruction::CreateSendDone(send)); + HloInstruction* recv_done = computation->AddInstruction( + HloInstruction::CreateRecvDone(recv, channel_id)); + HloInstruction* send_done = computation->AddInstruction( + HloInstruction::CreateSendDone(send, channel_id)); // We will add control dependence to represent how we want to order Send/Recv // and other collective operations. Here we only add the necessary control