Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA] Directly track callers and callees of an HloComputation. #23516

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1393,6 +1393,32 @@ xla_cc_test(
],
)

cc_library(
name = "online_topsort",
hdrs = ["online_topsort.h"],
deps = [
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
],
)

xla_cc_test(
name = "online_topsort_test",
srcs = ["online_topsort_test.cc"],
deps = [
":online_topsort",
"//xla/tsl/platform:test",
"//xla/tsl/platform:test_main",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/log:log_streamer",
"@com_google_absl//absl/random",
"@com_google_absl//absl/strings",
],
)

bzl_library(
name = "lit_bzl",
srcs = ["lit.bzl"],
Expand Down
8 changes: 7 additions & 1 deletion xla/hlo/evaluator/hlo_evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1036,13 +1036,19 @@ absl::StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(

std::unique_ptr<HloInstruction> cloned_instruction =
instruction->CloneWithNewOperands(instruction->shape(), operands);
// TODO(phawkins): it's unfortunate that we need to call set_parent() here.
// TODO(phawkins): it's unfortunate that we need to call set_parent() here,
// since it violates the invariant that an instruction has a parent iff it is
// in a computation.
// It's probably better to avoid constructing new instructions here in the
// first place.
cloned_instruction->set_parent(
const_cast<HloComputation*>(instruction->parent()));
auto result = Evaluate(cloned_instruction.get());

// Undo the parent change, since it will confuse code that expects the
// instruction to be in a computation.
cloned_instruction->set_parent(nullptr);

return result;
}

Expand Down
98 changes: 96 additions & 2 deletions xla/hlo/ir/hlo_computation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ limitations under the License.
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/btree_map.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/functional/function_ref.h"
#include "absl/log/check.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
Expand Down Expand Up @@ -183,11 +185,38 @@ HloComputation::~HloComputation() {
async_start_->ClearCalledComputations();
}
Cleanup();
ClearCalledComputations();

// We need to make sure there are no dangling references to this computation
// from instructions in other computations.
std::vector<HloComputation*> callers;
for (const auto& [caller, count] : caller_computations_) {
callers.push_back(caller);
}
for (HloComputation* caller : callers) {
for (HloInstruction* inst : caller->instructions()) {
for (int i = 0; i < inst->called_computations().size(); ++i) {
if (inst->called_computations()[i] == this) {
inst->set_called_computation(i, nullptr);
}
}
}
}
CHECK(caller_computations_.empty());

for (const auto& i : instructions_) {
delete i.inst();
}
}

void HloComputation::ClearCalledComputations() {
for (HloInstruction* i : instructions()) {
i->ClearCalledComputations();
}
// Clearing the instructions should have removed all callee computations.
CHECK(callee_computations_.empty());
}

