Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:Python] Fix more bugs in the weakref_lru_cache implementation. #17866

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,7 @@ cc_library(
# placeholder for index annotation deps
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@nanobind",
Expand Down
119 changes: 74 additions & 45 deletions xla/python/weakref_lru_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.

#include "absl/base/thread_annotations.h"
#include "absl/cleanup/cleanup.h"
#include "absl/hash/hash.h"
#include "absl/strings/str_cat.h"
#include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h"
Expand Down Expand Up @@ -78,34 +79,58 @@ class HashablePyDictIter {
nb::detail::dict_iterator& iter_;
};

struct HashableKey {
nb::object context;
nb::args args;
nb::kwargs kwargs;

template <typename H>
friend H AbslHashValue(H h, const HashableKey& key) {
// Note: Despite the fact this is an ABSL hash function, it's safe to call
// functions that may throw exceptions such as nb::hash(), because it is
// used by an LRUCache, which uses a std::unordered_map, which is
// exception-safe.
h = H::combine(std::move(h), nb::hash(key.context), nb::hash(key.args));
nb::detail::dict_iterator begin = key.kwargs.begin();
nb::detail::dict_iterator end = key.kwargs.end();
h = H::combine_unordered(std::move(h), HashablePyDictIter(begin),
HashablePyDictIter(end));
h = H::combine(std::move(h), key.kwargs.size());
return h;
}
};

} // namespace

