Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add check to directly use ANN Search when filters match all docs. #2320

Merged
merged 4 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
- Optimizes lucene query execution to prevent unnecessary rewrites (#2305)[https://github.com/opensearch-project/k-NN/pull/2305]
- Add check to directly use ANN Search when filters match all docs. (#2320)[https://github.com/opensearch-project/k-NN/pull/2320]
### Bug Fixes
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
* Allow validation for non knn index only after 2.17.0 (#2315)[https://github.com/opensearch-project/k-NN/pull/2315]
Expand Down
8 changes: 8 additions & 0 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
*/
public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOException {
final BitSet filterBitSet = getFilteredDocsBitSet(context);
final int maxDoc = context.reader().maxDoc();
int cardinality = filterBitSet.cardinality();
// We don't need to go to JNI layer if no documents are found which satisfy the filters
// We should give this condition a deeper look that where it should be placed. For now I feel this is a good
Expand All @@ -145,6 +146,12 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
Map<Integer, Float> result = doExactSearch(context, new BitSetIterator(filterBitSet, cardinality), cardinality, k);
return new PerLeafResult(filterWeight == null ? null : filterBitSet, result);
}
/*
* If filters match all docs in this segment, then there is no need to do any extra step
weiwang118 marked this conversation as resolved.
Show resolved Hide resolved
* and should directly do ANN Search*/
if (filterWeight != null && cardinality == maxDoc) {
weiwang118 marked this conversation as resolved.
Show resolved Hide resolved
return new PerLeafResult(new FixedBitSet(0), doANNSearch(context, new FixedBitSet(0), 0, k));
weiwang118 marked this conversation as resolved.
Show resolved Hide resolved
}
Map<Integer, Float> docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k);
// See whether we have to perform exact search based on approx search results
// This is required if there are no native engine files or if approximate search returned
Expand Down Expand Up @@ -320,6 +327,7 @@ private Map<Integer, Float> doANNSearch(
// Now that we have the allocation, we need to readLock it
indexAllocation.readLock();
indexAllocation.incRef();

try {
if (indexAllocation.isClosed()) {
throw new RuntimeException("Index has already been closed");
Expand Down
101 changes: 97 additions & 4 deletions src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is
when(liveDocsBits.length()).thenReturn(1000);

final SegmentReader reader = mockSegmentReader();
when(reader.maxDoc()).thenReturn(filterDocIds.length);
when(reader.maxDoc()).thenReturn(filterDocIds.length + 1);
when(reader.getLiveDocs()).thenReturn(liveDocsBits);

final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
Expand Down Expand Up @@ -758,6 +758,97 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
}

@SneakyThrows
public void testANNWithFilterQuery_whenFiltersMatchAllDocs_thenSuccess() {
// Given
int k = 3;
final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 };
FixedBitSet filterBitSet = new FixedBitSet(filterDocIds.length);
for (int docId : filterDocIds) {
filterBitSet.set(docId);
}

jniServiceMockedStatic.when(
() -> JNIService.queryIndex(
anyLong(),
eq(QUERY_VECTOR),
eq(k),
eq(HNSW_METHOD_PARAMETERS),
any(),
eq(new FixedBitSet(0).getBits()),
anyInt(),
any()
)
).thenReturn(getFilteredKNNQueryResults());

final Bits liveDocsBits = mock(Bits.class);
for (int filterDocId : filterDocIds) {
when(liveDocsBits.get(filterDocId)).thenReturn(true);
}
when(liveDocsBits.length()).thenReturn(1000);

final SegmentReader reader = mockSegmentReader();
when(reader.maxDoc()).thenReturn(filterDocIds.length);
when(reader.getLiveDocs()).thenReturn(liveDocsBits);

final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
when(leafReaderContext.reader()).thenReturn(reader);

final KNNQuery query = KNNQuery.builder()
.field(FIELD_NAME)
.queryVector(QUERY_VECTOR)
.k(k)
.indexName(INDEX_NAME)
.filterQuery(FILTER_QUERY)
.methodParameters(HNSW_METHOD_PARAMETERS)
.build();

final Weight filterQueryWeight = mock(Weight.class);
final Scorer filterScorer = mock(Scorer.class);
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);
// Just to make sure that we are not hitting the exact search condition
when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length + 1));

final float boost = (float) randomDoubleBetween(0, 10, true);
final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight);

