Skip to content

Commit

Permalink
Make channel_id optional for send/recv ops
Browse files Browse the repository at this point in the history
This is a preparation to enable cross-replica send/recv ops.

PiperOrigin-RevId: 695481949
  • Loading branch information
frgossen authored and Google-ML-Automation committed Nov 14, 2024
1 parent dde3c51 commit 8153844
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 77 deletions.
40 changes: 14 additions & 26 deletions xla/hlo/ir/hlo_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,8 @@ absl::StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
case HloOpcode::kSendDone:
TF_RET_CHECK(DynCast<HloSendInstruction>(operands(0)) != nullptr)
<< "SendDone must take the context operand from Send";
instruction = CreateSendDone(operands(0), proto.is_host_transfer());
instruction = CreateSendDone(operands(0), proto.channel_id(),
proto.is_host_transfer());
break;
case HloOpcode::kRecv:
instruction = CreateRecv(shape.tuple_shapes(0), operands(0),
Expand All @@ -474,7 +475,8 @@ absl::StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
case HloOpcode::kRecvDone:
TF_RET_CHECK(DynCast<HloRecvInstruction>(operands(0)) != nullptr)
<< "RecvDone must take the context operand from Recv";
instruction = CreateRecvDone(operands(0), proto.is_host_transfer());
instruction = CreateRecvDone(operands(0), proto.channel_id(),
proto.is_host_transfer());
break;
case HloOpcode::kReverse:
instruction =
Expand Down Expand Up @@ -1814,45 +1816,31 @@ HloInstruction::CreateCollectivePermuteStart(
}

/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
HloInstruction* operand, HloInstruction* token, int64_t channel_id,
bool is_host_transfer) {
HloInstruction* operand, HloInstruction* token,
std::optional<int64_t> channel_id, bool is_host_transfer) {
return std::make_unique<HloSendInstruction>(operand, token, channel_id,
is_host_transfer);
}

/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
HloInstruction* operand, bool is_host_transfer) {
auto send_operand = DynCast<HloSendInstruction>(operand);
CHECK(send_operand != nullptr)
<< "SendDone must take the context operand from Send";
return std::make_unique<HloSendDoneInstruction>(send_operand,
is_host_transfer);
}

/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
HloInstruction* operand, int64_t channel_id, bool is_host_transfer) {
HloInstruction* operand, std::optional<int64_t> channel_id,
bool is_host_transfer) {
CHECK(operand->channel_id() == channel_id);
return std::make_unique<HloSendDoneInstruction>(operand, channel_id,
is_host_transfer);
}

/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
const Shape& shape, HloInstruction* token, int64_t channel_id,
bool is_host_transfer) {
const Shape& shape, HloInstruction* token,
std::optional<int64_t> channel_id, bool is_host_transfer) {
return std::make_unique<HloRecvInstruction>(shape, token, channel_id,
is_host_transfer);
}

/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
HloInstruction* operand, bool is_host_transfer) {
auto recv_operand = DynCast<HloRecvInstruction>(operand);
CHECK(recv_operand != nullptr)
<< "RecvDone must take the context operand from Recv";
return std::make_unique<HloRecvDoneInstruction>(recv_operand,
is_host_transfer);
}

/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
HloInstruction* operand, int64_t channel_id, bool is_host_transfer) {
HloInstruction* operand, std::optional<int64_t> channel_id,
bool is_host_transfer) {
CHECK(operand->channel_id() == channel_id);
return std::make_unique<HloRecvDoneInstruction>(operand, channel_id,
is_host_transfer);
}
Expand Down
18 changes: 6 additions & 12 deletions xla/hlo/ir/hlo_instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1191,16 +1191,13 @@ class HloInstruction {
// another computation that has the same channel id. If is_host_transfer is
// true, then this Send operation transfers data to the host.
static std::unique_ptr<HloInstruction> CreateSend(
HloInstruction* operand, HloInstruction* token, int64_t channel_id,
bool is_host_transfer = false);
HloInstruction* operand, HloInstruction* token,
std::optional<int64_t> channel_id, bool is_host_transfer = false);

