From 4ebb39a6fba40935bec9fcd436fd26ba29a549c2 Mon Sep 17 00:00:00 2001 From: luyuncheng Date: Tue, 5 Mar 2024 03:43:27 +0800 Subject: [PATCH] Optimize Faiss Query With Filters: Reduce iteration and memory for id filter (#1402) * Optimize Faiss Query With Filters. Reduce iteration copy for docid set iterator Signed-off-by: luyuncheng * Optimize Faiss Query With Filters. Reduce iteration copy for docid set iterator. Use Bitmap And Batch to do id filter. and you sparse or fixed bitset do exact ANN search Signed-off-by: luyuncheng * Using int64_t instead of long type for GetLongArrayElements Signed-off-by: luyuncheng * Add IDSelectorJlongBitmap Signed-off-by: luyuncheng * 1. Add IDSelectorJlongBitmap and UT for it 2. Move FilterIdsSelectorType to a util class Signed-off-by: luyuncheng * 1. Add IDSelectorJlongBitmap and UT for it 2. Move FilterIdsSelectorType to a util class 3. Spotless apply Signed-off-by: luyuncheng * Rebase remote-tracking branch 'origin/main' into Filter Signed-off-by: luyuncheng * tidy Signed-off-by: luyuncheng * Add Changelog Signed-off-by: luyuncheng * fix javadoc tasks Signed-off-by: luyuncheng * fix bwc javadoc Signed-off-by: luyuncheng * UpdatedFilterIdsSelector Signed-off-by: luyuncheng * UpdatedFilterIdsSelector Signed-off-by: luyuncheng * Rebase faiss_wrapper.cpp Signed-off-by: luyuncheng * UpdatedFilterIdsSelector For description Select different FilterIdsSelectorType Signed-off-by: luyuncheng * UpdatedFilterIdsSelector For description Select different FilterIdsSelectorType Signed-off-by: luyuncheng * UpdatedFilterIdsSelector as Byte.SIZE Signed-off-by: luyuncheng * UpdatedFilterIdsSelector For comments Signed-off-by: luyuncheng --------- Signed-off-by: luyuncheng (cherry picked from commit 3eeb855b0e0434c8eda8ae45343775548c067824) --- CHANGELOG.md | 1 + jni/include/faiss_wrapper.h | 3 +- jni/include/jni_util.h | 9 ++ .../org_opensearch_knn_jni_FaissService.h | 2 +- jni/src/faiss_wrapper.cpp | 120 +++++------------- jni/src/jni_util.cpp | 26 ++++ .../org_opensearch_knn_jni_FaissService.cpp | 4 +- jni/tests/faiss_wrapper_test.cpp | 14 +- jni/tests/test_util.cpp | 34 +++++ jni/tests/test_util.h | 9 ++ .../knn/index/query/FilterIdsSelector.java | 107 ++++++++++++++++ .../opensearch/knn/index/query/KNNWeight.java | 51 ++++---- .../filtered/FilteredIdsKNNIterator.java | 22 ++-- .../NestedFilteredIdsKNNIterator.java | 16 ++- .../org/opensearch/knn/jni/FaissService.java | 3 +- .../org/opensearch/knn/jni/JNIService.java | 6 +- .../knn/index/codec/KNNCodecTestUtil.java | 2 +- .../memory/NativeMemoryLoadStrategyTests.java | 2 +- .../index/query/FilterIdsSelectorTests.java | 60 +++++++++ .../knn/index/query/KNNWeightTests.java | 29 +++-- .../filtered/FilteredIdsKNNIteratorTests.java | 5 +- .../NestedFilteredIdsKNNIteratorTests.java | 10 +- .../opensearch/knn/jni/JNIServiceTests.java | 25 ++-- 23 files changed, 401 insertions(+), 159 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/query/FilterIdsSelector.java create mode 100644 src/test/java/org/opensearch/knn/index/query/FilterIdsSelectorTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 93ad21dc9..69b213ad8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.12...2.x) ### Features ### Enhancements +* Optize Faiss Query With Filters: Reduce iteration and memory for id filter [#1402](https://github.com/opensearch-project/k-NN/pull/1402) ### Bug Fixes ### Infrastructure ### Documentation diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 078526000..e99cdafb2 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -43,7 +43,8 @@ namespace knn_jni { // // Return an array of KNNQueryResults jobjectArray QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ, jintArray parentIdsJ); + jfloatArray queryVectorJ, jint kJ, jlongArray filterIdsJ, + jint filterIdsTypeJ, jintArray parentIdsJ); // Free the index located in memory at indexPointerJ void Free(jlong indexPointer); diff --git a/jni/include/jni_util.h b/jni/include/jni_util.h index b4dd44891..52b08a202 100644 --- a/jni/include/jni_util.h +++ b/jni/include/jni_util.h @@ -80,6 +80,8 @@ namespace knn_jni { virtual int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ) = 0; + virtual int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ) = 0; + virtual int GetJavaBytesArrayLength(JNIEnv *env, jbyteArray arrayJ) = 0; virtual int GetJavaFloatArrayLength(JNIEnv *env, jfloatArray arrayJ) = 0; @@ -94,6 +96,8 @@ namespace knn_jni { virtual jint * GetIntArrayElements(JNIEnv *env, jintArray array, jboolean * isCopy) = 0; + virtual jlong * GetLongArrayElements(JNIEnv *env, jlongArray array, jboolean * isCopy) = 0; + virtual jobject GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index) = 0; virtual jobject NewObject(JNIEnv *env, jclass clazz, jmethodID methodId, int id, float distance) = 0; @@ -108,6 +112,8 @@ namespace knn_jni { virtual void ReleaseIntArrayElements(JNIEnv *env, jintArray array, jint *elems, jint mode) = 0; + virtual void ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode) = 0; + virtual void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val) = 0; virtual void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf) = 0; @@ -139,6 +145,7 @@ namespace knn_jni { int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ); int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ); int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ); + int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ); int GetJavaBytesArrayLength(JNIEnv *env, jbyteArray arrayJ); int GetJavaFloatArrayLength(JNIEnv *env, jfloatArray arrayJ); @@ -146,6 +153,7 @@ namespace knn_jni { jbyte * GetByteArrayElements(JNIEnv *env, jbyteArray array, jboolean * isCopy); jfloat * GetFloatArrayElements(JNIEnv *env, jfloatArray array, jboolean * isCopy); jint * GetIntArrayElements(JNIEnv *env, jintArray array, jboolean * isCopy); + jlong * GetLongArrayElements(JNIEnv *env, jlongArray array, jboolean * isCopy); jobject GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index); jobject NewObject(JNIEnv *env, jclass clazz, jmethodID methodId, int id, float distance); jobjectArray NewObjectArray(JNIEnv *env, jsize len, jclass clazz, jobject init); @@ -153,6 +161,7 @@ namespace knn_jni { void ReleaseByteArrayElements(JNIEnv *env, jbyteArray array, jbyte *elems, int mode); void ReleaseFloatArrayElements(JNIEnv *env, jfloatArray array, jfloat *elems, int mode); void ReleaseIntArrayElements(JNIEnv *env, jintArray array, jint *elems, jint mode); + void ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode); void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val); void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf); diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index aefadcee4..3b649c227 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -56,7 +56,7 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd * Signature: (J[FI[J)[Lorg/opensearch/knn/index/query/KNNQueryResult; */ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter - (JNIEnv *, jclass, jlong, jfloatArray, jint, jintArray, jintArray); + (JNIEnv *, jclass, jlong, jfloatArray, jint, jlongArray, jint, jintArray); /* * Class: org_opensearch_knn_jni_FaissService diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index e88254b86..43f1454b5 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -29,9 +29,32 @@ // Defines type of IDSelector enum FilterIdsSelectorType{ - BITMAP, BATCH + BITMAP = 0, BATCH = 1, }; +namespace faiss { +// Using jlong to do Bitmap selector, jlong[] equals to lucene FixedBitSet#bits +struct IDSelectorJlongBitmap : IDSelector { + size_t n; + const jlong* bitmap; + + /** Construct with a binary mask like Lucene FixedBitSet + * + * @param n size of the bitmap array + * @param bitmap id like Lucene FixedBitSet bits + */ + IDSelectorJlongBitmap(size_t n, const jlong* bitmap) : n(n), bitmap(bitmap) {}; + bool is_member(idx_t id) const final { + uint64_t index = id; + uint64_t i = index >> 6; // div 64 + if (i >= n ) { + return false; + } + return (bitmap[i] >> ( index & 63)) & 1L; + } + ~IDSelectorJlongBitmap() override {} +}; +} // Translate space type to faiss metric faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType); @@ -42,9 +65,6 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, // Train an index with data provided void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x); -// Helps to choose the right FilterIdsSelectorType for Faiss -FilterIdsSelectorType getIdSelectorType(const int* filterIds, int filterIdsLength); - // Converts the int FilterIds to Faiss ids type array. void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds); @@ -199,11 +219,12 @@ jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNI jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ) { - return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, nullptr, parentIdsJ); + return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, nullptr, 0, parentIdsJ); } jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ, jintArray parentIdsJ) { + jfloatArray queryVectorJ, jint kJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { + if (queryVectorJ == nullptr) { throw std::runtime_error("Query Vector cannot be null"); } @@ -225,28 +246,14 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter omp_set_num_threads(1); // create the filterSearch params if the filterIdsJ is not a null pointer if(filterIdsJ != nullptr) { - int *filteredIdsArray = jniUtil->GetIntArrayElements(env, filterIdsJ, nullptr); - int filterIdsLength = jniUtil->GetJavaIntArrayLength(env, filterIdsJ); + jlong *filteredIdsArray = jniUtil->GetLongArrayElements(env, filterIdsJ, nullptr); + int filterIdsLength = jniUtil->GetJavaLongArrayLength(env, filterIdsJ); std::unique_ptr idSelector; - FilterIdsSelectorType idSelectorType = getIdSelectorType(filteredIdsArray, filterIdsLength); - // start with empty vectors for 2 different types of empty Selectors. We need define them here to avoid copying of data - // during the returns. We could have used pass by reference, but we choose pointers. Returning reference to local - // vector is also an option which can be efficient than copying during returns but it requires upto date C++ compilers. - // To avoid all those confusions, its better to work with pointers here. Ref: https://cplusplus.com/forum/general/56177/ - std::vector convertedIds; - std::vector bitmap; - // Choose a selector which suits best - if(idSelectorType == BATCH) { - convertedIds.resize(filterIdsLength); - convertFilterIdsToFaissIdType(filteredIdsArray, filterIdsLength, convertedIds.data()); - idSelector.reset(new faiss::IDSelectorBatch(convertedIds.size(), convertedIds.data())); + if(filterIdsTypeJ == BITMAP) { + idSelector.reset(new faiss::IDSelectorJlongBitmap(filterIdsLength, filteredIdsArray)); } else { - int maxIdValue = filteredIdsArray[filterIdsLength - 1]; - // >> 3 is equivalent to value / 8 - const int bitsetArraySize = (maxIdValue >> 3) + 1; - bitmap.resize(bitsetArraySize, 0); - buildFilterIdsBitMap(filteredIdsArray, filterIdsLength, bitmap.data()); - idSelector.reset(new faiss::IDSelectorBitmap(bitsetArraySize, bitmap.data())); + faiss::idx_t* batchIndices = reinterpret_cast(filteredIdsArray); + idSelector.reset(new faiss::IDSelectorBatch(filterIdsLength, batchIndices)); } faiss::SearchParameters *searchParameters; faiss::SearchParametersHNSW hnswParams; @@ -276,10 +283,10 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data(), searchParameters); } catch (...) { jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); - jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); + jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); throw; } - jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); + jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); } else { faiss::SearchParameters *searchParameters = nullptr; faiss::SearchParametersHNSW hnswParams; @@ -454,63 +461,6 @@ void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) { } } -/** - * This function takes a call on what ID Selector to use: - * https://github.com/facebookresearch/faiss/wiki/Setting-search-parameters-for-one-query#idselectorarray-idselectorbatch-and-idselectorbitmap - * - * class storage lookup construction(Opensearch + Faiss) - * IDSelectorArray O(k) O(k) O(2k) - * IDSelectorBatch O(k) O(1) O(2k) - * IDSelectorBitmap O(n/8) O(1) O(k) -> n is the max value of id in the index - * - * TODO: We need to ideally decide when we can take another hit of K iterations in latency. Some facts: - * an OpenSearch Index can have max segment size as 5GB which, which on a vector with dimension of 128 boils down to - * 7.5M vectors. - * Ref: https://opensearch.org/docs/latest/search-plugins/knn/knn-index/#hnsw-memory-estimation - * M = 16 - * Dimension = 128 - * (1.1 * ( 4 * 128 + 8 * 16) * 7500000)/(1024*1024*1024) ~ 4.9GB - * Ids are sequential in a Segment which means for IDSelectorBitmap total size if the max ID has value of 7.5M will be - * 7500000/(8*1024) = 915KBs in worst case. But with larger dimensions this worst case value will decrease. - * - * With 915KB how many ids can be represented as an array of 64-bit longs : 117,120 ids - * So iterating on 117k ids for 1 single pass is also time consuming. So, we are currently concluding to consider only size - * as factor. We need to improve on this. - * - * TODO: Best way is to implement a SparseBitSet in C++. This can be done by extending the IDSelector Interface of Faiss. - * - * @param filterIds - * @param filterIdsLength - * @return std::string - */ -FilterIdsSelectorType getIdSelectorType(const int* filterIds, int filterIdsLength) { - int maxIdValue = filterIds[filterIdsLength - 1]; - if(filterIdsLength * sizeof(faiss::idx_t) * 8 <= maxIdValue ) { - return BATCH; - } - return BITMAP; -} - -void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds) { - for (int i = 0; i < filterIdsLength; i++) { - convertedFilterIds[i] = filterIds[i]; - } -} - -void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bitsetVector) { - /** - * Coming from Faiss IDSelectorBitmap::is_member function bitmap id will be selected - * iff id / 8 < n and bit number (i%8) of bitmap[floor(i / 8)] is 1. - */ - for(int i = 0 ; i < filterIdsLength ; i ++) { - int value = filterIds[i]; - // / , % are expensive operation. Hence, using BitShift operation as they are fast. - int bitsetArrayIndex = value >> 3 ; // is equivalent to value / 8 - // (value & 7) equivalent to value % 8 - bitsetVector[bitsetArrayIndex] = bitsetVector[bitsetArrayIndex] | (1 << (value & 7)); - } -} - std::unique_ptr buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector* bitmap) { int *parentIdsArray = jniUtil->GetIntArrayElements(env, parentIdsJ, nullptr); int parentIdsLength = jniUtil->GetJavaIntArrayLength(env, parentIdsJ); diff --git a/jni/src/jni_util.cpp b/jni/src/jni_util.cpp index cb9270c22..a0c1d5733 100644 --- a/jni/src/jni_util.cpp +++ b/jni/src/jni_util.cpp @@ -319,6 +319,17 @@ int knn_jni::JNIUtil::GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ) { return length; } +int knn_jni::JNIUtil::GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ) { + + if (arrayJ == nullptr) { + throw std::runtime_error("Array cannot be null"); + } + + int length = env->GetArrayLength(arrayJ); + this->HasExceptionInStack(env, "Unable to get array length"); + return length; +} + int knn_jni::JNIUtil::GetJavaBytesArrayLength(JNIEnv *env, jbyteArray arrayJ) { if (arrayJ == nullptr) { @@ -376,6 +387,17 @@ jint * knn_jni::JNIUtil::GetIntArrayElements(JNIEnv *env, jintArray array, jbool return intArray; } +jlong * knn_jni::JNIUtil::GetLongArrayElements(JNIEnv *env, jlongArray array, jboolean * isCopy) { + // Lets check for error here + jlong * longArray = env->GetLongArrayElements(array, isCopy); + if (longArray == nullptr) { + this->HasExceptionInStack(env, "Unable to get long array"); + throw std::runtime_error("Unable to get long array"); + } + + return longArray; +} + jobject knn_jni::JNIUtil::GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index) { jobject object = env->GetObjectArrayElement(array, index); this->HasExceptionInStack(env, "Unable to get object"); @@ -424,6 +446,10 @@ void knn_jni::JNIUtil::ReleaseIntArrayElements(JNIEnv *env, jintArray array, jin env->ReleaseIntArrayElements(array, elems, mode); } +void knn_jni::JNIUtil::ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode) { + env->ReleaseLongArrayElements(array, elems, mode); +} + void knn_jni::JNIUtil::SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val) { env->SetObjectArrayElement(array, index, val); this->HasExceptionInStack(env, "Unable to set object array element"); diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index a7b24fcab..e8e761ad7 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -89,10 +89,10 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd } JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter - (JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jintArray filteredIdsJ, jintArray parentIdsJ) { + (JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jlongArray filteredIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { try { - return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ, parentIdsJ); + return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ, filterIdsTypeJ, parentIdsJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 58daaee2b..4318e8ef9 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -251,9 +251,13 @@ TEST(FaissQueryIndexWithFilterTest1435, BasicAssertions) { queries.push_back(query); } - std::vector filterIds; + int num_bits = test_util::bits2words(164); + std::vector bitmap(num_bits,0); + std::vector filterIds; + for (int64_t i = 154; i < 163; i++) { filterIds.push_back(i); + test_util::setBitSet(i, bitmap.data(), bitmap.size()); } std::unordered_set filterIdSet(filterIds.begin(), filterIds.end()); @@ -270,9 +274,9 @@ TEST(FaissQueryIndexWithFilterTest1435, BasicAssertions) { JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; EXPECT_CALL(mockJNIUtil, - GetJavaIntArrayLength( - jniEnv, reinterpret_cast(&filterIds))) - .WillRepeatedly(Return(filterIds.size())); + GetJavaLongArrayLength( + jniEnv, reinterpret_cast(&bitmap))) + .WillRepeatedly(Return(bitmap.size())); int k = 20; for (auto query : queries) { @@ -282,7 +286,7 @@ TEST(FaissQueryIndexWithFilterTest1435, BasicAssertions) { &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), reinterpret_cast(&query), k, - reinterpret_cast(&filterIds), nullptr))); + reinterpret_cast(&bitmap), 0, nullptr))); ASSERT_TRUE(results->size() <= filterIds.size()); ASSERT_TRUE(results->size() > 0); diff --git a/jni/tests/test_util.cpp b/jni/tests/test_util.cpp index a75886c51..3c6933d89 100644 --- a/jni/tests/test_util.cpp +++ b/jni/tests/test_util.cpp @@ -116,6 +116,15 @@ test_util::MockJNIUtil::MockJNIUtil() { reinterpret_cast *>(arrayJ)->data()); }); + + // arrayJ is re-interpreted as a std::vector * and then the data is + // re-interpreted as a jlong * + ON_CALL(*this, GetLongArrayElements) + .WillByDefault([this](JNIEnv *env, jlongArray arrayJ, jboolean *isCopy) { + return reinterpret_cast( + reinterpret_cast *>(arrayJ)->data()); + }); + // arrayJ is re-interpreted as a std::vector * and then the data is // re-interpreted as a jfloat * ON_CALL(*this, GetFloatArrayElements) @@ -146,6 +155,13 @@ test_util::MockJNIUtil::MockJNIUtil() { return reinterpret_cast *>(arrayJ)->size(); }); + // arrayJ is re-interpreted as a std::vector * and then the size is + // returned + ON_CALL(*this, GetJavaLongArrayLength) + .WillByDefault([this](JNIEnv *env, jlongArray arrayJ) { + return reinterpret_cast *>(arrayJ)->size(); + }); + // arrayJ is re-interpreted as a std::vector> * and then // the 'index' element is re-interpreted as a jobject ON_CALL(*this, GetObjectArrayElement) @@ -193,6 +209,11 @@ test_util::MockJNIUtil::MockJNIUtil() { .WillByDefault( [this](JNIEnv *env, jintArray array, jint *elems, int mode) {}); + // This function should not do anything meaningful in the unit tests + ON_CALL(*this, ReleaseLongArrayElements) + .WillByDefault( + [this](JNIEnv *env, jlongArray array, jlong *elems, int mode) {}); + // array is re-interpreted as a std::vector * and then the bytes from // buf are copied to it ON_CALL(*this, SetByteArrayRegion) @@ -347,3 +368,16 @@ float test_util::RandomFloat(float min, float max) { std::uniform_real_distribution distribution(min, max); return distribution(e1); } + +size_t test_util::bits2words(uint64_t numBits) { + return ((numBits - 1) >> 6) + 1; +} + +void test_util::setBitSet(uint64_t value, jlong* array, size_t size) { + uint64_t wordNum = value >> 6; + if (wordNum >= size ) { + return; + } + jlong bitmask = (1L << (value & 63)); + array[wordNum] |= bitmask; +} \ No newline at end of file diff --git a/jni/tests/test_util.h b/jni/tests/test_util.h index 6eac70fcf..a8ce6ca8f 100644 --- a/jni/tests/test_util.h +++ b/jni/tests/test_util.h @@ -64,9 +64,12 @@ namespace test_util { (JNIEnv * env, jobjectArray array2dJ)); MOCK_METHOD(jint*, GetIntArrayElements, (JNIEnv * env, jintArray array, jboolean* isCopy)); + MOCK_METHOD(jlong*, GetLongArrayElements, + (JNIEnv * env, jlongArray array, jboolean* isCopy)); MOCK_METHOD(int, GetJavaBytesArrayLength, (JNIEnv * env, jbyteArray arrayJ)); MOCK_METHOD(int, GetJavaFloatArrayLength, (JNIEnv * env, jfloatArray arrayJ)); MOCK_METHOD(int, GetJavaIntArrayLength, (JNIEnv * env, jintArray arrayJ)); + MOCK_METHOD(int, GetJavaLongArrayLength, (JNIEnv * env, jlongArray arrayJ)); MOCK_METHOD(int, GetJavaObjectArrayLength, (JNIEnv * env, jobjectArray arrayJ)); MOCK_METHOD(jobject, GetObjectArrayElement, @@ -86,6 +89,8 @@ namespace test_util { (JNIEnv * env, jfloatArray array, jfloat* elems, int mode)); MOCK_METHOD(void, ReleaseIntArrayElements, (JNIEnv * env, jintArray array, jint* elems, jint mode)); + MOCK_METHOD(void, ReleaseLongArrayElements, + (JNIEnv * env, jlongArray array, jlong* elems, jint mode)); MOCK_METHOD(void, SetByteArrayRegion, (JNIEnv * env, jbyteArray array, jsize start, jsize len, const jbyte* buf)); @@ -150,6 +155,10 @@ namespace test_util { float RandomFloat(float min, float max); + // returns the number of 64 bit words it would take to hold numBits + size_t bits2words(uint64_t numBits); + + void setBitSet(uint64_t value, jlong* array, size_t size); // ------------------------------------------------------------------------------- } // namespace test_util diff --git a/src/main/java/org/opensearch/knn/index/query/FilterIdsSelector.java b/src/main/java/org/opensearch/knn/index/query/FilterIdsSelector.java new file mode 100644 index 000000000..bf06e8c5e --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/FilterIdsSelector.java @@ -0,0 +1,107 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.query; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; +import org.apache.lucene.util.FixedBitSet; + +import java.io.IOException; + +/** + * Util Class for filter ids selector + */ +@AllArgsConstructor +@Getter +public class FilterIdsSelector { + + /** + * When do ann query with filters, there are two types: + * BitMap using FixedBitSet, BATCH using a long array stands for filter result docids. + */ + @AllArgsConstructor + @Getter + public enum FilterIdsSelectorType { + BITMAP(0), + BATCH(1); + + private final int value; + } + + long[] filterIds; + private FilterIdsSelectorType filterType; + + /** + * This function takes a call on what ID Selector to use: + * https://github.com/facebookresearch/faiss/wiki/Setting-search-parameters-for-one-query#idselectorarray-idselectorbatch-and-idselectorbitmap + * + * class storage lookup construction(Opensearch + Faiss) + * IDSelectorArray O(k) O(k) O(2k) + * IDSelectorBatch O(k) O(1) O(2k) + * IDSelectorBitmap O(n/8) O(1) O(k) n is the max value of id in the index + * + * TODO: We need to ideally decide when we can take another hit of K iterations in latency. Some facts: + * an OpenSearch Index can have max segment size as 5GB which, which on a vector with dimension of 128 boils down to + * 7.5M vectors. + * Ref: https://opensearch.org/docs/latest/search-plugins/knn/knn-index/#hnsw-memory-estimation + * M = 16 + * Dimension = 128 + * (1.1 * ( 4 * 128 + 8 * 16) * 7500000)/(1024*1024*1024) ~ 4.9GB + * Ids are sequential in a Segment which means for IDSelectorBitmap total size if the max ID has value of 7.5M will be + * 7500000/(8*1024) = 915KBs in worst case. But with larger dimensions this worst case value will decrease. + * + * With 915KB how many ids can be represented as an array of 64-bit longs : 117,120 ids + * So iterating on 117k ids for 1 single pass is also time consuming. So, we are currently concluding to consider only size + * as factor. We need to improve on this. + * + * Array Memory: Cardinality * Long.BYTES + * BitSet Memory: MaxId / Byte.SIZE + * When Array Memory less than or equal to BitSet Memory return FilterIdsSelectorType.BATCH + * Else return FilterIdsSelectorType.BITMAP; + * + * @param filterIdsBitSet Filter query result docs + * @param cardinality The number of bits that are set + * @return {@link FilterIdsSelector} + */ + public static FilterIdsSelector getFilterIdSelector(final BitSet filterIdsBitSet, final int cardinality) throws IOException { + long[] filterIds; + FilterIdsSelector.FilterIdsSelectorType filterType; + if (filterIdsBitSet instanceof FixedBitSet) { + /** + * When filterIds is dense filter, using fixed bitset + */ + filterIds = ((FixedBitSet) filterIdsBitSet).getBits(); + filterType = FilterIdsSelector.FilterIdsSelectorType.BITMAP; + } else if ((cardinality * Long.BYTES * Byte.SIZE) <= filterIdsBitSet.length()) { + /** + * When filterIds is sparse bitset, using ram usage to decide FilterIdsSelectorType + */ + BitSetIterator bitSetIterator = new BitSetIterator(filterIdsBitSet, cardinality); + filterIds = new long[cardinality]; + int idx = 0; + for (int docId = bitSetIterator.nextDoc(); docId != DocIdSetIterator.NO_MORE_DOCS; docId = bitSetIterator.nextDoc()) { + filterIds[idx++] = docId; + } + filterType = FilterIdsSelectorType.BATCH; + } else { + FixedBitSet fixedBitSet = new FixedBitSet(filterIdsBitSet.length()); + BitSetIterator sparseBitSetIterator = new BitSetIterator(filterIdsBitSet, cardinality); + fixedBitSet.or(sparseBitSetIterator); + filterIds = fixedBitSet.getBits(); + filterType = FilterIdsSelector.FilterIdsSelectorType.BITMAP; + } + return new FilterIdsSelector(filterIds, filterType); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index afa7e93b0..5bd4e9359 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -100,11 +100,13 @@ public Explanation explain(LeafReaderContext context, int doc) { @Override public Scorer scorer(LeafReaderContext context) throws IOException { - final int[] filterIdsArray = getFilterIdsArray(context); + + final BitSet filterBitSet = getFilteredDocsBitSet(context); + int cardinality = filterBitSet.cardinality(); // We don't need to go to JNI layer if no documents are found which satisfy the filters // We should give this condition a deeper look that where it should be placed. For now I feel this is a good // place, - if (filterWeight != null && filterIdsArray.length == 0) { + if (filterWeight != null && cardinality == 0) { return KNNScorer.emptyScorer(this); } final Map docIdsToScoreMap = new HashMap<>(); @@ -114,22 +116,22 @@ public Scorer scorer(LeafReaderContext context) throws IOException { * . Hence, if filtered results are less than K and filter query is present we should shift to exact search. * This improves the recall. */ - if (filterWeight != null && canDoExactSearch(filterIdsArray.length)) { - docIdsToScoreMap.putAll(doExactSearch(context, filterIdsArray)); + if (filterWeight != null && canDoExactSearch(cardinality)) { + docIdsToScoreMap.putAll(doExactSearch(context, filterBitSet)); } else { - Map annResults = doANNSearch(context, filterIdsArray); + Map annResults = doANNSearch(context, filterBitSet, cardinality); if (annResults == null) { return null; } - if (canDoExactSearchAfterANNSearch(filterIdsArray.length, annResults.size())) { + if (canDoExactSearchAfterANNSearch(cardinality, annResults.size())) { log.debug( "Doing ExactSearch after doing ANNSearch as the number of documents returned are less than " + "K, even when we have more than K filtered Ids. K: {}, ANNResults: {}, filteredIdCount: {}", knnQuery.getK(), annResults.size(), - filterIdsArray.length + cardinality ); - annResults = doExactSearch(context, filterIdsArray); + annResults = doExactSearch(context, filterBitSet); } docIdsToScoreMap.putAll(annResults); } @@ -139,7 +141,11 @@ public Scorer scorer(LeafReaderContext context) throws IOException { return convertSearchResponseToScorer(docIdsToScoreMap); } - private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx, final Weight filterWeight) throws IOException { + private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException { + if (this.filterWeight == null) { + return new FixedBitSet(0); + } + final Bits liveDocs = ctx.reader().getLiveDocs(); final int maxDoc = ctx.reader().maxDoc(); @@ -166,13 +172,6 @@ protected boolean match(int doc) { return BitSet.of(filterIterator, maxDoc); } - private int[] getFilterIdsArray(final LeafReaderContext context) throws IOException { - if (filterWeight == null) { - return new int[0]; - } - return bitSetToIntArray(getFilteredDocsBitSet(context, this.filterWeight)); - } - private int[] getParentIdsArray(final LeafReaderContext context) throws IOException { if (knnQuery.getParentsFilter() == null) { return null; @@ -194,7 +193,8 @@ private int[] bitSetToIntArray(final BitSet bitSet) { return intArray; } - private Map doANNSearch(final LeafReaderContext context, final int[] filterIdsArray) throws IOException { + private Map doANNSearch(final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality) + throws IOException { SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader()); String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString(); @@ -265,6 +265,10 @@ private Map doANNSearch(final LeafReaderContext context, final i throw new RuntimeException(e); } + // From cardinality select different filterIds type + FilterIdsSelector filterIdsSelector = FilterIdsSelector.getFilterIdSelector(filterIdsBitSet, cardinality); + long[] filterIds = filterIdsSelector.getFilterIds(); + FilterIdsSelector.FilterIdsSelectorType filterType = filterIdsSelector.getFilterType(); // Now that we have the allocation, we need to readLock it indexAllocation.readLock(); try { @@ -277,7 +281,8 @@ private Map doANNSearch(final LeafReaderContext context, final i knnQuery.getQueryVector(), knnQuery.getK(), knnEngine.getName(), - filterIdsArray, + filterIds, + filterType.getValue(), parentIds ); @@ -303,13 +308,13 @@ private Map doANNSearch(final LeafReaderContext context, final i .collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType))); } - private Map doExactSearch(final LeafReaderContext leafReaderContext, final int[] filterIdsArray) throws IOException { + private Map doExactSearch(final LeafReaderContext leafReaderContext, final BitSet filterIdsBitSet) { try { // Creating min heap and init with MAX DocID and Score as -INF. final HitQueue queue = new HitQueue(this.knnQuery.getK(), true); ScoreDoc topDoc = queue.top(); final Map docToScore = new HashMap<>(); - FilteredIdsKNNIterator iterator = getFilteredKNNIterator(leafReaderContext, filterIdsArray); + FilteredIdsKNNIterator iterator = getFilteredKNNIterator(leafReaderContext, filterIdsBitSet); int docId; while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { if (iterator.score() > topDoc.score) { @@ -340,16 +345,16 @@ private Map doExactSearch(final LeafReaderContext leafReaderCont return Collections.emptyMap(); } - private FilteredIdsKNNIterator getFilteredKNNIterator(final LeafReaderContext leafReaderContext, final int[] filterIdsArray) + private FilteredIdsKNNIterator getFilteredKNNIterator(final LeafReaderContext leafReaderContext, final BitSet filterIdsBitSet) throws IOException { final SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader()); final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName()); final SpaceType spaceType = getSpaceType(fieldInfo); return knnQuery.getParentsFilter() == null - ? new FilteredIdsKNNIterator(filterIdsArray, knnQuery.getQueryVector(), values, spaceType) + ? new FilteredIdsKNNIterator(filterIdsBitSet, knnQuery.getQueryVector(), values, spaceType) : new NestedFilteredIdsKNNIterator( - filterIdsArray, + filterIdsBitSet, knnQuery.getQueryVector(), values, spaceType, diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIterator.java index a286829d5..a53cb8d60 100644 --- a/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIterator.java @@ -7,6 +7,8 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.BytesRef; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.util.KNNVectorSerializer; @@ -23,23 +25,26 @@ */ public class FilteredIdsKNNIterator { // Array of doc ids to iterate - protected final int[] filterIdsArray; + protected final BitSet filterIdsBitSet; + protected final BitSetIterator bitSetIterator; protected final float[] queryVector; protected final BinaryDocValues binaryDocValues; protected final SpaceType spaceType; protected float currentScore = Float.NEGATIVE_INFINITY; - protected int currentPos = 0; + protected int docId; public FilteredIdsKNNIterator( - final int[] filterIdsArray, + final BitSet filterIdsBitSet, final float[] queryVector, final BinaryDocValues binaryDocValues, final SpaceType spaceType ) { - this.filterIdsArray = filterIdsArray; + this.filterIdsBitSet = filterIdsBitSet; + this.bitSetIterator = new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length()); this.queryVector = queryVector; this.binaryDocValues = binaryDocValues; this.spaceType = spaceType; + this.docId = bitSetIterator.nextDoc(); } /** @@ -49,13 +54,14 @@ public FilteredIdsKNNIterator( * @return next doc id */ public int nextDoc() throws IOException { - if (currentPos >= filterIdsArray.length) { + + if (docId == DocIdSetIterator.NO_MORE_DOCS) { return DocIdSetIterator.NO_MORE_DOCS; } - int docId = binaryDocValues.advance(filterIdsArray[currentPos]); + int doc = binaryDocValues.advance(docId); currentScore = computeScore(); - currentPos++; - return docId; + docId = bitSetIterator.nextDoc(); + return doc; } public float score() { diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIterator.java index d2e4d3e25..9776ebbe9 100644 --- a/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIterator.java @@ -20,7 +20,7 @@ public class NestedFilteredIdsKNNIterator extends FilteredIdsKNNIterator { private final BitSet parentBitSet; public NestedFilteredIdsKNNIterator( - final int[] filterIdsArray, + final BitSet filterIdsArray, final float[] queryVector, final BinaryDocValues values, final SpaceType spaceType, @@ -38,20 +38,22 @@ public NestedFilteredIdsKNNIterator( */ @Override public int nextDoc() throws IOException { - if (currentPos >= filterIdsArray.length) { + if (docId == DocIdSetIterator.NO_MORE_DOCS) { return DocIdSetIterator.NO_MORE_DOCS; } + currentScore = Float.NEGATIVE_INFINITY; - int currentParent = parentBitSet.nextSetBit(filterIdsArray[currentPos]); + int currentParent = parentBitSet.nextSetBit(docId); int bestChild = -1; - while (currentPos < filterIdsArray.length && filterIdsArray[currentPos] < currentParent) { - binaryDocValues.advance(filterIdsArray[currentPos]); + + while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) { + binaryDocValues.advance(docId); float score = computeScore(); if (score > currentScore) { - bestChild = filterIdsArray[currentPos]; + bestChild = docId; currentScore = score; } - currentPos++; + docId = bitSetIterator.nextDoc(); } return bestChild; diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index abf3e052a..f330352ec 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -104,7 +104,8 @@ public static native KNNQueryResult[] queryIndexWithFilter( long indexPointer, float[] queryVector, int k, - int[] filterIds, + long[] filterIds, + int filterIdsType, int[] parentIds ); diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index beef9f927..2835be23d 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -99,6 +99,7 @@ public static long loadIndex(String indexPath, Map parameters, S * @param k neighbors to be returned * @param engineName name of engine to query index * @param filteredIds array of ints on which should be used for search. + * @param filterIdsType how to filter ids: Batch or BitMap * @return KNNQueryResult array of k neighbors */ public static KNNQueryResult[] queryIndex( @@ -106,7 +107,8 @@ public static KNNQueryResult[] queryIndex( float[] queryVector, int k, String engineName, - int[] filteredIds, + long[] filteredIds, + int filterIdsType, int[] parentIds ) { if (KNNEngine.NMSLIB.getName().equals(engineName)) { @@ -119,7 +121,7 @@ public static KNNQueryResult[] queryIndex( // filterIds. FilterIds is coming as empty then its the case where we need to do search with Faiss engine // normally. if (ArrayUtils.isNotEmpty(filteredIds)) { - return FaissService.queryIndexWithFilter(indexPointer, queryVector, k, filteredIds, parentIds); + return FaissService.queryIndexWithFilter(indexPointer, queryVector, k, filteredIds, filterIdsType, parentIds); } return FaissService.queryIndex(indexPointer, queryVector, k, parentIds); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java index b1d9d9434..472f05113 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java @@ -341,7 +341,7 @@ public static void assertLoadableByEngine( ); int k = 2; float[] queryVector = new float[dimension]; - KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, knnEngine.getName(), null, null); + KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, knnEngine.getName(), null, 0, null); assertTrue(results.length > 0); JNIService.free(indexPtr, knnEngine.getName()); } diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java index 798c64a17..1f1088de1 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java @@ -74,7 +74,7 @@ public void testIndexLoadStrategy_load() throws IOException { // Confirm that the file was loaded by querying float[] query = new float[dimension]; Arrays.fill(query, numVectors + 1); - KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, knnEngine.getName(), null, null); + KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, knnEngine.getName(), null, 0, null); assertTrue(results.length > 0); } diff --git a/src/test/java/org/opensearch/knn/index/query/FilterIdsSelectorTests.java b/src/test/java/org/opensearch/knn/index/query/FilterIdsSelectorTests.java new file mode 100644 index 000000000..02b1553a3 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/FilterIdsSelectorTests.java @@ -0,0 +1,60 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.query; + +import lombok.SneakyThrows; +import org.apache.lucene.util.BitSetIterator; +import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.SparseFixedBitSet; +import org.opensearch.knn.KNNTestCase; + +public class FilterIdsSelectorTests extends KNNTestCase { + + @SneakyThrows + public void testGetIdSelectorTypeWithFixedBitSet() { + FixedBitSet bits = new FixedBitSet(101); + for (int i = 1; i <= 100; i++) { + bits.set(i); + } + FilterIdsSelector idsSelector = FilterIdsSelector.getFilterIdSelector(bits, bits.cardinality()); + assertEquals(idsSelector.getFilterType(), FilterIdsSelector.FilterIdsSelectorType.BITMAP); + assertArrayEquals(bits.getBits(), idsSelector.filterIds); + } + + @SneakyThrows + public void testGetIdSelectorTypeWithSparseBitSetHigh() { + SparseFixedBitSet bits = new SparseFixedBitSet(101); + for (int i = 1; i <= 100; i++) { + bits.set(i); + } + FilterIdsSelector idsSelector = FilterIdsSelector.getFilterIdSelector(bits, bits.cardinality()); + assertEquals(idsSelector.getFilterType(), FilterIdsSelector.FilterIdsSelectorType.BITMAP); + FixedBitSet fixedBitSet = new FixedBitSet(bits.length()); + BitSetIterator sparseBitSetIterator = new BitSetIterator(bits, 101); + fixedBitSet.or(sparseBitSetIterator); + assertArrayEquals(fixedBitSet.getBits(), idsSelector.filterIds); + } + + @SneakyThrows + public void testGetIdSelectorTypeWithSparseBitSetLow() { + int maxDoc = (Integer.MAX_VALUE) / 2; + SparseFixedBitSet bits = new SparseFixedBitSet(maxDoc); + long array[] = new long[100]; + for (int i = maxDoc - 100, idx = 0; i < maxDoc; i++) { + bits.set(i); + array[idx++] = i; + } + FilterIdsSelector idsSelector = FilterIdsSelector.getFilterIdSelector(bits, bits.cardinality()); + assertEquals(idsSelector.getFilterType(), FilterIdsSelector.FilterIdsSelectorType.BATCH); + assertArrayEquals(array, idsSelector.filterIds); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 5d49f052c..49c7c7566 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -157,7 +157,7 @@ public void testQueryScoreForFaissWithModel() { SpaceType spaceType = SpaceType.L2; final Function scoreTranslator = spaceType::scoreTranslation; final String modelId = "modelId"; - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), any())) + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), anyInt(), any())) .thenReturn(getKNNQueryResults()); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); @@ -300,7 +300,7 @@ public void testShardWithoutFiles() { @SneakyThrows public void testEmptyQueryResults() { final KNNQueryResult[] knnQueryResults = new KNNQueryResult[] {}; - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), any())) + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), anyInt(), any())) .thenReturn(knnQueryResults); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); @@ -345,8 +345,13 @@ public void testEmptyQueryResults() { public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { int k = 3; final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterDocIds), any())) - .thenReturn(getFilteredKNNQueryResults()); + FixedBitSet filterBitSet = new FixedBitSet(filterDocIds.length); + for (int docId : filterDocIds) { + filterBitSet.set(docId); + } + jniServiceMockedStatic.when( + () -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterBitSet.getBits()), anyInt(), any()) + ).thenReturn(getFilteredKNNQueryResults()); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); final Bits liveDocsBits = mock(Bits.class); @@ -404,7 +409,10 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); assertNotNull(docIdSetIterator); assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); - jniServiceMockedStatic.verify(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterDocIds), any())); + + jniServiceMockedStatic.verify( + () -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterBitSet.getBits()), anyInt(), any()) + ); final List actualDocIds = new ArrayList<>(); final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); @@ -669,14 +677,17 @@ public void testANNWithParentsFilter_whenDoingANN_thenBitSetIsPassedToJNI() { final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, 1, INDEX_NAME, null, bitSetProducer); final KNNWeight knnWeight = new KNNWeight(query, 0.0f, null); - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), eq(parentsFilter))) - .thenReturn(getKNNQueryResults()); + jniServiceMockedStatic.when( + () -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), anyInt(), eq(parentsFilter)) + ).thenReturn(getKNNQueryResults()); // Execute Scorer knnScorer = knnWeight.scorer(leafReaderContext); // Verify - jniServiceMockedStatic.verify(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), eq(parentsFilter))); + jniServiceMockedStatic.verify( + () -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), anyInt(), eq(parentsFilter)) + ); assertNotNull(knnScorer); final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); assertNotNull(docIdSetIterator); @@ -735,7 +746,7 @@ private void testQueryScore( final Set segmentFiles, final Map fileAttributes ) throws IOException { - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), any())) + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), anyInt(), any())) .thenReturn(getKNNQueryResults()); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); diff --git a/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIteratorTests.java index cfb66662e..dce703050 100644 --- a/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIteratorTests.java @@ -9,6 +9,7 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.FixedBitSet; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.util.KNNVectorAsArraySerializer; @@ -41,12 +42,14 @@ public void testNextDoc_whenCalled_IterateAllDocs() { .collect(Collectors.toList()); when(values.binaryValue()).thenReturn(byteRefs.get(0), byteRefs.get(1), byteRefs.get(2)); + FixedBitSet filterBitSet = new FixedBitSet(4); for (int id : filterIds) { when(values.advance(id)).thenReturn(id); + filterBitSet.set(id); } // Execute and verify - FilteredIdsKNNIterator iterator = new FilteredIdsKNNIterator(filterIds, queryVector, values, spaceType); + FilteredIdsKNNIterator iterator = new FilteredIdsKNNIterator(filterBitSet, queryVector, values, spaceType); for (int i = 0; i < filterIds.length; i++) { assertEquals(filterIds[i], iterator.nextDoc()); assertEquals(expectedScores.get(i), (Float) iterator.score()); diff --git a/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIteratorTests.java index 90d56fddc..d732376ef 100644 --- a/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIteratorTests.java @@ -47,12 +47,20 @@ public void testNextDoc_whenIterate_ReturnBestChildDocsPerParent() { .collect(Collectors.toList()); when(values.binaryValue()).thenReturn(byteRefs.get(0), byteRefs.get(1), byteRefs.get(2)); + FixedBitSet filterBitSet = new FixedBitSet(4); for (int id : filterIds) { when(values.advance(id)).thenReturn(id); + filterBitSet.set(id); } // Execute and verify - NestedFilteredIdsKNNIterator iterator = new NestedFilteredIdsKNNIterator(filterIds, queryVector, values, spaceType, parentBitSet); + NestedFilteredIdsKNNIterator iterator = new NestedFilteredIdsKNNIterator( + filterBitSet, + queryVector, + values, + spaceType, + parentBitSet + ); assertEquals(filterIds[0], iterator.nextDoc()); assertEquals(expectedScores.get(0), iterator.score()); assertEquals(filterIds[2], iterator.nextDoc()); diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index 2afcff83d..8e3382ece 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -537,13 +537,13 @@ public void testQueryIndex_faiss_sqfp16_valid() { assertNotEquals(0, pointer); for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, null, null); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, null, 0, null); assertEquals(k, results.length); } // Filter will result in no ids for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, new int[] { 0 }, null); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, new long[] { 0 }, 0, null); assertEquals(0, results.length); } } @@ -726,12 +726,15 @@ public void testLoadIndex_faiss_valid() throws IOException { } public void testQueryIndex_invalidEngine() { - expectThrows(IllegalArgumentException.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, "invalid" + "-engine", null, null)); + expectThrows( + IllegalArgumentException.class, + () -> JNIService.queryIndex(0L, new float[] {}, 0, "invalid" + "-engine", null, 0, null) + ); } public void testQueryIndex_nmslib_invalid_badPointer() { - expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, KNNEngine.NMSLIB.getName(), null, null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, KNNEngine.NMSLIB.getName(), null, 0, null)); } public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { @@ -754,7 +757,7 @@ public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { ); assertNotEquals(0, pointer); - expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, KNNEngine.NMSLIB.getName(), null, null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, KNNEngine.NMSLIB.getName(), null, 0, null)); } public void testQueryIndex_nmslib_valid() throws IOException { @@ -780,7 +783,7 @@ public void testQueryIndex_nmslib_valid() throws IOException { assertNotEquals(0, pointer); for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.NMSLIB.getName(), null, null); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.NMSLIB.getName(), null, 0, null); assertEquals(k, results.length); } } @@ -788,7 +791,7 @@ public void testQueryIndex_nmslib_valid() throws IOException { public void testQueryIndex_faiss_invalid_badPointer() { - expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, FAISS_NAME, null, null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, FAISS_NAME, null, 0, null)); } public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { @@ -807,7 +810,7 @@ public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), Collections.emptyMap(), FAISS_NAME); assertNotEquals(0, pointer); - expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, FAISS_NAME, null, null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, FAISS_NAME, null, 0, null)); } public void testQueryIndex_faiss_valid() throws IOException { @@ -836,13 +839,13 @@ public void testQueryIndex_faiss_valid() throws IOException { assertNotEquals(0, pointer); for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, null, null); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, null, 0, null); assertEquals(k, results.length); } // Filter will result in no ids for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, new int[] { 0 }, null); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, new long[] { 0 }, 0, null); assertEquals(0, results.length); } } @@ -877,7 +880,7 @@ public void testQueryIndex_faiss_parentIds() throws IOException { assertNotEquals(0, pointer); for (float[] query : testDataNested.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, null, parentIds); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, null, 0, parentIds); // Verify there is no more than one result from same parent Set parentIdSet = toParentIdSet(results, idToParentIdMap); assertEquals(results.length, parentIdSet.size());