From fbc6cf943b05f46ce0440ec091a96bb11604e053 Mon Sep 17 00:00:00 2001 From: Wei Wang Date: Thu, 12 Dec 2024 00:38:50 +0800 Subject: [PATCH 1/3] Add check to directly use ANN Search when filters match all docs. Signed-off-by: Wei Wang --- .../org/opensearch/knn/index/query/KNNWeight.java | 14 ++++++++++++++ .../opensearch/knn/index/query/KNNWeightTests.java | 8 +++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index b64472994..9baf946a6 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -127,7 +127,14 @@ public Scorer scorer(LeafReaderContext context) throws IOException { */ public Map searchLeaf(LeafReaderContext context, int k) throws IOException { final BitSet filterBitSet = getFilteredDocsBitSet(context); + final int maxDoc = context.reader().maxDoc(); int cardinality = filterBitSet.cardinality(); + /* + * If filters match all docs in this segment, then there is no need to do any extra step + * and should directly do ANN Search*/ + if (cardinality == maxDoc){ + return doANNSearch(context, filterBitSet, cardinality, k); + } // We don't need to go to JNI layer if no documents are found which satisfy the filters // We should give this condition a deeper look that where it should be placed. For now I feel this is a good // place, @@ -142,6 +149,12 @@ public Map searchLeaf(LeafReaderContext context, int k) throws I if (isFilteredExactSearchPreferred(cardinality)) { return doExactSearch(context, filterBitSet, k); } + /* + * If filters match all docs in this segment, then there is no need to do any extra step + * and should directly do ANN Search*/ + if (filterWeight != null && cardinality == maxDoc) { + return doANNSearch(context, new FixedBitSet(0), 0, k); + } Map docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k); // See whether we have to perform exact search based on approx search results // This is required if there are no native engine files or if approximate search returned @@ -312,6 +325,7 @@ private Map doANNSearch( // Now that we have the allocation, we need to readLock it indexAllocation.readLock(); indexAllocation.incRef(); + try { if (indexAllocation.isClosed()) { throw new RuntimeException("Index has already been closed"); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 511895026..d72003820 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -815,7 +815,7 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); // scorer will return 2 documents when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(1)); - when(reader.maxDoc()).thenReturn(1); + when(reader.maxDoc()).thenReturn(2); final Bits liveDocsBits = mock(Bits.class); when(reader.getLiveDocs()).thenReturn(liveDocsBits); when(liveDocsBits.get(filterDocId)).thenReturn(true); @@ -891,6 +891,7 @@ public void testRadialSearch_whenNoEngineFiles_thenPerformExactSearch() { final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); when(leafReaderContext.reader()).thenReturn(reader); + when(reader.maxDoc()).thenReturn(1); final FSDirectory directory = mock(FSDirectory.class); when(reader.directory()).thenReturn(directory); @@ -968,7 +969,7 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); // scorer will return 2 documents when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(1)); - when(reader.maxDoc()).thenReturn(1); + when(reader.maxDoc()).thenReturn(2); final Bits liveDocsBits = mock(Bits.class); when(reader.getLiveDocs()).thenReturn(liveDocsBits); when(liveDocsBits.get(filterDocId)).thenReturn(true); @@ -1168,6 +1169,7 @@ public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() { final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); when(leafReaderContext.reader()).thenReturn(reader); + when(reader.maxDoc()).thenReturn(1); final Weight filterQueryWeight = mock(Weight.class); final Scorer filterScorer = mock(Scorer.class); @@ -1202,7 +1204,7 @@ public void testANNWithParentsFilter_whenExactSearch_thenSuccess() { // We will have 0, 1 for filteredIds and 2 will be the parent id for both of them final Scorer filterScorer = mock(Scorer.class); when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(2)); - when(reader.maxDoc()).thenReturn(2); + when(reader.maxDoc()).thenReturn(3); // Query vector is {1.8f, 2.4f}, therefore, second vector {1.9f, 2.5f} should be returned in a result final List vectors = Arrays.asList(new float[] { 0.1f, 0.3f }, new float[] { 1.9f, 2.5f }); From 105c39a199397901405f430bf39db1f7bdce74c4 Mon Sep 17 00:00:00 2001 From: Wei Wang Date: Wed, 18 Dec 2024 23:35:14 +0800 Subject: [PATCH 2/3] Fix failed tests and rebase on main branch Signed-off-by: Wei Wang --- CHANGELOG.md | 1 + .../opensearch/knn/index/query/KNNWeight.java | 8 +- .../knn/index/query/KNNWeightTests.java | 93 ++++++++++++++++++- 3 files changed, 94 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a5fda2a2..17b1b55c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] - Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290] - Optimizes lucene query execution to prevent unnecessary rewrites (#2305)[https://github.com/opensearch-project/k-NN/pull/2305] +- Add check to directly use ANN Search when filters match all docs. (#2320)[https://github.com/opensearch-project/k-NN/pull/2320] ### Bug Fixes * Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282] * Allow validation for non knn index only after 2.17.0 (#2315)[https://github.com/opensearch-project/k-NN/pull/2315] diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 497593a00..3ef716a9c 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -131,12 +131,6 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep final BitSet filterBitSet = getFilteredDocsBitSet(context); final int maxDoc = context.reader().maxDoc(); int cardinality = filterBitSet.cardinality(); - /* - * If filters match all docs in this segment, then there is no need to do any extra step - * and should directly do ANN Search*/ - if (cardinality == maxDoc){ - return doANNSearch(context, filterBitSet, cardinality, k); - } // We don't need to go to JNI layer if no documents are found which satisfy the filters // We should give this condition a deeper look that where it should be placed. For now I feel this is a good // place, @@ -156,7 +150,7 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep * If filters match all docs in this segment, then there is no need to do any extra step * and should directly do ANN Search*/ if (filterWeight != null && cardinality == maxDoc) { - return doANNSearch(context, new FixedBitSet(0), 0, k); + return new PerLeafResult(new FixedBitSet(0), doANNSearch(context, new FixedBitSet(0), 0, k)); } Map docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k); // See whether we have to perform exact search based on approx search results diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index d72003820..69298a994 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -671,7 +671,7 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is when(liveDocsBits.length()).thenReturn(1000); final SegmentReader reader = mockSegmentReader(); - when(reader.maxDoc()).thenReturn(filterDocIds.length); + when(reader.maxDoc()).thenReturn(filterDocIds.length + 1); when(reader.getLiveDocs()).thenReturn(liveDocsBits); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); @@ -758,6 +758,97 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); } + @SneakyThrows + public void testANNWithFilterQuery_whenFiltersMatchAllDocs_thenSuccess() { + // Given + int k = 3; + final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; + FixedBitSet filterBitSet = new FixedBitSet(filterDocIds.length); + for (int docId : filterDocIds) { + filterBitSet.set(docId); + } + + jniServiceMockedStatic.when( + () -> JNIService.queryIndex( + anyLong(), + eq(QUERY_VECTOR), + eq(k), + eq(HNSW_METHOD_PARAMETERS), + any(), + eq(new FixedBitSet(0).getBits()), + anyInt(), + any() + ) + ).thenReturn(getFilteredKNNQueryResults()); + + final Bits liveDocsBits = mock(Bits.class); + for (int filterDocId : filterDocIds) { + when(liveDocsBits.get(filterDocId)).thenReturn(true); + } + when(liveDocsBits.length()).thenReturn(1000); + + final SegmentReader reader = mockSegmentReader(); + when(reader.maxDoc()).thenReturn(filterDocIds.length); + when(reader.getLiveDocs()).thenReturn(liveDocsBits); + + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(QUERY_VECTOR) + .k(k) + .indexName(INDEX_NAME) + .filterQuery(FILTER_QUERY) + .methodParameters(HNSW_METHOD_PARAMETERS) + .build(); + + final Weight filterQueryWeight = mock(Weight.class); + final Scorer filterScorer = mock(Scorer.class); + when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); + // Just to make sure that we are not hitting the exact search condition + when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length + 1)); + + final float boost = (float) randomDoubleBetween(0, 10, true); + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); + + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + SPACE_TYPE, + SpaceType.L2.getValue() + ); + + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn(attributesMap); + + // When + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + + // Then + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + + jniServiceMockedStatic.verify( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), any(), anyInt(), any()), + times(1) + ); + + final List actualDocIds = new ArrayList<>(); + final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + private SegmentReader mockSegmentReader() { Path path = mock(Path.class); From a5187d2d8b59beb9b3a297d479c28665f3b45e6f Mon Sep 17 00:00:00 2001 From: Wei Wang Date: Sat, 28 Dec 2024 01:20:21 +0800 Subject: [PATCH 3/3] pass filterbitset as null and add integ tests. Signed-off-by: Wei Wang --- .../knn/index/query/FilterIdsSelector.java | 5 +- .../opensearch/knn/index/query/KNNWeight.java | 14 ++--- .../knn/index/query/KNNWeightTests.java | 11 +--- .../knn/integ/FilteredSearchANNSearchIT.java | 57 +++++++++++++++++++ 4 files changed, 69 insertions(+), 18 deletions(-) create mode 100644 src/test/java/org/opensearch/knn/integ/FilteredSearchANNSearchIT.java diff --git a/src/main/java/org/opensearch/knn/index/query/FilterIdsSelector.java b/src/main/java/org/opensearch/knn/index/query/FilterIdsSelector.java index bf06e8c5e..12711911a 100644 --- a/src/main/java/org/opensearch/knn/index/query/FilterIdsSelector.java +++ b/src/main/java/org/opensearch/knn/index/query/FilterIdsSelector.java @@ -78,7 +78,10 @@ public enum FilterIdsSelectorType { public static FilterIdsSelector getFilterIdSelector(final BitSet filterIdsBitSet, final int cardinality) throws IOException { long[] filterIds; FilterIdsSelector.FilterIdsSelectorType filterType; - if (filterIdsBitSet instanceof FixedBitSet) { + if (filterIdsBitSet == null) { + filterIds = null; + filterType = FilterIdsSelector.FilterIdsSelectorType.BITMAP; + } else if (filterIdsBitSet instanceof FixedBitSet) { /** * When filterIds is dense filter, using fixed bitset */ diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 3ef716a9c..37b5cc9ad 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -146,13 +146,14 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep Map result = doExactSearch(context, new BitSetIterator(filterBitSet, cardinality), cardinality, k); return new PerLeafResult(filterWeight == null ? null : filterBitSet, result); } + /* - * If filters match all docs in this segment, then there is no need to do any extra step - * and should directly do ANN Search*/ - if (filterWeight != null && cardinality == maxDoc) { - return new PerLeafResult(new FixedBitSet(0), doANNSearch(context, new FixedBitSet(0), 0, k)); - } - Map docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k); + * If filters match all docs in this segment, then null should be passed as filterBitSet + * so that it will not do a bitset look up in bottom search layer. + */ + final BitSet annFilter = (filterWeight != null && cardinality == maxDoc) ? null : filterBitSet; + final Map docIdsToScoreMap = doANNSearch(context, annFilter, cardinality, k); + // See whether we have to perform exact search based on approx search results // This is required if there are no native engine files or if approximate search returned // results less than K, though we have more than k filtered docs @@ -327,7 +328,6 @@ private Map doANNSearch( // Now that we have the allocation, we need to readLock it indexAllocation.readLock(); indexAllocation.incRef(); - try { if (indexAllocation.isClosed()) { throw new RuntimeException("Index has already been closed"); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 69298a994..8011cc08c 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -769,16 +769,7 @@ public void testANNWithFilterQuery_whenFiltersMatchAllDocs_thenSuccess() { } jniServiceMockedStatic.when( - () -> JNIService.queryIndex( - anyLong(), - eq(QUERY_VECTOR), - eq(k), - eq(HNSW_METHOD_PARAMETERS), - any(), - eq(new FixedBitSet(0).getBits()), - anyInt(), - any() - ) + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), eq(null), anyInt(), any()) ).thenReturn(getFilteredKNNQueryResults()); final Bits liveDocsBits = mock(Bits.class); diff --git a/src/test/java/org/opensearch/knn/integ/FilteredSearchANNSearchIT.java b/src/test/java/org/opensearch/knn/integ/FilteredSearchANNSearchIT.java new file mode 100644 index 000000000..191ab944c --- /dev/null +++ b/src/test/java/org/opensearch/knn/integ/FilteredSearchANNSearchIT.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.integ; + +import com.google.common.collect.ImmutableMap; +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.opensearch.client.Response; +import org.opensearch.common.settings.Settings; +import org.opensearch.knn.KNNJsonQueryBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.index.KNNSettings; +import java.util.List; + +import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; + +@Log4j2 +public class FilteredSearchANNSearchIT extends KNNRestTestCase { + @SneakyThrows + public void testFilteredSearchWithFaissHnsw_whenFiltersMatchAllDocs_thenReturnCorrectResults() { + String filterFieldName = "color"; + final int expectResultSize = randomIntBetween(1, 3); + final String filterValue = "red"; + createKnnIndex(INDEX_NAME, getKNNDefaultIndexSettings(), createKnnIndexMapping(FIELD_NAME, 3, METHOD_HNSW, FAISS_NAME)); + + // ingest 4 vector docs into the index with the same field {"color": "red"} + for (int i = 0; i < 4; i++) { + addKnnDocWithAttributes(String.valueOf(i), new float[] { i, i, i }, ImmutableMap.of(filterFieldName, filterValue)); + } + + refreshIndex(INDEX_NAME); + forceMergeKnnIndex(INDEX_NAME); + + updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, 0)); + + Float[] queryVector = { 3f, 3f, 3f }; + // All docs in one segment will match the filters value + String query = KNNJsonQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(expectResultSize) + .filterFieldName(filterFieldName) + .filterValue(filterValue) + .build() + .getQueryString(); + Response response = searchKNNIndex(INDEX_NAME, query, expectResultSize); + String entity = EntityUtils.toString(response.getEntity()); + List docIds = parseIds(entity); + assertEquals(expectResultSize, docIds.size()); + assertEquals(expectResultSize, parseTotalSearchHits(entity)); + } +}