Skip to content

Commit

Permalink
add estimate feature & test (#289)
Browse files Browse the repository at this point in the history
Signed-off-by: LHT129 <[email protected]>
  • Loading branch information
LHT129 authored Jan 15, 2025
1 parent b6ae713 commit 7106bd7
Show file tree
Hide file tree
Showing 13 changed files with 156 additions and 15 deletions.
2 changes: 2 additions & 0 deletions include/vsag/index_features.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ enum IndexFeature {
SUPPORT_SEARCH_DELETE_CONCURRENT, /**< Supports concurrent searching and deletion */
SUPPORT_ADD_SEARCH_DELETE_CONCURRENT, /**< Supports concurrent addition, searching, and deletion */

SUPPORT_ESTIMATE_MEMORY, /**< Supports estimate memory usage by data count */

INDEX_FEATURE_COUNT /** must be last one */
};
} // namespace vsag
19 changes: 12 additions & 7 deletions src/algorithm/hgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ next_multiple_of_power_of_two(uint64_t x, uint64_t n) {
return result;
}

HGraph::HGraph(const HGraphParameter& hgraph_param,
const vsag::IndexCommonParam& common_param) noexcept
HGraph::HGraph(const HGraphParameter& hgraph_param, const vsag::IndexCommonParam& common_param)
: common_param_(common_param),
dim_(common_param.dim_),
metric_(common_param.metric_),
Expand Down Expand Up @@ -853,19 +852,25 @@ HGraph::resize(uint64_t new_size) {
void
HGraph::init_features() {
// Common Init
// Build & Add
feature_list_.SetFeatures({IndexFeature::SUPPORT_BUILD,
IndexFeature::SUPPORT_BUILD_WITH_MULTI_THREAD,
IndexFeature::SUPPORT_ADD_AFTER_BUILD,
IndexFeature::SUPPORT_KNN_SEARCH,
IndexFeature::SUPPORT_ADD_AFTER_BUILD});
// search
feature_list_.SetFeatures({IndexFeature::SUPPORT_KNN_SEARCH,
IndexFeature::SUPPORT_RANGE_SEARCH,
IndexFeature::SUPPORT_KNN_SEARCH_WITH_ID_FILTER,
IndexFeature::SUPPORT_RANGE_SEARCH_WITH_ID_FILTER,
IndexFeature::SUPPORT_SEARCH_CONCURRENT,
IndexFeature::SUPPORT_DESERIALIZE_BINARY_SET,
IndexFeature::SUPPORT_RANGE_SEARCH_WITH_ID_FILTER});
// concurrency
feature_list_.SetFeature(IndexFeature::SUPPORT_SEARCH_CONCURRENT);
// serialize
feature_list_.SetFeatures({IndexFeature::SUPPORT_DESERIALIZE_BINARY_SET,
IndexFeature::SUPPORT_DESERIALIZE_FILE,
IndexFeature::SUPPORT_DESERIALIZE_READER_SET,
IndexFeature::SUPPORT_SERIALIZE_BINARY_SET,
IndexFeature::SUPPORT_SERIALIZE_FILE});
// other
feature_list_.SetFeatures({IndexFeature::SUPPORT_ESTIMATE_MEMORY});

// About Train
auto name = this->basic_flatten_codes_->GetQuantizerName();
Expand Down
2 changes: 1 addition & 1 deletion src/algorithm/hgraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
namespace vsag {
class HGraph {
public:
HGraph(const HGraphParameter& param, const IndexCommonParam& common_param) noexcept;
HGraph(const HGraphParameter& param, const IndexCommonParam& common_param);

tl::expected<std::vector<int64_t>, Error>
Build(const DatasetPtr& data);
Expand Down
2 changes: 1 addition & 1 deletion src/index/hgraph_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "hgraph_index.h"
namespace vsag {
HGraphIndex::HGraphIndex(const HGraphIndexParameter& param,
const vsag::IndexCommonParam& common_param) noexcept {
const vsag::IndexCommonParam& common_param) {
this->hgraph_ = std::make_unique<HGraph>(*param.hgraph_parameter_, common_param);
this->allocator_ = common_param.allocator_;
}
Expand Down
2 changes: 1 addition & 1 deletion src/index/hgraph_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
namespace vsag {
class HGraphIndex : public Index {
public:
HGraphIndex(const HGraphIndexParameter& param, const IndexCommonParam& common_param) noexcept;
HGraphIndex(const HGraphIndexParameter& param, const IndexCommonParam& common_param);

~HGraphIndex() override;

Expand Down
2 changes: 1 addition & 1 deletion src/index/index_common_param.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace vsag {
IndexCommonParam
IndexCommonParam::CheckAndCreate(JsonType& params, const std::shared_ptr<Resource>& resource) {
IndexCommonParam result;
result.allocator_ = resource->allocator;
result.allocator_ = resource->GetAllocator();
result.thread_pool_ = std::dynamic_pointer_cast<SafeThreadPool>(resource->thread_pool);
// Check DataType
CHECK_ARGUMENT(params.contains(PARAMETER_DTYPE),
Expand Down
2 changes: 1 addition & 1 deletion src/resource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Resource::Resource(Allocator* allocator, ThreadPool* thread_pool) {
if (thread_pool != nullptr) {
this->thread_pool = std::make_shared<SafeThreadPool>(thread_pool, false);
} else {
this->allocator = SafeAllocator::FactoryDefaultAllocator();
this->thread_pool = SafeThreadPool::FactoryDefaultThreadPool();
}
}

Expand Down
3 changes: 2 additions & 1 deletion tests/fixtures/fixtures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

#include "fixtures.h"

#include <unistd.h>

#include <cstdint>
#include <random>
#include <string>
Expand Down Expand Up @@ -301,5 +303,4 @@ SplitString(const std::string& s, char delimiter) {

return tokens;
}

} // namespace fixtures
87 changes: 87 additions & 0 deletions tests/fixtures/memory_record_allocator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <random>
#include <unordered_map>

#include "vsag/allocator.h"

namespace fixtures {

class MemoryRecordAllocator : public vsag::Allocator {
public:
std::string
Name() override {
return "memory_record_allocator";
}
MemoryRecordAllocator() : memory_bytes_(0) {
}

void*
Allocate(size_t size) override {
auto ptr = malloc(size);
{
std::lock_guard lock(mutex_);
records_[ptr] = size;
memory_bytes_ += size;
memory_peak_ = std::max(memory_peak_, memory_bytes_);
}
return ptr;
}

void
Deallocate(void* p) override {
{
std::lock_guard lock(mutex_);
memory_bytes_ -= records_[p];
}
return free(p);
}

void*
Reallocate(void* p, size_t size) override {
{
std::lock_guard lock(mutex_);
memory_bytes_ -= records_[p];
records_[p] = size;
memory_bytes_ += size;
memory_peak_ = std::max(memory_peak_, memory_bytes_);
}
return realloc(p, size);
}

uint64_t
GetMemoryPeak() const {
return this->memory_peak_;
}

uint64_t
GetCurrentMemory() const {
return this->memory_bytes_;
}

private:
uint64_t memory_bytes_{0};

uint64_t memory_peak_{0};

std::unordered_map<void*, uint64_t> records_{};

std::mutex mutex_{};
};
} // namespace fixtures
2 changes: 1 addition & 1 deletion tests/fixtures/random_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@ class RandomAllocator : public vsag::Allocator {
std::uniform_real_distribution<> dis_;
float error_ratio_ = 0.0f;
};
} // namespace fixtures
} // namespace fixtures
23 changes: 22 additions & 1 deletion tests/test_hgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Serialize File",

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex,
"HGraph Build & ContinueAdd Test With Random Allocator",
"[ft][hnsw]") {
"[ft][hgraph]") {
auto allocator = std::make_shared<fixtures::RandomAllocator>();
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
Expand Down Expand Up @@ -456,3 +456,24 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Duplicate Build"
}
}
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Estimate Memory", "[ft][hgraph]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2", "ip", "cosine");

const std::string name = "hgraph";
auto search_param = fmt::format(search_param_tmp, 200);
uint64_t estimate_count = 1000;
for (auto& dim : dims) {
for (auto& [base_quantization_str, recall] : test_cases) {
vsag::Options::Instance().set_block_size_limit(size);
auto param =
GenerateHGraphBuildParametersString(metric_type, dim, base_quantization_str);
auto dataset = pool.GetDatasetAndCreate(dim, estimate_count, metric_type);

TestEstimateMemory(name, param, dataset);
vsag::Options::Instance().set_block_size_limit(origin_size);
}
}
}
20 changes: 20 additions & 0 deletions tests/test_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "test_index.h"

#include "fixtures/memory_record_allocator.h"
#include "fixtures/test_logger.h"
#include "fixtures/thread_pool.h"
#include "simd/fp32_simd.h"
Expand Down Expand Up @@ -620,5 +621,24 @@ TestIndex::TestDuplicateAdd(const TestIndex::IndexPtr& index, const TestDatasetP
REQUIRE(add_index_2.has_value());
check_func(add_index_2.value());
}
void
TestIndex::TestEstimateMemory(const std::string& index_name,
const std::string& build_param,
const TestDatasetPtr& dataset) {
auto allocator = std::make_shared<fixtures::MemoryRecordAllocator>();
{
auto index = vsag::Factory::CreateIndex(index_name, build_param, allocator.get()).value();
REQUIRE(index->GetNumElements() == 0);
if (index->CheckFeature(vsag::SUPPORT_ESTIMATE_MEMORY)) {
auto data_size = dataset->base_->GetNumElements();
auto estimate_memory = index->EstimateMemory(data_size);
auto build_index = index->Build(dataset->base_);
REQUIRE(build_index.has_value());
auto real_memory = allocator->GetCurrentMemory();
REQUIRE(estimate_memory >= static_cast<uint64_t>(real_memory * 0.8));
REQUIRE(estimate_memory <= static_cast<uint64_t>(real_memory * 1.2));
}
}
}

} // namespace fixtures
5 changes: 5 additions & 0 deletions tests/test_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ class TestIndex {
bool expected_success = true);
static void
TestDuplicateAdd(const IndexPtr& index, const TestDatasetPtr& dataset);

static void
TestEstimateMemory(const std::string& index_name,
const std::string& build_param,
const TestDatasetPtr& dataset);
};

} // namespace fixtures

0 comments on commit 7106bd7

Please sign in to comment.