Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor scoring to map leaf reader context with results #2271

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Maintenance
* Select index settings based on cluster version[2236](https://github.com/opensearch-project/k-NN/pull/2236)
### Refactoring
* Refactor scoring to map leaf reader context with results[2271](https://github.com/opensearch-project/k-NN/pull/2271)
12 changes: 7 additions & 5 deletions src/main/java/org/opensearch/knn/index/query/ResultUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.index.query;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
Expand All @@ -30,14 +31,15 @@ public final class ResultUtil {
* @param perLeafResults Results from the list
* @param k the number of results across all leaf results to return
*/
public static void reduceToTopK(List<Map<Integer, Float>> perLeafResults, int k) {
public static void reduceToTopK(List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> perLeafResults, int k) {
// Iterate over all scores to get min competitive score
PriorityQueue<Float> topKMinQueue = new PriorityQueue<>(k);

int count = 0;
for (Map<Integer, Float> perLeafResult : perLeafResults) {
count += perLeafResult.size();
for (Float score : perLeafResult.values()) {
for (Map.Entry<LeafReaderContext, Map<Integer, Float>> perLeafResult : perLeafResults) {
Map<Integer, Float> docIdScoreMap = perLeafResult.getValue();
count += docIdScoreMap.size();
for (Float score : docIdScoreMap.values()) {
if (topKMinQueue.size() < k) {
topKMinQueue.add(score);
} else if (topKMinQueue.peek() != null && score > topKMinQueue.peek()) {
Expand All @@ -54,7 +56,7 @@ public static void reduceToTopK(List<Map<Integer, Float>> perLeafResults, int k)

// Reduce the results based on min competitive score
float minScore = topKMinQueue.peek() == null ? -Float.MAX_VALUE : topKMinQueue.peek();
perLeafResults.forEach(results -> results.entrySet().removeIf(entry -> entry.getValue() < minScore));
perLeafResults.forEach(results -> results.getValue().entrySet().removeIf(entry -> entry.getValue() < minScore));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.opensearch.knn.index.query.rescore.RescoreContext;

import java.io.IOException;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
Expand Down Expand Up @@ -55,7 +56,7 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
final IndexReader reader = indexSearcher.getIndexReader();
final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, scoreMode, 1);
List<LeafReaderContext> leafReaderContexts = reader.leaves();
List<Map<Integer, Float>> perLeafResults;
List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> perLeafResults;
RescoreContext rescoreContext = knnQuery.getRescoreContext();
final int finalK = knnQuery.getK();
if (rescoreContext == null) {
Expand All @@ -70,14 +71,15 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
}

StopWatch stopWatch = new StopWatch().start();
perLeafResults = doRescore(indexSearcher, leafReaderContexts, knnWeight, perLeafResults, finalK);
perLeafResults = doRescore(indexSearcher, knnWeight, perLeafResults, finalK);
long rescoreTime = stopWatch.stop().totalTime().millis();
log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, leafReaderContexts.size());
log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, perLeafResults.size());
}
ResultUtil.reduceToTopK(perLeafResults, finalK);
TopDocs[] topDocs = new TopDocs[perLeafResults.size()];
for (int i = 0; i < perLeafResults.size(); i++) {
topDocs[i] = ResultUtil.resultMapToTopDocs(perLeafResults.get(i), leafReaderContexts.get(i).docBase);
int i = 0;
for (Map.Entry<LeafReaderContext, Map<Integer, Float>> entry : perLeafResults) {
topDocs[i++] = ResultUtil.resultMapToTopDocs(entry.getValue(), entry.getKey().docBase);
}

TopDocs topK = TopDocs.merge(knnQuery.getK(), topDocs);
Expand All @@ -87,32 +89,29 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
return createDocAndScoreQuery(reader, topK).createWeight(indexSearcher, scoreMode, boost);
}

private List<Map<Integer, Float>> doSearch(
private List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> doSearch(
final IndexSearcher indexSearcher,
List<LeafReaderContext> leafReaderContexts,
KNNWeight knnWeight,
int k
) throws IOException {
List<Callable<Map<Integer, Float>>> tasks = new ArrayList<>(leafReaderContexts.size());
List<Callable<Map.Entry<LeafReaderContext, Map<Integer, Float>>>> tasks = new ArrayList<>(leafReaderContexts.size());
for (LeafReaderContext leafReaderContext : leafReaderContexts) {
tasks.add(() -> searchLeaf(leafReaderContext, knnWeight, k));
}
return indexSearcher.getTaskExecutor().invokeAll(tasks);
}

private List<Map<Integer, Float>> doRescore(
private List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> doRescore(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This data structure is getting complex, can we look into simplifying it in favor of readability?

final IndexSearcher indexSearcher,
List<LeafReaderContext> leafReaderContexts,
KNNWeight knnWeight,
List<Map<Integer, Float>> perLeafResults,
List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> perLeafResults,
int k
) throws IOException {
List<Callable<Map<Integer, Float>>> rescoreTasks = new ArrayList<>(leafReaderContexts.size());
for (int i = 0; i < perLeafResults.size(); i++) {
LeafReaderContext leafReaderContext = leafReaderContexts.get(i);
int finalI = i;
List<Callable<Map.Entry<LeafReaderContext, Map<Integer, Float>>>> rescoreTasks = new ArrayList<>(perLeafResults.size());
for (Map.Entry<LeafReaderContext, Map<Integer, Float>> entry : perLeafResults) {
rescoreTasks.add(() -> {
BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(perLeafResults.get(finalI));
BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(entry.getValue());
final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder()
.matchedDocs(convertedBitSet)
// setting to false because in re-scoring we want to do exact search on full precision vectors
Expand All @@ -121,7 +120,8 @@ private List<Map<Integer, Float>> doRescore(
.isParentHits(false)
.knnQuery(knnQuery)
.build();
return knnWeight.exactSearch(leafReaderContext, exactSearcherContext);
final Map<Integer, Float> docIdScoreMap = knnWeight.exactSearch(entry.getKey(), exactSearcherContext);
return new AbstractMap.SimpleEntry<>(entry.getKey(), docIdScoreMap);
});
}
return indexSearcher.getTaskExecutor().invokeAll(rescoreTasks);
Expand Down Expand Up @@ -158,13 +158,14 @@ static int[] findSegmentStarts(IndexReader reader, int[] docs) {
return starts;
}

private Map<Integer, Float> searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, int k) throws IOException {
private Map.Entry<LeafReaderContext, Map<Integer, Float>> searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, int k)
throws IOException {
final Map<Integer, Float> leafDocScores = queryWeight.searchLeaf(ctx, k);
final Bits liveDocs = ctx.reader().getLiveDocs();
if (liveDocs != null) {
leafDocScores.entrySet().removeIf(entry -> liveDocs.get(entry.getKey()) == false);
}
return leafDocScores;
return new AbstractMap.SimpleEntry<>(ctx, leafDocScores);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,24 @@

package org.opensearch.knn.index.query;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.BitSet;
import org.opensearch.knn.KNNTestCase;

import java.io.IOException;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static org.mockito.Mockito.mock;

public class ResultUtilTests extends KNNTestCase {

public void testReduceToTopK() {
Expand All @@ -27,7 +31,9 @@ public void testReduceToTopK() {
int segmentCount = 5;

List<Map<Integer, Float>> initialLeafResults = getRandomListOfResults(firstPassK, segmentCount);
List<Map<Integer, Float>> reducedLeafResults = initialLeafResults.stream().map(HashMap::new).collect(Collectors.toList());
List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> reducedLeafResults = initialLeafResults.stream()
.map(item -> new AbstractMap.SimpleEntry<>(mock(LeafReaderContext.class), item))
.collect(Collectors.toList());
ResultUtil.reduceToTopK(reducedLeafResults, finalK);
assertTopK(initialLeafResults, reducedLeafResults, finalK);

Expand All @@ -36,7 +42,9 @@ public void testReduceToTopK() {
segmentCount = 1;

initialLeafResults = getRandomListOfResults(firstPassK, segmentCount);
reducedLeafResults = initialLeafResults.stream().map(HashMap::new).collect(Collectors.toList());
reducedLeafResults = initialLeafResults.stream()
.map(item -> new AbstractMap.SimpleEntry<>(mock(LeafReaderContext.class), item))
.collect(Collectors.toList());
ResultUtil.reduceToTopK(reducedLeafResults, finalK);
assertTopK(initialLeafResults, reducedLeafResults, firstPassK);
}
Expand Down Expand Up @@ -75,9 +83,13 @@ private void assertResultMapToTopDocs(Map<Integer, Float> perLeafResults, TopDoc
}
}

private void assertTopK(List<Map<Integer, Float>> beforeResults, List<Map<Integer, Float>> reducedResults, int expectedK) {
private void assertTopK(
List<Map<Integer, Float>> beforeResults,
List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> reducedResults,
int expectedK
) {
assertEquals(beforeResults.size(), reducedResults.size());
assertEquals(expectedK, reducedResults.stream().map(Map::size).reduce(Integer::sum).orElse(-1).intValue());
assertEquals(expectedK, reducedResults.stream().map(row -> row.getValue().size()).reduce(Integer::sum).orElse(-1).intValue());
float minScore = getMinScore(reducedResults);
int count = 0;
for (Map<Integer, Float> result : beforeResults) {
Expand Down Expand Up @@ -126,10 +138,10 @@ private Map<Integer, Float> getRandomResults(int k) {
return results;
}

private float getMinScore(List<Map<Integer, Float>> perLeafResults) {
private float getMinScore(List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> perLeafResults) {
float minScore = Float.MAX_VALUE;
for (Map<Integer, Float> result : perLeafResults) {
for (float score : result.values()) {
for (Map.Entry<LeafReaderContext, Map<Integer, Float>> result : perLeafResults) {
for (float score : result.getValue().values()) {
if (score < minScore) {
minScore = score;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ public void setUp() throws Exception {

when(searcher.getTaskExecutor()).thenReturn(taskExecutor);
when(taskExecutor.invokeAll(any())).thenAnswer(invocationOnMock -> {
List<Callable<Map<Integer, Float>>> callables = invocationOnMock.getArgument(0);
List<Map<Integer, Float>> results = new ArrayList<>();
for (Callable<Map<Integer, Float>> callable : callables) {
List<Callable<Map.Entry<LeafReaderContext, Map<Integer, Float>>>> callables = invocationOnMock.getArgument(0);
List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> results = new ArrayList<>();
for (Callable<Map.Entry<LeafReaderContext, Map<Integer, Float>>> callable : callables) {
results.add(callable.call());
}
return results;
Expand Down
Loading