Skip to content

Commit

Permalink
Multiple innerHit in nested fields
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 committed Dec 10, 2024
1 parent 9276c77 commit 991f0c8
Show file tree
Hide file tree
Showing 45 changed files with 2,082 additions and 210 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.18...2.x)
### Features
- Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283]
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
### Bug Fixes
Expand Down
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ dependencies {
testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: '1.15.10'
testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.3'
testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: '1.15.4'
testFixturesImplementation 'com.jayway.jsonpath:json-path:2.8.0'
testFixturesImplementation "org.opensearch:common-utils:${version}"
implementation 'com.github.oshi:oshi-core:6.4.13'
api "net.java.dev.jna:jna:5.13.0"
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ public class KNNConstants {
public static final String QFRAMEWORK_CONFIG = "qframe_config";

public static final String VECTOR_DATA_TYPE_FIELD = "data_type";
public static final String EXPAND_NESTED = "expand_nested_docs";
public static final String MODEL_VECTOR_DATA_TYPE_KEY = VECTOR_DATA_TYPE_FIELD;
public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT;
public static final String MINIMAL_MODE_AND_COMPRESSION_FEATURE = "mode_and_compression_feature";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public enum KNNEngine implements KNNLibrary {
private static final Set<KNNEngine> CUSTOM_SEGMENT_FILE_ENGINES = ImmutableSet.of(KNNEngine.NMSLIB, KNNEngine.FAISS);
private static final Set<KNNEngine> ENGINES_SUPPORTING_FILTERS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);
public static final Set<KNNEngine> ENGINES_SUPPORTING_RADIAL_SEARCH = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);
public static final Set<KNNEngine> ENGINES_SUPPORTING_MULTI_VECTORS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);

