-
Notifications
You must be signed in to change notification settings - Fork 126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Split segment by search type #2273
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -153,6 +153,26 @@ public Map<Integer, Float> searchLeaf(LeafReaderContext context, int k) throws I | |
return docIdsToScoreMap; | ||
} | ||
|
||
/** | ||
* For given {@link LeafReaderContext}, this api will return will KNNWeight perform exact search or not | ||
* always. This decision is based on two properties, 1) if there are no native engine files in segments, | ||
* exact search will always be performed, 2) if number of docs after filter is less than 'k' | ||
* @param context | ||
* @return | ||
* @throws IOException | ||
*/ | ||
public boolean isExactSearchPreferred(LeafReaderContext context) throws IOException { | ||
final BitSet filterBitSet = getFilteredDocsBitSet(context); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For each ANN leaf, getting the filter Bitset is now happening twice. Once for this check, and then for getting the actual filter bitset. Considering worst case scenario, This goes through filterWeight.scorer twice and then creating a bitset which involves a linear loop How confident are we that this won't impact latencies for filtering cases? Can we avoid this duplicate? One way is to do only engine files empty check here. and then pass the bitset in searchLeaf and in exact Search context Let me know if I am missing something |
||
int cardinality = filterBitSet.cardinality(); | ||
if (isFilteredExactSearchPreferred(cardinality)) { | ||
return true; | ||
} | ||
if (isMissingNativeEngineFiles(context)) { | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException { | ||
if (this.filterWeight == null) { | ||
return new FixedBitSet(0); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
|
||
package org.opensearch.knn.index.query; | ||
|
||
import org.apache.lucene.index.LeafReaderContext; | ||
import org.apache.lucene.search.DocIdSetIterator; | ||
import org.apache.lucene.search.ScoreDoc; | ||
import org.apache.lucene.search.TopDocs; | ||
|
@@ -30,14 +31,15 @@ public final class ResultUtil { | |
* @param perLeafResults Results from the list | ||
* @param k the number of results across all leaf results to return | ||
*/ | ||
public static void reduceToTopK(List<Map<Integer, Float>> perLeafResults, int k) { | ||
public static void reduceToTopK(List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> perLeafResults, int k) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should move have a
this will help in future to abstract more details related to search for a segment. Feel free to have a better name for classes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||
// Iterate over all scores to get min competitive score | ||
PriorityQueue<Float> topKMinQueue = new PriorityQueue<>(k); | ||
|
||
int count = 0; | ||
for (Map<Integer, Float> perLeafResult : perLeafResults) { | ||
count += perLeafResult.size(); | ||
for (Float score : perLeafResult.values()) { | ||
for (Map.Entry<LeafReaderContext, Map<Integer, Float>> perLeafResult : perLeafResults) { | ||
Map<Integer, Float> docIdScoreMap = perLeafResult.getValue(); | ||
count += docIdScoreMap.size(); | ||
for (Float score : docIdScoreMap.values()) { | ||
if (topKMinQueue.size() < k) { | ||
topKMinQueue.add(score); | ||
} else if (topKMinQueue.peek() != null && score > topKMinQueue.peek()) { | ||
|
@@ -54,7 +56,7 @@ public static void reduceToTopK(List<Map<Integer, Float>> perLeafResults, int k) | |
|
||
// Reduce the results based on min competitive score | ||
float minScore = topKMinQueue.peek() == null ? -Float.MAX_VALUE : topKMinQueue.peek(); | ||
perLeafResults.forEach(results -> results.entrySet().removeIf(entry -> entry.getValue() < minScore)); | ||
perLeafResults.forEach(results -> results.getValue().entrySet().removeIf(entry -> entry.getValue() < minScore)); | ||
} | ||
|
||
/** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
import org.opensearch.knn.index.query.rescore.RescoreContext; | ||
|
||
import java.io.IOException; | ||
import java.util.AbstractMap; | ||
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.Comparator; | ||
|
@@ -55,7 +56,7 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo | |
final IndexReader reader = indexSearcher.getIndexReader(); | ||
final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, scoreMode, 1); | ||
List<LeafReaderContext> leafReaderContexts = reader.leaves(); | ||
List<Map<Integer, Float>> perLeafResults; | ||
List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> perLeafResults; | ||
RescoreContext rescoreContext = knnQuery.getRescoreContext(); | ||
final int finalK = knnQuery.getK(); | ||
if (rescoreContext == null) { | ||
|
@@ -64,20 +65,33 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo | |
boolean isShardLevelRescoringEnabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(knnQuery.getIndexName()); | ||
int dimension = knnQuery.getQueryVector().length; | ||
int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension); | ||
perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK); | ||
// split segments into whether exact search will be performed or not | ||
List<LeafReaderContext> exactSearchSegments = new ArrayList<>(); | ||
List<LeafReaderContext> approxSearchSegments = new ArrayList<>(); | ||
for (LeafReaderContext leafReaderContext : leafReaderContexts) { | ||
if (knnWeight.isExactSearchPreferred(leafReaderContext)) { | ||
exactSearchSegments.add(leafReaderContext); | ||
} else { | ||
approxSearchSegments.add(leafReaderContext); | ||
} | ||
} | ||
perLeafResults = doSearch(indexSearcher, approxSearchSegments, knnWeight, firstPassK); | ||
if (isShardLevelRescoringEnabled == true) { | ||
ResultUtil.reduceToTopK(perLeafResults, firstPassK); | ||
} | ||
|
||
StopWatch stopWatch = new StopWatch().start(); | ||
perLeafResults = doRescore(indexSearcher, leafReaderContexts, knnWeight, perLeafResults, finalK); | ||
perLeafResults = doRescore(indexSearcher, knnWeight, perLeafResults, finalK); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previously, we did exact search on finalK. After this change, we still does exact search on finalK. Could you tell me how will this improve the latency? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. doSearch can call either ApproxSearch or Exact Search based on conditions like whether engine files exists or not, number of docs after filter is less than k. In those cases, we will quantize query vector, and every vector from segments, and, then perform distance computation using Hamming distance for firstPassK. With this approach, we only call doSearch for those segments which we know will always call approxsearch, and, for other segments we will call exact search without quantization with finalK. The optimization is at https://github.com/opensearch-project/k-NN/pull/2273/files#diff-9cfe412357ba56b3ef216427d491fc653535686a760e8ba19ea1aa00fc0e0338R68-R78 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you assuming that an exact search on full precision vectors will be faster than an exact search with quantized vectors due to the slower quantization process? It would be interesting to see the benchmark results for this. If that’s the case, an alternative could be to retrieve quantized values directly from the Faiss file instead of performing on-the-fly quantization. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, exact search on full precision for k is less than, exact search on quantization for first pass K + rescore matched docs on full precision . The linked GitHub issues actually shows how performance got impacted 10x when there are segments with no faiss engine files. In my POC, I saw improvements but recall was poor because of using order as link between results and leaf reader context. I am rerunning experiments with my change to collect metrics with latency and recall There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is one case where we are running exact search; when the returned result is less than k. Are we going to handle that case as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe this happens with filter, if so, yes, it was already taken care There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. It happens regardless there is filter or not. https://github.com/opensearch-project/k-NN/blob/main/src/main/java/org/opensearch/knn/index/query/KNNWeight.java#L149 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. It can happen either when there are no engine files or after filter that number of matched documents is less than k. We only decided to call doSearch if we know that it will call Approx Search API. For other segments we will directly call Exact Search, this PR is about that only https://github.com/opensearch-project/k-NN/pull/2273/files#diff-9cfe412357ba56b3ef216427d491fc653535686a760e8ba19ea1aa00fc0e0338R72-R75 |
||
long rescoreTime = stopWatch.stop().totalTime().millis(); | ||
log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, leafReaderContexts.size()); | ||
log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, perLeafResults.size()); | ||
// do exact search on rest of segments and append to result lists | ||
perLeafResults.addAll(doExactSearch(indexSearcher, knnWeight, exactSearchSegments)); | ||
} | ||
ResultUtil.reduceToTopK(perLeafResults, finalK); | ||
TopDocs[] topDocs = new TopDocs[perLeafResults.size()]; | ||
for (int i = 0; i < perLeafResults.size(); i++) { | ||
topDocs[i] = ResultUtil.resultMapToTopDocs(perLeafResults.get(i), leafReaderContexts.get(i).docBase); | ||
int i = 0; | ||
for (Map.Entry<LeafReaderContext, Map<Integer, Float>> entry : perLeafResults) { | ||
topDocs[i++] = ResultUtil.resultMapToTopDocs(entry.getValue(), entry.getKey().docBase); | ||
} | ||
|
||
TopDocs topK = TopDocs.merge(knnQuery.getK(), topDocs); | ||
|
@@ -87,32 +101,29 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo | |
return createDocAndScoreQuery(reader, topK).createWeight(indexSearcher, scoreMode, boost); | ||
} | ||
|
||
private List<Map<Integer, Float>> doSearch( | ||
private List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> doSearch( | ||
final IndexSearcher indexSearcher, | ||
List<LeafReaderContext> leafReaderContexts, | ||
KNNWeight knnWeight, | ||
int k | ||
) throws IOException { | ||
List<Callable<Map<Integer, Float>>> tasks = new ArrayList<>(leafReaderContexts.size()); | ||
List<Callable<Map.Entry<LeafReaderContext, Map<Integer, Float>>>> tasks = new ArrayList<>(leafReaderContexts.size()); | ||
for (LeafReaderContext leafReaderContext : leafReaderContexts) { | ||
tasks.add(() -> searchLeaf(leafReaderContext, knnWeight, k)); | ||
} | ||
return indexSearcher.getTaskExecutor().invokeAll(tasks); | ||
} | ||
|
||
private List<Map<Integer, Float>> doRescore( | ||
private List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> doRescore( | ||
final IndexSearcher indexSearcher, | ||
List<LeafReaderContext> leafReaderContexts, | ||
KNNWeight knnWeight, | ||
List<Map<Integer, Float>> perLeafResults, | ||
List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> perLeafResults, | ||
int k | ||
) throws IOException { | ||
List<Callable<Map<Integer, Float>>> rescoreTasks = new ArrayList<>(leafReaderContexts.size()); | ||
for (int i = 0; i < perLeafResults.size(); i++) { | ||
LeafReaderContext leafReaderContext = leafReaderContexts.get(i); | ||
int finalI = i; | ||
List<Callable<Map.Entry<LeafReaderContext, Map<Integer, Float>>>> rescoreTasks = new ArrayList<>(perLeafResults.size()); | ||
for (Map.Entry<LeafReaderContext, Map<Integer, Float>> entry : perLeafResults) { | ||
rescoreTasks.add(() -> { | ||
BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(perLeafResults.get(finalI)); | ||
BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(entry.getValue()); | ||
final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder() | ||
.matchedDocs(convertedBitSet) | ||
// setting to false because in re-scoring we want to do exact search on full precision vectors | ||
|
@@ -121,12 +132,35 @@ private List<Map<Integer, Float>> doRescore( | |
.isParentHits(false) | ||
.knnQuery(knnQuery) | ||
.build(); | ||
return knnWeight.exactSearch(leafReaderContext, exactSearcherContext); | ||
final Map<Integer, Float> docIdScoreMap = knnWeight.exactSearch(entry.getKey(), exactSearcherContext); | ||
return new AbstractMap.SimpleEntry<>(entry.getKey(), docIdScoreMap); | ||
}); | ||
} | ||
return indexSearcher.getTaskExecutor().invokeAll(rescoreTasks); | ||
} | ||
|
||
private List<Map.Entry<LeafReaderContext, Map<Integer, Float>>> doExactSearch( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. doRescore and this method are more or less similar, any possiblity that we can reuse something here? |
||
final IndexSearcher indexSearcher, | ||
KNNWeight knnWeight, | ||
List<LeafReaderContext> leafReaderContexts | ||
) throws IOException { | ||
List<Callable<Map.Entry<LeafReaderContext, Map<Integer, Float>>>> exactSearchTasks = new ArrayList<>(leafReaderContexts.size()); | ||
for (LeafReaderContext context : leafReaderContexts) { | ||
exactSearchTasks.add(() -> { | ||
final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder() | ||
// setting to false because we want to do exact search on full precision vectors | ||
.useQuantizedVectorsForSearch(false) | ||
.k(knnQuery.getK()) | ||
.knnQuery(knnQuery) | ||
.isParentHits(true) | ||
.build(); | ||
final Map<Integer, Float> searchResults = knnWeight.exactSearch(context, exactSearcherContext); | ||
return new AbstractMap.SimpleEntry<>(context, searchResults); | ||
}); | ||
} | ||
return indexSearcher.getTaskExecutor().invokeAll(exactSearchTasks); | ||
} | ||
|
||
private Query createDocAndScoreQuery(IndexReader reader, TopDocs topK) { | ||
int len = topK.scoreDocs.length; | ||
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc)); | ||
|
@@ -158,13 +192,14 @@ static int[] findSegmentStarts(IndexReader reader, int[] docs) { | |
return starts; | ||
} | ||
|
||
private Map<Integer, Float> searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, int k) throws IOException { | ||
private Map.Entry<LeafReaderContext, Map<Integer, Float>> searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, int k) | ||
throws IOException { | ||
final Map<Integer, Float> leafDocScores = queryWeight.searchLeaf(ctx, k); | ||
final Bits liveDocs = ctx.reader().getLiveDocs(); | ||
if (liveDocs != null) { | ||
leafDocScores.entrySet().removeIf(entry -> liveDocs.get(entry.getKey()) == false); | ||
} | ||
return leafDocScores; | ||
return new AbstractMap.SimpleEntry<>(ctx, leafDocScores); | ||
} | ||
|
||
@Override | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this java doc is incorrect. We have bit more logic around this. So lets fix this java doc.