From 0c70e61363a2673265ac72b7534c3b35cf2d194c Mon Sep 17 00:00:00 2001 From: Vlad Sytchenko Date: Wed, 4 Sep 2024 15:47:34 -0700 Subject: [PATCH] [XLA] Introduce infeed token propagation During computation inlining, specifically loop unrolling, it is posibble for infeeds (and outfeeds) to get reordered in a way that breaks the original scheduling constraints set by the computation boundaries. This is a result of Tensorflow not exposing tokens for these ops to the user, so the input and output tokens end up dangling. Loop unrolling in XLA can be thought of applying the same function repeatedly to itself, e.g. transforming f(x) into f(f(x)). By pushing the tokens outside the loop body, we can guarantee that the output token of the first infeed will become the input token of the next infeed, thus creating a data dependency chain and preserving the original ordering. Reverts d8f4200c1bea46f4c970da688015d77dde1b0a55 PiperOrigin-RevId: 671128754 --- xla/hlo/ir/hlo_instruction.cc | 11 + xla/hlo/ir/hlo_instruction.h | 1 + xla/python/BUILD | 3 + xla/python/pjit.cc | 117 ++-- xla/python/xla_client.py | 2 +- xla/python/xla_extension/__init__.pyi | 2 +- xla/service/BUILD | 35 + xla/service/infeed_token_propagation.cc | 464 ++++++++++++++ xla/service/infeed_token_propagation.h | 45 ++ xla/service/infeed_token_propagation_test.cc | 601 ++++++++++++++++++ .../while_loop_invariant_code_motion.cc | 1 + 11 files changed, 1244 insertions(+), 38 deletions(-) create mode 100644 xla/service/infeed_token_propagation.cc create mode 100644 xla/service/infeed_token_propagation.h create mode 100644 xla/service/infeed_token_propagation_test.cc diff --git a/xla/hlo/ir/hlo_instruction.cc b/xla/hlo/ir/hlo_instruction.cc index 37d7a39d8ee0e0..40a7a5bd5a1950 100644 --- a/xla/hlo/ir/hlo_instruction.cc +++ b/xla/hlo/ir/hlo_instruction.cc @@ -3411,6 +3411,17 @@ HloComputation* HloInstruction::branch_computation(int b) const { return called_computations()[b]; } +int HloInstruction::branch_index(HloComputation* computation) const { + CHECK(HloOpcode::kConditional == opcode_); + CHECK_NE(computation, nullptr); + for (int idx = 0; idx < branch_count(); idx++) { + if (branch_computation(idx) == computation) { + return idx; + } + } + CHECK(false); +} + void HloInstruction::set_branch_computation(int b, HloComputation* computation) { CHECK_EQ(HloOpcode::kConditional, opcode_); diff --git a/xla/hlo/ir/hlo_instruction.h b/xla/hlo/ir/hlo_instruction.h index 3dcf016acd6cd0..2a3aff0568cbfe 100644 --- a/xla/hlo/ir/hlo_instruction.h +++ b/xla/hlo/ir/hlo_instruction.h @@ -1810,6 +1810,7 @@ class HloInstruction { const PtrVec& branch_computations() const; int branch_count() const; HloComputation* branch_computation(int b) const; + int branch_index(HloComputation* computation) const; // Sets a branch HloComputation for Conditional. // The setter should only be called by HloModule or HloComputation methods. // diff --git a/xla/python/BUILD b/xla/python/BUILD index a9081b5d63f1c7..ca95369305f077 100644 --- a/xla/python/BUILD +++ b/xla/python/BUILD @@ -745,9 +745,12 @@ cc_library( ":traceback", ":transfer_guard_lib", # placeholder for index annotation deps + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/xla/python/pjit.cc b/xla/python/pjit.cc index 12d4814e460cc7..c14556d1b58651 100644 --- a/xla/python/pjit.cc +++ b/xla/python/pjit.cc @@ -28,16 +28,21 @@ limitations under the License. #include #include #include // NOLINT +#include #include #include +#include "absl/base/thread_annotations.h" +#include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" @@ -122,10 +127,10 @@ class PjitFunctionCache { // might be evicted before we finish tracing/compiling. typedef xla::LRUCache> Cache; - // We include as part of the cache key `donate_argnums` (and any other fields - // that aren't subsumed by the CallSignature we compute for each call). + // We include as part of the cache key `global_cache_key` (and any other + // fields that aren't subsumed by the CallSignature we compute for each call). std::shared_ptr Lookup(nb::handle function, - absl::Span donate_argnums); + nb::object global_cache_key); std::shared_ptr DefaultCache(); int Size() const { return lru_list_.Size(); } @@ -138,19 +143,35 @@ class PjitFunctionCache { // Other fields that are part of the arguments to `jit`, but are not // otherwise part of CallSignature. - std::vector donate_argnums; + nb::object global_cache_key; bool operator==(const Key& other) const { - return function.ptr() == other.function.ptr() && - donate_argnums == other.donate_argnums; + bool global_cache_eq; + try { + global_cache_eq = global_cache_key.equal(other.global_cache_key); + } catch (const nanobind::python_error& e) { + throw std::invalid_argument( + absl::StrCat("Equality of global cache key lead to an exception. " + "The error was:\n", + e.what(), "\n")); + } + return function.ptr() == other.function.ptr() && global_cache_eq; } }; template friend H AbslHashValue(H h, const Key& key) { h = H::combine(std::move(h), key.function.ptr()); - h = H::combine_contiguous(std::move(h), key.donate_argnums.data(), - key.donate_argnums.size()); + Py_hash_t hash; + try { + hash = xla::nb_hash(key.global_cache_key); + } catch (const nanobind::python_error& e) { + if (!e.matches(PyExc_TypeError)) throw; + throw std::invalid_argument(absl::StrCat( + "Hashing global cache key lead to an exception. The error was:\n", + e.what(), "\n")); + } + h = H::combine(std::move(h), hash); return h; } @@ -167,7 +188,9 @@ class PjitFunctionCache { }; Cache::LRUList lru_list_; - absl::flat_hash_map> functions_; + absl::Mutex mu_; // Non-trivial hashes need to be mutex locked. + // ABSL containers are not exception safe: + std::unordered_map, absl::Hash> functions_; }; PjitFunctionCache::PjitFunctionCache(int capacity) : lru_list_(capacity) {} @@ -177,11 +200,20 @@ std::shared_ptr PjitFunctionCache::DefaultCache() { } std::shared_ptr PjitFunctionCache::Lookup( - nb::handle function, absl::Span donate_argnums) { + nb::handle function, + nb::object global_cache_key) ABSL_NO_THREAD_SAFETY_ANALYSIS { + { + // Because the gil can be released during cache insertion, this forces + // the lock order to be mu_ then gil so we must release the gil first. + nb::gil_scoped_release release; + // Acquire a mutex to avoid problems where the gil is released during + // cache insertion and then a second thread invalidates the cache order. + mu_.Lock(); + } + absl::Cleanup unlock = [this]() ABSL_UNLOCK_FUNCTION(mu_) { mu_.Unlock(); }; Key key; key.function = function; - key.donate_argnums = - std::vector(donate_argnums.begin(), donate_argnums.end()); + key.global_cache_key = global_cache_key; auto insert = functions_.emplace(key, nullptr); if (!insert.second) { return insert.first->second->cache; @@ -189,7 +221,10 @@ std::shared_ptr PjitFunctionCache::Lookup( std::shared_ptr cache = std::make_shared(&lru_list_); auto callback = nb::cpp_function([this, key{std::move(key)}](nb::handle weakref) { - functions_.erase(key); + auto it = functions_.find(key); + if (it != functions_.end()) { + functions_.erase(it); + } }); PyObject* weakref = PyWeakref_NewRef(function.ptr(), callback.ptr()); if (weakref) { @@ -211,7 +246,7 @@ class PjitFunction { PjitFunction(std::string function_name, std::optional fun, nb::callable cache_miss, std::vector static_argnums, std::vector static_argnames, - std::vector donate_argnums, + nb::object global_cache_key, std::shared_ptr pytree_registry, nb::callable shard_arg_fallback, std::shared_ptr cache); @@ -244,6 +279,8 @@ class PjitFunction { absl::StatusOr Call(nb::handle callable, PyObject* const* args, size_t nargs, PyObject* kwnames); + void InitExecutables(); + void ClearPythonReferences(); const std::string& function_name() const { return function_name_; } @@ -258,7 +295,7 @@ class PjitFunction { const std::vector& static_argnames() const { return static_argnames_; } - const std::vector& donate_argnums() const { return donate_argnums_; } + const nb::object& global_cache_key() const { return global_cache_key_; } const std::shared_ptr& cache() const { return cache_; } int cache_capacity() const { return executables_->Size(); } @@ -291,7 +328,7 @@ class PjitFunction { nb::callable cache_miss_; std::vector static_argnums_; std::vector static_argnames_; - std::vector donate_argnums_; + nb::object global_cache_key_; std::shared_ptr pytree_registry_; nb::callable shard_arg_fallback_; @@ -325,14 +362,14 @@ PjitFunctionStore& GetGlobalPjitFunctionStore() { PjitFunction::PjitFunction( std::string function_name, std::optional fun, nb::callable cache_miss, std::vector static_argnums, - std::vector static_argnames, std::vector donate_argnums, + std::vector static_argnames, nb::object global_cache_key, std::shared_ptr pytree_registry, nb::callable shard_arg_fallback, std::shared_ptr cache) : function_name_(std::move(function_name)), fun_(std::move(fun)), cache_miss_(std::move(cache_miss)), static_argnums_(std::move(static_argnums)), - donate_argnums_(donate_argnums), + global_cache_key_(std::move(global_cache_key)), pytree_registry_(std::move(pytree_registry)), shard_arg_fallback_(std::move(shard_arg_fallback)), cache_(std::move(cache)) { @@ -343,13 +380,16 @@ PjitFunction::PjitFunction( PyUnicode_InternInPlace(&s); static_argnames_.push_back(nb::steal(s)); } + + GetGlobalPjitFunctionStore().Insert(this); +} + +void PjitFunction::InitExecutables() { if (!fun_.has_value()) { executables_ = cache_->DefaultCache(); } else { - executables_ = cache_->Lookup(fun_.value(), donate_argnums); + executables_ = cache_->Lookup(fun_.value(), global_cache_key_); } - - GetGlobalPjitFunctionStore().Insert(this); } PjitFunction::~PjitFunction() { GetGlobalPjitFunctionStore().Erase(this); } @@ -1029,20 +1069,26 @@ void InitializePjitFunction( PjitFunctionObject* fn_obj, std::string function_name, std::optional fun, nb::callable cache_miss, std::vector static_argnums, std::vector static_argnames, - std::vector donate_argnums, + nb::object global_cache_key, std::shared_ptr pytree_registry, nb::callable shard_arg_fallback, std::shared_ptr cache) { + if (nb::isinstance(global_cache_key)) { + global_cache_key = nb::tuple(global_cache_key); + } new (&fn_obj->fun) PjitFunction( std::move(function_name), std::move(fun), std::move(cache_miss), std::move(static_argnums), std::move(static_argnames), - std::move(donate_argnums), std::move(pytree_registry), + std::move(global_cache_key), std::move(pytree_registry), std::move(shard_arg_fallback), std::move(cache)); + // Handled separately because it is not exception safe to call this + // in the constructor because it leaves the object improperly constructed. + fn_obj->fun.InitExecutables(); } nb::object MakePjitFunction( std::string function_name, std::optional fun, nb::callable cache_miss, std::vector static_argnums, - std::vector static_argnames, std::vector donate_argnums, + std::vector static_argnames, nb::object global_cache_key, std::shared_ptr pytree_registry, nb::callable shard_arg_fallback, std::optional> cache) { @@ -1053,11 +1099,11 @@ nb::object MakePjitFunction( cache = std::make_shared( PjitFunctionCache::kDefaultCapacity); } - InitializePjitFunction(fn_obj, std::move(function_name), std::move(fun), - std::move(cache_miss), std::move(static_argnums), - std::move(static_argnames), std::move(donate_argnums), - std::move(pytree_registry), - std::move(shard_arg_fallback), std::move(*cache)); + InitializePjitFunction( + fn_obj, std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(*cache)); return obj; } @@ -1169,7 +1215,7 @@ void BuildPjitSubmodule(nb::module_& m) { pickle["cache_miss"] = fn->cache_miss(); pickle["static_argnums"] = fn->static_argnums(); pickle["static_argnames"] = nb::cast(fn->static_argnames()); - pickle["donate_argnums"] = fn->donate_argnums(); + pickle["global_cache_key"] = fn->global_cache_key(); pickle["pytree_registry"] = nb::cast(fn->pytree_registry()); pickle["shard_arg_fallback"] = fn->shard_arg_fallback(); pickle["cache"] = fn->cache(); @@ -1197,8 +1243,7 @@ void BuildPjitSubmodule(nb::module_& m) { nb::cast>(pickle["static_argnums"]); std::vector static_argnames = nb::cast>(pickle["static_argnames"]); - std::vector donate_argnums = - nb::cast>(pickle["donate_argnums"]); + nb::object global_cache_key = pickle["global_cache_key"]; std::shared_ptr pytree_registry = nb::cast>( nb::handle(pickle["pytree_registry"].ptr())); @@ -1210,7 +1255,7 @@ void BuildPjitSubmodule(nb::module_& m) { reinterpret_cast(self.ptr()), std::move(function_name), std::move(fun), std::move(cache_miss), std::move(static_argnums), std::move(static_argnames), - std::move(donate_argnums), std::move(pytree_registry), + std::move(global_cache_key), std::move(pytree_registry), std::move(shard_arg_fallback), std::move(cache)); }, nb::is_method()); @@ -1236,7 +1281,7 @@ void BuildPjitSubmodule(nb::module_& m) { "pjit", [](std::string function_name, std::optional fun, nb::callable cache_miss, std::vector static_argnums, - std::vector static_argnames, std::vector donate_argnums, + std::vector static_argnames, nb::object global_cache_key, nb::object pytree_registry, nb::callable shard_arg_fallback, std::optional> cache) { std::shared_ptr registry = @@ -1245,12 +1290,12 @@ void BuildPjitSubmodule(nb::module_& m) { return MakePjitFunction( std::move(function_name), std::move(fun), std::move(cache_miss), std::move(static_argnums), std::move(static_argnames), - std::move(donate_argnums), std::move(registry), + std::move(global_cache_key), std::move(registry), std::move(shard_arg_fallback), std::move(cache)); }, nb::arg("function_name"), nb::arg("fun").none(), nb::arg("cache_miss"), nb::arg("static_argnums"), nb::arg("static_argnames"), - nb::arg("donate_argnums"), nb::arg("pytree_registry"), + nb::arg("global_cache_key"), nb::arg("pytree_registry"), nb::arg("shard_arg_fallback"), nb::arg("cache").none() = nb::none()); } diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index 927ad0f19bb0cd..89332de94b0b82 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 285 +_version = 286 # Version number for MLIR:Python components. mlir_api_version = 57 diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index 93088d1ae9c06f..b5ae4c6431ca66 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -932,7 +932,7 @@ def pjit( cache_miss: Callable, static_argnums: Sequence[int], static_argnames: Sequence[str], - donate_argnums: Sequence[int], + global_cache_key: Any, pytree_registry: pytree.PyTreeRegistry, shard_arg_fallback: Callable, cache: Optional[PjitFunctionCache] = ..., diff --git a/xla/service/BUILD b/xla/service/BUILD index 959606abb44b5f..bef5cf91161899 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -8522,4 +8522,39 @@ xla_cc_test( ], ) +cc_library( + name = "infeed_token_propagation", + srcs = ["infeed_token_propagation.cc"], + hdrs = ["infeed_token_propagation.h"], + deps = [ + ":hlo_dce", + ":tuple_simplifier", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +cc_test( + name = "infeed_token_propagation_test", + srcs = ["infeed_token_propagation_test.cc"], + deps = [ + ":infeed_token_propagation", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_matchers", + "//xla/tests:hlo_test_base", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", + ], +) + exports_files(["xla_aot_compile_test_gpu_target_config.prototxt"]) diff --git a/xla/service/infeed_token_propagation.cc b/xla/service/infeed_token_propagation.cc new file mode 100644 index 00000000000000..9408557913aeeb --- /dev/null +++ b/xla/service/infeed_token_propagation.cc @@ -0,0 +1,464 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/infeed_token_propagation.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo_dce.h" +#include "xla/service/tuple_simplifier.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { +bool IsDanglingInfeed(HloInstruction* infeed) { + CHECK(infeed->opcode() == HloOpcode::kInfeed); + if (infeed->has_sharding()) { + // TODO(vsytch): handle SPMD. + return false; + } + + bool is_dangling_input_token = true; + bool is_dangling_output_token = true; + if (const HloInstruction* after_all = infeed->operand(0); + after_all->opcode() != HloOpcode::kAfterAll || + after_all->operand_count() != 0) { + is_dangling_input_token = false; + } + for (const HloInstruction* user : infeed->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement && + user->tuple_index() == 1) { + is_dangling_output_token = false; + } + } + return is_dangling_input_token && is_dangling_output_token; +} + +bool IsDanglingOutfeed(HloInstruction* outfeed) { + CHECK(outfeed->opcode() == HloOpcode::kOutfeed); + if (outfeed->has_sharding()) { + // TODO(vsytch): handle SPMD. + return false; + } + + bool is_dangling_input_token = true; + bool is_dangling_output_token = true; + if (const HloInstruction* after_all = outfeed->operand(1); + after_all->opcode() != HloOpcode::kAfterAll || + after_all->operand_count() != 0) { + is_dangling_input_token = false; + } + if (outfeed->user_count() != 0) { + is_dangling_output_token = false; + } + return is_dangling_input_token && is_dangling_output_token; +} + +HloInstruction* ReconstructTuple(HloInstruction* tuple) { + CHECK(tuple->shape().IsTuple()); + HloComputation* computation = tuple->parent(); + + std::vector gtes; + gtes.resize(tuple->shape().tuple_shapes_size()); + for (int64_t idx = 0; idx < gtes.size(); idx++) { + gtes[idx] = computation->AddInstruction( + HloInstruction::CreateGetTupleElement(tuple, idx)); + } + + return computation->AddInstruction(HloInstruction::CreateTuple(gtes)); +} + +absl::StatusOr InsertTokenIntoTuple(HloInstruction* tuple, + bool add_token_operand) { + CHECK(tuple->shape().IsTuple()); + HloComputation* computation = tuple->parent(); + + // Recreate the original tuple, we'll need to pass this to all the users. + std::vector original_users = tuple->users(); + HloInstruction* original_tuple = ReconstructTuple(tuple); + for (HloInstruction* original_user : original_users) { + int64_t idx = original_user->operand_index(tuple); + TF_RETURN_IF_ERROR(original_user->ReplaceOperandWith(idx, original_tuple)); + } + + // Append the token to the parameter tuple. + *tuple->mutable_shape()->add_tuple_shapes() = ShapeUtil::MakeTokenShape(); + if (add_token_operand) { + tuple->AppendOperand( + computation->AddInstruction(HloInstruction::CreateToken())); + } + + HloInstruction* input_token_gte = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + tuple, tuple->shape().tuple_shapes_size() - 1)); + return input_token_gte; +} + +absl::Status CanonicalizeConditionalBranch(HloComputation* branch) { + CHECK(branch->IsConditionalBranchComputation()); + CHECK_EQ(branch->num_parameters(), 1); + + // Tuplify the branch parameter if needed. + HloInstruction* parameter = branch->parameter_instruction(0); + if (!parameter->shape().IsTuple()) { + *parameter->mutable_shape() = + ShapeUtil::MakeTupleShape({parameter->shape()}); + HloInstruction* original = branch->AddInstruction( + HloInstruction::CreateGetTupleElement(parameter, 0)); + TF_RETURN_IF_ERROR(parameter->ReplaceAllUsesWithDifferentShape(original)); + } + + // Tuplify the branch tuple if needed. + HloInstruction* conditional = branch->ConditionalCallInstruction(); + int64_t branch_operand_idx = conditional->branch_index(branch) + 1; + HloInstruction* branch_tuple = + conditional->mutable_operand(branch_operand_idx); + if (!branch_tuple->shape().IsTuple()) { + branch_tuple = conditional->parent()->AddInstruction( + HloInstruction::CreateTuple({branch_tuple})); + TF_RETURN_IF_ERROR(conditional->ReplaceOperandWithDifferentShape( + branch_operand_idx, branch_tuple)); + } + + // Explicitly disjoin computation parameters from branch inputs, so we can + // insert tokens into the input tuple. + if (branch_tuple->opcode() == HloOpcode::kParameter) { + branch_tuple = ReconstructTuple(branch_tuple); + TF_RETURN_IF_ERROR( + conditional->ReplaceOperandWith(branch_operand_idx, branch_tuple)); + } + + // If the computation root is a also a computation parameter, explicitly split + // them, as the input and output tokens cannot be part of the same + // instruction. + HloInstruction* root = branch->root_instruction(); + if (root->opcode() == HloOpcode::kParameter) { + root = ReconstructTuple(root); + branch->set_root_instruction(root); + } + + // ConditionalCanonicalizer should have already turned the conditional output + // to be a tuple. + CHECK(conditional->shape().IsTuple()); + return absl::OkStatus(); +} + +absl::Status CanonicalizeWhileBody(HloComputation* body) { + CHECK(body->IsWhileBodyComputation()); + CHECK_EQ(body->num_parameters(), 1); + + // Tuplify the body parameter if needed. + HloInstruction* parameter = body->parameter_instruction(0); + if (!parameter->shape().IsTuple()) { + *parameter->mutable_shape() = + ShapeUtil::MakeTupleShape({parameter->shape()}); + HloInstruction* original = body->AddInstruction( + HloInstruction::CreateGetTupleElement(parameter, 0)); + TF_RETURN_IF_ERROR(parameter->ReplaceAllUsesWithDifferentShape(original)); + } + + // Tuplify the body root if needed. + HloInstruction* root = body->root_instruction(); + if (!root->shape().IsTuple()) { + root = body->AddInstruction(HloInstruction::CreateTuple({root})); + body->set_root_instruction(root, /*accept_different_shape=*/true); + } + + // Tuplify the condition parameter if needed. + HloInstruction* loop = body->WhileCallInstruction(); + HloComputation* cond = loop->while_condition(); + HloInstruction* cond_parameter = cond->parameter_instruction(0); + if (!cond_parameter->shape().IsTuple()) { + *cond_parameter->mutable_shape() = + ShapeUtil::MakeTupleShape({cond_parameter->shape()}); + HloInstruction* original = cond->AddInstruction( + HloInstruction::CreateGetTupleElement(cond_parameter, 0)); + TF_RETURN_IF_ERROR( + cond_parameter->ReplaceAllUsesWithDifferentShape(original)); + } + + // Tuplify the while instruction if needed. + if (!loop->shape().IsTuple()) { + *loop->mutable_shape() = ShapeUtil::MakeTupleShape({loop->shape()}); + HloInstruction* original = loop->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(loop, 0)); + TF_RETURN_IF_ERROR(loop->ReplaceAllUsesWithDifferentShape(original)); + } + + // Tuplify the while tuple if needed. + HloInstruction* loop_tuple = loop->mutable_operand(0); + if (!loop_tuple->shape().IsTuple()) { + loop_tuple = loop->parent()->AddInstruction( + HloInstruction::CreateTuple({loop_tuple})); + TF_RETURN_IF_ERROR(loop->ReplaceOperandWithDifferentShape(0, loop_tuple)); + } + + // Explicitly disjoin computation parameters from loop inputs, so we can + // insert tokens into the input tuple. + if (loop_tuple->opcode() == HloOpcode::kParameter) { + loop_tuple = ReconstructTuple(loop_tuple); + TF_RETURN_IF_ERROR(loop->ReplaceOperandWith(0, loop_tuple)); + } + + // If the computation root is a also a computation parameter, explicitly + // split them, as the input and output tokens cannot be part of the same + // instruction. + if (root->opcode() == HloOpcode::kParameter) { + root = ReconstructTuple(root); + body->set_root_instruction(root); + } + + return absl::OkStatus(); +} + +absl::StatusOr> +PropagateTokenThroughConditionalBranch(HloInstruction* instruction, + HloInstruction* input_token, + HloInstruction* output_token) { + // Conditional branches can diverge in inputs, but must converge on outputs. + HloInstruction* next_instruction = nullptr; + HloInstruction* next_input_token = nullptr; + HloInstruction* next_output_token = nullptr; + + // Fixup the branch. + HloComputation* comp = instruction->parent(); + TF_RETURN_IF_ERROR(CanonicalizeConditionalBranch(comp)); + next_instruction = comp->ConditionalCallInstruction(); + + // Insert the output token into each branch. + for (HloComputation* branch : next_instruction->branch_computations()) { + HloInstruction* root = branch->root_instruction(); + if (branch == comp) { + TF_RETURN_IF_ERROR( + InsertTokenIntoTuple(root, /*add_token_operand=*/false).status()); + root->AppendOperand(output_token); + } else { + TF_RETURN_IF_ERROR( + InsertTokenIntoTuple(root, /*add_token_operand=*/true).status()); + } + } + + // Insert the input token into the branch parameter. + HloInstruction* parameter = comp->parameter_instruction(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * input_token_gte, + InsertTokenIntoTuple(parameter, /*add_token_operand=*/false)); + TF_RETURN_IF_ERROR(input_token->ReplaceAllUsesWith(input_token_gte)); + + // Insert the input token into the branch tuple. + int64_t branch_operand_idx = next_instruction->branch_index(comp) + 1; + HloInstruction* branch_tuple = + next_instruction->mutable_operand(branch_operand_idx); + TF_ASSIGN_OR_RETURN( + HloInstruction * next_input_token_gte, + InsertTokenIntoTuple(branch_tuple, /*add_token_operand=*/true)); + TF_RETURN_IF_ERROR(next_instruction->ReplaceOperandWithDifferentShape( + branch_operand_idx, branch_tuple)); + next_input_token = + branch_tuple->mutable_operand(next_input_token_gte->tuple_index()); + + // Insert the output token into conditional instruction. + TF_ASSIGN_OR_RETURN( + next_output_token, + InsertTokenIntoTuple(next_instruction, /*add_token_operand=*/false)); + + return std::make_tuple(next_instruction, next_input_token, next_output_token); +} + +absl::StatusOr> +PropagateTokenThroughWhileBody(HloInstruction* instruction, + HloInstruction* input_token, + HloInstruction* output_token) { + // While loops need to converge on input and output. + HloInstruction* next_instruction = nullptr; + HloInstruction* next_input_token = nullptr; + HloInstruction* next_output_token = nullptr; + + // Fixup the while body. + HloComputation* comp = instruction->parent(); + TF_RETURN_IF_ERROR(CanonicalizeWhileBody(comp)); + next_instruction = comp->WhileCallInstruction(); + + // Insert the output token into the body root. + HloInstruction* root = comp->root_instruction(); + TF_RETURN_IF_ERROR( + InsertTokenIntoTuple(root, /*add_token_operand=*/false).status()); + root->AppendOperand(output_token); + + // Insert the input token into the body parameter. + HloInstruction* body_parameter = comp->parameter_instruction(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * input_token_gte, + InsertTokenIntoTuple(body_parameter, /*add_token_operand=*/false)); + TF_RETURN_IF_ERROR(input_token->ReplaceAllUsesWith(input_token_gte)); + + // Insert the input token into the condition parameter. + HloComputation* cond = next_instruction->while_condition(); + HloInstruction* cond_parameter = cond->parameter_instruction(0); + TF_RETURN_IF_ERROR( + InsertTokenIntoTuple(cond_parameter, /*add_token_operand=*/false) + .status()); + + // Insert the input token into the while tuple. + HloInstruction* while_tuple = next_instruction->mutable_operand(0); + TF_ASSIGN_OR_RETURN( + next_input_token, + InsertTokenIntoTuple(while_tuple, /*add_token_operand=*/true)); + TF_RETURN_IF_ERROR( + next_instruction->ReplaceOperandWithDifferentShape(0, while_tuple)); + + // Insert the input token into the while instruction. + TF_ASSIGN_OR_RETURN( + next_output_token, + InsertTokenIntoTuple(next_instruction, /*add_token_operand=*/false)); + + return std::make_tuple(next_instruction, next_input_token, next_output_token); +} + +absl::Status PropagateToken(HloInstruction* instruction, + HloInstruction* input_token, + HloInstruction* output_token) { + HloComputation* comp = instruction->parent(); + if (comp->IsEntryComputation()) { + // If we propagate through the root instruction, reconstruct the original + // tuple and set that to be root. + if (instruction->IsRoot() && + (instruction->opcode() == HloOpcode::kWhile || + instruction->opcode() == HloOpcode::kConditional)) { + std::vector gtes; + int64_t output_token_idx = output_token->tuple_index(); + for (int64_t idx = 0; idx < instruction->shape().tuple_shapes_size(); + idx++) { + if (idx != output_token_idx) { + gtes.push_back(comp->AddInstruction( + HloInstruction::CreateGetTupleElement(instruction, idx))); + } + } + HloInstruction* original_tuple = + comp->AddInstruction(HloInstruction::CreateTuple(gtes)); + comp->set_root_instruction(original_tuple, + /*accept_different_shape=*/true); + } + return absl::OkStatus(); + } + + HloInstruction* next_instruction = nullptr; + HloInstruction* next_input_token = nullptr; + HloInstruction* next_output_token = nullptr; + if (comp->IsConditionalBranchComputation()) { + // TODO(vsytch): handle SPMD. + if (comp->ConditionalCallInstruction()->has_sharding()) { + return absl::OkStatus(); + } + TF_ASSIGN_OR_RETURN( + std::tie(next_instruction, next_input_token, next_output_token), + PropagateTokenThroughConditionalBranch(instruction, input_token, + output_token)); + } else if (comp->IsWhileBodyComputation()) { + // TODO(vsytch): handle SPMD. + if (comp->WhileCallInstruction()->has_sharding()) { + return absl::OkStatus(); + } + TF_ASSIGN_OR_RETURN( + std::tie(next_instruction, next_input_token, next_output_token), + PropagateTokenThroughWhileBody(instruction, input_token, output_token)); + } else { + // We only expect to encounter computations behind while and conditional + // instructions. In the case of it being behind a while condition, there is + // no way to propagate the output token, as the root only returns a + // predicate. All other computations that could possibly contain infeed + // or outfeed ops should have already been inlined. + VLOG(2) << "Unhandled computation: " << comp->name(); + return absl::OkStatus(); + } + CHECK_NE(next_instruction, nullptr); + CHECK_NE(next_input_token, nullptr); + CHECK_NE(next_output_token, nullptr); + + return PropagateToken(next_instruction, next_input_token, next_output_token); +} +} // namespace + +absl::StatusOr InfeedTokenPropagation::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + VLOG(5) << "Before InfeedTokenPropagation:"; + XLA_VLOG_LINES(5, module->ToString()); + + std::vector dangling_infeeds; + std::vector dangling_outfeeds; + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + if (!computation->IsEntryComputation()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kInfeed && + IsDanglingInfeed(instruction)) { + VLOG(1) << "Found dangling infeed: " << instruction->ToString(); + dangling_infeeds.push_back(instruction); + } else if (instruction->opcode() == HloOpcode::kOutfeed && + IsDanglingOutfeed(instruction)) { + VLOG(1) << "Found dangling outfeed: " << instruction->ToString(); + dangling_outfeeds.push_back(instruction); + } + } + } + } + if (!dangling_infeeds.empty() || !dangling_outfeeds.empty()) { + changed = true; + } + + for (HloInstruction* dangling_infeed : dangling_infeeds) { + HloInstruction* input_token = dangling_infeed->mutable_operand(0); + HloInstruction* output_token = dangling_infeed->AddInstruction( + HloInstruction::CreateGetTupleElement(dangling_infeed, 1)); + TF_RETURN_IF_ERROR( + PropagateToken(dangling_infeed, input_token, output_token)); + } + for (HloInstruction* dangling_outfeed : dangling_outfeeds) { + HloInstruction* input_token = dangling_outfeed->mutable_operand(1); + HloInstruction* output_token = dangling_outfeed; + TF_RETURN_IF_ERROR( + PropagateToken(dangling_outfeed, input_token, output_token)); + } + + if (changed) { + TF_RETURN_IF_ERROR( + TupleSimplifier().Run(module, execution_threads).status()); + TF_RETURN_IF_ERROR(HloDCE().Run(module, execution_threads).status()); + } + + VLOG(5) << "After InfeedTokenPropagation:"; + XLA_VLOG_LINES(5, module->ToString()); + return changed; +} +} // namespace xla diff --git a/xla/service/infeed_token_propagation.h b/xla/service/infeed_token_propagation.h new file mode 100644 index 00000000000000..cc6994a62a98a9 --- /dev/null +++ b/xla/service/infeed_token_propagation.h @@ -0,0 +1,45 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ +#define XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { +// Finds dangling infeed/outfeed tokens inside nested computations and bubbles +// them up through callers until they reach the entry computation. This is +// needed to prepare these computations to be inlined, otherwise the previous +// computation boundaries won't be there to stop infeeds/outfeeds from being +// reordered during scheduling. +// +// This pass assumes the HLO graph is flattened. +class InfeedTokenPropagation : public HloModulePass { + public: + std::string_view name() const override { return "infeed-token-propagation"; } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; +} // namespace xla + +#endif // XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ diff --git a/xla/service/infeed_token_propagation_test.cc b/xla/service/infeed_token_propagation_test.cc new file mode 100644 index 00000000000000..8c1024253868d6 --- /dev/null +++ b/xla/service/infeed_token_propagation_test.cc @@ -0,0 +1,601 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/infeed_token_propagation.h" + +#include +#include + +#include +#include +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/utils/hlo_matchers.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +class InfeedTokenPropagationTest : public HloTestBase { + protected: + InfeedTokenPropagationTest() = default; +}; + +TEST_F(InfeedTokenPropagationTest, EntryComputationInfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +ENTRY main { + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + ROOT gte.0 = get-tuple-element(infeed.0), index=0 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(InfeedTokenPropagationTest, EntryComputationOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +ENTRY main { + arg.0 = s32[] parameter(0) + tuple.0 = tuple(arg.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(tuple.0, token.0), outfeed_shape=(s32[]) + ROOT tuple.1 = tuple() +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(InfeedTokenPropagationTest, ConditionalInfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = () parameter(0) + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + pred.0 = pred[] constant(true) + true_tuple.0 = tuple() + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The infeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The infeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), + op::Tuple(op::GetTupleElement(op::Infeed(), 1))); + + // The infeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, ConditionalOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = (s32[]) parameter(0) + token.0 = after-all() + outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=(s32[]) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + arg.0 = s32[] parameter(0) + pred.0 = pred[] constant(true) + true_tuple.0 = tuple(arg.0) + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The outfeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, NonTupleConditional) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = s32[] parameter(0) + outfeed_tuple.0 = tuple(arg.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(outfeed_tuple.0, token.0), outfeed_shape=(s32[]) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + arg.0 = s32[] parameter(0) + pred.0 = pred[] constant(true) + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, arg.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = cond->mutable_operand(1); + EXPECT_TRUE(true_tuple->shape().IsTuple()); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The outfeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, DisjointConditionalOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + ROOT arg.0 = () parameter(0) + one.0 = s32[] constant(1) + outfeed_tuple.0 = tuple(one.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(outfeed_tuple.0, token.0), outfeed_shape=(s32[]) +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + pred.0 = pred[] constant(true) + true_tuple.0 = tuple() + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The outfeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, WhileInfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +comp { + arg.0 = () parameter(0) + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + ROOT tuple.0 = tuple() +} + +cond { + arg.0 = () parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + while_tuple.0 = tuple() + ROOT while.0 = () while(while_tuple.0), condition=cond, body=comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The infeed output token should have propagated through the loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(loop->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should have propagated through the loop tuple. + HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); + EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[0].IsToken()); + + // The infeed output token should have propagated through the while body root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), + op::Tuple(op::GetTupleElement(op::Infeed(), 1))); + + // The infeed input token should have propagated through the body parameter. + HloInstruction* body_param = body_comp->parameter_instruction(0); + EXPECT_EQ(body_param->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(body_param->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should have propagated through the condition + // parameter. + HloComputation* cond_comp = FindComputation(module.get(), "cond"); + HloInstruction* cond_param = cond_comp->parameter_instruction(0); + EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond_param->shape().tuple_shapes()[0].IsToken()); +} + +TEST_F(InfeedTokenPropagationTest, WhileOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +comp { + arg.0 = (s32[]) parameter(0) + token.0 = after-all() + outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=(s32[]) + gte.0 = get-tuple-element(arg.0), index=0 + ROOT tuple.0 = tuple(gte.0) +} + +cond { + arg.0 = (s32[]) parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + arg.0 = s32[] parameter(0) + while_tuple.0 = tuple(arg.0) + ROOT while.0 = (s32[]) while(while_tuple.0), condition=cond, body=comp +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should have propagated through the loop tuple. + HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); + EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[1].IsToken()); + + // The outfeed output token should have propagated through the while body + // root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), + op::Tuple(op::GetTupleElement(), op::Outfeed())); + + // The outfeed output token should have propagated through the body parameter. + HloInstruction* body_param = body_comp->parameter_instruction(0); + EXPECT_EQ(body_param->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(body_param->shape().tuple_shapes()[1].IsToken()); + + // The outfeed output token should have propagated through the condition + // parameter. + HloComputation* cond_comp = FindComputation(module.get(), "cond"); + HloInstruction* cond_param = cond_comp->parameter_instruction(0); + EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(cond_param->shape().tuple_shapes()[1].IsToken()); +} + +TEST_F(InfeedTokenPropagationTest, DisjointWhileOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +comp { + ROOT arg.0 = () parameter(0) + one.0 = s32[] constant(1) + outfeed_tuple.0 = tuple(one.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(outfeed_tuple.0, token.0), outfeed_shape=(s32[]) +} + +cond { + arg.0 = () parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + while_tuple.0 = tuple() + ROOT while.0 = () while(while_tuple.0), condition=cond, body=comp +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(loop->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the loop tuple. + HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); + EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[0].IsToken()); + + // The outfeed output token should have propagated through the while body + // root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated through the body parameter. + HloInstruction* body_param = body_comp->parameter_instruction(0); + EXPECT_EQ(body_param->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(body_param->shape().tuple_shapes()[0].IsToken()); + + // The outfeed output token should have propagated through the condition + // parameter. + HloComputation* cond_comp = FindComputation(module.get(), "cond"); + HloInstruction* cond_param = cond_comp->parameter_instruction(0); + EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond_param->shape().tuple_shapes()[0].IsToken()); +} + +TEST_F(InfeedTokenPropagationTest, NonTupleWhile) { + constexpr std::string_view hlo = R"( +HloModule main + +comp { + ROOT arg.0 = s32[] parameter(0) + tuple.0 = tuple(arg.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(tuple.0, token.0), outfeed_shape=(s32[]) +} + +cond { + arg.0 = s32[] parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + arg.0 = s32[] parameter(0) + ROOT while.0 = s32[] while(arg.0), condition=cond, body=comp +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_TRUE(loop->shape().IsTuple()); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should have propagated through the loop tuple. + EXPECT_THAT(loop->operand(0), op::Tuple(op::Parameter(), op::AfterAll())); + + // The outfeed output token should have propagated through the while body + // root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), + op::Tuple(op::GetTupleElement(), op::Outfeed())); + + // The outfeed output token should have propagated through the body parameter. + HloInstruction* body_param = body_comp->parameter_instruction(0); + EXPECT_EQ(body_param->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(body_param->shape().tuple_shapes()[1].IsToken()); + + // The outfeed output token should have propagated through the condition + // parameter. + HloComputation* cond_comp = FindComputation(module.get(), "cond"); + HloInstruction* cond_param = cond_comp->parameter_instruction(0); + EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(cond_param->shape().tuple_shapes()[1].IsToken()); +} + +TEST_F(InfeedTokenPropagationTest, NestedInfeedOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = (s32[]) parameter(0) + token.0 = after-all() + outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=(s32[]) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +comp { + arg.0 = () parameter(0) + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + gte.0 = get-tuple-element(infeed.0), index=0 + pred.0 = pred[] constant(true) + true_tuple.0 = tuple(gte.0) + false_tuple.0 = tuple() + cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp + ROOT tuple.0 = tuple() +} + +cond { + arg.0 = () parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + while_tuple.0 = tuple() + ROOT while.0 = () while(while_tuple.0), condition=cond, body=comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The infeed and outfeed output tokens should have propagated through the + // loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop->shape().tuple_shapes()[0].IsToken()); + EXPECT_TRUE(loop->shape().tuple_shapes()[1].IsToken()); + + // The infeed and outfeed input tokens should have propagated through the loop + // tuple. + HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); + EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[0].IsToken()); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[1].IsToken()); + + // The infeed and outfeed output tokens should have propagated through the + // while body root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), + op::Tuple(op::GetTupleElement(op::Infeed(), 1), + op::GetTupleElement(op::Conditional(), 0))); + + // The outfeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The outfeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} +} // namespace +} // namespace xla diff --git a/xla/service/while_loop_invariant_code_motion.cc b/xla/service/while_loop_invariant_code_motion.cc index ed44547af3fca4..423b15c69a63bf 100644 --- a/xla/service/while_loop_invariant_code_motion.cc +++ b/xla/service/while_loop_invariant_code_motion.cc @@ -134,6 +134,7 @@ bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually( case HloOpcode::kReshape: return !hoist_reshapes_; + case HloOpcode::kAfterAll: case HloOpcode::kBitcast: case HloOpcode::kBroadcast: case HloOpcode::kIota: