Skip to content

Commit

Permalink
Fixing the bug when a segment has no vector field present for disk ba…
Browse files Browse the repository at this point in the history
…sed vector search (#2281)

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v authored Nov 19, 2024
1 parent 4992736 commit 2d1a408
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
### Bug Fixes
* Fix NPE in ANN search when a segment doesn't contain vector field (#2278)[https://github.com/opensearch-project/k-NN/pull/2278]
### Infrastructure
* Updated C++ version in JNI from c++11 to c++17 [#2259](https://github.com/opensearch-project/k-NN/pull/2259)
* Upgrade bytebuddy and objenesis version to match OpenSearch core and, update github ci runner for macos [#2279](https://github.com/opensearch-project/k-NN/pull/2279)
Expand Down
10 changes: 5 additions & 5 deletions src/main/java/org/opensearch/knn/index/query/ResultUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,17 @@ public static void reduceToTopK(List<Map<Integer, Float>> perLeafResults, int k)
}

/**
* Convert map to bit set
* Convert map to bit set, if resultMap is empty or null then returns an Optional. Returning an optional here to
* ensure that the caller is aware that BitSet may not be present
*
* @param resultMap Map of results
* @return BitSet of results
* @return BitSet of results; null is returned if the result map is empty
* @throws IOException If an error occurs during the search.
*/
public static BitSet resultMapToMatchBitSet(Map<Integer, Float> resultMap) throws IOException {
if (resultMap.isEmpty()) {
return BitSet.of(DocIdSetIterator.empty(), 0);
if (resultMap == null || resultMap.isEmpty()) {
return null;
}

final int maxDoc = Collections.max(resultMap.keySet()) + 1;
return BitSet.of(resultMapToDocIds(resultMap, maxDoc), maxDoc);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -112,7 +113,12 @@ private List<Map<Integer, Float>> doRescore(
LeafReaderContext leafReaderContext = leafReaderContexts.get(i);
int finalI = i;
rescoreTasks.add(() -> {
BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(perLeafResults.get(finalI));
final BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(perLeafResults.get(finalI));
// if there is no docIds to re-score from a segment we should return early to ensure that we are not
// wasting any computation
if (convertedBitSet == null) {
return Collections.emptyMap();
}
final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder()
.matchedDocs(convertedBitSet)
// setting to false because in re-scoring we want to do exact search on full precision vectors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.BitSet;
import org.junit.Assert;
import org.opensearch.knn.KNNTestCase;

import java.io.IOException;
Expand Down Expand Up @@ -48,6 +49,14 @@ public void testResultMapToMatchBitSet() throws IOException {
assertResultMapToMatchBitSet(perLeafResults, resultBitset);
}

public void testResultMapToMatchBitSet_whenResultMapEmpty_thenReturnEmptyOptional() throws IOException {
BitSet resultBitset = ResultUtil.resultMapToMatchBitSet(Collections.emptyMap());
Assert.assertNull(resultBitset);

BitSet resultBitset2 = ResultUtil.resultMapToMatchBitSet(null);
Assert.assertNull(resultBitset2);
}

public void testResultMapToDocIds() throws IOException {
int firstPassK = 42;
Map<Integer, Float> perLeafResults = getRandomResults(firstPassK);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ public void testRescore() {
when(reader.leaves()).thenReturn(leaves);

int k = 2;
int firstPassK = 3;
int firstPassK = 100;
Map<Integer, Float> initialLeaf1Results = new HashMap<>(Map.of(0, 21f, 1, 19f, 2, 17f, 3, 15f));
Map<Integer, Float> initialLeaf2Results = new HashMap<>(Map.of(0, 20f, 1, 18f, 2, 16f, 3, 14f));
Map<Integer, Float> rescoredLeaf1Results = new HashMap<>(Map.of(0, 18f, 1, 20f));
Expand Down Expand Up @@ -257,6 +257,9 @@ public void testRescore() {
mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true);

mockedResultUtil.when(() -> ResultUtil.reduceToTopK(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod);
mockedResultUtil.when(() -> ResultUtil.resultMapToMatchBitSet(any())).thenAnswer(InvocationOnMock::callRealMethod);
mockedResultUtil.when(() -> ResultUtil.resultMapToDocIds(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod);

mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(eq(rescoredLeaf1Results), anyInt())).thenAnswer(t -> topDocs1);
mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(eq(rescoredLeaf2Results), anyInt())).thenAnswer(t -> topDocs2);
try (MockedStatic<NativeEngineKnnVectorQuery> mockedStaticNativeKnnVectorQuery = mockStatic(NativeEngineKnnVectorQuery.class)) {
Expand Down
36 changes: 36 additions & 0 deletions src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.XContentBuilder;
Expand Down Expand Up @@ -220,6 +221,41 @@ public void testDeletedDocsWithSegmentMerge_whenValid_ThenSucceed() {
validateGreenIndex(indexName);
}

@SneakyThrows
public void testCompressionIndexWithNonVectorFieldsSegment_whenValid_ThenSucceed() {
CompressionLevel compressionLevel = CompressionLevel.x32;
String indexName = INDEX_NAME + compressionLevel;
try (
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(FIELD_NAME)
.field("type", "knn_vector")
.field("dimension", DIMENSION)
.field(COMPRESSION_LEVEL_PARAMETER, compressionLevel.getName())
.field(MODE_PARAMETER, Mode.ON_DISK.getName())
.endObject()
.endObject()
.endObject()
) {
String mapping = builder.toString();
Settings indexSettings = buildKNNIndexSettings(0);
createKnnIndex(indexName, indexSettings, mapping);
// since we are going to delete a document, so its better to have 1 more extra doc so that we can re-use some tests
addKNNDocs(indexName, FIELD_NAME, DIMENSION, 0, NUM_DOCS + 1);
addNonKNNDoc(indexName, String.valueOf(NUM_DOCS + 2), FIELD_NAME_NON_KNN, "Hello world");
deleteKnnDoc(indexName, "0");
validateGreenIndex(indexName);
validateSearch(
indexName,
METHOD_PARAMETER_EF_SEARCH,
KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH,
compressionLevel.getName(),
Mode.ON_DISK.getName()
);
}
}

@SneakyThrows
public void testTraining_whenInvalid_thenFail() {
setupTrainingIndex();
Expand Down
13 changes: 13 additions & 0 deletions src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
public class KNNRestTestCase extends ODFERestTestCase {
public static final String INDEX_NAME = "test_index";
public static final String FIELD_NAME = "test_field";
public static final String FIELD_NAME_NON_KNN = "test_field_non_knn";
public static final String PROPERTIES_FIELD = "properties";
public static final String STORE_FIELD = "store";
public static final String STORED_QUERY_FIELD = "stored_fields";
Expand Down Expand Up @@ -607,6 +608,18 @@ protected <T> void addKnnDoc(String index, String docId, String fieldName, T vec
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}

protected <T> void addNonKNNDoc(String index, String docId, String fieldName, String text) throws IOException {
Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true");

XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(fieldName, text).endObject();
request.setJsonEntity(builder.toString());
client().performRequest(request);

request = new Request("POST", "/" + index + "/_refresh");
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}

/**
* Add a single KNN Doc to an index with a nested vector field
*
Expand Down

0 comments on commit 2d1a408

Please sign in to comment.