diff --git a/xla/BUILD b/xla/BUILD index ac7f71b338b9e..4794c47fed273 100644 --- a/xla/BUILD +++ b/xla/BUILD @@ -1393,6 +1393,32 @@ xla_cc_test( ], ) +cc_library( + name = "online_topsort", + hdrs = ["online_topsort.h"], + deps = [ + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + ], +) + +xla_cc_test( + name = "online_topsort_test", + srcs = ["online_topsort_test.cc"], + deps = [ + ":online_topsort", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:log_streamer", + "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", + ], +) + bzl_library( name = "lit_bzl", srcs = ["lit.bzl"], diff --git a/xla/hlo/evaluator/hlo_evaluator.cc b/xla/hlo/evaluator/hlo_evaluator.cc index 59c7b9cb4a0b8..b059601bbcfb0 100644 --- a/xla/hlo/evaluator/hlo_evaluator.cc +++ b/xla/hlo/evaluator/hlo_evaluator.cc @@ -1036,13 +1036,19 @@ absl::StatusOr HloEvaluator::EvaluateWithSubstitutions( std::unique_ptr cloned_instruction = instruction->CloneWithNewOperands(instruction->shape(), operands); - // TODO(phawkins): it's unfortunate that we need to call set_parent() here. + // TODO(phawkins): it's unfortunate that we need to call set_parent() here, + // since it violates the invariant that an instruction has a parent iff it is + // in a computation. // It's probably better to avoid constructing new instructions here in the // first place. cloned_instruction->set_parent( const_cast(instruction->parent())); auto result = Evaluate(cloned_instruction.get()); + // Undo the parent change, since it will confuse code that expects the + // instruction to be in a computation. + cloned_instruction->set_parent(nullptr); + return result; } diff --git a/xla/hlo/ir/hlo_computation.cc b/xla/hlo/ir/hlo_computation.cc index ee0c0a7ce4b5b..4b2051adc6ede 100644 --- a/xla/hlo/ir/hlo_computation.cc +++ b/xla/hlo/ir/hlo_computation.cc @@ -29,10 +29,12 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" +#include "absl/log/check.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -183,11 +185,38 @@ HloComputation::~HloComputation() { async_start_->ClearCalledComputations(); } Cleanup(); + ClearCalledComputations(); + + // We need to make sure there are no dangling references to this computation + // from instructions in other computations. + std::vector callers; + for (const auto& [caller, count] : caller_computations_) { + callers.push_back(caller); + } + for (HloComputation* caller : callers) { + for (HloInstruction* inst : caller->instructions()) { + for (int i = 0; i < inst->called_computations().size(); ++i) { + if (inst->called_computations()[i] == this) { + inst->set_called_computation(i, nullptr); + } + } + } + } + CHECK(caller_computations_.empty()); + for (const auto& i : instructions_) { delete i.inst(); } } +void HloComputation::ClearCalledComputations() { + for (HloInstruction* i : instructions()) { + i->ClearCalledComputations(); + } + // Clearing the instructions should have removed all callee computations. + CHECK(callee_computations_.empty()); +} + void HloComputation::SetInstruction(HloInstruction* instruction, InstructionType type) { static_assert(alignof(HloInstruction) == kInstructionTypeMask + 1, @@ -241,6 +270,38 @@ HloInstruction* HloComputation::AddInstruction( return AddInstruction(std::move(instruction)); } +static void IncrementCount( + absl::btree_map& + map, + HloComputation* key) { + ++map[key]; +} + +// Returns true if the callee was present and its count was decremented; returns +// false if the callee was not present. +static void DecrementCount( + absl::btree_map& + map, + HloComputation* key) { + auto it = map.find(key); + CHECK(it != map.end()); + CHECK_GT(it->second, 0); + --it->second; + if (it->second == 0) { + map.erase(it); + } +} + +void HloComputation::AddCallee(HloComputation* callee) { + IncrementCount(callee_computations_, callee); + IncrementCount(callee->caller_computations_, this); +} + +void HloComputation::RemoveCallee(HloComputation* callee) { + DecrementCount(callee_computations_, callee); + DecrementCount(callee->caller_computations_, this); +} + HloInstruction* HloComputation::AddInstructionInternal( std::unique_ptr instruction) { if (parent() != nullptr) { @@ -265,6 +326,7 @@ HloInstruction* HloComputation::AddInstructionInternal( CHECK(parent() == nullptr || called_computation->parent() == parent()) << "Called computation " << called_computation->name() << " is not in the same module as " << name(); + AddCallee(called_computation); } return pinst; } @@ -521,13 +583,13 @@ absl::Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction, HloInstructionInfo* info = &instructions_[instruction->index_in_parent_]; DCHECK_EQ(info->inst(), instruction); - info->inst()->set_parent(nullptr); to_be_deleted_.push_back(info->inst()); // Takes ownership to_be_deleted_.back()->DetachFromOperandsAndUsers(); // Clear all operands to avoid Null operands. to_be_deleted_.back()->RemoveAllOperands(); to_be_deleted_.back()->ClearCalledComputations(); to_be_deleted_.back()->MarkAsDead(); + info->inst()->set_parent(nullptr); // If this instruction is a constant, clear the literal eagerly instead of // waiting for the instruction to be deleted in Cleanup(). This greatly @@ -1089,7 +1151,7 @@ HloComputation::CreateFromProto( auto computation = absl::WrapUnique( new HloComputation(proto.name(), parameter_count, &instructions, root)); - computation->unique_id_ = proto.id(); + computation->SetUniqueIdHelper(proto.id()); if (proto.is_fusion_computation()) { computation->instruction_and_type_ = static_cast(InstructionType::kFusion); @@ -1840,4 +1902,36 @@ bool HloComputation::CanExpandIntoSingleInstruction() const { }); } +void HloComputation::ClearUniqueIdInternal() { SetUniqueIdHelper(-1); } + +void HloComputation::SetUniqueId(int64_t id) { + CHECK_EQ(unique_id_, -1); + CHECK_GE(id, 0); + SetUniqueIdHelper(id); +} + +void HloComputation::SetUniqueIdHelper(int64_t id) { + // The caller/callee computations are ordered by unique ID, so we need to + // remove and readd them to our neighbor's data structures. + for (auto& [computation, count] : caller_computations_) { + auto it = computation->callee_computations_.find(this); + CHECK(it != computation->callee_computations_.end()); + CHECK_EQ(it->second, count); + computation->callee_computations_.erase(it); + } + for (auto& [computation, count] : callee_computations_) { + auto it = computation->caller_computations_.find(this); + CHECK(it != computation->caller_computations_.end()); + CHECK_EQ(it->second, count); + computation->caller_computations_.erase(it); + } + unique_id_ = id; + for (auto& [computation, count] : caller_computations_) { + computation->callee_computations_[this] = count; + } + for (auto& [computation, count] : callee_computations_) { + computation->caller_computations_[this] = count; + } +} + } // namespace xla diff --git a/xla/hlo/ir/hlo_computation.h b/xla/hlo/ir/hlo_computation.h index ba706e882e1e6..c21a225c712c8 100644 --- a/xla/hlo/ir/hlo_computation.h +++ b/xla/hlo/ir/hlo_computation.h @@ -21,10 +21,12 @@ limitations under the License. #include #include #include +#include #include #include #include "absl/algorithm/container.h" +#include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" @@ -917,14 +919,10 @@ class HloComputation { // Clear the unique ID of the computation so that it can be re-assigned, such // as for the purpose of compacting the unique IDs. - void ClearUniqueIdInternal() { unique_id_ = -1; } + void ClearUniqueIdInternal(); // The id of this computation should be unique within the module. - void SetUniqueId(int64_t id) { - CHECK_EQ(unique_id_, -1); - CHECK_GE(id, 0); - unique_id_ = id; - } + void SetUniqueId(int64_t id); // Returns the instruction in this computation that has name `name`. Returns // null if there is no such computation. @@ -957,6 +955,34 @@ class HloComputation { // Returns true iff this computation can be inlined as a single instruction. bool CanExpandIntoSingleInstruction() const; + // A comparator that orders computations by their unique IDs. This is used + // for determinism. + struct UniqueIdComparator { + bool operator()(const HloComputation* lhs, + const HloComputation* rhs) const { + // We include the computation pointer so that we can disambiguate + // computations that do not belong to any module and therefore have a + // unique ID of -1. This is not deterministic, but we don't need + // determinism for computations not in a module since they are ignored + // by the topological sorting code. + return std::tie(lhs->unique_id_, lhs) < std::tie(rhs->unique_id_, rhs); + } + }; + + // Count of times this computation calls other computations. + absl::btree_map + callee_computations() const { + return callee_computations_; + } + + // Count of times this computation is called by other computations. + absl::btree_map + caller_computations() const { + return caller_computations_; + } + + void ClearCalledComputations(); + private: friend class HloModule; @@ -1018,6 +1044,18 @@ class HloComputation { // set the parent of a computation is to add it to a module. void set_parent(HloModule* module) { parent_ = module; } + // Helper that updates the unique ID of the computation. This requires + // updating the callee_computations_ and caller_computations_ sets since they + // are ordered by unique ID. + void SetUniqueIdHelper(int64_t id); + + friend class HloInstruction; + void AddCallee(HloComputation* callee); + void RemoveCallee(HloComputation* callee); + + // Unique ID of this computation. + // This is set to -1 if the computation is not in a module. Should only be + // updated by SetUniqueIdHelper(). int64_t unique_id_; HloInstruction* root_instruction_; @@ -1056,6 +1094,25 @@ class HloComputation { std::string name_; + // Callers and callees of this computation. + // * These include all computations that have a caller/callee relationship + // with this computation, even those that may not belong to a module. For + // example, a computation that has been created and is in the process of + // being constructed but has not been added to a module yet may appear here. + // * These are ordered maps, ordered by (unique ID, computation pointer). The + // unique ID is used to ensure determinism, whereas the computation pointer + // is used to disambiguate computations that do not belong to any module and + // therefore have a unique ID of -1. We assume that determinism only matters + // for computations that belong to a module (i.e, unique_id != -1), since + // the primary use case for this data structure is to topologically sort + // computations in a module. + // * The values of the maps are the number of times the computation is + // referenced. In a graph sense, this is the number of parallel edges. + absl::btree_map + callee_computations_; + absl::btree_map + caller_computations_; + HloComputation(const HloComputation&) = delete; HloComputation& operator=(const HloComputation&) = delete; }; diff --git a/xla/hlo/ir/hlo_instruction.cc b/xla/hlo/ir/hlo_instruction.cc index 89760b9b77133..e444e8d79b75f 100644 --- a/xla/hlo/ir/hlo_instruction.cc +++ b/xla/hlo/ir/hlo_instruction.cc @@ -219,16 +219,28 @@ void HloInstruction::AppendComputation(HloComputation* computation) { // In .cc file since PtrVec::push_back() wants to check the alignment // of T and hlo_instruction.h does not include hlo_computation.h. mutable_rare()->called_computations.push_back(computation); + if (parent()) { + parent()->AddCallee(computation); + } } void HloInstruction::set_called_computation(int index, HloComputation* computation) { - mutable_rare()->called_computations[index] = computation; // TODO(b/399394039): Consider also enforcing that computation->parent() != // nullptr. CHECK(parent() == nullptr || parent()->parent() == nullptr || - parent()->parent() == computation->parent()) + computation == nullptr || parent()->parent() == computation->parent()) << ToString(); + HloComputation* old_computation = computation; + std::swap(old_computation, mutable_rare()->called_computations[index]); + if (parent()) { + if (old_computation) { + parent()->RemoveCallee(old_computation); + } + if (computation) { + parent()->AddCallee(computation); + } + } } void HloInstruction::ReplaceCalledComputations( @@ -238,6 +250,19 @@ void HloInstruction::ReplaceCalledComputations( } } +void HloInstruction::ClearCalledComputations() { + if (has_rare()) { + if (parent()) { + for (HloComputation* computation : called_computations()) { + if (computation) { + parent()->RemoveCallee(computation); + } + } + } + mutable_rare()->called_computations.clear(); + } +} + HloInstruction* HloInstruction::AddInstruction( std::unique_ptr derived_instruction) { HloInstruction* derived = diff --git a/xla/hlo/ir/hlo_instruction.h b/xla/hlo/ir/hlo_instruction.h index 344b320edda6f..73103a909613b 100644 --- a/xla/hlo/ir/hlo_instruction.h +++ b/xla/hlo/ir/hlo_instruction.h @@ -1742,11 +1742,7 @@ class HloInstruction { // clearing out the computations, we reflect the fact that all side-effecting // properties have been reflected in the caller, and make the call HLO // removable. - virtual void ClearCalledComputations() { - if (has_rare()) { - mutable_rare()->called_computations.clear(); - } - } + virtual void ClearCalledComputations(); // Returns true if this instruction performs an elementwise operation on // `operand_idx`-th operand. An instruction is elementwise on an operand iff, diff --git a/xla/hlo/ir/hlo_module.cc b/xla/hlo/ir/hlo_module.cc index f9ab1aeb7897d..edf395438b310 100644 --- a/xla/hlo/ir/hlo_module.cc +++ b/xla/hlo/ir/hlo_module.cc @@ -90,6 +90,14 @@ HloModule::HloModule(const std::string& name, metadata_.set_canonical_module_id(unique_id_); } +HloModule::~HloModule() { + // To avoid dangling references between computations, we first clear all the + // inter-computation references before deleting any of the computations. + for (const auto& computation : computations_) { + computation->ClearCalledComputations(); + } +} + absl::Status HloModule::set_schedule(HloSchedule schedule) { TF_RET_CHECK(schedule.module() == this); TF_RETURN_IF_ERROR(schedule.Verify()); diff --git a/xla/hlo/ir/hlo_module.h b/xla/hlo/ir/hlo_module.h index 260b33720b9de..cd1dd49d7515a 100644 --- a/xla/hlo/ir/hlo_module.h +++ b/xla/hlo/ir/hlo_module.h @@ -88,7 +88,7 @@ class HloModule { HloModule(const std::string& name, std::shared_ptr config, std::unique_ptr comp_envs); - virtual ~HloModule() = default; + virtual ~HloModule(); // Adds an entry computation to the module. A module can only have one entry // computation. Returns a pointer to the newly added computation. diff --git a/xla/hlo/ir/hlo_module_test.cc b/xla/hlo/ir/hlo_module_test.cc index 66588a9611636..7aee920b94faf 100644 --- a/xla/hlo/ir/hlo_module_test.cc +++ b/xla/hlo/ir/hlo_module_test.cc @@ -18,8 +18,10 @@ limitations under the License. #include #include #include +#include #include +#include #include #include "absl/hash/hash.h" #include "absl/strings/str_cat.h" @@ -32,11 +34,15 @@ limitations under the License. #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #include "xla/xla_data.pb.h" namespace xla { namespace { +using ::testing::ElementsAre; + TEST(HloModuleTest, AbslHashValue) { HloModule module1("temp_module", HloModuleConfig()); HloModule module2("temp_module3", HloModuleConfig()); @@ -469,5 +475,43 @@ TEST(HloModuleTest, CheckToStringHonorsDebugOptions) { EXPECT_TRUE(filecheck_matched); } +TEST(HloModuleTest, TestCallersAndCallees) { + // Check that the debug options xla_dump_large_constants, + // xla_syntax_sugar_async_ops are honored. + const char* hlo = R"( + HloModule jit_h + + f { + Arg_0.3 = f32[] parameter(0) + ROOT sine.4 = f32[] sine(Arg_0.3) + } + + g { + Arg_0.13 = f32[] parameter(0) + call.14 = f32[] call(Arg_0.13), to_apply=f + ROOT call.15 = f32[] call(call.14), to_apply=f + } + + ENTRY main { + Arg_0.1 = f32[] parameter(0) + call.5 = f32[] call(Arg_0.1), to_apply=f + call.16 = f32[] call(call.5), to_apply=g + ROOT call.27 = f32[] call(call.16), to_apply=g + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule(hlo)); + EXPECT_EQ(module->computation_count(), 3); + HloComputation* main = module->GetComputationWithName("main"); + HloComputation* f = module->GetComputationWithName("f"); + HloComputation* g = module->GetComputationWithName("g"); + EXPECT_THAT(main->callee_computations(), + ElementsAre(std::make_pair(f, 1), std::make_pair(g, 2))); + EXPECT_THAT(f->callee_computations(), ElementsAre()); + EXPECT_THAT(g->callee_computations(), ElementsAre(std::make_pair(f, 2))); + EXPECT_THAT(f->caller_computations(), + ElementsAre(std::make_pair(g, 2), std::make_pair(main, 1))); + EXPECT_THAT(g->caller_computations(), ElementsAre(std::make_pair(main, 2))); +} + } // namespace } // namespace xla diff --git a/xla/online_topsort.h b/xla/online_topsort.h new file mode 100644 index 0000000000000..1b9f7352947cb --- /dev/null +++ b/xla/online_topsort.h @@ -0,0 +1,784 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This module implements an online topological sort using the two-way-search +// algorithm for sparse graphs of Bender et al., Section 2. The algorithm +// incorporates the extension from section 4 to maintain the topological order +// explicitly in a doubly-linked list. +// +// Per Bender et al, inserting m edges into a graph of n nodes takes +// O(m*min(m**(1/2), n**(2/3))). For the use case of our compiler IR, we assume +// that the number of edges is at most a small multiple of the number of nodes, +// and so the graph is quite sparse, and the dominant bound is O(m**(3/2)). +// +// We implement several extensions to the algorithm: +// - we allow adding and removing nodes. This does not require any significant +// changes to the algorithm. The original algorithm uses the values of m and n +// as part of a scheme for numbering nodes, but the purpose of that scheme is +// to combine (level, index) tuples into a single total order. We don't need +// explicit position numbers, only the topological order, so we can just use +// a lexicographic order of (level, index) tuples directly. +// - we number indices decreasing from std::numeric_limits::max(). The +// careful numbering of indices in the original paper is only to avoid +// collisions in the ID space with the level numbers, but since we don't try +// to combine these into a single number, we don't need to be quite as +// careful. +// - we allow removing edges. This is a trivial extension; removing an edge +// preserves topological ordering. Removing edges may affect the algorithmic +// complexity guarantees, but we probably don't care that much. +// +// This implementation is not thread-safe. +// +// Type parameters: +// - T is the type of the nodes in the graph. +// - Index is the type of the index_in_parent field in the nodes. We only care +// that the index values form a reasonably dense range starting at 0, since +// we use them to index into vectors. If we didn't have a dense range, we +// could use an associative map data structure instead, but that would be +// slower to lookup. +// - Link is a pointer to the embedded TopologicalSortNode field in T. +// - IndexInParent is a pointer to the index_in_parent field in T. +// These indices must remain fixed only during a call to AddEdge(), which +// is obviously true because we don't allow threads and the topological sort +// will not change them, but they are allowed to change between calls. +// - PredecessorIterator, PredecessorsBegin, PredecessorsEnd iterate over the +// predecessors of the node. Duplicates are allowed. +// - SuccessorIterator, SuccessorsBegin, SuccessorsEnd iterate over the +// successors of the node. Duplicates are allowed. +// +// References: +// * Bender, M.A., Fineman, J.T., Gilbert, S. and Tarjan, R.E., 2015. A new +// approach to incremental cycle detection and related problems. +// ACM Transactions on Algorithms (TALG), 12(2), pp.1-22. +// https://dl.acm.org/doi/abs/10.1145/2756553 + +#ifndef XLA_ONLINE_TOPSORT_H_ +#define XLA_ONLINE_TOPSORT_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" + +// The topological sort is an intrusive data structure. Nodes of type T that +// participate in the topological sort must have a TopologicalSortNode +// embedded within them. +template +class TopologicalSortNode { + public: + TopologicalSortNode() = default; + ~TopologicalSortNode() { DCHECK(!in_topological_order()) << level_; } + + TopologicalSortNode(const TopologicalSortNode&) = delete; + TopologicalSortNode(TopologicalSortNode&&) = delete; + TopologicalSortNode& operator=(const TopologicalSortNode&) = delete; + TopologicalSortNode& operator=(TopologicalSortNode&&) = delete; + + void clear() { + next_ = nullptr; + prev_ = nullptr; + level_ = -1; + index_ = -1; + } + + // Returns true if this node has been added to a topological order. + // It may have temporarily been removed from a specific location in that + // order if we are in the middle of an AddEdge() operation. + bool in_topological_order() const { return level_ >= 0; } + + private: + template S::* Link, + Index S::* IndexInParent, typename PredecessorIterator, + PredecessorIterator (S::*PredecessorsBegin)() const, + PredecessorIterator (S::*PredecessorsEnd)() const, + typename SuccessorIterator, + SuccessorIterator (S::*SuccessorsBegin)() const, + SuccessorIterator (S::*SuccessorsEnd)() const> + friend class TopologicalSort; + + template S::* Link> + friend class TopologicalSortForwardIterator; + template S::* Link> + friend class TopologicalSortReverseIterator; + + int index_ = -1; + int level_ = -1; + + // The nodes form a doubly-linked list, where the `next_` pointers are not + // circular, but the `prev_` pointers are circular. + // There is also an asymmetry in the types of `next_` and `prev_`: the former + // is a pointer to a node, while the latter is a pointer to a + // TopologicalSortNode embedded within a node. This trick helps us define + // an intrusive templated list in C++. + T* next_ = nullptr; + TopologicalSortNode* prev_ = nullptr; +}; + +// Iterator that traverses through the topological sort in order. +template T::* Link> +class TopologicalSortForwardIterator { + public: + TopologicalSortForwardIterator() : current_(nullptr) {} + explicit TopologicalSortForwardIterator(const TopologicalSortNode* current) + : current_(current) {} + + TopologicalSortForwardIterator(const TopologicalSortForwardIterator&) = + default; + TopologicalSortForwardIterator(TopologicalSortForwardIterator&&) = default; + TopologicalSortForwardIterator& operator=( + const TopologicalSortForwardIterator&) = default; + TopologicalSortForwardIterator& operator=(TopologicalSortForwardIterator&&) = + default; + + T& operator*() const { return *current_->next_; } + T* operator->() const { return current_->next_; } + + bool operator==(const TopologicalSortForwardIterator& other) const { + return current_ == other.current_; + } + bool operator!=(const TopologicalSortForwardIterator& other) const { + return current_ != other.current_; + } + + TopologicalSortForwardIterator& operator++() { + current_ = &(current_->next_->*Link); + return *this; + } + + TopologicalSortForwardIterator& operator--() { + current_ = ¤t_->prev_; + return *this; + } + + private: + // Note: the iterator is a pointer to a node whose *next* pointer points to + // the current node. + TopologicalSortNode const* current_; +}; + +// Iterator that traverses through the topological sort in reverse order. +template T::* Link> +class TopologicalSortReverseIterator { + public: + TopologicalSortReverseIterator() : current_(nullptr) {} + explicit TopologicalSortReverseIterator(const TopologicalSortNode* current) + : current_(current) {} + + TopologicalSortReverseIterator(const TopologicalSortReverseIterator&) = + default; + TopologicalSortReverseIterator(TopologicalSortReverseIterator&&) = default; + TopologicalSortReverseIterator& operator=( + const TopologicalSortReverseIterator&) = default; + TopologicalSortReverseIterator& operator=(TopologicalSortReverseIterator&&) = + default; + + T& operator*() const { return *current_->next_; } + T* operator->() const { return current_->next_; } + + bool operator==(const TopologicalSortReverseIterator& other) const { + return current_ == other.current_; + } + bool operator!=(const TopologicalSortReverseIterator& other) const { + return current_ != other.current_; + } + + TopologicalSortReverseIterator& operator++() { + current_ = current_->prev_; + return *this; + } + + private: + // Note: the iterator is a pointer to a node whose *next* pointer points to + // the current node. + TopologicalSortNode const* current_; +}; + +template T::* Link, + Index T::* IndexInParent, typename PredecessorIterator, + PredecessorIterator (T::*PredecessorsBegin)() const, + PredecessorIterator (T::*PredecessorsEnd)() const, + typename SuccessorIterator, + SuccessorIterator (T::*SuccessorsBegin)() const, + SuccessorIterator (T::*SuccessorsEnd)() const> +class TopologicalSort { + public: + TopologicalSort() { + node_.next_ = nullptr; + node_.prev_ = &node_; + first_in_level_.push_back(&node_); + } + + ~TopologicalSort(); + + // Invalidates iterators. + void AddNode(T* v); + + // Invalidates iterators. + void RemoveNode(T* v); + + // Caution: this data structure assumes that there are no parallel edges. + // Invalidates any iterators. We assume the user has added the edge to their + // own data structure before calling this method. + void AddEdge(T* v, T* w); + + // You might wonder why we don't have the following method: + // void RemoveEdge(T* v, T* w); + // The reason is that we don't need it. Removing an edge preserves topological + // ordering, and there's nothing for us to do here. The user still needs to + // remove the edge from their own data structure, of course. + + // Returns an iterator over the nodes in topological order. + TopologicalSortForwardIterator begin() const { + return TopologicalSortForwardIterator(&node_); + } + TopologicalSortForwardIterator end() const { + return TopologicalSortForwardIterator(node_.prev_); + } + + // Returns an iterator over the nodes in reverse topological order. + TopologicalSortReverseIterator rbegin() const { + return TopologicalSortReverseIterator(node_.prev_->prev_); + } + TopologicalSortReverseIterator rend() const { + return TopologicalSortReverseIterator(node_.prev_); + } + + // This is a helper for debugging. It logs the current order and checks a + // number of invariants. + void LogOrder() { + std::vector order; + int level = -1; + for (T& node : *this) { + const auto& link = node.*Link; + CHECK_GE(link.level_, level); + level = link.level_; + if (link.next_) { + CHECK((link.next_->*Link).prev_ == &link); + } else { + CHECK(node_.prev_ == &link); + } + CHECK(link.prev_->next_ == &node); + order.push_back(&node); + } + auto node_formatter = [](std::string* out, T* v) { + absl::StrAppend(out, v->*IndexInParent, "[", (v->*Link).level_, ":", + (v->*Link).index_, "]"); + }; + DVLOG(2) << this << " order=" << absl::StrJoin(order, ", ", node_formatter); + auto first_in_level_formatter = [](std::string* out, + TopologicalSortNode* v) { + if (v->next_) { + absl::StrAppend(out, v->next_->*IndexInParent, ":", + (v->next_->*Link).level_); + } else { + absl::StrAppend(out, "-:-"); + } + }; + DVLOG(2) << this << " first_in_level_=" + << absl::StrJoin(first_in_level_, ", ", first_in_level_formatter); + + CHECK(first_in_level_[0] == &node_); + auto it = order.begin(); + for (TopologicalSortNode* v : first_in_level_) { + it = std::find(it, order.end(), v->next_); + CHECK(v->next_ == nullptr || it != order.end()); + } + } + + void clear() { node_.clear(); } + + private: + // Updates delta_ after we have increased num_edges_ and num_nodes_. + // We don't bother decreasing delta_ after removals, since we assume that our + // graphs will not significantly shrink. + void UpdateDelta(); + + // Performs a DFS backwards from v of at most delta_ nodes on the same level, + // populating b with nodes in postorder with respect to the search (i.e., a + // node appears later in b than its predecessors). Returns true if we should + // run a forwards search. + bool SearchBackwards(T* v, T* w, std::vector& b); + + // Performs a DFS forwards from v populating f with nodes in postorder with + // respect to the search (i.e., a node appears later in f than all its + // predecessors). + // (Note "f" is reversed from the paper, which just because we can save time + // and reverse it when updating the indices, rather than explicitly reversing + // it here.) + void SearchForwards(T* v, T* w, std::vector& f); + + // Removes v from the topological order. + void RemoveFromOrder(T* v); + + void UpdateIndex(T* v); + + // Helper that makes sure that the AddEdge() data structures are large enough + // to hold nodes with index max_index_in_parent. + void UpdateMaxIndexInParent(Index max_index_in_parent) { + if (max_index_in_parent >= visited_backwards_.size()) { + visited_backwards_.resize(max_index_in_parent + 1); + visited_forwards_.resize(max_index_in_parent + 1); + increased_.resize(max_index_in_parent + 1); + } + } + + TopologicalSortNode node_; + + int num_edges_ = 0; // aka "m" in the paper. + int num_nodes_ = 0; // aka "n" in the paper. + + // How many nodes to search backwards when adding an edge. This should be + // ceil(min(m**(1/2), n**(2/3))), but we compute that bound online as we add + // nodes and edges via UpdateDelta(). + int64_t delta_ = 0; + + // The next value of index_ to assign, aka "a" in the paper. Monotonically + // decreasing as indices are assigned. + // You might also wonder where 'b' from the paper is, but we simply don't + // need it, since we're trying to maintain a doubly-linked list in topological + // order, and we don't care about computing a topological numbering. + int next_index_ = std::numeric_limits::max(); + + // The first node in each level or a higher level. + // As is the usual convention for this data structure, this is actually the + // TopologicalSortNode whose next_ pointer points to that node, if any. + // Invariant: There is always at least one level. Futher, these pointers are + // never nullptr: there's always a preceding node (node_, if nothing else). + std::vector*> first_in_level_; + + // Visited state for forwards and backwards searches which are used during + // AddEdge(). We keep this state in the class to save repeatedly allocating + // it. This would not be thread-safe, but neither is AddEdge(). + std::vector visited_backwards_; + std::vector visited_forwards_; + std::vector increased_; +}; + +template T::* Link, + Index T::* IndexInParent, typename PredecessorIterator, + PredecessorIterator (T::*PredecessorsBegin)() const, + PredecessorIterator (T::*PredecessorsEnd)() const, + typename SuccessorIterator, + SuccessorIterator (T::*SuccessorsBegin)() const, + SuccessorIterator (T::*SuccessorsEnd)() const> +TopologicalSort::~TopologicalSort() { + TopologicalSortNode* next; + for (TopologicalSortNode* node = &node_; node != nullptr; node = next) { + if (node->next_) { + next = &(node->next_->*Link); + } else { + next = nullptr; + } + node->clear(); + } +} + +template T::* Link, + Index T::* IndexInParent, typename PredecessorIterator, + PredecessorIterator (T::*PredecessorsBegin)() const, + PredecessorIterator (T::*PredecessorsEnd)() const, + typename SuccessorIterator, + SuccessorIterator (T::*SuccessorsBegin)() const, + SuccessorIterator (T::*SuccessorsEnd)() const> +void TopologicalSort::AddNode(T* v) { + TopologicalSortNode* node = &(v->*Link); + if (VLOG_IS_ON(1)) { + DVLOG(1) << this << " AddNode(" << v->*IndexInParent << ")"; + LogOrder(); + } + + // next_ and prev_ should be nullptr for a new node. + CHECK(node->next_ == nullptr); + CHECK(node->prev_ == nullptr); + node->level_ = 0; + node->index_ = next_index_--; + ++num_nodes_; + UpdateDelta(); + + // Add the node to the front of the topological ordering. + node->next_ = first_in_level_[0]->next_; + node->prev_ = first_in_level_[0]; + if (node->next_) { + (node->next_->*Link).prev_ = node; + } else { + node_.prev_ = node; + } + first_in_level_[0]->next_ = v; + for (int level = 1; + level < first_in_level_.size() && first_in_level_[level] == &node_; + ++level) { + first_in_level_[level] = node; + } + if (VLOG_IS_ON(1)) { + LogOrder(); + } +} + +template T::* Link, + Index T::* IndexInParent, typename PredecessorIterator, + PredecessorIterator (T::*PredecessorsBegin)() const, + PredecessorIterator (T::*PredecessorsEnd)() const, + typename SuccessorIterator, + SuccessorIterator (T::*SuccessorsBegin)() const, + SuccessorIterator (T::*SuccessorsEnd)() const> +void TopologicalSort::RemoveNode(T* v) { + TopologicalSortNode* node = &(v->*Link); + DVLOG(1) << this << " RemoveNode(" << v->*IndexInParent << ")"; + CHECK(node->prev_ == &node_ || node->prev_->in_topological_order()); + --num_nodes_; + if (VLOG_IS_ON(1)) { + LogOrder(); + } + RemoveFromOrder(v); + node->level_ = -1; + node->index_ = -1; + if (VLOG_IS_ON(1)) { + LogOrder(); + } +} + +template T::* Link, + Index T::* IndexInParent, typename PredecessorIterator, + PredecessorIterator (T::*PredecessorsBegin)() const, + PredecessorIterator (T::*PredecessorsEnd)() const, + typename SuccessorIterator, + SuccessorIterator (T::*SuccessorsBegin)() const, + SuccessorIterator (T::*SuccessorsEnd)() const> +void TopologicalSort::AddEdge(T* v, T* w) { + TopologicalSortNode* v_node = &(v->*Link); + TopologicalSortNode* w_node = &(w->*Link); + + ++num_edges_; + UpdateDelta(); + + DVLOG(1) << this << " AddEdge(" << v->*IndexInParent << ", " + << w->*IndexInParent << ") v={level=" << v_node->level_ << " " + << "index=" << v_node->index_ << "} " + << " w={level=" << w_node->level_ << " " + << "index=" << w_node->index_ << "} " + << "delta_=" << delta_; + + // Verify that both nodes are in the topological order. + DCHECK(v_node->in_topological_order()); + DCHECK(w_node->in_topological_order()); + + // Step 1: test order: if w is already higher than v in the lexicographical + // order then the current ordering is fine. + if (std::tie(v_node->level_, v_node->index_) < + std::tie(w_node->level_, w_node->index_)) { + if (VLOG_IS_ON(1)) { + LogOrder(); + } + return; + } + + // Step 2: search backwards from v, until we either find `w`, which means we + // have a cycle, visit delta_ edges, or run out of edges to visit. + std::vector b; + bool should_search_forwards; + bool visited_delta_edges = SearchBackwards(v, w, b); + if (visited_delta_edges) { + b.resize(1); + b.front() = v; + RemoveFromOrder(w); + w_node->level_ = v_node->level_ + 1; + + should_search_forwards = true; + } else if (w_node->level_ == v_node->level_) { + // l = b; + should_search_forwards = false; + } else { + // We know that w_node->level < v_node->level, by the case above and by the + // test in step 1. + DCHECK_LT(w_node->level_, v_node->level_); + RemoveFromOrder(w); + w_node->level_ = v_node->level_; + should_search_forwards = true; + } + + // Step 3: search forwards from w, following outgoing edges only from nodes + // whose level increases. + std::vector f; + if (should_search_forwards) { + SearchForwards(v, w, f); + if (v_node->level_ < w_node->level_) { + b.clear(); // l = reverse(f) + } else { + CHECK_EQ(v_node->level_, w_node->level_); + // l = b + reverse(f) + } + } + + // Step 4: update indices. + auto node_formatter = [](std::string* out, T* v) { + absl::StrAppend(out, v->*IndexInParent); + }; + DVLOG(2) << "b=" << absl::StrJoin(b, ", ", node_formatter) + << " f=" << absl::StrJoin(f, ", ", node_formatter); + for (auto it = f.begin(); it != f.end(); ++it) { + UpdateIndex(*it); + } + for (auto it = b.rbegin(); it != b.rend(); ++it) { + UpdateIndex(*it); + } + + // Step 5: add the edge. + // There's actually nothing to do here, because it's up to the user to add + // the edge to their own data structures. It doesn't matter whether the user + // does that before or after they call our AddEdge(), since we only search + // backwards from v and forwards from w. + + if (VLOG_IS_ON(1)) { + LogOrder(); + + DVLOG(1) << "end AddEdge(" << v->*IndexInParent << ", " << w->*IndexInParent + << ") v={level=" << v_node->level_ << " " + << "index=" << v_node->index_ << "} " + << " w={level=" << w_node->level_ << " " + << "index=" << w_node->index_ << "} " + << "delta_=" << delta_; + } +} + +template T::* Link, + Index T::* IndexInParent, typename PredecessorIterator, + PredecessorIterator (T::*PredecessorsBegin)() const, + PredecessorIterator (T::*PredecessorsEnd)() const, + typename SuccessorIterator, + SuccessorIterator (T::*SuccessorsBegin)() const, + SuccessorIterator (T::*SuccessorsEnd)() const> +bool TopologicalSort::SearchBackwards(T* v, T* w, + std::vector& b) { + std::vector> agenda; + std::fill(visited_backwards_.begin(), visited_backwards_.end(), false); + int num_edges_visited = 0; + agenda.emplace_back(v, false); + while (!agenda.empty()) { + auto [y, post] = agenda.back(); + agenda.pop_back(); + DVLOG(3) << "SearchBackwards visiting " << y->*IndexInParent + << " post=" << post; + CHECK(y != w) << "Cycle detected"; + TopologicalSortNode* y_node = &(y->*Link); + int level = y_node->level_; + if (post) { + b.push_back(y); + continue; + } + + Index y_index_in_parent = y->*IndexInParent; + UpdateMaxIndexInParent(y_index_in_parent); + if (visited_backwards_[y_index_in_parent]) { + continue; + } + visited_backwards_[y_index_in_parent] = true; + + agenda.emplace_back(y, true); + for (auto it = std::invoke(PredecessorsBegin, y); + num_edges_visited < delta_ && it != std::invoke(PredecessorsEnd, y); + ++it) { + T* x = *it; + TopologicalSortNode* x_node = &(x->*Link); + if (!x_node->in_topological_order()) { + continue; + } + CHECK_LE(x_node->level_, level); + VLOG(2) << "visiting edge " << x->*IndexInParent; + if (x_node->level_ == level) { + ++num_edges_visited; + if (num_edges_visited >= delta_) { + return true; + } + agenda.emplace_back(x, false); + } + } + } + return false; +} + +template T::* Link, + Index T::* IndexInParent, typename PredecessorIterator, + PredecessorIterator (T::*PredecessorsBegin)() const, + PredecessorIterator (T::*PredecessorsEnd)() const, + typename SuccessorIterator, + SuccessorIterator (T::*SuccessorsBegin)() const, + SuccessorIterator (T::*SuccessorsEnd)() const> +void TopologicalSort::SearchForwards(T* v, T* w, + std::vector& f) { + std::fill(visited_forwards_.begin(), visited_forwards_.end(), false); + std::fill(increased_.begin(), increased_.end(), false); + std::vector> agenda; + agenda.emplace_back(w, false); + UpdateMaxIndexInParent(w->*IndexInParent); + increased_[w->*IndexInParent] = true; + + // f list of vertices whose level increases, in reverse postorder, i.e., + // a vertex appears in f before its successors. + while (!agenda.empty()) { + auto [x, post] = agenda.back(); + agenda.pop_back(); + DVLOG(3) << "SearchForwards visiting " << x->*IndexInParent + << " post=" << post; + if (post) { + f.push_back(x); + continue; + } + Index x_index_in_parent = x->*IndexInParent; + UpdateMaxIndexInParent(x_index_in_parent); + if (visited_forwards_[x_index_in_parent] || + !increased_[x_index_in_parent]) { + continue; + } + visited_forwards_[x_index_in_parent] = true; + + agenda.emplace_back(x, true); + + TopologicalSortNode* x_node = &(x->*Link); + for (auto it = std::invoke(SuccessorsBegin, x); + it != std::invoke(SuccessorsEnd, x); ++it) { + T* y = *it; + VLOG(3) << "fwd edge to " << y->*IndexInParent; + TopologicalSortNode* y_node = &(y->*Link); + if (!y_node->in_topological_order()) { + continue; + } + Index y_index_in_parent = y->*IndexInParent; + UpdateMaxIndexInParent(y_index_in_parent); + DCHECK(y != v) << "Cycle detected " << y->*IndexInParent; + DCHECK(!visited_backwards_[y_index_in_parent]) + << "Cycle detected " << y->*IndexInParent; + agenda.emplace_back(y, false); + if (x_node->level_ > y_node->level_) { + RemoveFromOrder(y); + y_node->level_ = x_node->level_; + increased_[y_index_in_parent] = true; + } + } + } +} + +template T::* Link, + Index T::* IndexInParent, typename PredecessorIterator, + PredecessorIterator (T::*PredecessorsBegin)() const, + PredecessorIterator (T::*PredecessorsEnd)() const, + typename SuccessorIterator, + SuccessorIterator (T::*SuccessorsBegin)() const, + SuccessorIterator (T::*SuccessorsEnd)() const> +void TopologicalSort::RemoveFromOrder(T* v) { + TopologicalSortNode* v_node = &(v->*Link); + // If this node is the last node in any level, it may appear in the + // first_in_level_ vector for subsequent levels. + for (int level = v_node->level_ + 1; + level < first_in_level_.size() && first_in_level_[level] == v_node; + ++level) { + first_in_level_[level] = v_node->prev_; + } + v_node->prev_->next_ = v_node->next_; + if (v_node->next_) { + (v_node->next_->*Link).prev_ = v_node->prev_; + } else { + node_.prev_ = v_node->prev_; + } + v_node->next_ = nullptr; + v_node->prev_ = nullptr; +} + +template T::* Link, + Index T::* IndexInParent, typename PredecessorIterator, + PredecessorIterator (T::*PredecessorsBegin)() const, + PredecessorIterator (T::*PredecessorsEnd)() const, + typename SuccessorIterator, + SuccessorIterator (T::*SuccessorsBegin)() const, + SuccessorIterator (T::*SuccessorsEnd)() const> +void TopologicalSort::UpdateIndex(T* v) { + TopologicalSortNode* v_node = &(v->*Link); + + if (v_node->prev_) { + // TODO(phawkins): could we just do this above? + RemoveFromOrder(v); + } + + // Since this node just decreased in index, it now becomes the first node on + // its level. + v_node->index_ = next_index_--; + if (v_node->level_ >= first_in_level_.size()) { + TopologicalSortNode* t = first_in_level_.back(); + while (t->next_ != nullptr) { + t = &(t->next_->*Link); + } + first_in_level_.resize(v_node->level_ + 1, t); + } + + TopologicalSortNode* old_first = first_in_level_[v_node->level_]; + v_node->next_ = old_first->next_; + v_node->prev_ = old_first; + if (v_node->next_) { + (v_node->next_->*Link).prev_ = v_node; + } else { + node_.prev_ = v_node; + } + old_first->next_ = v; + for (int level = v_node->level_ + 1; + level < first_in_level_.size() && first_in_level_[level] == old_first; + ++level) { + first_in_level_[level] = v_node; + } +} + +template T::* Link, + Index T::* IndexInParent, typename PredecessorIterator, + PredecessorIterator (T::*PredecessorsBegin)() const, + PredecessorIterator (T::*PredecessorsEnd)() const, + typename SuccessorIterator, + SuccessorIterator (T::*SuccessorsBegin)() const, + SuccessorIterator (T::*SuccessorsEnd)() const> +void TopologicalSort::UpdateDelta() { + int64_t m = num_edges_; + int64_t n = num_nodes_; + // delta should be ceil(min(m**(1/2), n**(2/3))) + while (delta_ * delta_ < m && delta_ * delta_ * delta_ < n * n) { + ++delta_; + } +} + +#endif // XLA_ONLINE_TOPSORT_H_ diff --git a/xla/online_topsort_test.cc b/xla/online_topsort_test.cc new file mode 100644 index 0000000000000..5deb39f4487b1 --- /dev/null +++ b/xla/online_topsort_test.cc @@ -0,0 +1,278 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/online_topsort.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/log/log_streamer.h" +#include "absl/random/random.h" +#include "absl/strings/str_join.h" +#include "xla/tsl/platform/test.h" + +namespace { + +struct TestNode { + explicit TestNode(int id) : id(id) {} + + int id; + std::vector in; + std::vector out; + TopologicalSortNode node; + + std::vector::const_iterator incoming_begin() const { + return in.begin(); + } + std::vector::const_iterator incoming_end() const { + return in.end(); + } + std::vector::const_iterator outgoing_begin() const { + return out.begin(); + } + std::vector::const_iterator outgoing_end() const { + return out.end(); + } +}; + +using Topsort = + TopologicalSort::const_iterator, + &TestNode::incoming_begin, &TestNode::incoming_end, + std::vector::const_iterator, + &TestNode::outgoing_begin, &TestNode::outgoing_end>; + +struct TestGraph { + void AddNode(int id) { + if (id >= node_index.size()) { + node_index.resize(id + 1, nullptr); + } + auto node = std::make_unique(id); + CHECK(node_index[id] == nullptr) << id; + node_index[id] = node.get(); + topsort.AddNode(node.get()); + nodes.push_back(std::move(node)); + } + + void RemoveNode(int id) { + TestNode* node = node_index[id]; + for (TestNode* x : node->in) { + RemoveEdge(x->id, node->id); + } + for (TestNode* x : node->out) { + RemoveEdge(id, x->id); + } + node_index[id] = nullptr; + topsort.RemoveNode(node); + auto it = std::find_if(nodes.begin(), nodes.end(), + [node](const auto& x) { return x.get() == node; }); + CHECK(it != nodes.end()); + nodes.erase(it); + } + + void AddEdge(int from, int to) { + CHECK_GE(from, 0); + CHECK_LT(from, node_index.size()); + CHECK_GE(to, 0); + CHECK_LT(to, node_index.size()); + TestNode* from_node = node_index[from]; + TestNode* to_node = node_index[to]; + topsort.AddEdge(from_node, to_node); + from_node->out.push_back(to_node); + to_node->in.push_back(from_node); + } + + bool HasEdge(int from, int to) const { + TestNode* from_node = node_index[from]; + TestNode* to_node = node_index[to]; + return std::find(from_node->out.begin(), from_node->out.end(), to_node) != + from_node->out.end(); + } + + void RemoveEdge(int from, int to) { + TestNode* from_node = node_index[from]; + TestNode* to_node = node_index[to]; + auto it = std::find(from_node->out.begin(), from_node->out.end(), to_node); + CHECK(it != from_node->out.end()); + from_node->out.erase(it); + it = std::find(to_node->in.begin(), to_node->in.end(), from_node); + CHECK(it != to_node->in.end()); + to_node->in.erase(it); + } + + // Returns std::nullopt if the topological order is valid. Otherwise, returns + // an edge that is inconsistent with the topological order. + std::optional> TopologicalOrderIsValid() const { + std::vector order(node_index.size(), -1); + int i = 0; + std::vector forward; + for (const TestNode& node : topsort) { + forward.push_back(&node); + order[node.id] = i++; + } + + // Verifies that the reverse iterator gives the same order. + std::vector reverse; + for (auto it = topsort.rbegin(); it != topsort.rend(); ++it) { + reverse.push_back(&*it); + } + absl::c_reverse(reverse); + CHECK(forward == reverse); + + for (const auto& x : nodes) { + for (TestNode* y : x->out) { + if (order[x->id] >= order[y->id]) { + return std::make_pair(x->id, y->id); + } + } + } + return std::nullopt; + } + + std::vector> nodes; + std::vector node_index; + Topsort topsort; +}; + +std::string OrderString(const Topsort& top) { + std::vector order; + for (TestNode& node : top) { + order.push_back(node.id); + } + return absl::StrJoin(order, ","); +} + +MATCHER(HasValidTopologicalOrder, "") { + std::optional> result = arg.TopologicalOrderIsValid(); + if (!result) { + return true; + } + *result_listener << "Topological order: " << OrderString(arg.topsort) + << " is inconsistent with edge " << result->first << "->" + << result->second; + return false; +} + +TEST(TopologicalSortTest, Basic) { + TestGraph g; + for (int i = 0; i < 10; ++i) { + g.AddNode(i); + ASSERT_THAT(g, HasValidTopologicalOrder()); + } + g.AddEdge(0, 1); + ASSERT_THAT(g, HasValidTopologicalOrder()); + g.AddEdge(1, 2); + ASSERT_THAT(g, HasValidTopologicalOrder()); + g.RemoveNode(0); + ASSERT_THAT(g, HasValidTopologicalOrder()); + g.RemoveNode(1); + ASSERT_THAT(g, HasValidTopologicalOrder()); +} + +TEST(TopologicalSortTest, Stick) { + TestGraph g; + int n = 20; + for (int i = 0; i < n; ++i) { + g.AddNode(i); + ASSERT_THAT(g, HasValidTopologicalOrder()); + } + for (int i = 0; i < n - 1; ++i) { + g.AddEdge(i, i + 1); + ASSERT_THAT(g, HasValidTopologicalOrder()); + } + for (int i = 0; i < n; ++i) { + g.RemoveNode(i); + ASSERT_THAT(g, HasValidTopologicalOrder()); + } +} + +TEST(TopologicalSortTest, ChangeOrder) { + TestGraph g; + int n = 20; + for (int i = 0; i < n; ++i) { + g.AddNode(i); + ASSERT_THAT(g, HasValidTopologicalOrder()); + } + for (int i = 0; i < n - 1; ++i) { + g.AddEdge(i, i + 1); + ASSERT_THAT(g, HasValidTopologicalOrder()); + } + g.RemoveEdge(13, 14); + ASSERT_THAT(g, HasValidTopologicalOrder()); + g.AddEdge(n - 1, 0); + ASSERT_THAT(g, HasValidTopologicalOrder()); +} + +TEST(TopologicalSortTest, Diamonds) { + TestGraph g; + g.AddNode(0); + for (int i = 0; i < 500; ++i) { + int j = 3 * i; + for (int k = 1; k <= 3; ++k) { + g.AddNode(j + k); + ASSERT_THAT(g, HasValidTopologicalOrder()); + } + g.AddEdge(j, j + 1); + ASSERT_THAT(g, HasValidTopologicalOrder()); + g.AddEdge(j, j + 2); + ASSERT_THAT(g, HasValidTopologicalOrder()); + g.AddEdge(j + 1, j + 3); + ASSERT_THAT(g, HasValidTopologicalOrder()); + g.AddEdge(j + 2, j + 3); + ASSERT_THAT(g, HasValidTopologicalOrder()); + } + ASSERT_THAT(g, HasValidTopologicalOrder()); +} + +TEST(TopologicalSortTest, Random) { + absl::BitGen gen( + absl::MakeTaggedSeedSeq("TestPRNG", absl::LogInfoStreamer().stream())); + + for (int trial = 0; trial < 10; ++trial) { + int n = absl::Uniform(gen, 10, 1000); + int m = absl::Uniform(gen, 0, std::min(n * 5, (n * (n - 1)) / 2)); + LOG(INFO) << "trial: " << trial << " n: " << n << " m: " << m; + std::vector order(n); + TestGraph g; + for (int i = 0; i < n; ++i) { + g.AddNode(i); + } + absl::c_iota(order, 0); + absl::c_shuffle(order, gen); + for (int i = 0; i < m; ++i) { + int a, b; + do { + a = absl::Uniform(gen, 0, n); + b = absl::Uniform(gen, 0, n); + if (a > b) { + std::swap(a, b); + } + } while (a == b || g.HasEdge(order[a], order[b])); + g.AddEdge(order[a], order[b]); + // Note: this check makes the test O(m^2), but it's valuable to verify + // the invariant is maintained. + ASSERT_THAT(g, HasValidTopologicalOrder()); + } + } +} + +} // namespace