From 075d0b94a12b4dba7d27defbf7e59f63b555eb8b Mon Sep 17 00:00:00 2001 From: Vijayan Balasubramanian Date: Thu, 14 Nov 2024 11:36:56 -0800 Subject: [PATCH 1/2] Refactor scoring to map leafreader context with results We map order of results to order of segments, and finally rely on that order to build top docs. Refactor method to use map.Entry to map leafreader context with results from those leaves. This is required when we want to split segments based on approx search or exact search to reduce rescoring twice by exact search Signed-off-by: Vijayan Balasubramanian --- .../knn/index/query/ResultUtil.java | 12 ++++--- .../nativelib/NativeEngineKnnVectorQuery.java | 35 ++++++++++--------- .../knn/index/query/ResultUtilTests.java | 26 ++++++++++---- .../NativeEngineKNNVectorQueryTests.java | 6 ++-- 4 files changed, 47 insertions(+), 32 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/query/ResultUtil.java b/src/main/java/org/opensearch/knn/index/query/ResultUtil.java index f62c09cb0..9454b94a2 100644 --- a/src/main/java/org/opensearch/knn/index/query/ResultUtil.java +++ b/src/main/java/org/opensearch/knn/index/query/ResultUtil.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.query; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -30,14 +31,15 @@ public final class ResultUtil { * @param perLeafResults Results from the list * @param k the number of results across all leaf results to return */ - public static void reduceToTopK(List> perLeafResults, int k) { + public static void reduceToTopK(List>> perLeafResults, int k) { // Iterate over all scores to get min competitive score PriorityQueue topKMinQueue = new PriorityQueue<>(k); int count = 0; - for (Map perLeafResult : perLeafResults) { - count += perLeafResult.size(); - for (Float score : perLeafResult.values()) { + for (Map.Entry> perLeafResult : perLeafResults) { + Map docIdScoreMap = perLeafResult.getValue(); + count += docIdScoreMap.size(); + for (Float score : docIdScoreMap.values()) { if (topKMinQueue.size() < k) { topKMinQueue.add(score); } else if (topKMinQueue.peek() != null && score > topKMinQueue.peek()) { @@ -54,7 +56,7 @@ public static void reduceToTopK(List> perLeafResults, int k) // Reduce the results based on min competitive score float minScore = topKMinQueue.peek() == null ? -Float.MAX_VALUE : topKMinQueue.peek(); - perLeafResults.forEach(results -> results.entrySet().removeIf(entry -> entry.getValue() < minScore)); + perLeafResults.forEach(results -> results.getValue().entrySet().removeIf(entry -> entry.getValue() < minScore)); } /** diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index a34a0f1ee..53885850a 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -28,6 +28,7 @@ import org.opensearch.knn.index.query.rescore.RescoreContext; import java.io.IOException; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; @@ -55,7 +56,7 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo final IndexReader reader = indexSearcher.getIndexReader(); final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, scoreMode, 1); List leafReaderContexts = reader.leaves(); - List> perLeafResults; + List>> perLeafResults; RescoreContext rescoreContext = knnQuery.getRescoreContext(); final int finalK = knnQuery.getK(); if (rescoreContext == null) { @@ -70,14 +71,15 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo } StopWatch stopWatch = new StopWatch().start(); - perLeafResults = doRescore(indexSearcher, leafReaderContexts, knnWeight, perLeafResults, finalK); + perLeafResults = doRescore(indexSearcher, knnWeight, perLeafResults, finalK); long rescoreTime = stopWatch.stop().totalTime().millis(); log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, leafReaderContexts.size()); } ResultUtil.reduceToTopK(perLeafResults, finalK); TopDocs[] topDocs = new TopDocs[perLeafResults.size()]; - for (int i = 0; i < perLeafResults.size(); i++) { - topDocs[i] = ResultUtil.resultMapToTopDocs(perLeafResults.get(i), leafReaderContexts.get(i).docBase); + int i = 0; + for (Map.Entry> entry : perLeafResults) { + topDocs[i++] = ResultUtil.resultMapToTopDocs(entry.getValue(), entry.getKey().docBase); } TopDocs topK = TopDocs.merge(knnQuery.getK(), topDocs); @@ -87,32 +89,29 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo return createDocAndScoreQuery(reader, topK).createWeight(indexSearcher, scoreMode, boost); } - private List> doSearch( + private List>> doSearch( final IndexSearcher indexSearcher, List leafReaderContexts, KNNWeight knnWeight, int k ) throws IOException { - List>> tasks = new ArrayList<>(leafReaderContexts.size()); + List>>> tasks = new ArrayList<>(leafReaderContexts.size()); for (LeafReaderContext leafReaderContext : leafReaderContexts) { tasks.add(() -> searchLeaf(leafReaderContext, knnWeight, k)); } return indexSearcher.getTaskExecutor().invokeAll(tasks); } - private List> doRescore( + private List>> doRescore( final IndexSearcher indexSearcher, - List leafReaderContexts, KNNWeight knnWeight, - List> perLeafResults, + List>> perLeafResults, int k ) throws IOException { - List>> rescoreTasks = new ArrayList<>(leafReaderContexts.size()); - for (int i = 0; i < perLeafResults.size(); i++) { - LeafReaderContext leafReaderContext = leafReaderContexts.get(i); - int finalI = i; + List>>> rescoreTasks = new ArrayList<>(perLeafResults.size()); + for (Map.Entry> entry : perLeafResults) { rescoreTasks.add(() -> { - BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(perLeafResults.get(finalI)); + BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(entry.getValue()); final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder() .matchedDocs(convertedBitSet) // setting to false because in re-scoring we want to do exact search on full precision vectors @@ -121,7 +120,8 @@ private List> doRescore( .isParentHits(false) .knnQuery(knnQuery) .build(); - return knnWeight.exactSearch(leafReaderContext, exactSearcherContext); + final Map docIdScoreMap = knnWeight.exactSearch(entry.getKey(), exactSearcherContext); + return new AbstractMap.SimpleEntry<>(entry.getKey(), docIdScoreMap); }); } return indexSearcher.getTaskExecutor().invokeAll(rescoreTasks); @@ -158,13 +158,14 @@ static int[] findSegmentStarts(IndexReader reader, int[] docs) { return starts; } - private Map searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, int k) throws IOException { + private Map.Entry> searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, int k) + throws IOException { final Map leafDocScores = queryWeight.searchLeaf(ctx, k); final Bits liveDocs = ctx.reader().getLiveDocs(); if (liveDocs != null) { leafDocScores.entrySet().removeIf(entry -> liveDocs.get(entry.getKey()) == false); } - return leafDocScores; + return new AbstractMap.SimpleEntry<>(ctx, leafDocScores); } @Override diff --git a/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java b/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java index 70cb86e02..24b26973a 100644 --- a/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.query; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -12,6 +13,7 @@ import org.opensearch.knn.KNNTestCase; import java.io.IOException; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -19,6 +21,8 @@ import java.util.Map; import java.util.stream.Collectors; +import static org.mockito.Mockito.mock; + public class ResultUtilTests extends KNNTestCase { public void testReduceToTopK() { @@ -27,7 +31,9 @@ public void testReduceToTopK() { int segmentCount = 5; List> initialLeafResults = getRandomListOfResults(firstPassK, segmentCount); - List> reducedLeafResults = initialLeafResults.stream().map(HashMap::new).collect(Collectors.toList()); + List>> reducedLeafResults = initialLeafResults.stream() + .map(item -> new AbstractMap.SimpleEntry<>(mock(LeafReaderContext.class), item)) + .collect(Collectors.toList()); ResultUtil.reduceToTopK(reducedLeafResults, finalK); assertTopK(initialLeafResults, reducedLeafResults, finalK); @@ -36,7 +42,9 @@ public void testReduceToTopK() { segmentCount = 1; initialLeafResults = getRandomListOfResults(firstPassK, segmentCount); - reducedLeafResults = initialLeafResults.stream().map(HashMap::new).collect(Collectors.toList()); + reducedLeafResults = initialLeafResults.stream() + .map(item -> new AbstractMap.SimpleEntry<>(mock(LeafReaderContext.class), item)) + .collect(Collectors.toList()); ResultUtil.reduceToTopK(reducedLeafResults, finalK); assertTopK(initialLeafResults, reducedLeafResults, firstPassK); } @@ -75,9 +83,13 @@ private void assertResultMapToTopDocs(Map perLeafResults, TopDoc } } - private void assertTopK(List> beforeResults, List> reducedResults, int expectedK) { + private void assertTopK( + List> beforeResults, + List>> reducedResults, + int expectedK + ) { assertEquals(beforeResults.size(), reducedResults.size()); - assertEquals(expectedK, reducedResults.stream().map(Map::size).reduce(Integer::sum).orElse(-1).intValue()); + assertEquals(expectedK, reducedResults.stream().map(row -> row.getValue().size()).reduce(Integer::sum).orElse(-1).intValue()); float minScore = getMinScore(reducedResults); int count = 0; for (Map result : beforeResults) { @@ -126,10 +138,10 @@ private Map getRandomResults(int k) { return results; } - private float getMinScore(List> perLeafResults) { + private float getMinScore(List>> perLeafResults) { float minScore = Float.MAX_VALUE; - for (Map result : perLeafResults) { - for (float score : result.values()) { + for (Map.Entry> result : perLeafResults) { + for (float score : result.getValue().values()) { if (score < minScore) { minScore = score; } diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java index 4577a34d4..d02fc5e07 100644 --- a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java @@ -91,9 +91,9 @@ public void setUp() throws Exception { when(searcher.getTaskExecutor()).thenReturn(taskExecutor); when(taskExecutor.invokeAll(any())).thenAnswer(invocationOnMock -> { - List>> callables = invocationOnMock.getArgument(0); - List> results = new ArrayList<>(); - for (Callable> callable : callables) { + List>>> callables = invocationOnMock.getArgument(0); + List>> results = new ArrayList<>(); + for (Callable>> callable : callables) { results.add(callable.call()); } return results; From b149309af102790dfce386860ad26b69d6355a75 Mon Sep 17 00:00:00 2001 From: Vijayan Balasubramanian Date: Thu, 14 Nov 2024 12:05:17 -0800 Subject: [PATCH 2/2] Segregate segments based on search type For exact search, it is not required to perform qunatization during rescore with oversamples. However, to avoid normalization between segments from approx search and exact search, we will first identify segments that needs approxsearch and will perform oversamples and, at end, after rescore, we will add scores from segments that will perform exact search. Signed-off-by: Vijayan Balasubramanian --- .../opensearch/knn/index/query/KNNWeight.java | 20 ++++++++++ .../nativelib/NativeEngineKnnVectorQuery.java | 38 ++++++++++++++++++- 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 04c2ce587..073737476 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -153,6 +153,26 @@ public Map searchLeaf(LeafReaderContext context, int k) throws I return docIdsToScoreMap; } + /** + * For given {@link LeafReaderContext}, this api will return will KNNWeight perform exact search or not + * always. This decision is based on two properties, 1) if there are no native engine files in segments, + * exact search will always be performed, 2) if number of docs after filter is less than 'k' + * @param context + * @return + * @throws IOException + */ + public boolean isExactSearchPreferred(LeafReaderContext context) throws IOException { + final BitSet filterBitSet = getFilteredDocsBitSet(context); + int cardinality = filterBitSet.cardinality(); + if (isFilteredExactSearchPreferred(cardinality)) { + return true; + } + if (isMissingNativeEngineFiles(context)) { + return true; + } + return false; + } + private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException { if (this.filterWeight == null) { return new FixedBitSet(0); diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index 53885850a..143f74fe1 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -65,7 +65,17 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo boolean isShardLevelRescoringEnabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(knnQuery.getIndexName()); int dimension = knnQuery.getQueryVector().length; int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension); - perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK); + // split segments into whether exact search will be performed or not + List exactSearchSegments = new ArrayList<>(); + List approxSearchSegments = new ArrayList<>(); + for (LeafReaderContext leafReaderContext : leafReaderContexts) { + if (knnWeight.isExactSearchPreferred(leafReaderContext)) { + exactSearchSegments.add(leafReaderContext); + } else { + approxSearchSegments.add(leafReaderContext); + } + } + perLeafResults = doSearch(indexSearcher, approxSearchSegments, knnWeight, firstPassK); if (isShardLevelRescoringEnabled == true) { ResultUtil.reduceToTopK(perLeafResults, firstPassK); } @@ -73,7 +83,9 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo StopWatch stopWatch = new StopWatch().start(); perLeafResults = doRescore(indexSearcher, knnWeight, perLeafResults, finalK); long rescoreTime = stopWatch.stop().totalTime().millis(); - log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, leafReaderContexts.size()); + log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, perLeafResults.size()); + // do exact search on rest of segments and append to result lists + perLeafResults.addAll(doExactSearch(indexSearcher, knnWeight, exactSearchSegments)); } ResultUtil.reduceToTopK(perLeafResults, finalK); TopDocs[] topDocs = new TopDocs[perLeafResults.size()]; @@ -127,6 +139,28 @@ private List>> doRescore( return indexSearcher.getTaskExecutor().invokeAll(rescoreTasks); } + private List>> doExactSearch( + final IndexSearcher indexSearcher, + KNNWeight knnWeight, + List leafReaderContexts + ) throws IOException { + List>>> exactSearchTasks = new ArrayList<>(leafReaderContexts.size()); + for (LeafReaderContext context : leafReaderContexts) { + exactSearchTasks.add(() -> { + final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder() + // setting to false because we want to do exact search on full precision vectors + .useQuantizedVectorsForSearch(false) + .k(knnQuery.getK()) + .knnQuery(knnQuery) + .isParentHits(true) + .build(); + final Map searchResults = knnWeight.exactSearch(context, exactSearcherContext); + return new AbstractMap.SimpleEntry<>(context, searchResults); + }); + } + return indexSearcher.getTaskExecutor().invokeAll(exactSearchTasks); + } + private Query createDocAndScoreQuery(IndexReader reader, TopDocs topK) { int len = topK.scoreDocs.length; Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));