Skip to content

Commit

Permalink
Use correct type for binary vector in ivf training (#2086) (#2089)
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <[email protected]>
(cherry picked from commit bf11cbb)

Co-authored-by: Heemin Kim <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and heemin32 authored Sep 11, 2024
1 parent 009c8f3 commit 98d9b01
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
16 changes: 10 additions & 6 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env,
void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x);

// Train a binary index with data provided
void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x);
void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const uint8_t* x);

// Converts the int FilterIds to Faiss ids type array.
void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds);
Expand Down Expand Up @@ -286,7 +286,7 @@ void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInter
auto *inputVectors = reinterpret_cast<std::vector<uint8_t>*>(vectorsAddressJ);
int dim = (int)dimJ;
if (dim % 8 != 0) {
throw std::runtime_error("Dimensions should be multiply of 8");
throw std::runtime_error("Dimensions should be multiple of 8");
}
int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8));
int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
Expand Down Expand Up @@ -848,8 +848,12 @@ jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex(knn_jni::JNIUtilInterface *
}

// Train index if needed
auto *trainingVectorsPointerCpp = reinterpret_cast<std::vector<float>*>(trainVectorsPointerJ);
int numVectors = trainingVectorsPointerCpp->size()/(int) dimensionJ;
int dim = (int)dimensionJ;
if (dim % 8 != 0) {
throw std::runtime_error("Dimensions should be multiple of 8");
}
auto *trainingVectorsPointerCpp = reinterpret_cast<std::vector<uint8_t>*>(trainVectorsPointerJ);
int numVectors = (int) (trainingVectorsPointerCpp->size() / (dim / 8));
if(!indexWriter->is_trained) {
InternalTrainBinaryIndex(indexWriter.get(), numVectors, trainingVectorsPointerCpp->data());
}
Expand Down Expand Up @@ -997,12 +1001,12 @@ void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) {
}
}

void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x) {
void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const uint8_t* x) {
if (auto * indexIvf = dynamic_cast<faiss::IndexBinaryIVF*>(index)) {
indexIvf->make_direct_map();
}
if (!index->is_trained) {
index->train(n, reinterpret_cast<const uint8_t*>(x));
index->train(n, x);
}
}

Expand Down
1 change: 1 addition & 0 deletions release-notes/opensearch-knn.release-notes-2.17.0.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Compatible with OpenSearch 2.17.0
* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844)
* Disallow a vector field to have an invalid character for a physical file name. [#1936](https://github.com/opensearch-project/k-NN/pull/1936)
* Fix memory overflow caused by cache behavior [#2015](https://github.com/opensearch-project/k-NN/pull/2015)
* Use correct type for binary vector in ivf training [#2086](https://github.com/opensearch-project/k-NN/pull/2086)
### Infrastructure
* Parallelize make to reduce build time [#2006] (https://github.com/opensearch-project/k-NN/pull/2006)
### Maintenance
Expand Down

0 comments on commit 98d9b01

Please sign in to comment.