Skip to content

Commit

Permalink
implement the interfaces for range search, filter in pyramid
Browse files Browse the repository at this point in the history
Signed-off-by: jinjiabao.jjb <[email protected]>
  • Loading branch information
jinjiabao.jjb committed Jan 23, 2025
1 parent 797d7f4 commit 12602df
Show file tree
Hide file tree
Showing 5 changed files with 282 additions and 50 deletions.
198 changes: 158 additions & 40 deletions src/index/pyramid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,115 @@

namespace vsag {

Binary
binaryset_to_binary(const BinarySet binary_set) {
/*
* The serialized layout of the Binary data in memory will be as follows:
* | key_size_0 | key_0 (L_0 bytes) | binary_size_0 | binary_data_0 (S_0 bytes) |
* | key_size_1 | key_1 (L_1 bytes) | binary_size_1 | binary_data_1 (S_1 bytes) |
* | ... | ... | ... | ... |
* | key_size_(N-1) | key_(N-1) (L_(N-1) bytes) | binary_size_(N-1) | binary_data_(N-1) (S_(N-1) bytes) |
* Where:
* - `key_size_k`: size of the k-th key (in bytes)
* - `key_k`: the actual k-th key data (length L_k)
* - `binary_size_k`: size of the binary data associated with the k-th key (in bytes)
* - `binary_data_k`: the actual binary data contents (length S_k)
* - N: total number of keys in the BinarySet
*/
size_t total_size = 0;
auto keys = binary_set.GetKeys();

for (const auto& key : keys) {
total_size += sizeof(size_t) + key.size();
total_size += sizeof(size_t);
total_size += binary_set.Get(key).size;
}

Binary result;
result.data = std::shared_ptr<int8_t[]>(new int8_t[total_size]);
result.size = total_size;

size_t offset = 0;

for (const auto& key : keys) {
size_t key_size = key.size();
memcpy(result.data.get() + offset, &key_size, sizeof(size_t));
offset += sizeof(size_t);
memcpy(result.data.get() + offset, key.data(), key_size);
offset += key_size;

Binary binary = binary_set.Get(key);
memcpy(result.data.get() + offset, &binary.size, sizeof(size_t));
offset += sizeof(size_t);
memcpy(result.data.get() + offset, binary.data.get(), binary.size);
offset += binary.size;
}

return result;
}

BinarySet
binary_to_binaryset(const Binary binary) {
/*
* The Binary structure is serialized in the following layout:
* | key_size (sizeof(size_t)) | key (of length key_size) | binary_size (sizeof(size_t)) | binary data (of length binary_size) |
* Each key and its associated binary data are sequentially stored in the Binary object's data array,
* and this information guides the deserialization process here.
*/
BinarySet binary_set;
size_t offset = 0;

while (offset < binary.size) {
size_t key_size;
memcpy(&key_size, binary.data.get() + offset, sizeof(size_t));
offset += sizeof(size_t);

std::string key(reinterpret_cast<const char*>(binary.data.get() + offset), key_size);
offset += key_size;

size_t binary_size;
memcpy(&binary_size, binary.data.get() + offset, sizeof(size_t));
offset += sizeof(size_t);

Binary new_binary;
new_binary.size = binary_size;
new_binary.data = std::shared_ptr<int8_t[]>(new int8_t[binary_size]);
memcpy(new_binary.data.get(), binary.data.get() + offset, binary_size);
offset += binary_size;

binary_set.Set(key, new_binary);
}

return binary_set;
}

Check warning on line 100 in src/index/pyramid.cpp

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.cpp#L100

Added line #L100 was not covered by tests

ReaderSet
reader_to_readerset(std::shared_ptr<Reader> reader) {
ReaderSet reader_set;
size_t offset = 0;

while (offset < reader->Size()) {
size_t key_size;
reader->Read(offset, sizeof(size_t), &key_size);
offset += sizeof(size_t);
std::shared_ptr<char[]> key_chars = std::shared_ptr<char[]>(new char[key_size]);
reader->Read(offset, key_size, key_chars.get());
std::string key(key_chars.get(), key_size);
offset += key_size;

size_t binary_size;
reader->Read(offset, sizeof(size_t), &binary_size);
offset += sizeof(size_t);

auto new_eader = std::make_shared<SubReader>(reader, offset, binary_size);
offset += binary_size;

reader_set.Set(key, new_eader);
}

return reader_set;
}

Check warning on line 127 in src/index/pyramid.cpp

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.cpp#L127

Added line #L127 was not covered by tests

template <typename T>
using Deque = std::deque<T, vsag::AllocatorWrapper<T>>;

Expand Down Expand Up @@ -85,10 +194,10 @@ Pyramid::Add(const DatasetPtr& base) {
}

tl::expected<DatasetPtr, Error>
Pyramid::KnnSearch(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
BitsetPtr invalid) const {
Pyramid::knn_search(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
const SearchFunc& search_func) const {
auto path = query->GetPaths(); // TODO(inabao): provide different search modes.

std::string current_path = path[0];
Expand Down Expand Up @@ -118,7 +227,7 @@ Pyramid::KnnSearch(const DatasetPtr& query,
auto node = candidate_indexes.front();
candidate_indexes.pop_front();
if (node->index) {
auto result = node->index->KnnSearch(query, k, parameters, invalid);
auto result = search_func(node->index);
if (result.has_value()) {
DatasetPtr r = result.value();
for (int i = 0; i < r->GetDim(); ++i) {
Expand Down Expand Up @@ -162,52 +271,61 @@ Pyramid::KnnSearch(const DatasetPtr& query,
return result;
}

tl::expected<DatasetPtr, Error>
Pyramid::KnnSearch(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
const std::function<bool(int64_t)>& filter) const {
return {};
}

tl::expected<DatasetPtr, Error>
Pyramid::RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
int64_t limited_size) const {
return {};
}

tl::expected<DatasetPtr, Error>
Pyramid::RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
BitsetPtr invalid,
int64_t limited_size) const {
return {};
}

tl::expected<DatasetPtr, Error>
Pyramid::RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
const std::function<bool(int64_t)>& filter,
int64_t limited_size) const {
return {};
}

tl::expected<BinarySet, Error>
Pyramid::Serialize() const {
return {};
BinarySet binary_set;
for (const auto& root_index : indexes_) {
std::string path = root_index.first;
std::vector<std::pair<std::string, std::shared_ptr<IndexNode>>> need_serialize_indexes;
need_serialize_indexes.emplace_back(path, root_index.second);
while (not need_serialize_indexes.empty()) {
auto [current_path, index_node] = need_serialize_indexes.back();
need_serialize_indexes.pop_back();
if (index_node->index) {
auto serialize_result = index_node->index->Serialize();
if (not serialize_result.has_value()) {
return tl::unexpected(serialize_result.error());

Check warning on line 287 in src/index/pyramid.cpp

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.cpp#L287

Added line #L287 was not covered by tests
}
binary_set.Set(current_path, binaryset_to_binary(serialize_result.value()));
}
for (const auto& sub_index_node : index_node->children) {
need_serialize_indexes.emplace_back(
current_path + PART_OCTOTHORPE + sub_index_node.first, sub_index_node.second);
}
}
}
return binary_set;
}

tl::expected<void, Error>
Pyramid::Deserialize(const BinarySet& binary_set) {
auto keys = binary_set.GetKeys();
for (const auto& path : keys) {
const auto& binary = binary_set.Get(path);
auto path_slices = split(path, PART_OCTOTHORPE);
std::shared_ptr<IndexNode> node = try_get_node_with_init(indexes_, path_slices[0]);
for (int j = 1; j < path_slices.size(); ++j) {
node = try_get_node_with_init(node->children, path_slices[j]);
}
node->CreateIndex(pyramid_param_.index_builder);
node->index->Deserialize(binary_to_binaryset(binary));
}
return {};
}

tl::expected<void, Error>
Pyramid::Deserialize(const ReaderSet& reader_set) {
auto keys = reader_set.GetKeys();
for (const auto& path : keys) {
const auto& reader = reader_set.Get(path);
auto path_slices = split(path, PART_OCTOTHORPE);
std::shared_ptr<IndexNode> node = try_get_node_with_init(indexes_, path_slices[0]);
for (int j = 1; j < path_slices.size(); ++j) {
node = try_get_node_with_init(node->children, path_slices[j]);
}
node->CreateIndex(pyramid_param_.index_builder);
node->index->Deserialize(reader_to_readerset(reader));
}
return {};
}

Expand Down
86 changes: 81 additions & 5 deletions src/index/pyramid.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,48 @@

#include <utility>

#include "base_filter_functor.h"
#include "logger.h"
#include "pyramid_zparameters.h"
#include "safe_allocator.h"

namespace vsag {

class SubReader : public Reader {
public:
SubReader(std::shared_ptr<Reader> parrent_reader, uint64_t start_pos, uint64_t size)
: parrent_reader_(std::move(parrent_reader)), size_(size), start_pos_(start_pos) {
}

void
Read(uint64_t offset, uint64_t len, void* dest) override {
if (offset + len > size_)
throw std::out_of_range("Read out of range.");

Check warning on line 36 in src/index/pyramid.h

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.h#L36

Added line #L36 was not covered by tests
parrent_reader_->Read(offset + start_pos_, len, dest);
}

void
AsyncRead(uint64_t offset, uint64_t len, void* dest, CallBack callback) override {
}

Check warning on line 42 in src/index/pyramid.h

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.h#L41-L42

Added lines #L41 - L42 were not covered by tests

uint64_t
Size() const override {
return size_;
}

private:
std::shared_ptr<Reader> parrent_reader_;
uint64_t size_;
uint64_t start_pos_;
};

Binary
binaryset_to_binary(const BinarySet binary_set);
BinarySet
binary_to_binaryset(const Binary binary);
ReaderSet
reader_to_readerset(std::shared_ptr<Reader> reader);

struct IndexNode {
std::shared_ptr<Index> index{nullptr};
UnorderedMap<std::string, std::shared_ptr<IndexNode>> children;
Expand All @@ -35,6 +72,8 @@ struct IndexNode {
}
};

using SearchFunc = std::function<tl::expected<DatasetPtr, Error>(IndexPtr)>;

class Pyramid : public Index {
public:
Pyramid(PyramidParameters pyramid_param, const IndexCommonParam& commom_param)
Expand All @@ -55,33 +94,64 @@ class Pyramid : public Index {
KnnSearch(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
BitsetPtr invalid = nullptr) const override;
BitsetPtr invalid = nullptr) const override {
SearchFunc search_func = [&](IndexPtr index) {
return index->KnnSearch(query, k, parameters, invalid);
};
SAFE_CALL(return this->knn_search(query, k, parameters, search_func);)
}

tl::expected<DatasetPtr, Error>
KnnSearch(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
const std::function<bool(int64_t)>& filter) const override;
const std::function<bool(int64_t)>& filter) const override {
SearchFunc search_func = [&](IndexPtr index) {
return index->KnnSearch(query, k, parameters, filter);
};
SAFE_CALL(return this->knn_search(query, k, parameters, search_func);)
}

tl::expected<DatasetPtr, Error>
RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
int64_t limited_size = -1) const override;
int64_t limited_size = -1) const override {
SearchFunc search_func = [&](IndexPtr index) {
return index->RangeSearch(query, radius, parameters, limited_size);
};
int64_t final_limit =
limited_size == -1 ? std::numeric_limits<int64_t>::max() : limited_size;
SAFE_CALL(return this->knn_search(query, final_limit, parameters, search_func);)
}

tl::expected<DatasetPtr, Error>
RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
BitsetPtr invalid,
int64_t limited_size = -1) const override;
int64_t limited_size = -1) const override {
SearchFunc search_func = [&](IndexPtr index) {
return index->RangeSearch(query, radius, parameters, invalid, limited_size);
};

Check warning on line 136 in src/index/pyramid.h

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.h#L134-L136

Added lines #L134 - L136 were not covered by tests
int64_t final_limit =
limited_size == -1 ? std::numeric_limits<int64_t>::max() : limited_size;
SAFE_CALL(return this->knn_search(query, final_limit, parameters, search_func);)
}

Check warning on line 140 in src/index/pyramid.h

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.h#L138-L140

Added lines #L138 - L140 were not covered by tests

tl::expected<DatasetPtr, Error>
RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
const std::function<bool(int64_t)>& filter,
int64_t limited_size = -1) const override;
int64_t limited_size = -1) const override {
SearchFunc search_func = [&](IndexPtr index) {
return index->RangeSearch(query, radius, parameters, filter, limited_size);
};

Check warning on line 150 in src/index/pyramid.h

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.h#L148-L150

Added lines #L148 - L150 were not covered by tests
int64_t final_limit =
limited_size == -1 ? std::numeric_limits<int64_t>::max() : limited_size;
SAFE_CALL(return this->knn_search(query, final_limit, parameters, search_func);)
}

