Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make channel_id optional for send/recv ops #19239

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 14 additions & 26 deletions xla/hlo/ir/hlo_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,8 @@ absl::StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
case HloOpcode::kSendDone:
TF_RET_CHECK(DynCast<HloSendInstruction>(operands(0)) != nullptr)
<< "SendDone must take the context operand from Send";
instruction = CreateSendDone(operands(0), proto.is_host_transfer());
instruction = CreateSendDone(operands(0), proto.channel_id(),
proto.is_host_transfer());
break;
case HloOpcode::kRecv:
instruction = CreateRecv(shape.tuple_shapes(0), operands(0),
Expand All @@ -474,7 +475,8 @@ absl::StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
case HloOpcode::kRecvDone:
TF_RET_CHECK(DynCast<HloRecvInstruction>(operands(0)) != nullptr)
<< "RecvDone must take the context operand from Recv";
instruction = CreateRecvDone(operands(0), proto.is_host_transfer());
instruction = CreateRecvDone(operands(0), proto.channel_id(),
proto.is_host_transfer());
break;
case HloOpcode::kReverse:
instruction =
Expand Down Expand Up @@ -1814,45 +1816,31 @@ HloInstruction::CreateCollectivePermuteStart(
}

/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
HloInstruction* operand, HloInstruction* token, int64_t channel_id,
bool is_host_transfer) {
HloInstruction* operand, HloInstruction* token,
std::optional<int64_t> channel_id, bool is_host_transfer) {
return std::make_unique<HloSendInstruction>(operand, token, channel_id,
is_host_transfer);
}

/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
HloInstruction* operand, bool is_host_transfer) {
auto send_operand = DynCast<HloSendInstruction>(operand);
CHECK(send_operand != nullptr)
<< "SendDone must take the context operand from Send";
return std::make_unique<HloSendDoneInstruction>(send_operand,
is_host_transfer);
}

/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
HloInstruction* operand, int64_t channel_id, bool is_host_transfer) {
HloInstruction* operand, std::optional<int64_t> channel_id,
bool is_host_transfer) {
CHECK(operand->channel_id() == channel_id);
return std::make_unique<HloSendDoneInstruction>(operand, channel_id,
is_host_transfer);
}

/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
const Shape& shape, HloInstruction* token, int64_t channel_id,
bool is_host_transfer) {
const Shape& shape, HloInstruction* token,
std::optional<int64_t> channel_id, bool is_host_transfer) {
return std::make_unique<HloRecvInstruction>(shape, token, channel_id,
is_host_transfer);
}

/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
HloInstruction* operand, bool is_host_transfer) {
auto recv_operand = DynCast<HloRecvInstruction>(operand);
CHECK(recv_operand != nullptr)
<< "RecvDone must take the context operand from Recv";
return std::make_unique<HloRecvDoneInstruction>(recv_operand,
is_host_transfer);
}

/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
HloInstruction* operand, int64_t channel_id, bool is_host_transfer) {
HloInstruction* operand, std::optional<int64_t> channel_id,
bool is_host_transfer) {
CHECK(operand->channel_id() == channel_id);
return std::make_unique<HloRecvDoneInstruction>(operand, channel_id,
is_host_transfer);
}
Expand Down
18 changes: 6 additions & 12 deletions xla/hlo/ir/hlo_instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1191,16 +1191,13 @@ class HloInstruction {
// another computation that has the same channel id. If is_host_transfer is
// true, then this Send operation transfers data to the host.
static std::unique_ptr<HloInstruction> CreateSend(
HloInstruction* operand, HloInstruction* token, int64_t channel_id,
bool is_host_transfer = false);
HloInstruction* operand, HloInstruction* token,
std::optional<int64_t> channel_id, bool is_host_transfer = false);

// Blocks until data transfer for the Send instruction (operand) is complete.
// The operand must be kSend.
static std::unique_ptr<HloInstruction> CreateSendDone(
HloInstruction* operand, bool is_host_transfer = false);
// Similar to the above, but the operand doesn't have to be a kSend.
static std::unique_ptr<HloInstruction> CreateSendDone(
HloInstruction* operand, int64_t channel_id,
HloInstruction* operand, std::optional<int64_t> channel_id,
bool is_host_transfer = false);

// Creates an asynchronous receive instruction with the given channel id,
Expand All @@ -1209,16 +1206,13 @@ class HloInstruction {
// is_host_transfer is true, then this Recv operation transfers data from the
// host.
static std::unique_ptr<HloInstruction> CreateRecv(
const Shape& shape, HloInstruction* token, int64_t channel_id,
bool is_host_transfer = false);
const Shape& shape, HloInstruction* token,
std::optional<int64_t> channel_id, bool is_host_transfer = false);

// Blocks until data transfer for the Recv instruction (operand) is complete
// and returns the receive buffer. The operand must be kRecv.
static std::unique_ptr<HloInstruction> CreateRecvDone(
HloInstruction* operand, bool is_host_transfer = false);
// Similar to the above, but the operand doesn't have to be a kRecv.
static std::unique_ptr<HloInstruction> CreateRecvDone(
HloInstruction* operand, int64_t channel_id,
HloInstruction* operand, std::optional<int64_t> channel_id,
bool is_host_transfer = false);

