Skip to content

Commit

Permalink
[XLA] Introduce infeed token propagation
Browse files Browse the repository at this point in the history
During computation inlining, specifically loop unrolling, it is posibble for infeeds (and outfeeds) to get reordered in a way that breaks the original scheduling constraints set by the computation boundaries. This is a result of Tensorflow not exposing tokens for these ops to the user, so the input and output tokens end up dangling.

Loop unrolling in XLA can be thought of applying the same function repeatedly to itself, e.g. transforming f(x) into f(f(x)). By pushing the tokens outside the loop body, we can guarantee that the output token of the first infeed will become the input token of the next infeed, thus creating a data dependency chain and preserving the original ordering.

Reverts d8f4200

PiperOrigin-RevId: 671128754
  • Loading branch information
vsytch authored and Google-ML-Automation committed Sep 17, 2024
1 parent 9be7aca commit 0c70e61
Show file tree
Hide file tree
Showing 11 changed files with 1,244 additions and 38 deletions.
11 changes: 11 additions & 0 deletions xla/hlo/ir/hlo_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3411,6 +3411,17 @@ HloComputation* HloInstruction::branch_computation(int b) const {
return called_computations()[b];
}

int HloInstruction::branch_index(HloComputation* computation) const {
CHECK(HloOpcode::kConditional == opcode_);
CHECK_NE(computation, nullptr);
for (int idx = 0; idx < branch_count(); idx++) {
if (branch_computation(idx) == computation) {
return idx;
}
}
CHECK(false);
}

