diff --git a/CHANGELOG.md b/CHANGELOG.md index 93ad21dc9..69b213ad8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.12...2.x) ### Features ### Enhancements +* Optize Faiss Query With Filters: Reduce iteration and memory for id filter [#1402](https://github.com/opensearch-project/k-NN/pull/1402) ### Bug Fixes ### Infrastructure ### Documentation diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 078526000..e99cdafb2 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -43,7 +43,8 @@ namespace knn_jni { // // Return an array of KNNQueryResults jobjectArray QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ, jintArray parentIdsJ); + jfloatArray queryVectorJ, jint kJ, jlongArray filterIdsJ, + jint filterIdsTypeJ, jintArray parentIdsJ); // Free the index located in memory at indexPointerJ void Free(jlong indexPointer); diff --git a/jni/include/jni_util.h b/jni/include/jni_util.h index b4dd44891..52b08a202 100644 --- a/jni/include/jni_util.h +++ b/jni/include/jni_util.h @@ -80,6 +80,8 @@ namespace knn_jni { virtual int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ) = 0; + virtual int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ) = 0; + virtual int GetJavaBytesArrayLength(JNIEnv *env, jbyteArray arrayJ) = 0; virtual int GetJavaFloatArrayLength(JNIEnv *env, jfloatArray arrayJ) = 0; @@ -94,6 +96,8 @@ namespace knn_jni { virtual jint * GetIntArrayElements(JNIEnv *env, jintArray array, jboolean * isCopy) = 0; + virtual jlong * GetLongArrayElements(JNIEnv *env, jlongArray array, jboolean * isCopy) = 0; + virtual jobject GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index) = 0; virtual jobject NewObject(JNIEnv *env, jclass clazz, jmethodID methodId, int id, float distance) = 0; @@ -108,6 +112,8 @@ namespace knn_jni { virtual void ReleaseIntArrayElements(JNIEnv *env, jintArray array, jint *elems, jint mode) = 0; + virtual void ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode) = 0; + virtual void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val) = 0; virtual void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf) = 0; @@ -139,6 +145,7 @@ namespace knn_jni { int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ); int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ); int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ); + int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ); int GetJavaBytesArrayLength(JNIEnv *env, jbyteArray arrayJ); int GetJavaFloatArrayLength(JNIEnv *env, jfloatArray arrayJ); @@ -146,6 +153,7 @@ namespace knn_jni { jbyte * GetByteArrayElements(JNIEnv *env, jbyteArray array, jboolean * isCopy); jfloat * GetFloatArrayElements(JNIEnv *env, jfloatArray array, jboolean * isCopy); jint * GetIntArrayElements(JNIEnv *env, jintArray array, jboolean * isCopy); + jlong * GetLongArrayElements(JNIEnv *env, jlongArray array, jboolean * isCopy); jobject GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index); jobject NewObject(JNIEnv *env, jclass clazz, jmethodID methodId, int id, float distance); jobjectArray NewObjectArray(JNIEnv *env, jsize len, jclass clazz, jobject init); @@ -153,6 +161,7 @@ namespace knn_jni { void ReleaseByteArrayElements(JNIEnv *env, jbyteArray array, jbyte *elems, int mode); void ReleaseFloatArrayElements(JNIEnv *env, jfloatArray array, jfloat *elems, int mode); void ReleaseIntArrayElements(JNIEnv *env, jintArray array, jint *elems, jint mode); + void ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode); void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val); void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf); diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index aefadcee4..3b649c227 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -56,7 +56,7 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd * Signature: (J[FI[J)[Lorg/opensearch/knn/index/query/KNNQueryResult; */ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter - (JNIEnv *, jclass, jlong, jfloatArray, jint, jintArray, jintArray); + (JNIEnv *, jclass, jlong, jfloatArray, jint, jlongArray, jint, jintArray); /* * Class: org_opensearch_knn_jni_FaissService diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index e88254b86..43f1454b5 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -29,9 +29,32 @@ // Defines type of IDSelector enum FilterIdsSelectorType{ - BITMAP, BATCH + BITMAP = 0, BATCH = 1, }; +namespace faiss { +// Using jlong to do Bitmap selector, jlong[] equals to lucene FixedBitSet#bits +struct IDSelectorJlongBitmap : IDSelector { + size_t n; + const jlong* bitmap; + + /** Construct with a binary mask like Lucene FixedBitSet + * + * @param n size of the bitmap array + * @param bitmap id like Lucene FixedBitSet bits + */ + IDSelectorJlongBitmap(size_t n, const jlong* bitmap) : n(n), bitmap(bitmap) {}; + bool is_member(idx_t id) const final { + uint64_t index = id; + uint64_t i = index >> 6; // div 64 + if (i >= n ) { + return false; + } + return (bitmap[i] >> ( index & 63)) & 1L; + } + ~IDSelectorJlongBitmap() override {} +}; +} // Translate space type to faiss metric faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType); @@ -42,9 +65,6 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, // Train an index with data provided void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x); -// Helps to choose the right FilterIdsSelectorType for Faiss -FilterIdsSelectorType getIdSelectorType(const int* filterIds, int filterIdsLength); - // Converts the int FilterIds to Faiss ids type array. void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds); @@ -199,11 +219,12 @@ jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNI jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ) { - return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, nullptr, parentIdsJ); + return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, nullptr, 0, parentIdsJ); } jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ, jintArray parentIdsJ) { + jfloatArray queryVectorJ, jint kJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { + if (queryVectorJ == nullptr) { throw std::runtime_error("Query Vector cannot be null"); } @@ -225,28 +246,14 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter omp_set_num_threads(1); // create the filterSearch params if the filterIdsJ is not a null pointer if(filterIdsJ != nullptr) { - int *filteredIdsArray = jniUtil->GetIntArrayElements(env, filterIdsJ, nullptr); - int filterIdsLength = jniUtil->GetJavaIntArrayLength(env, filterIdsJ); + jlong *filteredIdsArray = jniUtil->GetLongArrayElements(env, filterIdsJ, nullptr); + int filterIdsLength = jniUtil->GetJavaLongArrayLength(env, filterIdsJ); std::unique_ptr idSelector; - FilterIdsSelectorType idSelectorType = getIdSelectorType(filteredIdsArray, filterIdsLength); - // start with empty vectors for 2 different types of empty Selectors. We need define them here to avoid copying of data - // during the returns. We could have used pass by reference, but we choose pointers. Returning reference to local - // vector is also an option which can be efficient than copying during returns but it requires upto date C++ compilers. - // To avoid all those confusions, its better to work with pointers here. Ref: https://cplusplus.com/forum/general/56177/ - std::vector convertedIds; - std::vector bitmap; - // Choose a selector which suits best - if(idSelectorType == BATCH) { - convertedIds.resize(filterIdsLength); - convertFilterIdsToFaissIdType(filteredIdsArray, filterIdsLength, convertedIds.data()); - idSelector.reset(new faiss::IDSelectorBatch(convertedIds.size(), convertedIds.data())); + if(filterIdsTypeJ == BITMAP) { + idSelector.reset(new faiss::IDSelectorJlongBitmap(filterIdsLength, filteredIdsArray)); } else { - int maxIdValue = filteredIdsArray[filterIdsLength - 1]; - // >> 3 is equivalent to value / 8 - const int bitsetArraySize = (maxIdValue >> 3) + 1; - bitmap.resize(bitsetArraySize, 0); - buildFilterIdsBitMap(filteredIdsArray, filterIdsLength, bitmap.data()); - idSelector.reset(new faiss::IDSelectorBitmap(bitsetArraySize, bitmap.data())); + faiss::idx_t* batchIndices = reinterpret_cast(filteredIdsArray); + idSelector.reset(new faiss::IDSelectorBatch(filterIdsLength, batchIndices)); } faiss::SearchParameters *searchParameters; faiss::SearchParametersHNSW hnswParams; @@ -276,10 +283,10 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data(), searchParameters); } catch (...) { jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); - jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); + jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); throw; } - jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); + jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); } else { faiss::SearchParameters *searchParameters = nullptr; faiss::SearchParametersHNSW hnswParams; @@ -454,63 +461,6 @@ void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) { } } -/** - * This function takes a call on what ID Selector to use: - * https://github.com/facebookresearch/faiss/wiki/Setting-search-parameters-for-one-query#idselectorarray-idselectorbatch-and-idselectorbitmap - * - * class storage lookup construction(Opensearch + Faiss) - * IDSelectorArray O(k) O(k) O(2k) - * IDSelectorBatch O(k) O(1) O(2k) - * IDSelectorBitmap O(n/8) O(1) O(k) -> n is the max value of id in the index - * - * TODO: We need to ideally decide when we can take another hit of K iterations in latency. Some facts: - * an OpenSearch Index can have max segment size as 5GB which, which on a vector with dimension of 128 boils down to - * 7.5M vectors. - * Ref: https://opensearch.org/docs/latest/search-plugins/knn/knn-index/#hnsw-memory-estimation - * M = 16 - * Dimension = 128 - * (1.1 * ( 4 * 128 + 8 * 16) * 7500000)/(1024*1024*1024) ~ 4.9GB - * Ids are sequential in a Segment which means for IDSelectorBitmap total size if the max ID has value of 7.5M will be - * 7500000/(8*1024) = 915KBs in worst case. But with larger dimensions this worst case value will decrease. - * - * With 915KB how many ids can be represented as an array of 64-bit longs : 117,120 ids - * So iterating on 117k ids for 1 single pass is also time consuming. So, we are currently concluding to consider only size - * as factor. We need to improve on this. - * - * TODO: Best way is to implement a SparseBitSet in C++. This can be done by extending the IDSelector Interface of Faiss. - * - * @param filterIds - * @param filterIdsLength - * @return std::string - */ -FilterIdsSelectorType getIdSelectorType(const int* filterIds, int filterIdsLength) { - int maxIdValue = filterIds[filterIdsLength - 1]; - if(filterIdsLength * sizeof(faiss::idx_t) * 8 <= maxIdValue ) { - return BATCH; - } - return BITMAP; -} - -void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds) { - for (int i = 0; i < filterIdsLength; i++) { - convertedFilterIds[i] = filterIds[i]; - } -} - -void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bitsetVector) { - /** - * Coming from Faiss IDSelectorBitmap::is_member function bitmap id will be selected - * iff id / 8 < n and bit number (i%8) of bitmap[floor(i / 8)] is 1. - */ - for(int i = 0 ; i < filterIdsLength ; i ++) { - int value = filterIds[i]; - // / , % are expensive operation. Hence, using BitShift operation as they are fast. - int bitsetArrayIndex = value >> 3 ; // is equivalent to value / 8 - // (value & 7) equivalent to value % 8 - bitsetVector[bitsetArrayIndex] = bitsetVector[bitsetArrayIndex] | (1 << (value & 7)); - } -} - std::unique_ptr buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector* bitmap) { int *parentIdsArray = jniUtil->GetIntArrayElements(env, parentIdsJ, nullptr); int parentIdsLength = jniUtil->GetJavaIntArrayLength(env, parentIdsJ); diff --git a/jni/src/jni_util.cpp b/jni/src/jni_util.cpp index cb9270c22..a0c1d5733 100644 --- a/jni/src/jni_util.cpp +++ b/jni/src/jni_util.cpp @@ -319,6 +319,17 @@ int knn_jni::JNIUtil::GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ) { return length; } +int knn_jni::JNIUtil::GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ) { + + if (arrayJ == nullptr) { + throw std::runtime_error("Array cannot be null"); + } + + int length = env->GetArrayLength(arrayJ); + this->HasExceptionInStack(env, "Unable to get array length"); + return length; +} + int knn_jni::JNIUtil::GetJavaBytesArrayLength(JNIEnv *env, jbyteArray arrayJ) { if (arrayJ == nullptr) { @@ -376,6 +387,17 @@ jint * knn_jni::JNIUtil::GetIntArrayElements(JNIEnv *env, jintArray array, jbool return intArray; } +jlong * knn_jni::JNIUtil::GetLongArrayElements(JNIEnv *env, jlongArray array, jboolean * isCopy) { + // Lets check for error here + jlong * longArray = env->GetLongArrayElements(array, isCopy); + if (longArray == nullptr) { + this->HasExceptionInStack(env, "Unable to get long array"); + throw std::runtime_error("Unable to get long array"); + } + + return longArray; +} + jobject knn_jni::JNIUtil::GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index) { jobject object = env->GetObjectArrayElement(array, index); this->HasExceptionInStack(env, "Unable to get object"); @@ -424,6 +446,10 @@ void knn_jni::JNIUtil::ReleaseIntArrayElements(JNIEnv *env, jintArray array, jin env->ReleaseIntArrayElements(array, elems, mode); } +void knn_jni::JNIUtil::ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode) { + env->ReleaseLongArrayElements(array, elems, mode); +} + void knn_jni::JNIUtil::SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val) { env->SetObjectArrayElement(array, index, val); this->HasExceptionInStack(env, "Unable to set object array element"); diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index a7b24fcab..e8e761ad7 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -89,10 +89,10 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd } JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter - (JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jintArray filteredIdsJ, jintArray parentIdsJ) { + (JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jlongArray filteredIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { try { - return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ, parentIdsJ); + return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ, filterIdsTypeJ, parentIdsJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 58daaee2b..4318e8ef9 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -251,9 +251,13 @@ TEST(FaissQueryIndexWithFilterTest1435, BasicAssertions) { queries.push_back(query); } - std::vector filterIds; + int num_bits = test_util::bits2words(164); + std::vector bitmap(num_bits,0); + std::vector filterIds; + for (int64_t i = 154; i < 163; i++) { filterIds.push_back(i); + test_util::setBitSet(i, bitmap.data(), bitmap.size()); } std::unordered_set filterIdSet(filterIds.begin(), filterIds.end()); @@ -270,9 +274,9 @@ TEST(FaissQueryIndexWithFilterTest1435, BasicAssertions) { JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; EXPECT_CALL(mockJNIUtil, - GetJavaIntArrayLength( - jniEnv, reinterpret_cast(&filterIds))) - .WillRepeatedly(Return(filterIds.size())); + GetJavaLongArrayLength( + jniEnv, reinterpret_cast(&bitmap))) + .WillRepeatedly(Return(bitmap.size())); int k = 20; for (auto query : queries) { @@ -282,7 +286,7 @@ TEST(FaissQueryIndexWithFilterTest1435, BasicAssertions) { &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), reinterpret_cast(&query), k, - reinterpret_cast(&filterIds), nullptr))); + reinterpret_cast(&bitmap), 0, nullptr))); ASSERT_TRUE(results->size() <= filterIds.size()); ASSERT_TRUE(results->size() > 0); diff --git a/jni/tests/test_util.cpp b/jni/tests/test_util.cpp index a75886c51..3c6933d89 100644 --- a/jni/tests/test_util.cpp +++ b/jni/tests/test_util.cpp @@ -116,6 +116,15 @@ test_util::MockJNIUtil::MockJNIUtil() { reinterpret_cast *>(arrayJ)->data()); }); + + // arrayJ is re-interpreted as a std::vector * and then the data is + // re-interpreted as a jlong * + ON_CALL(*this, GetLongArrayElements) + .WillByDefault([this](JNIEnv *env, jlongArray arrayJ, jboolean *isCopy) { + return reinterpret_cast( + reinterpret_cast *>(arrayJ)->data()); + }); + // arrayJ is re-interpreted as a std::vector * and then the data is // re-interpreted as a jfloat * ON_CALL(*this, GetFloatArrayElements) @@ -146,6 +155,13 @@ test_util::MockJNIUtil::MockJNIUtil() { return reinterpret_cast *>(arrayJ)->size(); }); + // arrayJ is re-interpreted as a std::vector * and then the size is + // returned + ON_CALL(*this, GetJavaLongArrayLength) + .WillByDefault([this](JNIEnv *env, jlongArray arrayJ) { + return reinterpret_cast *>(arrayJ)->size(); + }); + // arrayJ is re-interpreted as a std::vector> * and then // the 'index' element is re-interpreted as a jobject ON_CALL(*this, GetObjectArrayElement) @@ -193,6 +209,11 @@ test_util::MockJNIUtil::MockJNIUtil() { .WillByDefault( [this](JNIEnv *env, jintArray array, jint *elems, int mode) {}); + // This function should not do anything meaningful in the unit tests + ON_CALL(*this, ReleaseLongArrayElements) + .WillByDefault( + [this](JNIEnv *env, jlongArray array, jlong *elems, int mode) {}); + // array is re-interpreted as a std::vector * and then the bytes from // buf are copied to it ON_CALL(*this, SetByteArrayRegion) @@ -347,3 +368,16 @@ float test_util::RandomFloat(float min, float max) { std::uniform_real_distribution distribution(min, max); return distribution(e1); } + +size_t test_util::bits2words(uint64_t numBits) { + return ((numBits - 1) >> 6) + 1; +} + +void test_util::setBitSet(uint64_t value, jlong* array, size_t size) { + uint64_t wordNum = value >> 6; + if (wordNum >= size ) { + return; + } + jlong bitmask = (1L << (value & 63)); + array[wordNum] |= bitmask; +} \ No newline at end of file diff --git a/jni/tests/test_util.h b/jni/tests/test_util.h index 6eac70fcf..a8ce6ca8f 100644 --- a/jni/tests/test_util.h +++ b/jni/tests/test_util.h @@ -64,9 +64,12 @@ namespace test_util { (JNIEnv * env, jobjectArray array2dJ)); MOCK_METHOD(jint*, GetIntArrayElements, (JNIEnv * env, jintArray array, jboolean* isCopy)); + MOCK_METHOD(jlong*, GetLongArrayElements, + (JNIEnv * env, jlongArray array, jboolean* isCopy)); MOCK_METHOD(int, GetJavaBytesArrayLength, (JNIEnv * env, jbyteArray arrayJ)); MOCK_METHOD(int, GetJavaFloatArrayLength, (JNIEnv * env, jfloatArray arrayJ)); MOCK_METHOD(int, GetJavaIntArrayLength, (JNIEnv * env, jintArray arrayJ)); + MOCK_METHOD(int, GetJavaLongArrayLength, (JNIEnv * env, jlongArray arrayJ)); MOCK_METHOD(int, GetJavaObjectArrayLength, (JNIEnv * env, jobjectArray arrayJ)); MOCK_METHOD(jobject, GetObjectArrayElement, @@ -86,6 +89,8 @@ namespace test_util { (JNIEnv * env, jfloatArray array, jfloat* elems, int mode)); MOCK_METHOD(void, ReleaseIntArrayElements, (JNIEnv * env, jintArray array, jint* elems, jint mode)); + MOCK_METHOD(void, ReleaseLongArrayElements, + (JNIEnv * env, jlongArray array, jlong* elems, jint mode)); MOCK_METHOD(void, SetByteArrayRegion, (JNIEnv * env, jbyteArray array, jsize start, jsize len, const jbyte* buf)); @@ -150,6 +155,10 @@ namespace test_util { float RandomFloat(float min, float max); + // returns the number of 64 bit words it would take to hold numBits + size_t bits2words(uint64_t numBits); + + void setBitSet(uint64_t value, jlong* array, size_t size); // ------------------------------------------------------------------------------- } // namespace test_util diff --git a/src/main/java/org/opensearch/knn/index/query/FilterIdsSelector.java b/src/main/java/org/opensearch/knn/index/query/FilterIdsSelector.java new file mode 100644 index 000000000..bf06e8c5e --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/FilterIdsSelector.java @@ -0,0 +1,107 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.query; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; +import org.apache.lucene.util.FixedBitSet; + +import java.io.IOException; + +/** + * Util Class for filter ids selector + */ +@AllArgsConstructor +@Getter +public class FilterIdsSelector { + + /** + * When do ann query with filters, there are two types: + * BitMap using FixedBitSet, BATCH using a long array stands for filter result docids. + */ + @AllArgsConstructor + @Getter + public enum FilterIdsSelectorType { + BITMAP(0), + BATCH(1); + + private final int value; + } + + long[] filterIds; + private FilterIdsSelectorType filterType; + + /** + * This function takes a call on what ID Selector to use: + * https://github.com/facebookresearch/faiss/wiki/Setting-search-parameters-for-one-query#idselectorarray-idselectorbatch-and-idselectorbitmap + * + * class storage lookup construction(Opensearch + Faiss) + * IDSelectorArray O(k) O(k) O(2k) + * IDSelectorBatch O(k) O(1) O(2k) + * IDSelectorBitmap O(n/8) O(1) O(k) n is the max value of id in the index + * + * TODO: We need to ideally decide when we can take another hit of K iterations in latency. Some facts: + * an OpenSearch Index can have max segment size as 5GB which, which on a vector with dimension of 128 boils down to + * 7.5M vectors. + * Ref: https://opensearch.org/docs/latest/search-plugins/knn/knn-index/#hnsw-memory-estimation + * M = 16 + * Dimension = 128 + * (1.1 * ( 4 * 128 + 8 * 16) * 7500000)/(1024*1024*1024) ~ 4.9GB + * Ids are sequential in a Segment which means for IDSelectorBitmap total size if the max ID has value of 7.5M will be + * 7500000/(8*1024) = 915KBs in worst case. But with larger dimensions this worst case value will decrease. + * + * With 915KB how many ids can be represented as an array of 64-bit longs : 117,120 ids + * So iterating on 117k ids for 1 single pass is also time consuming. So, we are currently concluding to consider only size + * as factor. We need to improve on this. + * + * Array Memory: Cardinality * Long.BYTES + * BitSet Memory: MaxId / Byte.SIZE + * When Array Memory less than or equal to BitSet Memory return FilterIdsSelectorType.BATCH + * Else return FilterIdsSelectorType.BITMAP; + * + * @param filterIdsBitSet Filter query result docs + * @param cardinality The number of bits that are set + * @return {@link FilterIdsSelector} + */ + public static FilterIdsSelector getFilterIdSelector(final BitSet filterIdsBitSet, final int cardinality) throws IOException { + long[] filterIds; + FilterIdsSelector.FilterIdsSelectorType filterType; + if (filterIdsBitSet instanceof FixedBitSet) { + /** + * When filterIds is dense filter, using fixed bitset + */ + filterIds = ((FixedBitSet) filterIdsBitSet).getBits(); + filterType = FilterIdsSelector.FilterIdsSelectorType.BITMAP; + } else if ((cardinality * Long.BYTES * Byte.SIZE) <= filterIdsBitSet.length()) { + /** + * When filterIds is sparse bitset, using ram usage to decide FilterIdsSelectorType + */ + BitSetIterator bitSetIterator = new BitSetIterator(filterIdsBitSet, cardinality); + filterIds = new long[cardinality]; + int idx = 0; + for (int docId = bitSetIterator.nextDoc(); docId != DocIdSetIterator.NO_MORE_DOCS; docId = bitSetIterator.nextDoc()) { + filterIds[idx++] = docId; + } + filterType = FilterIdsSelectorType.BATCH; + } else { + FixedBitSet fixedBitSet = new FixedBitSet(filterIdsBitSet.length()); + BitSetIterator sparseBitSetIterator = new BitSetIterator(filterIdsBitSet, cardinality); + fixedBitSet.or(sparseBitSetIterator); + filterIds = fixedBitSet.getBits(); + filterType = FilterIdsSelector.FilterIdsSelectorType.BITMAP; + } + return new FilterIdsSelector(filterIds, filterType); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index afa7e93b0..5bd4e9359 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -100,11 +100,13 @@ public Explanation explain(LeafReaderContext context, int doc) { @Override public Scorer scorer(LeafReaderContext context) throws IOException { - final int[] filterIdsArray = getFilterIdsArray(context); + + final BitSet filterBitSet = getFilteredDocsBitSet(context); + int cardinality = filterBitSet.cardinality(); // We don't need to go to JNI layer if no documents are found which satisfy the filters // We should give this condition a deeper look that where it should be placed. For now I feel this is a good // place, - if (filterWeight != null && filterIdsArray.length == 0) { + if (filterWeight != null && cardinality == 0) { return KNNScorer.emptyScorer(this); } final Map docIdsToScoreMap = new HashMap<>(); @@ -114,22 +116,22 @@ public Scorer scorer(LeafReaderContext context) throws IOException { * . Hence, if filtered results are less than K and filter query is present we should shift to exact search. * This improves the recall. */ - if (filterWeight != null && canDoExactSearch(filterIdsArray.length)) { - docIdsToScoreMap.putAll(doExactSearch(context, filterIdsArray)); + if (filterWeight != null && canDoExactSearch(cardinality)) { + docIdsToScoreMap.putAll(doExactSearch(context, filterBitSet)); } else { - Map annResults = doANNSearch(context, filterIdsArray); + Map annResults = doANNSearch(context, filterBitSet, cardinality); if (annResults == null) { return null; } - if (canDoExactSearchAfterANNSearch(filterIdsArray.length, annResults.size())) { + if (canDoExactSearchAfterANNSearch(cardinality, annResults.size())) { log.debug( "Doing ExactSearch after doing ANNSearch as the number of documents returned are less than " + "K, even when we have more than K filtered Ids. K: {}, ANNResults: {}, filteredIdCount: {}", knnQuery.getK(), annResults.size(), - filterIdsArray.length + cardinality ); - annResults = doExactSearch(context, filterIdsArray); + annResults = doExactSearch(context, filterBitSet); } docIdsToScoreMap.putAll(annResults); } @@ -139,7 +141,11 @@ public Scorer scorer(LeafReaderContext context) throws IOException { return convertSearchResponseToScorer(docIdsToScoreMap); } - private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx, final Weight filterWeight) throws IOException { + private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException { + if (this.filterWeight == null) { + return new FixedBitSet(0); + } + final Bits liveDocs = ctx.reader().getLiveDocs(); final int maxDoc = ctx.reader().maxDoc(); @@ -166,13 +172,6 @@ protected boolean match(int doc) { return BitSet.of(filterIterator, maxDoc); } - private int[] getFilterIdsArray(final LeafReaderContext context) throws IOException { - if (filterWeight == null) { - return new int[0]; - } - return bitSetToIntArray(getFilteredDocsBitSet(context, this.filterWeight)); - } - private int[] getParentIdsArray(final LeafReaderContext context) throws IOException { if (knnQuery.getParentsFilter() == null) { return null; @@ -194,7 +193,8 @@ private int[] bitSetToIntArray(final BitSet bitSet) { return intArray; } - private Map doANNSearch(final LeafReaderContext context, final int[] filterIdsArray) throws IOException { + private Map doANNSearch(final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality) + throws IOException { SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader()); String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString(); @@ -265,6 +265,10 @@ private Map doANNSearch(final LeafReaderContext context, final i throw new RuntimeException(e); } + // From cardinality select different filterIds type + FilterIdsSelector filterIdsSelector = FilterIdsSelector.getFilterIdSelector(filterIdsBitSet, cardinality); + long[] filterIds = filterIdsSelector.getFilterIds(); + FilterIdsSelector.FilterIdsSelectorType filterType = filterIdsSelector.getFilterType(); // Now that we have the allocation, we need to readLock it indexAllocation.readLock(); try { @@ -277,7 +281,8 @@ private Map doANNSearch(final LeafReaderContext context, final i knnQuery.getQueryVector(), knnQuery.getK(), knnEngine.getName(), - filterIdsArray, + filterIds, + filterType.getValue(), parentIds ); @@ -303,13 +308,13 @@ private Map doANNSearch(final LeafReaderContext context, final i .collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType))); } - private Map doExactSearch(final LeafReaderContext leafReaderContext, final int[] filterIdsArray) throws IOException { + private Map doExactSearch(final LeafReaderContext leafReaderContext, final BitSet filterIdsBitSet) { try { // Creating min heap and init with MAX DocID and Score as -INF. final HitQueue queue = new HitQueue(this.knnQuery.getK(), true); ScoreDoc topDoc = queue.top(); final Map docToScore = new HashMap<>(); - FilteredIdsKNNIterator iterator = getFilteredKNNIterator(leafReaderContext, filterIdsArray); + FilteredIdsKNNIterator iterator = getFilteredKNNIterator(leafReaderContext, filterIdsBitSet); int docId; while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { if (iterator.score() > topDoc.score) { @@ -340,16 +345,16 @@ private Map doExactSearch(final LeafReaderContext leafReaderCont return Collections.emptyMap(); } - private FilteredIdsKNNIterator getFilteredKNNIterator(final LeafReaderContext leafReaderContext, final int[] filterIdsArray) + private FilteredIdsKNNIterator getFilteredKNNIterator(final LeafReaderContext leafReaderContext, final BitSet filterIdsBitSet) throws IOException { final SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader()); final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName()); final SpaceType spaceType = getSpaceType(fieldInfo); return knnQuery.getParentsFilter() == null - ? new FilteredIdsKNNIterator(filterIdsArray, knnQuery.getQueryVector(), values, spaceType) + ? new FilteredIdsKNNIterator(filterIdsBitSet, knnQuery.getQueryVector(), values, spaceType) : new NestedFilteredIdsKNNIterator( - filterIdsArray, + filterIdsBitSet, knnQuery.getQueryVector(), values, spaceType, diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIterator.java index a286829d5..a53cb8d60 100644 --- a/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIterator.java @@ -7,6 +7,8 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.BytesRef; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.util.KNNVectorSerializer; @@ -23,23 +25,26 @@ */ public class FilteredIdsKNNIterator { // Array of doc ids to iterate - protected final int[] filterIdsArray; + protected final BitSet filterIdsBitSet; + protected final BitSetIterator bitSetIterator; protected final float[] queryVector; protected final BinaryDocValues binaryDocValues; protected final SpaceType spaceType; protected float currentScore = Float.NEGATIVE_INFINITY; - protected int currentPos = 0; + protected int docId; public FilteredIdsKNNIterator( - final int[] filterIdsArray, + final BitSet filterIdsBitSet, final float[] queryVector, final BinaryDocValues binaryDocValues, final SpaceType spaceType ) { - this.filterIdsArray = filterIdsArray; + this.filterIdsBitSet = filterIdsBitSet; + this.bitSetIterator = new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length()); this.queryVector = queryVector; this.binaryDocValues = binaryDocValues; this.spaceType = spaceType; + this.docId = bitSetIterator.nextDoc(); } /** @@ -49,13 +54,14 @@ public FilteredIdsKNNIterator( * @return next doc id */ public int nextDoc() throws IOException { - if (currentPos >= filterIdsArray.length) { + + if (docId == DocIdSetIterator.NO_MORE_DOCS) { return DocIdSetIterator.NO_MORE_DOCS; } - int docId = binaryDocValues.advance(filterIdsArray[currentPos]); + int doc = binaryDocValues.advance(docId); currentScore = computeScore(); - currentPos++; - return docId; + docId = bitSetIterator.nextDoc(); + return doc; } public float score() { diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIterator.java index d2e4d3e25..9776ebbe9 100644 --- a/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIterator.java @@ -20,7 +20,7 @@ public class NestedFilteredIdsKNNIterator extends FilteredIdsKNNIterator { private final BitSet parentBitSet; public NestedFilteredIdsKNNIterator( - final int[] filterIdsArray, + final BitSet filterIdsArray, final float[] queryVector, final BinaryDocValues values, final SpaceType spaceType, @@ -38,20 +38,22 @@ public NestedFilteredIdsKNNIterator( */ @Override public int nextDoc() throws IOException { - if (currentPos >= filterIdsArray.length) { + if (docId == DocIdSetIterator.NO_MORE_DOCS) { return DocIdSetIterator.NO_MORE_DOCS; } + currentScore = Float.NEGATIVE_INFINITY; - int currentParent = parentBitSet.nextSetBit(filterIdsArray[currentPos]); + int currentParent = parentBitSet.nextSetBit(docId); int bestChild = -1; - while (currentPos < filterIdsArray.length && filterIdsArray[currentPos] < currentParent) { - binaryDocValues.advance(filterIdsArray[currentPos]); + + while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) { + binaryDocValues.advance(docId); float score = computeScore(); if (score > currentScore) { - bestChild = filterIdsArray[currentPos]; + bestChild = docId; currentScore = score; } - currentPos++; + docId = bitSetIterator.nextDoc(); } return bestChild; diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index abf3e052a..f330352ec 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -104,7 +104,8 @@ public static native KNNQueryResult[] queryIndexWithFilter( long indexPointer, float[] queryVector, int k, - int[] filterIds, + long[] filterIds, + int filterIdsType, int[] parentIds ); diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index beef9f927..2835be23d 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -99,6 +99,7 @@ public static long loadIndex(String indexPath, Map parameters, S * @param k neighbors to be returned * @param engineName name of engine to query index * @param filteredIds array of ints on which should be used for search. + * @param filterIdsType how to filter ids: Batch or BitMap * @return KNNQueryResult array of k neighbors */ public static KNNQueryResult[] queryIndex( @@ -106,7 +107,8 @@ public static KNNQueryResult[] queryIndex( float[] queryVector, int k, String engineName, - int[] filteredIds, + long[] filteredIds, + int filterIdsType, int[] parentIds ) { if (KNNEngine.NMSLIB.getName().equals(engineName)) { @@ -119,7 +121,7 @@ public static KNNQueryResult[] queryIndex( // filterIds. FilterIds is coming as empty then its the case where we need to do search with Faiss engine // normally. if (ArrayUtils.isNotEmpty(filteredIds)) { - return FaissService.queryIndexWithFilter(indexPointer, queryVector, k, filteredIds, parentIds); + return FaissService.queryIndexWithFilter(indexPointer, queryVector, k, filteredIds, filterIdsType, parentIds); } return FaissService.queryIndex(indexPointer, queryVector, k, parentIds); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java index b1d9d9434..472f05113 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java @@ -341,7 +341,7 @@ public static void assertLoadableByEngine( ); int k = 2; float[] queryVector = new float[dimension]; - KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, knnEngine.getName(), null, null); + KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, knnEngine.getName(), null, 0, null); assertTrue(results.length > 0); JNIService.free(indexPtr, knnEngine.getName()); } diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java index 798c64a17..1f1088de1 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java @@ -74,7 +74,7 @@ public void testIndexLoadStrategy_load() throws IOException { // Confirm that the file was loaded by querying float[] query = new float[dimension]; Arrays.fill(query, numVectors + 1); - KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, knnEngine.getName(), null, null); + KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, knnEngine.getName(), null, 0, null); assertTrue(results.length > 0); } diff --git a/src/test/java/org/opensearch/knn/index/query/FilterIdsSelectorTests.java b/src/test/java/org/opensearch/knn/index/query/FilterIdsSelectorTests.java new file mode 100644 index 000000000..02b1553a3 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/FilterIdsSelectorTests.java @@ -0,0 +1,60 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.query; + +import lombok.SneakyThrows; +import org.apache.lucene.util.BitSetIterator; +import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.SparseFixedBitSet; +import org.opensearch.knn.KNNTestCase; + +public class FilterIdsSelectorTests extends KNNTestCase { + + @SneakyThrows + public void testGetIdSelectorTypeWithFixedBitSet() { + FixedBitSet bits = new FixedBitSet(101); + for (int i = 1; i <= 100; i++) { + bits.set(i); + } + FilterIdsSelector idsSelector = FilterIdsSelector.getFilterIdSelector(bits, bits.cardinality()); + assertEquals(idsSelector.getFilterType(), FilterIdsSelector.FilterIdsSelectorType.BITMAP); + assertArrayEquals(bits.getBits(), idsSelector.filterIds); + } + + @SneakyThrows + public void testGetIdSelectorTypeWithSparseBitSetHigh() { + SparseFixedBitSet bits = new SparseFixedBitSet(101); + for (int i = 1; i <= 100; i++) { + bits.set(i); + } + FilterIdsSelector idsSelector = FilterIdsSelector.getFilterIdSelector(bits, bits.cardinality()); + assertEquals(idsSelector.getFilterType(), FilterIdsSelector.FilterIdsSelectorType.BITMAP); + FixedBitSet fixedBitSet = new FixedBitSet(bits.length()); + BitSetIterator sparseBitSetIterator = new BitSetIterator(bits, 101); + fixedBitSet.or(sparseBitSetIterator); + assertArrayEquals(fixedBitSet.getBits(), idsSelector.filterIds); + } + + @SneakyThrows + public void testGetIdSelectorTypeWithSparseBitSetLow() { + int maxDoc = (Integer.MAX_VALUE) / 2; + SparseFixedBitSet bits = new SparseFixedBitSet(maxDoc); + long array[] = new long[100]; + for (int i = maxDoc - 100, idx = 0; i < maxDoc; i++) { + bits.set(i); + array[idx++] = i; + } + FilterIdsSelector idsSelector = FilterIdsSelector.getFilterIdSelector(bits, bits.cardinality()); + assertEquals(idsSelector.getFilterType(), FilterIdsSelector.FilterIdsSelectorType.BATCH); + assertArrayEquals(array, idsSelector.filterIds); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 5d49f052c..49c7c7566 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -157,7 +157,7 @@ public void testQueryScoreForFaissWithModel() { SpaceType spaceType = SpaceType.L2; final Function scoreTranslator = spaceType::scoreTranslation; final String modelId = "modelId"; - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), any())) + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), anyInt(), any())) .thenReturn(getKNNQueryResults()); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); @@ -300,7 +300,7 @@ public void testShardWithoutFiles() { @SneakyThrows public void testEmptyQueryResults() { final KNNQueryResult[] knnQueryResults = new KNNQueryResult[] {}; - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), any())) + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), anyInt(), any())) .thenReturn(knnQueryResults); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); @@ -345,8 +345,13 @@ public void testEmptyQueryResults() { public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { int k = 3; final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterDocIds), any())) - .thenReturn(getFilteredKNNQueryResults()); + FixedBitSet filterBitSet = new FixedBitSet(filterDocIds.length); + for (int docId : filterDocIds) { + filterBitSet.set(docId); + } + jniServiceMockedStatic.when( + () -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterBitSet.getBits()), anyInt(), any()) + ).thenReturn(getFilteredKNNQueryResults()); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); final Bits liveDocsBits = mock(Bits.class); @@ -404,7 +409,10 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); assertNotNull(docIdSetIterator); assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); - jniServiceMockedStatic.verify(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterDocIds), any())); + + jniServiceMockedStatic.verify( + () -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterBitSet.getBits()), anyInt(), any()) + ); final List actualDocIds = new ArrayList<>(); final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); @@ -669,14 +677,17 @@ public void testANNWithParentsFilter_whenDoingANN_thenBitSetIsPassedToJNI() { final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, 1, INDEX_NAME, null, bitSetProducer); final KNNWeight knnWeight = new KNNWeight(query, 0.0f, null); - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), eq(parentsFilter))) - .thenReturn(getKNNQueryResults()); + jniServiceMockedStatic.when( + () -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), anyInt(), eq(parentsFilter)) + ).thenReturn(getKNNQueryResults()); // Execute Scorer knnScorer = knnWeight.scorer(leafReaderContext); // Verify - jniServiceMockedStatic.verify(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), eq(parentsFilter))); + jniServiceMockedStatic.verify( + () -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), anyInt(), eq(parentsFilter)) + ); assertNotNull(knnScorer); final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); assertNotNull(docIdSetIterator); @@ -735,7 +746,7 @@ private void testQueryScore( final Set segmentFiles, final Map fileAttributes ) throws IOException { - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), any())) + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), anyInt(), any())) .thenReturn(getKNNQueryResults()); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); diff --git a/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIteratorTests.java index cfb66662e..dce703050 100644 --- a/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIteratorTests.java @@ -9,6 +9,7 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.FixedBitSet; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.util.KNNVectorAsArraySerializer; @@ -41,12 +42,14 @@ public void testNextDoc_whenCalled_IterateAllDocs() { .collect(Collectors.toList()); when(values.binaryValue()).thenReturn(byteRefs.get(0), byteRefs.get(1), byteRefs.get(2)); + FixedBitSet filterBitSet = new FixedBitSet(4); for (int id : filterIds) { when(values.advance(id)).thenReturn(id); + filterBitSet.set(id); } // Execute and verify - FilteredIdsKNNIterator iterator = new FilteredIdsKNNIterator(filterIds, queryVector, values, spaceType); + FilteredIdsKNNIterator iterator = new FilteredIdsKNNIterator(filterBitSet, queryVector, values, spaceType); for (int i = 0; i < filterIds.length; i++) { assertEquals(filterIds[i], iterator.nextDoc()); assertEquals(expectedScores.get(i), (Float) iterator.score()); diff --git a/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIteratorTests.java index 90d56fddc..d732376ef 100644 --- a/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIteratorTests.java @@ -47,12 +47,20 @@ public void testNextDoc_whenIterate_ReturnBestChildDocsPerParent() { .collect(Collectors.toList()); when(values.binaryValue()).thenReturn(byteRefs.get(0), byteRefs.get(1), byteRefs.get(2)); + FixedBitSet filterBitSet = new FixedBitSet(4); for (int id : filterIds) { when(values.advance(id)).thenReturn(id); + filterBitSet.set(id); } // Execute and verify - NestedFilteredIdsKNNIterator iterator = new NestedFilteredIdsKNNIterator(filterIds, queryVector, values, spaceType, parentBitSet); + NestedFilteredIdsKNNIterator iterator = new NestedFilteredIdsKNNIterator( + filterBitSet, + queryVector, + values, + spaceType, + parentBitSet + ); assertEquals(filterIds[0], iterator.nextDoc()); assertEquals(expectedScores.get(0), iterator.score()); assertEquals(filterIds[2], iterator.nextDoc()); diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index 2afcff83d..8e3382ece 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -537,13 +537,13 @@ public void testQueryIndex_faiss_sqfp16_valid() { assertNotEquals(0, pointer); for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, null, null); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, null, 0, null); assertEquals(k, results.length); } // Filter will result in no ids for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, new int[] { 0 }, null); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, new long[] { 0 }, 0, null); assertEquals(0, results.length); } } @@ -726,12 +726,15 @@ public void testLoadIndex_faiss_valid() throws IOException { } public void testQueryIndex_invalidEngine() { - expectThrows(IllegalArgumentException.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, "invalid" + "-engine", null, null)); + expectThrows( + IllegalArgumentException.class, + () -> JNIService.queryIndex(0L, new float[] {}, 0, "invalid" + "-engine", null, 0, null) + ); } public void testQueryIndex_nmslib_invalid_badPointer() { - expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, KNNEngine.NMSLIB.getName(), null, null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, KNNEngine.NMSLIB.getName(), null, 0, null)); } public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { @@ -754,7 +757,7 @@ public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { ); assertNotEquals(0, pointer); - expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, KNNEngine.NMSLIB.getName(), null, null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, KNNEngine.NMSLIB.getName(), null, 0, null)); } public void testQueryIndex_nmslib_valid() throws IOException { @@ -780,7 +783,7 @@ public void testQueryIndex_nmslib_valid() throws IOException { assertNotEquals(0, pointer); for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.NMSLIB.getName(), null, null); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.NMSLIB.getName(), null, 0, null); assertEquals(k, results.length); } } @@ -788,7 +791,7 @@ public void testQueryIndex_nmslib_valid() throws IOException { public void testQueryIndex_faiss_invalid_badPointer() { - expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, FAISS_NAME, null, null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, FAISS_NAME, null, 0, null)); } public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { @@ -807,7 +810,7 @@ public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), Collections.emptyMap(), FAISS_NAME); assertNotEquals(0, pointer); - expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, FAISS_NAME, null, null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, FAISS_NAME, null, 0, null)); } public void testQueryIndex_faiss_valid() throws IOException { @@ -836,13 +839,13 @@ public void testQueryIndex_faiss_valid() throws IOException { assertNotEquals(0, pointer); for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, null, null); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, null, 0, null); assertEquals(k, results.length); } // Filter will result in no ids for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, new int[] { 0 }, null); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, new long[] { 0 }, 0, null); assertEquals(0, results.length); } } @@ -877,7 +880,7 @@ public void testQueryIndex_faiss_parentIds() throws IOException { assertNotEquals(0, pointer); for (float[] query : testDataNested.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, null, parentIds); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, FAISS_NAME, null, 0, parentIds); // Verify there is no more than one result from same parent Set parentIdSet = toParentIdSet(results, idToParentIdMap); assertEquals(results.length, parentIdSet.size());