Skip to content

Commit

Permalink
[XLA:Python] Fix more bugs in the weakref_lru_cache implementation.
Browse files Browse the repository at this point in the history
a) MSVC's std::unordered_map says behavior is undefined if the hash function throws an exception (https://learn.microsoft.com/en-us/cpp/standard-library/unordered-map-class?view=msvc-170#emplace). That's easy to work around, though: we can just precompute all the hash functions.
b) my idiom for avoiding heterogenous lookups had a use after free problem: the weakref callback is called after the object is already in an invalid state. However, there's a much simpler solution: just create the weakref object and use it as a key unconditionally. Yes, this will mean we create more weak references than perhaps we had to otherwise. But this is simple and obviously correct.

PiperOrigin-RevId: 681479822
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Oct 2, 2024
1 parent 849f758 commit 148e245
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 45 deletions.
1 change: 1 addition & 0 deletions xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,7 @@ cc_library(
# placeholder for index annotation deps
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@nanobind",
Expand Down
119 changes: 74 additions & 45 deletions xla/python/weakref_lru_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.

#include "absl/base/thread_annotations.h"
#include "absl/cleanup/cleanup.h"
#include "absl/hash/hash.h"
#include "absl/strings/str_cat.h"
#include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h"
Expand Down Expand Up @@ -78,34 +79,58 @@ class HashablePyDictIter {
nb::detail::dict_iterator& iter_;
};

struct HashableKey {
nb::object context;
nb::args args;
nb::kwargs kwargs;

template <typename H>
friend H AbslHashValue(H h, const HashableKey& key) {
// Note: Despite the fact this is an ABSL hash function, it's safe to call
// functions that may throw exceptions such as nb::hash(), because it is
// used by an LRUCache, which uses a std::unordered_map, which is
// exception-safe.
h = H::combine(std::move(h), nb::hash(key.context), nb::hash(key.args));
nb::detail::dict_iterator begin = key.kwargs.begin();
nb::detail::dict_iterator end = key.kwargs.end();
h = H::combine_unordered(std::move(h), HashablePyDictIter(begin),
HashablePyDictIter(end));
h = H::combine(std::move(h), key.kwargs.size());
return h;
}
};

} // namespace