void HloInstruction::set_branch_computation(int b,
HloComputation* computation) {
CHECK_EQ(HloOpcode::kConditional, opcode_);
Expand Down
1 change: 1 addition & 0 deletions xla/hlo/ir/hlo_instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1810,6 +1810,7 @@ class HloInstruction {
const PtrVec<HloComputation*>& branch_computations() const;
int branch_count() const;
HloComputation* branch_computation(int b) const;
int branch_index(HloComputation* computation) const;
// Sets a branch HloComputation for Conditional.
// The setter should only be called by HloModule or HloComputation methods.
//
Expand Down
3 changes: 3 additions & 0 deletions xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -745,9 +745,12 @@ cc_library(
":traceback",
":transfer_guard_lib",
# placeholder for index annotation deps
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand Down
117 changes: 81 additions & 36 deletions xla/python/pjit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,21 @@ limitations under the License.
#include <string>
#include <string_view>
#include <thread> // NOLINT
#include <unordered_map>
#include <utility>
#include <vector>

#include "absl/base/thread_annotations.h"
#include "absl/cleanup/cleanup.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/hash/hash.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h"
#include "absl/types/span.h"
#include "nanobind/nanobind.h"
Expand Down Expand Up @@ -122,10 +127,10 @@ class PjitFunctionCache {
// might be evicted before we finish tracing/compiling.
typedef xla::LRUCache<CallSignature, std::shared_ptr<PjitCacheEntry>> Cache;

// We include as part of the cache key `donate_argnums` (and any other fields
// that aren't subsumed by the CallSignature we compute for each call).
// We include as part of the cache key `global_cache_key` (and any other
// fields that aren't subsumed by the CallSignature we compute for each call).
std::shared_ptr<Cache> Lookup(nb::handle function,
absl::Span<const int> donate_argnums);
nb::object global_cache_key);
std::shared_ptr<Cache> DefaultCache();

int Size() const { return lru_list_.Size(); }
Expand All @@ -138,19 +143,35 @@ class PjitFunctionCache {

// Other fields that are part of the arguments to `jit`, but are not
// otherwise part of CallSignature.
std::vector<int> donate_argnums;
nb::object global_cache_key;

bool operator==(const Key& other) const {
return function.ptr() == other.function.ptr() &&
donate_argnums == other.donate_argnums;
bool global_cache_eq;
try {
global_cache_eq = global_cache_key.equal(other.global_cache_key);
} catch (const nanobind::python_error& e) {
throw std::invalid_argument(
absl::StrCat("Equality of global cache key lead to an exception. "
"The error was:\n",
e.what(), "\n"));
}
return function.ptr() == other.function.ptr() && global_cache_eq;
}
};

template <typename H>
friend H AbslHashValue(H h, const Key& key) {
h = H::combine(std::move(h), key.function.ptr());
h = H::combine_contiguous(std::move(h), key.donate_argnums.data(),
key.donate_argnums.size());
Py_hash_t hash;
try {
hash = xla::nb_hash(key.global_cache_key);
} catch (const nanobind::python_error& e) {
if (!e.matches(PyExc_TypeError)) throw;
throw std::invalid_argument(absl::StrCat(
"Hashing global cache key lead to an exception. The error was:\n",
e.what(), "\n"));
}
h = H::combine(std::move(h), hash);
return h;
}

Expand All @@ -167,7 +188,9 @@ class PjitFunctionCache {
};

Cache::LRUList lru_list_;
absl::flat_hash_map<Key, std::unique_ptr<Value>> functions_;
absl::Mutex mu_; // Non-trivial hashes need to be mutex locked.
// ABSL containers are not exception safe:
std::unordered_map<Key, std::unique_ptr<Value>, absl::Hash<Key>> functions_;
};

PjitFunctionCache::PjitFunctionCache(int capacity) : lru_list_(capacity) {}
Expand All @@ -177,19 +200,31 @@ std::shared_ptr<PjitFunctionCache::Cache> PjitFunctionCache::DefaultCache() {
}

std::shared_ptr<PjitFunctionCache::Cache> PjitFunctionCache::Lookup(
nb::handle function, absl::Span<const int> donate_argnums) {
nb::handle function,
nb::object global_cache_key) ABSL_NO_THREAD_SAFETY_ANALYSIS {
{
// Because the gil can be released during cache insertion, this forces
// the lock order to be mu_ then gil so we must release the gil first.
nb::gil_scoped_release release;
// Acquire a mutex to avoid problems where the gil is released during
// cache insertion and then a second thread invalidates the cache order.
mu_.Lock();
}
absl::Cleanup unlock = [this]() ABSL_UNLOCK_FUNCTION(mu_) { mu_.Unlock(); };
Key key;
key.function = function;
key.donate_argnums =
std::vector<int>(donate_argnums.begin(), donate_argnums.end());
key.global_cache_key = global_cache_key;
auto insert = functions_.emplace(key, nullptr);
if (!insert.second) {
return insert.first->second->cache;
}
std::shared_ptr<Cache> cache = std::make_shared<Cache>(&lru_list_);
auto callback =
nb::cpp_function([this, key{std::move(key)}](nb::handle weakref) {
functions_.erase(key);
auto it = functions_.find(key);
if (it != functions_.end()) {
functions_.erase(it);
}
});
PyObject* weakref = PyWeakref_NewRef(function.ptr(), callback.ptr());
if (weakref) {
Expand All @@ -211,7 +246,7 @@ class PjitFunction {
PjitFunction(std::string function_name, std::optional<nb::callable> fun,
nb::callable cache_miss, std::vector<int> static_argnums,
std::vector<nb::str> static_argnames,
std::vector<int> donate_argnums,
nb::object global_cache_key,
std::shared_ptr<xla::PyTreeRegistry> pytree_registry,
nb::callable shard_arg_fallback,
std::shared_ptr<PjitFunctionCache> cache);
Expand Down Expand Up @@ -244,6 +279,8 @@ class PjitFunction {
absl::StatusOr<nb::object> Call(nb::handle callable, PyObject* const* args,
size_t nargs, PyObject* kwnames);

void InitExecutables();

void ClearPythonReferences();

const std::string& function_name() const { return function_name_; }
Expand All @@ -258,7 +295,7 @@ class PjitFunction {
const std::vector<nb::str>& static_argnames() const {
return static_argnames_;
}
const std::vector<int>& donate_argnums() const { return donate_argnums_; }
const nb::object& global_cache_key() const { return global_cache_key_; }
const std::shared_ptr<PjitFunctionCache>& cache() const { return cache_; }

int cache_capacity() const { return executables_->Size(); }
Expand Down Expand Up @@ -291,7 +328,7 @@ class PjitFunction {
nb::callable cache_miss_;
std::vector<int> static_argnums_;
std::vector<nb::str> static_argnames_;
std::vector<int> donate_argnums_;
nb::object global_cache_key_;

std::shared_ptr<xla::PyTreeRegistry> pytree_registry_;
nb::callable shard_arg_fallback_;
Expand Down Expand Up @@ -325,14 +362,14 @@ PjitFunctionStore& GetGlobalPjitFunctionStore() {
PjitFunction::PjitFunction(
std::string function_name, std::optional<nb::callable> fun,
nb::callable cache_miss, std::vector<int> static_argnums,
std::vector<nb::str> static_argnames, std::vector<int> donate_argnums,
std::vector<nb::str> static_argnames, nb::object global_cache_key,
std::shared_ptr<xla::PyTreeRegistry> pytree_registry,
nb::callable shard_arg_fallback, std::shared_ptr<PjitFunctionCache> cache)
: function_name_(std::move(function_name)),
fun_(std::move(fun)),
cache_miss_(std::move(cache_miss)),
static_argnums_(std::move(static_argnums)),
donate_argnums_(donate_argnums),
global_cache_key_(std::move(global_cache_key)),
pytree_registry_(std::move(pytree_registry)),
shard_arg_fallback_(std::move(shard_arg_fallback)),
cache_(std::move(cache)) {
Expand All @@ -343,13 +380,16 @@ PjitFunction::PjitFunction(
PyUnicode_InternInPlace(&s);
static_argnames_.push_back(nb::steal<nb::str>(s));
}

GetGlobalPjitFunctionStore().Insert(this);
}

void PjitFunction::InitExecutables() {
if (!fun_.has_value()) {
executables_ = cache_->DefaultCache();
} else {
executables_ = cache_->Lookup(fun_.value(), donate_argnums);
executables_ = cache_->Lookup(fun_.value(), global_cache_key_);
}

GetGlobalPjitFunctionStore().Insert(this);
}

PjitFunction::~PjitFunction() { GetGlobalPjitFunctionStore().Erase(this); }
Expand Down Expand Up @@ -1029,20 +1069,26 @@ void InitializePjitFunction(
PjitFunctionObject* fn_obj, std::string function_name,
std::optional<nb::callable> fun, nb::callable cache_miss,
std::vector<int> static_argnums, std::vector<nb::str> static_argnames,
std::vector<int> donate_argnums,
nb::object global_cache_key,
std::shared_ptr<xla::PyTreeRegistry> pytree_registry,
nb::callable shard_arg_fallback, std::shared_ptr<PjitFunctionCache> cache) {
if (nb::isinstance<nb::list>(global_cache_key)) {
global_cache_key = nb::tuple(global_cache_key);
}
new (&fn_obj->fun) PjitFunction(
std::move(function_name), std::move(fun), std::move(cache_miss),
std::move(static_argnums), std::move(static_argnames),
std::move(donate_argnums), std::move(pytree_registry),
std::move(global_cache_key), std::move(pytree_registry),
std::move(shard_arg_fallback), std::move(cache));
// Handled separately because it is not exception safe to call this
// in the constructor because it leaves the object improperly constructed.
fn_obj->fun.InitExecutables();
}

nb::object MakePjitFunction(
std::string function_name, std::optional<nb::callable> fun,
nb::callable cache_miss, std::vector<int> static_argnums,
std::vector<nb::str> static_argnames, std::vector<int> donate_argnums,
std::vector<nb::str> static_argnames, nb::object global_cache_key,
std::shared_ptr<xla::PyTreeRegistry> pytree_registry,
nb::callable shard_arg_fallback,
std::optional<std::shared_ptr<PjitFunctionCache>> cache) {
Expand All @@ -1053,11 +1099,11 @@ nb::object MakePjitFunction(
cache = std::make_shared<PjitFunctionCache>(
PjitFunctionCache::kDefaultCapacity);
}
InitializePjitFunction(fn_obj, std::move(function_name), std::move(fun),
std::move(cache_miss), std::move(static_argnums),
std::move(static_argnames), std::move(donate_argnums),
std::move(pytree_registry),
std::move(shard_arg_fallback), std::move(*cache));
InitializePjitFunction(
fn_obj, std::move(function_name), std::move(fun), std::move(cache_miss),
std::move(static_argnums), std::move(static_argnames),
std::move(global_cache_key), std::move(pytree_registry),
std::move(shard_arg_fallback), std::move(*cache));
return obj;
}

Expand Down Expand Up @@ -1169,7 +1215,7 @@ void BuildPjitSubmodule(nb::module_& m) {
pickle["cache_miss"] = fn->cache_miss();
pickle["static_argnums"] = fn->static_argnums();
pickle["static_argnames"] = nb::cast(fn->static_argnames());
pickle["donate_argnums"] = fn->donate_argnums();
pickle["global_cache_key"] = fn->global_cache_key();
pickle["pytree_registry"] = nb::cast(fn->pytree_registry());
pickle["shard_arg_fallback"] = fn->shard_arg_fallback();
pickle["cache"] = fn->cache();
Expand Down Expand Up @@ -1197,8 +1243,7 @@ void BuildPjitSubmodule(nb::module_& m) {
nb::cast<std::vector<int>>(pickle["static_argnums"]);
std::vector<nb::str> static_argnames =
nb::cast<std::vector<nb::str>>(pickle["static_argnames"]);
std::vector<int> donate_argnums =
nb::cast<std::vector<int>>(pickle["donate_argnums"]);
nb::object global_cache_key = pickle["global_cache_key"];
std::shared_ptr<xla::PyTreeRegistry> pytree_registry =
nb::cast<std::shared_ptr<xla::PyTreeRegistry>>(
nb::handle(pickle["pytree_registry"].ptr()));
Expand All @@ -1210,7 +1255,7 @@ void BuildPjitSubmodule(nb::module_& m) {
reinterpret_cast<PjitFunctionObject*>(self.ptr()),
std::move(function_name), std::move(fun), std::move(cache_miss),
std::move(static_argnums), std::move(static_argnames),
std::move(donate_argnums), std::move(pytree_registry),
std::move(global_cache_key), std::move(pytree_registry),
std::move(shard_arg_fallback), std::move(cache));
},
nb::is_method());
Expand All @@ -1236,7 +1281,7 @@ void BuildPjitSubmodule(nb::module_& m) {
"pjit",
[](std::string function_name, std::optional<nb::callable> fun,
nb::callable cache_miss, std::vector<int> static_argnums,
std::vector<nb::str> static_argnames, std::vector<int> donate_argnums,
std::vector<nb::str> static_argnames, nb::object global_cache_key,
nb::object pytree_registry, nb::callable shard_arg_fallback,
std::optional<std::shared_ptr<PjitFunctionCache>> cache) {
std::shared_ptr<xla::PyTreeRegistry> registry =
Expand All @@ -1245,12 +1290,12 @@ void BuildPjitSubmodule(nb::module_& m) {
return MakePjitFunction(
std::move(function_name), std::move(fun), std::move(cache_miss),
std::move(static_argnums), std::move(static_argnames),
std::move(donate_argnums), std::move(registry),
std::move(global_cache_key), std::move(registry),
std::move(shard_arg_fallback), std::move(cache));
},
nb::arg("function_name"), nb::arg("fun").none(), nb::arg("cache_miss"),
nb::arg("static_argnums"), nb::arg("static_argnames"),
nb::arg("donate_argnums"), nb::arg("pytree_registry"),
nb::arg("global_cache_key"), nb::arg("pytree_registry"),
nb::arg("shard_arg_fallback"), nb::arg("cache").none() = nb::none());
}

Expand Down
2 changes: 1 addition & 1 deletion xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 285
_version = 286

# Version number for MLIR:Python components.
mlir_api_version = 57
Expand Down
2 changes: 1 addition & 1 deletion xla/python/xla_extension/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ def pjit(
cache_miss: Callable,
static_argnums: Sequence[int],
static_argnames: Sequence[str],
donate_argnums: Sequence[int],
global_cache_key: Any,
pytree_registry: pytree.PyTreeRegistry,
shard_arg_fallback: Callable,
cache: Optional[PjitFunctionCache] = ...,
Expand Down
35 changes: 35 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8522,4 +8522,39 @@ xla_cc_test(
],
)

cc_library(
name = "infeed_token_propagation",
srcs = ["infeed_token_propagation.cc"],
hdrs = ["infeed_token_propagation.h"],
deps = [
":hlo_dce",
":tuple_simplifier",
"//xla:shape_util",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
)

cc_test(
name = "infeed_token_propagation_test",
srcs = ["infeed_token_propagation_test.cc"],
deps = [
":infeed_token_propagation",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_matchers",
"//xla/tests:hlo_test_base",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:statusor",
],
)

exports_files(["xla_aot_compile_test_gpu_target_config.prototxt"])
Loading

0 comments on commit 0c70e61

Please sign in to comment.