Skip to content

Commit

Permalink
🔨 version 1.0.3 - Add GetMaxElements, GetCurrentElementCount, `Ge…
Browse files Browse the repository at this point in the history
…tDeleteCount`, `GetVectorByLabel` APIs
  • Loading branch information
sunhailin-Leo committed Feb 27, 2023
1 parent 7a3eb65 commit be02fde
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 28 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ Hnswlib to go. Golang interface to hnswlib(https://github.com/nmslib/hnswlib). T

### Version

* version 1.0.3
* Add `GetMaxElements`, `GetCurrentElementCount`, `GetDeleteCount`, `GetVectorByLabel` APIs

* version 1.0.2
* Update hnswlib compatible version to 0.7.0
* Add `AddBatchPoints`, `SearchBatchKNN`, `SetNormalize`, `ResizeIndex`, `MarkDelete`, `UnmarkDelete`, `GetLabelIsMarkedDeleted` API
Expand Down
55 changes: 46 additions & 9 deletions example/demo.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"fmt"
"math/rand"
"reflect"
"runtime"
"time"

Expand Down Expand Up @@ -43,26 +44,35 @@ func randVector(dim int) []float32 {
}

// 单个写入
func exampleAddPoint(indexFileName string) {
func exampleAddPoint(indexFileName string) []float32 {
var dim, M, ef = 128, 32, 300
// 最大的 elements 数
var maxElements uint32 = 10000
var maxElements uint32 = 100
// 定义距离 cosine
var spaceType = "cosine"
var randomSeed = 100
var randomSeed = 2000
fmt.Println("Before Create HNSW")
traceMemStats()
// Init new index
h := hnswgo.New(dim, M, ef, randomSeed, maxElements, spaceType)

// randomIndex to test the api GetVectorByLabel
var randomIndex []float32

// Insert 1000 vectors to index. Label Type is uint32
var i uint32
for ; i < maxElements; i++ {
if i%1000 == 0 {
fmt.Println(i)
}
h.AddPoint(randVector(dim), i)
randVec := randVector(dim)
h.AddPoint(randVec, i)
if i == 0 {
randomIndex = randVec
}
}
h.Save(indexFileName)
return randomIndex
}

// 批量写入
Expand Down Expand Up @@ -97,7 +107,7 @@ func exampleBatchAddPoint(indexFileName string) {
}

// 读取
func exampleLoadIndex(indexFileName, spaceType string, dim int) {
func exampleLoadIndex(indexFileName, spaceType string, dim int) []float32 {
h := hnswgo.Load(indexFileName, dim, spaceType)
// Search vector with maximum 5 NN
h.SetEf(15)
Expand All @@ -109,36 +119,63 @@ func exampleLoadIndex(indexFileName, spaceType string, dim int) {
fmt.Println(endTime - startTime)
fmt.Println(labels, vectors)

// Test GetMaxElements API Before Resize
maxElementsBeforeResize := h.GetMaxElements()
currentElementsBeforeResize := h.GetCurrentElementCount()
fmt.Println("maxElements, currentElements(before resize): ", maxElementsBeforeResize, currentElementsBeforeResize)

// Test ResizeIndex API
isResize := h.ResizeIndex(12000)
fmt.Println("Size flag: ", isResize)

// Test GetMaxElements API After Resize
maxElementsAfterResize := h.GetMaxElements()
currentElementsAfterResize := h.GetCurrentElementCount()
fmt.Println("maxElements, currentElements(after resize): ", maxElementsAfterResize, currentElementsAfterResize)

// Test GetDeleteCount API
deleteCountBeforeDelete := h.GetDeleteCount()
fmt.Println("GetDeleteCount(before): ", deleteCountBeforeDelete)

// Test Mark API
isMarkDelete := h.MarkDelete(10)
fmt.Println("isMarkDelete: ", isMarkDelete)

labelIsDelete := h.GetLabelIsMarkedDeleted(10)
fmt.Println("labelIsDelete: ", labelIsDelete)

// Test GetDeleteCount API
deleteCountBeforeAfter := h.GetDeleteCount()
fmt.Println("GetDeleteCount(after): ", deleteCountBeforeAfter)

isUnmarkDelete := h.UnmarkDelete(10)
fmt.Println("isUnmarkDelete: ", isUnmarkDelete)

// Test GetVectorByLabel API
getVectorByIdRes := h.GetVectorByLabel(0, dim)
fmt.Println("Vector: ", getVectorByIdRes)

// Test Unload API
fmt.Println("Before Unload")
traceMemStats()
h.Unload()
fmt.Println("After Unload")
traceMemStats()

return getVectorByIdRes
}

func main() {
// 单条写入 add index point by point
exampleAddPoint("hnsw_demo_single.bin")
demoVector := exampleAddPoint("hnsw_demo_single.bin")
// 测试读取 test loading
exampleLoadIndex("hnsw_demo_single.bin", "cosine", 128)
demoSearchVector := exampleLoadIndex("hnsw_demo_single.bin", "cosine", 128)
// test GetVectorByLabel API
isEqual := reflect.DeepEqual(demoVector, demoSearchVector)
fmt.Println("GetVectorByLabel return data is equal: ", isEqual)

// 批量写入 add index with batch mode
//exampleBatchAddPoint("hnsw_demo_multiple.bin")
exampleBatchAddPoint("hnsw_demo_multiple.bin")
// 测试读取 test loading
//exampleLoadIndex("hnsw_demo_multiple.bin", "cosine", 128)
exampleLoadIndex("hnsw_demo_multiple.bin", "cosine", 128)
}
72 changes: 58 additions & 14 deletions hnsw.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
package hnswgo

// #cgo LDFLAGS: -L${SRCDIR} -lhnsw -lm
// #include <stdlib.h>
// #include <stdbool.h>
// #include "hnsw_wrapper.h"
// HNSW initHNSW(int dim, unsigned long int max_elements, int M, int ef_construction, int rand_seed, char stype);
// HNSW loadHNSW(char *location, int dim, char stype);
// void addPoint(HNSW index, float *vec, unsigned long int label);
// int searchKnn(HNSW index, float *vec, int N, unsigned long int *label, float *dist);
// void setEf(HNSW index, int ef);
// bool resizeIndex(HNSW index, unsigned long int new_max_elements);
// bool markDelete(HNSW index, unsigned long int label);
// bool unmarkDelete(HNSW index, unsigned long int label);
// bool isMarkedDeleted(HNSW index, unsigned long int label);
// bool updatePoint(HNSW index, float *vec, unsigned long int label);
import "C"

/*
#cgo CXXFLAGS: -std=c++11
#cgo LDFLAGS: -L${SRCDIR} -lhnsw -lm
#include <stdlib.h>
#include <stdbool.h>
#include "hnsw_wrapper.h"
HNSW initHNSW(int dim, unsigned long int max_elements, int M, int ef_construction, int rand_seed, char stype);
HNSW loadHNSW(char *location, int dim, char stype);
void addPoint(HNSW index, float *vec, unsigned long int label);
int searchKnn(HNSW index, float *vec, int N, unsigned long int *label, float *dist);
void setEf(HNSW index, int ef);
bool resizeIndex(HNSW index, unsigned long int new_max_elements);
bool markDelete(HNSW index, unsigned long int label);
bool unmarkDelete(HNSW index, unsigned long int label);
bool isMarkedDeleted(HNSW index, unsigned long int label);
bool updatePoint(HNSW index, float *vec, unsigned long int label);
void getDataByLabel(HNSW index, unsigned long int label, float* out_data);
*/
import "C"
import (
"math"
Expand All @@ -22,6 +30,13 @@ import (
"unsafe"
)

func toSlice(v *C.float, len int) []float32 {
// 创建一个指向C数组的slice
slice := (*[1 << 30]float32)(unsafe.Pointer(v))[:len:len]
// 复制slice的值,将其转换为一个新的Go切片
return append([]float32(nil), slice...)
}

type HNSW struct {
index C.HNSW
spaceType string
Expand Down Expand Up @@ -224,3 +239,32 @@ func (h *HNSW) GetLabelIsMarkedDeleted(label uint32) bool {
isDelete := bool(C.isMarkedDeleted(h.index, C.ulong(label)))
return isDelete
}

// GetMaxElements get index max elements
func (h *HNSW) GetMaxElements() int {
maxElements := int(C.getMaxElements(h.index))
return maxElements
}

// GetCurrentElementCount get index current elements
func (h *HNSW) GetCurrentElementCount() int {
elementCnt := int(C.getCurrentElementCount(h.index))
return elementCnt
}

// GetDeleteCount get index count which mark deleted
func (h *HNSW) GetDeleteCount() int {
deleteElementCnt := int(C.getDeleteCount(h.index))
return deleteElementCnt
}

// GetVectorByLabel get index by label
func (h *HNSW) GetVectorByLabel(label uint32, dim int) []float32 {
var outDataPtr C.float
C.getDataByLabel(h.index, C.ulong(label), &outDataPtr)
outData := make([]float32, dim)
for i := 0; i < dim; i++ {
outData[i] = float32(*(*C.float)(unsafe.Pointer(uintptr(unsafe.Pointer(&outDataPtr)) + uintptr(i)*unsafe.Sizeof(C.float(0)))))
}
return outData
}
31 changes: 27 additions & 4 deletions hnsw_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,30 @@ bool updatePoint(HNSW index, float *vec, unsigned long int label) {
return false;
}

// TODO
//std::vector<float> getDataByLabel(HNSW index, unsigned long int label) {
// return ((hnswlib::HierarchicalNSW<float>*)index)->getDataByLabel<float>(label);
//}
void getDataByLabel(HNSW index, unsigned long int label, float* out_data) {
auto data = ((hnswlib::HierarchicalNSW<float>*)index)->getDataByLabel<float>(label);
std::vector<float>* vec = new std::vector<float>(data.begin(), data.end());
if (vec == nullptr) {
return;
}

size_t size = vec->size();
for (size_t i = 0; i < size; i++) {
out_data[i] = (*vec)[i];
}

delete vec;
}

int getMaxElements(HNSW index) {
return ((hnswlib::HierarchicalNSW<float> *) index)->getMaxElements();
}

int getCurrentElementCount(HNSW index) {
return ((hnswlib::HierarchicalNSW<float> *) index)->getCurrentElementCount();
}

int getDeleteCount(HNSW index) {
return ((hnswlib::HierarchicalNSW<float> *) index)->getDeletedCount();
}

9 changes: 8 additions & 1 deletion hnsw_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ bool isMarkedDeleted(HNSW index, unsigned long int label);

bool updatePoint(HNSW index, float *vec, unsigned long int label);

int getMaxElements(HNSW index);

int getCurrentElementCount(HNSW index);

int getDeleteCount(HNSW index);

void getDataByLabel(HNSW index, unsigned long int label, float* out_data);
#ifdef __cplusplus
}
#endif
#endif
Binary file modified hnsw_wrapper.o
Binary file not shown.
Binary file modified libhnsw.a
Binary file not shown.

0 comments on commit be02fde

Please sign in to comment.