Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: zhongxiaoyao.zxy <[email protected]>
  • Loading branch information
ShawnShawnYou committed Jan 10, 2025
1 parent fc5e695 commit 8b7d194
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 11 deletions.
13 changes: 13 additions & 0 deletions src/allocator_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,17 @@ class AllocatorWrapper {

Allocator* allocator_{};
};

template <typename T, typename U>
bool
operator==(const AllocatorWrapper<T>&, const AllocatorWrapper<U>&) noexcept {
return true;
}

template <typename T, typename U>
bool
operator!=(const AllocatorWrapper<T>& a, const AllocatorWrapper<U>& b) noexcept {
return !(a == b);
}

} // namespace vsag
22 changes: 22 additions & 0 deletions src/impl/basic_optimizer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "basic_optimizer.h"

namespace vsag {
template class Optimizer<
BasicSearcher<AdaptGraphDataCell,
FlattenDataCell<FP32Quantizer<vsag::MetricType::METRIC_TYPE_L2SQR>, MemoryIO>>>;
}
20 changes: 12 additions & 8 deletions src/impl/basic_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,21 @@

#pragma once
#include "../utils.h"
#include "basic_searcher.h"
#include "logger.h"
#include "runtime_parameter.h"

namespace vsag {

template <typename OptimizableOBJ>
class Optimizer {
Optimizer(std::shared_ptr<Allocator> allocator, int trials = 100)
: parameters_(allocator.get()),
best_params_(allocator.get()),
public:
Optimizer(const IndexCommonParam& common_param, int trials = 100)
: parameters_(common_param.allocator_.get()),
best_params_(common_param.allocator_.get()),
n_trials_(trials),
best_loss_(std::numeric_limits<double>::max()) {
allocator_ = allocator.get();
allocator_ = common_param.allocator_.get();
std::random_device rd;
gen_.seed(rd());
}
Expand All @@ -38,19 +40,19 @@ class Optimizer {
}

void
Optimize(OptimizableOBJ& obj) {
double original_loss = obj.MockRun();
Optimize(std::shared_ptr<OptimizableOBJ> obj) {
double original_loss = obj->MockRun();

for (int i = 0; i < n_trials_; ++i) {
// generate a group of runtime params
UnorderedMap<std::string, ParamValue> current_params(allocator_);
for (auto& param : parameters_) {
current_params[param->name_] = param->sample(gen_);
}
obj.SetRuntimeParameters(current_params);
obj->SetRuntimeParameters(current_params);

// evaluate
double loss = obj.MockRun();
double loss = obj->MockRun();

// update
if (loss < best_loss_) {
Expand All @@ -62,6 +64,8 @@ class Optimizer {
(original_loss - best_loss_) / original_loss));
}
}

obj->SetRuntimeParameters(best_params_);
}

UnorderedMap<std::string, ParamValue>
Expand Down
2 changes: 1 addition & 1 deletion src/impl/basic_searcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

namespace vsag {

static const InnerIdType SAMPLE_SIZE = 10000;
static const InnerIdType SAMPLE_SIZE = 1000;
static const uint32_t CENTROID_EF = 500;
static const uint32_t PREFETCH_DEGREE_DIVIDE = 3;
static const uint32_t PREFETCH_MAXIMAL_DEGREE = 1;
Expand Down
13 changes: 11 additions & 2 deletions src/impl/basic_searcher_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "algorithm/hnswlib/hnswalg.h"
#include "algorithm/hnswlib/space_l2.h"
#include "basic_optimizer.h"
#include "catch2/catch_template_test_macros.hpp"
#include "data_cell/adapter_graph_datacell.h"
#include "data_cell/flatten_datacell.h"
Expand All @@ -28,7 +29,7 @@

using namespace vsag;

TEST_CASE("search with alg_hnsw", "[ut][basic_searcher]") {
TEST_CASE("search with alg_hnsw and optimizer", "[ut][basic_searcher]") {
// data attr
uint32_t base_size = 1000;
uint32_t query_size = 100;
Expand Down Expand Up @@ -88,9 +89,17 @@ TEST_CASE("search with alg_hnsw", "[ut][basic_searcher]") {
vector_data_cell->BatchInsertVector(base_vectors.data(), base_size, ids.data());
using VectorDataTmpl = std::remove_pointer_t<decltype(vector_data_cell.get())>;

// searcher
// init searcher and optimizer
auto searcher = std::make_shared<BasicSearcher<GraphTmpl, VectorDataTmpl>>(
graph_data_cell, vector_data_cell, common);
auto optimizer =
std::make_shared<Optimizer<BasicSearcher<GraphTmpl, VectorDataTmpl>>>(common, 1);
optimizer->RegisterParameter(std::make_shared<IntRuntimeParameter>(PREFETCH_CACHE_LINE, 1, 10));
optimizer->RegisterParameter(
std::make_shared<IntRuntimeParameter>(PREFETCH_NEIGHBOR_CODE_NUM, 1, 10));
optimizer->RegisterParameter(
std::make_shared<IntRuntimeParameter>(PREFETCH_NEIGHBOR_VISIT_NUM, 1, 10));
optimizer->Optimize(searcher);

// search
InnerSearchParam search_param;
Expand Down

0 comments on commit 8b7d194

Please sign in to comment.