From 98d9b016b8ab97c7313163ff152a3027c5de49ed Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 11 Sep 2024 12:13:13 -0700 Subject: [PATCH] Use correct type for binary vector in ivf training (#2086) (#2089) Signed-off-by: Heemin Kim (cherry picked from commit bf11cbb12731f5471f441f92f9360065a6caedd1) Co-authored-by: Heemin Kim --- jni/src/faiss_wrapper.cpp | 16 ++++++++++------ .../opensearch-knn.release-notes-2.17.0.0.md | 1 + 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index ba15c3ce7..227fcb477 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -71,7 +71,7 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x); // Train a binary index with data provided -void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x); +void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const uint8_t* x); // Converts the int FilterIds to Faiss ids type array. void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds); @@ -286,7 +286,7 @@ void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInter auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); int dim = (int)dimJ; if (dim % 8 != 0) { - throw std::runtime_error("Dimensions should be multiply of 8"); + throw std::runtime_error("Dimensions should be multiple of 8"); } int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8)); int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); @@ -848,8 +848,12 @@ jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex(knn_jni::JNIUtilInterface * } // Train index if needed - auto *trainingVectorsPointerCpp = reinterpret_cast*>(trainVectorsPointerJ); - int numVectors = trainingVectorsPointerCpp->size()/(int) dimensionJ; + int dim = (int)dimensionJ; + if (dim % 8 != 0) { + throw std::runtime_error("Dimensions should be multiple of 8"); + } + auto *trainingVectorsPointerCpp = reinterpret_cast*>(trainVectorsPointerJ); + int numVectors = (int) (trainingVectorsPointerCpp->size() / (dim / 8)); if(!indexWriter->is_trained) { InternalTrainBinaryIndex(indexWriter.get(), numVectors, trainingVectorsPointerCpp->data()); } @@ -997,12 +1001,12 @@ void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) { } } -void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x) { +void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const uint8_t* x) { if (auto * indexIvf = dynamic_cast(index)) { indexIvf->make_direct_map(); } if (!index->is_trained) { - index->train(n, reinterpret_cast(x)); + index->train(n, x); } } diff --git a/release-notes/opensearch-knn.release-notes-2.17.0.0.md b/release-notes/opensearch-knn.release-notes-2.17.0.0.md index 8b4aa8e95..1e046be9b 100644 --- a/release-notes/opensearch-knn.release-notes-2.17.0.0.md +++ b/release-notes/opensearch-knn.release-notes-2.17.0.0.md @@ -18,6 +18,7 @@ Compatible with OpenSearch 2.17.0 * Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844) * Disallow a vector field to have an invalid character for a physical file name. [#1936](https://github.com/opensearch-project/k-NN/pull/1936) * Fix memory overflow caused by cache behavior [#2015](https://github.com/opensearch-project/k-NN/pull/2015) +* Use correct type for binary vector in ivf training [#2086](https://github.com/opensearch-project/k-NN/pull/2086) ### Infrastructure * Parallelize make to reduce build time [#2006] (https://github.com/opensearch-project/k-NN/pull/2006) ### Maintenance