// Blocks until data transfer for the Send instruction (operand) is complete.
// The operand must be kSend.
static std::unique_ptr<HloInstruction> CreateSendDone(
HloInstruction* operand, bool is_host_transfer = false);
// Similar to the above, but the operand doesn't have to be a kSend.
static std::unique_ptr<HloInstruction> CreateSendDone(
HloInstruction* operand, int64_t channel_id,
HloInstruction* operand, std::optional<int64_t> channel_id,
bool is_host_transfer = false);

// Creates an asynchronous receive instruction with the given channel id,
Expand All @@ -1209,16 +1206,13 @@ class HloInstruction {
// is_host_transfer is true, then this Recv operation transfers data from the
// host.
static std::unique_ptr<HloInstruction> CreateRecv(
const Shape& shape, HloInstruction* token, int64_t channel_id,
bool is_host_transfer = false);
const Shape& shape, HloInstruction* token,
std::optional<int64_t> channel_id, bool is_host_transfer = false);

// Blocks until data transfer for the Recv instruction (operand) is complete
// and returns the receive buffer. The operand must be kRecv.
static std::unique_ptr<HloInstruction> CreateRecvDone(
HloInstruction* operand, bool is_host_transfer = false);
// Similar to the above, but the operand doesn't have to be a kRecv.
static std::unique_ptr<HloInstruction> CreateRecvDone(
HloInstruction* operand, int64_t channel_id,
HloInstruction* operand, std::optional<int64_t> channel_id,
bool is_host_transfer = false);

// Creates a slice instruction, where the operand is sliced by the given
Expand Down
42 changes: 20 additions & 22 deletions xla/hlo/ir/hlo_instructions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -769,10 +769,9 @@ bool HloTopKInstruction::IdenticalSlowPath(
return k() == casted_other.k() && largest() == casted_other.largest();
}

HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
const Shape& shape,
int64_t channel_id,
bool is_host_transfer)
HloSendRecvInstruction::HloSendRecvInstruction(
HloOpcode opcode, const Shape& shape, std::optional<int64_t> channel_id,
bool is_host_transfer)
: HloChannelInstruction(opcode, shape, channel_id),
is_host_transfer_(is_host_transfer) {}

Expand Down Expand Up @@ -802,7 +801,7 @@ bool HloSendRecvInstruction::IdenticalSlowPathIgnoringChannelIdValues(
// Send instruction produces a tuple of {aliased operand, U32 context}.
HloSendInstruction::HloSendInstruction(HloInstruction* operand,
HloInstruction* token,
int64_t channel_id,
std::optional<int64_t> channel_id,
bool is_host_transfer)
: HloSendRecvInstruction(
HloOpcode::kSend,
Expand All @@ -818,21 +817,20 @@ std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return std::make_unique<HloSendInstruction>(
new_operands[0], new_operands[1], *channel_id(), is_host_transfer());
return std::make_unique<HloSendInstruction>(new_operands[0], new_operands[1],
channel_id(), is_host_transfer());
}

HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
bool is_host_transfer)
: HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
CHECK_NOTNULL(operand)->channel_id().value(),
is_host_transfer) {
operand->channel_id(), is_host_transfer) {
AppendOperand(operand);
}

HloSendDoneInstruction::HloSendDoneInstruction(HloInstruction* operand,
int64_t channel_id,
bool is_host_transfer)
HloSendDoneInstruction::HloSendDoneInstruction(
HloInstruction* operand, std::optional<int64_t> channel_id,
bool is_host_transfer)
: HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
channel_id, is_host_transfer) {
AppendOperand(operand);
Expand All @@ -848,14 +846,14 @@ HloSendDoneInstruction::CloneWithNewOperandsImpl(
return std::make_unique<HloSendDoneInstruction>(send, is_host_transfer());
}

return std::make_unique<HloSendDoneInstruction>(
new_operands[0], channel_id().value(), is_host_transfer());
return std::make_unique<HloSendDoneInstruction>(new_operands[0], channel_id(),
is_host_transfer());
}

