Skip to content

Commit

Permalink
support safe update
Browse files Browse the repository at this point in the history
Signed-off-by: zhongxiaoyao.zxy <[email protected]>
  • Loading branch information
ShawnShawnYou committed Jan 9, 2025
1 parent e87defc commit 652697a
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 12 deletions.
4 changes: 2 additions & 2 deletions include/vsag/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,11 @@ class Index {
*
* @param id indicates the old id of a base point in index
* @param new_base is the updated new vector of the base point
* @param need_fine_tune indicates whether the connection of the base point needs to be fine-tuned
* @param force_update is false means that a check of the connectivity of the graph updated by this operation is performed
* @return result indicates whether the update operation is successful.
*/
virtual tl::expected<bool, Error>
UpdateVector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune = false) {
UpdateVector(int64_t id, const DatasetPtr& new_base, bool force_update = false) {
throw std::runtime_error("Index not support update vector");
}

Expand Down
49 changes: 42 additions & 7 deletions src/index/hnsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ const static int64_t DEFAULT_MAX_ELEMENT = 1;
const static int MINIMAL_M = 8;
const static int MAXIMAL_M = 64;
const static uint32_t GENERATE_SEARCH_K = 50;
const static uint32_t UPDATE_CHECK_SEARCH_K = 10;
const static uint32_t GENERATE_SEARCH_L = 400;
const static uint32_t UPDATE_CHECK_SEARCH_L = 100;
const static float GENERATE_OMEGA = 0.51;

HNSW::HNSW(HnswParameters hnsw_params, const IndexCommonParam& index_common_param)
Expand Down Expand Up @@ -630,8 +632,7 @@ HNSW::update_id(int64_t old_id, int64_t new_id) {
}

tl::expected<bool, Error>
HNSW::update_vector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune) {
// TODO(ZXY): implement need_fine_tune to allow update with distant vector
HNSW::update_vector(int64_t id, const DatasetPtr& new_base, bool force_update) {
if (use_static_) {
LOG_ERROR_AND_RETURNS(ErrorType::UNSUPPORTED_INDEX_OPERATION,
"static hnsw does not support update");
Expand All @@ -643,6 +644,41 @@ HNSW::update_vector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune)
size_t data_size = 0;
get_vectors(new_base, &new_base_vec, &data_size);

if (not force_update) {
const void* base_data;
auto base = Dataset::Make();

// check if id exists
base_data =
std::reinterpret_pointer_cast<hnswlib::HierarchicalNSW>(alg_hnsw_)->getDataByLabel(
id);
set_dataset(base, base_data, 1);

// search neighbors
auto result = this->knn_search(base,
vsag::UPDATE_CHECK_SEARCH_K,
fmt::format(R"(
{{
"hnsw": {{
"ef_search": {}
}}
}})",
vsag::UPDATE_CHECK_SEARCH_L),
nullptr);

// check whether the neighborhood relationship is same
float self_dist = std::reinterpret_pointer_cast<hnswlib::HierarchicalNSW>(alg_hnsw_)
->getDistanceByLabel(id, new_base_vec);
for (int i = 0; i < result.value()->GetDim(); i++) {
float neighbor_dist =
std::reinterpret_pointer_cast<hnswlib::HierarchicalNSW>(alg_hnsw_)
->getDistanceByLabel(result.value()->GetIds()[i], new_base_vec);
if (neighbor_dist < self_dist) {
return false;
}
}
}

// note that the validation of old_id is handled within updatePoint.
std::reinterpret_pointer_cast<hnswlib::HierarchicalNSW>(alg_hnsw_)->updateVector(
id, new_base_vec);
Expand Down Expand Up @@ -818,11 +854,10 @@ HNSW::pretrain(const std::vector<int64_t>& base_tag_ids,
base_data = (const void*)this->alg_hnsw_->getDataByLabel(base_tag_id);
set_dataset(base, base_data, 1);
} catch (const std::runtime_error& e) {
LOG_ERROR_AND_RETURNS(
ErrorType::INVALID_ARGUMENT,
fmt::format(
"failed to pretrain(invalid argument): bas tag id ({}) doesn't belong to index",
base_tag_id));
LOG_ERROR_AND_RETURNS(ErrorType::INVALID_ARGUMENT,
fmt::format("failed to pretrain(invalid argument): base tag id "
"({}) doesn't belong to index",
base_tag_id));
}

auto result = this->knn_search(base,
Expand Down
6 changes: 3 additions & 3 deletions src/index/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ class HNSW : public Index {
}

tl::expected<bool, Error>
UpdateVector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune = false) override {
SAFE_CALL(return this->update_vector(id, new_base, need_fine_tune));
UpdateVector(int64_t id, const DatasetPtr& new_base, bool force_update = false) override {
SAFE_CALL(return this->update_vector(id, new_base, force_update));
}

tl::expected<DatasetPtr, Error>
Expand Down Expand Up @@ -212,7 +212,7 @@ class HNSW : public Index {
update_id(int64_t old_id, int64_t new_id);

tl::expected<bool, Error>
update_vector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune);
update_vector(int64_t id, const DatasetPtr& new_base, bool force_update);

template <typename FilterType>
tl::expected<DatasetPtr, Error>
Expand Down
37 changes: 37 additions & 0 deletions tests/test_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ TestIndex::TestUpdateVector(const IndexPtr& index,
}

std::vector<int> correct_num = {0, 0};
uint32_t success_force_updated = 0, failed_force_updated = 0;
for (int round = 0; round < 2; round++) {
// round 0 for update, round 1 for validate update results
for (int i = 0; i < num_vectors; i++) {
Expand All @@ -159,15 +160,18 @@ TestIndex::TestUpdateVector(const IndexPtr& index,
}

std::vector<float> update_vecs(dim);
std::vector<float> far_vecs(dim);
for (int d = 0; d < dim; d++) {
update_vecs[d] = base[i * dim + d] + 0.001f;
far_vecs[d] = base[i * dim + d] + 1.0f;
}
auto new_base = vsag::Dataset::Make();
new_base->NumElements(1)
->Dim(dim)
->Float32Vectors(update_vecs.data())
->Owner(false);

// success case
auto before_update_dist = *index->CalcDistanceById(base + i * dim, ids[i]);
auto succ_vec_res = index->UpdateVector(ids[i], new_base);
REQUIRE(succ_vec_res.has_value());
Expand All @@ -177,6 +181,38 @@ TestIndex::TestUpdateVector(const IndexPtr& index,
auto after_update_dist = *index->CalcDistanceById(base + i * dim, ids[i]);
REQUIRE(before_update_dist < after_update_dist);

// update with far vector
new_base->Float32Vectors(far_vecs.data());
auto fail_vec_res = index->UpdateVector(ids[i], new_base);
REQUIRE(fail_vec_res.has_value());
if (fail_vec_res.value()) {
// note that the update should be failed, but for some cases, it success
auto force_update_dist = *index->CalcDistanceById(base + i * dim, ids[i]);
REQUIRE(after_update_dist < force_update_dist);
success_force_updated++;
} else {
failed_force_updated++;
}

// force update with far vector
new_base->Float32Vectors(far_vecs.data());
auto force_update_res1 = index->UpdateVector(ids[i], new_base, true);
REQUIRE(force_update_res1.has_value());
if (expected_success) {
REQUIRE(force_update_res1.value());
auto force_update_dist = *index->CalcDistanceById(base + i * dim, ids[i]);
REQUIRE(after_update_dist < force_update_dist);
}

new_base->Float32Vectors(update_vecs.data());
auto force_update_res2 = index->UpdateVector(ids[i], new_base, true);
REQUIRE(force_update_res2.has_value());
if (expected_success) {
REQUIRE(force_update_res2.value());
auto force_update_dist = *index->CalcDistanceById(base + i * dim, ids[i]);
REQUIRE(std::abs(after_update_dist - force_update_dist) < 1e-5);
}

// old id don't exist
auto failed_old_res = index->UpdateVector(ids[i] + 2 * max_id, new_base);
REQUIRE(failed_old_res.has_value());
Expand All @@ -190,6 +226,7 @@ TestIndex::TestUpdateVector(const IndexPtr& index,
}

REQUIRE(correct_num[0] == correct_num[1]);
REQUIRE(success_force_updated < failed_force_updated);
}

void
Expand Down

0 comments on commit 652697a

Please sign in to comment.