final FieldInfos fieldInfos = mock(FieldInfos.class);
final FieldInfo fieldInfo = mock(FieldInfo.class);
final Map<String, String> attributesMap = ImmutableMap.of(
KNN_ENGINE,
KNNEngine.FAISS.getName(),
SPACE_TYPE,
SpaceType.L2.getValue()
);

when(reader.getFieldInfos()).thenReturn(fieldInfos);
when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo);
when(fieldInfo.attributes()).thenReturn(attributesMap);

// When
final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext);

// Then
assertNotNull(knnScorer);
final DocIdSetIterator docIdSetIterator = knnScorer.iterator();
assertNotNull(docIdSetIterator);
assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost());

jniServiceMockedStatic.verify(
() -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), any(), anyInt(), any()),
times(1)
);

final List<Integer> actualDocIds = new ArrayList<>();
final Map<Integer, Float> translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation);
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
actualDocIds.add(docId);
assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f);
}
assertEquals(docIdSetIterator.cost(), actualDocIds.size());
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
}

private SegmentReader mockSegmentReader() {
Path path = mock(Path.class);

Expand Down Expand Up @@ -815,7 +906,7 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);
// scorer will return 2 documents
when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(1));
when(reader.maxDoc()).thenReturn(1);
when(reader.maxDoc()).thenReturn(2);
final Bits liveDocsBits = mock(Bits.class);
when(reader.getLiveDocs()).thenReturn(liveDocsBits);
when(liveDocsBits.get(filterDocId)).thenReturn(true);
Expand Down Expand Up @@ -891,6 +982,7 @@ public void testRadialSearch_whenNoEngineFiles_thenPerformExactSearch() {
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
final SegmentReader reader = mock(SegmentReader.class);
when(leafReaderContext.reader()).thenReturn(reader);
when(reader.maxDoc()).thenReturn(1);

final FSDirectory directory = mock(FSDirectory.class);
when(reader.directory()).thenReturn(directory);
Expand Down Expand Up @@ -968,7 +1060,7 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);
// scorer will return 2 documents
when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(1));
when(reader.maxDoc()).thenReturn(1);
when(reader.maxDoc()).thenReturn(2);
final Bits liveDocsBits = mock(Bits.class);
when(reader.getLiveDocs()).thenReturn(liveDocsBits);
when(liveDocsBits.get(filterDocId)).thenReturn(true);
Expand Down Expand Up @@ -1168,6 +1260,7 @@ public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() {
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
final SegmentReader reader = mock(SegmentReader.class);
when(leafReaderContext.reader()).thenReturn(reader);
when(reader.maxDoc()).thenReturn(1);

final Weight filterQueryWeight = mock(Weight.class);
final Scorer filterScorer = mock(Scorer.class);
Expand Down Expand Up @@ -1202,7 +1295,7 @@ public void testANNWithParentsFilter_whenExactSearch_thenSuccess() {
// We will have 0, 1 for filteredIds and 2 will be the parent id for both of them
final Scorer filterScorer = mock(Scorer.class);
when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(2));
when(reader.maxDoc()).thenReturn(2);
when(reader.maxDoc()).thenReturn(3);

// Query vector is {1.8f, 2.4f}, therefore, second vector {1.9f, 2.5f} should be returned in a result
final List<float[]> vectors = Arrays.asList(new float[] { 0.1f, 0.3f }, new float[] { 1.9f, 2.5f });
Expand Down
Loading