diff --git a/include/vsag/index.h b/include/vsag/index.h index b9068cc1..96812420 100644 --- a/include/vsag/index.h +++ b/include/vsag/index.h @@ -105,11 +105,11 @@ class Index { * * @param id indicates the old id of a base point in index * @param new_base is the updated new vector of the base point - * @param need_fine_tune indicates whether the connection of the base point needs to be fine-tuned + * @param force_update is false means that a check of the connectivity of the graph updated by this operation is performed * @return result indicates whether the update operation is successful. */ virtual tl::expected<bool, Error> - UpdateVector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune = false) { + UpdateVector(int64_t id, const DatasetPtr& new_base, bool force_update = false) { throw std::runtime_error("Index not support update vector"); } diff --git a/src/index/hnsw.cpp b/src/index/hnsw.cpp index 9b200c54..ea052d91 100644 --- a/src/index/hnsw.cpp +++ b/src/index/hnsw.cpp @@ -41,7 +41,9 @@ const static int64_t DEFAULT_MAX_ELEMENT = 1; const static int MINIMAL_M = 8; const static int MAXIMAL_M = 64; const static uint32_t GENERATE_SEARCH_K = 50; +const static uint32_t UPDATE_CHECK_SEARCH_K = 10; const static uint32_t GENERATE_SEARCH_L = 400; +const static uint32_t UPDATE_CHECK_SEARCH_L = 100; const static float GENERATE_OMEGA = 0.51; HNSW::HNSW(HnswParameters hnsw_params, const IndexCommonParam& index_common_param) @@ -630,8 +632,7 @@ HNSW::update_id(int64_t old_id, int64_t new_id) { } tl::expected<bool, Error> -HNSW::update_vector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune) { - // TODO(ZXY): implement need_fine_tune to allow update with distant vector +HNSW::update_vector(int64_t id, const DatasetPtr& new_base, bool force_update) { if (use_static_) { LOG_ERROR_AND_RETURNS(ErrorType::UNSUPPORTED_INDEX_OPERATION, "static hnsw does not support update"); @@ -643,6 +644,41 @@ HNSW::update_vector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune) size_t data_size = 0; get_vectors(new_base, &new_base_vec, &data_size); + if (not force_update) { + const void* base_data; + auto base = Dataset::Make(); + + // check if id exists + base_data = + std::reinterpret_pointer_cast<hnswlib::HierarchicalNSW>(alg_hnsw_)->getDataByLabel( + id); + set_dataset(base, base_data, 1); + + // search neighbors + auto result = this->knn_search(base, + vsag::UPDATE_CHECK_SEARCH_K, + fmt::format(R"( + {{ + "hnsw": {{ + "ef_search": {} + }} + }})", + vsag::UPDATE_CHECK_SEARCH_L), + nullptr); + + // check whether the neighborhood relationship is same + float self_dist = std::reinterpret_pointer_cast<hnswlib::HierarchicalNSW>(alg_hnsw_) + ->getDistanceByLabel(id, new_base_vec); + for (int i = 0; i < result.value()->GetDim(); i++) { + float neighbor_dist = + std::reinterpret_pointer_cast<hnswlib::HierarchicalNSW>(alg_hnsw_) + ->getDistanceByLabel(result.value()->GetIds()[i], new_base_vec); + if (neighbor_dist < self_dist) { + return false; + } + } + } + // note that the validation of old_id is handled within updatePoint. std::reinterpret_pointer_cast<hnswlib::HierarchicalNSW>(alg_hnsw_)->updateVector( id, new_base_vec); @@ -818,11 +854,10 @@ HNSW::pretrain(const std::vector<int64_t>& base_tag_ids, base_data = (const void*)this->alg_hnsw_->getDataByLabel(base_tag_id); set_dataset(base, base_data, 1); } catch (const std::runtime_error& e) { - LOG_ERROR_AND_RETURNS( - ErrorType::INVALID_ARGUMENT, - fmt::format( - "failed to pretrain(invalid argument): bas tag id ({}) doesn't belong to index", - base_tag_id)); + LOG_ERROR_AND_RETURNS(ErrorType::INVALID_ARGUMENT, + fmt::format("failed to pretrain(invalid argument): base tag id " + "({}) doesn't belong to index", + base_tag_id)); } auto result = this->knn_search(base, diff --git a/src/index/hnsw.h b/src/index/hnsw.h index 9a06e245..1a170ea8 100644 --- a/src/index/hnsw.h +++ b/src/index/hnsw.h @@ -76,8 +76,8 @@ class HNSW : public Index { } tl::expected<bool, Error> - UpdateVector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune = false) override { - SAFE_CALL(return this->update_vector(id, new_base, need_fine_tune)); + UpdateVector(int64_t id, const DatasetPtr& new_base, bool force_update = false) override { + SAFE_CALL(return this->update_vector(id, new_base, force_update)); } tl::expected<DatasetPtr, Error> @@ -212,7 +212,7 @@ class HNSW : public Index { update_id(int64_t old_id, int64_t new_id); tl::expected<bool, Error> - update_vector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune); + update_vector(int64_t id, const DatasetPtr& new_base, bool force_update); template <typename FilterType> tl::expected<DatasetPtr, Error> diff --git a/tests/test_index.cpp b/tests/test_index.cpp index ef79cd47..91bd3db8 100644 --- a/tests/test_index.cpp +++ b/tests/test_index.cpp @@ -144,6 +144,7 @@ TestIndex::TestUpdateVector(const IndexPtr& index, } std::vector<int> correct_num = {0, 0}; + uint32_t success_force_updated = 0, failed_force_updated = 0; for (int round = 0; round < 2; round++) { // round 0 for update, round 1 for validate update results for (int i = 0; i < num_vectors; i++) { @@ -159,8 +160,10 @@ TestIndex::TestUpdateVector(const IndexPtr& index, } std::vector<float> update_vecs(dim); + std::vector<float> far_vecs(dim); for (int d = 0; d < dim; d++) { update_vecs[d] = base[i * dim + d] + 0.001f; + far_vecs[d] = base[i * dim + d] + 1.0f; } auto new_base = vsag::Dataset::Make(); new_base->NumElements(1) @@ -168,6 +171,7 @@ TestIndex::TestUpdateVector(const IndexPtr& index, ->Float32Vectors(update_vecs.data()) ->Owner(false); + // success case auto before_update_dist = *index->CalcDistanceById(base + i * dim, ids[i]); auto succ_vec_res = index->UpdateVector(ids[i], new_base); REQUIRE(succ_vec_res.has_value()); @@ -177,6 +181,38 @@ TestIndex::TestUpdateVector(const IndexPtr& index, auto after_update_dist = *index->CalcDistanceById(base + i * dim, ids[i]); REQUIRE(before_update_dist < after_update_dist); + // update with far vector + new_base->Float32Vectors(far_vecs.data()); + auto fail_vec_res = index->UpdateVector(ids[i], new_base); + REQUIRE(fail_vec_res.has_value()); + if (fail_vec_res.value()) { + // note that the update should be failed, but for some cases, it success + auto force_update_dist = *index->CalcDistanceById(base + i * dim, ids[i]); + REQUIRE(after_update_dist < force_update_dist); + success_force_updated++; + } else { + failed_force_updated++; + } + + // force update with far vector + new_base->Float32Vectors(far_vecs.data()); + auto force_update_res1 = index->UpdateVector(ids[i], new_base, true); + REQUIRE(force_update_res1.has_value()); + if (expected_success) { + REQUIRE(force_update_res1.value()); + auto force_update_dist = *index->CalcDistanceById(base + i * dim, ids[i]); + REQUIRE(after_update_dist < force_update_dist); + } + + new_base->Float32Vectors(update_vecs.data()); + auto force_update_res2 = index->UpdateVector(ids[i], new_base, true); + REQUIRE(force_update_res2.has_value()); + if (expected_success) { + REQUIRE(force_update_res2.value()); + auto force_update_dist = *index->CalcDistanceById(base + i * dim, ids[i]); + REQUIRE(std::abs(after_update_dist - force_update_dist) < 1e-5); + } + // old id don't exist auto failed_old_res = index->UpdateVector(ids[i] + 2 * max_id, new_base); REQUIRE(failed_old_res.has_value()); @@ -190,6 +226,7 @@ TestIndex::TestUpdateVector(const IndexPtr& index, } REQUIRE(correct_num[0] == correct_num[1]); + REQUIRE(success_force_updated < failed_force_updated); } void