From e05c94f9871b64065a3da62799d932de9293d10c Mon Sep 17 00:00:00 2001 From: Wei Wang Date: Thu, 12 Dec 2024 00:38:50 +0800 Subject: [PATCH] Add check to directly use ANN Search when filters match all docs. Signed-off-by: Wei Wang --- .../java/org/opensearch/knn/index/query/KNNWeight.java | 7 +++++++ .../org/opensearch/knn/index/query/KNNWeightTests.java | 8 +++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index b64472994..fb73088ac 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -127,7 +127,14 @@ public Scorer scorer(LeafReaderContext context) throws IOException { */ public Map searchLeaf(LeafReaderContext context, int k) throws IOException { final BitSet filterBitSet = getFilteredDocsBitSet(context); + final int maxDoc = context.reader().maxDoc(); int cardinality = filterBitSet.cardinality(); + /* + * If filters match all docs in this segment, then there is no need to do any extra step + * and should directly do ANN Search*/ + if (cardinality == maxDoc){ + return doANNSearch(context, filterBitSet, cardinality, k); + } // We don't need to go to JNI layer if no documents are found which satisfy the filters // We should give this condition a deeper look that where it should be placed. For now I feel this is a good // place, diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 511895026..d72003820 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -815,7 +815,7 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); // scorer will return 2 documents when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(1)); - when(reader.maxDoc()).thenReturn(1); + when(reader.maxDoc()).thenReturn(2); final Bits liveDocsBits = mock(Bits.class); when(reader.getLiveDocs()).thenReturn(liveDocsBits); when(liveDocsBits.get(filterDocId)).thenReturn(true); @@ -891,6 +891,7 @@ public void testRadialSearch_whenNoEngineFiles_thenPerformExactSearch() { final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); when(leafReaderContext.reader()).thenReturn(reader); + when(reader.maxDoc()).thenReturn(1); final FSDirectory directory = mock(FSDirectory.class); when(reader.directory()).thenReturn(directory); @@ -968,7 +969,7 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); // scorer will return 2 documents when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(1)); - when(reader.maxDoc()).thenReturn(1); + when(reader.maxDoc()).thenReturn(2); final Bits liveDocsBits = mock(Bits.class); when(reader.getLiveDocs()).thenReturn(liveDocsBits); when(liveDocsBits.get(filterDocId)).thenReturn(true); @@ -1168,6 +1169,7 @@ public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() { final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); when(leafReaderContext.reader()).thenReturn(reader); + when(reader.maxDoc()).thenReturn(1); final Weight filterQueryWeight = mock(Weight.class); final Scorer filterScorer = mock(Scorer.class); @@ -1202,7 +1204,7 @@ public void testANNWithParentsFilter_whenExactSearch_thenSuccess() { // We will have 0, 1 for filteredIds and 2 will be the parent id for both of them final Scorer filterScorer = mock(Scorer.class); when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(2)); - when(reader.maxDoc()).thenReturn(2); + when(reader.maxDoc()).thenReturn(3); // Query vector is {1.8f, 2.4f}, therefore, second vector {1.9f, 2.5f} should be returned in a result final List vectors = Arrays.asList(new float[] { 0.1f, 0.3f }, new float[] { 1.9f, 2.5f });