diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c57523c3..af4068451 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,3 +25,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Maintenance * Select index settings based on cluster version[2236](https://github.com/opensearch-project/k-NN/pull/2236) ### Refactoring +* Refactor scoring to map leaf reader context with results[2271](https://github.com/opensearch-project/k-NN/pull/2271) diff --git a/src/main/java/org/opensearch/knn/index/query/ResultUtil.java b/src/main/java/org/opensearch/knn/index/query/ResultUtil.java index f62c09cb0..9454b94a2 100644 --- a/src/main/java/org/opensearch/knn/index/query/ResultUtil.java +++ b/src/main/java/org/opensearch/knn/index/query/ResultUtil.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.query; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -30,14 +31,15 @@ public final class ResultUtil { * @param perLeafResults Results from the list * @param k the number of results across all leaf results to return */ - public static void reduceToTopK(List> perLeafResults, int k) { + public static void reduceToTopK(List>> perLeafResults, int k) { // Iterate over all scores to get min competitive score PriorityQueue topKMinQueue = new PriorityQueue<>(k); int count = 0; - for (Map perLeafResult : perLeafResults) { - count += perLeafResult.size(); - for (Float score : perLeafResult.values()) { + for (Map.Entry> perLeafResult : perLeafResults) { + Map docIdScoreMap = perLeafResult.getValue(); + count += docIdScoreMap.size(); + for (Float score : docIdScoreMap.values()) { if (topKMinQueue.size() < k) { topKMinQueue.add(score); } else if (topKMinQueue.peek() != null && score > topKMinQueue.peek()) { @@ -54,7 +56,7 @@ public static void reduceToTopK(List> perLeafResults, int k) // Reduce the results based on min competitive score float minScore = topKMinQueue.peek() == null ? -Float.MAX_VALUE : topKMinQueue.peek(); - perLeafResults.forEach(results -> results.entrySet().removeIf(entry -> entry.getValue() < minScore)); + perLeafResults.forEach(results -> results.getValue().entrySet().removeIf(entry -> entry.getValue() < minScore)); } /** diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index a34a0f1ee..5435451dd 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -28,6 +28,7 @@ import org.opensearch.knn.index.query.rescore.RescoreContext; import java.io.IOException; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; @@ -55,7 +56,7 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo final IndexReader reader = indexSearcher.getIndexReader(); final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, scoreMode, 1); List leafReaderContexts = reader.leaves(); - List> perLeafResults; + List>> perLeafResults; RescoreContext rescoreContext = knnQuery.getRescoreContext(); final int finalK = knnQuery.getK(); if (rescoreContext == null) { @@ -70,14 +71,15 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo } StopWatch stopWatch = new StopWatch().start(); - perLeafResults = doRescore(indexSearcher, leafReaderContexts, knnWeight, perLeafResults, finalK); + perLeafResults = doRescore(indexSearcher, knnWeight, perLeafResults, finalK); long rescoreTime = stopWatch.stop().totalTime().millis(); - log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, leafReaderContexts.size()); + log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, perLeafResults.size()); } ResultUtil.reduceToTopK(perLeafResults, finalK); TopDocs[] topDocs = new TopDocs[perLeafResults.size()]; - for (int i = 0; i < perLeafResults.size(); i++) { - topDocs[i] = ResultUtil.resultMapToTopDocs(perLeafResults.get(i), leafReaderContexts.get(i).docBase); + int i = 0; + for (Map.Entry> entry : perLeafResults) { + topDocs[i++] = ResultUtil.resultMapToTopDocs(entry.getValue(), entry.getKey().docBase); } TopDocs topK = TopDocs.merge(knnQuery.getK(), topDocs); @@ -87,32 +89,29 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo return createDocAndScoreQuery(reader, topK).createWeight(indexSearcher, scoreMode, boost); } - private List> doSearch( + private List>> doSearch( final IndexSearcher indexSearcher, List leafReaderContexts, KNNWeight knnWeight, int k ) throws IOException { - List>> tasks = new ArrayList<>(leafReaderContexts.size()); + List>>> tasks = new ArrayList<>(leafReaderContexts.size()); for (LeafReaderContext leafReaderContext : leafReaderContexts) { tasks.add(() -> searchLeaf(leafReaderContext, knnWeight, k)); } return indexSearcher.getTaskExecutor().invokeAll(tasks); } - private List> doRescore( + private List>> doRescore( final IndexSearcher indexSearcher, - List leafReaderContexts, KNNWeight knnWeight, - List> perLeafResults, + List>> perLeafResults, int k ) throws IOException { - List>> rescoreTasks = new ArrayList<>(leafReaderContexts.size()); - for (int i = 0; i < perLeafResults.size(); i++) { - LeafReaderContext leafReaderContext = leafReaderContexts.get(i); - int finalI = i; + List>>> rescoreTasks = new ArrayList<>(perLeafResults.size()); + for (Map.Entry> entry : perLeafResults) { rescoreTasks.add(() -> { - BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(perLeafResults.get(finalI)); + BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(entry.getValue()); final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder() .matchedDocs(convertedBitSet) // setting to false because in re-scoring we want to do exact search on full precision vectors @@ -121,7 +120,8 @@ private List> doRescore( .isParentHits(false) .knnQuery(knnQuery) .build(); - return knnWeight.exactSearch(leafReaderContext, exactSearcherContext); + final Map docIdScoreMap = knnWeight.exactSearch(entry.getKey(), exactSearcherContext); + return new AbstractMap.SimpleEntry<>(entry.getKey(), docIdScoreMap); }); } return indexSearcher.getTaskExecutor().invokeAll(rescoreTasks); @@ -158,13 +158,14 @@ static int[] findSegmentStarts(IndexReader reader, int[] docs) { return starts; } - private Map searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, int k) throws IOException { + private Map.Entry> searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, int k) + throws IOException { final Map leafDocScores = queryWeight.searchLeaf(ctx, k); final Bits liveDocs = ctx.reader().getLiveDocs(); if (liveDocs != null) { leafDocScores.entrySet().removeIf(entry -> liveDocs.get(entry.getKey()) == false); } - return leafDocScores; + return new AbstractMap.SimpleEntry<>(ctx, leafDocScores); } @Override diff --git a/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java b/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java index 70cb86e02..24b26973a 100644 --- a/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.query; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -12,6 +13,7 @@ import org.opensearch.knn.KNNTestCase; import java.io.IOException; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -19,6 +21,8 @@ import java.util.Map; import java.util.stream.Collectors; +import static org.mockito.Mockito.mock; + public class ResultUtilTests extends KNNTestCase { public void testReduceToTopK() { @@ -27,7 +31,9 @@ public void testReduceToTopK() { int segmentCount = 5; List> initialLeafResults = getRandomListOfResults(firstPassK, segmentCount); - List> reducedLeafResults = initialLeafResults.stream().map(HashMap::new).collect(Collectors.toList()); + List>> reducedLeafResults = initialLeafResults.stream() + .map(item -> new AbstractMap.SimpleEntry<>(mock(LeafReaderContext.class), item)) + .collect(Collectors.toList()); ResultUtil.reduceToTopK(reducedLeafResults, finalK); assertTopK(initialLeafResults, reducedLeafResults, finalK); @@ -36,7 +42,9 @@ public void testReduceToTopK() { segmentCount = 1; initialLeafResults = getRandomListOfResults(firstPassK, segmentCount); - reducedLeafResults = initialLeafResults.stream().map(HashMap::new).collect(Collectors.toList()); + reducedLeafResults = initialLeafResults.stream() + .map(item -> new AbstractMap.SimpleEntry<>(mock(LeafReaderContext.class), item)) + .collect(Collectors.toList()); ResultUtil.reduceToTopK(reducedLeafResults, finalK); assertTopK(initialLeafResults, reducedLeafResults, firstPassK); } @@ -75,9 +83,13 @@ private void assertResultMapToTopDocs(Map perLeafResults, TopDoc } } - private void assertTopK(List> beforeResults, List> reducedResults, int expectedK) { + private void assertTopK( + List> beforeResults, + List>> reducedResults, + int expectedK + ) { assertEquals(beforeResults.size(), reducedResults.size()); - assertEquals(expectedK, reducedResults.stream().map(Map::size).reduce(Integer::sum).orElse(-1).intValue()); + assertEquals(expectedK, reducedResults.stream().map(row -> row.getValue().size()).reduce(Integer::sum).orElse(-1).intValue()); float minScore = getMinScore(reducedResults); int count = 0; for (Map result : beforeResults) { @@ -126,10 +138,10 @@ private Map getRandomResults(int k) { return results; } - private float getMinScore(List> perLeafResults) { + private float getMinScore(List>> perLeafResults) { float minScore = Float.MAX_VALUE; - for (Map result : perLeafResults) { - for (float score : result.values()) { + for (Map.Entry> result : perLeafResults) { + for (float score : result.getValue().values()) { if (score < minScore) { minScore = score; } diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java index 4577a34d4..d02fc5e07 100644 --- a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java @@ -91,9 +91,9 @@ public void setUp() throws Exception { when(searcher.getTaskExecutor()).thenReturn(taskExecutor); when(taskExecutor.invokeAll(any())).thenAnswer(invocationOnMock -> { - List>> callables = invocationOnMock.getArgument(0); - List> results = new ArrayList<>(); - for (Callable> callable : callables) { + List>>> callables = invocationOnMock.getArgument(0); + List>> results = new ArrayList<>(); + for (Callable>> callable : callables) { results.add(callable.call()); } return results;