class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
public:
struct Key {
nb::object context;
nb::args args;
nb::kwargs kwargs;
class Key {
public:
Key(nb::object context, nb::args args, nb::kwargs kwargs)
: context_(std::move(context)),
args_(std::move(args)),
kwargs_(std::move(kwargs)),
cached_hash_(absl::HashOf(HashableKey{context_, args_, kwargs_})) {}

bool operator==(const Key& other) const {
return context.equal(other.context) && args.equal(other.args) &&
kwargs.equal(other.kwargs);
return context_.equal(other.context_) && args_.equal(other.args_) &&
kwargs_.equal(other.kwargs_);
}

template <typename H>
friend H AbslHashValue(H h, const Key& key) {
// Note: Despite the fact this is an ABSL hash function, it's safe to call
// functions that may throw exceptions such as nb::hash(), because it is
// used by an LRUCache, which uses a std::unordered_map, which is
// exception-safe.
h = H::combine(std::move(h), nb::hash(key.context), nb::hash(key.args));
nb::detail::dict_iterator begin = key.kwargs.begin();
nb::detail::dict_iterator end = key.kwargs.end();
h = H::combine_unordered(std::move(h), HashablePyDictIter(begin),
HashablePyDictIter(end));
h = H::combine(std::move(h), key.kwargs.size());
return h;
return H::combine(std::move(h), key.cached_hash_);
}

nb::object context() const { return context_; }
nb::args args() const { return args_; }
nb::kwargs kwargs() const { return kwargs_; }

private:
nb::object context_;
nb::args args_;
nb::kwargs kwargs_;
size_t cached_hash_;
};

struct CacheEntry {
Expand All @@ -123,14 +148,13 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
};

struct WeakrefCacheKey {
nb::handle object;
nb::weakref ref;
size_t cached_hash;
};

using Cache = xla::LRUCache<Key, std::shared_ptr<CacheEntry>>;

struct WeakrefCacheValue {
std::optional<nb::weakref> weakref;
std::shared_ptr<Cache> cache;
};

Expand All @@ -141,7 +165,7 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
struct WeakrefKeyEq {
bool operator()(const WeakrefCacheKey& lhs,
const WeakrefCacheKey& rhs) const {
return lhs.object.equal(rhs.object);
return lhs.ref.equal(rhs.ref);
}
};

Expand All @@ -150,43 +174,49 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
: cache_context_fn_(cache_context_fn), fn_(fn), lru_list_(maxsize) {}

std::shared_ptr<Cache> GetCache(WeakrefCacheKey key) {
auto [it, inserted] = entries_.emplace(key, WeakrefCacheValue());
if (!inserted) {
return it->second.cache;
WeakrefCacheValue& value = entries_[key];
if (!value.cache) {
value.cache = std::make_shared<Cache>(&lru_list_);
}
return value.cache;
}

auto& value = it->second;
nb::object Call(nb::object weakref_key, nb::args args,
nb::kwargs kwargs) ABSL_NO_THREAD_SAFETY_ANALYSIS {
nb::object context = cache_context_fn_();

// We precompute all of the hash values needed by the various maps rather
// than computing them during the std::unordered_map insertions. At the very
// least, MSVC's std::unordered_map has undefined behavior if the hash
// function throws an exception
// (https://learn.microsoft.com/en-us/cpp/standard-library/unordered-map-class?view=msvc-170#emplace).
Key key(context, args, kwargs);
size_t wrcache_hash = static_cast<size_t>(nb::hash(weakref_key));

// No hash computations after this point.

value.cache = std::make_shared<Cache>(&lru_list_);
auto weakref_gc_callback = nb::cpp_function(
[this_weak = weak_from_this(), key](nb::handle weakref) {
[this_weak = weak_from_this(), wrcache_hash](nb::handle weakref) {
auto cache = this_weak.lock();
if (cache == nullptr) {
return;
}
auto it = cache->entries_.find(key);
// The object the reference referred to is now in the process of being
// destroyed, so we cannot refer to its contents. Python weakref
// objects compare based on identity if the object they refer to is
// gone, so the hash lookup will work fine.
auto it = cache->entries_.find(
WeakrefCacheKey{nb::borrow<nb::weakref>(weakref), wrcache_hash});
if (it == cache->entries_.end()) {
return;
}
// Create temp-var to avoid re-entrant erase.
auto tmp = std::move(it->second);
cache->entries_.erase(it);
});
PyObject* ref =
PyWeakref_NewRef(key.object.ptr(), weakref_gc_callback.ptr());
if (!ref) {
entries_.erase(it);
throw nb::python_error();
}
value.weakref = nb::steal<nb::weakref>(ref);
return value.cache;
}

nb::object Call(nb::object weakref_key, nb::args args,
nb::kwargs kwargs) ABSL_NO_THREAD_SAFETY_ANALYSIS {
nb::object context = cache_context_fn_();
std::shared_ptr<Cache> cache_ptr = GetCache(WeakrefCacheKey{
weakref_key, static_cast<size_t>(nb::hash(weakref_key))});
nb::weakref weakref = nb::weakref(weakref_key, weakref_gc_callback);
WeakrefCacheKey wrcache_key{weakref, wrcache_hash};
std::shared_ptr<Cache> cache_ptr = GetCache(wrcache_key);
Cache& cache = *cache_ptr;
++total_queries_;

Expand All @@ -206,7 +236,6 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
// released if that happens.
absl::Cleanup unlock = [this]()
ABSL_UNLOCK_FUNCTION(mu_) { mu_.Unlock(); };
Key key{context, args, kwargs};
entry = cache.GetOrCreateIfAbsent(key, [&inserted](const Key& key) {
inserted = true;
return std::make_shared<CacheEntry>();
Expand Down Expand Up @@ -245,8 +274,8 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
for (const auto& wr_entry : entries_) {
for (const auto& rest : *wr_entry.second.cache) {
nb::tuple result =
nb::make_tuple(*wr_entry.second.weakref, rest.first.context,
rest.first.args, rest.first.kwargs);
nb::make_tuple(*wr_entry.first.ref, rest.first.context(),
rest.first.args(), rest.first.kwargs());
results.push_back(std::move(result));
}
}
Expand Down
23 changes: 23 additions & 0 deletions xla/python/weakref_lru_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,29 @@ class WRKey:
"WeakrefLRUCache(hits=5, misses=10, maxsize=2048, currsize=10)",
)

def testGCKeys(self):
class WRKey:

def __init__(self, x):
self.x = x

def __eq__(self, other):
return self.x == other.x

def __hash__(self):
return hash(self.x)

cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048)
keys = [WRKey(i) for i in range(10)]
for i in range(10):
cache(keys[i], i)

# Delete some keys, to exercise the weakref callback behavior.
del keys[::2]

for key in keys:
cache(key, 7)


if __name__ == "__main__":
absltest.main()
Loading