Skip to content

Commit

Permalink
Fix NPE in ANN search when a segment doesn't contain vector field
Browse files Browse the repository at this point in the history
Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed Nov 19, 2024
1 parent 4992736 commit 6b758f6
Show file tree
Hide file tree
Showing 10 changed files with 225 additions and 84 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
### Bug Fixes
* Fix NPE in ANN search when a segment doesn't contain vector field (#2278)[https://github.com/opensearch-project/k-NN/pull/2278]
### Infrastructure
* Updated C++ version in JNI from c++11 to c++17 [#2259](https://github.com/opensearch-project/k-NN/pull/2259)
* Upgrade bytebuddy and objenesis version to match OpenSearch core and, update github ci runner for macos [#2279](https://github.com/opensearch-project/k-NN/pull/2279)
Expand Down
1 change: 0 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ dependencies {
testFixturesImplementation "org.opensearch.test:framework:${opensearch_version}"
testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: '1.15.10'
testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.3'
testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: '1.15.4'
testFixturesImplementation "org.opensearch:common-utils:${version}"
implementation 'com.github.oshi:oshi-core:6.4.13'
api "net.java.dev.jna:jna:5.13.0"
Expand Down
167 changes: 96 additions & 71 deletions src/main/java/org/opensearch/knn/index/query/ExactSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.opensearch.common.lucene.Lucene;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.query.iterators.BinaryVectorIdsKNNIterator;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.iterators.ByteVectorIdsKNNIterator;
Expand All @@ -38,9 +37,11 @@
import org.opensearch.knn.indices.ModelDao;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;

@Log4j2
Expand All @@ -59,23 +60,27 @@ public class ExactSearcher {
*/
public Map<Integer, Float> searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext)
throws IOException {
KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext);
final Optional<KNNIterator> iterator = getKNNIterator(leafReaderContext, exactSearcherContext);
// because of any reason if we are not able to get KNNIterator, return an empty map
if (iterator.isEmpty()) {
return Collections.emptyMap();
}
if (exactSearcherContext.getKnnQuery().getRadius() != null) {
return doRadialSearch(leafReaderContext, exactSearcherContext, iterator);
return doRadialSearch(leafReaderContext, exactSearcherContext, iterator.get());
}
if (exactSearcherContext.getMatchedDocs() != null
&& exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) {
return scoreAllDocs(iterator);
return scoreAllDocs(iterator.get());
}
return searchTopCandidates(iterator, exactSearcherContext.getK(), Predicates.alwaysTrue());
return searchTopCandidates(iterator.get(), exactSearcherContext.getK(), Predicates.alwaysTrue());
}

/**
* Perform radial search by comparing scores with min score. Currently, FAISS from native engine supports radial search.
* Hence, we assume that Radius from knnQuery is always distance, and we convert it to score since we do exact search uses scores
* to filter out the documents that does not have given min score.
* @param leafReaderContext
* @param exactSearcherContext
* @param leafReaderContext {@link LeafReaderContext}
* @param exactSearcherContext {@link ExactSearcherContext}
* @param iterator {@link KNNIterator}
* @return Map of docId and score
* @throws IOException exception raised by iterator during traversal
Expand Down Expand Up @@ -145,79 +150,99 @@ private Map<Integer, Float> filterDocsByMinScore(ExactSearcherContext context, K
return searchTopCandidates(iterator, maxResultWindow, scoreGreaterThanOrEqualToMinScore);
}

private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) throws IOException {
private Optional<KNNIterator> getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext)
throws IOException {
final KNNQuery knnQuery = exactSearcherContext.getKnnQuery();
final BitSet matchedDocs = exactSearcherContext.getMatchedDocs();
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
if (fieldInfo == null) {
log.debug("[KNN] Cannot get KNNIterator as Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName());
return Optional.empty();
}
final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo);

boolean isNestedRequired = exactSearcherContext.isParentHits() && knnQuery.getParentsFilter() != null;

if (VectorDataType.BINARY == knnQuery.getVectorDataType()) {
final KNNVectorValues<byte[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
if (isNestedRequired) {
return new NestedBinaryVectorIdsKNNIterator(
matchedDocs,
knnQuery.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
);
}
return new BinaryVectorIdsKNNIterator(
matchedDocs,
knnQuery.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
spaceType
);
}

if (VectorDataType.BYTE == knnQuery.getVectorDataType()) {
final KNNVectorValues<byte[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
if (isNestedRequired) {
return new NestedByteVectorIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNByteVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
final KNNIterator knnIterator;
KNNVectorValues<?> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
switch (knnQuery.getVectorDataType()) {
case BINARY:
if (isNestedRequired) {
knnIterator = new NestedBinaryVectorIdsKNNIterator(
matchedDocs,
knnQuery.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
);
} else {
knnIterator = new BinaryVectorIdsKNNIterator(
matchedDocs,
knnQuery.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
spaceType
);
}
return Optional.of(knnIterator);
case BYTE:
if (isNestedRequired) {
knnIterator = new NestedByteVectorIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNByteVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
);
} else {
knnIterator = new ByteVectorIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNByteVectorValues) vectorValues,
spaceType
);
}
return Optional.of(knnIterator);
case FLOAT:
final byte[] quantizedQueryVector;
final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo;
if (exactSearcherContext.isUseQuantizedVectorsForSearch()) {
// Build Segment Level Quantization info.
segmentLevelQuantizationInfo = SegmentLevelQuantizationInfo.build(reader, fieldInfo, knnQuery.getField());
// Quantize the Query Vector Once.
quantizedQueryVector = SegmentLevelQuantizationUtil.quantizeVector(
knnQuery.getQueryVector(),
segmentLevelQuantizationInfo
);
} else {
segmentLevelQuantizationInfo = null;
quantizedQueryVector = null;
}
if (isNestedRequired) {
knnIterator = new NestedVectorIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNFloatVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext),
quantizedQueryVector,
segmentLevelQuantizationInfo
);
} else {
knnIterator = new VectorIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNFloatVectorValues) vectorValues,
spaceType,
quantizedQueryVector,
segmentLevelQuantizationInfo
);
}
return Optional.of(knnIterator);
default:
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Vector data type [%s] is not supported", knnQuery.getVectorDataType())
);
}
return new ByteVectorIdsKNNIterator(matchedDocs, knnQuery.getQueryVector(), (KNNByteVectorValues) vectorValues, spaceType);
}
final byte[] quantizedQueryVector;
final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo;
if (exactSearcherContext.isUseQuantizedVectorsForSearch()) {
// Build Segment Level Quantization info.
segmentLevelQuantizationInfo = SegmentLevelQuantizationInfo.build(reader, fieldInfo, knnQuery.getField());
// Quantize the Query Vector Once.
quantizedQueryVector = SegmentLevelQuantizationUtil.quantizeVector(knnQuery.getQueryVector(), segmentLevelQuantizationInfo);
} else {
segmentLevelQuantizationInfo = null;
quantizedQueryVector = null;
}

final KNNVectorValues<float[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
if (isNestedRequired) {
return new NestedVectorIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNFloatVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext),
quantizedQueryVector,
segmentLevelQuantizationInfo
);
}
return new VectorIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNFloatVectorValues) vectorValues,
spaceType,
quantizedQueryVector,
segmentLevelQuantizationInfo
);
}

/**
Expand Down
15 changes: 8 additions & 7 deletions src/main/java/org/opensearch/knn/index/query/ResultUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;

/**
Expand Down Expand Up @@ -58,19 +59,19 @@ public static void reduceToTopK(List<Map<Integer, Float>> perLeafResults, int k)
}

/**
* Convert map to bit set
* Convert map to bit set, if resultMap is empty or null then returns an Optional. Returning an optional here to
* ensure that the caller is aware that BitSet may not be present
*
* @param resultMap Map of results
* @return BitSet of results
* @return Optional BitSet of results
* @throws IOException If an error occurs during the search.
*/
public static BitSet resultMapToMatchBitSet(Map<Integer, Float> resultMap) throws IOException {
if (resultMap.isEmpty()) {
return BitSet.of(DocIdSetIterator.empty(), 0);
public static Optional<BitSet> resultMapToMatchBitSet(Map<Integer, Float> resultMap) throws IOException {
if (resultMap == null || resultMap.isEmpty()) {
return Optional.empty();
}

final int maxDoc = Collections.max(resultMap.keySet()) + 1;
return BitSet.of(resultMapToDocIds(resultMap, maxDoc), maxDoc);
return Optional.of(BitSet.of(resultMapToDocIds(resultMap, maxDoc), maxDoc));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.Callable;

/**
Expand Down Expand Up @@ -112,9 +114,14 @@ private List<Map<Integer, Float>> doRescore(
LeafReaderContext leafReaderContext = leafReaderContexts.get(i);
int finalI = i;
rescoreTasks.add(() -> {
BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(perLeafResults.get(finalI));
final Optional<BitSet> convertedBitSet = ResultUtil.resultMapToMatchBitSet(perLeafResults.get(finalI));
// if there is no docIds to re-score from a segment we should return early to ensure that we are not
// wasting any computation
if (convertedBitSet.isEmpty()) {
return Collections.emptyMap();
}
final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder()
.matchedDocs(convertedBitSet)
.matchedDocs(convertedBitSet.get())
// setting to false because in re-scoring we want to do exact search on full precision vectors
.useQuantizedVectorsForSearch(false)
.k(k)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.mockito.Mockito;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.KNNCodecVersion;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;
Expand Down Expand Up @@ -50,6 +51,50 @@ public class ExactSearcherTests extends KNNTestCase {

private static final String SEGMENT_NAME = "0";

@SneakyThrows
public void testExactSearch_whenSegmentHasNoVectorField_thenNoDocsReturned() {
final float[] queryVector = new float[] { 0.1f, 2.0f, 3.0f };
final KNNQuery query = KNNQuery.builder().field(FIELD_NAME).queryVector(queryVector).k(10).indexName(INDEX_NAME).build();

final ExactSearcher.ExactSearcherContext.ExactSearcherContextBuilder exactSearcherContextBuilder =
ExactSearcher.ExactSearcherContext.builder().knnQuery(query);

ExactSearcher exactSearcher = new ExactSearcher(null);
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
final SegmentReader reader = mock(SegmentReader.class);
when(leafReaderContext.reader()).thenReturn(reader);

final FSDirectory directory = mock(FSDirectory.class);
final SegmentInfo segmentInfo = new SegmentInfo(
directory,
Version.LATEST,
Version.LATEST,
SEGMENT_NAME,
100,
false,
false,
KNNCodecVersion.current().getDefaultCodecDelegate(),
Map.of(),
new byte[StringHelper.ID_LENGTH],
Map.of(),
Sort.RELEVANCE
);
segmentInfo.setFiles(Set.of());
final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]);
when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo);

final Path path = mock(Path.class);
when(directory.getDirectory()).thenReturn(path);
final FieldInfos fieldInfos = mock(FieldInfos.class);
final FieldInfo fieldInfo = mock(FieldInfo.class);
when(reader.getFieldInfos()).thenReturn(fieldInfos);
when(fieldInfos.fieldInfo(query.getField())).thenReturn(null);
when(fieldInfo.attributes()).thenReturn(Collections.emptyMap());
Map<Integer, Float> docIds = exactSearcher.searchLeaf(leafReaderContext, exactSearcherContextBuilder.build());
Mockito.verify(fieldInfos).fieldInfo(query.getField());
assertEquals(0, docIds.size());
}

@SneakyThrows
public void testRadialSearch_whenNoEngineFiles_thenSuccess() {
try (MockedStatic<KNNVectorValuesFactory> valuesFactoryMockedStatic = Mockito.mockStatic(KNNVectorValuesFactory.class)) {
Expand All @@ -75,6 +120,7 @@ public void testRadialSearch_whenNoEngineFiles_thenSuccess() {
.queryVector(queryVector)
.radius(radius)
.indexName(INDEX_NAME)
.vectorDataType(VectorDataType.FLOAT)
.context(context)
.build();

Expand Down
Loading

0 comments on commit 6b758f6

Please sign in to comment.