// Creates a slice instruction, where the operand is sliced by the given
Expand Down
42 changes: 20 additions & 22 deletions xla/hlo/ir/hlo_instructions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -769,10 +769,9 @@ bool HloTopKInstruction::IdenticalSlowPath(
return k() == casted_other.k() && largest() == casted_other.largest();
}

HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
const Shape& shape,
int64_t channel_id,
bool is_host_transfer)
HloSendRecvInstruction::HloSendRecvInstruction(
HloOpcode opcode, const Shape& shape, std::optional<int64_t> channel_id,
bool is_host_transfer)
: HloChannelInstruction(opcode, shape, channel_id),
is_host_transfer_(is_host_transfer) {}

Expand Down Expand Up @@ -802,7 +801,7 @@ bool HloSendRecvInstruction::IdenticalSlowPathIgnoringChannelIdValues(
// Send instruction produces a tuple of {aliased operand, U32 context}.
HloSendInstruction::HloSendInstruction(HloInstruction* operand,
HloInstruction* token,
int64_t channel_id,
std::optional<int64_t> channel_id,
bool is_host_transfer)
: HloSendRecvInstruction(
HloOpcode::kSend,
Expand All @@ -818,21 +817,20 @@ std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return std::make_unique<HloSendInstruction>(
new_operands[0], new_operands[1], *channel_id(), is_host_transfer());
return std::make_unique<HloSendInstruction>(new_operands[0], new_operands[1],
channel_id(), is_host_transfer());
}

HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
bool is_host_transfer)
: HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
CHECK_NOTNULL(operand)->channel_id().value(),
is_host_transfer) {
operand->channel_id(), is_host_transfer) {
AppendOperand(operand);
}

HloSendDoneInstruction::HloSendDoneInstruction(HloInstruction* operand,
int64_t channel_id,
bool is_host_transfer)
HloSendDoneInstruction::HloSendDoneInstruction(
HloInstruction* operand, std::optional<int64_t> channel_id,
bool is_host_transfer)
: HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
channel_id, is_host_transfer) {
AppendOperand(operand);
Expand All @@ -848,14 +846,14 @@ HloSendDoneInstruction::CloneWithNewOperandsImpl(
return std::make_unique<HloSendDoneInstruction>(send, is_host_transfer());
}

return std::make_unique<HloSendDoneInstruction>(
new_operands[0], channel_id().value(), is_host_transfer());
return std::make_unique<HloSendDoneInstruction>(new_operands[0], channel_id(),
is_host_transfer());
}

// Recv instruction produces a tuple of {receive buffer, U32 context}.
HloRecvInstruction::HloRecvInstruction(const Shape& shape,
HloInstruction* token,
int64_t channel_id,
std::optional<int64_t> channel_id,
bool is_host_transfer)
: HloSendRecvInstruction(
HloOpcode::kRecv,
Expand All @@ -870,7 +868,7 @@ std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return std::make_unique<HloRecvInstruction>(
ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], *channel_id(),
ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(),
is_host_transfer());
}

Expand All @@ -881,13 +879,13 @@ HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand,
ShapeUtil::MakeTupleShape(
{ShapeUtil::GetTupleElementShape(operand->shape(), 0),
ShapeUtil::MakeTokenShape()}),
CHECK_NOTNULL(operand)->channel_id().value(), is_host_transfer) {
operand->channel_id(), is_host_transfer) {
AppendOperand(operand);
}

