From 52e0c5eab65da182b09e68931c1fb20dd38fac5b Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Fri, 13 Sep 2024 11:19:55 -0700 Subject: [PATCH] [Take 2] Generalize global jit cpp cache keys so we can add more keys than the current donate_argnums. This allows us to get more cache hits globally. For example: Before: jax.jit(f, out_shardings=s)(arr) jax.jit(f, out_shardings=s)(arr) # cpp cache miss After: jax.jit(f, out_shardings=s)(arr) jax.jit(f, out_shardings=s)(arr) # cpp cache hit Reverts d8f4200c1bea46f4c970da688015d77dde1b0a55 PiperOrigin-RevId: 674367894 --- xla/python/BUILD | 3 + xla/python/pjit.cc | 102 +++++++++++++++++--------- xla/python/xla_client.py | 2 +- xla/python/xla_extension/__init__.pyi | 2 +- 4 files changed, 73 insertions(+), 36 deletions(-) diff --git a/xla/python/BUILD b/xla/python/BUILD index a9081b5d63f1c7..ca95369305f077 100644 --- a/xla/python/BUILD +++ b/xla/python/BUILD @@ -745,9 +745,12 @@ cc_library( ":traceback", ":transfer_guard_lib", # placeholder for index annotation deps + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/xla/python/pjit.cc b/xla/python/pjit.cc index 12d4814e460cc7..510e702d671278 100644 --- a/xla/python/pjit.cc +++ b/xla/python/pjit.cc @@ -28,16 +28,21 @@ limitations under the License. #include #include #include // NOLINT +#include #include #include +#include "absl/base/thread_annotations.h" +#include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" @@ -122,10 +127,10 @@ class PjitFunctionCache { // might be evicted before we finish tracing/compiling. typedef xla::LRUCache> Cache; - // We include as part of the cache key `donate_argnums` (and any other fields - // that aren't subsumed by the CallSignature we compute for each call). + // We include as part of the cache key `global_cache_key` (and any other + // fields that aren't subsumed by the CallSignature we compute for each call). std::shared_ptr Lookup(nb::handle function, - absl::Span donate_argnums); + nb::object global_cache_key); std::shared_ptr DefaultCache(); int Size() const { return lru_list_.Size(); } @@ -138,19 +143,35 @@ class PjitFunctionCache { // Other fields that are part of the arguments to `jit`, but are not // otherwise part of CallSignature. - std::vector donate_argnums; + nb::object global_cache_key; bool operator==(const Key& other) const { - return function.ptr() == other.function.ptr() && - donate_argnums == other.donate_argnums; + bool global_cache_eq; + try { + global_cache_eq = global_cache_key.equal(other.global_cache_key); + } catch (const nanobind::python_error& e) { + throw std::invalid_argument( + absl::StrCat("Equality of global cache key lead to an exception. " + "The error was:\n", + e.what(), "\n")); + } + return function.ptr() == other.function.ptr() && global_cache_eq; } }; template friend H AbslHashValue(H h, const Key& key) { h = H::combine(std::move(h), key.function.ptr()); - h = H::combine_contiguous(std::move(h), key.donate_argnums.data(), - key.donate_argnums.size()); + Py_hash_t hash; + try { + hash = xla::nb_hash(key.global_cache_key); + } catch (const nanobind::python_error& e) { + if (!e.matches(PyExc_TypeError)) throw; + throw std::invalid_argument(absl::StrCat( + "Hashing global cache key lead to an exception. The error was:\n", + e.what(), "\n")); + } + h = H::combine(std::move(h), hash); return h; } @@ -167,7 +188,9 @@ class PjitFunctionCache { }; Cache::LRUList lru_list_; - absl::flat_hash_map> functions_; + absl::Mutex mu_; // Non-trivial hashes need to be mutex locked. + // ABSL containers are not exception safe: + std::unordered_map, absl::Hash> functions_; }; PjitFunctionCache::PjitFunctionCache(int capacity) : lru_list_(capacity) {} @@ -177,11 +200,20 @@ std::shared_ptr PjitFunctionCache::DefaultCache() { } std::shared_ptr PjitFunctionCache::Lookup( - nb::handle function, absl::Span donate_argnums) { + nb::handle function, + nb::object global_cache_key) ABSL_NO_THREAD_SAFETY_ANALYSIS { + { + // Because the gil can be released during cache insertion, this forces + // the lock order to be mu_ then gil so we must release the gil first. + nb::gil_scoped_release release; + // Acquire a mutex to avoid problems where the gil is released during + // cache insertion and then a second thread invalidates the cache order. + mu_.Lock(); + } + absl::Cleanup unlock = [this]() ABSL_UNLOCK_FUNCTION(mu_) { mu_.Unlock(); }; Key key; key.function = function; - key.donate_argnums = - std::vector(donate_argnums.begin(), donate_argnums.end()); + key.global_cache_key = global_cache_key; auto insert = functions_.emplace(key, nullptr); if (!insert.second) { return insert.first->second->cache; @@ -189,7 +221,10 @@ std::shared_ptr PjitFunctionCache::Lookup( std::shared_ptr cache = std::make_shared(&lru_list_); auto callback = nb::cpp_function([this, key{std::move(key)}](nb::handle weakref) { - functions_.erase(key); + auto it = functions_.find(key); + if (it != functions_.end()) { + functions_.erase(it); + } }); PyObject* weakref = PyWeakref_NewRef(function.ptr(), callback.ptr()); if (weakref) { @@ -211,7 +246,7 @@ class PjitFunction { PjitFunction(std::string function_name, std::optional fun, nb::callable cache_miss, std::vector static_argnums, std::vector static_argnames, - std::vector donate_argnums, + nb::object global_cache_key, std::shared_ptr pytree_registry, nb::callable shard_arg_fallback, std::shared_ptr cache); @@ -258,7 +293,7 @@ class PjitFunction { const std::vector& static_argnames() const { return static_argnames_; } - const std::vector& donate_argnums() const { return donate_argnums_; } + const nb::object& global_cache_key() const { return global_cache_key_; } const std::shared_ptr& cache() const { return cache_; } int cache_capacity() const { return executables_->Size(); } @@ -291,7 +326,7 @@ class PjitFunction { nb::callable cache_miss_; std::vector static_argnums_; std::vector static_argnames_; - std::vector donate_argnums_; + nb::object global_cache_key_; std::shared_ptr pytree_registry_; nb::callable shard_arg_fallback_; @@ -325,14 +360,14 @@ PjitFunctionStore& GetGlobalPjitFunctionStore() { PjitFunction::PjitFunction( std::string function_name, std::optional fun, nb::callable cache_miss, std::vector static_argnums, - std::vector static_argnames, std::vector donate_argnums, + std::vector static_argnames, nb::object global_cache_key, std::shared_ptr pytree_registry, nb::callable shard_arg_fallback, std::shared_ptr cache) : function_name_(std::move(function_name)), fun_(std::move(fun)), cache_miss_(std::move(cache_miss)), static_argnums_(std::move(static_argnums)), - donate_argnums_(donate_argnums), + global_cache_key_(global_cache_key), pytree_registry_(std::move(pytree_registry)), shard_arg_fallback_(std::move(shard_arg_fallback)), cache_(std::move(cache)) { @@ -346,7 +381,7 @@ PjitFunction::PjitFunction( if (!fun_.has_value()) { executables_ = cache_->DefaultCache(); } else { - executables_ = cache_->Lookup(fun_.value(), donate_argnums); + executables_ = cache_->Lookup(fun_.value(), global_cache_key); } GetGlobalPjitFunctionStore().Insert(this); @@ -1029,20 +1064,20 @@ void InitializePjitFunction( PjitFunctionObject* fn_obj, std::string function_name, std::optional fun, nb::callable cache_miss, std::vector static_argnums, std::vector static_argnames, - std::vector donate_argnums, + nb::object global_cache_key, std::shared_ptr pytree_registry, nb::callable shard_arg_fallback, std::shared_ptr cache) { new (&fn_obj->fun) PjitFunction( std::move(function_name), std::move(fun), std::move(cache_miss), std::move(static_argnums), std::move(static_argnames), - std::move(donate_argnums), std::move(pytree_registry), + std::move(global_cache_key), std::move(pytree_registry), std::move(shard_arg_fallback), std::move(cache)); } nb::object MakePjitFunction( std::string function_name, std::optional fun, nb::callable cache_miss, std::vector static_argnums, - std::vector static_argnames, std::vector donate_argnums, + std::vector static_argnames, nb::object global_cache_key, std::shared_ptr pytree_registry, nb::callable shard_arg_fallback, std::optional> cache) { @@ -1053,11 +1088,11 @@ nb::object MakePjitFunction( cache = std::make_shared( PjitFunctionCache::kDefaultCapacity); } - InitializePjitFunction(fn_obj, std::move(function_name), std::move(fun), - std::move(cache_miss), std::move(static_argnums), - std::move(static_argnames), std::move(donate_argnums), - std::move(pytree_registry), - std::move(shard_arg_fallback), std::move(*cache)); + InitializePjitFunction( + fn_obj, std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(*cache)); return obj; } @@ -1169,7 +1204,7 @@ void BuildPjitSubmodule(nb::module_& m) { pickle["cache_miss"] = fn->cache_miss(); pickle["static_argnums"] = fn->static_argnums(); pickle["static_argnames"] = nb::cast(fn->static_argnames()); - pickle["donate_argnums"] = fn->donate_argnums(); + pickle["global_cache_key"] = fn->global_cache_key(); pickle["pytree_registry"] = nb::cast(fn->pytree_registry()); pickle["shard_arg_fallback"] = fn->shard_arg_fallback(); pickle["cache"] = fn->cache(); @@ -1197,8 +1232,7 @@ void BuildPjitSubmodule(nb::module_& m) { nb::cast>(pickle["static_argnums"]); std::vector static_argnames = nb::cast>(pickle["static_argnames"]); - std::vector donate_argnums = - nb::cast>(pickle["donate_argnums"]); + nb::object global_cache_key = pickle["global_cache_key"]; std::shared_ptr pytree_registry = nb::cast>( nb::handle(pickle["pytree_registry"].ptr())); @@ -1210,7 +1244,7 @@ void BuildPjitSubmodule(nb::module_& m) { reinterpret_cast(self.ptr()), std::move(function_name), std::move(fun), std::move(cache_miss), std::move(static_argnums), std::move(static_argnames), - std::move(donate_argnums), std::move(pytree_registry), + std::move(global_cache_key), std::move(pytree_registry), std::move(shard_arg_fallback), std::move(cache)); }, nb::is_method()); @@ -1236,7 +1270,7 @@ void BuildPjitSubmodule(nb::module_& m) { "pjit", [](std::string function_name, std::optional fun, nb::callable cache_miss, std::vector static_argnums, - std::vector static_argnames, std::vector donate_argnums, + std::vector static_argnames, nb::object global_cache_key, nb::object pytree_registry, nb::callable shard_arg_fallback, std::optional> cache) { std::shared_ptr registry = @@ -1245,12 +1279,12 @@ void BuildPjitSubmodule(nb::module_& m) { return MakePjitFunction( std::move(function_name), std::move(fun), std::move(cache_miss), std::move(static_argnums), std::move(static_argnames), - std::move(donate_argnums), std::move(registry), + std::move(global_cache_key), std::move(registry), std::move(shard_arg_fallback), std::move(cache)); }, nb::arg("function_name"), nb::arg("fun").none(), nb::arg("cache_miss"), nb::arg("static_argnums"), nb::arg("static_argnames"), - nb::arg("donate_argnums"), nb::arg("pytree_registry"), + nb::arg("global_cache_key"), nb::arg("pytree_registry"), nb::arg("shard_arg_fallback"), nb::arg("cache").none() = nb::none()); } diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index 927ad0f19bb0cd..89332de94b0b82 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 285 +_version = 286 # Version number for MLIR:Python components. mlir_api_version = 57 diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index 93088d1ae9c06f..b5ae4c6431ca66 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -932,7 +932,7 @@ def pjit( cache_miss: Callable, static_argnums: Sequence[int], static_argnames: Sequence[str], - donate_argnums: Sequence[int], + global_cache_key: Any, pytree_registry: pytree.PyTreeRegistry, shard_arg_fallback: Callable, cache: Optional[PjitFunctionCache] = ...,