Skip to content

Commit

Permalink
[XLA:GPU] Support partially pipelined async send recv ops
Browse files Browse the repository at this point in the history
This is needed for pipeline parallelism on GPU where the send/recv operations
are issued in one loop iteration and completed in the next. The same buffer
must be alive throughout the process and no copies can be inserted.

Avoid copies for these partially pipelined async send/recv ops. Insert the
required copies and controlflow constraints on the send/recv ops separately.
This is to ensure that the live times of the buffers do not overlap.

Send: For send, a copy is inserted on the operand, starting a new live range.
By enforcing this copy after the corresponding send/done, buffer live times are
disjoint.

Recv: For recv, a copy is inserted after recv-done, ending the live time of the
buffer. Bt enforcing the copy to be before the corresponding recv. buffer live
times are disjoint.
PiperOrigin-RevId: 676075106
  • Loading branch information
frgossen authored and Google-ML-Automation committed Oct 2, 2024
1 parent 93be085 commit b886ebe
Show file tree
Hide file tree
Showing 5 changed files with 452 additions and 39 deletions.
3 changes: 3 additions & 0 deletions xla/hlo/utils/hlo_query.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ HloComputation* FindComputation(HloModule* module, absl::string_view name) {
return *it;
}

// TODO: Make this return only the instruction.
std::pair<HloInstruction*, int> FindFirstInstruction(
const HloComputation* computation, absl::string_view name) {
int current_index = 0;
Expand All @@ -293,6 +294,7 @@ std::pair<HloInstruction*, int> FindFirstInstruction(
return {nullptr, -1};
}

// TODO: Make this return only the instruction.
std::pair<HloInstruction*, int> FindFirstInstruction(
const HloComputation* computation, HloOpcode opcode) {
int current_index = 0;
Expand All @@ -306,6 +308,7 @@ std::pair<HloInstruction*, int> FindFirstInstruction(
return {nullptr, -1};
}

// TODO: Remove this. It could be misleading as there is no linear order.
bool IsBeforeInComputation(const HloComputation* computation,
absl::string_view inst1, absl::string_view inst2) {
return FindFirstInstruction(computation, inst1).second <
Expand Down
2 changes: 2 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5234,6 +5234,7 @@ cc_library(
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:hlo_reachability",
"//xla/hlo/pass:hlo_pass",
"//xla/hlo/utils:hlo_query",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
Expand Down Expand Up @@ -5295,6 +5296,7 @@ xla_cc_test(
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_matchers",
"//xla/hlo/utils:hlo_query",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"@com_google_absl//absl/log",
Expand Down
188 changes: 149 additions & 39 deletions xla/service/copy_insertion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ limitations under the License.
#include "xla/service/copy_insertion.h"

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
Expand All @@ -35,12 +37,15 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/frontend_attributes.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_input_output_alias_config.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_reachability.h"
#include "xla/hlo/utils/hlo_query.h"
#include "xla/map_util.h"
#include "xla/service/call_graph.h"
#include "xla/service/compile_time_cap.h"
Expand Down Expand Up @@ -186,6 +191,22 @@ DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to,
return std::make_pair(from_deep_copy, to_deep_copy);
}

bool IsSendRecv(const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kSend ||
instruction->opcode() == HloOpcode::kRecv;
}

bool IsSendRecvDone(const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kSendDone ||
instruction->opcode() == HloOpcode::kRecvDone;
}

bool IsSendRecvInInit(const HloInstruction* init, const ShapeIndex& index) {
if (index.empty()) return false;
int64_t i = index.front();
return i < init->operand_count() && IsSendRecv(init->operand(i));
}

// Compute the indices of the loop state which need copies in order to avoid
// live range interference. Generally, an element in the loop state does not
// need to be copied if the element is passed through transparently through the
Expand All @@ -202,9 +223,14 @@ bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow,
for (auto& pair : *indices_to_copy) {
const ShapeIndex& index = pair.first;
bool& should_copy = pair.second;
// If there is any ambiguity, then loop state must be copied.
if (dataflow.GetValueSet(init, index).values().size() > 1 ||
dataflow.GetValueSet(xla_while, index).values().size() > 1) {
if (IsSendRecvInInit(init, index)) {
// Do not copy partially pipelined send/recv ops. The required copies will
// be inserted specifically for the send/recv ops.
should_copy = false;
continue;
} else if (dataflow.GetValueSet(init, index).values().size() > 1 ||
dataflow.GetValueSet(xla_while, index).values().size() > 1) {
// If there is any ambiguity, then loop state must be copied.
should_copy = true;
} else {
// If the output of the while instruction is not the same as the init
Expand Down Expand Up @@ -1307,42 +1333,6 @@ class CopyRemover {
if (buffer.values().at(0)->defining_instruction()->IsFused()) {
continue;
}
if (check_live_range_ordering) {
// Skip checking if execution thread is not included.
auto should_skip_value = [&execution_threads](const HloValue* value) {
return value->defining_instruction()->parent() != nullptr &&
!HloInstruction::IsThreadIncluded(value->defining_instruction()
->parent()
->execution_thread(),
execution_threads);
};
// Verify values contained in the buffer are strictly ordered. This
// should always be the case after adding copies to eliminate
// interference. Specifically, the addition of the control flow edges
// between copies added around aliased operations (kWhile) guarantees
// this strict order.
for (const HloValue* value_a : buffer.values()) {
if (value_a->shape().IsToken()) {
// Token values have no representation and cannot interfere.
continue;
}
if (should_skip_value(value_a)) {
continue;
}
for (const HloValue* value_b : buffer.values()) {
if (!should_skip_value(value_b) && value_a != value_b) {
DCHECK(ordering_->LiveRangeStrictlyBefore(
*value_a, *value_b, dataflow_,
/*use_is_always_before_def_in_same_instr=*/true) ||
ordering_->LiveRangeStrictlyBefore(
*value_b, *value_a, dataflow_,
/*use_is_always_before_def_in_same_instr=*/true))
<< value_a->ToString() << " and " << value_b->ToString()
<< " are not ordered";
}
}
}
}

std::vector<const HloValue*> values = buffer.values();
absl::c_sort(values, [this, &instruction_ids](const HloValue* a,
Expand Down Expand Up @@ -2014,6 +2004,122 @@ absl::Status CopyInsertion::AddCopiesForConditional(
return absl::OkStatus();
}

HloInstruction* FindAsyncSendRecvDoneInWhileBody(
const HloComputation* while_body, const HloInstruction* start_op) {
// Partially pipelined send/recv must have a single user.
if (start_op->user_count() != 1) return nullptr;
HloInstruction* unique_user = start_op->users().front();
// Send/recv must be consumed by send/recv-done op or be passed through the
// loop.
if (IsSendRecvDone(unique_user)) return unique_user;
if (unique_user->opcode() != HloOpcode::kTuple || !unique_user->IsRoot())
return nullptr;
int64_t index = unique_user->operand_index(start_op);
for (const HloInstruction* it :
while_body->parameter_instruction(0)->users()) {
const auto* gte = DynCast<HloGetTupleElementInstruction>(it);
if (gte->tuple_index() == index) {
CHECK_EQ(gte->user_count(), 1) << "send/recv in next loop iteration must "
"be consumed by unique send/recv-done.";
HloInstruction* next_unique_user = gte->users().front();
if (IsSendRecvDone(next_unique_user)) return next_unique_user;
}
}
return nullptr;
}

// Add copies for partially pipelined async send/recv.
// Copies are added before before startint to send and after finishing to recv.
// This is to prevent overlapping live times of the buffers. The control edges
// from the added copy to the recv or send-done operation guarantee disjoint
// live times.
//
//
// Before:
//
// kParameter kParameter
// | |
// kSendDone kRecvDone
// |
// ... consumer
//
// producer ...
// |
// kSend kRecv
// | |
// (body root) (body root)
//
//
// After:
//
// kParameter kParameter
// | |
// kSendDone ----+ kRecvDone
// | |
// ctrl kCopy ----+
// producer edge | |
// | | consumer ctrl
// kCopy <-----+ edge
// | |
// kSend kRecv <---+
// | |
// (body root) (body root)
//
absl::Status CopyInsertion::AddCopiesForAsyncSendRecv(
const HloAliasAnalysis& alias_analysis, HloInstruction* start_op) {
// If start op has multiple users, this must be the synchronous use of
// send/recv.
// TODO: Disambiguate sync and async use of send/recv b/369589022
if (start_op->users().size() != 1) return absl::OkStatus();

// If start feeds directly into done, the live time is contained and we don't
// need to add any copies.
HloInstruction* unique_user = start_op->users().front();
const HloOpcode done_opcode = start_op->opcode() == HloOpcode::kSend
? HloOpcode::kSendDone
: HloOpcode::kRecvDone;
if (unique_user->opcode() == done_opcode) {
return absl::OkStatus();
}

// For send/recv outside of the while loop, live times are disjoint. No copies
// needed.
HloComputation* while_body = start_op->parent();
if (!while_body->IsWhileBodyComputation()) return absl::OkStatus();

// Handle send case.
HloInstruction* done_op =
FindAsyncSendRecvDoneInWhileBody(while_body, start_op);
// TODO: Disambiguate sync and async use of send/recv b/369589022
if (done_op == nullptr) return absl::OkStatus();
if (start_op->opcode() == HloOpcode::kSend) {
HloInstruction* operand = start_op->mutable_operand(0);
HloInstruction* copied_operand =
while_body->AddInstruction(HloInstruction::CreateUnary(
operand->shape(), HloOpcode::kCopy, operand));
TF_RETURN_IF_ERROR(operand->ReplaceUseWith(start_op, copied_operand));
TF_RETURN_IF_ERROR(done_op->AddControlDependencyTo(copied_operand));
return absl::OkStatus();
}

// Handle recv case.
CHECK_EQ(start_op->opcode(), HloOpcode::kRecv);
PtrVec<HloInstruction*> done_op_users = done_op->users();
ShapeTree<HloInstruction*> copies_added(done_op->shape());
TF_ASSIGN_OR_RETURN(HloInstruction * done_op_copy,
while_body->DeepCopyInstruction(
done_op, /*indices_to_copy=*/nullptr, &copies_added));
for (auto [shape_index, instr] : copies_added) {
if (instr != nullptr)
TF_RETURN_IF_ERROR(instr->AddControlDependencyTo(start_op));
}
TF_RETURN_IF_ERROR(done_op->AddControlDependencyTo(start_op));
for (HloInstruction* it : done_op_users) {
TF_RETURN_IF_ERROR(done_op->ReplaceUseWith(it, done_op_copy));
}
return absl::OkStatus();
}

// Add kCopy instructions to the given module to guarantee there is no
// live-range interference. Generally interference can only occur around kWhile
// instructions which have update-in-place semantics.
Expand All @@ -2034,6 +2140,10 @@ absl::Status CopyInsertion::AddCopiesToResolveInterference(
} else if (instruction->opcode() == HloOpcode::kConditional) {
TF_RETURN_IF_ERROR(
AddCopiesForConditional(*alias_analysis, instruction));
} else if (IsSendRecv(instruction)) {
// TODO: Generalize this to all async collectives.
TF_RETURN_IF_ERROR(
AddCopiesForAsyncSendRecv(*alias_analysis, instruction));
} else {
// When an operand is a tuple, we avoid copying the operand multiple
// times by recording and checking the operand number of operands that
Expand Down
4 changes: 4 additions & 0 deletions xla/service/copy_insertion.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ class CopyInsertion : public HloModulePass {
virtual absl::Status AddCopiesForConditional(
const HloAliasAnalysis& alias_analysis, HloInstruction* conditional);

// Add copies for async send/recv instructions.
absl::Status AddCopiesForAsyncSendRecv(const HloAliasAnalysis& alias_analysis,
HloInstruction* async);

// Backend specific function that decides whether an instruction can share
// buffer with its operand.
HloDataflowAnalysis::CanShareBuffer can_share_buffer_;
Expand Down
Loading

0 comments on commit b886ebe

Please sign in to comment.