Skip to content

Commit

Permalink
[Take 2] Generalize global jit cpp cache keys so we can add more keys…
Browse files Browse the repository at this point in the history
… than the current donate_argnums.

This allows us to get more cache hits globally. For example:

Before:

jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache miss
After:

jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache hit

Reverts d8f4200

PiperOrigin-RevId: 674367894
  • Loading branch information
pschuh authored and Google-ML-Automation committed Sep 17, 2024
1 parent bc1aad8 commit 52e0c5e
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 36 deletions.
3 changes: 3 additions & 0 deletions xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -745,9 +745,12 @@ cc_library(
":traceback",
":transfer_guard_lib",
# placeholder for index annotation deps
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand Down
102 changes: 68 additions & 34 deletions xla/python/pjit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,21 @@ limitations under the License.
#include <string>
#include <string_view>
#include <thread> // NOLINT
#include <unordered_map>
#include <utility>
#include <vector>

#include "absl/base/thread_annotations.h"
#include "absl/cleanup/cleanup.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/hash/hash.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h"
#include "absl/types/span.h"
#include "nanobind/nanobind.h"
Expand Down Expand Up @@ -122,10 +127,10 @@ class PjitFunctionCache {
// might be evicted before we finish tracing/compiling.
typedef xla::LRUCache<CallSignature, std::shared_ptr<PjitCacheEntry>> Cache;

// We include as part of the cache key `donate_argnums` (and any other fields
// that aren't subsumed by the CallSignature we compute for each call).
// We include as part of the cache key `global_cache_key` (and any other
// fields that aren't subsumed by the CallSignature we compute for each call).
std::shared_ptr<Cache> Lookup(nb::handle function,
absl::Span<const int> donate_argnums);
nb::object global_cache_key);
std::shared_ptr<Cache> DefaultCache();

int Size() const { return lru_list_.Size(); }
Expand All @@ -138,19 +143,35 @@ class PjitFunctionCache {

// Other fields that are part of the arguments to `jit`, but are not
// otherwise part of CallSignature.
std::vector<int> donate_argnums;
nb::object global_cache_key;

bool operator==(const Key& other) const {
return function.ptr() == other.function.ptr() &&
donate_argnums == other.donate_argnums;
bool global_cache_eq;
try {
global_cache_eq = global_cache_key.equal(other.global_cache_key);
} catch (const nanobind::python_error& e) {
throw std::invalid_argument(
absl::StrCat("Equality of global cache key lead to an exception. "
"The error was:\n",
e.what(), "\n"));
}
return function.ptr() == other.function.ptr() && global_cache_eq;
}
};

template <typename H>
friend H AbslHashValue(H h, const Key& key) {
h = H::combine(std::move(h), key.function.ptr());
h = H::combine_contiguous(std::move(h), key.donate_argnums.data(),
key.donate_argnums.size());
Py_hash_t hash;
try {
hash = xla::nb_hash(key.global_cache_key);
} catch (const nanobind::python_error& e) {
if (!e.matches(PyExc_TypeError)) throw;
throw std::invalid_argument(absl::StrCat(
"Hashing global cache key lead to an exception. The error was:\n",
e.what(), "\n"));
}
h = H::combine(std::move(h), hash);
return h;
}

Expand All @@ -167,7 +188,9 @@ class PjitFunctionCache {
};

Cache::LRUList lru_list_;
absl::flat_hash_map<Key, std::unique_ptr<Value>> functions_;
absl::Mutex mu_; // Non-trivial hashes need to be mutex locked.
// ABSL containers are not exception safe:
std::unordered_map<Key, std::unique_ptr<Value>, absl::Hash<Key>> functions_;
};

PjitFunctionCache::PjitFunctionCache(int capacity) : lru_list_(capacity) {}
Expand All @@ -177,19 +200,31 @@ std::shared_ptr<PjitFunctionCache::Cache> PjitFunctionCache::DefaultCache() {
}

