Skip to content

Commit

Permalink
Add parent join support for lucene knn
Browse files Browse the repository at this point in the history
Call DiversifyingChildren[Byte|Float]KnnVectorQuery for nested field so that k number of parent document can be returned in search result

Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 committed Oct 3, 2023
1 parent 78aba55 commit 9cc75c7
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 30 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.10...2.x)
### Features
* Add parent join support for lucene knn [#1182](https://github.com/opensearch-project/k-NN/pull/1182)
### Enhancements
* Added support for ignore_unmapped in KNN queries. [#1071](https://github.com/opensearch-project/k-NN/pull/1071)
### Bug Fixes
Expand Down
8 changes: 8 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ public class KNNConstants {
public static final String NAME = "name";
public static final String PARAMETERS = "parameters";
public static final String METHOD_HNSW = "hnsw";
public static final String TYPE = "type";
public static final String TYPE_NESTED = "nested";
public static final String PATH = "path";
public static final String QUERY = "query";
public static final String KNN = "knn";
public static final String VECTOR = "vector";
public static final String K = "k";
public static final String TYPE_KNN_VECTOR = "knn_vector";
public static final String METHOD_PARAMETER_EF_SEARCH = "ef_search";
public static final String METHOD_PARAMETER_EF_CONSTRUCTION = "ef_construction";
public static final String METHOD_PARAMETER_M = "m";
Expand Down
65 changes: 37 additions & 28 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.VectorDataType;
Expand Down Expand Up @@ -86,10 +89,12 @@ public static Query create(CreateQueryRequest createQueryRequest) {
return new KNNQuery(fieldName, vector, k, indexName);
}

log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
BitSetProducer parentFilter = createQueryRequest.context == null ? null : createQueryRequest.context.getParentFilter();
if (VectorDataType.BYTE == vectorDataType) {
return getKnnByteVectorQuery(indexName, fieldName, byteVector, k, filterQuery);
return getKnnByteVectorQuery(fieldName, byteVector, k, filterQuery, parentFilter);
} else if (VectorDataType.FLOAT == vectorDataType) {
return getKnnFloatVectorQuery(indexName, fieldName, vector, k, filterQuery);
return getKnnFloatVectorQuery(fieldName, vector, k, filterQuery, parentFilter);
} else {
throw new IllegalArgumentException(
String.format(
Expand All @@ -102,38 +107,40 @@ public static Query create(CreateQueryRequest createQueryRequest) {
}
}

private static Query getKnnByteVectorQuery(String indexName, String fieldName, byte[] byteVector, int k, Query filterQuery) {
if (filterQuery != null) {
log.debug(
String.format(
Locale.ROOT,
"Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d",
indexName,
fieldName,
k
)
);
/**
* If parentFilter is not null, it is a nested query. Therefore, we return {@link DiversifyingChildrenByteKnnVectorQuery}
* which will dedupe search result per parent so that we can get k parent results at the end.
*/
private static Query getKnnByteVectorQuery(
final String fieldName,
final byte[] byteVector,
final int k,
final Query filterQuery,
final BitSetProducer parentFilter
) {
if (parentFilter == null) {
return new KnnByteVectorQuery(fieldName, byteVector, k, filterQuery);
} else {
return new DiversifyingChildrenByteKnnVectorQuery(fieldName, byteVector, filterQuery, k, parentFilter);
}
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
return new KnnByteVectorQuery(fieldName, byteVector, k);
}

private static Query getKnnFloatVectorQuery(String indexName, String fieldName, float[] floatVector, int k, Query filterQuery) {
if (filterQuery != null) {
log.debug(
String.format(
Locale.ROOT,
"Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d",
indexName,
fieldName,
k
)
);
/**
* If parentFilter is not null, it is a nested query. Therefore, we return {@link DiversifyingChildrenFloatKnnVectorQuery}
* which will dedupe search result per parent so that we can get k parent results at the end.
*/
private static Query getKnnFloatVectorQuery(
final String fieldName,
final float[] floatVector,
final int k,
final Query filterQuery,
final BitSetProducer parentFilter
) {
if (parentFilter == null) {
return new KnnFloatVectorQuery(fieldName, floatVector, k, filterQuery);
} else {
return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, floatVector, filterQuery, k, parentFilter);
}
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
return new KnnFloatVectorQuery(fieldName, floatVector, k);
}

private static Query getFilterQuery(CreateQueryRequest createQueryRequest) {
Expand Down Expand Up @@ -181,6 +188,8 @@ static class CreateQueryRequest {
@Getter
private int k;
// can be null in cases filter not passed with the knn query
@Getter
private BitSetProducer parentFilter;
private QueryBuilder filter;
// can be null in cases filter not passed with the knn query
private QueryShardContext context;
Expand Down
202 changes: 202 additions & 0 deletions src/test/java/org/opensearch/knn/index/NestedSearchIT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index;

import lombok.SneakyThrows;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.junit.After;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.index.util.KNNEngine;

import java.io.IOException;
import java.util.List;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.K;
import static org.opensearch.knn.common.KNNConstants.KNN;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.PATH;
import static org.opensearch.knn.common.KNNConstants.QUERY;
import static org.opensearch.knn.common.KNNConstants.TYPE;
import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR;
import static org.opensearch.knn.common.KNNConstants.TYPE_NESTED;
import static org.opensearch.knn.common.KNNConstants.VECTOR;

public class NestedSearchIT extends KNNRestTestCase {
private static final String INDEX_NAME = "test-index-nested-search";
private static final String FIELD_NAME_NESTED = "test-nested";
private static final String FIELD_NAME_VECTOR = "test-vector";
private static final String PROPERTIES_FIELD = "properties";
private static final int EF_CONSTRUCTION = 128;
private static final int M = 16;
private static final SpaceType SPACE_TYPE = SpaceType.L2;

@After
@SneakyThrows
public final void cleanUp() {
deleteKNNIndex(INDEX_NAME);
}

@SneakyThrows
public void testNestedSearch_whenKIsTwo_thenReturnTwoResults() {
createKnnIndex(2, KNNEngine.LUCENE.getName());

String doc1 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
.add(FIELD_NAME_VECTOR, new Float[] { 1f, 1f }, new Float[] { 1f, 1f })
.build();
addNestedKnnDoc(INDEX_NAME, "1", doc1);

String doc2 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
.add(FIELD_NAME_VECTOR, new Float[] { 2f, 2f }, new Float[] { 2f, 2f })
.build();
addNestedKnnDoc(INDEX_NAME, "2", doc2);

Float[] queryVector = { 1f, 1f };
Response response = queryNestedField(INDEX_NAME, 2, queryVector);

List<Object> hits = (List<Object>) ((Map<String, Object>) createParser(
MediaTypeRegistry.getDefaultMediaType().xContent(),
EntityUtils.toString(response.getEntity())
).map().get("hits")).get("hits");
assertEquals(2, hits.size());
}

/**
* {
* "properties": {
* "test-nested": {
* "type": "nested",
* "properties": {
* "test-vector": {
* "type": "knn_vector",
* "dimension": 3,
* "method": {
* "name": "hnsw",
* "space_type": "l2",
* "engine": "lucene",
* "parameters": {
* "ef_construction": 128,
* "m": 24
* }
* }
* }
* }
* }
* }
* }
*/
private void createKnnIndex(final int dimension, final String engine) throws Exception {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject(PROPERTIES_FIELD)
.startObject(FIELD_NAME_NESTED)
.field(TYPE, TYPE_NESTED)
.startObject(PROPERTIES_FIELD)
.startObject(FIELD_NAME_VECTOR)
.field(TYPE, TYPE_KNN_VECTOR)
.field(DIMENSION, dimension)
.startObject(KNN_METHOD)
.field(NAME, METHOD_HNSW)
.field(METHOD_PARAMETER_SPACE_TYPE, SPACE_TYPE)
.field(KNN_ENGINE, engine)
.startObject(PARAMETERS)
.field(METHOD_PARAMETER_M, M)
.field(METHOD_PARAMETER_EF_CONSTRUCTION, EF_CONSTRUCTION)
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
.endObject();

String mapping = builder.toString();
createKnnIndex(INDEX_NAME, mapping);
}

@SneakyThrows
private void ingestTestData() {
String doc1 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
.add(FIELD_NAME_VECTOR, new Float[] { 1f, 1f }, new Float[] { 1f, 1f })
.build();
addNestedKnnDoc(INDEX_NAME, "1", doc1);

String doc2 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
.add(FIELD_NAME_VECTOR, new Float[] { 2f, 2f }, new Float[] { 2f, 2f })
.build();
addNestedKnnDoc(INDEX_NAME, "2", doc2);
}

private void addNestedKnnDoc(final String index, final String docId, final String document) throws IOException {
Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true");

request.setJsonEntity(document);
client().performRequest(request);

request = new Request("POST", "/" + index + "/_refresh");
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}

private Response queryNestedField(final String index, final int k, final Object[] vector) throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject(QUERY);
builder.startObject(TYPE_NESTED);
builder.field(PATH, FIELD_NAME_NESTED);
builder.startObject(QUERY).startObject(KNN).startObject(FIELD_NAME_NESTED + "." + FIELD_NAME_VECTOR);
builder.field(VECTOR, vector);
builder.field(K, k);
builder.endObject().endObject().endObject().endObject().endObject().endObject();

Request request = new Request("POST", "/" + index + "/_search");
request.setJsonEntity(builder.toString());

Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

return response;
}

private static class NestedKnnDocBuilder {
private XContentBuilder builder;

public NestedKnnDocBuilder(final String fieldName) throws IOException {
builder = XContentFactory.jsonBuilder().startObject().startArray(fieldName);
}

public static NestedKnnDocBuilder create(final String fieldName) throws IOException {
return new NestedKnnDocBuilder(fieldName);
}

public NestedKnnDocBuilder add(final String fieldName, final Object[]... vectors) throws IOException {
for (Object[] vector : vectors) {
builder.startObject();
builder.field(fieldName, vector);
builder.endObject();
}
return this;
}

public String build() throws IOException {
builder.endArray().endObject();
return builder.toString();
}

}
}
Loading

0 comments on commit 9cc75c7

Please sign in to comment.