void HloComputation::SetInstruction(HloInstruction* instruction,
InstructionType type) {
static_assert(alignof(HloInstruction) == kInstructionTypeMask + 1,
Expand Down Expand Up @@ -241,6 +270,38 @@ HloInstruction* HloComputation::AddInstruction(
return AddInstruction(std::move(instruction));
}

static void IncrementCount(
absl::btree_map<HloComputation*, int, HloComputation::UniqueIdComparator>&
map,
HloComputation* key) {
++map[key];
}

// Returns true if the callee was present and its count was decremented; returns
// false if the callee was not present.
static void DecrementCount(
absl::btree_map<HloComputation*, int, HloComputation::UniqueIdComparator>&
map,
HloComputation* key) {
auto it = map.find(key);
CHECK(it != map.end());
CHECK_GT(it->second, 0);
--it->second;
if (it->second == 0) {
map.erase(it);
}
}

void HloComputation::AddCallee(HloComputation* callee) {
IncrementCount(callee_computations_, callee);
IncrementCount(callee->caller_computations_, this);
}

void HloComputation::RemoveCallee(HloComputation* callee) {
DecrementCount(callee_computations_, callee);
DecrementCount(callee->caller_computations_, this);
}

HloInstruction* HloComputation::AddInstructionInternal(
std::unique_ptr<HloInstruction> instruction) {
if (parent() != nullptr) {
Expand All @@ -265,6 +326,7 @@ HloInstruction* HloComputation::AddInstructionInternal(
CHECK(parent() == nullptr || called_computation->parent() == parent())
<< "Called computation " << called_computation->name()
<< " is not in the same module as " << name();
AddCallee(called_computation);
}
return pinst;
}
Expand Down Expand Up @@ -521,13 +583,13 @@ absl::Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction,

HloInstructionInfo* info = &instructions_[instruction->index_in_parent_];
DCHECK_EQ(info->inst(), instruction);
info->inst()->set_parent(nullptr);
to_be_deleted_.push_back(info->inst()); // Takes ownership
to_be_deleted_.back()->DetachFromOperandsAndUsers();
// Clear all operands to avoid Null operands.
to_be_deleted_.back()->RemoveAllOperands();
to_be_deleted_.back()->ClearCalledComputations();
to_be_deleted_.back()->MarkAsDead();
info->inst()->set_parent(nullptr);

// If this instruction is a constant, clear the literal eagerly instead of
// waiting for the instruction to be deleted in Cleanup(). This greatly
Expand Down Expand Up @@ -1089,7 +1151,7 @@ HloComputation::CreateFromProto(

auto computation = absl::WrapUnique(
new HloComputation(proto.name(), parameter_count, &instructions, root));
computation->unique_id_ = proto.id();
computation->SetUniqueIdHelper(proto.id());
if (proto.is_fusion_computation()) {
computation->instruction_and_type_ =
static_cast<uintptr_t>(InstructionType::kFusion);
Expand Down Expand Up @@ -1840,4 +1902,36 @@ bool HloComputation::CanExpandIntoSingleInstruction() const {
});
}

void HloComputation::ClearUniqueIdInternal() { SetUniqueIdHelper(-1); }

void HloComputation::SetUniqueId(int64_t id) {
CHECK_EQ(unique_id_, -1);
CHECK_GE(id, 0);
SetUniqueIdHelper(id);
}

void HloComputation::SetUniqueIdHelper(int64_t id) {
// The caller/callee computations are ordered by unique ID, so we need to
// remove and readd them to our neighbor's data structures.
for (auto& [computation, count] : caller_computations_) {
auto it = computation->callee_computations_.find(this);
CHECK(it != computation->callee_computations_.end());
CHECK_EQ(it->second, count);
computation->callee_computations_.erase(it);
}
for (auto& [computation, count] : callee_computations_) {
auto it = computation->caller_computations_.find(this);
CHECK(it != computation->caller_computations_.end());
CHECK_EQ(it->second, count);
computation->caller_computations_.erase(it);
}
unique_id_ = id;
for (auto& [computation, count] : caller_computations_) {
computation->callee_computations_[this] = count;
}
for (auto& [computation, count] : callee_computations_) {
computation->caller_computations_[this] = count;
}
}

} // namespace xla
69 changes: 63 additions & 6 deletions xla/hlo/ir/hlo_computation.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ limitations under the License.
#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/btree_map.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
Expand Down Expand Up @@ -917,14 +919,10 @@ class HloComputation {

// Clear the unique ID of the computation so that it can be re-assigned, such
// as for the purpose of compacting the unique IDs.
void ClearUniqueIdInternal() { unique_id_ = -1; }
void ClearUniqueIdInternal();

// The id of this computation should be unique within the module.
void SetUniqueId(int64_t id) {
CHECK_EQ(unique_id_, -1);
CHECK_GE(id, 0);
unique_id_ = id;
}
void SetUniqueId(int64_t id);

// Returns the instruction in this computation that has name `name`. Returns
// null if there is no such computation.
Expand Down Expand Up @@ -957,6 +955,34 @@ class HloComputation {
// Returns true iff this computation can be inlined as a single instruction.
bool CanExpandIntoSingleInstruction() const;

// A comparator that orders computations by their unique IDs. This is used
// for determinism.
struct UniqueIdComparator {
bool operator()(const HloComputation* lhs,
const HloComputation* rhs) const {
// We include the computation pointer so that we can disambiguate
// computations that do not belong to any module and therefore have a
// unique ID of -1. This is not deterministic, but we don't need
// determinism for computations not in a module since they are ignored
// by the topological sorting code.
return std::tie(lhs->unique_id_, lhs) < std::tie(rhs->unique_id_, rhs);
}
};

// Count of times this computation calls other computations.
absl::btree_map<HloComputation*, int, UniqueIdComparator>
callee_computations() const {
return callee_computations_;
}

// Count of times this computation is called by other computations.
absl::btree_map<HloComputation*, int, UniqueIdComparator>
caller_computations() const {
return caller_computations_;
}

void ClearCalledComputations();

private:
friend class HloModule;

Expand Down Expand Up @@ -1018,6 +1044,18 @@ class HloComputation {
// set the parent of a computation is to add it to a module.
void set_parent(HloModule* module) { parent_ = module; }

// Helper that updates the unique ID of the computation. This requires
// updating the callee_computations_ and caller_computations_ sets since they
// are ordered by unique ID.
void SetUniqueIdHelper(int64_t id);

friend class HloInstruction;
void AddCallee(HloComputation* callee);
void RemoveCallee(HloComputation* callee);

// Unique ID of this computation.
// This is set to -1 if the computation is not in a module. Should only be
// updated by SetUniqueIdHelper().
int64_t unique_id_;
HloInstruction* root_instruction_;

Expand Down Expand Up @@ -1056,6 +1094,25 @@ class HloComputation {

std::string name_;

// Callers and callees of this computation.
// * These include all computations that have a caller/callee relationship
// with this computation, even those that may not belong to a module. For
// example, a computation that has been created and is in the process of
// being constructed but has not been added to a module yet may appear here.
// * These are ordered maps, ordered by (unique ID, computation pointer). The
// unique ID is used to ensure determinism, whereas the computation pointer
// is used to disambiguate computations that do not belong to any module and
// therefore have a unique ID of -1. We assume that determinism only matters
// for computations that belong to a module (i.e, unique_id != -1), since
// the primary use case for this data structure is to topologically sort
// computations in a module.
// * The values of the maps are the number of times the computation is
// referenced. In a graph sense, this is the number of parallel edges.
absl::btree_map<HloComputation*, int, UniqueIdComparator>
callee_computations_;
absl::btree_map<HloComputation*, int, UniqueIdComparator>
caller_computations_;

HloComputation(const HloComputation&) = delete;
HloComputation& operator=(const HloComputation&) = delete;
};
Expand Down
29 changes: 27 additions & 2 deletions xla/hlo/ir/hlo_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,28 @@ void HloInstruction::AppendComputation(HloComputation* computation) {
// In .cc file since PtrVec<T*>::push_back() wants to check the alignment
// of T and hlo_instruction.h does not include hlo_computation.h.
mutable_rare()->called_computations.push_back(computation);
if (parent()) {
parent()->AddCallee(computation);
}
}

void HloInstruction::set_called_computation(int index,
HloComputation* computation) {
mutable_rare()->called_computations[index] = computation;
// TODO(b/399394039): Consider also enforcing that computation->parent() !=
// nullptr.
CHECK(parent() == nullptr || parent()->parent() == nullptr ||
parent()->parent() == computation->parent())
computation == nullptr || parent()->parent() == computation->parent())
<< ToString();
HloComputation* old_computation = computation;
std::swap(old_computation, mutable_rare()->called_computations[index]);
if (parent()) {
if (old_computation) {
parent()->RemoveCallee(old_computation);
}
if (computation) {
parent()->AddCallee(computation);
}
}
}

void HloInstruction::ReplaceCalledComputations(
Expand All @@ -238,6 +250,19 @@ void HloInstruction::ReplaceCalledComputations(
}
}

void HloInstruction::ClearCalledComputations() {
if (has_rare()) {
if (parent()) {
for (HloComputation* computation : called_computations()) {
if (computation) {
parent()->RemoveCallee(computation);
}
}
}
mutable_rare()->called_computations.clear();
}
}

HloInstruction* HloInstruction::AddInstruction(
std::unique_ptr<HloInstruction> derived_instruction) {
HloInstruction* derived =
Expand Down
6 changes: 1 addition & 5 deletions xla/hlo/ir/hlo_instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1742,11 +1742,7 @@ class HloInstruction {
// clearing out the computations, we reflect the fact that all side-effecting
// properties have been reflected in the caller, and make the call HLO
// removable.
virtual void ClearCalledComputations() {
if (has_rare()) {
mutable_rare()->called_computations.clear();
}
}
virtual void ClearCalledComputations();

// Returns true if this instruction performs an elementwise operation on
// `operand_idx`-th operand. An instruction is elementwise on an operand iff,
Expand Down
Loading
Loading