Skip to content

Commit

Permalink
Refactor scoring to map leafreader context with results
Browse files Browse the repository at this point in the history
We map order of results to order of segments, and finally
rely on that order to build top docs. Refactor method to use
map.Entry to map leafreader context with results from those
leaves.
This is required when we want to split segments based on
approx search or exact search to reduce rescoring twice
by exact search

Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Nov 14, 2024
1 parent a07bad1 commit a8b3c1d
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 33 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Maintenance
* Select index settings based on cluster version[2236](https://github.com/opensearch-project/k-NN/pull/2236)
### Refactoring
* Refactor scoring to map leaf reader context with results[2271](https://github.com/opensearch-project/k-NN/pull/2271)
12 changes: 7 additions & 5 deletions src/main/java/org/opensearch/knn/index/query/ResultUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.index.query;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
Expand All @@ -30,14 +31,15 @@ public final class ResultUtil {
* @param perLeafResults Results from the list
* @param k the number of results across all leaf results to return
*/
public static void reduceToTopK(List<Map<Integer, Float>> perLeafResults, int k) {
public static void reduceToTopK(List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> perLeafResults, int k) {
// Iterate over all scores to get min competitive score
PriorityQueue<Float> topKMinQueue = new PriorityQueue<>(k);

int count = 0;
for (Map<Integer, Float> perLeafResult : perLeafResults) {
count += perLeafResult.size();
for (Float score : perLeafResult.values()) {
for (Map.Entry<LeafReaderContext, Map<Integer, Float>> perLeafResult : perLeafResults) {
Map<Integer, Float> docIdScoreMap = perLeafResult.getValue();
count += docIdScoreMap.size();
for (Float score : docIdScoreMap.values()) {
if (topKMinQueue.size() < k) {
topKMinQueue.add(score);
} else if (topKMinQueue.peek() != null && score > topKMinQueue.peek()) {
Expand All @@ -54,7 +56,7 @@ public static void reduceToTopK(List<Map<Integer, Float>> perLeafResults, int k)

// Reduce the results based on min competitive score
float minScore = topKMinQueue.peek() == null ? -Float.MAX_VALUE : topKMinQueue.peek();
perLeafResults.forEach(results -> results.entrySet().removeIf(entry -> entry.getValue() < minScore));
perLeafResults.forEach(results -> results.getValue().entrySet().removeIf(entry -> entry.getValue() < minScore));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.opensearch.knn.index.query.rescore.RescoreContext;

import java.io.IOException;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
Expand Down Expand Up @@ -55,7 +56,7 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
final IndexReader reader = indexSearcher.getIndexReader();
final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, scoreMode, 1);
List<LeafReaderContext> leafReaderContexts = reader.leaves();
List<Map<Integer, Float>> perLeafResults;
List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> perLeafResults;
RescoreContext rescoreContext = knnQuery.getRescoreContext();
final int finalK = knnQuery.getK();
if (rescoreContext == null) {
Expand All @@ -70,14 +71,15 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
}

StopWatch stopWatch = new StopWatch().start();
perLeafResults = doRescore(indexSearcher, leafReaderContexts, knnWeight, perLeafResults, finalK);
perLeafResults = doRescore(indexSearcher, knnWeight, perLeafResults, finalK);
long rescoreTime = stopWatch.stop().totalTime().millis();
log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, leafReaderContexts.size());
log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, perLeafResults.size());
}
ResultUtil.reduceToTopK(perLeafResults, finalK);
TopDocs[] topDocs = new TopDocs[perLeafResults.size()];
for (int i = 0; i < perLeafResults.size(); i++) {
topDocs[i] = ResultUtil.resultMapToTopDocs(perLeafResults.get(i), leafReaderContexts.get(i).docBase);
int i = 0;
for (Map.Entry<LeafReaderContext, Map<Integer, Float>> entry : perLeafResults) {
topDocs[i++] = ResultUtil.resultMapToTopDocs(entry.getValue(), entry.getKey().docBase);
}

TopDocs topK = TopDocs.merge(knnQuery.getK(), topDocs);
Expand All @@ -87,32 +89,29 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
return createDocAndScoreQuery(reader, topK).createWeight(indexSearcher, scoreMode, boost);
}

private List<Map<Integer, Float>> doSearch(
private List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> doSearch(
final IndexSearcher indexSearcher,
List<LeafReaderContext> leafReaderContexts,
KNNWeight knnWeight,
int k
) throws IOException {
List<Callable<Map<Integer, Float>>> tasks = new ArrayList<>(leafReaderContexts.size());
List<Callable<Map.Entry<LeafReaderContext, Map<Integer, Float>>>> tasks = new ArrayList<>(leafReaderContexts.size());
for (LeafReaderContext leafReaderContext : leafReaderContexts) {
tasks.add(() -> searchLeaf(leafReaderContext, knnWeight, k));
}
return indexSearcher.getTaskExecutor().invokeAll(tasks);
}

private List<Map<Integer, Float>> doRescore(
private List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> doRescore(
final IndexSearcher indexSearcher,
List<LeafReaderContext> leafReaderContexts,
KNNWeight knnWeight,
List<Map<Integer, Float>> perLeafResults,
List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> perLeafResults,
int k
) throws IOException {
List<Callable<Map<Integer, Float>>> rescoreTasks = new ArrayList<>(leafReaderContexts.size());
for (int i = 0; i < perLeafResults.size(); i++) {
LeafReaderContext leafReaderContext = leafReaderContexts.get(i);
int finalI = i;
List<Callable<Map.Entry<LeafReaderContext, Map<Integer, Float>>>> rescoreTasks = new ArrayList<>(perLeafResults.size());
for (Map.Entry<LeafReaderContext, Map<Integer, Float>> entry : perLeafResults) {
rescoreTasks.add(() -> {
BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(perLeafResults.get(finalI));
BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(entry.getValue());
final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder()
.matchedDocs(convertedBitSet)
// setting to false because in re-scoring we want to do exact search on full precision vectors
Expand All @@ -121,7 +120,8 @@ private List<Map<Integer, Float>> doRescore(
.isParentHits(false)
.knnQuery(knnQuery)
.build();
return knnWeight.exactSearch(leafReaderContext, exactSearcherContext);
final Map<Integer, Float> docIdScoreMap = knnWeight.exactSearch(entry.getKey(), exactSearcherContext);
return new AbstractMap.SimpleEntry<>(entry.getKey(), docIdScoreMap);
});
}
return indexSearcher.getTaskExecutor().invokeAll(rescoreTasks);
Expand Down Expand Up @@ -158,13 +158,14 @@ static int[] findSegmentStarts(IndexReader reader, int[] docs) {
return starts;
}

private Map<Integer, Float> searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, int k) throws IOException {
private Map.Entry<LeafReaderContext, Map<Integer, Float>> searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, int k)
throws IOException {
final Map<Integer, Float> leafDocScores = queryWeight.searchLeaf(ctx, k);
final Bits liveDocs = ctx.reader().getLiveDocs();
if (liveDocs != null) {
leafDocScores.entrySet().removeIf(entry -> liveDocs.get(entry.getKey()) == false);
}
return leafDocScores;
return new AbstractMap.SimpleEntry<>(ctx, leafDocScores);
}

@Override
Expand Down
26 changes: 19 additions & 7 deletions src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,24 @@

package org.opensearch.knn.index.query;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.BitSet;
import org.opensearch.knn.KNNTestCase;

import java.io.IOException;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static org.mockito.Mockito.mock;

public class ResultUtilTests extends KNNTestCase {

public void testReduceToTopK() {
Expand All @@ -27,7 +31,9 @@ public void testReduceToTopK() {
int segmentCount = 5;

List<Map<Integer, Float>> initialLeafResults = getRandomListOfResults(firstPassK, segmentCount);
List<Map<Integer, Float>> reducedLeafResults = initialLeafResults.stream().map(HashMap::new).collect(Collectors.toList());
List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> reducedLeafResults = initialLeafResults.stream()
.map(item -> new AbstractMap.SimpleEntry<>(mock(LeafReaderContext.class), item))
.collect(Collectors.toList());
ResultUtil.reduceToTopK(reducedLeafResults, finalK);
assertTopK(initialLeafResults, reducedLeafResults, finalK);

Expand All @@ -36,7 +42,9 @@ public void testReduceToTopK() {
segmentCount = 1;

initialLeafResults = getRandomListOfResults(firstPassK, segmentCount);
reducedLeafResults = initialLeafResults.stream().map(HashMap::new).collect(Collectors.toList());
reducedLeafResults = initialLeafResults.stream()
.map(item -> new AbstractMap.SimpleEntry<>(mock(LeafReaderContext.class), item))
.collect(Collectors.toList());
ResultUtil.reduceToTopK(reducedLeafResults, finalK);
assertTopK(initialLeafResults, reducedLeafResults, firstPassK);
}
Expand Down Expand Up @@ -75,9 +83,13 @@ private void assertResultMapToTopDocs(Map<Integer, Float> perLeafResults, TopDoc
}
}

private void assertTopK(List<Map<Integer, Float>> beforeResults, List<Map<Integer, Float>> reducedResults, int expectedK) {
private void assertTopK(
List<Map<Integer, Float>> beforeResults,
List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> reducedResults,
int expectedK
) {
assertEquals(beforeResults.size(), reducedResults.size());
assertEquals(expectedK, reducedResults.stream().map(Map::size).reduce(Integer::sum).orElse(-1).intValue());
assertEquals(expectedK, reducedResults.stream().map(row -> row.getValue().size()).reduce(Integer::sum).orElse(-1).intValue());
float minScore = getMinScore(reducedResults);
int count = 0;
for (Map<Integer, Float> result : beforeResults) {
Expand Down Expand Up @@ -126,10 +138,10 @@ private Map<Integer, Float> getRandomResults(int k) {
return results;
}

private float getMinScore(List<Map<Integer, Float>> perLeafResults) {
private float getMinScore(List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> perLeafResults) {
float minScore = Float.MAX_VALUE;
for (Map<Integer, Float> result : perLeafResults) {
for (float score : result.values()) {
for (Map.Entry<LeafReaderContext, Map<Integer, Float>> result : perLeafResults) {
for (float score : result.getValue().values()) {
if (score < minScore) {
minScore = score;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ public void setUp() throws Exception {

when(searcher.getTaskExecutor()).thenReturn(taskExecutor);
when(taskExecutor.invokeAll(any())).thenAnswer(invocationOnMock -> {
List<Callable<Map<Integer, Float>>> callables = invocationOnMock.getArgument(0);
List<Map<Integer, Float>> results = new ArrayList<>();
for (Callable<Map<Integer, Float>> callable : callables) {
List<Callable<Map.Entry<LeafReaderContext, Map<Integer, Float>>>> callables = invocationOnMock.getArgument(0);
List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> results = new ArrayList<>();
for (Callable<Map.Entry<LeafReaderContext, Map<Integer, Float>>> callable : callables) {
results.add(callable.call());
}
return results;
Expand Down

0 comments on commit a8b3c1d

Please sign in to comment.