From 50c96f314da065fabfabbb2b61f61d23b4e5cd07 Mon Sep 17 00:00:00 2001 From: Wei Wang Date: Sat, 28 Dec 2024 01:20:21 +0800 Subject: [PATCH] pass filterbitset as null and add integ tests. Signed-off-by: Wei Wang --- .../opensearch/knn/index/query/KNNWeight.java | 25 ++++-- .../knn/index/query/KNNWeightTests.java | 2 +- .../knn/integ/FilteredSearchANNSearchIT.java | 82 +++++++++++++++++++ 3 files changed, 99 insertions(+), 10 deletions(-) create mode 100644 src/test/java/org/opensearch/knn/integ/FilteredSearchANNSearchIT.java diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 3ef716a9c..b8e654c2b 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -146,13 +146,16 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep Map result = doExactSearch(context, new BitSetIterator(filterBitSet, cardinality), cardinality, k); return new PerLeafResult(filterWeight == null ? null : filterBitSet, result); } + final Map docIdsToScoreMap; /* - * If filters match all docs in this segment, then there is no need to do any extra step - * and should directly do ANN Search*/ + * If filters match all docs in this segment, then null should be passed as filterBitSet + * so that it will not do a bitset look up in bottom search layer. + */ if (filterWeight != null && cardinality == maxDoc) { - return new PerLeafResult(new FixedBitSet(0), doANNSearch(context, new FixedBitSet(0), 0, k)); + docIdsToScoreMap = doANNSearch(context, null, 0, k); + } else { + docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k); } - Map docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k); // See whether we have to perform exact search based on approx search results // This is required if there are no native engine files or if approximate search returned // results less than K, though we have more than k filtered docs @@ -161,7 +164,7 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep Map result = doExactSearch(context, docs, cardinality, k); return new PerLeafResult(filterWeight == null ? null : filterBitSet, result); } - return new PerLeafResult(filterWeight == null ? null : filterBitSet, docIdsToScoreMap); + return new PerLeafResult((filterWeight == null || cardinality == maxDoc ) ? null : filterBitSet, docIdsToScoreMap); } private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException { @@ -320,10 +323,14 @@ private Map doANNSearch( throw new RuntimeException(e); } - // From cardinality select different filterIds type - FilterIdsSelector filterIdsSelector = FilterIdsSelector.getFilterIdSelector(filterIdsBitSet, cardinality); - long[] filterIds = filterIdsSelector.getFilterIds(); - FilterIdsSelector.FilterIdsSelectorType filterType = filterIdsSelector.getFilterType(); + long[] filterIds = null; + FilterIdsSelector.FilterIdsSelectorType filterType = FilterIdsSelector.FilterIdsSelectorType.BITMAP; + if (filterIdsBitSet != null){ + // From cardinality select different filterIds type + FilterIdsSelector filterIdsSelector = FilterIdsSelector.getFilterIdSelector(filterIdsBitSet, cardinality); + filterIds = filterIdsSelector.getFilterIds(); + filterType = filterIdsSelector.getFilterType(); + } // Now that we have the allocation, we need to readLock it indexAllocation.readLock(); indexAllocation.incRef(); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 69298a994..c77dedc26 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -775,7 +775,7 @@ public void testANNWithFilterQuery_whenFiltersMatchAllDocs_thenSuccess() { eq(k), eq(HNSW_METHOD_PARAMETERS), any(), - eq(new FixedBitSet(0).getBits()), + eq(null), anyInt(), any() ) diff --git a/src/test/java/org/opensearch/knn/integ/FilteredSearchANNSearchIT.java b/src/test/java/org/opensearch/knn/integ/FilteredSearchANNSearchIT.java new file mode 100644 index 000000000..f3ebb88e9 --- /dev/null +++ b/src/test/java/org/opensearch/knn/integ/FilteredSearchANNSearchIT.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.integ; + +import com.google.common.collect.ImmutableMap; +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.opensearch.client.Response; +import org.opensearch.common.settings.Settings; +import org.opensearch.knn.KNNJsonIndexMappingsBuilder; +import org.opensearch.knn.KNNJsonQueryBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.engine.KNNEngine; +import java.util.List; + +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; + +@Log4j2 +public class FilteredSearchANNSearchIT extends KNNRestTestCase { + @SneakyThrows + public void testFilteredSearchWithFaissHnsw_whenFiltersMatchAllDocs_thenReturnCorrectResults() { + String filterFieldName = "color"; + final int expectResultSize = randomIntBetween(1,3); + final String filterValue = "red"; + createKnnIndex(INDEX_NAME, FIELD_NAME, 3, KNNEngine.FAISS); + + // ingest 4 vector docs into the index with the same field {"color": "red"} + for (int i = 0; i < 4; i++) { + addKnnDocWithAttributes( + String.valueOf(i), + new float[] { i, i, i }, + ImmutableMap.of(filterFieldName, filterValue) + ); + } + + refreshIndex(INDEX_NAME); + forceMergeKnnIndex(INDEX_NAME); + + updateIndexSettings( + INDEX_NAME, + Settings.builder().put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, 0) + ); + + Float[] queryVector = { 3f, 3f, 3f }; + // All docs in one segment will match the filters value + String query = KNNJsonQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(expectResultSize) + .filterFieldName(filterFieldName) + .filterValue(filterValue) + .build() + .getQueryString(); + Response response = searchKNNIndex(INDEX_NAME, query, expectResultSize); + String entity = EntityUtils.toString(response.getEntity()); + List docIds = parseIds(entity); + assertEquals(expectResultSize, docIds.size()); + assertEquals(expectResultSize, parseTotalSearchHits(entity)); + } + + private void createKnnIndex(final String indexName, final String fieldName, final int dimension, final KNNEngine knnEngine) + throws Exception { + KNNJsonIndexMappingsBuilder.Method method = KNNJsonIndexMappingsBuilder.Method.builder() + .methodName(METHOD_HNSW) + .engine(knnEngine.getName()) + .build(); + + String knnIndexMapping = KNNJsonIndexMappingsBuilder.builder() + .fieldName(fieldName) + .dimension(dimension) + .method(method) + .build() + .getIndexMapping(); + + createKnnIndex(indexName, knnIndexMapping); + } +}