private static Map<KNNEngine, Integer> MAX_DIMENSIONS_BY_ENGINE = Map.of(
KNNEngine.NMSLIB,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public static class CreateQueryRequest {
private QueryBuilder filter;
private QueryShardContext context;
private RescoreContext rescoreContext;
private boolean expandNested;

public Optional<QueryBuilder> getFilter() {
return Optional.ofNullable(filter);
Expand Down
10 changes: 5 additions & 5 deletions src/main/java/org/opensearch/knn/index/query/ExactSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.util.BitSet;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.SpaceType;
Expand Down Expand Up @@ -68,8 +67,8 @@ public Map<Integer, Float> searchLeaf(final LeafReaderContext leafReaderContext,
if (exactSearcherContext.getKnnQuery().getRadius() != null) {
return doRadialSearch(leafReaderContext, exactSearcherContext, iterator);
}
if (exactSearcherContext.getMatchedDocs() != null
&& exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) {
if (exactSearcherContext.getMatchedDocsIterator() != null
&& exactSearcherContext.numberOfMatchedDocs <= exactSearcherContext.getK()) {
return scoreAllDocs(iterator);
}
return searchTopCandidates(iterator, exactSearcherContext.getK(), Predicates.alwaysTrue());
Expand Down Expand Up @@ -155,7 +154,7 @@ private Map<Integer, Float> filterDocsByMinScore(ExactSearcherContext context, K

private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) throws IOException {
final KNNQuery knnQuery = exactSearcherContext.getKnnQuery();
final BitSet matchedDocs = exactSearcherContext.getMatchedDocs();
final DocIdSetIterator matchedDocs = exactSearcherContext.getMatchedDocsIterator();
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField());
if (fieldInfo == null) {
Expand Down Expand Up @@ -245,7 +244,8 @@ public static class ExactSearcherContext {
*/
boolean useQuantizedVectorsForSearch;
int k;
BitSet matchedDocs;
DocIdSetIterator matchedDocsIterator;
long numberOfMatchedDocs;
KNNQuery knnQuery;
/**
* whether the matchedDocs contains parent ids or child ids. This is relevant in the case of
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public class KNNQuery extends Query {

@Setter
private Query filterQuery;
@Getter
private BitSetProducer parentsFilter;
private Float radius;
private Context context;
Expand Down
22 changes: 19 additions & 3 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;

import static org.opensearch.knn.common.KNNConstants.EXPAND_NESTED;
import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
Expand All @@ -74,6 +75,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
public static final ParseField K_FIELD = new ParseField("k");
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped");
public static final ParseField EXPAND_NESTED_FIELD = new ParseField(EXPAND_NESTED);
public static final ParseField MAX_DISTANCE_FIELD = new ParseField(MAX_DISTANCE);
public static final ParseField MIN_SCORE_FIELD = new ParseField(MIN_SCORE);
public static final ParseField EF_SEARCH_FIELD = new ParseField(METHOD_PARAMETER_EF_SEARCH);
Expand Down Expand Up @@ -106,6 +108,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
private boolean ignoreUnmapped;
@Getter
private RescoreContext rescoreContext;
@Getter
private boolean expandNested;

/**
* Constructs a new query with the given field name and vector
Expand Down Expand Up @@ -147,6 +151,7 @@ public static class Builder {
private String queryName;
private float boost = DEFAULT_BOOST;
private RescoreContext rescoreContext;
private boolean expandNested;

public Builder() {}

Expand Down Expand Up @@ -205,6 +210,11 @@ public Builder rescoreContext(RescoreContext rescoreContext) {
return this;
}

public Builder expandNested(boolean expandNested) {
this.expandNested = expandNested;
return this;
}

public KNNQueryBuilder build() {
validate();
int k = this.k == null ? 0 : this.k;
Expand All @@ -217,7 +227,8 @@ public KNNQueryBuilder build() {
methodParameters,
filter,
ignoreUnmapped,
rescoreContext
rescoreContext,
expandNested
).boost(boost).queryName(queryName);
}

Expand Down Expand Up @@ -319,6 +330,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil
this.maxDistance = null;
this.minScore = null;
this.rescoreContext = null;
this.expandNested = false;
}

public static void initialize(ModelDao modelDao) {
Expand All @@ -341,6 +353,7 @@ public KNNQueryBuilder(StreamInput in) throws IOException {
minScore = builder.minScore;
methodParameters = builder.methodParameters;
rescoreContext = builder.rescoreContext;
expandNested = builder.expandNested;
}

@Override
Expand Down Expand Up @@ -536,6 +549,7 @@ protected Query doToQuery(QueryShardContext context) {
.filter(this.filter)
.context(context)
.rescoreContext(processedRescoreContext)
.expandNested(expandNested)
.build();
return KNNQueryFactory.create(createQueryRequest);
}
Expand Down Expand Up @@ -621,7 +635,8 @@ protected boolean doEquals(KNNQueryBuilder other) {
&& Objects.equals(methodParameters, other.methodParameters)
&& Objects.equals(filter, other.filter)
&& Objects.equals(ignoreUnmapped, other.ignoreUnmapped)
&& Objects.equals(rescoreContext, other.rescoreContext);
&& Objects.equals(rescoreContext, other.rescoreContext)
&& Objects.equals(expandNested, other.expandNested);
}

@Override
Expand All @@ -635,7 +650,8 @@ protected int doHashCode() {
ignoreUnmapped,
maxDistance,
minScore,
rescoreContext
rescoreContext,
expandNested
);
}

Expand Down
70 changes: 55 additions & 15 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,28 @@
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.common.QueryUtils;
import org.opensearch.knn.index.query.lucenelib.NestedKnnVectorQueryFactory;
import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery;
import org.opensearch.knn.index.query.rescore.RescoreContext;

import java.util.Locale;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.EXPAND_NESTED;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES;
import static org.opensearch.knn.index.engine.KNNEngine.ENGINES_SUPPORTING_MULTI_VECTORS;

/**
* Creates the Lucene k-NN queries
*/
@Log4j2
public class KNNQueryFactory extends BaseQueryFactory {

/**
* Creates a Lucene query for a particular engine.
* @param createQueryRequest request object that has all required fields to construct the query
Expand All @@ -48,13 +49,25 @@ public static Query create(CreateQueryRequest createQueryRequest) {
final Query filterQuery = getFilterQuery(createQueryRequest);
final Map<String, ?> methodParameters = createQueryRequest.getMethodParameters();
final RescoreContext rescoreContext = createQueryRequest.getRescoreContext().orElse(null);

final KNNEngine knnEngine = createQueryRequest.getKnnEngine();
final boolean expandNested = createQueryRequest.isExpandNested();
BitSetProducer parentFilter = null;
if (createQueryRequest.getContext().isPresent()) {
QueryShardContext context = createQueryRequest.getContext().get();
parentFilter = context.getParentFilter();
}

if (parentFilter == null && expandNested) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Invalid value provided for the [%s] field. [%s] is only supported with a nested field.",
EXPAND_NESTED,
EXPAND_NESTED
)
);
}

if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) {
final Query validatedFilterQuery = validateFilterQuerySupport(filterQuery, createQueryRequest.getKnnEngine());

Expand Down Expand Up @@ -95,7 +108,16 @@ public static Query create(CreateQueryRequest createQueryRequest) {
.rescoreContext(rescoreContext)
.build();
}
return createQueryRequest.getRescoreContext().isPresent() ? new NativeEngineKnnVectorQuery(knnQuery) : knnQuery;

if (createQueryRequest.getRescoreContext().isPresent()) {
return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested);
}

if (ENGINES_SUPPORTING_MULTI_VECTORS.contains(knnEngine) && expandNested) {
return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested);
}

return knnQuery;
}

Integer requestEfSearch = null;
Expand All @@ -106,9 +128,9 @@ public static Query create(CreateQueryRequest createQueryRequest) {
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
switch (vectorDataType) {
case BYTE:
return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter);
return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter, expandNested);
case FLOAT:
return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter);
return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter, expandNested);
default:
throw new IllegalArgumentException(
String.format(
Expand All @@ -131,38 +153,56 @@ private static Query validateFilterQuerySupport(final Query filterQuery, final K
}

/**
* If parentFilter is not null, it is a nested query. Therefore, we return {@link DiversifyingChildrenByteKnnVectorQuery}
* which will dedupe search result per parent so that we can get k parent results at the end.
* If parentFilter is not null, it is a nested query. Therefore, we delegate creation of query to {@link NestedKnnVectorQueryFactory}
* which will create query to dedupe search result per parent so that we can get k parent results at the end.
*/
private static Query getKnnByteVectorQuery(
final String fieldName,
final byte[] byteVector,
final int k,
final Query filterQuery,
final BitSetProducer parentFilter
final BitSetProducer parentFilter,
final boolean expandNested
) {
if (parentFilter == null) {
assert expandNested == false : "expandNested is allowed to be true only for nested fields.";
return new KnnByteVectorQuery(fieldName, byteVector, k, filterQuery);
} else {
return new DiversifyingChildrenByteKnnVectorQuery(fieldName, byteVector, filterQuery, k, parentFilter);
return NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(
fieldName,
byteVector,
k,
filterQuery,
parentFilter,
expandNested
);
}
}

/**
* If parentFilter is not null, it is a nested query. Therefore, we return {@link DiversifyingChildrenFloatKnnVectorQuery}
* which will dedupe search result per parent so that we can get k parent results at the end.
* If parentFilter is not null, it is a nested query. Therefore, we delegate creation of query to {@link NestedKnnVectorQueryFactory}
* which will create query to dedupe search result per parent so that we can get k parent results at the end.
*/
private static Query getKnnFloatVectorQuery(
final String fieldName,
final float[] floatVector,
final int k,
final Query filterQuery,
final BitSetProducer parentFilter
final BitSetProducer parentFilter,
final boolean expandNested
) {
if (parentFilter == null) {
assert expandNested == false : "expandNested is allowed to be true only for nested fields.";
return new KnnFloatVectorQuery(fieldName, floatVector, k, filterQuery);
} else {
return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, floatVector, filterQuery, k, parentFilter);
return NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(
fieldName,
floatVector,
k,
filterQuery,
parentFilter,
expandNested
);
}
}
}
Loading

0 comments on commit 991f0c8

Please sign in to comment.