std::shared_ptr<PjitFunctionCache::Cache> PjitFunctionCache::Lookup(
nb::handle function, absl::Span<const int> donate_argnums) {
nb::handle function,
nb::object global_cache_key) ABSL_NO_THREAD_SAFETY_ANALYSIS {
{
// Because the gil can be released during cache insertion, this forces
// the lock order to be mu_ then gil so we must release the gil first.
nb::gil_scoped_release release;
// Acquire a mutex to avoid problems where the gil is released during
// cache insertion and then a second thread invalidates the cache order.
mu_.Lock();
}
absl::Cleanup unlock = [this]() ABSL_UNLOCK_FUNCTION(mu_) { mu_.Unlock(); };
Key key;
key.function = function;
key.donate_argnums =
std::vector<int>(donate_argnums.begin(), donate_argnums.end());
key.global_cache_key = global_cache_key;
auto insert = functions_.emplace(key, nullptr);
if (!insert.second) {
return insert.first->second->cache;
}
std::shared_ptr<Cache> cache = std::make_shared<Cache>(&lru_list_);
auto callback =
nb::cpp_function([this, key{std::move(key)}](nb::handle weakref) {
functions_.erase(key);
auto it = functions_.find(key);
if (it != functions_.end()) {
functions_.erase(it);
}
});
PyObject* weakref = PyWeakref_NewRef(function.ptr(), callback.ptr());
if (weakref) {
Expand All @@ -211,7 +246,7 @@ class PjitFunction {
PjitFunction(std::string function_name, std::optional<nb::callable> fun,
nb::callable cache_miss, std::vector<int> static_argnums,
std::vector<nb::str> static_argnames,
std::vector<int> donate_argnums,
nb::object global_cache_key,
std::shared_ptr<xla::PyTreeRegistry> pytree_registry,
nb::callable shard_arg_fallback,
std::shared_ptr<PjitFunctionCache> cache);
Expand Down Expand Up @@ -258,7 +293,7 @@ class PjitFunction {
const std::vector<nb::str>& static_argnames() const {
return static_argnames_;
}
const std::vector<int>& donate_argnums() const { return donate_argnums_; }
const nb::object& global_cache_key() const { return global_cache_key_; }
const std::shared_ptr<PjitFunctionCache>& cache() const { return cache_; }

int cache_capacity() const { return executables_->Size(); }
Expand Down Expand Up @@ -291,7 +326,7 @@ class PjitFunction {
nb::callable cache_miss_;
std::vector<int> static_argnums_;
std::vector<nb::str> static_argnames_;
std::vector<int> donate_argnums_;
nb::object global_cache_key_;

std::shared_ptr<xla::PyTreeRegistry> pytree_registry_;
nb::callable shard_arg_fallback_;
Expand Down Expand Up @@ -325,14 +360,14 @@ PjitFunctionStore& GetGlobalPjitFunctionStore() {
PjitFunction::PjitFunction(
std::string function_name, std::optional<nb::callable> fun,
nb::callable cache_miss, std::vector<int> static_argnums,
std::vector<nb::str> static_argnames, std::vector<int> donate_argnums,
std::vector<nb::str> static_argnames, nb::object global_cache_key,
std::shared_ptr<xla::PyTreeRegistry> pytree_registry,
nb::callable shard_arg_fallback, std::shared_ptr<PjitFunctionCache> cache)
: function_name_(std::move(function_name)),
fun_(std::move(fun)),
cache_miss_(std::move(cache_miss)),
static_argnums_(std::move(static_argnums)),
donate_argnums_(donate_argnums),
global_cache_key_(global_cache_key),
pytree_registry_(std::move(pytree_registry)),
shard_arg_fallback_(std::move(shard_arg_fallback)),
cache_(std::move(cache)) {
Expand All @@ -346,7 +381,7 @@ PjitFunction::PjitFunction(
if (!fun_.has_value()) {
executables_ = cache_->DefaultCache();
} else {
executables_ = cache_->Lookup(fun_.value(), donate_argnums);
executables_ = cache_->Lookup(fun_.value(), global_cache_key);
}

GetGlobalPjitFunctionStore().Insert(this);
Expand Down Expand Up @@ -1029,20 +1064,20 @@ void InitializePjitFunction(
PjitFunctionObject* fn_obj, std::string function_name,
std::optional<nb::callable> fun, nb::callable cache_miss,
std::vector<int> static_argnums, std::vector<nb::str> static_argnames,
std::vector<int> donate_argnums,
nb::object global_cache_key,
std::shared_ptr<xla::PyTreeRegistry> pytree_registry,
nb::callable shard_arg_fallback, std::shared_ptr<PjitFunctionCache> cache) {
new (&fn_obj->fun) PjitFunction(
std::move(function_name), std::move(fun), std::move(cache_miss),
std::move(static_argnums), std::move(static_argnames),
std::move(donate_argnums), std::move(pytree_registry),
std::move(global_cache_key), std::move(pytree_registry),
std::move(shard_arg_fallback), std::move(cache));
}

nb::object MakePjitFunction(
std::string function_name, std::optional<nb::callable> fun,
nb::callable cache_miss, std::vector<int> static_argnums,
std::vector<nb::str> static_argnames, std::vector<int> donate_argnums,
std::vector<nb::str> static_argnames, nb::object global_cache_key,
std::shared_ptr<xla::PyTreeRegistry> pytree_registry,
nb::callable shard_arg_fallback,
std::optional<std::shared_ptr<PjitFunctionCache>> cache) {
Expand All @@ -1053,11 +1088,11 @@ nb::object MakePjitFunction(
cache = std::make_shared<PjitFunctionCache>(
PjitFunctionCache::kDefaultCapacity);
}
InitializePjitFunction(fn_obj, std::move(function_name), std::move(fun),
std::move(cache_miss), std::move(static_argnums),
std::move(static_argnames), std::move(donate_argnums),
std::move(pytree_registry),
std::move(shard_arg_fallback), std::move(*cache));
InitializePjitFunction(
fn_obj, std::move(function_name), std::move(fun), std::move(cache_miss),
std::move(static_argnums), std::move(static_argnames),
std::move(global_cache_key), std::move(pytree_registry),
std::move(shard_arg_fallback), std::move(*cache));
return obj;
}

Expand Down Expand Up @@ -1169,7 +1204,7 @@ void BuildPjitSubmodule(nb::module_& m) {
pickle["cache_miss"] = fn->cache_miss();
pickle["static_argnums"] = fn->static_argnums();
pickle["static_argnames"] = nb::cast(fn->static_argnames());
pickle["donate_argnums"] = fn->donate_argnums();
pickle["global_cache_key"] = fn->global_cache_key();
pickle["pytree_registry"] = nb::cast(fn->pytree_registry());
pickle["shard_arg_fallback"] = fn->shard_arg_fallback();
pickle["cache"] = fn->cache();
Expand Down Expand Up @@ -1197,8 +1232,7 @@ void BuildPjitSubmodule(nb::module_& m) {
nb::cast<std::vector<int>>(pickle["static_argnums"]);
std::vector<nb::str> static_argnames =
nb::cast<std::vector<nb::str>>(pickle["static_argnames"]);
std::vector<int> donate_argnums =
nb::cast<std::vector<int>>(pickle["donate_argnums"]);
nb::object global_cache_key = pickle["global_cache_key"];
std::shared_ptr<xla::PyTreeRegistry> pytree_registry =
nb::cast<std::shared_ptr<xla::PyTreeRegistry>>(
nb::handle(pickle["pytree_registry"].ptr()));
Expand All @@ -1210,7 +1244,7 @@ void BuildPjitSubmodule(nb::module_& m) {
reinterpret_cast<PjitFunctionObject*>(self.ptr()),
std::move(function_name), std::move(fun), std::move(cache_miss),
std::move(static_argnums), std::move(static_argnames),
std::move(donate_argnums), std::move(pytree_registry),
std::move(global_cache_key), std::move(pytree_registry),
std::move(shard_arg_fallback), std::move(cache));
},
nb::is_method());
Expand All @@ -1236,7 +1270,7 @@ void BuildPjitSubmodule(nb::module_& m) {
"pjit",
[](std::string function_name, std::optional<nb::callable> fun,
nb::callable cache_miss, std::vector<int> static_argnums,
std::vector<nb::str> static_argnames, std::vector<int> donate_argnums,
std::vector<nb::str> static_argnames, nb::object global_cache_key,
nb::object pytree_registry, nb::callable shard_arg_fallback,
std::optional<std::shared_ptr<PjitFunctionCache>> cache) {
std::shared_ptr<xla::PyTreeRegistry> registry =
Expand All @@ -1245,12 +1279,12 @@ void BuildPjitSubmodule(nb::module_& m) {
return MakePjitFunction(
std::move(function_name), std::move(fun), std::move(cache_miss),
std::move(static_argnums), std::move(static_argnames),
std::move(donate_argnums), std::move(registry),
std::move(global_cache_key), std::move(registry),
std::move(shard_arg_fallback), std::move(cache));
},
nb::arg("function_name"), nb::arg("fun").none(), nb::arg("cache_miss"),
nb::arg("static_argnums"), nb::arg("static_argnames"),
nb::arg("donate_argnums"), nb::arg("pytree_registry"),
nb::arg("global_cache_key"), nb::arg("pytree_registry"),
nb::arg("shard_arg_fallback"), nb::arg("cache").none() = nb::none());
}

Expand Down
2 changes: 1 addition & 1 deletion xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 285
_version = 286

# Version number for MLIR:Python components.
mlir_api_version = 57
Expand Down
2 changes: 1 addition & 1 deletion xla/python/xla_extension/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ def pjit(
cache_miss: Callable,
static_argnums: Sequence[int],
static_argnames: Sequence[str],
donate_argnums: Sequence[int],
global_cache_key: Any,
pytree_registry: pytree.PyTreeRegistry,
shard_arg_fallback: Callable,
cache: Optional[PjitFunctionCache] = ...,
Expand Down

0 comments on commit 52e0c5e

Please sign in to comment.