Skip to content

Commit

Permalink
implement of the engine and rewrite factory
Browse files Browse the repository at this point in the history
- allocator shared_ptr hold in Index and Dataset
- the allocator ptr always point to SafeAllocator
- allocator ptr is used inner index
- only Engine/Resource can create Allocator Object

Signed-off-by: LHT129 <[email protected]>
  • Loading branch information
LHT129 committed Dec 19, 2024
1 parent 413d456 commit 57bee65
Show file tree
Hide file tree
Showing 36 changed files with 394 additions and 331 deletions.
7 changes: 5 additions & 2 deletions include/vsag/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@

#include <memory>

#include "dataset.h"
#include "index.h"
#include "resource.h"

namespace vsag {
class Engine {
public:
explicit Engine(Resource* resource = nullptr);
explicit Engine();

explicit Engine(Resource* resource);

void
Shutdown();
Shutdown(); // like ~Engine(), but will warn whether the resources are still reference outside

tl::expected<std::shared_ptr<Index>, Error>
CreateIndex(const std::string& name, const std::string& parameters);
Expand Down
1 change: 1 addition & 0 deletions include/vsag/vsag.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ init();
#include "bitset.h"
#include "constants.h"
#include "dataset.h"
#include "engine.h"
#include "errors.h"
#include "expected.hpp"
#include "factory.h"
Expand Down
12 changes: 6 additions & 6 deletions src/algorithm/hgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ next_multiple_of_power_of_two(uint64_t x, uint64_t n) {
HGraph::HGraph(const JsonType& index_param, const vsag::IndexCommonParam& common_param) noexcept
: index_param_(index_param),
common_param_(common_param),
label_lookup_(common_param.allocator_),
label_op_mutex_(MAX_LABEL_OPERATION_LOCKS, common_param_.allocator_),
neighbors_mutex_(0, common_param_.allocator_),
route_graphs_(common_param.allocator_),
labels_(common_param.allocator_) {
label_lookup_(common_param.allocator_.get()),
label_op_mutex_(MAX_LABEL_OPERATION_LOCKS, common_param.allocator_.get()),
neighbors_mutex_(0, common_param.allocator_.get()),
route_graphs_(common_param.allocator_.get()),
labels_(common_param.allocator_.get()) {
this->dim_ = common_param.dim_;
this->metric_ = common_param.metric_;
this->allocator_ = common_param.allocator_;
this->allocator_ = common_param.allocator_.get();
}

tl::expected<void, Error>
Expand Down
2 changes: 1 addition & 1 deletion src/data_cell/flatten_datacell.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ template <typename QuantTmpl, typename IOTmpl>
FlattenDataCell<QuantTmpl, IOTmpl>::FlattenDataCell(const JsonType& quantization_param,
const JsonType& io_param,
const IndexCommonParam& common_param)
: allocator_(common_param.allocator_) {
: allocator_(common_param.allocator_.get()) {
this->quantizer_ = std::make_shared<QuantTmpl>(quantization_param, common_param);
this->io_ = std::make_shared<IOTmpl>(io_param, common_param);
this->code_size_ = quantizer_->GetCodeSize();
Expand Down
8 changes: 5 additions & 3 deletions src/data_cell/flatten_datacell_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
#include "flatten_datacell.h"

#include <algorithm>
#include <utility>

#include "catch2/catch_template_test_macros.hpp"
#include "default_allocator.h"
#include "fixtures.h"
#include "flatten_interface_test.h"
#include "io/io_headers.h"
#include "quantization/quantizer_headers.h"
#include "safe_allocator.h"

using namespace vsag;

Expand All @@ -36,7 +38,7 @@ TestFlattenDataCell(int dim,
auto counts = {100, 1000};
IndexCommonParam common;
common.dim_ = dim;
common.allocator_ = allocator.get();
common.allocator_ = std::move(allocator);
common.metric_ = metric;
for (auto count : counts) {
auto flatten =
Expand Down Expand Up @@ -67,7 +69,7 @@ TestFlattenDataCellFP32(int dim,
}

TEST_CASE("fp32", "[ut][flatten_data_cell]") {
auto allocator = std::make_shared<DefaultAllocator>();
auto allocator = SafeAllocator::FactoryDefaultAllocator();
auto fp32_param = JsonType::parse("{}");
auto io_param = JsonType::parse("{}");
auto dims = {8, 64, 512};
Expand Down Expand Up @@ -96,7 +98,7 @@ TestFlattenDataCellSQ8(int dim,
}

TEST_CASE("sq8", "[ut][flatten_data_cell]") {
auto allocator = std::make_shared<DefaultAllocator>();
auto allocator = SafeAllocator::FactoryDefaultAllocator();
auto sq8_param = JsonType::parse("{}");
auto io_param = JsonType::parse("{}");
auto dims = {32, 64, 512};
Expand Down
5 changes: 3 additions & 2 deletions src/data_cell/graph_datacell_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "fmt/format-inl.h"
#include "graph_interface_test.h"
#include "io/io_headers.h"
#include "safe_allocator.h"
using namespace vsag;

template <typename IOTemp>
Expand All @@ -37,7 +38,7 @@ TestGraphDataCell(const JsonType& graph_param,
}

TEST_CASE("graph basic test", "[ut][graph_datacell]") {
auto allocator = std::make_shared<DefaultAllocator>();
auto allocator = SafeAllocator::FactoryDefaultAllocator();
auto dims = {32, 64};
auto max_degrees = {5, 12, 24, 32, 64, 128};
auto max_capacities = {1, 100, 10000, 10'000'000, 32'179'837};
Expand All @@ -58,7 +59,7 @@ TEST_CASE("graph basic test", "[ut][graph_datacell]") {
for (auto dim : dims) {
IndexCommonParam param;
param.dim_ = dim;
param.allocator_ = allocator.get();
param.allocator_ = allocator;
for (auto& gp : graph_params) {
TestGraphDataCell<MemoryIO>(gp, io_param, param);
TestGraphDataCell<MemoryBlockIO>(gp, io_param, param);
Expand Down
3 changes: 2 additions & 1 deletion src/data_cell/graph_interface_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
#include "catch2/catch_test_macros.hpp"
#include "default_allocator.h"
#include "fixtures.h"
#include "safe_allocator.h"

using namespace vsag;

void
GraphInterfaceTest::BasicTest(uint64_t max_id, uint64_t count, GraphInterfacePtr other) {
auto allocator = std::make_shared<DefaultAllocator>();
auto allocator = SafeAllocator::FactoryDefaultAllocator();
auto max_degree = this->graph_->MaximumDegree();
auto old_count = this->graph_->TotalCount();
UnorderedMap<InnerIdType, std::shared_ptr<Vector<InnerIdType>>> maps(allocator.get());
Expand Down
2 changes: 1 addition & 1 deletion src/data_cell/sparse_graph_datacell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace vsag {

SparseGraphDataCell::SparseGraphDataCell(const JsonType& graph_param,
const IndexCommonParam& common_param)
: allocator_(common_param.allocator_), neighbors_(common_param.allocator_) {
: allocator_(common_param.allocator_.get()), neighbors_(common_param.allocator_.get()) {
if (graph_param.contains(GRAPH_PARAM_MAX_DEGREE)) {
this->maximum_degree_ = graph_param[GRAPH_PARAM_MAX_DEGREE];
}
Expand Down
5 changes: 3 additions & 2 deletions src/data_cell/sparse_graph_datacell_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "default_allocator.h"
#include "fmt/format-inl.h"
#include "graph_interface_test.h"
#include "safe_allocator.h"

using namespace vsag;

Expand All @@ -35,7 +36,7 @@ TestSparseGraphDataCell(const JsonType& graph_param, const IndexCommonParam& par
}

TEST_CASE("graph basic test", "[ut][sparse_graph_datacell]") {
auto allocator = std::make_shared<DefaultAllocator>();
auto allocator = SafeAllocator::FactoryDefaultAllocator();
auto dims = {32, 64};
auto max_degrees = {5, 12, 24, 32, 64, 128};
auto max_capacities = {1, 100, 10000, 10'000'000, 32'179'837};
Expand All @@ -55,7 +56,7 @@ TEST_CASE("graph basic test", "[ut][sparse_graph_datacell]") {
for (auto dim : dims) {
IndexCommonParam param;
param.dim_ = dim;
param.allocator_ = allocator.get();
param.allocator_ = allocator;
for (auto& gp : graph_params) {
TestSparseGraphDataCell(gp, param);
}
Expand Down
10 changes: 0 additions & 10 deletions src/default_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,6 @@
namespace vsag {

class DefaultAllocator : public Allocator {
public:
static std::shared_ptr<Allocator>
Instance() {
static std::shared_ptr<Allocator> s_instance;
if (s_instance == nullptr) {
s_instance = std::make_shared<DefaultAllocator>();
}
return s_instance;
}

public:
DefaultAllocator() = default;
~DefaultAllocator() override {
Expand Down
12 changes: 8 additions & 4 deletions src/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@

namespace vsag {

Engine::Engine() {
this->resource_ = std::make_shared<ResourceOwnerWrapper>(new Resource(), /*owned*/ true);
}

Engine::Engine(Resource* resource) {
if (resource == nullptr) {
this->resource_ = std::make_shared<ResourceOwnerWrapper>(new Resource(), /*owned*/ true);
Expand All @@ -42,13 +46,16 @@ Engine::Engine(Resource* resource) {

void
Engine::Shutdown() {
auto refcount = this->resource_.use_count();
this->resource_.reset();

// TODO(LHT): add refcount warning
}

tl::expected<std::shared_ptr<Index>, Error>
Engine::CreateIndex(const std::string& origin_name, const std::string& parameters) {
try {
auto* allocator = this->resource_->allocator.get();
auto& allocator = this->resource_->allocator;
std::string name = origin_name;
transform(name.begin(), name.end(), name.begin(), ::tolower);
JsonType parsed_params = JsonType::parse(parameters);
Expand Down Expand Up @@ -79,9 +86,6 @@ Engine::CreateIndex(const std::string& origin_name, const std::string& parameter
logger::debug("created a diskann index");
return std::make_shared<DiskANN>(diskann_params, index_common_params);
} else if (name == INDEX_HGRAPH) {
if (allocator == nullptr) {
index_common_params.allocator_ = DefaultAllocator::Instance().get();
}
logger::debug("created a hgraph index");
JsonType hgraph_params;
if (parsed_params.contains(INDEX_PARAM)) {
Expand Down
73 changes: 8 additions & 65 deletions src/factory/factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,87 +13,30 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "vsag/factory.h"

#include <algorithm>
#include <cctype>
#include <cstdint>
#include <exception>
#include <fstream>
#include <ios>
#include <memory>
#include <mutex>
#include <stdexcept>
#include <string>

#include "index/diskann.h"
#include "index/diskann_zparameters.h"
#include "index/hgraph_index.h"
#include "index/hgraph_zparameters.h"
#include "index/hnsw.h"
#include "index/hnsw_zparameters.h"
#include "index/index_common_param.h"
#include "vsag/vsag.h"
#include "ThreadPool.h"
#include "vsag/engine.h"
#include "vsag/options.h"

namespace vsag {

tl::expected<std::shared_ptr<Index>, Error>
Factory::CreateIndex(const std::string& origin_name,
const std::string& parameters,
Allocator* allocator) {
try {
std::string name = origin_name;
transform(name.begin(), name.end(), name.begin(), ::tolower);
JsonType parsed_params = JsonType::parse(parameters);
auto index_common_params = IndexCommonParam::CheckAndCreate(parsed_params, allocator);
if (name == INDEX_HNSW) {
// read parameters from json, throw exception if not exists
CHECK_ARGUMENT(parsed_params.contains(INDEX_HNSW),
fmt::format("parameters must contains {}", INDEX_HNSW));
auto& hnsw_param_obj = parsed_params[INDEX_HNSW];
auto hnsw_params = HnswParameters::FromJson(hnsw_param_obj, index_common_params);
logger::debug("created a hnsw index");
return std::make_shared<HNSW>(hnsw_params, index_common_params);
} else if (name == INDEX_FRESH_HNSW) {
// read parameters from json, throw exception if not exists
CHECK_ARGUMENT(parsed_params.contains(INDEX_HNSW),
fmt::format("parameters must contains {}", INDEX_HNSW));
auto& hnsw_param_obj = parsed_params[INDEX_HNSW];
auto hnsw_params = FreshHnswParameters::FromJson(hnsw_param_obj, index_common_params);
logger::debug("created a fresh-hnsw index");
return std::make_shared<HNSW>(hnsw_params, index_common_params);
} else if (name == INDEX_DISKANN) {
// read parameters from json, throw exception if not exists
CHECK_ARGUMENT(parsed_params.contains(INDEX_DISKANN),
fmt::format("parameters must contains {}", INDEX_DISKANN));
auto& diskann_param_obj = parsed_params[INDEX_DISKANN];
auto diskann_params =
DiskannParameters::FromJson(diskann_param_obj, index_common_params);
logger::debug("created a diskann index");
return std::make_shared<DiskANN>(diskann_params, index_common_params);
} else if (name == INDEX_HGRAPH) {
if (allocator == nullptr) {
index_common_params.allocator_ = DefaultAllocator::Instance().get();
}
logger::debug("created a hgraph index");
JsonType hgraph_params;
if (parsed_params.contains(INDEX_PARAM)) {
hgraph_params = std::move(parsed_params[INDEX_PARAM]);
}
HGraphParameters hgraph_param(hgraph_params, index_common_params);
auto hgraph_index =
std::make_shared<HGraphIndex>(hgraph_param.GetJson(), index_common_params);
hgraph_index->Init();
return hgraph_index;
} else {
LOG_ERROR_AND_RETURNS(
ErrorType::UNSUPPORTED_INDEX, "failed to create index(unsupported): ", name);
}
} catch (const std::invalid_argument& e) {
LOG_ERROR_AND_RETURNS(
ErrorType::INVALID_ARGUMENT, "failed to create index(invalid argument): ", e.what());
} catch (const std::exception& e) {
LOG_ERROR_AND_RETURNS(
ErrorType::UNSUPPORTED_INDEX, "failed to create index(unknown error): ", e.what());
}
Resource resource(allocator);
Engine e(&resource);
return e.CreateIndex(origin_name, parameters);
}

class LocalFileReader : public Reader {
Expand Down
6 changes: 6 additions & 0 deletions src/index/hgraph_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@ namespace vsag {
HGraphIndex::HGraphIndex(const vsag::JsonType& index_param,
const vsag::IndexCommonParam& common_param) noexcept {
this->hgraph_ = std::make_unique<HGraph>(index_param, common_param);
this->allocator_ = common_param.allocator_;
}

HGraphIndex::~HGraphIndex() {
this->hgraph_.reset();
this->allocator_.reset();
}
} // namespace vsag
4 changes: 4 additions & 0 deletions src/index/hgraph_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class HGraphIndex : public Index {
public:
HGraphIndex(const JsonType& index_param, const IndexCommonParam& common_param) noexcept;

~HGraphIndex() override;

tl::expected<void, Error>
Init() {
SAFE_CALL(return this->hgraph_->Init());
Expand Down Expand Up @@ -141,5 +143,7 @@ class HGraphIndex : public Index {

private:
std::unique_ptr<HGraph> hgraph_{nullptr};

std::shared_ptr<Allocator> allocator_{nullptr};
};
} // namespace vsag
12 changes: 4 additions & 8 deletions src/index/hnsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ HNSW::HNSW(HnswParameters hnsw_params, const IndexCommonParam& index_common_para
conjugate_graph_ = std::make_shared<ConjugateGraph>();
}

if (not index_common_param.allocator_) {
allocator_ = std::make_shared<SafeAllocator>(DefaultAllocator::Instance());
} else {
allocator_ = std::make_shared<SafeAllocator>(index_common_param.allocator_);
}
allocator_ = index_common_param.allocator_;

if (!use_static_) {
alg_hnsw_ =
Expand Down Expand Up @@ -276,7 +272,7 @@ HNSW::knn_search(const DatasetPtr& query,
results.pop();
}

result->Dim(results.size())->NumElements(1)->Owner(true, allocator_->GetRawAllocator());
result->Dim(results.size())->NumElements(1)->Owner(true, allocator_.get());

int64_t* ids = (int64_t*)allocator_->Allocate(sizeof(int64_t) * results.size());
result->Ids(ids);
Expand Down Expand Up @@ -387,7 +383,7 @@ HNSW::range_search(const DatasetPtr& query,
if (limited_size >= 1) {
target_size = std::min((size_t)limited_size, target_size);
}
result->Dim(target_size)->NumElements(1)->Owner(true, allocator_->GetRawAllocator());
result->Dim(target_size)->NumElements(1)->Owner(true, allocator_.get());
int64_t* ids = (int64_t*)allocator_->Allocate(sizeof(int64_t) * target_size);
result->Ids(ids);
float* dists = (float*)allocator_->Allocate(sizeof(float) * target_size);
Expand Down Expand Up @@ -714,7 +710,7 @@ HNSW::brute_force(const DatasetPtr& query, int64_t k) {
fmt::format("query.dim({}) must be equal to index.dim({})", query->GetDim(), dim_));

auto result = Dataset::Make();
result->NumElements(k)->Owner(true, allocator_->GetRawAllocator());
result->NumElements(k)->Owner(true, allocator_.get());
int64_t* ids = (int64_t*)allocator_->Allocate(sizeof(int64_t) * k);
result->Ids(ids);
float* dists = (float*)allocator_->Allocate(sizeof(float) * k);
Expand Down
Loading

0 comments on commit 57bee65

Please sign in to comment.