Check warning on line 154 in src/index/pyramid.h

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.h#L152-L154

Added lines #L152 - L154 were not covered by tests

tl::expected<BinarySet, Error>
Serialize() const override;
Expand All @@ -99,6 +169,12 @@ class Pyramid : public Index {
GetMemoryUsage() const override;

private:
tl::expected<DatasetPtr, Error>
knn_search(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
const SearchFunc& search_func) const;

inline std::shared_ptr<IndexNode>
try_get_node_with_init(UnorderedMap<std::string, std::shared_ptr<IndexNode>>& index_map,
const std::string& key) {
Expand Down
7 changes: 5 additions & 2 deletions tests/fixtures/test_dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,11 @@ TestDataset::CreateTestDataset(uint64_t dim,
dataset->range_ground_truth_ = dataset->ground_truth_;
dataset->range_radius_.resize(query_count);
for (uint64_t i = 0; i < query_count; ++i) {
dataset->range_radius_[i] = 0.5f * (result.first[i * count + dataset->top_k] +
result.first[i * count + dataset->top_k - 1]);
dataset->range_radius_[i] =
0.5f * (dataset->range_ground_truth_
->GetDistances()[i * dataset->top_k + dataset->top_k - 1] +
dataset->range_ground_truth_
->GetDistances()[i * dataset->top_k + dataset->top_k - 2]);
}
delete[] result.first;
delete[] result.second;
Expand Down
Loading

0 comments on commit 12602df

Please sign in to comment.