Skip to content

Commit

Permalink
Optimize vector index (TuGraph-family#717)
Browse files Browse the repository at this point in the history
* update

* update

* fix

* fix test case error

* fix case error

* fix mem leak
  • Loading branch information
ljcui authored Oct 27, 2024
1 parent e4c5a44 commit 3126051
Show file tree
Hide file tree
Showing 17 changed files with 335 additions and 83 deletions.
12 changes: 5 additions & 7 deletions src/core/index_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ IndexManager::IndexManager(KvTransaction& txn, SchemaManager* v_schema_manager,
new HNSW(idx.label, idx.field, idx.distance_type, idx.index_type,
idx.dimension, {idx.hnsw_m, idx.hnsw_ef_construction})));
uint64_t count = 0;
std::vector<std::vector<float>> floatvector;
std::vector<int64_t> vids;
auto kv_iter = schema->GetPropertyTable().GetIterator(txn);
for (kv_iter->GotoFirstKey(); kv_iter->IsValid(); kv_iter->Next()) {
auto prop = kv_iter->GetValue();
Expand All @@ -119,14 +117,14 @@ IndexManager::IndexManager(KvTransaction& txn, SchemaManager* v_schema_manager,
}
auto vid = graph::KeyPacker::GetVidFromPropertyTableKey(kv_iter->GetKey());
auto vector = (extractor->GetConstRef(prop)).AsType<std::vector<float>>();
floatvector.emplace_back(vector);
vids.emplace_back(vid);
vsag_index->Add({std::move(vector)}, {vid});
count++;
if ((count % 10000) == 0) {
LOG_INFO() << "vector index count: " << count;
}
}
vsag_index->Build();
vsag_index->Add(floatvector, vids, count);
kv_iter.reset();
LOG_DEBUG() << "index count: " << count;
LOG_INFO() << "vector index count: " << count;
schema->MarkVectorIndexed(extractor->GetFieldId(), vsag_index.release());
LOG_INFO() << FMA_FMT("end building vertex vector index for {}:{} in detached model",
idx.label, idx.field);
Expand Down
25 changes: 21 additions & 4 deletions src/core/index_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,43 +289,60 @@ class IndexManager {
std::vector<VectorIndexSpec> ListVectorIndex(KvTransaction& txn);

// vertex index
std::pair<std::vector<IndexSpec>, std::vector<CompositeIndexSpec>> ListAllIndexes(
std::tuple<std::vector<IndexSpec>, std::vector<CompositeIndexSpec>,
std::vector<VectorIndexSpec>> ListAllIndexes(
KvTransaction& txn) {
std::vector<IndexSpec> indexes;
std::vector<CompositeIndexSpec> compositeIndexes;
IndexSpec is;
CompositeIndexSpec cis;
std::vector<VectorIndexSpec> vectorIndexes;
size_t v_index_len = strlen(_detail::VERTEX_INDEX);
size_t e_index_len = strlen(_detail::EDGE_INDEX);
size_t c_index_len = strlen(_detail::COMPOSITE_INDEX);
size_t ve_index_len = strlen(_detail::VERTEX_VECTOR_INDEX);
auto it = index_list_table_->GetIterator(txn);
for (it->GotoFirstKey(); it->IsValid(); it->Next()) {
std::string index_name = it->GetKey().AsString();
if (index_name.size() > v_index_len &&
index_name.substr(index_name.size() - v_index_len) == _detail::VERTEX_INDEX) {
_detail::IndexEntry ent = LoadIndex(it->GetValue());
IndexSpec is;
is.label = ent.label;
is.field = ent.field;
is.type = ent.type;
indexes.emplace_back(std::move(is));
} else if (index_name.size() > e_index_len &&
index_name.substr(index_name.size() - e_index_len) == _detail::EDGE_INDEX) {
_detail::IndexEntry ent = LoadIndex(it->GetValue());
IndexSpec is;
is.label = ent.label;
is.field = ent.field;
is.type = ent.type;
indexes.emplace_back(std::move(is));
} else if (index_name.size() > ve_index_len &&
index_name.substr(index_name.size() - ve_index_len)
== _detail::VERTEX_VECTOR_INDEX) {
_detail::VectorIndexEntry ent = LoadVectorIndex(it->GetValue());
VectorIndexSpec vis;
vis.label = ent.label;
vis.field = ent.field;
vis.distance_type = ent.distance_type;
vis.dimension = ent.dimension;
vis.hnsw_ef_construction = ent.hnsw_ef_construction;
vis.hnsw_m = ent.hnsw_m;
vis.index_type = ent.index_type;
vectorIndexes.emplace_back(vis);
} else if (index_name.size() > c_index_len &&
index_name.substr(index_name.size() - c_index_len) ==
_detail::COMPOSITE_INDEX) {
_detail::CompositeIndexEntry idx = LoadCompositeIndex(it->GetValue());
CompositeIndexSpec cis;
cis.label = idx.label;
cis.fields = idx.field_names;
cis.type = idx.index_type;
compositeIndexes.emplace_back(std::move(cis));
}
}
return {indexes, compositeIndexes};
return {std::move(indexes), std::move(compositeIndexes), std::move(vectorIndexes)};
}
};
} // namespace lgraph
35 changes: 26 additions & 9 deletions src/core/lightning_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ void LightningGraph::DropAllVertex() {
Transaction txn = CreateWriteTxn(false);
ScopedRef<SchemaInfo> curr_schema = schema_.GetScopedRef();
// clear indexes
auto [indexes, composite_indexes] = index_manager_->ListAllIndexes(txn.GetTxn());
auto [indexes, composite_indexes, vector_indexes]
= index_manager_->ListAllIndexes(txn.GetTxn());
for (auto& idx : indexes) {
auto v_schema = curr_schema->v_schema_manager.GetSchema(idx.label);
auto e_schema = curr_schema->e_schema_manager.GetSchema(idx.label);
Expand All @@ -89,6 +90,15 @@ void LightningGraph::DropAllVertex() {
v_schema->GetCompositeIndex(idx.fields)->Clear(txn.GetTxn());
}
}
for (auto& idx : vector_indexes) {
auto v_schema = curr_schema->v_schema_manager.GetSchema(idx.label);
FMA_DBG_ASSERT(v_schema);
if (v_schema) {
auto ext = v_schema->GetFieldExtractor(idx.field);
FMA_DBG_ASSERT(ext);
ext->GetVectorIndex()->Clear();
}
}
// clear detached property data
for (auto& name : curr_schema->v_schema_manager.GetAllLabels()) {
auto s = curr_schema->v_schema_manager.GetSchema(name);
Expand Down Expand Up @@ -2219,8 +2229,6 @@ bool LightningGraph::BlockingAddVectorIndex(bool is_vertex, const std::string& l
label, field);
VectorIndex* index = extractor->GetVectorIndex();
uint64_t count = 0;
std::vector<std::vector<float>> floatvector;
std::vector<int64_t> vids;
auto dim = index->GetVecDimension();
auto kv_iter = schema->GetPropertyTable().GetIterator(txn.GetTxn());
for (kv_iter->GotoFirstKey(); kv_iter->IsValid(); kv_iter->Next()) {
Expand All @@ -2234,13 +2242,13 @@ bool LightningGraph::BlockingAddVectorIndex(bool is_vertex, const std::string& l
THROW_CODE(VectorIndexException,
"vector size error, size:{}, dim:{}", vector.size(), dim);
}
floatvector.emplace_back(std::move(vector));
vids.emplace_back(vid);
index->Add({std::move(vector)}, {vid});
count++;
if ((count % 10000) == 0) {
LOG_INFO() << "vector index count: " << count;
}
}
index->Build();
index->Add(floatvector, vids, count);
LOG_INFO() << "index count: " << count;
LOG_INFO() << "vector index count: " << count;
LOG_INFO() << FMA_FMT("end building vertex vector index for {}:{} in detached model",
label, field);
kv_iter.reset();
Expand Down Expand Up @@ -2809,7 +2817,8 @@ void LightningGraph::DropAllIndex() {
ScopedRef<SchemaInfo> curr_schema = schema_.GetScopedRef();
std::unique_ptr<SchemaInfo> new_schema(new SchemaInfo(*curr_schema.Get()));
std::unique_ptr<SchemaInfo> backup_schema(new SchemaInfo(*curr_schema.Get()));
auto [indexes, composite_indexes] = index_manager_->ListAllIndexes(txn.GetTxn());
auto [indexes, composite_indexes, vector_indexes]
= index_manager_->ListAllIndexes(txn.GetTxn());

bool success = true;
for (auto& idx : indexes) {
Expand Down Expand Up @@ -2840,6 +2849,14 @@ void LightningGraph::DropAllIndex() {
}
v_schema->UnVertexCompositeIndex(idx.fields);
}
for (auto& idx : vector_indexes) {
auto v_schema = new_schema->v_schema_manager.GetSchema(idx.label);
FMA_DBG_ASSERT(v_schema);
auto ret = index_manager_->DeleteVectorIndex(txn.GetTxn(), idx.label, idx.field);
FMA_DBG_ASSERT(ret);
auto ext = v_schema->GetFieldExtractor(idx.field);
v_schema->UnVectorIndex(ext->GetFieldId());
}
if (success) {
schema_.Assign(new_schema.release());
AutoCleanupAction revert_assign_new_schema(
Expand Down
6 changes: 2 additions & 4 deletions src/core/schema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ void Schema::AddVectorToVectorIndex(KvTransaction& txn, VertexId vid, const Valu
"vector index dimension mismatch, vector size:{}, dim:{}",
floatvector.back().size(), dim);
}
index->Add(floatvector, vids, 1);
index->Add(floatvector, vids);
}
}

Expand All @@ -323,9 +323,7 @@ void Schema::DeleteVectorIndex(KvTransaction& txn, VertexId vid, const Value& re
auto& fe = fields_[idx];
if (fe.GetIsNull(record)) continue;
VectorIndex* index = fe.GetVectorIndex();
std::vector<int64_t> vids;
vids.push_back(vid);
index->Add({}, vids, 0);
index->Remove({vid});
}
}

Expand Down
42 changes: 40 additions & 2 deletions src/core/transaction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -968,8 +968,46 @@ Transaction::SetVertexProperty(VertexIterator& it, size_t n_fields, const FieldT
// no need to update index since blob cannot be indexed
} else if (fe->Type() == FieldType::FLOAT_VECTOR) {
fe->ParseAndSet(new_prop, values[i]);
schema->DeleteVectorIndex(*txn_, vid, old_prop);
schema->AddVectorToVectorIndex(*txn_, vid, new_prop);
VectorIndex* index = fe->GetVectorIndex();
if (index) {
bool oldnull = fe->GetIsNull(old_prop);
bool newnull = fe->GetIsNull(new_prop);
std::vector<int64_t> vids {vid};
if (!oldnull && !newnull) {
const auto& old_v = fe->GetConstRef(old_prop);
const auto& new_v = fe->GetConstRef(new_prop);
if (old_v == new_v) {
continue;
}
// delete
index->Remove(vids);
// add
auto dim = index->GetVecDimension();
std::vector<std::vector<float>> floatvector;
floatvector.emplace_back(new_v.AsFloatVector());
if (floatvector.back().size() != (size_t)dim) {
THROW_CODE(InputError,
"vector index dimension mismatch, vector size:{}, dim:{}",
floatvector.back().size(), dim);
}
index->Add(floatvector, vids);
} else if (oldnull && !newnull) {
// add
const auto& new_v = fe->GetConstRef(new_prop);
auto dim = index->GetVecDimension();
std::vector<std::vector<float>> floatvector;
floatvector.emplace_back(new_v.AsFloatVector());
if (floatvector.back().size() != (size_t)dim) {
THROW_CODE(InputError,
"vector index dimension mismatch, vector size:{}, dim:{}",
floatvector.back().size(), dim);
}
index->Add(floatvector, vids);
} else if (!oldnull && newnull) {
// delete
index->Remove(vids);
}
}
} else {
fe->ParseAndSet(new_prop, values[i]);
// update index if there is no error
Expand Down
1 change: 1 addition & 0 deletions src/core/transaction.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class Transaction {
std::vector<IteratorBase*> iterators_;
FullTextIndex* fulltext_index_;
std::vector<FTIndexEntry> fulltext_buffers_;
std::vector<VectorIndexEntry> vector_buffers_;
std::unordered_map<LabelId, int64_t> vertex_delta_count_;
std::unordered_map<LabelId, int64_t> edge_delta_count_;
std::set<LabelId> vertex_label_delete_;
Expand Down
2 changes: 2 additions & 0 deletions src/core/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,8 @@ class Value {
*/
std::string AsString() const { return AsType<std::string>(); }

std::vector<float> AsFloatVector() const { return AsType<std::vector<float>>(); }

/**
* Create a Value that is a const reference to the object t
*
Expand Down
22 changes: 18 additions & 4 deletions src/core/vector_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

namespace lgraph {


class VectorIndex {
friend class Schema;
friend class LightningGraph;
Expand Down Expand Up @@ -66,12 +67,13 @@ class VectorIndex {

// add vector to index and build index
virtual void Add(const std::vector<std::vector<float>>& vectors,
const std::vector<int64_t>& vids, int64_t num_vectors) = 0;
const std::vector<int64_t>& vids) = 0;

virtual void Remove(const std::vector<int64_t>& vids) = 0;

// build index
virtual void Build() = 0;
virtual void Clear() = 0;

// serialize index
// serialize index
virtual std::vector<uint8_t> Save() = 0;

// load index form serialization
Expand All @@ -83,5 +85,17 @@ class VectorIndex {

virtual std::vector<std::pair<int64_t, float>>
RangeSearch(const std::vector<float>& query, float radius, int ef_search, int limit) = 0;

virtual int64_t GetElementsNum() = 0;
virtual int64_t GetMemoryUsage() = 0;
virtual int64_t GetDeletedIdsNum() = 0;
};

struct VectorIndexEntry {
VectorIndex* index;
bool add;
std::vector<int64_t> vids;
std::vector<std::vector<float>> vectors;
};

} // namespace lgraph
Loading

0 comments on commit 3126051

Please sign in to comment.