Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] Support partially pipelined async send recv ops #17446

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions xla/hlo/utils/hlo_query.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ HloComputation* FindComputation(HloModule* module, absl::string_view name) {
return *it;
}

// TODO: Make this return only the instruction.
std::pair<HloInstruction*, int> FindFirstInstruction(
const HloComputation* computation, absl::string_view name) {
int current_index = 0;
Expand All @@ -293,6 +294,7 @@ std::pair<HloInstruction*, int> FindFirstInstruction(
return {nullptr, -1};
}

// TODO: Make this return only the instruction.
std::pair<HloInstruction*, int> FindFirstInstruction(
const HloComputation* computation, HloOpcode opcode) {
int current_index = 0;
Expand All @@ -306,6 +308,7 @@ std::pair<HloInstruction*, int> FindFirstInstruction(
return {nullptr, -1};
}

// TODO: Remove this. It could be misleading as there is no linear order.
bool IsBeforeInComputation(const HloComputation* computation,
absl::string_view inst1, absl::string_view inst2) {
return FindFirstInstruction(computation, inst1).second <
Expand Down
2 changes: 2 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5234,6 +5234,7 @@ cc_library(
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:hlo_reachability",
"//xla/hlo/pass:hlo_pass",
"//xla/hlo/utils:hlo_query",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
Expand Down Expand Up @@ -5295,6 +5296,7 @@ xla_cc_test(
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_matchers",
"//xla/hlo/utils:hlo_query",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"@com_google_absl//absl/log",
Expand Down
188 changes: 149 additions & 39 deletions xla/service/copy_insertion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ limitations under the License.
#include "xla/service/copy_insertion.h"

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
Expand All @@ -35,12 +37,15 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/frontend_attributes.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_input_output_alias_config.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_reachability.h"
#include "xla/hlo/utils/hlo_query.h"
#include "xla/map_util.h"
#include "xla/service/call_graph.h"
#include "xla/service/compile_time_cap.h"
Expand Down Expand Up @@ -186,6 +191,22 @@ DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to,
return std::make_pair(from_deep_copy, to_deep_copy);
}

bool IsSendRecv(const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kSend ||
instruction->opcode() == HloOpcode::kRecv;
}

bool IsSendRecvDone(const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kSendDone ||
instruction->opcode() == HloOpcode::kRecvDone;
}

bool IsSendRecvInInit(const HloInstruction* init, const ShapeIndex& index) {
if (index.empty()) return false;
int64_t i = index.front();
return i < init->operand_count() && IsSendRecv(init->operand(i));
}

// Compute the indices of the loop state which need copies in order to avoid
// live range interference. Generally, an element in the loop state does not
// need to be copied if the element is passed through transparently through the
Expand All @@ -202,9 +223,14 @@ bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow,
for (auto& pair : *indices_to_copy) {
const ShapeIndex& index = pair.first;
bool& should_copy = pair.second;
// If there is any ambiguity, then loop state must be copied.
if (dataflow.GetValueSet(init, index).values().size() > 1 ||
dataflow.GetValueSet(xla_while, index).values().size() > 1) {
if (IsSendRecvInInit(init, index)) {
// Do not copy partially pipelined send/recv ops. The required copies will
// be inserted specifically for the send/recv ops.
should_copy = false;
continue;
} else if (dataflow.GetValueSet(init, index).values().size() > 1 ||
dataflow.GetValueSet(xla_while, index).values().size() > 1) {
// If there is any ambiguity, then loop state must be copied.
should_copy = true;
} else {
// If the output of the while instruction is not the same as the init
Expand Down Expand Up @@ -1307,42 +1333,6 @@ class CopyRemover {
if (buffer.values().at(0)->defining_instruction()->IsFused()) {
continue;
}
if (check_live_range_ordering) {
// Skip checking if execution thread is not included.
auto should_skip_value = [&execution_threads](const HloValue* value) {
return value->defining_instruction()->parent() != nullptr &&
!HloInstruction::IsThreadIncluded(value->defining_instruction()
->parent()
->execution_thread(),
execution_threads);
};
// Verify values contained in the buffer are strictly ordered. This
// should always be the case after adding copies to eliminate
// interference. Specifically, the addition of the control flow edges
// between copies added around aliased operations (kWhile) guarantees
// this strict order.
for (const HloValue* value_a : buffer.values()) {
if (value_a->shape().IsToken()) {
// Token values have no representation and cannot interfere.
continue;
}
if (should_skip_value(value_a)) {
continue;
}
for (const HloValue* value_b : buffer.values()) {
if (!should_skip_value(value_b) && value_a != value_b) {
DCHECK(ordering_->LiveRangeStrictlyBefore(
*value_a, *value_b, dataflow_,
/*use_is_always_before_def_in_same_instr=*/true) ||
ordering_->LiveRangeStrictlyBefore(
*value_b, *value_a, dataflow_,
/*use_is_always_before_def_in_same_instr=*/true))
<< value_a->ToString() << " and " << value_b->ToString()
<< " are not ordered";
}
}
}
}

std::vector<const HloValue*> values = buffer.values();
absl::c_sort(values, [this, &instruction_ids](const HloValue* a,
Expand Down Expand Up @@ -2014,6 +2004,122 @@ absl::Status CopyInsertion::AddCopiesForConditional(
return absl::OkStatus();
}

HloInstruction* FindAsyncSendRecvDoneInWhileBody(
const HloComputation* while_body, const HloInstruction* start_op) {
// Partially pipelined send/recv must have a single user.
if (start_op->user_count() != 1) return nullptr;
HloInstruction* unique_user = start_op->users().front();
// Send/recv must be consumed by send/recv-done op or be passed through the
// loop.
if (IsSendRecvDone(unique_user)) return unique_user;
if (unique_user->opcode() != HloOpcode::kTuple || !unique_user->IsRoot())
return nullptr;
int64_t index = unique_user->operand_index(start_op);
for (const HloInstruction* it :
while_body->parameter_instruction(0)->users()) {
const auto* gte = DynCast<HloGetTupleElementInstruction>(it);
if (gte->tuple_index() == index) {
CHECK_EQ(gte->user_count(), 1) << "send/recv in next loop iteration must "
"be consumed by unique send/recv-done.";
HloInstruction* next_unique_user = gte->users().front();
if (IsSendRecvDone(next_unique_user)) return next_unique_user;
}
}
return nullptr;
}

// Add copies for partially pipelined async send/recv.
// Copies are added before before startint to send and after finishing to recv.
// This is to prevent overlapping live times of the buffers. The control edges
// from the added copy to the recv or send-done operation guarantee disjoint
// live times.
//
//
// Before:
//
// kParameter kParameter
// | |
// kSendDone kRecvDone
// |
// ... consumer
//
// producer ...
// |
// kSend kRecv
// | |
// (body root) (body root)
//
//
// After:
//
// kParameter kParameter
// | |
// kSendDone ----+ kRecvDone
// | |
// ctrl kCopy ----+
// producer edge | |
// | | consumer ctrl
// kCopy <-----+ edge
// | |
// kSend kRecv <---+
// | |
// (body root) (body root)
//
absl::Status CopyInsertion::AddCopiesForAsyncSendRecv(
const HloAliasAnalysis& alias_analysis, HloInstruction* start_op) {
// If start op has multiple users, this must be the synchronous use of
// send/recv.
// TODO: Disambiguate sync and async use of send/recv b/369589022
if (start_op->users().size() != 1) return absl::OkStatus();

// If start feeds directly into done, the live time is contained and we don't
// need to add any copies.
HloInstruction* unique_user = start_op->users().front();
const HloOpcode done_opcode = start_op->opcode() == HloOpcode::kSend
? HloOpcode::kSendDone
: HloOpcode::kRecvDone;
if (unique_user->opcode() == done_opcode) {
return absl::OkStatus();
}

// For send/recv outside of the while loop, live times are disjoint. No copies
// needed.
HloComputation* while_body = start_op->parent();
if (!while_body->IsWhileBodyComputation()) return absl::OkStatus();

// Handle send case.
HloInstruction* done_op =
FindAsyncSendRecvDoneInWhileBody(while_body, start_op);
// TODO: Disambiguate sync and async use of send/recv b/369589022
if (done_op == nullptr) return absl::OkStatus();
if (start_op->opcode() == HloOpcode::kSend) {
HloInstruction* operand = start_op->mutable_operand(0);
HloInstruction* copied_operand =
while_body->AddInstruction(HloInstruction::CreateUnary(
operand->shape(), HloOpcode::kCopy, operand));
TF_RETURN_IF_ERROR(operand->ReplaceUseWith(start_op, copied_operand));
TF_RETURN_IF_ERROR(done_op->AddControlDependencyTo(copied_operand));
return absl::OkStatus();
}

// Handle recv case.
CHECK_EQ(start_op->opcode(), HloOpcode::kRecv);
PtrVec<HloInstruction*> done_op_users = done_op->users();
ShapeTree<HloInstruction*> copies_added(done_op->shape());
TF_ASSIGN_OR_RETURN(HloInstruction * done_op_copy,
while_body->DeepCopyInstruction(
done_op, /*indices_to_copy=*/nullptr, &copies_added));
for (auto [shape_index, instr] : copies_added) {
if (instr != nullptr)
TF_RETURN_IF_ERROR(instr->AddControlDependencyTo(start_op));
}
TF_RETURN_IF_ERROR(done_op->AddControlDependencyTo(start_op));
for (HloInstruction* it : done_op_users) {
TF_RETURN_IF_ERROR(done_op->ReplaceUseWith(it, done_op_copy));
}
return absl::OkStatus();
}

// Add kCopy instructions to the given module to guarantee there is no
// live-range interference. Generally interference can only occur around kWhile
// instructions which have update-in-place semantics.
Expand All @@ -2034,6 +2140,10 @@ absl::Status CopyInsertion::AddCopiesToResolveInterference(
} else if (instruction->opcode() == HloOpcode::kConditional) {
TF_RETURN_IF_ERROR(
AddCopiesForConditional(*alias_analysis, instruction));
} else if (IsSendRecv(instruction)) {
// TODO: Generalize this to all async collectives.
TF_RETURN_IF_ERROR(
AddCopiesForAsyncSendRecv(*alias_analysis, instruction));
} else {
// When an operand is a tuple, we avoid copying the operand multiple
// times by recording and checking the operand number of operands that
Expand Down
4 changes: 4 additions & 0 deletions xla/service/copy_insertion.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ class CopyInsertion : public HloModulePass {
virtual absl::Status AddCopiesForConditional(
const HloAliasAnalysis& alias_analysis, HloInstruction* conditional);

// Add copies for async send/recv instructions.
absl::Status AddCopiesForAsyncSendRecv(const HloAliasAnalysis& alias_analysis,
HloInstruction* async);

// Backend specific function that decides whether an instruction can share
// buffer with its operand.
HloDataflowAnalysis::CanShareBuffer can_share_buffer_;
Expand Down
Loading
Loading