HloRecvDoneInstruction::HloRecvDoneInstruction(HloInstruction* operand,
int64_t channel_id,
bool is_host_transfer)
HloRecvDoneInstruction::HloRecvDoneInstruction(
HloInstruction* operand, std::optional<int64_t> channel_id,
bool is_host_transfer)
: HloSendRecvInstruction(
HloOpcode::kRecvDone,
ShapeUtil::MakeTupleShape(
Expand All @@ -907,8 +905,8 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl(
return std::make_unique<HloRecvDoneInstruction>(recv, is_host_transfer());
}

return std::make_unique<HloRecvDoneInstruction>(
new_operands[0], channel_id().value(), is_host_transfer());
return std::make_unique<HloRecvDoneInstruction>(new_operands[0], channel_id(),
is_host_transfer());
}

HloCollectiveInstruction::HloCollectiveInstruction(
Expand Down
15 changes: 10 additions & 5 deletions xla/hlo/ir/hlo_instructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,8 @@ class HloSendRecvInstruction : public HloChannelInstruction {

protected:
explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape,
int64_t channel_id, bool is_host_transfer);
std::optional<int64_t> channel_id,
bool is_host_transfer);

private:
void PrintExtraAttributesImpl(AttributePrinter& printer,
Expand All @@ -571,7 +572,8 @@ class HloSendRecvInstruction : public HloChannelInstruction {
class HloSendInstruction : public HloSendRecvInstruction {
public:
explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token,
int64_t channel_id, bool is_host_transfer);
std::optional<int64_t> channel_id,
bool is_host_transfer);

static bool ClassOf(const HloInstruction* hlo) {
return hlo->opcode() == HloOpcode::kSend;
Expand All @@ -588,7 +590,8 @@ class HloSendDoneInstruction : public HloSendRecvInstruction {
public:
explicit HloSendDoneInstruction(HloSendInstruction* operand,
bool is_host_transfer);
explicit HloSendDoneInstruction(HloInstruction* operand, int64_t channel_id,
explicit HloSendDoneInstruction(HloInstruction* operand,
std::optional<int64_t> channel_id,
bool is_host_transfer);
static bool ClassOf(const HloInstruction* hlo) {
return hlo->opcode() == HloOpcode::kSendDone;
Expand All @@ -604,7 +607,8 @@ class HloSendDoneInstruction : public HloSendRecvInstruction {
class HloRecvInstruction : public HloSendRecvInstruction {
public:
explicit HloRecvInstruction(const Shape& shape, HloInstruction* token,
int64_t channel_id, bool is_host_transfer);
std::optional<int64_t> channel_id,
bool is_host_transfer);

static bool ClassOf(const HloInstruction* hlo) {
return hlo->opcode() == HloOpcode::kRecv;
Expand All @@ -621,7 +625,8 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction {
public:
explicit HloRecvDoneInstruction(HloRecvInstruction* operand,
bool is_host_transfer);
explicit HloRecvDoneInstruction(HloInstruction* operand, int64_t channel_id,
explicit HloRecvDoneInstruction(HloInstruction* operand,
std::optional<int64_t> channel_id,
bool is_host_transfer);

static bool ClassOf(const HloInstruction* hlo) {
Expand Down
16 changes: 8 additions & 8 deletions xla/hlo/parser/hlo_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2188,7 +2188,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT
optional<int64_t> channel_id;
// If the is_host_transfer attribute is not present then default to false.
optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
&is_host_transfer};
if ((!preset_operands &&
Expand All @@ -2198,13 +2198,13 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT
}
// If the is_host_transfer attribute is not present then default to false.
return builder->AddInstruction(HloInstruction::CreateRecv(
shape->tuple_shapes(0), operands[0], *channel_id, *is_host_transfer));
shape->tuple_shapes(0), operands[0], channel_id, *is_host_transfer));
}
case HloOpcode::kRecvDone: {
optional<int64_t> channel_id;
// If the is_host_transfer attribute is not present then default to false.
optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
&is_host_transfer};
if ((!preset_operands &&
Expand All @@ -2220,13 +2220,13 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT
}

return builder->AddInstruction(HloInstruction::CreateRecvDone(
operands[0], channel_id.value(), *is_host_transfer));
operands[0], channel_id, *is_host_transfer));
}
case HloOpcode::kSend: {
optional<int64_t> channel_id;
// If the is_host_transfer attribute is not present then default to false.
optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
&is_host_transfer};
if ((!preset_operands &&
Expand All @@ -2235,13 +2235,13 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT
return nullptr;
}
return builder->AddInstruction(HloInstruction::CreateSend(
operands[0], operands[1], *channel_id, *is_host_transfer));
operands[0], operands[1], channel_id, *is_host_transfer));
}
case HloOpcode::kSendDone: {
optional<int64_t> channel_id;
// If the is_host_transfer attribute is not present then default to false.
optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
&is_host_transfer};
if ((!preset_operands &&
Expand All @@ -2257,7 +2257,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT
}

return builder->AddInstruction(HloInstruction::CreateSendDone(
operands[0], channel_id.value(), *is_host_transfer));
operands[0], channel_id, *is_host_transfer));
}
case HloOpcode::kGetTupleElement: {
optional<int64_t> index;
Expand Down
15 changes: 15 additions & 0 deletions xla/hlo/parser/hlo_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,21 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) {
%send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0}
}
)"
},
{
"SendRecvWoChannelID",
R"(HloModule SendRecvWoChannelID_module, entry_computation_layout={()->(f32[], token[])}
ENTRY %computation () -> (f32[], token[]) {
%token0 = token[] after-all()
%recv = (f32[], u32[], token[]) recv(token[] %token0)
ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv)
%constant = f32[] constant(2.1)
%send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0)
%send-done = token[] send-done((f32[], u32[], token[]) %send)
}
)"
},
{
Expand Down
8 changes: 4 additions & 4 deletions xla/service/collective_permute_decomposer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,10 @@ absl::Status DecomposeCollectivePermute(
send->add_frontend_attributes(attributes);
send->set_metadata(metadata);

HloInstruction* recv_done =
computation->AddInstruction(HloInstruction::CreateRecvDone(recv));
HloInstruction* send_done =
computation->AddInstruction(HloInstruction::CreateSendDone(send));
HloInstruction* recv_done = computation->AddInstruction(
HloInstruction::CreateRecvDone(recv, channel_id));
HloInstruction* send_done = computation->AddInstruction(
HloInstruction::CreateSendDone(send, channel_id));

// We will add control dependence to represent how we want to order Send/Recv
// and other collective operations. Here we only add the necessary control
Expand Down
Loading