From a8b3c1df3a478efb574e81a5598e55086834e53d Mon Sep 17 00:00:00 2001 From: Vijayan Balasubramanian Date: Thu, 14 Nov 2024 11:36:56 -0800 Subject: [PATCH] Refactor scoring to map leafreader context with results We map order of results to order of segments, and finally rely on that order to build top docs. Refactor method to use map.Entry to map leafreader context with results from those leaves. This is required when we want to split segments based on approx search or exact search to reduce rescoring twice by exact search Signed-off-by: Vijayan Balasubramanian --- CHANGELOG.md | 1 + .../knn/index/query/ResultUtil.java | 12 +++--- .../nativelib/NativeEngineKnnVectorQuery.java | 37 ++++++++++--------- .../knn/index/query/ResultUtilTests.java | 26 +++++++++---- .../NativeEngineKNNVectorQueryTests.java | 6 +-- 5 files changed, 49 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c57523c3..af4068451 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,3 +25,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Maintenance * Select index settings based on cluster version[2236](https://github.com/opensearch-project/k-NN/pull/2236) ### Refactoring +* Refactor scoring to map leaf reader context with results[2271](https://github.com/opensearch-project/k-NN/pull/2271) diff --git a/src/main/java/org/opensearch/knn/index/query/ResultUtil.java b/src/main/java/org/opensearch/knn/index/query/ResultUtil.java index f62c09cb0..9454b94a2 100644 --- a/src/main/java/org/opensearch/knn/index/query/ResultUtil.java +++ b/src/main/java/org/opensearch/knn/index/query/ResultUtil.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.query; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -30,14 +31,15 @@ public final class ResultUtil { * @param perLeafResults Results from the list * @param k the number of results across all leaf results to return */ - public static void reduceToTopK(List> perLeafResults, int k) { + public static void reduceToTopK(List>> perLeafResults, int k) { // Iterate over all scores to get min competitive score PriorityQueue topKMinQueue = new PriorityQueue<>(k); int count = 0; - for (Map perLeafResult : perLeafResults) { - count += perLeafResult.size(); - for (Float score : perLeafResult.values()) { + for (Map.Entry> perLeafResult : perLeafResults) { + Map docIdScoreMap = perLeafResult.getValue(); + count += docIdScoreMap.size(); + for (Float score : docIdScoreMap.values()) { if (topKMinQueue.size() < k) { topKMinQueue.add(score); } else if (topKMinQueue.peek() != null && score > topKMinQueue.peek()) { @@ -54,7 +56,7 @@ public static void reduceToTopK(List> perLeafResults, int k) // Reduce the results based on min competitive score float minScore = topKMinQueue.peek() == null ? -Float.MAX_VALUE : topKMinQueue.peek(); - perLeafResults.forEach(results -> results.entrySet().removeIf(entry -> entry.getValue() < minScore)); + perLeafResults.forEach(results -> results.getValue().entrySet().removeIf(entry -> entry.getValue() < minScore)); } /** diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index a34a0f1ee..5435451dd 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -28,6 +28,7 @@ import org.opensearch.knn.index.query.rescore.RescoreContext; import java.io.IOException; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; @@ -55,7 +56,7 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo final IndexReader reader = indexSearcher.getIndexReader(); final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, scoreMode, 1); List leafReaderContexts = reader.leaves(); - List> perLeafResults; + List>> perLeafResults; RescoreContext rescoreContext = knnQuery.getRescoreContext(); final int finalK = knnQuery.getK(); if (rescoreContext == null) { @@ -70,14 +71,15 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo } StopWatch stopWatch = new StopWatch().start(); - perLeafResults = doRescore(indexSearcher, leafReaderContexts, knnWeight, perLeafResults, finalK); + perLeafResults = doRescore(indexSearcher, knnWeight, perLeafResults, finalK); long rescoreTime = stopWatch.stop().totalTime().millis(); - log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, leafReaderContexts.size()); + log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, perLeafResults.size()); } ResultUtil.reduceToTopK(perLeafResults, finalK); TopDocs[] topDocs = new TopDocs[perLeafResults.size()]; - for (int i = 0; i < perLeafResults.size(); i++) { - topDocs[i] = ResultUtil.resultMapToTopDocs(perLeafResults.get(i), leafReaderContexts.get(i).docBase); + int i = 0; + for (Map.Entry> entry : perLeafResults) { + topDocs[i++] = ResultUtil.resultMapToTopDocs(entry.getValue(), entry.getKey().docBase); } TopDocs topK = TopDocs.merge(knnQuery.getK(), topDocs); @@ -87,32 +89,29 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo return createDocAndScoreQuery(reader, topK).createWeight(indexSearcher, scoreMode, boost); } - private List> doSearch( + private List>> doSearch( final IndexSearcher indexSearcher, List leafReaderContexts, KNNWeight knnWeight, int k ) throws IOException { - List>> tasks = new ArrayList<>(leafReaderContexts.size()); + List>>> tasks = new ArrayList<>(leafReaderContexts.size()); for (LeafReaderContext leafReaderContext : leafReaderContexts) { tasks.add(() -> searchLeaf(leafReaderContext, knnWeight, k)); } return indexSearcher.getTaskExecutor().invokeAll(tasks); } - private List> doRescore( + private List>> doRescore( final IndexSearcher indexSearcher, - List leafReaderContexts, KNNWeight knnWeight, - List> perLeafResults, + List>> perLeafResults, int k ) throws IOException { - List>> rescoreTasks = new ArrayList<>(leafReaderContexts.size()); - for (int i = 0; i < perLeafResults.size(); i++) { - LeafReaderContext leafReaderContext = leafReaderContexts.get(i); - int finalI = i; + List>>> rescoreTasks = new ArrayList<>(perLeafResults.size()); + for (Map.Entry> entry : perLeafResults) { rescoreTasks.add(() -> { - BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(perLeafResults.get(finalI)); + BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(entry.getValue()); final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder() .matchedDocs(convertedBitSet) // setting to false because in re-scoring we want to do exact search on full precision vectors @@ -121,7 +120,8 @@ private List> doRescore( .isParentHits(false) .knnQuery(knnQuery) .build(); - return knnWeight.exactSearch(leafReaderContext, exactSearcherContext); + final Map docIdScoreMap = knnWeight.exactSearch(entry.getKey(), exactSearcherContext); + return new AbstractMap.SimpleEntry<>(entry.getKey(), docIdScoreMap); }); } return indexSearcher.getTaskExecutor().invokeAll(rescoreTasks); @@ -158,13 +158,14 @@ static int[] findSegmentStarts(IndexReader reader, int[] docs) { return starts; } - private Map searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, int k) throws IOException { + private Map.Entry> searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, int k) + throws IOException { final Map leafDocScores = queryWeight.searchLeaf(ctx, k); final Bits liveDocs = ctx.reader().getLiveDocs(); if (liveDocs != null) { leafDocScores.entrySet().removeIf(entry -> liveDocs.get(entry.getKey()) == false); } - return leafDocScores; + return new AbstractMap.SimpleEntry<>(ctx, leafDocScores); } @Override diff --git a/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java b/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java index 70cb86e02..24b26973a 100644 --- a/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.query; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -12,6 +13,7 @@ import org.opensearch.knn.KNNTestCase; import java.io.IOException; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -19,6 +21,8 @@ import java.util.Map; import java.util.stream.Collectors; +import static org.mockito.Mockito.mock; + public class ResultUtilTests extends KNNTestCase { public void testReduceToTopK() { @@ -27,7 +31,9 @@ public void testReduceToTopK() { int segmentCount = 5; List> initialLeafResults = getRandomListOfResults(firstPassK, segmentCount); - List> reducedLeafResults = initialLeafResults.stream().map(HashMap::new).collect(Collectors.toList()); + List>> reducedLeafResults = initialLeafResults.stream() + .map(item -> new AbstractMap.SimpleEntry<>(mock(LeafReaderContext.class), item)) + .collect(Collectors.toList()); ResultUtil.reduceToTopK(reducedLeafResults, finalK); assertTopK(initialLeafResults, reducedLeafResults, finalK); @@ -36,7 +42,9 @@ public void testReduceToTopK() { segmentCount = 1; initialLeafResults = getRandomListOfResults(firstPassK, segmentCount); - reducedLeafResults = initialLeafResults.stream().map(HashMap::new).collect(Collectors.toList()); + reducedLeafResults = initialLeafResults.stream() + .map(item -> new AbstractMap.SimpleEntry<>(mock(LeafReaderContext.class), item)) + .collect(Collectors.toList()); ResultUtil.reduceToTopK(reducedLeafResults, finalK); assertTopK(initialLeafResults, reducedLeafResults, firstPassK); } @@ -75,9 +83,13 @@ private void assertResultMapToTopDocs(Map perLeafResults, TopDoc } } - private void assertTopK(List> beforeResults, List> reducedResults, int expectedK) { + private void assertTopK( + List> beforeResults, + List>> reducedResults, + int expectedK + ) { assertEquals(beforeResults.size(), reducedResults.size()); - assertEquals(expectedK, reducedResults.stream().map(Map::size).reduce(Integer::sum).orElse(-1).intValue()); + assertEquals(expectedK, reducedResults.stream().map(row -> row.getValue().size()).reduce(Integer::sum).orElse(-1).intValue()); float minScore = getMinScore(reducedResults); int count = 0; for (Map result : beforeResults) { @@ -126,10 +138,10 @@ private Map getRandomResults(int k) { return results; } - private float getMinScore(List> perLeafResults) { + private float getMinScore(List>> perLeafResults) { float minScore = Float.MAX_VALUE; - for (Map result : perLeafResults) { - for (float score : result.values()) { + for (Map.Entry> result : perLeafResults) { + for (float score : result.getValue().values()) { if (score < minScore) { minScore = score; } diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java index 4577a34d4..d02fc5e07 100644 --- a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java @@ -91,9 +91,9 @@ public void setUp() throws Exception { when(searcher.getTaskExecutor()).thenReturn(taskExecutor); when(taskExecutor.invokeAll(any())).thenAnswer(invocationOnMock -> { - List>> callables = invocationOnMock.getArgument(0); - List> results = new ArrayList<>(); - for (Callable> callable : callables) { + List>>> callables = invocationOnMock.getArgument(0); + List>> results = new ArrayList<>(); + for (Callable>> callable : callables) { results.add(callable.call()); } return results;