// Recv instruction produces a tuple of {receive buffer, U32 context}.
HloRecvInstruction::HloRecvInstruction(const Shape& shape,
HloInstruction* token,
int64_t channel_id,
std::optional<int64_t> channel_id,
bool is_host_transfer)
: HloSendRecvInstruction(
HloOpcode::kRecv,
Expand All @@ -870,7 +868,7 @@ std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return std::make_unique<HloRecvInstruction>(
ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], *channel_id(),
ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(),
is_host_transfer());
}

Expand All @@ -881,13 +879,13 @@ HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand,
ShapeUtil::MakeTupleShape(
{ShapeUtil::GetTupleElementShape(operand->shape(), 0),
ShapeUtil::MakeTokenShape()}),
CHECK_NOTNULL(operand)->channel_id().value(), is_host_transfer) {
operand->channel_id(), is_host_transfer) {
AppendOperand(operand);
}

HloRecvDoneInstruction::HloRecvDoneInstruction(HloInstruction* operand,
int64_t channel_id,
bool is_host_transfer)
HloRecvDoneInstruction::HloRecvDoneInstruction(
HloInstruction* operand, std::optional<int64_t> channel_id,
bool is_host_transfer)
: HloSendRecvInstruction(
HloOpcode::kRecvDone,
ShapeUtil::MakeTupleShape(
Expand All @@ -907,8 +905,8 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl(
return std::make_unique<HloRecvDoneInstruction>(recv, is_host_transfer());
}

return std::make_unique<HloRecvDoneInstruction>(
new_operands[0], channel_id().value(), is_host_transfer());
return std::make_unique<HloRecvDoneInstruction>(new_operands[0], channel_id(),
is_host_transfer());
}

HloCollectiveInstruction::HloCollectiveInstruction(
Expand Down
15 changes: 10 additions & 5 deletions xla/hlo/ir/hlo_instructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,8 @@ class HloSendRecvInstruction : public HloChannelInstruction {

protected:
explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape,
int64_t channel_id, bool is_host_transfer);
std::optional<int64_t> channel_id,
bool is_host_transfer);

private:
void PrintExtraAttributesImpl(AttributePrinter& printer,
Expand All @@ -571,7 +572,8 @@ class HloSendRecvInstruction : public HloChannelInstruction {
class HloSendInstruction : public HloSendRecvInstruction {
public:
explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token,
int64_t channel_id, bool is_host_transfer);
std::optional<int64_t> channel_id,
bool is_host_transfer);

static bool ClassOf(const HloInstruction* hlo) {
return hlo->opcode() == HloOpcode::kSend;
Expand All @@ -588,7 +590,8 @@ class HloSendDoneInstruction : public HloSendRecvInstruction {
public:
explicit HloSendDoneInstruction(HloSendInstruction* operand,
bool is_host_transfer);
explicit HloSendDoneInstruction(HloInstruction* operand, int64_t channel_id,
explicit HloSendDoneInstruction(HloInstruction* operand,
std::optional<int64_t> channel_id,
bool is_host_transfer);
static bool ClassOf(const HloInstruction* hlo) {
return hlo->opcode() == HloOpcode::kSendDone;
Expand All @@ -604,7 +607,8 @@ class HloSendDoneInstruction : public HloSendRecvInstruction {
class HloRecvInstruction : public HloSendRecvInstruction {
public:
explicit HloRecvInstruction(const Shape& shape, HloInstruction* token,
int64_t channel_id, bool is_host_transfer);
std::optional<int64_t> channel_id,
bool is_host_transfer);

static bool ClassOf(const HloInstruction* hlo) {
return hlo->opcode() == HloOpcode::kRecv;
Expand All @@ -621,7 +625,8 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction {
public:
explicit HloRecvDoneInstruction(HloRecvInstruction* operand,
bool is_host_transfer);
explicit HloRecvDoneInstruction(HloInstruction* operand, int64_t channel_id,
explicit HloRecvDoneInstruction(HloInstruction* operand,
std::optional<int64_t> channel_id,
bool is_host_transfer);

static bool ClassOf(const HloInstruction* hlo) {
Expand Down
16 changes: 8 additions & 8 deletions xla/hlo/parser/hlo_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2188,7 +2188,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT
optional<int64_t> channel_id;
// If the is_host_transfer attribute is not present then default to false.
optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
&is_host_transfer};
if ((!preset_operands &&
Expand All @@ -2198,13 +2198,13 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT
}
// If the is_host_transfer attribute is not present then default to false.
return builder->AddInstruction(HloInstruction::CreateRecv(
shape->tuple_shapes(0), operands[0], *channel_id, *is_host_transfer));
shape->tuple_shapes(0), operands[0], channel_id, *is_host_transfer));
}
case HloOpcode::kRecvDone: {
optional<int64_t> channel_id;
// If the is_host_transfer attribute is not present then default to false.
optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
&is_host_transfer};
if ((!preset_operands &&
Expand All @@ -2220,13 +2220,13 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT
}

return builder->AddInstruction(HloInstruction::CreateRecvDone(
operands[0], channel_id.value(), *is_host_transfer));
operands[0], channel_id, *is_host_transfer));
}
case HloOpcode::kSend: {
optional<int64_t> channel_id;
// If the is_host_transfer attribute is not present then default to false.
optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
&is_host_transfer};
if ((!preset_operands &&
Expand All @@ -2235,13 +2235,13 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT
return nullptr;
}
return builder->AddInstruction(HloInstruction::CreateSend(
operands[0], operands[1], *channel_id, *is_host_transfer));
operands[0], operands[1], channel_id, *is_host_transfer));
}
case HloOpcode::kSendDone: {
optional<int64_t> channel_id;
// If the is_host_transfer attribute is not present then default to false.
optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
&is_host_transfer};
if ((!preset_operands &&
Expand All @@ -2257,7 +2257,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT
}

return builder->AddInstruction(HloInstruction::CreateSendDone(
operands[0], channel_id.value(), *is_host_transfer));
operands[0], channel_id, *is_host_transfer));
}
case HloOpcode::kGetTupleElement: {
optional<int64_t> index;
Expand Down
15 changes: 15 additions & 0 deletions xla/hlo/parser/hlo_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,21 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) {
%send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0}
}
)"
},
{
"SendRecvWoChannelID",
R"(HloModule SendRecvWoChannelID_module, entry_computation_layout={()->(f32[], token[])}
ENTRY %computation () -> (f32[], token[]) {
%token0 = token[] after-all()
%recv = (f32[], u32[], token[]) recv(token[] %token0)
ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv)
%constant = f32[] constant(2.1)
%send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0)
%send-done = token[] send-done((f32[], u32[], token[]) %send)
}
)"
},
{
Expand Down
8 changes: 4 additions & 4 deletions xla/service/collective_permute_decomposer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,10 @@ absl::Status DecomposeCollectivePermute(
send->add_frontend_attributes(attributes);
send->set_metadata(metadata);

HloInstruction* recv_done =
computation->AddInstruction(HloInstruction::CreateRecvDone(recv));
HloInstruction* send_done =
computation->AddInstruction(HloInstruction::CreateSendDone(send));
HloInstruction* recv_done = computation->AddInstruction(
HloInstruction::CreateRecvDone(recv, channel_id));
HloInstruction* send_done = computation->AddInstruction(
HloInstruction::CreateSendDone(send, channel_id));

// We will add control dependence to represent how we want to order Send/Recv
// and other collective operations. Here we only add the necessary control
Expand Down

0 comments on commit 8153844

Please sign in to comment.