From d733ba7fc44dd300fca4d345ab01e2825435b93a Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Fri, 26 Apr 2024 10:17:09 -0700 Subject: [PATCH] * Support filter and nested field in faiss engine radial search (#1652) * Use 0.95 as default ratio for lucene radial search traversal similarity (#1619) Signed-off-by: Junqiu Lei --- CHANGELOG.md | 1 + jni/CMakeLists.txt | 2 +- jni/cmake/init-faiss.cmake | 3 +- jni/include/faiss_wrapper.h | 20 ++- .../org_opensearch_knn_jni_FaissService.h | 14 +- ...patch-to-support-range-search-params.patch | 53 ++++++ jni/src/faiss_wrapper.cpp | 71 +++++++- .../org_opensearch_knn_jni_FaissService.cpp | 19 +- jni/tests/faiss_wrapper_test.cpp | 156 ++++++++++++++++- .../opensearch/knn/common/KNNConstants.java | 5 + .../knn/index/query/KNNQueryBuilder.java | 132 +++++++------- .../opensearch/knn/index/query/KNNWeight.java | 5 +- .../knn/index/query/RNNQueryFactory.java | 17 +- .../org/opensearch/knn/jni/FaissService.java | 31 +++- .../org/opensearch/knn/jni/JNIService.java | 21 ++- .../org/opensearch/knn/index/FaissIT.java | 163 +++++++++++------- .../opensearch/knn/index/LuceneEngineIT.java | 6 +- .../opensearch/knn/index/NestedSearchIT.java | 75 +++++++- .../knn/index/query/KNNQueryBuilderTests.java | 8 +- .../knn/index/query/KNNWeightTests.java | 9 +- 20 files changed, 653 insertions(+), 158 deletions(-) create mode 100644 jni/patches/faiss/0003-Custom-patch-to-support-range-search-params.patch diff --git a/CHANGELOG.md b/CHANGELOG.md index 05ba3be4c..c0dc0b523 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features * Add Clear Cache API [#740](https://github.com/opensearch-project/k-NN/pull/740) * Support radial search in k-NN plugin [#1617](https://github.com/opensearch-project/k-NN/pull/1617) +* Support filter and nested field in faiss engine radial search [#1652](https://github.com/opensearch-project/k-NN/pull/1652) ### Enhancements * Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549) * Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573) diff --git a/jni/CMakeLists.txt b/jni/CMakeLists.txt index 6f0894607..595fa6fea 100644 --- a/jni/CMakeLists.txt +++ b/jni/CMakeLists.txt @@ -36,7 +36,7 @@ endif () # build workflow once, it can cause issues because git commits require that the user and the user's email be set. # See https://github.com/opensearch-project/k-NN/issues/1651. So, we provide a flag that allows users to select between # the two -if(NOT DEFINED COMMIT_LIB_PATCHES OR ${COMMIT_LIB_PATCHES} STREQUAL true) +if(NOT DEFINED COMMIT_LIB_PATCHES OR "${COMMIT_LIB_PATCHES}" STREQUAL true) set(GIT_PATCH_COMMAND am) else() set(GIT_PATCH_COMMAND apply) diff --git a/jni/cmake/init-faiss.cmake b/jni/cmake/init-faiss.cmake index fefed61e2..bc3836b06 100644 --- a/jni/cmake/init-faiss.cmake +++ b/jni/cmake/init-faiss.cmake @@ -13,13 +13,14 @@ if (NOT EXISTS ${FAISS_REPO_DIR}) endif () # Check if patch exist, this is to skip git apply during CI build. See CI.yml with ubuntu. -find_path(PATCH_FILE NAMES 0001-Custom-patch-to-support-multi-vector.patch 0002-Enable-precomp-table-to-be-shared-ivfpq.patch PATHS ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss NO_DEFAULT_PATH) +find_path(PATCH_FILE NAMES 0001-Custom-patch-to-support-multi-vector.patch 0002-Enable-precomp-table-to-be-shared-ivfpq.patch 0003-Custom-patch-to-support-range-search-params.patch PATHS ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss NO_DEFAULT_PATH) # If it exists, apply patches if (EXISTS ${PATCH_FILE}) message(STATUS "Applying custom patches.") execute_process(COMMAND git ${GIT_PATCH_COMMAND} --3way --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE) execute_process(COMMAND git ${GIT_PATCH_COMMAND} --3way --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0002-Enable-precomp-table-to-be-shared-ivfpq.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE) + execute_process(COMMAND git ${GIT_PATCH_COMMAND} --3way --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0003-Custom-patch-to-support-range-search-params.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE) if(RESULT_CODE) message(FATAL_ERROR "Failed to apply patch:\n${ERROR_MSG}") endif() diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index da67c0f59..958eca8ac 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -74,6 +74,22 @@ namespace knn_jni { jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, jlong trainVectorsPointerJ); + /* + * Perform a range search with filter against the index located in memory at indexPointerJ. + * + * @param indexPointerJ - pointer to the index + * @param queryVectorJ - the query vector + * @param radiusJ - the radius for the range search + * @param maxResultsWindowJ - the maximum number of results to return + * @param filterIdsJ - the filter ids + * @param filterIdsTypeJ - the filter ids type + * @param parentIdsJ - the parent ids + * + * @return an array of RangeQueryResults + */ + jobjectArray RangeSearchWithFilter(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, jfloatArray queryVectorJ, + jfloat radiusJ, jint maxResultWindowJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ); + /* * Perform a range search against the index located in memory at indexPointerJ. * @@ -81,10 +97,12 @@ namespace knn_jni { * @param queryVectorJ - the query vector * @param radiusJ - the radius for the range search * @param maxResultsWindowJ - the maximum number of results to return + * @param parentIdsJ - the parent ids + * * @return an array of RangeQueryResults */ jobjectArray RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, jfloatArray queryVectorJ, - jfloat radiusJ, jint maxResultsWindowJ); + jfloat radiusJ, jint maxResultWindowJ, jintArray parentIdsJ); } } diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 3715730ab..e16677db7 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -124,11 +124,19 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors /* * Class: org_opensearch_knn_jni_FaissService -* Method: rangeSearchIndex -* Signature: (J[F[F)J +* Method: rangeSearchIndexWithFilter +* Signature: (J[FJ[I)[Lorg/opensearch/knn/index/query/RangeQueryResult; */ +JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndexWithFilter + (JNIEnv *, jclass, jlong, jfloatArray, jfloat, jint, jlongArray, jint, jintArray); + +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: rangeSearchIndex + * Signature: (J[FJ[I)[Lorg/opensearch/knn/index/query/RangeQueryResult; + */ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex - (JNIEnv *, jclass, jlong, jfloatArray, jfloat, jint); + (JNIEnv *, jclass, jlong, jfloatArray, jfloat, jint, jintArray); #ifdef __cplusplus } diff --git a/jni/patches/faiss/0003-Custom-patch-to-support-range-search-params.patch b/jni/patches/faiss/0003-Custom-patch-to-support-range-search-params.patch new file mode 100644 index 000000000..bdc202bf6 --- /dev/null +++ b/jni/patches/faiss/0003-Custom-patch-to-support-range-search-params.patch @@ -0,0 +1,53 @@ +From af6770b505a32b2c4eab2036d2509dec4b137f28 Mon Sep 17 00:00:00 2001 +From: Junqiu Lei +Date: Tue, 23 Apr 2024 17:18:56 -0700 +Subject: [PATCH] Custom patch to support range search params + +Signed-off-by: Junqiu Lei +--- + faiss/IndexIDMap.cpp | 28 ++++++++++++++++++++++++---- + 1 file changed, 24 insertions(+), 4 deletions(-) + +diff --git a/faiss/IndexIDMap.cpp b/faiss/IndexIDMap.cpp +index 3f375e7b..11f3a847 100644 +--- a/faiss/IndexIDMap.cpp ++++ b/faiss/IndexIDMap.cpp +@@ -176,11 +176,31 @@ void IndexIDMapTemplate::range_search( + RangeSearchResult* result, + const SearchParameters* params) const { + if (params) { +- SearchParameters internal_search_parameters; +- IDSelectorTranslated id_selector_translated(id_map, params->sel); +- internal_search_parameters.sel = &id_selector_translated; ++ IDSelectorTranslated this_idtrans(this->id_map, nullptr); ++ ScopedSelChange sel_change; ++ IDGrouperTranslated this_idgrptrans(this->id_map, nullptr); ++ ScopedGrpChange grp_change; ++ ++ if (params->sel) { ++ auto idtrans = dynamic_cast(params->sel); ++ ++ if (!idtrans) { ++ auto params_non_const = const_cast(params); ++ this_idtrans.sel = params->sel; ++ sel_change.set(params_non_const, &this_idtrans); ++ } ++ } ++ ++ if (params->grp) { ++ auto idtrans = dynamic_cast(params->grp); + +- index->range_search(n, x, radius, result, &internal_search_parameters); ++ if (!idtrans) { ++ auto params_non_const = const_cast(params); ++ this_idgrptrans.grp = params->grp; ++ grp_change.set(params_non_const, &this_idgrptrans); ++ } ++ } ++ index->range_search(n, x, radius, result, params); + } else { + index->range_search(n, x, radius, result); + } +-- +2.39.0 + diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 983cfa8a9..5a0910d9a 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -589,7 +589,12 @@ faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index) { } jobjectArray knn_jni::faiss_wrapper::RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, - jfloatArray queryVectorJ, jfloat radiusJ, jint maxResultWindowJ) { + jfloatArray queryVectorJ, jfloat radiusJ, jint maxResultWindowJ, jintArray parentIdsJ) { + return knn_jni::faiss_wrapper::RangeSearchWithFilter(jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, maxResultWindowJ, nullptr, 0, parentIdsJ); +} + +jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, + jfloatArray queryVectorJ, jfloat radiusJ, jint maxResultWindowJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { if (queryVectorJ == nullptr) { throw std::runtime_error("Query Vector cannot be null"); } @@ -605,7 +610,69 @@ jobjectArray knn_jni::faiss_wrapper::RangeSearch(knn_jni::JNIUtilInterface *jniU // The res will be freed by ~RangeSearchResult() in FAISS // The second parameter is always true, as lims is allocated by FAISS faiss::RangeSearchResult res(1, true); - indexReader->range_search(1, rawQueryVector, radiusJ, &res); + + if(filterIdsJ != nullptr) { + jlong *filteredIdsArray = jniUtil->GetLongArrayElements(env, filterIdsJ, nullptr); + int filterIdsLength = jniUtil->GetJavaLongArrayLength(env, filterIdsJ); + std::unique_ptr idSelector; + if(filterIdsTypeJ == BITMAP) { + idSelector.reset(new faiss::IDSelectorJlongBitmap(filterIdsLength, filteredIdsArray)); + } else { + faiss::idx_t* batchIndices = reinterpret_cast(filteredIdsArray); + idSelector.reset(new faiss::IDSelectorBatch(filterIdsLength, batchIndices)); + } + faiss::SearchParameters *searchParameters; + faiss::SearchParametersHNSW hnswParams; + faiss::SearchParametersIVF ivfParams; + std::unique_ptr idGrouper; + std::vector idGrouperBitmap; + auto hnswReader = dynamic_cast(indexReader->index); + if(hnswReader) { + // Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default + // value of ef_search = 16 which will then be used. + hnswParams.efSearch = hnswReader->hnsw.efSearch; + hnswParams.sel = idSelector.get(); + if (parentIdsJ != nullptr) { + idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); + hnswParams.grp = idGrouper.get(); + } + searchParameters = &hnswParams; + } else { + auto ivfReader = dynamic_cast(indexReader->index); + auto ivfFlatReader = dynamic_cast(indexReader->index); + if(ivfReader || ivfFlatReader) { + ivfParams.sel = idSelector.get(); + searchParameters = &ivfParams; + } + } + try { + indexReader->range_search(1, rawQueryVector, radiusJ, &res, searchParameters); + } catch (...) { + jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryVector, JNI_ABORT); + jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); + throw; + } + } else { + faiss::SearchParameters *searchParameters = nullptr; + faiss::SearchParametersHNSW hnswParams; + std::unique_ptr idGrouper; + std::vector idGrouperBitmap; + auto hnswReader = dynamic_cast(indexReader->index); + if(hnswReader!= nullptr && parentIdsJ != nullptr) { + // Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default + // value of ef_search = 16 which will then be used. + hnswParams.efSearch = hnswReader->hnsw.efSearch; + idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); + hnswParams.grp = idGrouper.get(); + searchParameters = &hnswParams; + } + try { + indexReader->range_search(1, rawQueryVector, radiusJ, &res, searchParameters); + } catch (...) { + jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryVector, JNI_ABORT); + throw; + } + } // lims is structured to support batched queries, it has a length of nq + 1 (where nq is the number of queries), // lims[i] - lims[i-1] gives the number of results for the i-th query. With a single query we used in k-NN, diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index ab2a37e84..0aa51987d 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -194,11 +194,26 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex(JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, - jfloat radiusJ, jint maxResultWindowJ) + jfloat radiusJ, jint maxResultWindowJ, + jintArray parentIdsJ) { try { - return knn_jni::faiss_wrapper::RangeSearch(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, maxResultWindowJ); + return knn_jni::faiss_wrapper::RangeSearch(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, maxResultWindowJ, parentIdsJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return nullptr; +} +JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndexWithFilter(JNIEnv * env, jclass cls, + jlong indexPointerJ, + jfloatArray queryVectorJ, + jfloat radiusJ, jint maxResultWindowJ, + jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) +{ + try { + return knn_jni::faiss_wrapper::RangeSearchWithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, + maxResultWindowJ, filterIdsJ, filterIdsTypeJ, parentIdsJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 07b34976f..e9316dcc2 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -628,7 +628,7 @@ TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) { // Define query data float radius = 100000.0; - int numQueries = 2; + int numQueries = 100; std::vector> queries; for (int i = 0; i < numQueries; i++) { @@ -659,7 +659,7 @@ TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) { knn_jni::faiss_wrapper::RangeSearch( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), radius, maxResultWindow))); + reinterpret_cast(&query), radius, maxResultWindow, nullptr))); // assert result size is not 0 ASSERT_NE(0, results->size()); @@ -684,7 +684,7 @@ TEST(FaissRangeSearchQueryIndexTest_WhenHitMaxWindowResult, BasicAssertions){ // Define query data float radius = 100000.0; - int numQueries = 2; + int numQueries = 100; std::vector> queries; for (int i = 0; i < numQueries; i++) { @@ -715,7 +715,7 @@ TEST(FaissRangeSearchQueryIndexTest_WhenHitMaxWindowResult, BasicAssertions){ knn_jni::faiss_wrapper::RangeSearch( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), radius, maxResultWindow))); + reinterpret_cast(&query), radius, maxResultWindow, nullptr))); // assert result size is not 0 ASSERT_NE(0, results->size()); @@ -728,3 +728,151 @@ TEST(FaissRangeSearchQueryIndexTest_WhenHitMaxWindowResult, BasicAssertions){ } } } + +TEST(FaissRangeSearchQueryIndexTestWithFilterTest, BasicAssertions) { + // Define the index data + faiss::idx_t numIds = 200; + int dim = 2; + std::vector ids = test_util::Range(numIds); + std::vector vectors = test_util::RandomVectors(dim, numIds, randomDataMin, randomDataMax); + + faiss::MetricType metricType = faiss::METRIC_L2; + std::string method = "HNSW32,Flat"; + + // Define query data + float radius = 100000.0; + int numQueries = 100; + std::vector> queries; + + for (int i = 0; i < numQueries; i++) { + std::vector query; + query.reserve(dim); + for (int j = 0; j < dim; j++) { + query.push_back(test_util::RandomFloat(randomDataMin, randomDataMax)); + } + queries.push_back(query); + } + + // Create the index + std::unique_ptr createdIndex( + test_util::FaissCreateIndex(dim, method, metricType)); + auto createdIndexWithData = + test_util::FaissAddData(createdIndex.get(), ids, vectors); + + // Setup jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + int num_bits = test_util::bits2words(164); + std::vector bitmap(num_bits,0); + std::vector filterIds; + + for (int64_t i = 154; i < 163; i++) { + filterIds.push_back(i); + test_util::setBitSet(i, bitmap.data(), bitmap.size()); + } + std::unordered_set filterIdSet(filterIds.begin(), filterIds.end()); + + int maxResultWindow = 20000; + + for (auto query : queries) { + std::unique_ptr *>> results( + reinterpret_cast *> *>( + + knn_jni::faiss_wrapper::RangeSearchWithFilter( + &mockJNIUtil, jniEnv, + reinterpret_cast(&createdIndexWithData), + reinterpret_cast(&query), radius, maxResultWindow, + reinterpret_cast(&bitmap), 0, nullptr))); + + // assert result size is not 0 + ASSERT_NE(0, results->size()); + ASSERT_TRUE(results->size() <= filterIds.size()); + for (const auto& pairPtr : *results) { + auto it = filterIdSet.find(pairPtr->first); + ASSERT_NE(it, filterIdSet.end()); + } + + // Need to free up each result + for (auto it : *results) { + delete it; + } + } +} + +TEST(FaissRangeSearchQueryIndexTestWithParentFilterTest, BasicAssertions) { + // Define the index data + faiss::idx_t numIds = 100; + std::vector ids; + std::vector vectors; + std::vector parentIds; + int dim = 2; + for (int64_t i = 1; i < numIds + 1; i++) { + if (i % 10 == 0) { + parentIds.push_back(i); + continue; + } + ids.push_back(i); + for (int j = 0; j < dim; j++) { + vectors.push_back(test_util::RandomFloat(-500.0, 500.0)); + } + } + + faiss::MetricType metricType = faiss::METRIC_L2; + std::string method = "HNSW32,Flat"; + + // Define query data + float radius = 100000.0; + int numQueries = 1; + std::vector> queries; + + for (int i = 0; i < numQueries; i++) { + std::vector query; + query.reserve(dim); + for (int j = 0; j < dim; j++) { + query.push_back(test_util::RandomFloat(-500.0, 500.0)); + } + queries.push_back(query); + } + + // Create the index + std::unique_ptr createdIndex( + test_util::FaissCreateIndex(dim, method, metricType)); + auto createdIndexWithData = + test_util::FaissAddData(createdIndex.get(), ids, vectors); + + // Setup jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + EXPECT_CALL(mockJNIUtil, + GetJavaIntArrayLength( + jniEnv, reinterpret_cast(&parentIds))) + .WillRepeatedly(Return(parentIds.size())); + + int maxResultWindow = 10000; + + for (auto query : queries) { + std::unique_ptr *>> results( + reinterpret_cast *> *>( + + knn_jni::faiss_wrapper::RangeSearchWithFilter( + &mockJNIUtil, jniEnv, + reinterpret_cast(&createdIndexWithData), + reinterpret_cast(&query), radius, maxResultWindow, nullptr, 0, + reinterpret_cast(&parentIds)))); + + // assert result size is not 0 + ASSERT_NE(0, results->size()); + // Result should be one for each group + std::set idSet; + for (const auto& pairPtr : *results) { + idSet.insert(pairPtr->first / 10); + } + ASSERT_NE(0, idSet.size()); + + // Need to free up each result + for (auto it : *results) { + delete it; + } + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 2c3b03e47..7c5bb61ad 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -132,4 +132,9 @@ public class KNNConstants { // API Constants public static final String CLEAR_CACHE = "clear_cache"; + + // Radial search constants + public static final Float DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO = 0.95f; + public static final String MIN_SCORE = "min_score"; + public static final String MAX_DISTANCE = "max_distance"; } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 3a24c1012..3d3b0969f 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -37,6 +37,8 @@ import org.opensearch.index.query.QueryShardContext; import static org.opensearch.knn.index.IndexUtil.*; +import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE; +import static org.opensearch.knn.common.KNNConstants.MIN_SCORE; import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion; import static org.opensearch.knn.index.util.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH; @@ -52,8 +54,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { public static final ParseField K_FIELD = new ParseField("k"); public static final ParseField FILTER_FIELD = new ParseField("filter"); public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped"); - public static final ParseField MAX_DISTANCE_FIELD = new ParseField("max_distance"); - public static final ParseField MIN_SCORE_FIELD = new ParseField("min_score"); + public static final ParseField MAX_DISTANCE_FIELD = new ParseField(MAX_DISTANCE); + public static final ParseField MIN_SCORE_FIELD = new ParseField(MIN_SCORE); public static final int K_MAX = 10000; /** * The name for the knn query @@ -65,8 +67,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { private final String fieldName; private final float[] vector; private int k = 0; - private Float max_distance = null; - private Float min_score = null; + private Float maxDistance = null; + private Float minScore = null; private QueryBuilder filter; private boolean ignoreUnmapped = false; @@ -78,13 +80,13 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { */ public KNNQueryBuilder(String fieldName, float[] vector) { if (Strings.isNullOrEmpty(fieldName)) { - throw new IllegalArgumentException("[" + NAME + "] requires fieldName"); + throw new IllegalArgumentException(String.format("[%s] requires fieldName", NAME)); } if (vector == null) { - throw new IllegalArgumentException("[" + NAME + "] requires query vector"); + throw new IllegalArgumentException(String.format("[%s] requires query vector", NAME)); } if (vector.length == 0) { - throw new IllegalArgumentException("[" + NAME + "] query vector is empty"); + throw new IllegalArgumentException(String.format("[%s] query vector is empty", NAME)); } this.fieldName = fieldName; this.vector = vector; @@ -97,44 +99,44 @@ public KNNQueryBuilder(String fieldName, float[] vector) { */ public KNNQueryBuilder k(Integer k) { if (k == null) { - throw new IllegalArgumentException("[" + NAME + "] requires k to be set"); + throw new IllegalArgumentException(String.format("[%s] requires k to be set", NAME)); } - validateSingleQueryType(k, max_distance, min_score); + validateSingleQueryType(k, maxDistance, minScore); if (k <= 0 || k > K_MAX) { - throw new IllegalArgumentException("[" + NAME + "] requires 0 < k <= " + K_MAX); + throw new IllegalArgumentException(String.format("[%s] requires k to be in the range (0, %d]", NAME, K_MAX)); } this.k = k; return this; } /** - * Builder method for max_distance + * Builder method for maxDistance * - * @param max_distance the max_distance threshold for the nearest neighbours + * @param maxDistance the maxDistance threshold for the nearest neighbours */ - public KNNQueryBuilder maxDistance(Float max_distance) { - if (max_distance == null) { - throw new IllegalArgumentException("[" + NAME + "] requires max_distance to be set"); + public KNNQueryBuilder maxDistance(Float maxDistance) { + if (maxDistance == null) { + throw new IllegalArgumentException(String.format("[%s] requires maxDistance to be set", NAME)); } - validateSingleQueryType(k, max_distance, min_score); - this.max_distance = max_distance; + validateSingleQueryType(k, maxDistance, minScore); + this.maxDistance = maxDistance; return this; } /** - * Builder method for min_score + * Builder method for minScore * - * @param min_score the min_score threshold for the nearest neighbours + * @param minScore the minScore threshold for the nearest neighbours */ - public KNNQueryBuilder minScore(Float min_score) { - if (min_score == null) { - throw new IllegalArgumentException("[" + NAME + "] requires min_score to be set"); + public KNNQueryBuilder minScore(Float minScore) { + if (minScore == null) { + throw new IllegalArgumentException(String.format("[%s] requires minScore to be set", NAME)); } - validateSingleQueryType(k, max_distance, min_score); - if (min_score <= 0) { - throw new IllegalArgumentException("[" + NAME + "] requires min_score greater than 0"); + validateSingleQueryType(k, maxDistance, minScore); + if (minScore <= 0) { + throw new IllegalArgumentException(String.format("[%s] requires minScore to be greater than 0", NAME)); } - this.min_score = min_score; + this.minScore = minScore; return this; } @@ -161,19 +163,19 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k) { public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder filter) { if (StringUtils.isBlank(fieldName)) { - throw new IllegalArgumentException("[" + NAME + "] requires fieldName"); + throw new IllegalArgumentException(String.format("[%s] requires fieldName", NAME)); } if (vector == null) { - throw new IllegalArgumentException("[" + NAME + "] requires query vector"); + throw new IllegalArgumentException(String.format("[%s] requires query vector", NAME)); } if (vector.length == 0) { - throw new IllegalArgumentException("[" + NAME + "] query vector is empty"); + throw new IllegalArgumentException(String.format("[%s] query vector is empty", NAME)); } if (k <= 0) { - throw new IllegalArgumentException("[" + NAME + "] requires k > 0"); + throw new IllegalArgumentException(String.format("[%s] requires k > 0", NAME)); } if (k > K_MAX) { - throw new IllegalArgumentException("[" + NAME + "] requires k <= " + K_MAX); + throw new IllegalArgumentException(String.format("[%s] requires k <= %d", NAME, K_MAX)); } this.fieldName = fieldName; @@ -181,8 +183,8 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil this.k = k; this.filter = filter; this.ignoreUnmapped = false; - this.max_distance = null; - this.min_score = null; + this.maxDistance = null; + this.minScore = null; } public static void initialize(ModelDao modelDao) { @@ -222,10 +224,10 @@ public KNNQueryBuilder(StreamInput in) throws IOException { ignoreUnmapped = in.readOptionalBoolean(); } if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { - max_distance = in.readOptionalFloat(); + maxDistance = in.readOptionalFloat(); } if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { - min_score = in.readOptionalFloat(); + minScore = in.readOptionalFloat(); } } catch (IOException ex) { throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex); @@ -237,8 +239,8 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep List vector = null; float boost = AbstractQueryBuilder.DEFAULT_BOOST; Integer k = null; - Float max_distance = null; - Float min_score = null; + Float maxDistance = null; + Float minScore = null; QueryBuilder filter = null; boolean ignoreUnmapped = false; String queryName = null; @@ -264,9 +266,9 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { queryName = parser.text(); } else if (MAX_DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - max_distance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); + maxDistance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); } else if (MIN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - min_score = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); + minScore = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); } else if (IGNORE_UNMAPPED_FIELD.getPreferredName().equals(currentFieldName)) { if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) { ignoreUnmapped = parser.booleanValue(); @@ -320,7 +322,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } } - validateSingleQueryType(k, max_distance, min_score); + validateSingleQueryType(k, maxDistance, minScore); KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter) .boost(boost) @@ -332,10 +334,10 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep if (k != null) { knnQueryBuilder.k(k); - } else if (max_distance != null) { - knnQueryBuilder.maxDistance(max_distance); - } else if (min_score != null) { - knnQueryBuilder.minScore(min_score); + } else if (maxDistance != null) { + knnQueryBuilder.maxDistance(maxDistance); + } else if (minScore != null) { + knnQueryBuilder.minScore(minScore); } return knnQueryBuilder; @@ -355,10 +357,10 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeOptionalBoolean(ignoreUnmapped); } if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { - out.writeOptionalFloat(max_distance); + out.writeOptionalFloat(maxDistance); } if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { - out.writeOptionalFloat(min_score); + out.writeOptionalFloat(minScore); } } @@ -381,11 +383,11 @@ public int getK() { } public float getMaxDistance() { - return this.max_distance; + return this.maxDistance; } public float getMinScore() { - return this.min_score; + return this.minScore; } public QueryBuilder getFilter() { @@ -416,14 +418,14 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio if (filter != null) { builder.field(FILTER_FIELD.getPreferredName(), filter); } - if (max_distance != null) { - builder.field(MAX_DISTANCE_FIELD.getPreferredName(), max_distance); + if (maxDistance != null) { + builder.field(MAX_DISTANCE_FIELD.getPreferredName(), maxDistance); } if (ignoreUnmapped) { builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped); } - if (min_score != null) { - builder.field(MIN_SCORE_FIELD.getPreferredName(), min_score); + if (minScore != null) { + builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore); } printBoostAndQueryName(builder); builder.endObject(); @@ -467,18 +469,22 @@ protected Query doToQuery(QueryShardContext context) { // Currently, k-NN supports distance and score types radial search // We need transform distance/score to right type of engine required radius. Float radius = null; - if (this.max_distance != null) { - if (this.max_distance < 0 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { - throw new IllegalArgumentException("[" + NAME + "] requires distance to be non-negative for space type: " + spaceType); + if (this.maxDistance != null) { + if (this.maxDistance < 0 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { + throw new IllegalArgumentException( + String.format("[" + NAME + "] requires distance to be non-negative for space type: %s", spaceType) + ); } - radius = knnEngine.distanceToRadialThreshold(this.max_distance, spaceType); + radius = knnEngine.distanceToRadialThreshold(this.maxDistance, spaceType); } - if (this.min_score != null) { - if (this.min_score > 1 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { - throw new IllegalArgumentException("[" + NAME + "] requires score to be in the range (0, 1] for space type: " + spaceType); + if (this.minScore != null) { + if (this.minScore > 1 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { + throw new IllegalArgumentException( + String.format("[" + NAME + "] requires score to be in the range [0, 1] for space type: %s", spaceType) + ); } - radius = knnEngine.scoreToRadialThreshold(this.min_score, spaceType); + radius = knnEngine.scoreToRadialThreshold(this.minScore, spaceType); } if (fieldDimension != vector.length) { @@ -538,7 +544,7 @@ protected Query doToQuery(QueryShardContext context) { .build(); return RNNQueryFactory.create(createQueryRequest); } - throw new IllegalArgumentException("[" + NAME + "] requires either k or distance or score to be set"); + throw new IllegalArgumentException(String.format("[%s] requires k or distance or score to be set", NAME)); } private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) { @@ -588,9 +594,7 @@ private static void validateSingleQueryType(Integer k, Float distance, Float sco } if (countSetFields != 1) { - throw new IllegalArgumentException( - "[" + NAME + "] requires only one query type to be set, it can be either k, distance, or score" - ); + throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME)); } } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 8939a569e..bac8c03d4 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -294,7 +294,10 @@ private Map doANNSearch(final LeafReaderContext context, final B knnQuery.getQueryVector(), knnQuery.getRadius(), knnEngine, - knnQuery.getContext().getMaxResultWindow() + knnQuery.getContext().getMaxResultWindow(), + filterIds, + filterType.getValue(), + parentIds ); } } catch (Exception e) { diff --git a/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java index cd32ac4f3..db8084864 100644 --- a/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.query; +import static org.opensearch.knn.common.KNNConstants.DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; @@ -118,7 +119,13 @@ private static Query getFloatVectorSimilarityQuery( final float resultSimilarity, final Query filterQuery ) { - return new FloatVectorSimilarityQuery(fieldName, floatVector, resultSimilarity, filterQuery); + return new FloatVectorSimilarityQuery( + fieldName, + floatVector, + DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO * resultSimilarity, + resultSimilarity, + filterQuery + ); } /** @@ -131,6 +138,12 @@ private static Query getByteVectorSimilarityQuery( final float resultSimilarity, final Query filterQuery ) { - return new ByteVectorSimilarityQuery(fieldName, byteVector, resultSimilarity, filterQuery); + return new ByteVectorSimilarityQuery( + fieldName, + byteVector, + DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO * resultSimilarity, + resultSimilarity, + filterQuery + ); } } diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index b59ac4bcf..53980bbb7 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -191,6 +191,28 @@ public static native KNNQueryResult[] queryIndexWithFilter( @Deprecated(since = "2.14.0", forRemoval = true) public static native long transferVectors(long vectorsPointer, float[][] trainingData); + /** + * Range search index with filter + * + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param radius search within radius threshold + * @param indexMaxResultWindow maximum number of results to return + * @param filteredIds list of doc ids to include in the query result + * @param filterIdsType type of filter ids + * @param parentIds list of parent doc ids when the knn field is a nested field + * @return KNNQueryResult array of neighbors within radius + */ + public static native KNNQueryResult[] rangeSearchIndexWithFilter( + long indexPointer, + float[] queryVector, + float radius, + int indexMaxResultWindow, + long[] filteredIds, + int filterIdsType, + int[] parentIds + ); + /** * Range search index * @@ -198,7 +220,14 @@ public static native KNNQueryResult[] queryIndexWithFilter( * @param queryVector vector to be used for query * @param radius search within radius threshold * @param indexMaxResultWindow maximum number of results to return + * @param parentIds list of parent doc ids when the knn field is a nested field * @return KNNQueryResult array of neighbors within radius */ - public static native KNNQueryResult[] rangeSearchIndex(long indexPointer, float[] queryVector, float radius, int indexMaxResultWindow); + public static native KNNQueryResult[] rangeSearchIndex( + long indexPointer, + float[] queryVector, + float radius, + int indexMaxResultWindow, + int[] parentIds + ); } diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index e846f02d1..20c418819 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -271,6 +271,9 @@ public static long transferVectors(long vectorsPointer, float[][] trainingData) * @param radius search within radius threshold * @param knnEngine engine to query index * @param indexMaxResultWindow maximum number of results to return + * @param filteredIds list of doc ids to include in the query result + * @param filterIdsType how to filter ids: Batch or BitMap + * @param parentIds parent ids of the vectors * @return KNNQueryResult array of neighbors within radius */ public static KNNQueryResult[] radiusQueryIndex( @@ -278,10 +281,24 @@ public static KNNQueryResult[] radiusQueryIndex( float[] queryVector, float radius, KNNEngine knnEngine, - int indexMaxResultWindow + int indexMaxResultWindow, + long[] filteredIds, + int filterIdsType, + int[] parentIds ) { if (KNNEngine.FAISS == knnEngine) { - return FaissService.rangeSearchIndex(indexPointer, queryVector, radius, indexMaxResultWindow); + if (ArrayUtils.isNotEmpty(filteredIds)) { + return FaissService.rangeSearchIndexWithFilter( + indexPointer, + queryVector, + radius, + indexMaxResultWindow, + filteredIds, + filterIdsType, + parentIds + ); + } + return FaissService.rangeSearchIndex(indexPointer, queryVector, radius, indexMaxResultWindow, parentIds); } throw new IllegalArgumentException("RadiusQueryIndex not supported for provided engine"); } diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 61b0461ef..2349bb8d0 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -24,6 +24,7 @@ import org.opensearch.client.ResponseException; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; import org.opensearch.knn.TestUtils; @@ -34,6 +35,7 @@ import java.io.IOException; import java.net.URL; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; import java.util.List; @@ -55,12 +57,14 @@ import static org.opensearch.knn.common.KNNConstants.FP16_MAX_VALUE; import static org.opensearch.knn.common.KNNConstants.FP16_MIN_VALUE; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.MIN_SCORE; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; @@ -91,7 +95,7 @@ public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() { String indexName = "test-index-1"; String fieldName = "test-field-1"; - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); SpaceType spaceType = SpaceType.L2; List mValues = ImmutableList.of(16, 32, 64, 128); @@ -108,10 +112,10 @@ public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() { .field("type", "knn_vector") .field("dimension", dimension) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(KNNConstants.PARAMETERS) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) @@ -180,7 +184,7 @@ public void testEndToEnd_whenMethodIsHNSWFlatAndHasDeletedDocs_thenSucceed() { String indexName = "test-index-1"; String fieldName = "test-field-1"; - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); SpaceType spaceType = SpaceType.L2; List mValues = ImmutableList.of(16, 32, 64, 128); @@ -197,10 +201,10 @@ public void testEndToEnd_whenMethodIsHNSWFlatAndHasDeletedDocs_thenSucceed() { .field("type", "knn_vector") .field("dimension", dimension) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(KNNConstants.PARAMETERS) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) @@ -279,7 +283,7 @@ public void testEndToEnd_whenMethodIsHNSWFlatAndHasDeletedDocs_thenSucceed() { @SneakyThrows public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHNSWFlat_thenSucceed() { - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); SpaceType spaceType = SpaceType.L2; List mValues = ImmutableList.of(16, 32, 64, 128); @@ -296,10 +300,10 @@ public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHN .field("type", "knn_vector") .field("dimension", dimension) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(KNNConstants.PARAMETERS) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) @@ -330,7 +334,7 @@ public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHN assertEquals(testData.indexData.docs.length, getDocCount(INDEX_NAME)); float distance = 300000000000f; - validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, distance, null, spaceType); + validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, distance, null, spaceType, null); // Delete index deleteKNNIndex(INDEX_NAME); @@ -338,7 +342,7 @@ public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHN @SneakyThrows public void testEndToEnd_whenDoRadiusSearch_whenScoreThreshold_whenMethodIsHNSWFlat_thenSucceed() { - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); SpaceType spaceType = SpaceType.L2; List mValues = ImmutableList.of(16, 32, 64, 128); @@ -355,10 +359,10 @@ public void testEndToEnd_whenDoRadiusSearch_whenScoreThreshold_whenMethodIsHNSWF .field("type", "knn_vector") .field("dimension", dimension) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(KNNConstants.PARAMETERS) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) @@ -390,7 +394,7 @@ public void testEndToEnd_whenDoRadiusSearch_whenScoreThreshold_whenMethodIsHNSWF float score = 0.00001f; - validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, null, score, spaceType); + validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, null, score, spaceType, null); // Delete index deleteKNNIndex(INDEX_NAME); @@ -398,7 +402,7 @@ public void testEndToEnd_whenDoRadiusSearch_whenScoreThreshold_whenMethodIsHNSWF @SneakyThrows public void testEndToEnd_whenDoRadiusSearch_whenMoreThanOneScoreThreshold_whenMethodIsHNSWFlat_thenSucceed() { - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); SpaceType spaceType = SpaceType.INNER_PRODUCT; List mValues = ImmutableList.of(16, 32, 64, 128); @@ -415,10 +419,10 @@ public void testEndToEnd_whenDoRadiusSearch_whenMoreThanOneScoreThreshold_whenMe .field("type", "knn_vector") .field("dimension", dimension) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(KNNConstants.PARAMETERS) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) @@ -450,14 +454,14 @@ public void testEndToEnd_whenDoRadiusSearch_whenMoreThanOneScoreThreshold_whenMe float score = 5f; - validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, null, score, spaceType); + validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, null, score, spaceType, null); // Delete index deleteKNNIndex(INDEX_NAME); } @SneakyThrows - public void testEndToEnd_whenDoRadiusSearch__whenDistanceThreshold_whenMethodIsHNSWPQ_thenSucceed() { + public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHNSWPQ_thenSucceed() { String indexName = "test-index"; String fieldName = "test-field"; String trainingIndexName = "training-index"; @@ -535,7 +539,7 @@ public void testEndToEnd_whenDoRadiusSearch__whenDistanceThreshold_whenMethodIsH float distance = 300000000000f; - validateRadiusSearchResults(indexName, fieldName, testData.queries, distance, null, spaceType); + validateRadiusSearchResults(indexName, fieldName, testData.queries, distance, null, spaceType, null); // Delete index deleteKNNIndex(indexName); @@ -554,6 +558,32 @@ public void testEndToEnd_whenDoRadiusSearch__whenDistanceThreshold_whenMethodIsH fail("Graphs are not getting evicted"); } + @SneakyThrows + public void testRadialQuery_withFilter_thenSuccess() { + setupKNNIndexForFilterQuery(); + + final float[][] searchVector = new float[][] { { 3.3f, 3.0f, 5.0f } }; + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery("color", "red"); + List expectedDocIds = Arrays.asList(DOC_ID_3); + + float distance = 15f; + List> queryResult = validateRadiusSearchResults( + INDEX_NAME, + FIELD_NAME, + searchVector, + distance, + null, + SpaceType.L2, + termQueryBuilder + ); + + assertEquals(1, queryResult.get(0).size()); + assertEquals(expectedDocIds.get(0), queryResult.get(0).get(0).getDocId()); + + // Delete index + deleteKNNIndex(INDEX_NAME); + } + @SneakyThrows public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() { String indexName = "test-index"; @@ -670,7 +700,7 @@ public void testHNSWSQFP16_whenIndexedAndQueried_thenSucceed() { String indexName = "test-index-hnsw-sqfp16"; String fieldName = "test-field-hnsw-sqfp16"; - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); SpaceType[] spaceTypes = { SpaceType.L2, SpaceType.INNER_PRODUCT }; Random random = new Random(); SpaceType spaceType = spaceTypes[random.nextInt(spaceTypes.length)]; @@ -690,10 +720,10 @@ public void testHNSWSQFP16_whenIndexedAndQueried_thenSucceed() { .field("type", "knn_vector") .field("dimension", dimension) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(KNNConstants.PARAMETERS) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) @@ -784,7 +814,7 @@ public void testHNSWSQFP16_whenIndexedWithOutOfFP16Range_thenThrowException() { String indexName = "test-index-sqfp16"; String fieldName = "test-field-sqfp16"; - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); SpaceType[] spaceTypes = { SpaceType.L2, SpaceType.INNER_PRODUCT }; Random random = new Random(); SpaceType spaceType = spaceTypes[random.nextInt(spaceTypes.length)]; @@ -803,10 +833,10 @@ public void testHNSWSQFP16_whenIndexedWithOutOfFP16Range_thenThrowException() { .field("type", "knn_vector") .field("dimension", dimension) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(KNNConstants.PARAMETERS) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) @@ -886,7 +916,7 @@ public void testHNSWSQFP16_whenClipToFp16isTrueAndIndexedWithOutOfFP16Range_then String indexName = "test-index-sqfp16-clip-fp16"; String fieldName = "test-field-sqfp16"; - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); Random random = new Random(); List mValues = ImmutableList.of(16, 32, 64, 128); @@ -903,10 +933,10 @@ public void testHNSWSQFP16_whenClipToFp16isTrueAndIndexedWithOutOfFP16Range_then .field("type", "knn_vector") .field("dimension", dimension) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(KNNConstants.PARAMETERS) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) @@ -1351,7 +1381,7 @@ public void testDocUpdate() throws IOException { String fieldName = "test-field-1"; Integer dimension = 2; - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); SpaceType spaceType = SpaceType.L2; // Create an index @@ -1362,9 +1392,9 @@ public void testDocUpdate() throws IOException { .field("type", "knn_vector") .field("dimension", dimension) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) .endObject() .endObject() .endObject() @@ -1387,7 +1417,7 @@ public void testDocDeletion() throws IOException { String fieldName = "test-field-1"; Integer dimension = 2; - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); SpaceType spaceType = SpaceType.L2; // Create an index @@ -1398,9 +1428,9 @@ public void testDocDeletion() throws IOException { .field("type", "knn_vector") .field("dimension", dimension) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) .endObject() .endObject() .endObject() @@ -1574,9 +1604,9 @@ public void testFiltering_whenUsingFaissExactSearchWithIP_thenMatchExpectedScore .field("type", "knn_vector") .field("dimension", 2) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW).getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.INNER_PRODUCT.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .field(NAME, KNNEngine.FAISS.getMethod(METHOD_HNSW).getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.INNER_PRODUCT.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) .endObject() .endObject() .endObject() @@ -1625,9 +1655,9 @@ protected void setupKNNIndexForFilterQuery() throws Exception { .field("type", "knn_vector") .field("dimension", 3) .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW).getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .field(NAME, KNNEngine.FAISS.getMethod(METHOD_HNSW).getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) .endObject() .endObject() .endObject() @@ -1686,26 +1716,31 @@ private void validateGraphEviction() throws Exception { fail("Graphs are not getting evicted"); } - private void validateRadiusSearchResults( + private List> validateRadiusSearchResults( String indexName, String fieldName, float[][] queryVectors, Float distanceThreshold, Float scoreThreshold, - final SpaceType spaceType + final SpaceType spaceType, + TermQueryBuilder filterQuery ) throws IOException, ParseException { + List> queryResults = new ArrayList<>(); for (float[] queryVector : queryVectors) { XContentBuilder queryBuilder = XContentFactory.jsonBuilder().startObject().startObject("query"); queryBuilder.startObject("knn"); queryBuilder.startObject(fieldName); queryBuilder.field("vector", queryVector); if (distanceThreshold != null) { - queryBuilder.field("max_distance", distanceThreshold); + queryBuilder.field(MAX_DISTANCE, distanceThreshold); } else if (scoreThreshold != null) { - queryBuilder.field("min_score", scoreThreshold); + queryBuilder.field(MIN_SCORE, scoreThreshold); } else { throw new IllegalArgumentException("Invalid threshold"); } + if (filterQuery != null) { + queryBuilder.field("filter", filterQuery); + } queryBuilder.endObject(); queryBuilder.endObject(); queryBuilder.endObject().endObject(); @@ -1724,6 +1759,8 @@ private void validateRadiusSearchResults( throw new IllegalArgumentException("Invalid space type"); } } + queryResults.add(knnResults); } + return queryResults; } } diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index 4cc856613..bf9a6b776 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -34,7 +34,9 @@ import java.util.function.Function; import java.util.stream.Collectors; +import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.MIN_SCORE; import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; @@ -629,9 +631,9 @@ private void validateRadiusSearchResults( builder.startObject(FIELD_NAME); builder.field("vector", searchVectors[i]); if (distanceThreshold != null) { - builder.field("max_distance", distanceThreshold); + builder.field(MAX_DISTANCE, distanceThreshold); } else if (scoreThreshold != null) { - builder.field("min_score", scoreThreshold); + builder.field(MIN_SCORE, scoreThreshold); } else { throw new IllegalArgumentException("Either distance or score must be provided"); } diff --git a/src/test/java/org/opensearch/knn/index/NestedSearchIT.java b/src/test/java/org/opensearch/knn/index/NestedSearchIT.java index e4d828a01..d97aa4d40 100644 --- a/src/test/java/org/opensearch/knn/index/NestedSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/NestedSearchIT.java @@ -32,6 +32,7 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.MIN_SCORE; import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.PATH; @@ -152,7 +153,64 @@ public void testNestedSearchWithFaiss_whenDoingExactSearch_thenReturnCorrectResu updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, 100)); Float[] queryVector = { 3f, 3f, 3f }; - Response response = queryNestedField(INDEX_NAME, 3, queryVector, FIELD_NAME_PARKING, FIELD_VALUE_TRUE); + Response response = queryNestedField(INDEX_NAME, 3, queryVector, FIELD_NAME_PARKING, FIELD_VALUE_TRUE, null); + String entity = EntityUtils.toString(response.getEntity()); + List docIds = parseIds(entity); + assertEquals(2, docIds.size()); + assertEquals("3", docIds.get(0)); + assertEquals("1", docIds.get(1)); + assertEquals(2, parseTotalSearchHits(entity)); + } + + /** + * { + * "query": { + * "nested": { + * "path": "test_nested", + * "query": { + * "knn": { + * "test_nested.test_vector": { + * "vector": [ + * 1, 1, 1 + * ], + * "min_score": 0.00001, + * "filter": { + * "term": { + * "parking": "true" + * } + * } + * } + * } + * } + * } + * } + * } + * + */ + @SneakyThrows + public void testNestedWithFaiss_whenFilter_whenDoRadialSearch_thenReturnCorrectResults() { + createKnnIndex(3, KNNEngine.FAISS.getName()); + + for (int i = 1; i < 4; i++) { + float value = (float) i; + String doc = NestedKnnDocBuilder.create(FIELD_NAME_NESTED) + .addVectors( + FIELD_NAME_VECTOR, + new Float[] { value, value, value }, + new Float[] { value, value, value }, + new Float[] { value, value, value } + ) + .addTopLevelField(FIELD_NAME_PARKING, i % 2 == 1 ? FIELD_VALUE_TRUE : FIELD_VALUE_FALSE) + .build(); + addKnnDoc(INDEX_NAME, String.valueOf(i), doc); + } + refreshIndex(INDEX_NAME); + forceMergeKnnIndex(INDEX_NAME); + + Float[] queryVector = { 3f, 3f, 3f }; + Float minScore = 0.00001f; + Response response = queryNestedField(INDEX_NAME, null, queryVector, FIELD_NAME_PARKING, FIELD_VALUE_TRUE, minScore); + String entity = EntityUtils.toString(response.getEntity()); List docIds = parseIds(entity); assertEquals(2, docIds.size()); @@ -215,22 +273,29 @@ private void createKnnIndex(final int dimension, final String engine) throws Exc } private Response queryNestedField(final String index, final int k, final Object[] vector) throws IOException { - return queryNestedField(index, k, vector, null, null); + return queryNestedField(index, k, vector, null, null, null); } private Response queryNestedField( final String index, - final int k, + final Integer k, final Object[] vector, final String filterName, - final String filterValue + final String filterValue, + final Float minScore ) throws IOException { XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject(QUERY); builder.startObject(TYPE_NESTED); builder.field(PATH, FIELD_NAME_NESTED); builder.startObject(QUERY).startObject(KNN).startObject(FIELD_NAME_NESTED + "." + FIELD_NAME_VECTOR); builder.field(VECTOR, vector); - builder.field(K, k); + if (minScore != null) { + builder.field(MIN_SCORE, minScore); + } else if (k != null) { + builder.field(K, k); + } else { + throw new IllegalArgumentException("k or minScore must be provided in the query"); + } if (filterName != null && filterValue != null) { builder.startObject(FIELD_FILTER); builder.startObject(FIELD_TERM); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index ddc961093..4b9872131 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -53,6 +53,7 @@ import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.KNNConstants.DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; import static org.opensearch.knn.index.util.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH; @@ -417,7 +418,12 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_th KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); FloatVectorSimilarityQuery query = (FloatVectorSimilarityQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); - assertTrue(query.toString().contains("resultSimilarity=" + KNNEngine.LUCENE.distanceToRadialThreshold(MAX_DISTANCE, SpaceType.L2))); + float resultSimilarity = KNNEngine.LUCENE.distanceToRadialThreshold(MAX_DISTANCE, SpaceType.L2); + + assertTrue(query.toString().contains("resultSimilarity=" + resultSimilarity)); + assertTrue( + query.toString().contains("traversalSimilarity=" + DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO * resultSimilarity) + ); } public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() { diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index e80190ff8..0d15b5f5f 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -699,8 +699,9 @@ public void testDoANNSearch_whenRadialIsDefined_thenCallJniRadiusQueryIndex() { final float[] queryVector = new float[] { 0.1f, 0.3f }; final float radius = 0.5f; final int maxResults = 1000; - jniServiceMockedStatic.when(() -> JNIService.radiusQueryIndex(anyLong(), any(), anyFloat(), any(), anyInt())) - .thenReturn(getKNNQueryResults()); + jniServiceMockedStatic.when( + () -> JNIService.radiusQueryIndex(anyLong(), any(), anyFloat(), any(), anyInt(), any(), anyInt(), any()) + ).thenReturn(getKNNQueryResults()); KNNQuery.Context context = mock(KNNQuery.Context.class); when(context.getMaxResultWindow()).thenReturn(maxResults); @@ -742,7 +743,9 @@ public void testDoANNSearch_whenRadialIsDefined_thenCallJniRadiusQueryIndex() { final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); assertNotNull(knnScorer); - jniServiceMockedStatic.verify(() -> JNIService.radiusQueryIndex(anyLong(), any(), anyFloat(), any(), anyInt())); + jniServiceMockedStatic.verify( + () -> JNIService.radiusQueryIndex(anyLong(), any(), anyFloat(), any(), anyInt(), any(), anyInt(), any()) + ); final DocIdSetIterator docIdSetIterator = knnScorer.iterator();