Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support inserting and querying vectors with nan values #248

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/algorithm/hnswlib/hnswalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,9 +437,11 @@ HierarchicalNSW::searchBaseLayerST(InnerIdType ep_id,
_mm_prefetch(vector_data_ptr, _MM_HINT_T0);
#endif

if ((!has_deletions || !isMarkedDeleted(candidate_id)) &&
((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))
top_candidates.emplace(dist, candidate_id);
if ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))) {
if (not std::isnan(dist)) {
top_candidates.emplace(dist, candidate_id);
}
}

if (top_candidates.size() > ef)
top_candidates.pop();
Expand Down Expand Up @@ -531,9 +533,11 @@ HierarchicalNSW::searchBaseLayerST(InnerIdType ep_id,
_mm_prefetch(vector_data_ptr, _MM_HINT_T0); ////////////////////////
#endif

if ((!has_deletions || !isMarkedDeleted(candidate_id)) &&
((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))
top_candidates.emplace(dist, candidate_id);
if (((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id)))) {
if (not std::isnan(dist)) {
top_candidates.emplace(dist, candidate_id);
}
}

if (not top_candidates.empty())
lowerBound = top_candidates.top().first;
Expand Down
83 changes: 67 additions & 16 deletions tests/fixtures/test_dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ CopyVector(const std::vector<T>& vec) {
static TestDataset::DatasetPtr
GenerateRandomDataset(uint64_t dim, uint64_t count, std::string metric_str = "l2") {
auto base = vsag::Dataset::Make();
constexpr uint64_t query_count = 100;
bool need_normalize = (metric_str != "cosine");
auto vecs =
fixtures::generate_vectors(count, dim, need_normalize, fixtures::RandomValue(0, 564));
Expand All @@ -62,6 +61,27 @@ GenerateRandomDataset(uint64_t dim, uint64_t count, std::string metric_str = "l2
return base;
}

static TestDataset::DatasetPtr
GenerateNanRandomDataset(uint64_t dim, uint64_t count, std::string metric_str = "l2") {
auto base = vsag::Dataset::Make();
bool need_normalize = (metric_str != "cosine");

std::vector<float> vecs =
fixtures::generate_vectors(count, dim, need_normalize, fixtures::RandomValue(0, 564));
for (int i = (int)(count * 0.9); i < count; ++i) {
vecs[i * dim] = std::numeric_limits<float>::quiet_NaN();
}

std::vector<int64_t> ids(count);
std::iota(ids.begin(), ids.end(), 10086);
base->Dim(dim)
->Ids(CopyVector(ids))
->Float32Vectors(CopyVector(vecs))
->NumElements(count)
->Owner(true);
return base;
}

static std::pair<float*, int64_t*>
CalDistanceFloatMetrix(const vsag::DatasetPtr query,
const vsag::DatasetPtr base,
Expand Down Expand Up @@ -152,27 +172,58 @@ CalFilterGroundTruth(const std::pair<float*, int64_t*>& result,
return gt;
}

TestDataset::TestDataset(uint64_t dim, uint64_t count, std::string metric_str)
: dim_(dim), count_(count) {
this->base_ = GenerateRandomDataset(dim, count, metric_str);
TestDatasetPtr
TestDataset::CreateTestDataset(uint64_t dim, uint64_t count, std::string metric_str) {
TestDatasetPtr dataset = std::shared_ptr<TestDataset>(new TestDataset);
dataset->dim_ = dim;
dataset->count_ = count;
dataset->base_ = GenerateRandomDataset(dim, count, metric_str);
constexpr uint64_t query_count = 100;
dataset->query_ = GenerateRandomDataset(dim, query_count, metric_str);
dataset->filter_query_ = dataset->query_;
dataset->range_query_ = dataset->query_;
{
auto result = CalDistanceFloatMetrix(dataset->query_, dataset->base_, metric_str);
dataset->top_k = 10;
dataset->ground_truth_ = CalTopKGroundTruth(result, dataset->top_k, count, query_count);
dataset->range_ground_truth_ = dataset->ground_truth_;
dataset->range_radius_.resize(query_count);
for (uint64_t i = 0; i < query_count; ++i) {
dataset->range_radius_[i] = 0.5f * (result.first[i * count + dataset->top_k] +
result.first[i * count + dataset->top_k - 1]);
}
dataset->filter_function_ = [](int64_t id) -> bool { return id % 7 != 5; };
dataset->filter_ground_truth_ =
CalFilterGroundTruth(result, dataset->top_k, 7, 5, count, query_count);
delete[] result.first;
delete[] result.second;
}
return dataset;
}

TestDatasetPtr
TestDataset::CreateNanDataset(const std::string& metric_str) {
TestDatasetPtr dataset = std::shared_ptr<TestDataset>(new TestDataset);
dataset->dim_ = 256;
dataset->count_ = 1000;
constexpr uint64_t query_count = 100;
this->query_ = GenerateRandomDataset(dim, query_count, metric_str);
this->filter_query_ = query_;
this->range_query_ = query_;
dataset->base_ = GenerateNanRandomDataset(dataset->dim_, dataset->count_, metric_str);
dataset->query_ = GenerateNanRandomDataset(dataset->dim_, query_count, metric_str);
{
auto result = CalDistanceFloatMetrix(query_, base_, metric_str);
this->top_k = 10;
this->ground_truth_ = CalTopKGroundTruth(result, top_k, count, query_count);
this->range_ground_truth_ = this->ground_truth_;
this->range_radius_.resize(query_count);
auto result = CalDistanceFloatMetrix(dataset->query_, dataset->base_, metric_str);
dataset->top_k = 10;
dataset->ground_truth_ =
CalTopKGroundTruth(result, dataset->top_k, dataset->count_, query_count);
dataset->range_ground_truth_ = dataset->ground_truth_;
dataset->range_radius_.resize(query_count);
for (uint64_t i = 0; i < query_count; ++i) {
this->range_radius_[i] =
0.5f * (result.first[i * count + top_k] + result.first[i * count + top_k - 1]);
dataset->range_radius_[i] =
dataset->ground_truth_->GetDistances()[i * dataset->top_k + dataset->top_k - 1];
}
this->filter_function_ = [](int64_t id) -> bool { return id % 7 != 5; };
this->filter_ground_truth_ = CalFilterGroundTruth(result, top_k, 7, 5, count, query_count);
delete[] result.first;
delete[] result.second;
}
return dataset;
}

} // namespace fixtures
13 changes: 10 additions & 3 deletions tests/fixtures/test_dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ class TestDataset {
public:
using DatasetPtr = vsag::DatasetPtr;

TestDataset(uint64_t dim, uint64_t count, std::string metric_str = "l2");
static std::shared_ptr<TestDataset>
CreateTestDataset(uint64_t dim, uint64_t count, std::string metric_str = "l2");
wxyucs marked this conversation as resolved.
Show resolved Hide resolved

static std::shared_ptr<TestDataset>
CreateNanDataset(const std::string& metric_str);

DatasetPtr base_{nullptr};

Expand All @@ -41,8 +45,11 @@ class TestDataset {
DatasetPtr filter_ground_truth_{nullptr};
std::function<bool(int64_t)> filter_function_{nullptr};

const uint64_t dim_;
const uint64_t count_;
uint64_t dim_{0};
uint64_t count_{0};

private:
TestDataset() = default;
};

using TestDatasetPtr = std::shared_ptr<TestDataset>;
Expand Down
11 changes: 10 additions & 1 deletion tests/fixtures/test_dataset_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,21 @@ TestDatasetPool::GetDatasetAndCreate(uint64_t dim, uint64_t count, const std::st
auto key = key_gen(dim, count, metric_str);
if (this->pool_.find(key) == this->pool_.end()) {
this->dim_counts_.emplace_back(dim, count);
this->pool_[key] = std::make_shared<TestDataset>(dim, count, metric_str);
this->pool_[key] = TestDataset::CreateTestDataset(dim, count, metric_str);
}
return this->pool_.at(key);
}
std::string
TestDatasetPool::key_gen(int64_t dim, uint64_t count, const std::string& metric_str) {
return std::to_string(dim) + "_" + std::to_string(count) + "_" + metric_str;
}

TestDatasetPtr
TestDatasetPool::GetNanDataset(const std::string& metric_str) {
auto key = NAN_DATASET + metric_str;
if (this->pool_.find(key) == this->pool_.end()) {
this->pool_[key] = TestDataset::CreateNanDataset(metric_str);
}
return this->pool_.at(key);
}
} // namespace fixtures
6 changes: 6 additions & 0 deletions tests/fixtures/test_dataset_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,17 @@
#include "test_dataset.h"

namespace fixtures {

static const std::string NAN_DATASET = "nan_dataset";

class TestDatasetPool {
public:
TestDatasetPtr
GetDatasetAndCreate(uint64_t dim, uint64_t count, const std::string& metric_str = "l2");

TestDatasetPtr
GetNanDataset(const std::string& metric_str);

private:
static std::string
key_gen(int64_t dim, uint64_t count, const std::string& metric_str = "l2");
Expand Down
17 changes: 17 additions & 0 deletions tests/test_hnsw_new.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,23 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex,
vsag::Options::Instance().set_block_size_limit(origin_size);
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Search with Nan", "[ft][hnsw]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2", "ip", "cosine");
auto dataset = pool.GetNanDataset(metric_type);
auto dim = dataset->dim_;
const std::string name = "hnsw";
auto search_param = fmt::format(search_param_tmp, 100);

vsag::Options::Instance().set_block_size_limit(size);
auto param = GenerateHNSWBuildParametersString(metric_type, dim);
auto index = TestFactory(name, param, true);
TestContinueAdd(index, dataset, true);
TestSearchWithNan(index, dataset, search_param, 0.98, true);
vsag::Options::Instance().set_block_size_limit(origin_size);
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Build", "[ft][hnsw]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
Expand Down
75 changes: 74 additions & 1 deletion tests/test_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,79 @@ TestIndex::TestSerializeFile(const IndexPtr& index_from,
}
}
}
void
TestIndex::TestSearchWithNan(const TestIndex::IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param,
float expected_recall,
bool expected_success) {
auto queries = dataset->query_;
auto query_count = queries->GetNumElements();
auto dim = queries->GetDim();
auto gts = dataset->ground_truth_;
auto gt_topK = dataset->top_k;
float cur_recall = 0.0f;
auto topk = gt_topK;
for (auto i = 0; i < (int)(query_count * 0.9); ++i) {
auto query = vsag::Dataset::Make();
query->NumElements(1)
->Dim(dim)
->Float32Vectors(queries->GetFloat32Vectors() + i * dim)
->Owner(false);
auto res = index->KnnSearch(query, topk, search_param);
REQUIRE(res.has_value() == expected_success);
if (!expected_success) {
return;
}
REQUIRE(res.value()->GetDim() == topk);
auto result = res.value()->GetIds();
auto dists = res.value()->GetDistances();
for (int j = 0; j < topk; ++j) {
REQUIRE_FALSE(std::isnan(dists[j]));
}
auto gt = gts->GetIds() + gt_topK * i;
auto val = Intersection(gt, gt_topK, result, topk);
cur_recall += static_cast<float>(val) / static_cast<float>(gt_topK);
}
float knn_recall = cur_recall / (query_count / 2);
REQUIRE(knn_recall > expected_recall);

cur_recall = 0.0f;
const auto& radius = dataset->range_radius_;
for (auto i = 0; i < (int)(query_count * 0.9); ++i) {
auto query = vsag::Dataset::Make();
query->NumElements(1)
->Dim(dim)
->Float32Vectors(queries->GetFloat32Vectors() + i * dim)
->Owner(false);
auto res = index->RangeSearch(query, radius[i], search_param);
REQUIRE(res.has_value() == expected_success);
if (!expected_success) {
return;
}
REQUIRE((res.value()->GetDim() >= topk && res.value()->GetDim() <= topk + 1));
auto result = res.value()->GetIds();
auto dists = res.value()->GetDistances();
auto gt = gts->GetIds() + gt_topK * i;
auto val = Intersection(gt, gt_topK, result, topk);
cur_recall += static_cast<float>(val) / static_cast<float>(gt_topK);
}
float range_recall = cur_recall / (query_count / 2);
REQUIRE(range_recall > expected_recall);

for (auto i = (int)(query_count * 0.9); i < query_count; ++i) {
auto query = vsag::Dataset::Make();
query->NumElements(1)
->Dim(dim)
->Float32Vectors(queries->GetFloat32Vectors() + i * dim)
->Owner(false);
auto res = index->KnnSearch(query, topk, search_param);
REQUIRE(res.has_value() == expected_success);
if (!expected_success) {
return;
}
}
}

void
TestIndex::TestConcurrentAdd(const TestIndex::IndexPtr& index,
Expand Down Expand Up @@ -407,4 +480,4 @@ TestIndex::TestDuplicateAdd(const TestIndex::IndexPtr& index, const TestDatasetP
check_func(add_index_2.value());
}

} // namespace fixtures
} // namespace fixtures
7 changes: 7 additions & 0 deletions tests/test_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ class TestIndex {
float expected_recall = 0.99,
bool expected_success = true);

static void
TestSearchWithNan(const IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param,
float expected_recall = 0.99,
bool expected_success = true);

static void
TestRangeSearch(const IndexPtr& index,
const TestDatasetPtr& dataset,
Expand Down
Loading