class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
public:
struct Key {
nb::object context;
nb::args args;
nb::kwargs kwargs;
class Key {
public:
Key(nb::object context, nb::args args, nb::kwargs kwargs)
: context_(std::move(context)),
args_(std::move(args)),
kwargs_(std::move(kwargs)),
cached_hash_(absl::HashOf(HashableKey{context_, args_, kwargs_})) {}

bool operator==(const Key& other) const {
return context.equal(other.context) && args.equal(other.args) &&
kwargs.equal(other.kwargs);
return context_.equal(other.context_) && args_.equal(other.args_) &&
kwargs_.equal(other.kwargs_);
}

template <typename H>
friend H AbslHashValue(H h, const Key& key) {
// Note: Despite the fact this is an ABSL hash function, it's safe to call
// functions that may throw exceptions such as nb::hash(), because it is
// used by an LRUCache, which uses a std::unordered_map, which is
// exception-safe.
h = H::combine(std::move(h), nb::hash(key.context), nb::hash(key.args));
nb::detail::dict_iterator begin = key.kwargs.begin();
nb::detail::dict_iterator end = key.kwargs.end();
h = H::combine_unordered(std::move(h), HashablePyDictIter(begin),
HashablePyDictIter(end));
h = H::combine(std::move(h), key.kwargs.size());
return h;
return H::combine(std::move(h), key.cached_hash_);
}

nb::object context() const { return context_; }
nb::args args() const { return args_; }
nb::kwargs kwargs() const { return kwargs_; }

private:
nb::object context_;
nb::args args_;
nb::kwargs kwargs_;
size_t cached_hash_;
};

struct CacheEntry {
Expand All @@ -123,14 +148,13 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
};

struct WeakrefCacheKey {
nb::handle object;
nb::weakref ref;
size_t cached_hash;
};

using Cache = xla::LRUCache<Key, std::shared_ptr<CacheEntry>>;

struct WeakrefCacheValue {
std::optional<nb::weakref> weakref;
std::shared_ptr<Cache> cache;
};

Expand All @@ -141,7 +165,7 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
struct WeakrefKeyEq {
bool operator()(const WeakrefCacheKey& lhs,
const WeakrefCacheKey& rhs) const {
return lhs.object.equal(rhs.object);
return lhs.ref.equal(rhs.ref);
}
};

Expand All @@ -150,43 +174,49 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
: cache_context_fn_(cache_context_fn), fn_(fn), lru_list_(maxsize) {}

std::shared_ptr<Cache> GetCache(WeakrefCacheKey key) {
auto [it, inserted] = entries_.emplace(key, WeakrefCacheValue());
if (!inserted) {
return it->second.cache;
WeakrefCacheValue& value = entries_[key];
if (!value.cache) {
value.cache = std::make_shared<Cache>(&lru_list_);
}
return value.cache;
}

auto& value = it->second;
nb::object Call(nb::object weakref_key, nb::args args,
nb::kwargs kwargs) ABSL_NO_THREAD_SAFETY_ANALYSIS {
nb::object context = cache_context_fn_();

// We precompute all of the hash values needed by the various maps rather
// than computing them during the std::unordered_map insertions. At the very
// least, MSVC's std::unordered_map has undefined behavior if the hash
// function throws an exception
// (https://learn.microsoft.com/en-us/cpp/standard-library/unordered-map-class?view=msvc-170#emplace).
Key key(context, args, kwargs);
size_t wrcache_hash = static_cast<size_t>(nb::hash(weakref_key));

// No hash computations after this point.

value.cache = std::make_shared<Cache>(&lru_list_);
auto weakref_gc_callback = nb::cpp_function(
[this_weak = weak_from_this(), key](nb::handle weakref) {
[this_weak = weak_from_this(), wrcache_hash](nb::handle weakref) {
auto cache = this_weak.lock();
if (cache == nullptr) {
return;
}
auto it = cache->entries_.find(key);
// The object the reference referred to is now in the process of being
// destroyed, so we cannot refer to its contents. Python weakref
// objects compare based on identity if the object they refer to is
// gone, so the hash lookup will work fine.
auto it = cache->entries_.find(
WeakrefCacheKey{nb::borrow<nb::weakref>(weakref), wrcache_hash});
if (it == cache->entries_.end()) {
return;
}
// Create temp-var to avoid re-entrant erase.
auto tmp = std::move(it->second);
cache->entries_.erase(it);
});
PyObject* ref =
PyWeakref_NewRef(key.object.ptr(), weakref_gc_callback.ptr());
if (!ref) {
entries_.erase(it);
throw nb::python_error();
}
value.weakref = nb::steal<nb::weakref>(ref);
return value.cache;
}

nb::object Call(nb::object weakref_key, nb::args args,
nb::kwargs kwargs) ABSL_NO_THREAD_SAFETY_ANALYSIS {
nb::object context = cache_context_fn_();
std::shared_ptr<Cache> cache_ptr = GetCache(WeakrefCacheKey{
weakref_key, static_cast<size_t>(nb::hash(weakref_key))});
nb::weakref weakref = nb::weakref(weakref_key, weakref_gc_callback);
WeakrefCacheKey wrcache_key{weakref, wrcache_hash};
std::shared_ptr<Cache> cache_ptr = GetCache(wrcache_key);
Cache& cache = *cache_ptr;
++total_queries_;

Expand All @@ -206,7 +236,6 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
// released if that happens.
absl::Cleanup unlock = [this]()
ABSL_UNLOCK_FUNCTION(mu_) { mu_.Unlock(); };
Key key{context, args, kwargs};
entry = cache.GetOrCreateIfAbsent(key, [&inserted](const Key& key) {
inserted = true;
return std::make_shared<CacheEntry>();
Expand Down Expand Up @@ -245,8 +274,8 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
for (const auto& wr_entry : entries_) {
for (const auto& rest : *wr_entry.second.cache) {
nb::tuple result =
nb::make_tuple(*wr_entry.second.weakref, rest.first.context,
rest.first.args, rest.first.kwargs);
nb::make_tuple(*wr_entry.first.ref, rest.first.context(),
rest.first.args(), rest.first.kwargs());
results.push_back(std::move(result));
}
}
Expand Down
23 changes: 23 additions & 0 deletions xla/python/weakref_lru_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,29 @@ class WRKey:
"WeakrefLRUCache(hits=5, misses=10, maxsize=2048, currsize=10)",
)

def testGCKeys(self):
class WRKey:

def __init__(self, x):
self.x = x

def __eq__(self, other):
return self.x == other.x

def __hash__(self):
return hash(self.x)

cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048)
keys = [WRKey(i) for i in range(10)]
for i in range(10):
cache(keys[i], i)

# Delete some keys, to exercise the weakref callback behavior.
del keys[::2]

for key in keys:
cache(key, 7)


if __name__ == "__main__":
absltest.main()

0 comments on commit 148e245

Please sign in to comment.