diff --git a/src/core/search/rax_tree.h b/src/core/search/rax_tree.h index f99f3d3d4514..1f5662faad5c 100644 --- a/src/core/search/rax_tree.h +++ b/src/core/search/rax_tree.h @@ -24,30 +24,37 @@ template struct RaxTreeMap { // Simple seeking iterator struct SeekIterator { - friend struct FindIterator; - SeekIterator() { - raxStart(&it_, nullptr); - it_.node = nullptr; - } - - ~SeekIterator() { - raxStop(&it_); + it_.rt = nullptr; } - SeekIterator(SeekIterator&&) = delete; // self-referential - SeekIterator(const SeekIterator&) = delete; // self-referential - SeekIterator(rax* tree, const char* op, std::string_view key) { raxStart(&it_, tree); - raxSeek(&it_, op, to_key_ptr(key), key.size()); - operator++(); + if (raxSeek(&it_, op, to_key_ptr(key), key.size())) { // Successfuly seeked + operator++(); + } else { + InvalidateIterator(); + } } explicit SeekIterator(rax* tree) : SeekIterator(tree, "^", std::string_view{nullptr, 0}) { } + /* Remove copy/move constructors to avoid double iterator invalidation */ + SeekIterator(SeekIterator&&) = delete; + SeekIterator(const SeekIterator&) = delete; + SeekIterator& operator=(SeekIterator&&) = delete; + SeekIterator& operator=(const SeekIterator&) = delete; + + ~SeekIterator() { + if (IsValid()) { + InvalidateIterator(); + } + } + bool operator==(const SeekIterator& rhs) const { + if (!IsValid() || !rhs.IsValid()) + return !IsValid() && !rhs.IsValid(); return it_.node == rhs.it_.node; } @@ -56,31 +63,40 @@ template struct RaxTreeMap { } SeekIterator& operator++() { - if (!raxNext(&it_)) { - raxStop(&it_); - it_.node = nullptr; + int next_result = raxNext(&it_); + if (!next_result) { // OOM or we reached the end of the tree + InvalidateIterator(); } return *this; } + /* After operator++() the first value (string_view) is invalid. So make sure your copied it to + * string */ std::pair operator*() const { + assert(IsValid() && it_.node && it_.node->iskey && it_.data); return {std::string_view{reinterpret_cast(it_.key), it_.key_len}, *reinterpret_cast(it_.data)}; } + bool IsValid() const { + return it_.rt; + } + private: + void InvalidateIterator() { + raxStop(&it_); + it_.rt = nullptr; + } + raxIterator it_; }; // Result of find() call. Inherits from pair to mimic iterator interface, not incrementable. struct FindIterator : public std::optional> { bool operator==(const SeekIterator& rhs) const { - if (this->has_value() != !bool(rhs.it_.flags & RAX_ITER_EOF)) - return false; - if (!this->has_value()) - return true; - return (*this)->first == - std::string_view{reinterpret_cast(rhs.it_.key), rhs.it_.key_len}; + if (!this->has_value() || !rhs.IsValid()) + return !this->has_value() && !rhs.IsValid(); + return (*this)->first == (*rhs).first; } bool operator!=(const SeekIterator& rhs) const { @@ -160,7 +176,7 @@ std::pair::FindIterator, bool> RaxTreeMap::try_emplace V* old = nullptr; raxInsert(tree_, to_key_ptr(key), key.size(), ptr, reinterpret_cast(&old)); - assert(old == nullptr); + assert(!old); auto it = std::make_optional(std::pair(std::string(key), *ptr)); return std::make_pair(std::move(FindIterator{it}), true); diff --git a/src/core/search/rax_tree_test.cc b/src/core/search/rax_tree_test.cc index 69179e705ea3..519fae929871 100644 --- a/src/core/search/rax_tree_test.cc +++ b/src/core/search/rax_tree_test.cc @@ -104,4 +104,28 @@ TEST_F(RaxTreeTest, Find) { EXPECT_TRUE(map.find(string_view{}) == map.end()); } +/* Run with mimalloc to make sure there is no double free */ +TEST_F(RaxTreeTest, Iterate) { + const char* kKeys[] = { + "aaaaaaaaaaaaaaaaaaaa", + "bbbbbbbbbbbbbbbbbbbbbb" + "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc", + "dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd" + "eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee", + }; + + RaxTreeMap map(pmr::get_default_resource()); + for (const char* key : kKeys) { + map.try_emplace(key, 2); + } + + for (auto it = map.begin(); it != map.end(); ++it) { + EXPECT_EQ((*it).second, 2); + } + + for (auto it = map.begin(); it != map.end(); ++it) { + EXPECT_EQ((*it).second, 2); + } +} + } // namespace dfly::search