diff --git a/xla/hlo/ir/hlo_instruction.cc b/xla/hlo/ir/hlo_instruction.cc index 37d7a39d8ee0e0..40a7a5bd5a1950 100644 --- a/xla/hlo/ir/hlo_instruction.cc +++ b/xla/hlo/ir/hlo_instruction.cc @@ -3411,6 +3411,17 @@ HloComputation* HloInstruction::branch_computation(int b) const { return called_computations()[b]; } +int HloInstruction::branch_index(HloComputation* computation) const { + CHECK(HloOpcode::kConditional == opcode_); + CHECK_NE(computation, nullptr); + for (int idx = 0; idx < branch_count(); idx++) { + if (branch_computation(idx) == computation) { + return idx; + } + } + CHECK(false); +} + void HloInstruction::set_branch_computation(int b, HloComputation* computation) { CHECK_EQ(HloOpcode::kConditional, opcode_); diff --git a/xla/hlo/ir/hlo_instruction.h b/xla/hlo/ir/hlo_instruction.h index 3dcf016acd6cd0..2a3aff0568cbfe 100644 --- a/xla/hlo/ir/hlo_instruction.h +++ b/xla/hlo/ir/hlo_instruction.h @@ -1810,6 +1810,7 @@ class HloInstruction { const PtrVec& branch_computations() const; int branch_count() const; HloComputation* branch_computation(int b) const; + int branch_index(HloComputation* computation) const; // Sets a branch HloComputation for Conditional. // The setter should only be called by HloModule or HloComputation methods. // diff --git a/xla/python/BUILD b/xla/python/BUILD index a9081b5d63f1c7..ca95369305f077 100644 --- a/xla/python/BUILD +++ b/xla/python/BUILD @@ -745,9 +745,12 @@ cc_library( ":traceback", ":transfer_guard_lib", # placeholder for index annotation deps + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/xla/python/pjit.cc b/xla/python/pjit.cc index 12d4814e460cc7..c14556d1b58651 100644 --- a/xla/python/pjit.cc +++ b/xla/python/pjit.cc @@ -28,16 +28,21 @@ limitations under the License. #include #include #include // NOLINT +#include #include #include +#include "absl/base/thread_annotations.h" +#include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" @@ -122,10 +127,10 @@ class PjitFunctionCache { // might be evicted before we finish tracing/compiling. typedef xla::LRUCache> Cache; - // We include as part of the cache key `donate_argnums` (and any other fields - // that aren't subsumed by the CallSignature we compute for each call). + // We include as part of the cache key `global_cache_key` (and any other + // fields that aren't subsumed by the CallSignature we compute for each call). std::shared_ptr Lookup(nb::handle function, - absl::Span donate_argnums); + nb::object global_cache_key); std::shared_ptr DefaultCache(); int Size() const { return lru_list_.Size(); } @@ -138,19 +143,35 @@ class PjitFunctionCache { // Other fields that are part of the arguments to `jit`, but are not // otherwise part of CallSignature. - std::vector donate_argnums; + nb::object global_cache_key; bool operator==(const Key& other) const { - return function.ptr() == other.function.ptr() && - donate_argnums == other.donate_argnums; + bool global_cache_eq; + try { + global_cache_eq = global_cache_key.equal(other.global_cache_key); + } catch (const nanobind::python_error& e) { + throw std::invalid_argument( + absl::StrCat("Equality of global cache key lead to an exception. " + "The error was:\n", + e.what(), "\n")); + } + return function.ptr() == other.function.ptr() && global_cache_eq; } }; template friend H AbslHashValue(H h, const Key& key) { h = H::combine(std::move(h), key.function.ptr()); - h = H::combine_contiguous(std::move(h), key.donate_argnums.data(), - key.donate_argnums.size()); + Py_hash_t hash; + try { + hash = xla::nb_hash(key.global_cache_key); + } catch (const nanobind::python_error& e) { + if (!e.matches(PyExc_TypeError)) throw; + throw std::invalid_argument(absl::StrCat( + "Hashing global cache key lead to an exception. The error was:\n", + e.what(), "\n")); + } + h = H::combine(std::move(h), hash); return h; } @@ -167,7 +188,9 @@ class PjitFunctionCache { }; Cache::LRUList lru_list_; - absl::flat_hash_map> functions_; + absl::Mutex mu_; // Non-trivial hashes need to be mutex locked. + // ABSL containers are not exception safe: + std::unordered_map, absl::Hash> functions_; }; PjitFunctionCache::PjitFunctionCache(int capacity) : lru_list_(capacity) {} @@ -177,11 +200,20 @@ std::shared_ptr PjitFunctionCache::DefaultCache() { } std::shared_ptr PjitFunctionCache::Lookup( - nb::handle function, absl::Span donate_argnums) { + nb::handle function, + nb::object global_cache_key) ABSL_NO_THREAD_SAFETY_ANALYSIS { + { + // Because the gil can be released during cache insertion, this forces + // the lock order to be mu_ then gil so we must release the gil first. + nb::gil_scoped_release release; + // Acquire a mutex to avoid problems where the gil is released during + // cache insertion and then a second thread invalidates the cache order. + mu_.Lock(); + } + absl::Cleanup unlock = [this]() ABSL_UNLOCK_FUNCTION(mu_) { mu_.Unlock(); }; Key key; key.function = function; - key.donate_argnums = - std::vector(donate_argnums.begin(), donate_argnums.end()); + key.global_cache_key = global_cache_key; auto insert = functions_.emplace(key, nullptr); if (!insert.second) { return insert.first->second->cache; @@ -189,7 +221,10 @@ std::shared_ptr PjitFunctionCache::Lookup( std::shared_ptr cache = std::make_shared(&lru_list_); auto callback = nb::cpp_function([this, key{std::move(key)}](nb::handle weakref) { - functions_.erase(key); + auto it = functions_.find(key); + if (it != functions_.end()) { + functions_.erase(it); + } }); PyObject* weakref = PyWeakref_NewRef(function.ptr(), callback.ptr()); if (weakref) { @@ -211,7 +246,7 @@ class PjitFunction { PjitFunction(std::string function_name, std::optional fun, nb::callable cache_miss, std::vector static_argnums, std::vector static_argnames, - std::vector donate_argnums, + nb::object global_cache_key, std::shared_ptr pytree_registry, nb::callable shard_arg_fallback, std::shared_ptr cache); @@ -244,6 +279,8 @@ class PjitFunction { absl::StatusOr Call(nb::handle callable, PyObject* const* args, size_t nargs, PyObject* kwnames); + void InitExecutables(); + void ClearPythonReferences(); const std::string& function_name() const { return function_name_; } @@ -258,7 +295,7 @@ class PjitFunction { const std::vector& static_argnames() const { return static_argnames_; } - const std::vector& donate_argnums() const { return donate_argnums_; } + const nb::object& global_cache_key() const { return global_cache_key_; } const std::shared_ptr& cache() const { return cache_; } int cache_capacity() const { return executables_->Size(); } @@ -291,7 +328,7 @@ class PjitFunction { nb::callable cache_miss_; std::vector static_argnums_; std::vector static_argnames_; - std::vector donate_argnums_; + nb::object global_cache_key_; std::shared_ptr pytree_registry_; nb::callable shard_arg_fallback_; @@ -325,14 +362,14 @@ PjitFunctionStore& GetGlobalPjitFunctionStore() { PjitFunction::PjitFunction( std::string function_name, std::optional fun, nb::callable cache_miss, std::vector static_argnums, - std::vector static_argnames, std::vector donate_argnums, + std::vector static_argnames, nb::object global_cache_key, std::shared_ptr pytree_registry, nb::callable shard_arg_fallback, std::shared_ptr cache) : function_name_(std::move(function_name)), fun_(std::move(fun)), cache_miss_(std::move(cache_miss)), static_argnums_(std::move(static_argnums)), - donate_argnums_(donate_argnums), + global_cache_key_(std::move(global_cache_key)), pytree_registry_(std::move(pytree_registry)), shard_arg_fallback_(std::move(shard_arg_fallback)), cache_(std::move(cache)) { @@ -343,13 +380,16 @@ PjitFunction::PjitFunction( PyUnicode_InternInPlace(&s); static_argnames_.push_back(nb::steal(s)); } + + GetGlobalPjitFunctionStore().Insert(this); +} + +void PjitFunction::InitExecutables() { if (!fun_.has_value()) { executables_ = cache_->DefaultCache(); } else { - executables_ = cache_->Lookup(fun_.value(), donate_argnums); + executables_ = cache_->Lookup(fun_.value(), global_cache_key_); } - - GetGlobalPjitFunctionStore().Insert(this); } PjitFunction::~PjitFunction() { GetGlobalPjitFunctionStore().Erase(this); } @@ -1029,20 +1069,26 @@ void InitializePjitFunction( PjitFunctionObject* fn_obj, std::string function_name, std::optional fun, nb::callable cache_miss, std::vector static_argnums, std::vector static_argnames, - std::vector donate_argnums, + nb::object global_cache_key, std::shared_ptr pytree_registry, nb::callable shard_arg_fallback, std::shared_ptr cache) { + if (nb::isinstance(global_cache_key)) { + global_cache_key = nb::tuple(global_cache_key); + } new (&fn_obj->fun) PjitFunction( std::move(function_name), std::move(fun), std::move(cache_miss), std::move(static_argnums), std::move(static_argnames), - std::move(donate_argnums), std::move(pytree_registry), + std::move(global_cache_key), std::move(pytree_registry), std::move(shard_arg_fallback), std::move(cache)); + // Handled separately because it is not exception safe to call this + // in the constructor because it leaves the object improperly constructed. + fn_obj->fun.InitExecutables(); } nb::object MakePjitFunction( std::string function_name, std::optional fun, nb::callable cache_miss, std::vector static_argnums, - std::vector static_argnames, std::vector donate_argnums, + std::vector static_argnames, nb::object global_cache_key, std::shared_ptr pytree_registry, nb::callable shard_arg_fallback, std::optional> cache) { @@ -1053,11 +1099,11 @@ nb::object MakePjitFunction( cache = std::make_shared( PjitFunctionCache::kDefaultCapacity); } - InitializePjitFunction(fn_obj, std::move(function_name), std::move(fun), - std::move(cache_miss), std::move(static_argnums), - std::move(static_argnames), std::move(donate_argnums), - std::move(pytree_registry), - std::move(shard_arg_fallback), std::move(*cache)); + InitializePjitFunction( + fn_obj, std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(*cache)); return obj; } @@ -1169,7 +1215,7 @@ void BuildPjitSubmodule(nb::module_& m) { pickle["cache_miss"] = fn->cache_miss(); pickle["static_argnums"] = fn->static_argnums(); pickle["static_argnames"] = nb::cast(fn->static_argnames()); - pickle["donate_argnums"] = fn->donate_argnums(); + pickle["global_cache_key"] = fn->global_cache_key(); pickle["pytree_registry"] = nb::cast(fn->pytree_registry()); pickle["shard_arg_fallback"] = fn->shard_arg_fallback(); pickle["cache"] = fn->cache(); @@ -1197,8 +1243,7 @@ void BuildPjitSubmodule(nb::module_& m) { nb::cast>(pickle["static_argnums"]); std::vector static_argnames = nb::cast>(pickle["static_argnames"]); - std::vector donate_argnums = - nb::cast>(pickle["donate_argnums"]); + nb::object global_cache_key = pickle["global_cache_key"]; std::shared_ptr pytree_registry = nb::cast>( nb::handle(pickle["pytree_registry"].ptr())); @@ -1210,7 +1255,7 @@ void BuildPjitSubmodule(nb::module_& m) { reinterpret_cast(self.ptr()), std::move(function_name), std::move(fun), std::move(cache_miss), std::move(static_argnums), std::move(static_argnames), - std::move(donate_argnums), std::move(pytree_registry), + std::move(global_cache_key), std::move(pytree_registry), std::move(shard_arg_fallback), std::move(cache)); }, nb::is_method()); @@ -1236,7 +1281,7 @@ void BuildPjitSubmodule(nb::module_& m) { "pjit", [](std::string function_name, std::optional fun, nb::callable cache_miss, std::vector static_argnums, - std::vector static_argnames, std::vector donate_argnums, + std::vector static_argnames, nb::object global_cache_key, nb::object pytree_registry, nb::callable shard_arg_fallback, std::optional> cache) { std::shared_ptr registry = @@ -1245,12 +1290,12 @@ void BuildPjitSubmodule(nb::module_& m) { return MakePjitFunction( std::move(function_name), std::move(fun), std::move(cache_miss), std::move(static_argnums), std::move(static_argnames), - std::move(donate_argnums), std::move(registry), + std::move(global_cache_key), std::move(registry), std::move(shard_arg_fallback), std::move(cache)); }, nb::arg("function_name"), nb::arg("fun").none(), nb::arg("cache_miss"), nb::arg("static_argnums"), nb::arg("static_argnames"), - nb::arg("donate_argnums"), nb::arg("pytree_registry"), + nb::arg("global_cache_key"), nb::arg("pytree_registry"), nb::arg("shard_arg_fallback"), nb::arg("cache").none() = nb::none()); } diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index 927ad0f19bb0cd..89332de94b0b82 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 285 +_version = 286 # Version number for MLIR:Python components. mlir_api_version = 57 diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index 93088d1ae9c06f..b5ae4c6431ca66 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -932,7 +932,7 @@ def pjit( cache_miss: Callable, static_argnums: Sequence[int], static_argnames: Sequence[str], - donate_argnums: Sequence[int], + global_cache_key: Any, pytree_registry: pytree.PyTreeRegistry, shard_arg_fallback: Callable, cache: Optional[PjitFunctionCache] = ..., diff --git a/xla/service/BUILD b/xla/service/BUILD index 959606abb44b5f..bef5cf91161899 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -8522,4 +8522,39 @@ xla_cc_test( ], ) +cc_library( + name = "infeed_token_propagation", + srcs = ["infeed_token_propagation.cc"], + hdrs = ["infeed_token_propagation.h"], + deps = [ + ":hlo_dce", + ":tuple_simplifier", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +cc_test( + name = "infeed_token_propagation_test", + srcs = ["infeed_token_propagation_test.cc"], + deps = [ + ":infeed_token_propagation", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_matchers", + "//xla/tests:hlo_test_base", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", + ], +) + exports_files(["xla_aot_compile_test_gpu_target_config.prototxt"]) diff --git a/xla/service/infeed_token_propagation.cc b/xla/service/infeed_token_propagation.cc new file mode 100644 index 00000000000000..9408557913aeeb --- /dev/null +++ b/xla/service/infeed_token_propagation.cc @@ -0,0 +1,464 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/infeed_token_propagation.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo_dce.h" +#include "xla/service/tuple_simplifier.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { +bool IsDanglingInfeed(HloInstruction* infeed) { + CHECK(infeed->opcode() == HloOpcode::kInfeed); + if (infeed->has_sharding()) { + // TODO(vsytch): handle SPMD. + return false; + } + + bool is_dangling_input_token = true; + bool is_dangling_output_token = true; + if (const HloInstruction* after_all = infeed->operand(0); + after_all->opcode() != HloOpcode::kAfterAll || + after_all->operand_count() != 0) { + is_dangling_input_token = false; + } + for (const HloInstruction* user : infeed->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement && + user->tuple_index() == 1) { + is_dangling_output_token = false; + } + } + return is_dangling_input_token && is_dangling_output_token; +} + +bool IsDanglingOutfeed(HloInstruction* outfeed) { + CHECK(outfeed->opcode() == HloOpcode::kOutfeed); + if (outfeed->has_sharding()) { + // TODO(vsytch): handle SPMD. + return false; + } + + bool is_dangling_input_token = true; + bool is_dangling_output_token = true; + if (const HloInstruction* after_all = outfeed->operand(1); + after_all->opcode() != HloOpcode::kAfterAll || + after_all->operand_count() != 0) { + is_dangling_input_token = false; + } + if (outfeed->user_count() != 0) { + is_dangling_output_token = false; + } + return is_dangling_input_token && is_dangling_output_token; +} + +HloInstruction* ReconstructTuple(HloInstruction* tuple) { + CHECK(tuple->shape().IsTuple()); + HloComputation* computation = tuple->parent(); + + std::vector gtes; + gtes.resize(tuple->shape().tuple_shapes_size()); + for (int64_t idx = 0; idx < gtes.size(); idx++) { + gtes[idx] = computation->AddInstruction( + HloInstruction::CreateGetTupleElement(tuple, idx)); + } + + return computation->AddInstruction(HloInstruction::CreateTuple(gtes)); +} + +absl::StatusOr InsertTokenIntoTuple(HloInstruction* tuple, + bool add_token_operand) { + CHECK(tuple->shape().IsTuple()); + HloComputation* computation = tuple->parent(); + + // Recreate the original tuple, we'll need to pass this to all the users. + std::vector original_users = tuple->users(); + HloInstruction* original_tuple = ReconstructTuple(tuple); + for (HloInstruction* original_user : original_users) { + int64_t idx = original_user->operand_index(tuple); + TF_RETURN_IF_ERROR(original_user->ReplaceOperandWith(idx, original_tuple)); + } + + // Append the token to the parameter tuple. + *tuple->mutable_shape()->add_tuple_shapes() = ShapeUtil::MakeTokenShape(); + if (add_token_operand) { + tuple->AppendOperand( + computation->AddInstruction(HloInstruction::CreateToken())); + } + + HloInstruction* input_token_gte = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + tuple, tuple->shape().tuple_shapes_size() - 1)); + return input_token_gte; +} + +absl::Status CanonicalizeConditionalBranch(HloComputation* branch) { + CHECK(branch->IsConditionalBranchComputation()); + CHECK_EQ(branch->num_parameters(), 1); + + // Tuplify the branch parameter if needed. + HloInstruction* parameter = branch->parameter_instruction(0); + if (!parameter->shape().IsTuple()) { + *parameter->mutable_shape() = + ShapeUtil::MakeTupleShape({parameter->shape()}); + HloInstruction* original = branch->AddInstruction( + HloInstruction::CreateGetTupleElement(parameter, 0)); + TF_RETURN_IF_ERROR(parameter->ReplaceAllUsesWithDifferentShape(original)); + } + + // Tuplify the branch tuple if needed. + HloInstruction* conditional = branch->ConditionalCallInstruction(); + int64_t branch_operand_idx = conditional->branch_index(branch) + 1; + HloInstruction* branch_tuple = + conditional->mutable_operand(branch_operand_idx); + if (!branch_tuple->shape().IsTuple()) { + branch_tuple = conditional->parent()->AddInstruction( + HloInstruction::CreateTuple({branch_tuple})); + TF_RETURN_IF_ERROR(conditional->ReplaceOperandWithDifferentShape( + branch_operand_idx, branch_tuple)); + } + + // Explicitly disjoin computation parameters from branch inputs, so we can + // insert tokens into the input tuple. + if (branch_tuple->opcode() == HloOpcode::kParameter) { + branch_tuple = ReconstructTuple(branch_tuple); + TF_RETURN_IF_ERROR( + conditional->ReplaceOperandWith(branch_operand_idx, branch_tuple)); + } + + // If the computation root is a also a computation parameter, explicitly split + // them, as the input and output tokens cannot be part of the same + // instruction. + HloInstruction* root = branch->root_instruction(); + if (root->opcode() == HloOpcode::kParameter) { + root = ReconstructTuple(root); + branch->set_root_instruction(root); + } + + // ConditionalCanonicalizer should have already turned the conditional output + // to be a tuple. + CHECK(conditional->shape().IsTuple()); + return absl::OkStatus(); +} + +absl::Status CanonicalizeWhileBody(HloComputation* body) { + CHECK(body->IsWhileBodyComputation()); + CHECK_EQ(body->num_parameters(), 1); + + // Tuplify the body parameter if needed. + HloInstruction* parameter = body->parameter_instruction(0); + if (!parameter->shape().IsTuple()) { + *parameter->mutable_shape() = + ShapeUtil::MakeTupleShape({parameter->shape()}); + HloInstruction* original = body->AddInstruction( + HloInstruction::CreateGetTupleElement(parameter, 0)); + TF_RETURN_IF_ERROR(parameter->ReplaceAllUsesWithDifferentShape(original)); + } + + // Tuplify the body root if needed. + HloInstruction* root = body->root_instruction(); + if (!root->shape().IsTuple()) { + root = body->AddInstruction(HloInstruction::CreateTuple({root})); + body->set_root_instruction(root, /*accept_different_shape=*/true); + } + + // Tuplify the condition parameter if needed. + HloInstruction* loop = body->WhileCallInstruction(); + HloComputation* cond = loop->while_condition(); + HloInstruction* cond_parameter = cond->parameter_instruction(0); + if (!cond_parameter->shape().IsTuple()) { + *cond_parameter->mutable_shape() = + ShapeUtil::MakeTupleShape({cond_parameter->shape()}); + HloInstruction* original = cond->AddInstruction( + HloInstruction::CreateGetTupleElement(cond_parameter, 0)); + TF_RETURN_IF_ERROR( + cond_parameter->ReplaceAllUsesWithDifferentShape(original)); + } + + // Tuplify the while instruction if needed. + if (!loop->shape().IsTuple()) { + *loop->mutable_shape() = ShapeUtil::MakeTupleShape({loop->shape()}); + HloInstruction* original = loop->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(loop, 0)); + TF_RETURN_IF_ERROR(loop->ReplaceAllUsesWithDifferentShape(original)); + } + + // Tuplify the while tuple if needed. + HloInstruction* loop_tuple = loop->mutable_operand(0); + if (!loop_tuple->shape().IsTuple()) { + loop_tuple = loop->parent()->AddInstruction( + HloInstruction::CreateTuple({loop_tuple})); + TF_RETURN_IF_ERROR(loop->ReplaceOperandWithDifferentShape(0, loop_tuple)); + } + + // Explicitly disjoin computation parameters from loop inputs, so we can + // insert tokens into the input tuple. + if (loop_tuple->opcode() == HloOpcode::kParameter) { + loop_tuple = ReconstructTuple(loop_tuple); + TF_RETURN_IF_ERROR(loop->ReplaceOperandWith(0, loop_tuple)); + } + + // If the computation root is a also a computation parameter, explicitly + // split them, as the input and output tokens cannot be part of the same + // instruction. + if (root->opcode() == HloOpcode::kParameter) { + root = ReconstructTuple(root); + body->set_root_instruction(root); + } + + return absl::OkStatus(); +} + +absl::StatusOr> +PropagateTokenThroughConditionalBranch(HloInstruction* instruction, + HloInstruction* input_token, + HloInstruction* output_token) { + // Conditional branches can diverge in inputs, but must converge on outputs. + HloInstruction* next_instruction = nullptr; + HloInstruction* next_input_token = nullptr; + HloInstruction* next_output_token = nullptr; + + // Fixup the branch. + HloComputation* comp = instruction->parent(); + TF_RETURN_IF_ERROR(CanonicalizeConditionalBranch(comp)); + next_instruction = comp->ConditionalCallInstruction(); + + // Insert the output token into each branch. + for (HloComputation* branch : next_instruction->branch_computations()) { + HloInstruction* root = branch->root_instruction(); + if (branch == comp) { + TF_RETURN_IF_ERROR( + InsertTokenIntoTuple(root, /*add_token_operand=*/false).status()); + root->AppendOperand(output_token); + } else { + TF_RETURN_IF_ERROR( + InsertTokenIntoTuple(root, /*add_token_operand=*/true).status()); + } + } + + // Insert the input token into the branch parameter. + HloInstruction* parameter = comp->parameter_instruction(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * input_token_gte, + InsertTokenIntoTuple(parameter, /*add_token_operand=*/false)); + TF_RETURN_IF_ERROR(input_token->ReplaceAllUsesWith(input_token_gte)); + + // Insert the input token into the branch tuple. + int64_t branch_operand_idx = next_instruction->branch_index(comp) + 1; + HloInstruction* branch_tuple = + next_instruction->mutable_operand(branch_operand_idx); + TF_ASSIGN_OR_RETURN( + HloInstruction * next_input_token_gte, + InsertTokenIntoTuple(branch_tuple, /*add_token_operand=*/true)); + TF_RETURN_IF_ERROR(next_instruction->ReplaceOperandWithDifferentShape( + branch_operand_idx, branch_tuple)); + next_input_token = + branch_tuple->mutable_operand(next_input_token_gte->tuple_index()); + + // Insert the output token into conditional instruction. + TF_ASSIGN_OR_RETURN( + next_output_token, + InsertTokenIntoTuple(next_instruction, /*add_token_operand=*/false)); + + return std::make_tuple(next_instruction, next_input_token, next_output_token); +} + +absl::StatusOr> +PropagateTokenThroughWhileBody(HloInstruction* instruction, + HloInstruction* input_token, + HloInstruction* output_token) { + // While loops need to converge on input and output. + HloInstruction* next_instruction = nullptr; + HloInstruction* next_input_token = nullptr; + HloInstruction* next_output_token = nullptr; + + // Fixup the while body. + HloComputation* comp = instruction->parent(); + TF_RETURN_IF_ERROR(CanonicalizeWhileBody(comp)); + next_instruction = comp->WhileCallInstruction(); + + // Insert the output token into the body root. + HloInstruction* root = comp->root_instruction(); + TF_RETURN_IF_ERROR( + InsertTokenIntoTuple(root, /*add_token_operand=*/false).status()); + root->AppendOperand(output_token); + + // Insert the input token into the body parameter. + HloInstruction* body_parameter = comp->parameter_instruction(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * input_token_gte, + InsertTokenIntoTuple(body_parameter, /*add_token_operand=*/false)); + TF_RETURN_IF_ERROR(input_token->ReplaceAllUsesWith(input_token_gte)); + + // Insert the input token into the condition parameter. + HloComputation* cond = next_instruction->while_condition(); + HloInstruction* cond_parameter = cond->parameter_instruction(0); + TF_RETURN_IF_ERROR( + InsertTokenIntoTuple(cond_parameter, /*add_token_operand=*/false) + .status()); + + // Insert the input token into the while tuple. + HloInstruction* while_tuple = next_instruction->mutable_operand(0); + TF_ASSIGN_OR_RETURN( + next_input_token, + InsertTokenIntoTuple(while_tuple, /*add_token_operand=*/true)); + TF_RETURN_IF_ERROR( + next_instruction->ReplaceOperandWithDifferentShape(0, while_tuple)); + + // Insert the input token into the while instruction. + TF_ASSIGN_OR_RETURN( + next_output_token, + InsertTokenIntoTuple(next_instruction, /*add_token_operand=*/false)); + + return std::make_tuple(next_instruction, next_input_token, next_output_token); +} + +absl::Status PropagateToken(HloInstruction* instruction, + HloInstruction* input_token, + HloInstruction* output_token) { + HloComputation* comp = instruction->parent(); + if (comp->IsEntryComputation()) { + // If we propagate through the root instruction, reconstruct the original + // tuple and set that to be root. + if (instruction->IsRoot() && + (instruction->opcode() == HloOpcode::kWhile || + instruction->opcode() == HloOpcode::kConditional)) { + std::vector gtes; + int64_t output_token_idx = output_token->tuple_index(); + for (int64_t idx = 0; idx < instruction->shape().tuple_shapes_size(); + idx++) { + if (idx != output_token_idx) { + gtes.push_back(comp->AddInstruction( + HloInstruction::CreateGetTupleElement(instruction, idx))); + } + } + HloInstruction* original_tuple = + comp->AddInstruction(HloInstruction::CreateTuple(gtes)); + comp->set_root_instruction(original_tuple, + /*accept_different_shape=*/true); + } + return absl::OkStatus(); + } + + HloInstruction* next_instruction = nullptr; + HloInstruction* next_input_token = nullptr; + HloInstruction* next_output_token = nullptr; + if (comp->IsConditionalBranchComputation()) { + // TODO(vsytch): handle SPMD. + if (comp->ConditionalCallInstruction()->has_sharding()) { + return absl::OkStatus(); + } + TF_ASSIGN_OR_RETURN( + std::tie(next_instruction, next_input_token, next_output_token), + PropagateTokenThroughConditionalBranch(instruction, input_token, + output_token)); + } else if (comp->IsWhileBodyComputation()) { + // TODO(vsytch): handle SPMD. + if (comp->WhileCallInstruction()->has_sharding()) { + return absl::OkStatus(); + } + TF_ASSIGN_OR_RETURN( + std::tie(next_instruction, next_input_token, next_output_token), + PropagateTokenThroughWhileBody(instruction, input_token, output_token)); + } else { + // We only expect to encounter computations behind while and conditional + // instructions. In the case of it being behind a while condition, there is + // no way to propagate the output token, as the root only returns a + // predicate. All other computations that could possibly contain infeed + // or outfeed ops should have already been inlined. + VLOG(2) << "Unhandled computation: " << comp->name(); + return absl::OkStatus(); + } + CHECK_NE(next_instruction, nullptr); + CHECK_NE(next_input_token, nullptr); + CHECK_NE(next_output_token, nullptr); + + return PropagateToken(next_instruction, next_input_token, next_output_token); +} +} // namespace + +absl::StatusOr InfeedTokenPropagation::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + VLOG(5) << "Before InfeedTokenPropagation:"; + XLA_VLOG_LINES(5, module->ToString()); + + std::vector dangling_infeeds; + std::vector dangling_outfeeds; + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + if (!computation->IsEntryComputation()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kInfeed && + IsDanglingInfeed(instruction)) { + VLOG(1) << "Found dangling infeed: " << instruction->ToString(); + dangling_infeeds.push_back(instruction); + } else if (instruction->opcode() == HloOpcode::kOutfeed && + IsDanglingOutfeed(instruction)) { + VLOG(1) << "Found dangling outfeed: " << instruction->ToString(); + dangling_outfeeds.push_back(instruction); + } + } + } + } + if (!dangling_infeeds.empty() || !dangling_outfeeds.empty()) { + changed = true; + } + + for (HloInstruction* dangling_infeed : dangling_infeeds) { + HloInstruction* input_token = dangling_infeed->mutable_operand(0); + HloInstruction* output_token = dangling_infeed->AddInstruction( + HloInstruction::CreateGetTupleElement(dangling_infeed, 1)); + TF_RETURN_IF_ERROR( + PropagateToken(dangling_infeed, input_token, output_token)); + } + for (HloInstruction* dangling_outfeed : dangling_outfeeds) { + HloInstruction* input_token = dangling_outfeed->mutable_operand(1); + HloInstruction* output_token = dangling_outfeed; + TF_RETURN_IF_ERROR( + PropagateToken(dangling_outfeed, input_token, output_token)); + } + + if (changed) { + TF_RETURN_IF_ERROR( + TupleSimplifier().Run(module, execution_threads).status()); + TF_RETURN_IF_ERROR(HloDCE().Run(module, execution_threads).status()); + } + + VLOG(5) << "After InfeedTokenPropagation:"; + XLA_VLOG_LINES(5, module->ToString()); + return changed; +} +} // namespace xla diff --git a/xla/service/infeed_token_propagation.h b/xla/service/infeed_token_propagation.h new file mode 100644 index 00000000000000..cc6994a62a98a9 --- /dev/null +++ b/xla/service/infeed_token_propagation.h @@ -0,0 +1,45 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ +#define XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { +// Finds dangling infeed/outfeed tokens inside nested computations and bubbles +// them up through callers until they reach the entry computation. This is +// needed to prepare these computations to be inlined, otherwise the previous +// computation boundaries won't be there to stop infeeds/outfeeds from being +// reordered during scheduling. +// +// This pass assumes the HLO graph is flattened. +class InfeedTokenPropagation : public HloModulePass { + public: + std::string_view name() const override { return "infeed-token-propagation"; } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; +} // namespace xla + +#endif // XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ diff --git a/xla/service/infeed_token_propagation_test.cc b/xla/service/infeed_token_propagation_test.cc new file mode 100644 index 00000000000000..8c1024253868d6 --- /dev/null +++ b/xla/service/infeed_token_propagation_test.cc @@ -0,0 +1,601 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/infeed_token_propagation.h" + +#include +#include + +#include +#include +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/utils/hlo_matchers.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +class InfeedTokenPropagationTest : public HloTestBase { + protected: + InfeedTokenPropagationTest() = default; +}; + +TEST_F(InfeedTokenPropagationTest, EntryComputationInfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +ENTRY main { + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + ROOT gte.0 = get-tuple-element(infeed.0), index=0 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(InfeedTokenPropagationTest, EntryComputationOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +ENTRY main { + arg.0 = s32[] parameter(0) + tuple.0 = tuple(arg.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(tuple.0, token.0), outfeed_shape=(s32[]) + ROOT tuple.1 = tuple() +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(InfeedTokenPropagationTest, ConditionalInfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = () parameter(0) + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + pred.0 = pred[] constant(true) + true_tuple.0 = tuple() + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The infeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The infeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), + op::Tuple(op::GetTupleElement(op::Infeed(), 1))); + + // The infeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, ConditionalOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = (s32[]) parameter(0) + token.0 = after-all() + outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=(s32[]) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + arg.0 = s32[] parameter(0) + pred.0 = pred[] constant(true) + true_tuple.0 = tuple(arg.0) + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The outfeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, NonTupleConditional) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = s32[] parameter(0) + outfeed_tuple.0 = tuple(arg.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(outfeed_tuple.0, token.0), outfeed_shape=(s32[]) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + arg.0 = s32[] parameter(0) + pred.0 = pred[] constant(true) + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, arg.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = cond->mutable_operand(1); + EXPECT_TRUE(true_tuple->shape().IsTuple()); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The outfeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, DisjointConditionalOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + ROOT arg.0 = () parameter(0) + one.0 = s32[] constant(1) + outfeed_tuple.0 = tuple(one.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(outfeed_tuple.0, token.0), outfeed_shape=(s32[]) +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + pred.0 = pred[] constant(true) + true_tuple.0 = tuple() + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The outfeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, WhileInfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +comp { + arg.0 = () parameter(0) + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + ROOT tuple.0 = tuple() +} + +cond { + arg.0 = () parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + while_tuple.0 = tuple() + ROOT while.0 = () while(while_tuple.0), condition=cond, body=comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The infeed output token should have propagated through the loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(loop->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should have propagated through the loop tuple. + HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); + EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[0].IsToken()); + + // The infeed output token should have propagated through the while body root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), + op::Tuple(op::GetTupleElement(op::Infeed(), 1))); + + // The infeed input token should have propagated through the body parameter. + HloInstruction* body_param = body_comp->parameter_instruction(0); + EXPECT_EQ(body_param->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(body_param->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should have propagated through the condition + // parameter. + HloComputation* cond_comp = FindComputation(module.get(), "cond"); + HloInstruction* cond_param = cond_comp->parameter_instruction(0); + EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond_param->shape().tuple_shapes()[0].IsToken()); +} + +TEST_F(InfeedTokenPropagationTest, WhileOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +comp { + arg.0 = (s32[]) parameter(0) + token.0 = after-all() + outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=(s32[]) + gte.0 = get-tuple-element(arg.0), index=0 + ROOT tuple.0 = tuple(gte.0) +} + +cond { + arg.0 = (s32[]) parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + arg.0 = s32[] parameter(0) + while_tuple.0 = tuple(arg.0) + ROOT while.0 = (s32[]) while(while_tuple.0), condition=cond, body=comp +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should have propagated through the loop tuple. + HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); + EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[1].IsToken()); + + // The outfeed output token should have propagated through the while body + // root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), + op::Tuple(op::GetTupleElement(), op::Outfeed())); + + // The outfeed output token should have propagated through the body parameter. + HloInstruction* body_param = body_comp->parameter_instruction(0); + EXPECT_EQ(body_param->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(body_param->shape().tuple_shapes()[1].IsToken()); + + // The outfeed output token should have propagated through the condition + // parameter. + HloComputation* cond_comp = FindComputation(module.get(), "cond"); + HloInstruction* cond_param = cond_comp->parameter_instruction(0); + EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(cond_param->shape().tuple_shapes()[1].IsToken()); +} + +TEST_F(InfeedTokenPropagationTest, DisjointWhileOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +comp { + ROOT arg.0 = () parameter(0) + one.0 = s32[] constant(1) + outfeed_tuple.0 = tuple(one.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(outfeed_tuple.0, token.0), outfeed_shape=(s32[]) +} + +cond { + arg.0 = () parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + while_tuple.0 = tuple() + ROOT while.0 = () while(while_tuple.0), condition=cond, body=comp +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(loop->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the loop tuple. + HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); + EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[0].IsToken()); + + // The outfeed output token should have propagated through the while body + // root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated through the body parameter. + HloInstruction* body_param = body_comp->parameter_instruction(0); + EXPECT_EQ(body_param->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(body_param->shape().tuple_shapes()[0].IsToken()); + + // The outfeed output token should have propagated through the condition + // parameter. + HloComputation* cond_comp = FindComputation(module.get(), "cond"); + HloInstruction* cond_param = cond_comp->parameter_instruction(0); + EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond_param->shape().tuple_shapes()[0].IsToken()); +} + +TEST_F(InfeedTokenPropagationTest, NonTupleWhile) { + constexpr std::string_view hlo = R"( +HloModule main + +comp { + ROOT arg.0 = s32[] parameter(0) + tuple.0 = tuple(arg.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(tuple.0, token.0), outfeed_shape=(s32[]) +} + +cond { + arg.0 = s32[] parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + arg.0 = s32[] parameter(0) + ROOT while.0 = s32[] while(arg.0), condition=cond, body=comp +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_TRUE(loop->shape().IsTuple()); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should have propagated through the loop tuple. + EXPECT_THAT(loop->operand(0), op::Tuple(op::Parameter(), op::AfterAll())); + + // The outfeed output token should have propagated through the while body + // root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), + op::Tuple(op::GetTupleElement(), op::Outfeed())); + + // The outfeed output token should have propagated through the body parameter. + HloInstruction* body_param = body_comp->parameter_instruction(0); + EXPECT_EQ(body_param->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(body_param->shape().tuple_shapes()[1].IsToken()); + + // The outfeed output token should have propagated through the condition + // parameter. + HloComputation* cond_comp = FindComputation(module.get(), "cond"); + HloInstruction* cond_param = cond_comp->parameter_instruction(0); + EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(cond_param->shape().tuple_shapes()[1].IsToken()); +} + +TEST_F(InfeedTokenPropagationTest, NestedInfeedOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = (s32[]) parameter(0) + token.0 = after-all() + outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=(s32[]) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +comp { + arg.0 = () parameter(0) + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + gte.0 = get-tuple-element(infeed.0), index=0 + pred.0 = pred[] constant(true) + true_tuple.0 = tuple(gte.0) + false_tuple.0 = tuple() + cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp + ROOT tuple.0 = tuple() +} + +cond { + arg.0 = () parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + while_tuple.0 = tuple() + ROOT while.0 = () while(while_tuple.0), condition=cond, body=comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The infeed and outfeed output tokens should have propagated through the + // loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop->shape().tuple_shapes()[0].IsToken()); + EXPECT_TRUE(loop->shape().tuple_shapes()[1].IsToken()); + + // The infeed and outfeed input tokens should have propagated through the loop + // tuple. + HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); + EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[0].IsToken()); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[1].IsToken()); + + // The infeed and outfeed output tokens should have propagated through the + // while body root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), + op::Tuple(op::GetTupleElement(op::Infeed(), 1), + op::GetTupleElement(op::Conditional(), 0))); + + // The outfeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The outfeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} +} // namespace +} // namespace xla diff --git a/xla/service/while_loop_invariant_code_motion.cc b/xla/service/while_loop_invariant_code_motion.cc index ed44547af3fca4..423b15c69a63bf 100644 --- a/xla/service/while_loop_invariant_code_motion.cc +++ b/xla/service/while_loop_invariant_code_motion.cc @@ -134,6 +134,7 @@ bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually( case HloOpcode::kReshape: return !hoist_reshapes_; + case HloOpcode::kAfterAll: case HloOpcode::kBitcast: case HloOpcode::kBroadcast: case HloOpcode::kIota: