From 88792e42f121b050f2fc9cf32b039052aab62128 Mon Sep 17 00:00:00 2001 From: Heemin Kim Date: Wed, 11 Dec 2024 09:35:54 -0800 Subject: [PATCH] Multiple innerHit in nested fields (#2283) Signed-off-by: Heemin Kim --- CHANGELOG.md | 1 + build.gradle | 1 + .../opensearch/knn/common/KNNConstants.java | 1 + .../knn/index/engine/KNNEngine.java | 1 + .../knn/index/query/BaseQueryFactory.java | 1 + .../knn/index/query/ExactSearcher.java | 10 +- .../opensearch/knn/index/query/KNNQuery.java | 1 + .../knn/index/query/KNNQueryBuilder.java | 22 +- .../knn/index/query/KNNQueryFactory.java | 70 +++- .../opensearch/knn/index/query/KNNWeight.java | 32 +- .../knn/index/query/PerLeafResult.java | 28 ++ .../knn/index/query/ResultUtil.java | 24 +- .../DocAndScoreQuery.java | 4 +- .../knn/index/query/common/QueryUtils.java | 173 ++++++++ .../iterators/BinaryVectorIdsKNNIterator.java | 12 +- .../iterators/ByteVectorIdsKNNIterator.java | 12 +- .../GroupedNestedDocIdSetIterator.java | 122 ++++++ .../NestedBinaryVectorIdsKNNIterator.java | 4 +- .../NestedByteVectorIdsKNNIterator.java | 4 +- .../iterators/NestedVectorIdsKNNIterator.java | 8 +- .../query/iterators/VectorIdsKNNIterator.java | 16 +- .../lucenelib/ExpandNestedDocsQuery.java | 141 +++++++ .../InternalNestedKnnByteVectoryQuery.java | 57 +++ .../InternalNestedKnnFloatVectoryQuery.java | 57 +++ .../InternalNestedKnnVectorQuery.java | 63 +++ .../NestedKnnVectorQueryFactory.java | 77 ++++ .../nativelib/NativeEngineKnnVectorQuery.java | 158 ++++--- .../query/parser/KNNQueryBuilderParser.java | 15 + .../opensearch/knn/index/util/IndexUtil.java | 3 + .../knn/index/query/KNNQueryFactoryTests.java | 36 ++ .../knn/index/query/ResultUtilTests.java | 41 +- .../DocAndScoreQueryTests.java | 2 +- .../index/query/common/QueryUtilsTests.java | 210 ++++++++++ .../BinaryVectorIdsKNNIteratorTests.java | 8 +- .../ByteVectorIdsKNNIteratorTests.java | 8 +- .../GroupedNestedDocIdSetIteratorTests.java | 73 ++++ ...NestedBinaryVectorIdsKNNIteratorTests.java | 3 +- .../NestedByteVectorIdsKNNIteratorTests.java | 3 +- .../NestedVectorIdsKNNIteratorTests.java | 9 +- .../iterators/VectorIdsKNNIteratorTests.java | 8 +- .../ExpandNestedEDocsQueryTests.java | 132 ++++++ .../NestedKnnVectorQueryFactoryTests.java | 64 +++ .../NativeEngineKNNVectorQueryTests.java | 162 ++++++-- .../knn/integ/ExpandNestedDocsIT.java | 392 ++++++++++++++++++ .../org/opensearch/knn/KNNRestTestCase.java | 23 +- 45 files changed, 2082 insertions(+), 210 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/query/PerLeafResult.java rename src/main/java/org/opensearch/knn/index/query/{nativelib => common}/DocAndScoreQuery.java (97%) create mode 100644 src/main/java/org/opensearch/knn/index/query/common/QueryUtils.java create mode 100644 src/main/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIterator.java create mode 100644 src/main/java/org/opensearch/knn/index/query/lucenelib/ExpandNestedDocsQuery.java create mode 100644 src/main/java/org/opensearch/knn/index/query/lucenelib/InternalNestedKnnByteVectoryQuery.java create mode 100644 src/main/java/org/opensearch/knn/index/query/lucenelib/InternalNestedKnnFloatVectoryQuery.java create mode 100644 src/main/java/org/opensearch/knn/index/query/lucenelib/InternalNestedKnnVectorQuery.java create mode 100644 src/main/java/org/opensearch/knn/index/query/lucenelib/NestedKnnVectorQueryFactory.java rename src/test/java/org/opensearch/knn/index/query/{nativelib => common}/DocAndScoreQueryTests.java (98%) create mode 100644 src/test/java/org/opensearch/knn/index/query/common/QueryUtilsTests.java create mode 100644 src/test/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIteratorTests.java create mode 100644 src/test/java/org/opensearch/knn/index/query/lucenelib/ExpandNestedEDocsQueryTests.java create mode 100644 src/test/java/org/opensearch/knn/index/query/lucenelib/NestedKnnVectorQueryFactoryTests.java create mode 100644 src/test/java/org/opensearch/knn/integ/ExpandNestedDocsIT.java diff --git a/CHANGELOG.md b/CHANGELOG.md index f614a2368..7e6016ff9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.18...2.x) ### Features +- Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283] ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] - Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290] diff --git a/build.gradle b/build.gradle index 25ff04032..7fd67d1b2 100644 --- a/build.gradle +++ b/build.gradle @@ -298,6 +298,7 @@ dependencies { testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: '1.15.10' testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.3' testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: '1.15.4' + testFixturesImplementation 'com.jayway.jsonpath:json-path:2.8.0' testFixturesImplementation "org.opensearch:common-utils:${version}" implementation 'com.github.oshi:oshi-core:6.4.13' api "net.java.dev.jna:jna:5.13.0" diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 16084499c..ce6095fd0 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -71,6 +71,7 @@ public class KNNConstants { public static final String QFRAMEWORK_CONFIG = "qframe_config"; public static final String VECTOR_DATA_TYPE_FIELD = "data_type"; + public static final String EXPAND_NESTED = "expand_nested_docs"; public static final String MODEL_VECTOR_DATA_TYPE_KEY = VECTOR_DATA_TYPE_FIELD; public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT; public static final String MINIMAL_MODE_AND_COMPRESSION_FEATURE = "mode_and_compression_feature"; diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java index 1e560a11b..f75c7f1d9 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java @@ -34,6 +34,7 @@ public enum KNNEngine implements KNNLibrary { private static final Set CUSTOM_SEGMENT_FILE_ENGINES = ImmutableSet.of(KNNEngine.NMSLIB, KNNEngine.FAISS); private static final Set ENGINES_SUPPORTING_FILTERS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS); public static final Set ENGINES_SUPPORTING_RADIAL_SEARCH = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS); + public static final Set ENGINES_SUPPORTING_MULTI_VECTORS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS); private static Map MAX_DIMENSIONS_BY_ENGINE = Map.of( KNNEngine.NMSLIB, diff --git a/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java index cfb604c18..984ed00bc 100644 --- a/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java @@ -50,6 +50,7 @@ public static class CreateQueryRequest { private QueryBuilder filter; private QueryShardContext context; private RescoreContext rescoreContext; + private boolean expandNested; public Optional getFilter() { return Optional.ofNullable(filter); diff --git a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java index 6a97b4083..7f0330432 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java @@ -17,7 +17,6 @@ import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.HitQueue; import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.util.BitSet; import org.opensearch.common.lucene.Lucene; import org.opensearch.knn.common.FieldInfoExtractor; import org.opensearch.knn.index.SpaceType; @@ -68,8 +67,8 @@ public Map searchLeaf(final LeafReaderContext leafReaderContext, if (exactSearcherContext.getKnnQuery().getRadius() != null) { return doRadialSearch(leafReaderContext, exactSearcherContext, iterator); } - if (exactSearcherContext.getMatchedDocs() != null - && exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) { + if (exactSearcherContext.getMatchedDocsIterator() != null + && exactSearcherContext.numberOfMatchedDocs <= exactSearcherContext.getK()) { return scoreAllDocs(iterator); } return searchTopCandidates(iterator, exactSearcherContext.getK(), Predicates.alwaysTrue()); @@ -155,7 +154,7 @@ private Map filterDocsByMinScore(ExactSearcherContext context, K private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) throws IOException { final KNNQuery knnQuery = exactSearcherContext.getKnnQuery(); - final BitSet matchedDocs = exactSearcherContext.getMatchedDocs(); + final DocIdSetIterator matchedDocs = exactSearcherContext.getMatchedDocsIterator(); final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader()); final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField()); if (fieldInfo == null) { @@ -245,7 +244,8 @@ public static class ExactSearcherContext { */ boolean useQuantizedVectorsForSearch; int k; - BitSet matchedDocs; + DocIdSetIterator matchedDocsIterator; + long numberOfMatchedDocs; KNNQuery knnQuery; /** * whether the matchedDocs contains parent ids or child ids. This is relevant in the case of diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java index f0974f7e9..1a03f4b99 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java @@ -48,6 +48,7 @@ public class KNNQuery extends Query { @Setter private Query filterQuery; + @Getter private BitSetProducer parentsFilter; private Float radius; private Context context; diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 8f7c5a3ff..063842a7f 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -49,6 +49,7 @@ import java.util.Objects; import java.util.concurrent.atomic.AtomicReference; +import static org.opensearch.knn.common.KNNConstants.EXPAND_NESTED; import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; @@ -74,6 +75,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { public static final ParseField K_FIELD = new ParseField("k"); public static final ParseField FILTER_FIELD = new ParseField("filter"); public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped"); + public static final ParseField EXPAND_NESTED_FIELD = new ParseField(EXPAND_NESTED); public static final ParseField MAX_DISTANCE_FIELD = new ParseField(MAX_DISTANCE); public static final ParseField MIN_SCORE_FIELD = new ParseField(MIN_SCORE); public static final ParseField EF_SEARCH_FIELD = new ParseField(METHOD_PARAMETER_EF_SEARCH); @@ -106,6 +108,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { private boolean ignoreUnmapped; @Getter private RescoreContext rescoreContext; + @Getter + private boolean expandNested; /** * Constructs a new query with the given field name and vector @@ -147,6 +151,7 @@ public static class Builder { private String queryName; private float boost = DEFAULT_BOOST; private RescoreContext rescoreContext; + private boolean expandNested; public Builder() {} @@ -205,6 +210,11 @@ public Builder rescoreContext(RescoreContext rescoreContext) { return this; } + public Builder expandNested(boolean expandNested) { + this.expandNested = expandNested; + return this; + } + public KNNQueryBuilder build() { validate(); int k = this.k == null ? 0 : this.k; @@ -217,7 +227,8 @@ public KNNQueryBuilder build() { methodParameters, filter, ignoreUnmapped, - rescoreContext + rescoreContext, + expandNested ).boost(boost).queryName(queryName); } @@ -319,6 +330,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil this.maxDistance = null; this.minScore = null; this.rescoreContext = null; + this.expandNested = false; } public static void initialize(ModelDao modelDao) { @@ -341,6 +353,7 @@ public KNNQueryBuilder(StreamInput in) throws IOException { minScore = builder.minScore; methodParameters = builder.methodParameters; rescoreContext = builder.rescoreContext; + expandNested = builder.expandNested; } @Override @@ -536,6 +549,7 @@ protected Query doToQuery(QueryShardContext context) { .filter(this.filter) .context(context) .rescoreContext(processedRescoreContext) + .expandNested(expandNested) .build(); return KNNQueryFactory.create(createQueryRequest); } @@ -621,7 +635,8 @@ protected boolean doEquals(KNNQueryBuilder other) { && Objects.equals(methodParameters, other.methodParameters) && Objects.equals(filter, other.filter) && Objects.equals(ignoreUnmapped, other.ignoreUnmapped) - && Objects.equals(rescoreContext, other.rescoreContext); + && Objects.equals(rescoreContext, other.rescoreContext) + && Objects.equals(expandNested, other.expandNested); } @Override @@ -635,7 +650,8 @@ protected int doHashCode() { ignoreUnmapped, maxDistance, minScore, - rescoreContext + rescoreContext, + expandNested ); } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index dab2e08c8..bac6e95b5 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -10,27 +10,28 @@ import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; -import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; -import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.query.common.QueryUtils; +import org.opensearch.knn.index.query.lucenelib.NestedKnnVectorQueryFactory; import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery; import org.opensearch.knn.index.query.rescore.RescoreContext; import java.util.Locale; import java.util.Map; +import static org.opensearch.knn.common.KNNConstants.EXPAND_NESTED; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; +import static org.opensearch.knn.index.engine.KNNEngine.ENGINES_SUPPORTING_MULTI_VECTORS; /** * Creates the Lucene k-NN queries */ @Log4j2 public class KNNQueryFactory extends BaseQueryFactory { - /** * Creates a Lucene query for a particular engine. * @param createQueryRequest request object that has all required fields to construct the query @@ -48,13 +49,25 @@ public static Query create(CreateQueryRequest createQueryRequest) { final Query filterQuery = getFilterQuery(createQueryRequest); final Map methodParameters = createQueryRequest.getMethodParameters(); final RescoreContext rescoreContext = createQueryRequest.getRescoreContext().orElse(null); - + final KNNEngine knnEngine = createQueryRequest.getKnnEngine(); + final boolean expandNested = createQueryRequest.isExpandNested(); BitSetProducer parentFilter = null; if (createQueryRequest.getContext().isPresent()) { QueryShardContext context = createQueryRequest.getContext().get(); parentFilter = context.getParentFilter(); } + if (parentFilter == null && expandNested) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Invalid value provided for the [%s] field. [%s] is only supported with a nested field.", + EXPAND_NESTED, + EXPAND_NESTED + ) + ); + } + if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { final Query validatedFilterQuery = validateFilterQuerySupport(filterQuery, createQueryRequest.getKnnEngine()); @@ -95,7 +108,16 @@ public static Query create(CreateQueryRequest createQueryRequest) { .rescoreContext(rescoreContext) .build(); } - return createQueryRequest.getRescoreContext().isPresent() ? new NativeEngineKnnVectorQuery(knnQuery) : knnQuery; + + if (createQueryRequest.getRescoreContext().isPresent()) { + return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested); + } + + if (ENGINES_SUPPORTING_MULTI_VECTORS.contains(knnEngine) && expandNested) { + return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested); + } + + return knnQuery; } Integer requestEfSearch = null; @@ -106,9 +128,9 @@ public static Query create(CreateQueryRequest createQueryRequest) { log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); switch (vectorDataType) { case BYTE: - return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter); + return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter, expandNested); case FLOAT: - return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter); + return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter, expandNested); default: throw new IllegalArgumentException( String.format( @@ -131,38 +153,56 @@ private static Query validateFilterQuerySupport(final Query filterQuery, final K } /** - * If parentFilter is not null, it is a nested query. Therefore, we return {@link DiversifyingChildrenByteKnnVectorQuery} - * which will dedupe search result per parent so that we can get k parent results at the end. + * If parentFilter is not null, it is a nested query. Therefore, we delegate creation of query to {@link NestedKnnVectorQueryFactory} + * which will create query to dedupe search result per parent so that we can get k parent results at the end. */ private static Query getKnnByteVectorQuery( final String fieldName, final byte[] byteVector, final int k, final Query filterQuery, - final BitSetProducer parentFilter + final BitSetProducer parentFilter, + final boolean expandNested ) { if (parentFilter == null) { + assert expandNested == false : "expandNested is allowed to be true only for nested fields."; return new KnnByteVectorQuery(fieldName, byteVector, k, filterQuery); } else { - return new DiversifyingChildrenByteKnnVectorQuery(fieldName, byteVector, filterQuery, k, parentFilter); + return NestedKnnVectorQueryFactory.createNestedKnnVectorQuery( + fieldName, + byteVector, + k, + filterQuery, + parentFilter, + expandNested + ); } } /** - * If parentFilter is not null, it is a nested query. Therefore, we return {@link DiversifyingChildrenFloatKnnVectorQuery} - * which will dedupe search result per parent so that we can get k parent results at the end. + * If parentFilter is not null, it is a nested query. Therefore, we delegate creation of query to {@link NestedKnnVectorQueryFactory} + * which will create query to dedupe search result per parent so that we can get k parent results at the end. */ private static Query getKnnFloatVectorQuery( final String fieldName, final float[] floatVector, final int k, final Query filterQuery, - final BitSetProducer parentFilter + final BitSetProducer parentFilter, + final boolean expandNested ) { if (parentFilter == null) { + assert expandNested == false : "expandNested is allowed to be true only for nested fields."; return new KnnFloatVectorQuery(fieldName, floatVector, k, filterQuery); } else { - return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, floatVector, filterQuery, k, parentFilter); + return NestedKnnVectorQueryFactory.createNestedKnnVectorQuery( + fieldName, + floatVector, + k, + filterQuery, + parentFilter, + expandNested + ); } } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index b64472994..891f9325c 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.query; import com.google.common.annotations.VisibleForTesting; +import lombok.Getter; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReaderContext; @@ -66,6 +67,7 @@ public class KNNWeight extends Weight { private final float boost; private final NativeMemoryCacheManager nativeMemoryCacheManager; + @Getter private final Weight filterWeight; private final ExactSearcher exactSearcher; @@ -109,7 +111,7 @@ public Explanation explain(LeafReaderContext context, int doc) { @Override public Scorer scorer(LeafReaderContext context) throws IOException { - final Map docIdToScoreMap = searchLeaf(context, knnQuery.getK()); + final Map docIdToScoreMap = searchLeaf(context, knnQuery.getK()).getResult(); if (docIdToScoreMap.isEmpty()) { return KNNScorer.emptyScorer(this); } @@ -125,14 +127,14 @@ public Scorer scorer(LeafReaderContext context) throws IOException { * @param k Number of results to return * @return A Map of docId to scores for top k results */ - public Map searchLeaf(LeafReaderContext context, int k) throws IOException { + public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOException { final BitSet filterBitSet = getFilteredDocsBitSet(context); int cardinality = filterBitSet.cardinality(); // We don't need to go to JNI layer if no documents are found which satisfy the filters // We should give this condition a deeper look that where it should be placed. For now I feel this is a good // place, if (filterWeight != null && cardinality == 0) { - return Collections.emptyMap(); + return PerLeafResult.EMPTY_RESULT; } /* * The idea for this optimization is to get K results, we need to at least look at K vectors in the HNSW graph @@ -140,17 +142,19 @@ public Map searchLeaf(LeafReaderContext context, int k) throws I * This improves the recall. */ if (isFilteredExactSearchPreferred(cardinality)) { - return doExactSearch(context, filterBitSet, k); + Map result = doExactSearch(context, new BitSetIterator(filterBitSet, cardinality), cardinality, k); + return new PerLeafResult(filterWeight == null ? null : filterBitSet, result); } Map docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k); // See whether we have to perform exact search based on approx search results // This is required if there are no native engine files or if approximate search returned // results less than K, though we have more than k filtered docs if (isExactSearchRequire(context, cardinality, docIdsToScoreMap.size())) { - final BitSet docs = filterWeight != null ? filterBitSet : null; - return doExactSearch(context, docs, k); + final BitSetIterator docs = filterWeight != null ? new BitSetIterator(filterBitSet, cardinality) : null; + Map result = doExactSearch(context, docs, cardinality, k); + return new PerLeafResult(filterWeight == null ? null : filterBitSet, result); } - return docIdsToScoreMap; + return new PerLeafResult(filterWeight == null ? null : filterBitSet, docIdsToScoreMap); } private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException { @@ -205,17 +209,21 @@ private int[] bitSetToIntArray(final BitSet bitSet) { return intArray; } - private Map doExactSearch(final LeafReaderContext context, final BitSet acceptedDocs, int k) throws IOException { + private Map doExactSearch( + final LeafReaderContext context, + final DocIdSetIterator acceptedDocs, + final long numberOfAcceptedDocs, + int k + ) throws IOException { final ExactSearcherContextBuilder exactSearcherContextBuilder = ExactSearcher.ExactSearcherContext.builder() .isParentHits(true) .k(k) // setting to true, so that if quantization details are present we want to do search on the quantized // vectors as this flow is used in first pass of search. .useQuantizedVectorsForSearch(true) - .knnQuery(knnQuery); - if (acceptedDocs != null) { - exactSearcherContextBuilder.matchedDocs(acceptedDocs); - } + .knnQuery(knnQuery) + .matchedDocsIterator(acceptedDocs) + .numberOfMatchedDocs(numberOfAcceptedDocs); return exactSearch(context, exactSearcherContextBuilder.build()); } diff --git a/src/main/java/org/opensearch/knn/index/query/PerLeafResult.java b/src/main/java/org/opensearch/knn/index/query/PerLeafResult.java new file mode 100644 index 000000000..8434d6a17 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/PerLeafResult.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import lombok.Getter; +import lombok.Setter; +import org.apache.lucene.util.Bits; +import org.opensearch.common.Nullable; + +import java.util.Collections; +import java.util.Map; + +@Getter +public class PerLeafResult { + public static final PerLeafResult EMPTY_RESULT = new PerLeafResult(new Bits.MatchNoBits(0), Collections.emptyMap()); + @Nullable + private final Bits filterBits; + @Setter + private Map result; + + public PerLeafResult(final Bits filterBits, final Map result) { + this.filterBits = filterBits == null ? new Bits.MatchAllBits(0) : filterBits; + this.result = result; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/ResultUtil.java b/src/main/java/org/opensearch/knn/index/query/ResultUtil.java index df1ce3827..5c66eaaa2 100644 --- a/src/main/java/org/opensearch/knn/index/query/ResultUtil.java +++ b/src/main/java/org/opensearch/knn/index/query/ResultUtil.java @@ -9,7 +9,6 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; -import org.apache.lucene.util.BitSet; import org.apache.lucene.util.DocIdSetBuilder; import java.io.IOException; @@ -30,14 +29,14 @@ public final class ResultUtil { * @param perLeafResults Results from the list * @param k the number of results across all leaf results to return */ - public static void reduceToTopK(List> perLeafResults, int k) { + public static void reduceToTopK(List perLeafResults, int k) { // Iterate over all scores to get min competitive score PriorityQueue topKMinQueue = new PriorityQueue<>(k); int count = 0; - for (Map perLeafResult : perLeafResults) { - count += perLeafResult.size(); - for (Float score : perLeafResult.values()) { + for (PerLeafResult perLeafResult : perLeafResults) { + count += perLeafResult.getResult().size(); + for (Float score : perLeafResult.getResult().values()) { if (topKMinQueue.size() < k) { topKMinQueue.add(score); } else if (topKMinQueue.peek() != null && score > topKMinQueue.peek()) { @@ -54,23 +53,22 @@ public static void reduceToTopK(List> perLeafResults, int k) // Reduce the results based on min competitive score float minScore = topKMinQueue.peek() == null ? -Float.MAX_VALUE : topKMinQueue.peek(); - perLeafResults.forEach(results -> results.entrySet().removeIf(entry -> entry.getValue() < minScore)); + perLeafResults.forEach(results -> results.getResult().entrySet().removeIf(entry -> entry.getValue() < minScore)); } /** - * Convert map to bit set, if resultMap is empty or null then returns an Optional. Returning an optional here to - * ensure that the caller is aware that BitSet may not be present + * Convert map of docs to doc id set iterator * * @param resultMap Map of results - * @return BitSet of results; null is returned if the result map is empty + * @return Doc id set iterator * @throws IOException If an error occurs during the search. */ - public static BitSet resultMapToMatchBitSet(Map resultMap) throws IOException { - if (resultMap == null || resultMap.isEmpty()) { - return null; + public static DocIdSetIterator resultMapToDocIds(Map resultMap) throws IOException { + if (resultMap.isEmpty()) { + return DocIdSetIterator.empty(); } final int maxDoc = Collections.max(resultMap.keySet()) + 1; - return BitSet.of(resultMapToDocIds(resultMap, maxDoc), maxDoc); + return resultMapToDocIds(resultMap, maxDoc); } /** diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQuery.java b/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java similarity index 97% rename from src/main/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQuery.java rename to src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java index b94264b4d..f38cc96c6 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.knn.index.query.nativelib; +package org.opensearch.knn.index.query.common; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; @@ -32,7 +32,7 @@ final class DocAndScoreQuery extends Query { private final int[] segmentStarts; private final Object contextIdentity; - DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) { + public DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) { this.k = k; this.docs = docs; this.scores = scores; diff --git a/src/main/java/org/opensearch/knn/index/query/common/QueryUtils.java b/src/main/java/org/opensearch/knn/index/query/common/QueryUtils.java new file mode 100644 index 000000000..5fc0fb077 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/common/QueryUtils.java @@ -0,0 +1,173 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.common; + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.FilteredDocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Weight; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; +import org.apache.lucene.util.Bits; +import org.opensearch.knn.index.query.iterators.GroupedNestedDocIdSetIterator; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.Callable; + +/** + * This class contains utility methods that help customize the search results + */ +public class QueryUtils { + public static QueryUtils INSTANCE = new QueryUtils(); + + /** + * Returns a query that represents the specified TopDocs + * This is copied from {@link org.apache.lucene.search.AbstractKnnVectorQuery#createRewrittenQuery} + * + * @param reader the index reader + * @param topDocs the documents to be retured by the query + * @return a query representing the given TopDocs + */ + public Query createDocAndScoreQuery(final IndexReader reader, final TopDocs topDocs) { + int len = topDocs.scoreDocs.length; + Arrays.sort(topDocs.scoreDocs, Comparator.comparingInt(a -> a.doc)); + int[] docs = new int[len]; + float[] scores = new float[len]; + for (int i = 0; i < len; i++) { + docs[i] = topDocs.scoreDocs[i].doc; + scores[i] = topDocs.scoreDocs[i].score; + } + int[] segmentStarts = findSegmentStarts(reader, docs); + return new DocAndScoreQuery(len, docs, scores, segmentStarts, reader.getContext().id()); + } + + private int[] findSegmentStarts(final IndexReader reader, final int[] docs) { + int[] starts = new int[reader.leaves().size() + 1]; + starts[starts.length - 1] = docs.length; + if (starts.length == 2) { + return starts; + } + int resultIndex = 0; + for (int i = 1; i < starts.length - 1; i++) { + int upper = reader.leaves().get(i).docBase; + resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper); + if (resultIndex < 0) { + resultIndex = -1 - resultIndex; + } + starts[i] = resultIndex; + } + return starts; + } + + /** + * Performs the search in parallel. + * + * @param indexSearcher the index searcher + * @param leafReaderContexts the leaf reader contexts + * @param weight the search weight + * @return a list of maps, each mapping document IDs to their scores + * @throws IOException + */ + public List> doSearch( + final IndexSearcher indexSearcher, + final List leafReaderContexts, + final Weight weight + ) throws IOException { + List>> tasks = new ArrayList<>(leafReaderContexts.size()); + for (LeafReaderContext leafReaderContext : leafReaderContexts) { + tasks.add(() -> searchLeaf(leafReaderContext, weight)); + } + return indexSearcher.getTaskExecutor().invokeAll(tasks); + } + + private Map searchLeaf(final LeafReaderContext ctx, final Weight weight) throws IOException { + Map leafDocScores = new HashMap<>(); + Scorer scorer = weight.scorer(ctx); + if (scorer == null) { + return Collections.emptyMap(); + } + + DocIdSetIterator iterator = scorer.iterator(); + iterator.nextDoc(); + while (iterator.docID() != DocIdSetIterator.NO_MORE_DOCS) { + leafDocScores.put(scorer.docID(), scorer.score()); + iterator.nextDoc(); + } + return leafDocScores; + } + + /** + * For the specified nested field document IDs, retrieves all sibling nested field document IDs. + * + * @param leafReaderContext the leaf reader context + * @param docIds the document IDs of the nested field + * @param parentsFilter a bitset mapping parent document IDs to their nested field document IDs + * @return an iterator of document IDs for all filtered sibling nested field documents corresponding to the given document IDs + * @throws IOException + */ + public DocIdSetIterator getAllSiblings( + final LeafReaderContext leafReaderContext, + final Set docIds, + final BitSetProducer parentsFilter, + final Bits queryFilter + ) throws IOException { + if (docIds.isEmpty()) { + return DocIdSetIterator.empty(); + } + + BitSet parentBitSet = parentsFilter.getBitSet(leafReaderContext); + return new GroupedNestedDocIdSetIterator(parentBitSet, docIds, queryFilter); + } + + /** + * Converts the specified search weight into a {@link Bits} containing document IDs. + * + * @param leafReaderContext the leaf reader context + * @param filterWeight the search weight + * @return a {@link Bits} of document IDs derived from the search weight + * @throws IOException + */ + public Bits createBits(final LeafReaderContext leafReaderContext, final Weight filterWeight) throws IOException { + if (filterWeight == null) { + return new Bits.MatchAllBits(0); + } + + final Scorer scorer = filterWeight.scorer(leafReaderContext); + if (scorer == null) { + return new Bits.MatchNoBits(0); + } + + final Bits liveDocs = leafReaderContext.reader().getLiveDocs(); + final int maxDoc = leafReaderContext.reader().maxDoc(); + DocIdSetIterator filteredDocIdsIterator = scorer.iterator(); + if (liveDocs == null && filteredDocIdsIterator instanceof BitSetIterator) { + // If we already have a BitSet and no deletions, reuse the BitSet + return ((BitSetIterator) filteredDocIdsIterator).getBitSet(); + } + // Create a new BitSet from matching and live docs + FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(filteredDocIdsIterator) { + @Override + protected boolean match(int doc) { + return liveDocs == null || liveDocs.get(doc); + } + }; + return BitSet.of(filterIterator, maxDoc); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIterator.java index 5bab5b573..b6eaf182f 100644 --- a/src/main/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIterator.java @@ -6,8 +6,6 @@ package org.opensearch.knn.index.query.iterators; import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.util.BitSet; -import org.apache.lucene.util.BitSetIterator; import org.opensearch.common.Nullable; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; @@ -21,7 +19,7 @@ * The class is used in KNNWeight to score all docs, but, it iterates over filterIdsArray if filter is provided */ public class BinaryVectorIdsKNNIterator implements KNNIterator { - protected final BitSetIterator bitSetIterator; + protected final DocIdSetIterator docIdSetIterator; protected final byte[] queryVector; protected final KNNBinaryVectorValues binaryVectorValues; protected final SpaceType spaceType; @@ -29,12 +27,12 @@ public class BinaryVectorIdsKNNIterator implements KNNIterator { protected int docId; public BinaryVectorIdsKNNIterator( - @Nullable final BitSet filterIdsBitSet, + @Nullable final DocIdSetIterator docIdSetIterator, final byte[] queryVector, final KNNBinaryVectorValues binaryVectorValues, final SpaceType spaceType ) throws IOException { - this.bitSetIterator = filterIdsBitSet == null ? null : new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length()); + this.docIdSetIterator = docIdSetIterator; this.queryVector = queryVector; this.binaryVectorValues = binaryVectorValues; this.spaceType = spaceType; @@ -79,10 +77,10 @@ protected float computeScore() throws IOException { } protected int getNextDocId() throws IOException { - if (bitSetIterator == null) { + if (docIdSetIterator == null) { return binaryVectorValues.nextDoc(); } - int nextDocID = this.bitSetIterator.nextDoc(); + int nextDocID = this.docIdSetIterator.nextDoc(); // For filter case, advance vector values to corresponding doc id from filter bit set if (nextDocID != DocIdSetIterator.NO_MORE_DOCS) { binaryVectorValues.advance(nextDocID); diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIterator.java index 0e8005163..5030e1c7b 100644 --- a/src/main/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIterator.java @@ -6,8 +6,6 @@ package org.opensearch.knn.index.query.iterators; import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.util.BitSet; -import org.apache.lucene.util.BitSetIterator; import org.opensearch.common.Nullable; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues; @@ -21,7 +19,7 @@ * The class is used in KNNWeight to score all docs, but, it iterates over filterIdsArray if filter is provided */ public class ByteVectorIdsKNNIterator implements KNNIterator { - protected final BitSetIterator bitSetIterator; + protected final DocIdSetIterator filterIdsIterator; protected final float[] queryVector; protected final KNNByteVectorValues byteVectorValues; protected final SpaceType spaceType; @@ -29,12 +27,12 @@ public class ByteVectorIdsKNNIterator implements KNNIterator { protected int docId; public ByteVectorIdsKNNIterator( - @Nullable final BitSet filterIdsBitSet, + @Nullable final DocIdSetIterator filterIdsIterator, final float[] queryVector, final KNNByteVectorValues byteVectorValues, final SpaceType spaceType ) throws IOException { - this.bitSetIterator = filterIdsBitSet == null ? null : new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length()); + this.filterIdsIterator = filterIdsIterator; this.queryVector = queryVector; this.byteVectorValues = byteVectorValues; this.spaceType = spaceType; @@ -89,10 +87,10 @@ protected float computeScore() throws IOException { } protected int getNextDocId() throws IOException { - if (bitSetIterator == null) { + if (filterIdsIterator == null) { return byteVectorValues.nextDoc(); } - int nextDocID = this.bitSetIterator.nextDoc(); + int nextDocID = this.filterIdsIterator.nextDoc(); // For filter case, advance vector values to corresponding doc id from filter bit set if (nextDocID != DocIdSetIterator.NO_MORE_DOCS) { byteVectorValues.advance(nextDocID); diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIterator.java new file mode 100644 index 000000000..19842a67a --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIterator.java @@ -0,0 +1,122 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.iterators; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.Bits; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Set; + +/** + * A `DocIdSetIterator` that iterates over all nested document IDs belongs to the same parent document for a given + * set of nested document IDs. + * + * The {@link #docIds} should include only a single nested document ID per parent document. Otherwise, the nested documents + * of that parent document will be iterated multiple times. + * + */ +public class GroupedNestedDocIdSetIterator extends DocIdSetIterator { + private final BitSet parentBitSet; + private final Bits filterBits; + private final List docIds; + private long cost; + private int currentIndex; + private int currentDocId; + private int currentParentId; + + public GroupedNestedDocIdSetIterator(final BitSet parentBitSet, final Set docIds, final Bits filterBits) { + this.parentBitSet = parentBitSet; + this.docIds = new ArrayList<>(docIds); + this.docIds.sort(Comparator.naturalOrder()); + this.filterBits = filterBits; + currentIndex = -1; + currentDocId = -1; + cost = -1; + } + + @Override + public int docID() { + return currentDocId; + } + + @Override + public int nextDoc() throws IOException { + while (true) { + if (doNextDoc() != NO_MORE_DOCS) { + if (!filterBits.get(currentDocId)) { + continue; + } + + return currentDocId; + } + + return currentDocId; + } + } + + public int doNextDoc() throws IOException { + if (currentDocId == NO_MORE_DOCS) { + return currentDocId; + } + + if (currentDocId == -1) { + moveToNextIndex(); + return currentDocId; + } + + currentDocId++; + assert currentDocId <= currentParentId; + if (currentDocId == currentParentId) { + moveToNextIndex(); + } + return currentDocId; + } + + @Override + public int advance(final int i) throws IOException { + if (currentDocId == NO_MORE_DOCS) { + return currentDocId; + } + + return slowAdvance(i); + } + + @Override + public long cost() { + if (cost == -1) { + cost = calculateCost(); + } + return cost; + } + + private long calculateCost() { + long numDocs = 0; + for (int docId : docIds) { + for (int i = parentBitSet.prevSetBit(docId) + 1; i < parentBitSet.nextSetBit(docId); i++) { + if (filterBits.get(i)) { + numDocs++; + } + } + } + return numDocs; + } + + private void moveToNextIndex() { + currentIndex++; + if (currentIndex >= docIds.size()) { + currentDocId = NO_MORE_DOCS; + return; + } + currentDocId = parentBitSet.prevSetBit(docIds.get(currentIndex)) + 1; + currentParentId = parentBitSet.nextSetBit(docIds.get(currentIndex)); + assert currentParentId != NO_MORE_DOCS; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIterator.java index 97bf3517e..eb285814a 100644 --- a/src/main/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIterator.java @@ -22,13 +22,13 @@ public class NestedBinaryVectorIdsKNNIterator extends BinaryVectorIdsKNNIterator private final BitSet parentBitSet; public NestedBinaryVectorIdsKNNIterator( - @Nullable final BitSet filterIdsArray, + @Nullable final DocIdSetIterator filterIdsIterator, final byte[] queryVector, final KNNBinaryVectorValues binaryVectorValues, final SpaceType spaceType, final BitSet parentBitSet ) throws IOException { - super(filterIdsArray, queryVector, binaryVectorValues, spaceType); + super(filterIdsIterator, queryVector, binaryVectorValues, spaceType); this.parentBitSet = parentBitSet; } diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIterator.java index 9644b620f..645133ba2 100644 --- a/src/main/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIterator.java @@ -22,13 +22,13 @@ public class NestedByteVectorIdsKNNIterator extends ByteVectorIdsKNNIterator { private final BitSet parentBitSet; public NestedByteVectorIdsKNNIterator( - @Nullable final BitSet filterIdsArray, + @Nullable final DocIdSetIterator filterIdsIterator, final float[] queryVector, final KNNByteVectorValues byteVectorValues, final SpaceType spaceType, final BitSet parentBitSet ) throws IOException { - super(filterIdsArray, queryVector, byteVectorValues, spaceType); + super(filterIdsIterator, queryVector, byteVectorValues, spaceType); this.parentBitSet = parentBitSet; } diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/NestedVectorIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/NestedVectorIdsKNNIterator.java index 692793b99..f356fa02e 100644 --- a/src/main/java/org/opensearch/knn/index/query/iterators/NestedVectorIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/iterators/NestedVectorIdsKNNIterator.java @@ -23,13 +23,13 @@ public class NestedVectorIdsKNNIterator extends VectorIdsKNNIterator { private final BitSet parentBitSet; public NestedVectorIdsKNNIterator( - @Nullable final BitSet filterIdsArray, + @Nullable final DocIdSetIterator filterIdsIterator, final float[] queryVector, final KNNFloatVectorValues knnFloatVectorValues, final SpaceType spaceType, final BitSet parentBitSet ) throws IOException { - this(filterIdsArray, queryVector, knnFloatVectorValues, spaceType, parentBitSet, null, null); + this(filterIdsIterator, queryVector, knnFloatVectorValues, spaceType, parentBitSet, null, null); } public NestedVectorIdsKNNIterator( @@ -42,7 +42,7 @@ public NestedVectorIdsKNNIterator( } public NestedVectorIdsKNNIterator( - @Nullable final BitSet filterIdsArray, + @Nullable final DocIdSetIterator filterIdsIterator, final float[] queryVector, final KNNFloatVectorValues knnFloatVectorValues, final SpaceType spaceType, @@ -50,7 +50,7 @@ public NestedVectorIdsKNNIterator( final byte[] quantizedVector, final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo ) throws IOException { - super(filterIdsArray, queryVector, knnFloatVectorValues, spaceType, quantizedVector, segmentLevelQuantizationInfo); + super(filterIdsIterator, queryVector, knnFloatVectorValues, spaceType, quantizedVector, segmentLevelQuantizationInfo); this.parentBitSet = parentBitSet; } diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/VectorIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/VectorIdsKNNIterator.java index 9fb354242..8f7f287e1 100644 --- a/src/main/java/org/opensearch/knn/index/query/iterators/VectorIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/iterators/VectorIdsKNNIterator.java @@ -6,8 +6,6 @@ package org.opensearch.knn.index.query.iterators; import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.util.BitSet; -import org.apache.lucene.util.BitSetIterator; import org.opensearch.common.Nullable; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.query.SegmentLevelQuantizationInfo; @@ -23,7 +21,7 @@ * The class is used in KNNWeight to score all docs, but, it iterates over filterIdsArray if filter is provided */ public class VectorIdsKNNIterator implements KNNIterator { - protected final BitSetIterator bitSetIterator; + protected final DocIdSetIterator filterIdsIterator; protected final float[] queryVector; private final byte[] quantizedQueryVector; protected final KNNFloatVectorValues knnFloatVectorValues; @@ -33,12 +31,12 @@ public class VectorIdsKNNIterator implements KNNIterator { private final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo; public VectorIdsKNNIterator( - @Nullable final BitSet filterIdsBitSet, + @Nullable final DocIdSetIterator filterIdsIterator, final float[] queryVector, final KNNFloatVectorValues knnFloatVectorValues, final SpaceType spaceType ) throws IOException { - this(filterIdsBitSet, queryVector, knnFloatVectorValues, spaceType, null, null); + this(filterIdsIterator, queryVector, knnFloatVectorValues, spaceType, null, null); } public VectorIdsKNNIterator(final float[] queryVector, final KNNFloatVectorValues knnFloatVectorValues, final SpaceType spaceType) @@ -47,14 +45,14 @@ public VectorIdsKNNIterator(final float[] queryVector, final KNNFloatVectorValue } public VectorIdsKNNIterator( - @Nullable final BitSet filterIdsBitSet, + @Nullable final DocIdSetIterator filterIdsIterator, final float[] queryVector, final KNNFloatVectorValues knnFloatVectorValues, final SpaceType spaceType, final byte[] quantizedQueryVector, final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo ) throws IOException { - this.bitSetIterator = filterIdsBitSet == null ? null : new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length()); + this.filterIdsIterator = filterIdsIterator; this.queryVector = queryVector; this.knnFloatVectorValues = knnFloatVectorValues; this.spaceType = spaceType; @@ -101,10 +99,10 @@ protected float computeScore() throws IOException { } protected int getNextDocId() throws IOException { - if (bitSetIterator == null) { + if (filterIdsIterator == null) { return knnFloatVectorValues.nextDoc(); } - int nextDocID = this.bitSetIterator.nextDoc(); + int nextDocID = this.filterIdsIterator.nextDoc(); // For filter case, advance vector values to corresponding doc id from filter bit set if (nextDocID != DocIdSetIterator.NO_MORE_DOCS) { knnFloatVectorValues.advance(nextDocID); diff --git a/src/main/java/org/opensearch/knn/index/query/lucenelib/ExpandNestedDocsQuery.java b/src/main/java/org/opensearch/knn/index/query/lucenelib/ExpandNestedDocsQuery.java new file mode 100644 index 000000000..863fd39ed --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/lucenelib/ExpandNestedDocsQuery.java @@ -0,0 +1,141 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.lucenelib; + +import lombok.Builder; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.FieldExistsQuery; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.Bits; +import org.opensearch.knn.index.query.common.QueryUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; + +/** + * This query is for a nested k-NN field to return multiple nested field documents + * rather than only the highest-scoring nested field document. + * + * It begins by performing an approximate nearest neighbor search. Once results are gathered from all segments, + * they are reduced to the top k results. Then, it constructs filtered document IDs for nested field documents + * from these top k parent documents. Using these document IDs, it executes an exact nearest neighbor search + * with a k value of Integer.MAX_VALUE, which provides scores for all specified nested field documents. + */ +@Builder +public class ExpandNestedDocsQuery extends Query { + final private InternalNestedKnnVectorQuery internalNestedKnnVectorQuery; + final private QueryUtils queryUtils; + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + Query docAndScoreQuery = internalNestedKnnVectorQuery.knnRewrite(searcher); + Weight weight = docAndScoreQuery.createWeight(searcher, scoreMode, boost); + IndexReader reader = searcher.getIndexReader(); + List leafReaderContexts = reader.leaves(); + List> perLeafResults; + perLeafResults = queryUtils.doSearch(searcher, leafReaderContexts, weight); + TopDocs[] topDocs = retrieveAll(searcher, leafReaderContexts, perLeafResults); + int sum = 0; + for (TopDocs topDoc : topDocs) { + sum += topDoc.scoreDocs.length; + } + TopDocs topK = TopDocs.merge(sum, topDocs); + if (topK.scoreDocs.length == 0) { + return new MatchNoDocsQuery().createWeight(searcher, scoreMode, boost); + } + return queryUtils.createDocAndScoreQuery(reader, topK).createWeight(searcher, scoreMode, boost); + } + + private TopDocs[] retrieveAll( + final IndexSearcher indexSearcher, + final List leafReaderContexts, + final List> perLeafResults + ) throws IOException { + // Construct query + List> nestedQueryTasks = new ArrayList<>(leafReaderContexts.size()); + Weight filterWeight = getFilterWeight(indexSearcher); + for (int i = 0; i < perLeafResults.size(); i++) { + LeafReaderContext leafReaderContext = leafReaderContexts.get(i); + int finalI = i; + nestedQueryTasks.add(() -> { + Bits queryFilter = queryUtils.createBits(leafReaderContext, filterWeight); + DocIdSetIterator allSiblings = queryUtils.getAllSiblings( + leafReaderContext, + perLeafResults.get(finalI).keySet(), + internalNestedKnnVectorQuery.getParentFilter(), + queryFilter + ); + TopDocs topDocs = internalNestedKnnVectorQuery.knnExactSearch(leafReaderContext, allSiblings); + // Update doc id from segment id to shard id + for (ScoreDoc scoreDoc : topDocs.scoreDocs) { + scoreDoc.doc = scoreDoc.doc + leafReaderContext.docBase; + } + return topDocs; + }); + } + return indexSearcher.getTaskExecutor().invokeAll(nestedQueryTasks).toArray(TopDocs[]::new); + } + + /** + * This is copied from {@link org.apache.lucene.search.AbstractKnnVectorQuery#rewrite} + */ + private Weight getFilterWeight(final IndexSearcher indexSearcher) throws IOException { + if (internalNestedKnnVectorQuery.getFilter() == null) { + return null; + } + + BooleanQuery booleanQuery = (new BooleanQuery.Builder()).add(internalNestedKnnVectorQuery.getFilter(), BooleanClause.Occur.FILTER) + .add(new FieldExistsQuery(internalNestedKnnVectorQuery.getField()), BooleanClause.Occur.FILTER) + .build(); + Query rewritten = indexSearcher.rewrite(booleanQuery); + return indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1.0F); + } + + @Override + public void visit(final QueryVisitor queryVisitor) { + queryVisitor.visitLeaf(this); + } + + @Override + public boolean equals(final Object o) { + if (!sameClassAs(o)) { + return false; + } + ExpandNestedDocsQuery other = (ExpandNestedDocsQuery) o; + return internalNestedKnnVectorQuery.equals(other.internalNestedKnnVectorQuery); + } + + @Override + public int hashCode() { + return internalNestedKnnVectorQuery.hashCode(); + } + + @Override + public String toString(final String s) { + return this.getClass().getSimpleName() + + "[" + + internalNestedKnnVectorQuery.getField() + + "]..." + + internalNestedKnnVectorQuery.getClass().getSimpleName() + + "[" + + internalNestedKnnVectorQuery.toString() + + "]"; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/lucenelib/InternalNestedKnnByteVectoryQuery.java b/src/main/java/org/opensearch/knn/index/query/lucenelib/InternalNestedKnnByteVectoryQuery.java new file mode 100644 index 000000000..e9d022232 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/lucenelib/InternalNestedKnnByteVectoryQuery.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.lucenelib; + +import lombok.Getter; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnByteVectorQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; + +import java.io.IOException; + +/** + * InternalNestedKnnVectorQuery for byte vector + */ +@Getter +public class InternalNestedKnnByteVectoryQuery extends KnnByteVectorQuery implements InternalNestedKnnVectorQuery { + private final String field; + private final byte[] target; + private final Query filter; + private final int k; + private final BitSetProducer parentFilter; + private final DiversifyingChildrenByteKnnVectorQuery diversifyingChildrenByteKnnVectorQuery; + + public InternalNestedKnnByteVectoryQuery( + final String field, + final byte[] target, + final Query filter, + final int k, + final BitSetProducer parentFilter + ) { + super(field, target, Integer.MAX_VALUE, filter); + this.field = field; + this.target = target; + this.filter = filter; + this.k = k; + this.parentFilter = parentFilter; + this.diversifyingChildrenByteKnnVectorQuery = new DiversifyingChildrenByteKnnVectorQuery(field, target, filter, k, parentFilter); + } + + @Override + public Query knnRewrite(final IndexSearcher searcher) throws IOException { + return diversifyingChildrenByteKnnVectorQuery.rewrite(searcher); + } + + @Override + public TopDocs knnExactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) throws IOException { + return super.exactSearch(context, acceptIterator, null); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/lucenelib/InternalNestedKnnFloatVectoryQuery.java b/src/main/java/org/opensearch/knn/index/query/lucenelib/InternalNestedKnnFloatVectoryQuery.java new file mode 100644 index 000000000..6e5408bb5 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/lucenelib/InternalNestedKnnFloatVectoryQuery.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.lucenelib; + +import lombok.Getter; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; + +import java.io.IOException; + +/** + * InternalNestedKnnVectorQuery for float vector + */ +@Getter +public class InternalNestedKnnFloatVectoryQuery extends KnnFloatVectorQuery implements InternalNestedKnnVectorQuery { + private final String field; + private final float[] target; + private final Query filter; + private final int k; + private final BitSetProducer parentFilter; + private final DiversifyingChildrenFloatKnnVectorQuery diversifyingChildrenFloatKnnVectorQuery; + + public InternalNestedKnnFloatVectoryQuery( + final String field, + final float[] target, + final Query filter, + final int k, + final BitSetProducer parentFilter + ) { + super(field, target, Integer.MAX_VALUE, filter); + this.field = field; + this.target = target; + this.filter = filter; + this.k = k; + this.parentFilter = parentFilter; + this.diversifyingChildrenFloatKnnVectorQuery = new DiversifyingChildrenFloatKnnVectorQuery(field, target, filter, k, parentFilter); + } + + @Override + public Query knnRewrite(final IndexSearcher searcher) throws IOException { + return diversifyingChildrenFloatKnnVectorQuery.rewrite(searcher); + } + + @Override + public TopDocs knnExactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) throws IOException { + return super.exactSearch(context, acceptIterator, null); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/lucenelib/InternalNestedKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/lucenelib/InternalNestedKnnVectorQuery.java new file mode 100644 index 000000000..e5ea319e4 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/lucenelib/InternalNestedKnnVectorQuery.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.lucenelib; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.join.BitSetProducer; + +import java.io.IOException; + +/** + * Query interface to support k-NN nested field + */ +public interface InternalNestedKnnVectorQuery { + /** + * Return a rewritten query of nested knn search + * + * @param searcher index searcher + * @return rewritten query of nested knn search + * @throws IOException + */ + Query knnRewrite(final IndexSearcher searcher) throws IOException; + + /** + * Return a result of exact knn search + * + * @param leafReaderContext segment context + * @param iterator filtered doc ids + * @return + * @throws IOException + */ + TopDocs knnExactSearch(final LeafReaderContext leafReaderContext, final DocIdSetIterator iterator) throws IOException; + + /** + * Return a field name + * @return field name + */ + String getField(); + + /** + * Return a filter query + * @return filter query + */ + Query getFilter(); + + /** + * Return k value + * @return k value + */ + int getK(); + + /** + * Return a parent filter + * @return parent filter + */ + BitSetProducer getParentFilter(); +} diff --git a/src/main/java/org/opensearch/knn/index/query/lucenelib/NestedKnnVectorQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/lucenelib/NestedKnnVectorQueryFactory.java new file mode 100644 index 000000000..7d5c78180 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/lucenelib/NestedKnnVectorQueryFactory.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.lucenelib; + +import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; +import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; +import org.opensearch.knn.index.query.common.QueryUtils; + +/** + * A class to create a nested knn vector query for lucene + */ +public class NestedKnnVectorQueryFactory { + /** + * Create a query for k-NN nested field. + * + * The query is generated two times when inner_hits() parameter exist in the request. + * For inner hit, we return all filtered nested field documents belongs to the final result of parent documents. + * + * @param fieldName field name for search + * @param vector target vector for search + * @param k k nearest neighbor for search + * @param filterQuery efficient filtering query + * @param parentFilter has mapping data between parent doc and child doc + * @param expandNestedDocs tells if this query is for expanding nested docs + * @return Query for k-NN nested field + */ + public static Query createNestedKnnVectorQuery( + final String fieldName, + final byte[] vector, + final int k, + final Query filterQuery, + final BitSetProducer parentFilter, + final boolean expandNestedDocs + ) { + if (expandNestedDocs) { + return new ExpandNestedDocsQuery.ExpandNestedDocsQueryBuilder().internalNestedKnnVectorQuery( + new InternalNestedKnnByteVectoryQuery(fieldName, vector, filterQuery, k, parentFilter) + ).queryUtils(QueryUtils.INSTANCE).build(); + } + return new DiversifyingChildrenByteKnnVectorQuery(fieldName, vector, filterQuery, k, parentFilter); + } + + /** + * Create a query for k-NN nested field. + * + * The query is generated two times when inner_hits() parameter exist in the request. + * For inner hit, we return all filtered nested field documents belongs to the final result of parent documents. + * + * @param fieldName field name for search + * @param vector target vector for search + * @param k k nearest neighbor for search + * @param filterQuery efficient filtering query + * @param parentFilter has mapping data between parent doc and child doc + * @param expandNestedDocs tells if this query is for expanding nested docs + * @return Query for k-NN nested field + */ + public static Query createNestedKnnVectorQuery( + final String fieldName, + final float[] vector, + final int k, + final Query filterQuery, + final BitSetProducer parentFilter, + final boolean expandNestedDocs + ) { + if (expandNestedDocs) { + return new ExpandNestedDocsQuery.ExpandNestedDocsQueryBuilder().internalNestedKnnVectorQuery( + new InternalNestedKnnFloatVectoryQuery(fieldName, vector, filterQuery, k, parentFilter) + ).queryUtils(QueryUtils.INSTANCE).build(); + } + return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, vector, filterQuery, k, parentFilter); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index f782b0180..47ea215f3 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -10,6 +10,7 @@ import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; @@ -17,21 +18,19 @@ import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.Weight; -import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.opensearch.common.StopWatch; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.query.ExactSearcher; import org.opensearch.knn.index.query.KNNQuery; import org.opensearch.knn.index.query.KNNWeight; +import org.opensearch.knn.index.query.PerLeafResult; import org.opensearch.knn.index.query.ResultUtil; +import org.opensearch.knn.index.query.common.QueryUtils; import org.opensearch.knn.index.query.rescore.RescoreContext; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -50,13 +49,15 @@ public class NativeEngineKnnVectorQuery extends Query { private final KNNQuery knnQuery; + private final QueryUtils queryUtils; + private final boolean expandNestedDocs; @Override public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, float boost) throws IOException { final IndexReader reader = indexSearcher.getIndexReader(); final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, scoreMode, 1); List leafReaderContexts = reader.leaves(); - List> perLeafResults; + List perLeafResults; RescoreContext rescoreContext = knnQuery.getRescoreContext(); final int finalK = knnQuery.getK(); if (rescoreContext == null) { @@ -76,101 +77,148 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, leafReaderContexts.size()); } ResultUtil.reduceToTopK(perLeafResults, finalK); + + if (expandNestedDocs) { + StopWatch stopWatch = new StopWatch().start(); + perLeafResults = retrieveAll(indexSearcher, leafReaderContexts, knnWeight, perLeafResults, rescoreContext == null); + long time_in_millis = stopWatch.stop().totalTime().millis(); + if (log.isDebugEnabled()) { + long totalNestedDocs = perLeafResults.stream().mapToLong(perLeafResult -> perLeafResult.getResult().size()).sum(); + log.debug("Expanding of nested docs took {} ms. totalNestedDocs:{} ", time_in_millis, totalNestedDocs); + } + } + TopDocs[] topDocs = new TopDocs[perLeafResults.size()]; for (int i = 0; i < perLeafResults.size(); i++) { - topDocs[i] = ResultUtil.resultMapToTopDocs(perLeafResults.get(i), leafReaderContexts.get(i).docBase); + topDocs[i] = ResultUtil.resultMapToTopDocs(perLeafResults.get(i).getResult(), leafReaderContexts.get(i).docBase); } - TopDocs topK = TopDocs.merge(knnQuery.getK(), topDocs); + TopDocs topK = TopDocs.merge(getTotalTopDoc(topDocs), topDocs); + if (topK.scoreDocs.length == 0) { return new MatchNoDocsQuery().createWeight(indexSearcher, scoreMode, boost); } - return createDocAndScoreQuery(reader, topK).createWeight(indexSearcher, scoreMode, boost); + return queryUtils.createDocAndScoreQuery(reader, topK).createWeight(indexSearcher, scoreMode, boost); + } + + /** + * When expandNestedDocs is set to true, additional nested documents are retrieved. + * As a result, the total number of documents will exceed k. + * Instead of relying on the k value, we must count the total number of documents + * to accurately determine how many are in topDocs. + * The theoretical maximum value this method could return is Integer.MAX_VALUE, + * as a single shard cannot have more documents than Integer.MAX_VALUE. + * + * @param topDocs the top documents + * @return the total number of documents in the topDocs + */ + private int getTotalTopDoc(TopDocs[] topDocs) { + if (expandNestedDocs == false) { + return knnQuery.getK(); + } + + int sum = 0; + for (TopDocs topDoc : topDocs) { + sum += topDoc.scoreDocs.length; + } + return sum; + } + + private List retrieveAll( + final IndexSearcher indexSearcher, + List leafReaderContexts, + KNNWeight knnWeight, + List perLeafResults, + boolean useQuantizedVectors + ) throws IOException { + List> nestedQueryTasks = new ArrayList<>(leafReaderContexts.size()); + for (int i = 0; i < perLeafResults.size(); i++) { + LeafReaderContext leafReaderContext = leafReaderContexts.get(i); + int finalI = i; + nestedQueryTasks.add(() -> { + PerLeafResult perLeafResult = perLeafResults.get(finalI); + if (perLeafResult.getResult().isEmpty()) { + return perLeafResult; + } + + DocIdSetIterator allSiblings = queryUtils.getAllSiblings( + leafReaderContext, + perLeafResult.getResult().keySet(), + knnQuery.getParentsFilter(), + perLeafResult.getFilterBits() + ); + + final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder() + .matchedDocsIterator(allSiblings) + .numberOfMatchedDocs(allSiblings.cost()) + // setting to false because in re-scoring we want to do exact search on full precision vectors + .useQuantizedVectorsForSearch(useQuantizedVectors) + .k((int) allSiblings.cost()) + .isParentHits(false) + .knnQuery(knnQuery) + .build(); + Map rescoreResult = knnWeight.exactSearch(leafReaderContext, exactSearcherContext); + perLeafResult.setResult(rescoreResult); + return perLeafResult; + }); + } + return indexSearcher.getTaskExecutor().invokeAll(nestedQueryTasks); } - private List> doSearch( + private List doSearch( final IndexSearcher indexSearcher, List leafReaderContexts, KNNWeight knnWeight, int k ) throws IOException { - List>> tasks = new ArrayList<>(leafReaderContexts.size()); + List> tasks = new ArrayList<>(leafReaderContexts.size()); for (LeafReaderContext leafReaderContext : leafReaderContexts) { tasks.add(() -> searchLeaf(leafReaderContext, knnWeight, k)); } return indexSearcher.getTaskExecutor().invokeAll(tasks); } - private List> doRescore( + private List doRescore( final IndexSearcher indexSearcher, List leafReaderContexts, KNNWeight knnWeight, - List> perLeafResults, + List perLeafResults, int k ) throws IOException { - List>> rescoreTasks = new ArrayList<>(leafReaderContexts.size()); + List> rescoreTasks = new ArrayList<>(leafReaderContexts.size()); for (int i = 0; i < perLeafResults.size(); i++) { LeafReaderContext leafReaderContext = leafReaderContexts.get(i); int finalI = i; rescoreTasks.add(() -> { - final BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(perLeafResults.get(finalI)); - // if there is no docIds to re-score from a segment we should return early to ensure that we are not - // wasting any computation - if (convertedBitSet == null) { - return Collections.emptyMap(); + PerLeafResult perLeafeResult = perLeafResults.get(finalI); + if (perLeafeResult.getResult().isEmpty()) { + return perLeafeResult; } + DocIdSetIterator matchedDocs = ResultUtil.resultMapToDocIds(perLeafeResult.getResult()); final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder() - .matchedDocs(convertedBitSet) + .matchedDocsIterator(matchedDocs) + .numberOfMatchedDocs(perLeafResults.get(finalI).getResult().size()) // setting to false because in re-scoring we want to do exact search on full precision vectors .useQuantizedVectorsForSearch(false) .k(k) .isParentHits(false) .knnQuery(knnQuery) .build(); - return knnWeight.exactSearch(leafReaderContext, exactSearcherContext); + Map rescoreResult = knnWeight.exactSearch(leafReaderContext, exactSearcherContext); + perLeafeResult.setResult(rescoreResult); + return perLeafeResult; }); } return indexSearcher.getTaskExecutor().invokeAll(rescoreTasks); } - private Query createDocAndScoreQuery(IndexReader reader, TopDocs topK) { - int len = topK.scoreDocs.length; - Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc)); - int[] docs = new int[len]; - float[] scores = new float[len]; - for (int i = 0; i < len; i++) { - docs[i] = topK.scoreDocs[i].doc; - scores[i] = topK.scoreDocs[i].score; - } - int[] segmentStarts = findSegmentStarts(reader, docs); - return new DocAndScoreQuery(knnQuery.getK(), docs, scores, segmentStarts, reader.getContext().id()); - } - - static int[] findSegmentStarts(IndexReader reader, int[] docs) { - int[] starts = new int[reader.leaves().size() + 1]; - starts[starts.length - 1] = docs.length; - if (starts.length == 2) { - return starts; - } - int resultIndex = 0; - for (int i = 1; i < starts.length - 1; i++) { - int upper = reader.leaves().get(i).docBase; - resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper); - if (resultIndex < 0) { - resultIndex = -1 - resultIndex; - } - starts[i] = resultIndex; - } - return starts; - } - - private Map searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, int k) throws IOException { - final Map leafDocScores = queryWeight.searchLeaf(ctx, k); + private PerLeafResult searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, int k) throws IOException { + final PerLeafResult perLeafResult = queryWeight.searchLeaf(ctx, k); final Bits liveDocs = ctx.reader().getLiveDocs(); if (liveDocs != null) { - leafDocScores.entrySet().removeIf(entry -> liveDocs.get(entry.getKey()) == false); + perLeafResult.getResult().entrySet().removeIf(entry -> liveDocs.get(entry.getKey()) == false); } - return leafDocScores; + return perLeafResult; } @Override diff --git a/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java b/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java index 02fbd0113..376f60334 100644 --- a/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java +++ b/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java @@ -29,7 +29,9 @@ import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; +import static org.opensearch.knn.common.KNNConstants.EXPAND_NESTED; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER; +import static org.opensearch.knn.index.query.KNNQueryBuilder.EXPAND_NESTED_FIELD; import static org.opensearch.knn.index.query.KNNQueryBuilder.RESCORE_FIELD; import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_PARAMETER; import static org.opensearch.knn.index.util.IndexUtil.isClusterOnOrAfterMinRequiredVersion; @@ -89,6 +91,8 @@ private static ObjectParser createInternalObjectP RESCORE_FIELD ); + internalParser.declareBoolean(KNNQueryBuilder.Builder::expandNested, EXPAND_NESTED_FIELD); + // Declare fields that cannot be set at the same time. Right now, rescore and radial is not supported internalParser.declareExclusiveFieldSet(RESCORE_FIELD.getPreferredName(), MAX_DISTANCE_FIELD.getPreferredName()); internalParser.declareExclusiveFieldSet(RESCORE_FIELD.getPreferredName(), MIN_SCORE_FIELD.getPreferredName()); @@ -128,6 +132,10 @@ public static KNNQueryBuilder.Builder streamInput(StreamInput in, Function minimalRequiredVersionMap = initializeMinimalRequiredVersionMap(); public static final Set VECTOR_DATA_TYPES_NOT_SUPPORTING_ENCODERS = Set.of(VectorDataType.BINARY, VectorDataType.BYTE); @@ -397,6 +399,7 @@ private static Map initializeMinimalRequiredVersionMap() { put(KNNConstants.MINIMAL_MODE_AND_COMPRESSION_FEATURE, MINIMAL_MODE_AND_COMPRESSION_FEATURE); put(KNNConstants.TOP_LEVEL_SPACE_TYPE_FEATURE, MINIMAL_TOP_LEVEL_SPACE_TYPE_FEATURE); put(KNNConstants.MODEL_VERSION, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VERSION); + put(EXPAND_NESTED, MINIMAL_EXPAND_NESTED_FEATURE); } }; diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index 96493acec..1836ddb7e 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -29,6 +29,7 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.query.lucenelib.ExpandNestedDocsQuery; import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery; import org.opensearch.knn.index.query.rescore.RescoreContext; @@ -482,4 +483,39 @@ public void testCreate_whenRescoreContextPassed_thenSuccess() { // Then assertEquals(expected, ((NativeEngineKnnVectorQuery) query).getKnnQuery()); } + + public void testCreate_whenExpandNestedDocsQueryWithFaiss_thenCreateNativeEngineKNNVectorQuery() { + testExpandNestedDocsQuery(KNNEngine.FAISS, NativeEngineKnnVectorQuery.class, VectorDataType.values()[randomInt(2)]); + } + + public void testCreate_whenExpandNestedDocsQueryWithNmslib_thenCreateKNNQuery() { + testExpandNestedDocsQuery(KNNEngine.NMSLIB, KNNQuery.class, VectorDataType.FLOAT); + } + + public void testCreate_whenExpandNestedDocsQueryWithLucene_thenCreateExpandNestedDocsQuery() { + testExpandNestedDocsQuery(KNNEngine.LUCENE, ExpandNestedDocsQuery.class, VectorDataType.BYTE); + testExpandNestedDocsQuery(KNNEngine.LUCENE, ExpandNestedDocsQuery.class, VectorDataType.FLOAT); + } + + private void testExpandNestedDocsQuery(KNNEngine knnEngine, Class klass, VectorDataType vectorDataType) { + QueryShardContext queryShardContext = mock(QueryShardContext.class); + BitSetProducer parentFilter = mock(BitSetProducer.class); + when(queryShardContext.getParentFilter()).thenReturn(parentFilter); + final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(testQueryVector) + .vector(testQueryVector) + .byteVector(testByteQueryVector) + .vectorDataType(vectorDataType) + .k(testK) + .expandNested(true) + .context(queryShardContext) + .build(); + Query query = KNNQueryFactory.create(createQueryRequest); + + // Then + assertEquals(klass, query.getClass()); + } } diff --git a/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java b/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java index 7cda1ed79..a3b8c6989 100644 --- a/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java @@ -8,8 +8,6 @@ import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.util.BitSet; -import org.junit.Assert; import org.opensearch.knn.KNNTestCase; import java.io.IOException; @@ -28,8 +26,13 @@ public void testReduceToTopK() { int segmentCount = 5; List> initialLeafResults = getRandomListOfResults(firstPassK, segmentCount); - List> reducedLeafResults = initialLeafResults.stream().map(HashMap::new).collect(Collectors.toList()); - ResultUtil.reduceToTopK(reducedLeafResults, finalK); + List perLeafLeafResults = initialLeafResults.stream() + .map(result -> new PerLeafResult(null, new HashMap<>(result))) + .collect(Collectors.toList()); + ResultUtil.reduceToTopK(perLeafLeafResults, finalK); + List> reducedLeafResults = perLeafLeafResults.stream() + .map(PerLeafResult::getResult) + .collect(Collectors.toList()); assertTopK(initialLeafResults, reducedLeafResults, finalK); firstPassK = 5; @@ -37,27 +40,22 @@ public void testReduceToTopK() { segmentCount = 1; initialLeafResults = getRandomListOfResults(firstPassK, segmentCount); - reducedLeafResults = initialLeafResults.stream().map(HashMap::new).collect(Collectors.toList()); - ResultUtil.reduceToTopK(reducedLeafResults, finalK); + perLeafLeafResults = initialLeafResults.stream() + .map(result -> new PerLeafResult(null, new HashMap<>(result))) + .collect(Collectors.toList()); + ResultUtil.reduceToTopK(perLeafLeafResults, finalK); + reducedLeafResults = perLeafLeafResults.stream().map(PerLeafResult::getResult).collect(Collectors.toList()); assertTopK(initialLeafResults, reducedLeafResults, firstPassK); } - public void testResultMapToMatchBitSet() throws IOException { + public void testResultMapToDocIds() throws IOException { int firstPassK = 35; Map perLeafResults = getRandomResults(firstPassK); - BitSet resultBitset = ResultUtil.resultMapToMatchBitSet(perLeafResults); - assertResultMapToMatchBitSet(perLeafResults, resultBitset); - } - - public void testResultMapToMatchBitSet_whenResultMapEmpty_thenReturnEmptyOptional() throws IOException { - BitSet resultBitset = ResultUtil.resultMapToMatchBitSet(Collections.emptyMap()); - Assert.assertNull(resultBitset); - - BitSet resultBitset2 = ResultUtil.resultMapToMatchBitSet(null); - Assert.assertNull(resultBitset2); + DocIdSetIterator resultDocIdSetIterator = ResultUtil.resultMapToDocIds(perLeafResults); + assertResultMapToDocIdSetIterator(perLeafResults, resultDocIdSetIterator); } - public void testResultMapToDocIds() throws IOException { + public void testResultMapToDocIdsWithMaxDoc() throws IOException { int firstPassK = 42; Map perLeafResults = getRandomResults(firstPassK); final int maxDoc = Collections.max(perLeafResults.keySet()) + 1; @@ -99,13 +97,6 @@ private void assertTopK(List> beforeResults, List resultsMap, BitSet resultBitset) { - assertEquals(resultsMap.size(), resultBitset.cardinality()); - for (Integer docId : resultsMap.keySet()) { - assertTrue(resultBitset.get(docId)); - } - } - private void assertResultMapToDocIdSetIterator(Map resultsMap, DocIdSetIterator resultDocIdSetIterator) throws IOException { int count = 0; diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQueryTests.java b/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java similarity index 98% rename from src/test/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQueryTests.java rename to src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java index 185cb5d47..b32496138 100644 --- a/src/test/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.knn.index.query.nativelib; +package org.opensearch.knn.index.query.common; import lombok.SneakyThrows; import org.apache.lucene.index.IndexReader; diff --git a/src/test/java/org/opensearch/knn/index/query/common/QueryUtilsTests.java b/src/test/java/org/opensearch/knn/index/query/common/QueryUtilsTests.java new file mode 100644 index 000000000..d804b9ab8 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/common/QueryUtilsTests.java @@ -0,0 +1,210 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.common; + +import junit.framework.TestCase; +import lombok.SneakyThrows; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TaskExecutor; +import org.apache.lucene.search.Weight; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; +import org.junit.Before; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class QueryUtilsTests extends TestCase { + private Executor executor; + private TaskExecutor taskExecutor; + private QueryUtils queryUtils; + + @Before + public void setUp() throws Exception { + executor = Executors.newSingleThreadExecutor(); + taskExecutor = new TaskExecutor(executor); + queryUtils = QueryUtils.INSTANCE; + } + + @SneakyThrows + public void testDoSearch_whenExecuted_thenSucceed() { + IndexSearcher indexSearcher = mock(IndexSearcher.class); + when(indexSearcher.getTaskExecutor()).thenReturn(taskExecutor); + + LeafReaderContext leafReaderContext1 = mock(LeafReaderContext.class); + LeafReaderContext leafReaderContext2 = mock(LeafReaderContext.class); + List leafReaderContexts = Arrays.asList(leafReaderContext1, leafReaderContext2); + + DocIdSetIterator docIdSetIterator = mock(DocIdSetIterator.class); + when(docIdSetIterator.docID()).thenReturn(0, 1, DocIdSetIterator.NO_MORE_DOCS); + Scorer scorer = mock(Scorer.class); + when(scorer.iterator()).thenReturn(docIdSetIterator); + when(scorer.docID()).thenReturn(0, 1, DocIdSetIterator.NO_MORE_DOCS); + when(scorer.score()).thenReturn(10.f, 11.f, -1f); + + Weight weight = mock(Weight.class); + when(weight.scorer(leafReaderContext1)).thenReturn(null); + when(weight.scorer(leafReaderContext2)).thenReturn(scorer); + + // Run + List> results = queryUtils.doSearch(indexSearcher, leafReaderContexts, weight); + + // Verify + assertEquals(2, results.size()); + assertEquals(0, results.get(0).size()); + assertEquals(2, results.get(1).size()); + assertEquals(10.f, results.get(1).get(0)); + assertEquals(11.f, results.get(1).get(1)); + + } + + @SneakyThrows + public void testGetAllSiblings_whenEmptyDocIds_thenEmptyIterator() { + LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + BitSetProducer bitSetProducer = mock(BitSetProducer.class); + Bits bits = mock(Bits.class); + + // Run + DocIdSetIterator docIdSetIterator = queryUtils.getAllSiblings(leafReaderContext, Collections.emptySet(), bitSetProducer, bits); + + // Verify + assertEquals(DocIdSetIterator.NO_MORE_DOCS, docIdSetIterator.nextDoc()); + } + + @SneakyThrows + public void testGetAllSiblings_whenNonEmptyDocIds_thenReturnAllSiblings() { + LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + // 0, 1, 2(parent), 3, 4, 5, 6, 7(parent), 8, 9, 10(parent) + BitSet bitSet = new FixedBitSet(new long[1], 11); + bitSet.set(2); + bitSet.set(7); + bitSet.set(10); + BitSetProducer bitSetProducer = mock(BitSetProducer.class); + when(bitSetProducer.getBitSet(leafReaderContext)).thenReturn(bitSet); + + BitSet filterBits = new FixedBitSet(new long[1], 11); + filterBits.set(1); + filterBits.set(8); + filterBits.set(9); + + // Run + Set docIds = Set.of(1, 8); + DocIdSetIterator docIdSetIterator = queryUtils.getAllSiblings(leafReaderContext, docIds, bitSetProducer, filterBits); + + // Verify + Set expectedDocIds = Set.of(1, 8, 9); + Set returnedDocIds = new HashSet<>(); + docIdSetIterator.nextDoc(); + while (docIdSetIterator.docID() != DocIdSetIterator.NO_MORE_DOCS) { + returnedDocIds.add(docIdSetIterator.docID()); + docIdSetIterator.nextDoc(); + } + assertEquals(expectedDocIds, returnedDocIds); + } + + @SneakyThrows + public void testCreateBits_whenWeightIsNull_thenMatchAllBits() { + LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + + // Run + Bits bits = queryUtils.createBits(leafReaderContext, null); + + // Verify + assertEquals(Bits.MatchAllBits.class, bits.getClass()); + + } + + @SneakyThrows + public void testCreateBits_whenScoreIsNull_thenMatchNoBits() { + LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + Weight weight = mock(Weight.class); + when(weight.scorer(leafReaderContext)).thenReturn(null); + + // Run + Bits bits = queryUtils.createBits(leafReaderContext, weight); + + // Verify + assertEquals(Bits.MatchNoBits.class, bits.getClass()); + } + + @SneakyThrows + public void testCreateBits_whenCalled_thenReturnBits() { + FixedBitSet liveDocBitSet = new FixedBitSet(new long[1], 11); + liveDocBitSet.set(2); + liveDocBitSet.set(7); + liveDocBitSet.set(10); + + FixedBitSet matchedBitSet = new FixedBitSet(new long[1], 11); + matchedBitSet.set(1); + matchedBitSet.set(2); + matchedBitSet.set(4); + matchedBitSet.set(9); + matchedBitSet.set(10); + + BitSetIterator matchedBitSetIterator = new BitSetIterator(matchedBitSet, 5); + + LeafReader leafReader = mock(LeafReader.class); + when(leafReader.getLiveDocs()).thenReturn(liveDocBitSet); + when(leafReader.maxDoc()).thenReturn(11); + + LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + when(leafReaderContext.reader()).thenReturn(leafReader); + + Scorer scorer = mock(Scorer.class); + when(scorer.iterator()).thenReturn(matchedBitSetIterator); + + Weight weight = mock(Weight.class); + when(weight.scorer(leafReaderContext)).thenReturn(scorer); + + // Run + Bits bits = queryUtils.createBits(leafReaderContext, weight); + + // Verify + FixedBitSet expectedBitSet = matchedBitSet.clone(); + expectedBitSet.and(liveDocBitSet); + assertTrue(areSetBitsEqual(expectedBitSet, bits)); + } + + private boolean areSetBitsEqual(Bits bits1, Bits bits2) { + int minLength = Math.min(bits1.length(), bits2.length()); + + for (int i = 0; i < minLength; i++) { + if (bits1.get(i) != bits2.get(i)) { + return false; + } + } + + for (int i = minLength; i < bits1.length(); i++) { + if (bits1.get(i)) { + return false; + } + } + + for (int i = minLength; i < bits2.length(); i++) { + if (bits2.get(i)) { + return false; + } + } + + return true; + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIteratorTests.java index 6d5dffa98..5cd16a090 100644 --- a/src/test/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIteratorTests.java @@ -8,6 +8,7 @@ import junit.framework.TestCase; import lombok.SneakyThrows; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.FixedBitSet; import org.mockito.stubbing.OngoingStubbing; import org.opensearch.knn.index.SpaceType; @@ -45,7 +46,12 @@ public void testNextDoc_whenCalled_IterateAllDocs() { } // Execute and verify - BinaryVectorIdsKNNIterator iterator = new BinaryVectorIdsKNNIterator(filterBitSet, queryVector, values, spaceType); + BinaryVectorIdsKNNIterator iterator = new BinaryVectorIdsKNNIterator( + new BitSetIterator(filterBitSet, filterBitSet.length()), + queryVector, + values, + spaceType + ); for (int i = 0; i < filterIds.length; i++) { assertEquals(filterIds[i], iterator.nextDoc()); assertEquals(expectedScores.get(i), (Float) iterator.score()); diff --git a/src/test/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIteratorTests.java index 60169b95f..91d13e1ec 100644 --- a/src/test/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIteratorTests.java @@ -8,6 +8,7 @@ import junit.framework.TestCase; import lombok.SneakyThrows; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.FixedBitSet; import org.mockito.stubbing.OngoingStubbing; import org.opensearch.knn.index.SpaceType; @@ -46,7 +47,12 @@ public void testNextDoc_whenCalled_IterateAllDocs() { } // Execute and verify - ByteVectorIdsKNNIterator iterator = new ByteVectorIdsKNNIterator(filterBitSet, queryVector, values, spaceType); + ByteVectorIdsKNNIterator iterator = new ByteVectorIdsKNNIterator( + new BitSetIterator(filterBitSet, filterBitSet.length()), + queryVector, + values, + spaceType + ); for (int i = 0; i < filterIds.length; i++) { assertEquals(filterIds[i], iterator.nextDoc()); assertEquals(expectedScores.get(i), (Float) iterator.score()); diff --git a/src/test/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIteratorTests.java new file mode 100644 index 000000000..55f3d91d9 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIteratorTests.java @@ -0,0 +1,73 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.iterators; + +import junit.framework.TestCase; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.FixedBitSet; + +import java.util.HashSet; +import java.util.Set; + +public class GroupedNestedDocIdSetIteratorTests extends TestCase { + public void testGroupedNestedDocIdSetIterator_whenNextDocIsCalled_thenBehaveAsExpected() throws Exception { + // 0, 1, 2(parent), 3, 4, 5, 6, 7(parent), 8, 9, 10(parent) + BitSet parentBitSet = new FixedBitSet(new long[1], 11); + parentBitSet.set(2); + parentBitSet.set(7); + parentBitSet.set(10); + + BitSet filterBits = new FixedBitSet(new long[1], 11); + filterBits.set(1); + filterBits.set(8); + filterBits.set(9); + + // Run + Set docIds = Set.of(1, 8); + GroupedNestedDocIdSetIterator groupedNestedDocIdSetIterator = new GroupedNestedDocIdSetIterator(parentBitSet, docIds, filterBits); + + // Verify + Set expectedDocIds = Set.of(1, 8, 9); + Set returnedDocIds = new HashSet<>(); + groupedNestedDocIdSetIterator.nextDoc(); + while (groupedNestedDocIdSetIterator.docID() != DocIdSetIterator.NO_MORE_DOCS) { + returnedDocIds.add(groupedNestedDocIdSetIterator.docID()); + groupedNestedDocIdSetIterator.nextDoc(); + } + assertEquals(expectedDocIds, returnedDocIds); + assertEquals(expectedDocIds.size(), groupedNestedDocIdSetIterator.cost()); + } + + public void testGroupedNestedDocIdSetIterator_whenAdvanceIsCalled_thenBehaveAsExpected() throws Exception { + // 0, 1, 2(parent), 3, 4, 5, 6, 7(parent), 8, 9, 10(parent) + BitSet parentBitSet = new FixedBitSet(new long[1], 11); + parentBitSet.set(2); + parentBitSet.set(7); + parentBitSet.set(10); + + BitSet filterBits = new FixedBitSet(new long[1], 11); + filterBits.set(1); + filterBits.set(8); + filterBits.set(9); + + // Run + Set docIds = Set.of(1, 8); + GroupedNestedDocIdSetIterator groupedNestedDocIdSetIterator = new GroupedNestedDocIdSetIterator(parentBitSet, docIds, filterBits); + + // Verify + Set expectedDocIds = Set.of(1, 8, 9); + groupedNestedDocIdSetIterator.advance(1); + assertEquals(1, groupedNestedDocIdSetIterator.docID()); + groupedNestedDocIdSetIterator.advance(8); + assertEquals(8, groupedNestedDocIdSetIterator.docID()); + groupedNestedDocIdSetIterator.advance(9); + assertEquals(9, groupedNestedDocIdSetIterator.docID()); + groupedNestedDocIdSetIterator.nextDoc(); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, groupedNestedDocIdSetIterator.docID()); + assertEquals(expectedDocIds.size(), groupedNestedDocIdSetIterator.cost()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIteratorTests.java index a39a3b2e9..32ac08156 100644 --- a/src/test/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIteratorTests.java @@ -9,6 +9,7 @@ import lombok.SneakyThrows; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.FixedBitSet; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; @@ -49,7 +50,7 @@ public void testNextDoc_whenIterate_ReturnBestChildDocsPerParent() { // Execute and verify NestedBinaryVectorIdsKNNIterator iterator = new NestedBinaryVectorIdsKNNIterator( - filterBitSet, + new BitSetIterator(filterBitSet, filterBitSet.length()), queryVector, values, spaceType, diff --git a/src/test/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIteratorTests.java index 08c859779..fcc635aaa 100644 --- a/src/test/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIteratorTests.java @@ -9,6 +9,7 @@ import lombok.SneakyThrows; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.FixedBitSet; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues; @@ -50,7 +51,7 @@ public void testNextDoc_whenIterate_ReturnBestChildDocsPerParent() { // Execute and verify NestedByteVectorIdsKNNIterator iterator = new NestedByteVectorIdsKNNIterator( - filterBitSet, + new BitSetIterator(filterBitSet, filterBitSet.length()), queryVector, values, spaceType, diff --git a/src/test/java/org/opensearch/knn/index/query/iterators/NestedVectorIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/iterators/NestedVectorIdsKNNIteratorTests.java index f94ddb4e1..b44a90c5f 100644 --- a/src/test/java/org/opensearch/knn/index/query/iterators/NestedVectorIdsKNNIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/iterators/NestedVectorIdsKNNIteratorTests.java @@ -9,6 +9,7 @@ import lombok.SneakyThrows; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.FixedBitSet; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; @@ -56,7 +57,13 @@ public void testNextDoc_whenIterate_ReturnBestChildDocsPerParent() { } // Execute and verify - NestedVectorIdsKNNIterator iterator = new NestedVectorIdsKNNIterator(filterBitSet, queryVector, values, spaceType, parentBitSet); + NestedVectorIdsKNNIterator iterator = new NestedVectorIdsKNNIterator( + new BitSetIterator(filterBitSet, filterBitSet.length()), + queryVector, + values, + spaceType, + parentBitSet + ); assertEquals(filterIds[0], iterator.nextDoc()); assertEquals(expectedScores.get(0), iterator.score()); assertEquals(filterIds[2], iterator.nextDoc()); diff --git a/src/test/java/org/opensearch/knn/index/query/iterators/VectorIdsKNNIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/iterators/VectorIdsKNNIteratorTests.java index 96932d0f1..dc79a20ca 100644 --- a/src/test/java/org/opensearch/knn/index/query/iterators/VectorIdsKNNIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/iterators/VectorIdsKNNIteratorTests.java @@ -7,6 +7,7 @@ import lombok.SneakyThrows; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.FixedBitSet; import org.mockito.stubbing.OngoingStubbing; import org.opensearch.knn.KNNTestCase; @@ -48,7 +49,12 @@ public void testNextDoc_whenCalledWithFilters_thenIterateAllDocs() { } // Execute and verify - VectorIdsKNNIterator iterator = new VectorIdsKNNIterator(filterBitSet, queryVector, values, spaceType); + VectorIdsKNNIterator iterator = new VectorIdsKNNIterator( + new BitSetIterator(filterBitSet, filterBitSet.length()), + queryVector, + values, + spaceType + ); for (int i = 0; i < filterIds.length; i++) { assertEquals(filterIds[i], iterator.nextDoc()); assertEquals(expectedScores.get(i), (Float) iterator.score()); diff --git a/src/test/java/org/opensearch/knn/index/query/lucenelib/ExpandNestedEDocsQueryTests.java b/src/test/java/org/opensearch/knn/index/query/lucenelib/ExpandNestedEDocsQueryTests.java new file mode 100644 index 000000000..55a110f6a --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/lucenelib/ExpandNestedEDocsQueryTests.java @@ -0,0 +1,132 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.lucenelib; + +import junit.framework.TestCase; +import lombok.SneakyThrows; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TaskExecutor; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Weight; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.util.Bits; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.opensearch.knn.index.query.ResultUtil; +import org.opensearch.knn.index.query.common.QueryUtils; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class ExpandNestedEDocsQueryTests extends TestCase { + private Executor executor; + private TaskExecutor taskExecutor; + + @Before + public void setUp() throws Exception { + executor = Executors.newSingleThreadExecutor(); + taskExecutor = new TaskExecutor(executor); + } + + @SneakyThrows + public void testCreateWeight_whenCalled_thenSucceed() { + LeafReaderContext leafReaderContext1 = mock(LeafReaderContext.class); + LeafReaderContext leafReaderContext2 = mock(LeafReaderContext.class); + List leafReaderContexts = Arrays.asList(leafReaderContext1, leafReaderContext2); + + IndexReader indexReader = mock(IndexReader.class); + when(indexReader.leaves()).thenReturn(leafReaderContexts); + + Weight filterWeight = mock(Weight.class); + + IndexSearcher indexSearcher = mock(IndexSearcher.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(indexSearcher.getTaskExecutor()).thenReturn(taskExecutor); + when(indexSearcher.createWeight(any(), eq(ScoreMode.COMPLETE_NO_SCORES), eq(1.0F))).thenReturn(filterWeight); + + Weight queryWeight = mock(Weight.class); + ScoreMode scoreMode = mock(ScoreMode.class); + float boost = 1.f; + Query docAndScoreQuery = mock(Query.class); + when(docAndScoreQuery.createWeight(indexSearcher, scoreMode, boost)).thenReturn(queryWeight); + + TopDocs topDocs1 = ResultUtil.resultMapToTopDocs(Map.of(1, 20f), 0); + TopDocs topDocs2 = ResultUtil.resultMapToTopDocs(Map.of(0, 21f), 4); + + Query filterQuery = mock(Query.class); + BitSetProducer parentFilter = mock(BitSetProducer.class); + + InternalNestedKnnVectorQuery internalQuery = mock(InternalNestedKnnVectorQuery.class); + when(internalQuery.knnRewrite(indexSearcher)).thenReturn(docAndScoreQuery); + when(internalQuery.getK()).thenReturn(2); + when(internalQuery.knnExactSearch(any(), any())).thenReturn(topDocs1, topDocs2); + when(internalQuery.getFilter()).thenReturn(filterQuery); + when(internalQuery.getField()).thenReturn("field"); + when(internalQuery.getParentFilter()).thenReturn(parentFilter); + + Map initialLeaf1Results = new HashMap<>(Map.of(0, 19f, 1, 20f, 2, 17f, 3, 15f)); + Map initialLeaf2Results = new HashMap<>(Map.of(0, 21f, 1, 18f, 2, 16f, 3, 14f)); + List> perLeafResults = Arrays.asList(initialLeaf1Results, initialLeaf2Results); + + Bits queryFilterBits = mock(Bits.class); + DocIdSetIterator allSiblings = mock(DocIdSetIterator.class); + when(allSiblings.nextDoc()).thenReturn(1, 2, DocIdSetIterator.NO_MORE_DOCS); + + Weight expectedWeight = mock(Weight.class); + TopDocs topK = TopDocs.merge(2, new TopDocs[] { topDocs1, topDocs2 }); + Query finalQuery = mock(Query.class); + when(finalQuery.createWeight(indexSearcher, scoreMode, boost)).thenReturn(expectedWeight); + + QueryUtils queryUtils = mock(QueryUtils.class); + when(queryUtils.doSearch(indexSearcher, leafReaderContexts, queryWeight)).thenReturn(perLeafResults); + when(queryUtils.createBits(any(), any())).thenReturn(queryFilterBits); + when(queryUtils.getAllSiblings(any(), any(), any(), any())).thenReturn(allSiblings); + when(queryUtils.createDocAndScoreQuery(eq(indexReader), any())).thenReturn(finalQuery); + + // Run + ExpandNestedDocsQuery query = new ExpandNestedDocsQuery(internalQuery, queryUtils); + Weight finalWeigh = query.createWeight(indexSearcher, scoreMode, 1.f); + + // Verify + assertEquals(expectedWeight, finalWeigh); + verify(queryUtils).createBits(leafReaderContext1, filterWeight); + verify(queryUtils).createBits(leafReaderContext2, filterWeight); + verify(queryUtils).getAllSiblings(leafReaderContext1, perLeafResults.get(0).keySet(), parentFilter, queryFilterBits); + verify(queryUtils).getAllSiblings(leafReaderContext2, perLeafResults.get(1).keySet(), parentFilter, queryFilterBits); + ArgumentCaptor topDocsCaptor = ArgumentCaptor.forClass(TopDocs.class); + verify(queryUtils).createDocAndScoreQuery(eq(indexReader), topDocsCaptor.capture()); + TopDocs capturedTopDocs = topDocsCaptor.getValue(); + assertEquals(topK.totalHits, capturedTopDocs.totalHits); + for (int i = 0; i < topK.scoreDocs.length; i++) { + assertEquals(topK.scoreDocs[i].doc, capturedTopDocs.scoreDocs[i].doc); + assertEquals(topK.scoreDocs[i].score, capturedTopDocs.scoreDocs[i].score, 0.01f); + assertEquals(topK.scoreDocs[i].shardIndex, capturedTopDocs.scoreDocs[i].shardIndex); + } + + // Verify acceptedDocIds is intersection of allSiblings and filteredDocIds + ArgumentCaptor iteratorCaptor = ArgumentCaptor.forClass(DocIdSetIterator.class); + verify(internalQuery, times(perLeafResults.size())).knnExactSearch(any(), iteratorCaptor.capture()); + assertEquals(1, iteratorCaptor.getValue().nextDoc()); + assertEquals(2, iteratorCaptor.getValue().nextDoc()); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iteratorCaptor.getValue().nextDoc()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/lucenelib/NestedKnnVectorQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/lucenelib/NestedKnnVectorQueryFactoryTests.java new file mode 100644 index 000000000..5e6570a74 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/lucenelib/NestedKnnVectorQueryFactoryTests.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.lucenelib; + +import junit.framework.TestCase; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; +import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; + +import static org.mockito.Mockito.mock; + +public class NestedKnnVectorQueryFactoryTests extends TestCase { + public void testCreate_whenCalled_thenCreateQuery() { + String fieldName = "field"; + byte[] byteVectors = new byte[3]; + float[] floatVectors = new float[3]; + int k = 3; + Query queryFilter = mock(Query.class); + BitSetProducer parentFilter = mock(BitSetProducer.class); + boolean expandNestedDocs = true; + + ExpandNestedDocsQuery expectedByteQuery = new ExpandNestedDocsQuery.ExpandNestedDocsQueryBuilder().internalNestedKnnVectorQuery( + new InternalNestedKnnByteVectoryQuery(fieldName, byteVectors, queryFilter, k, parentFilter) + ).queryUtils(null).build(); + assertEquals( + expectedByteQuery, + NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(fieldName, byteVectors, k, queryFilter, parentFilter, expandNestedDocs) + ); + + ExpandNestedDocsQuery expectedFloatQuery = new ExpandNestedDocsQuery.ExpandNestedDocsQueryBuilder().internalNestedKnnVectorQuery( + new InternalNestedKnnFloatVectoryQuery(fieldName, floatVectors, queryFilter, k, parentFilter) + ).queryUtils(null).build(); + assertEquals( + expectedFloatQuery, + NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(fieldName, floatVectors, k, queryFilter, parentFilter, expandNestedDocs) + ); + } + + public void testCreate_whenNoExpandNestedDocs_thenDiversifyingQuery() { + String fieldName = "field"; + byte[] byteVectors = new byte[3]; + float[] floatVectors = new float[3]; + int k = 3; + Query queryFilter = mock(Query.class); + BitSetProducer parentFilter = mock(BitSetProducer.class); + boolean expandNestedDocs = false; + + assertEquals( + DiversifyingChildrenByteKnnVectorQuery.class, + NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(fieldName, byteVectors, k, queryFilter, parentFilter, expandNestedDocs) + .getClass() + ); + + assertEquals( + DiversifyingChildrenFloatKnnVectorQuery.class, + NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(fieldName, floatVectors, k, queryFilter, parentFilter, expandNestedDocs) + .getClass() + ); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java index 01e3fa6f9..789bd1054 100644 --- a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java @@ -10,41 +10,49 @@ import org.apache.lucene.index.IndexReaderContext; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.Weight; -import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.Weight; +import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.util.Bits; -import org.mockito.InjectMocks; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockedStatic; import org.mockito.invocation.InvocationOnMock; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.service.ClusterService; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.query.ExactSearcher; import org.opensearch.knn.index.query.KNNQuery; import org.opensearch.knn.index.query.KNNWeight; +import org.opensearch.knn.index.query.PerLeafResult; import org.opensearch.knn.index.query.ResultUtil; +import org.opensearch.knn.index.query.common.QueryUtils; import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.test.OpenSearchTestCase; import java.util.ArrayList; -import java.util.Collections; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.Callable; -import static org.mockito.ArgumentMatchers.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; -import static org.mockito.Mockito.when; import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.mockito.MockitoAnnotations.openMocks; public class NativeEngineKNNVectorQueryTests extends OpenSearchTestCase { @@ -73,7 +81,6 @@ public class NativeEngineKNNVectorQueryTests extends OpenSearchTestCase { @Mock private ClusterService clusterService; - @InjectMocks private NativeEngineKnnVectorQuery objectUnderTest; private static ScoreMode scoreMode = ScoreMode.TOP_SCORES; @@ -82,7 +89,7 @@ public class NativeEngineKNNVectorQueryTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); openMocks(this); - + objectUnderTest = new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, false); when(leaf1.reader()).thenReturn(leafReader1); when(leaf2.reader()).thenReturn(leafReader2); @@ -91,9 +98,9 @@ public void setUp() throws Exception { when(searcher.getTaskExecutor()).thenReturn(taskExecutor); when(taskExecutor.invokeAll(any())).thenAnswer(invocationOnMock -> { - List>> callables = invocationOnMock.getArgument(0); - List> results = new ArrayList<>(); - for (Callable> callable : callables) { + List> callables = invocationOnMock.getArgument(0); + List results = new ArrayList<>(); + for (Callable callable : callables) { results.add(callable.call()); } return results; @@ -115,8 +122,11 @@ public void testMultiLeaf() { List leaves = List.of(leaf1, leaf2); when(reader.leaves()).thenReturn(leaves); - when(knnWeight.searchLeaf(leaf1, 4)).thenReturn(new HashMap<>(Map.of(0, 1.2f, 1, 5.1f, 2, 2.2f))); - when(knnWeight.searchLeaf(leaf2, 4)).thenReturn(new HashMap<>(Map.of(4, 3.4f, 3, 5.1f))); + PerLeafResult leaf1Result = new PerLeafResult(null, new HashMap<>(Map.of(0, 1.2f, 1, 5.1f, 2, 2.2f))); + PerLeafResult leaf2Result = new PerLeafResult(null, new HashMap<>(Map.of(4, 3.4f, 3, 5.1f))); + + when(knnWeight.searchLeaf(leaf1, 4)).thenReturn(leaf1Result); + when(knnWeight.searchLeaf(leaf2, 4)).thenReturn(leaf2Result); // Making sure there is deleted docs in one of the segments Bits liveDocs = mock(Bits.class); @@ -129,17 +139,19 @@ public void testMultiLeaf() { // k=4 to make sure we get topk results even if docs are deleted/less in one of the leaves when(knnQuery.getK()).thenReturn(4); - when(indexReaderContext.id()).thenReturn(1); - int[] expectedDocs = { 0, 3, 4 }; - float[] expectedScores = { 1.2f, 5.1f, 3.4f }; - int[] findSegments = { 0, 1, 3 }; - Query expected = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1); + + Map leaf1ResultLive = Map.of(0, 1.2f); + TopDocs[] topDocs = { + ResultUtil.resultMapToTopDocs(leaf1ResultLive, leaf1.docBase), + ResultUtil.resultMapToTopDocs(leaf2Result.getResult(), leaf2.docBase) }; + TopDocs expectedTopDocs = TopDocs.merge(4, topDocs); // When Weight actual = objectUnderTest.createWeight(searcher, scoreMode, 1); // Then + Query expected = QueryUtils.INSTANCE.createDocAndScoreQuery(reader, expectedTopDocs); assertEquals(expected, actual.getQuery()); } @@ -150,12 +162,13 @@ public void testRescoreWhenShardLevelRescoringEnabled() { when(reader.leaves()).thenReturn(leaves); int k = 2; - int firstPassK = 3; - Map initialLeaf1Results = new HashMap<>(Map.of(0, 21f, 1, 19f, 2, 17f)); - Map initialLeaf2Results = new HashMap<>(Map.of(0, 20f, 1, 18f, 2, 16f)); + PerLeafResult initialLeaf1Results = new PerLeafResult(null, new HashMap<>(Map.of(0, 21f, 1, 19f, 2, 17f))); + PerLeafResult initialLeaf2Results = new PerLeafResult(null, new HashMap<>(Map.of(0, 20f, 1, 18f, 2, 16f))); Map rescoredLeaf1Results = new HashMap<>(Map.of(0, 18f, 1, 20f)); Map rescoredLeaf2Results = new HashMap<>(Map.of(0, 21f)); + RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(1.5f).build(); + int firstPassK = rescoreContext.getFirstPassK(k, true, 1); when(knnQuery.getRescoreContext()).thenReturn(RescoreContext.builder().oversampleFactor(1.5f).build()); when(knnQuery.getK()).thenReturn(k); when(knnWeight.getQuery()).thenReturn(knnQuery); @@ -189,21 +202,21 @@ public void testRescoreWhenShardLevelRescoringEnabled() { @SneakyThrows public void testSingleLeaf() { // Given + int k = 4; + float boost = 1; + PerLeafResult leaf1Result = new PerLeafResult(null, new HashMap<>(Map.of(0, 1.2f, 1, 5.1f, 2, 2.2f))); List leaves = List.of(leaf1); when(reader.leaves()).thenReturn(leaves); - when(knnWeight.searchLeaf(leaf1, 4)).thenReturn(new HashMap<>(Map.of(0, 1.2f, 1, 5.1f, 2, 2.2f))); - when(knnQuery.getK()).thenReturn(4); - + when(knnWeight.searchLeaf(leaf1, k)).thenReturn(leaf1Result); + when(knnQuery.getK()).thenReturn(k); when(indexReaderContext.id()).thenReturn(1); - int[] expectedDocs = { 0, 1, 2 }; - float[] expectedScores = { 1.2f, 5.1f, 2.2f }; - int[] findSegments = { 0, 3 }; - Query expected = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1); + TopDocs expectedTopDocs = ResultUtil.resultMapToTopDocs(leaf1Result.getResult(), leaf1.docBase); // When - Weight actual = objectUnderTest.createWeight(searcher, scoreMode, 1); + Weight actual = objectUnderTest.createWeight(searcher, scoreMode, boost); // Then + Query expected = QueryUtils.INSTANCE.createDocAndScoreQuery(reader, expectedTopDocs); assertEquals(expected, actual.getQuery()); } @@ -212,7 +225,7 @@ public void testNoMatch() { // Given List leaves = List.of(leaf1); when(reader.leaves()).thenReturn(leaves); - when(knnWeight.searchLeaf(leaf1, 4)).thenReturn(Collections.emptyMap()); + when(knnWeight.searchLeaf(leaf1, 4)).thenReturn(PerLeafResult.EMPTY_RESULT); when(knnQuery.getK()).thenReturn(4); // When @@ -230,14 +243,12 @@ public void testRescore() { int k = 2; int firstPassK = 100; - Map initialLeaf1Results = new HashMap<>(Map.of(0, 21f, 1, 19f, 2, 17f, 3, 15f)); - Map initialLeaf2Results = new HashMap<>(Map.of(0, 20f, 1, 18f, 2, 16f, 3, 14f)); + PerLeafResult initialLeaf1Results = new PerLeafResult(null, new HashMap<>(Map.of(0, 21f, 1, 19f, 2, 17f, 3, 15f))); + PerLeafResult initialLeaf2Results = new PerLeafResult(null, new HashMap<>(Map.of(0, 20f, 1, 18f, 2, 16f, 3, 14f))); Map rescoredLeaf1Results = new HashMap<>(Map.of(0, 18f, 1, 20f)); Map rescoredLeaf2Results = new HashMap<>(Map.of(0, 21f)); TopDocs topDocs1 = ResultUtil.resultMapToTopDocs(Map.of(1, 20f), 0); TopDocs topDocs2 = ResultUtil.resultMapToTopDocs(Map.of(0, 21f), 4); - Query expected = new DocAndScoreQuery(2, new int[] { 1, 4 }, new float[] { 20f, 21f }, new int[] { 0, 4, 2 }, 1); - when(indexReaderContext.id()).thenReturn(1); when(knnQuery.getRescoreContext()).thenReturn(RescoreContext.builder().oversampleFactor(1.5f).build()); when(knnQuery.getK()).thenReturn(k); @@ -257,17 +268,86 @@ public void testRescore() { mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true); mockedResultUtil.when(() -> ResultUtil.reduceToTopK(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod); - mockedResultUtil.when(() -> ResultUtil.resultMapToMatchBitSet(any())).thenAnswer(InvocationOnMock::callRealMethod); mockedResultUtil.when(() -> ResultUtil.resultMapToDocIds(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod); mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(eq(rescoredLeaf1Results), anyInt())).thenAnswer(t -> topDocs1); mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(eq(rescoredLeaf2Results), anyInt())).thenAnswer(t -> topDocs2); - try (MockedStatic mockedStaticNativeKnnVectorQuery = mockStatic(NativeEngineKnnVectorQuery.class)) { - mockedStaticNativeKnnVectorQuery.when(() -> NativeEngineKnnVectorQuery.findSegmentStarts(any(), any())) - .thenReturn(new int[] { 0, 4, 2 }); - Weight actual = objectUnderTest.createWeight(searcher, scoreMode, 1); - assertEquals(expected, actual.getQuery()); - } + + // Run + Weight actual = objectUnderTest.createWeight(searcher, scoreMode, 1); + + // Verify + TopDocs[] topDocs = { topDocs1, topDocs2 }; + TopDocs expectedTopDocs = TopDocs.merge(k, topDocs); + Query expected = QueryUtils.INSTANCE.createDocAndScoreQuery(reader, expectedTopDocs); + assertEquals(expected, actual.getQuery()); } } + + @SneakyThrows + public void testExpandNestedDocs() { + List leafReaderContexts = Arrays.asList(leaf1, leaf2); + when(reader.leaves()).thenReturn(leafReaderContexts); + Bits queryFilterBits = mock(Bits.class); + PerLeafResult initialLeaf1Results = new PerLeafResult(queryFilterBits, new HashMap<>(Map.of(0, 19f, 1, 20f, 2, 17f, 3, 15f))); + PerLeafResult initialLeaf2Results = new PerLeafResult(queryFilterBits, new HashMap<>(Map.of(0, 21f, 1, 18f, 2, 16f, 3, 14f))); + List> perLeafResults = Arrays.asList(initialLeaf1Results.getResult(), initialLeaf2Results.getResult()); + + Map exactSearchLeaf1Result = new HashMap<>(Map.of(1, 20f)); + Map exactSearchLeaf2Result = new HashMap<>(Map.of(0, 21f)); + + TopDocs topDocs1 = ResultUtil.resultMapToTopDocs(exactSearchLeaf1Result, 0); + TopDocs topDocs2 = ResultUtil.resultMapToTopDocs(exactSearchLeaf2Result, 0); + TopDocs topK = TopDocs.merge(2, new TopDocs[] { topDocs1, topDocs2 }); + + int k = 2; + when(knnQuery.getRescoreContext()).thenReturn(null); + when(knnQuery.getK()).thenReturn(k); + + BitSetProducer parentFilter = mock(BitSetProducer.class); + when(knnQuery.getParentsFilter()).thenReturn(parentFilter); + when(knnWeight.searchLeaf(leaf1, k)).thenReturn(initialLeaf1Results); + when(knnWeight.searchLeaf(leaf2, k)).thenReturn(initialLeaf2Results); + when(knnWeight.exactSearch(any(), any())).thenReturn(exactSearchLeaf1Result, exactSearchLeaf2Result); + Weight filterWeight = mock(Weight.class); + when(knnWeight.getFilterWeight()).thenReturn(filterWeight); + + DocIdSetIterator allSiblings = mock(DocIdSetIterator.class); + when(allSiblings.nextDoc()).thenReturn(1, 2, DocIdSetIterator.NO_MORE_DOCS); + + Weight expectedWeight = mock(Weight.class); + Query finalQuery = mock(Query.class); + when(finalQuery.createWeight(searcher, scoreMode, 1)).thenReturn(expectedWeight); + + QueryUtils queryUtils = mock(QueryUtils.class); + when(queryUtils.getAllSiblings(any(), any(), any(), any())).thenReturn(allSiblings); + when(queryUtils.createDocAndScoreQuery(eq(reader), any())).thenReturn(finalQuery); + + // Run + NativeEngineKnnVectorQuery query = new NativeEngineKnnVectorQuery(knnQuery, queryUtils, true); + Weight finalWeigh = query.createWeight(searcher, scoreMode, 1.f); + + // Verify + assertEquals(expectedWeight, finalWeigh); + verify(queryUtils).getAllSiblings(leaf1, perLeafResults.get(0).keySet(), parentFilter, queryFilterBits); + verify(queryUtils).getAllSiblings(leaf2, perLeafResults.get(1).keySet(), parentFilter, queryFilterBits); + ArgumentCaptor topDocsCaptor = ArgumentCaptor.forClass(TopDocs.class); + verify(queryUtils).createDocAndScoreQuery(eq(reader), topDocsCaptor.capture()); + TopDocs capturedTopDocs = topDocsCaptor.getValue(); + assertEquals(topK.totalHits, capturedTopDocs.totalHits); + for (int i = 0; i < topK.scoreDocs.length; i++) { + assertEquals(topK.scoreDocs[i].doc, capturedTopDocs.scoreDocs[i].doc); + assertEquals(topK.scoreDocs[i].score, capturedTopDocs.scoreDocs[i].score, 0.01f); + assertEquals(topK.scoreDocs[i].shardIndex, capturedTopDocs.scoreDocs[i].shardIndex); + } + + // Verify acceptedDocIds is intersection of allSiblings and filteredDocIds + ArgumentCaptor contextCaptor = ArgumentCaptor.forClass( + ExactSearcher.ExactSearcherContext.class + ); + verify(knnWeight, times(perLeafResults.size())).exactSearch(any(), contextCaptor.capture()); + assertEquals(1, contextCaptor.getValue().getMatchedDocsIterator().nextDoc()); + assertEquals(2, contextCaptor.getValue().getMatchedDocsIterator().nextDoc()); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, contextCaptor.getValue().getMatchedDocsIterator().nextDoc()); + } } diff --git a/src/test/java/org/opensearch/knn/integ/ExpandNestedDocsIT.java b/src/test/java/org/opensearch/knn/integ/ExpandNestedDocsIT.java new file mode 100644 index 000000000..164aa7100 --- /dev/null +++ b/src/test/java/org/opensearch/knn/integ/ExpandNestedDocsIT.java @@ -0,0 +1,392 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.integ; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import com.google.common.collect.Multimap; +import lombok.AllArgsConstructor; +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.After; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.NestedKnnDocBuilder; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.mapper.Mode; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; +import static org.opensearch.knn.common.Constants.FIELD_FILTER; +import static org.opensearch.knn.common.Constants.FIELD_TERM; +import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.EXPAND_NESTED; +import static org.opensearch.knn.common.KNNConstants.K; +import static org.opensearch.knn.common.KNNConstants.KNN; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.MODE_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.PATH; +import static org.opensearch.knn.common.KNNConstants.QUERY; +import static org.opensearch.knn.common.KNNConstants.TYPE; +import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR; +import static org.opensearch.knn.common.KNNConstants.TYPE_NESTED; +import static org.opensearch.knn.common.KNNConstants.VECTOR; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; + +@Log4j2 +@AllArgsConstructor +public class ExpandNestedDocsIT extends KNNRestTestCase { + private static final String INDEX_NAME = "test-index-expand-nested-search"; + private static final String FIELD_NAME_NESTED = "test_nested"; + private static final String FIELD_NAME_VECTOR = "test_vector"; + private static final String FIELD_NAME_PARKING = "parking"; + private static final String FIELD_NAME_STORAGE = "storage"; + private static final String TYPE_BOOLEAN = "boolean"; + private static final String FIELD_VALUE_TRUE = "true"; + private static final String FIELD_VALUE_FALSE = "false"; + private static final String PROPERTIES_FIELD = "properties"; + private static final String INNER_HITS = "inner_hits"; + + private String description; + private KNNEngine engine; + private VectorDataType dataType; + private Mode mode; + private Integer dimension; + + @After + @SneakyThrows + public final void cleanUp() { + deleteKNNIndex(INDEX_NAME); + } + + @ParametersFactory(argumentFormatting = "description:%1$s; engine:%2$s, data_type:%3$s, mode:%4$s, dimension:%5$s") + public static Collection parameters() throws IOException { + int dimension = 1; + return Arrays.asList( + $$( + $("Lucene with byte format and in memory mode", KNNEngine.LUCENE, VectorDataType.BYTE, Mode.NOT_CONFIGURED, dimension), + $("Lucene with float format and in memory mode", KNNEngine.LUCENE, VectorDataType.FLOAT, Mode.NOT_CONFIGURED, dimension), + $( + "Faiss with binary format and in memory mode", + KNNEngine.FAISS, + VectorDataType.BINARY, + Mode.NOT_CONFIGURED, + dimension * 8 + ), + $("Faiss with byte format and in memory mode", KNNEngine.FAISS, VectorDataType.BYTE, Mode.NOT_CONFIGURED, dimension), + $("Faiss with float format and in memory mode", KNNEngine.FAISS, VectorDataType.FLOAT, Mode.IN_MEMORY, dimension), + $( + "Faiss with float format and on disk mode", + KNNEngine.FAISS, + VectorDataType.FLOAT, + Mode.ON_DISK, + // Currently, on disk mode only supports dimension of multiple of 8 + dimension * 8 + ) + ) + ); + } + + @SneakyThrows + public void testExpandNestedDocs_whenFilteredOnParentDoc_thenReturnAllNestedDoc() { + int numberOfNestedFields = 2; + createKnnIndex(engine, mode, dimension, dataType); + addRandomVectorsWithTopLevelField(1, numberOfNestedFields, FIELD_NAME_PARKING, FIELD_VALUE_TRUE); + addRandomVectorsWithTopLevelField(2, numberOfNestedFields, FIELD_NAME_PARKING, FIELD_VALUE_TRUE); + addRandomVectorsWithTopLevelField(3, numberOfNestedFields, FIELD_NAME_PARKING, FIELD_VALUE_TRUE); + addRandomVectorsWithTopLevelField(4, numberOfNestedFields, FIELD_NAME_PARKING, FIELD_VALUE_TRUE); + addRandomVectorsWithTopLevelField(5, numberOfNestedFields, FIELD_NAME_PARKING, FIELD_VALUE_TRUE); + deleteKnnDoc(INDEX_NAME, String.valueOf(1)); + updateVectorWithTopLevelField(2, numberOfNestedFields, FIELD_NAME_PARKING, FIELD_VALUE_FALSE); + + // Run + Float[] queryVector = createVector(); + Response response = queryNestedFieldWithExpandNestedDocs(INDEX_NAME, 10, queryVector, FIELD_NAME_PARKING, FIELD_VALUE_TRUE); + + // Verify + String entity = EntityUtils.toString(response.getEntity()); + Multimap docIdToOffsets = parseInnerHits(entity, FIELD_NAME_NESTED); + assertEquals(3, docIdToOffsets.keySet().size()); + for (String key : docIdToOffsets.keySet()) { + assertEquals(numberOfNestedFields, docIdToOffsets.get(key).size()); + } + } + + @SneakyThrows + public void testExpandNestedDocs_whenFilteredOnNestedFieldDoc_thenReturnFilteredNestedDoc() { + int numberOfNestedFields = 2; + createKnnIndex(engine, mode, dimension, dataType); + addRandomVectorsWithMetadata(1, numberOfNestedFields, FIELD_NAME_STORAGE, Arrays.asList(FIELD_VALUE_FALSE, FIELD_VALUE_FALSE)); + addRandomVectorsWithMetadata(2, numberOfNestedFields, FIELD_NAME_STORAGE, Arrays.asList(FIELD_VALUE_TRUE, FIELD_VALUE_TRUE)); + addRandomVectorsWithMetadata(3, numberOfNestedFields, FIELD_NAME_STORAGE, Arrays.asList(FIELD_VALUE_TRUE, FIELD_VALUE_TRUE)); + addRandomVectorsWithMetadata(4, numberOfNestedFields, FIELD_NAME_STORAGE, Arrays.asList(FIELD_VALUE_FALSE, FIELD_VALUE_TRUE)); + addRandomVectorsWithMetadata(5, numberOfNestedFields, FIELD_NAME_STORAGE, Arrays.asList(FIELD_VALUE_TRUE, FIELD_VALUE_FALSE)); + deleteKnnDoc(INDEX_NAME, String.valueOf(1)); + addRandomVectorsWithMetadata(2, numberOfNestedFields, FIELD_NAME_STORAGE, Arrays.asList(FIELD_VALUE_FALSE, FIELD_VALUE_FALSE)); + + // Run + Float[] queryVector = createVector(); + Response response = queryNestedFieldWithExpandNestedDocs( + INDEX_NAME, + 10, + queryVector, + FIELD_NAME_NESTED + "." + FIELD_NAME_STORAGE, + FIELD_VALUE_TRUE + ); + + // Verify + String entity = EntityUtils.toString(response.getEntity()); + Multimap docIdToOffsets = parseInnerHits(entity, FIELD_NAME_NESTED); + assertEquals(3, docIdToOffsets.keySet().size()); + assertEquals(2, docIdToOffsets.get(String.valueOf(3)).size()); + assertEquals(1, docIdToOffsets.get(String.valueOf(4)).size()); + assertEquals(1, docIdToOffsets.get(String.valueOf(5)).size()); + + assertTrue(docIdToOffsets.get(String.valueOf(4)).contains(1)); + assertTrue(docIdToOffsets.get(String.valueOf(5)).contains(0)); + } + + @SneakyThrows + public void testExpandNestedDocs_whenMultiShards_thenReturnCorrectResult() { + int numberOfNestedFields = 10; + int numberOfDocuments = 5; + createKnnIndex(engine, mode, dimension, dataType, 2); + for (int i = 1; i <= numberOfDocuments; i++) { + addSingleRandomVectors(i, numberOfNestedFields); + } + forceMergeKnnIndex(INDEX_NAME); + + // Run + Float[] queryVector = createVector(); + Response response = queryNestedFieldWithExpandNestedDocs(INDEX_NAME, numberOfDocuments, queryVector); + + // Verify + String entity = EntityUtils.toString(response.getEntity()); + Multimap docIdToOffsets = parseInnerHits(entity, FIELD_NAME_NESTED); + assertEquals(numberOfDocuments, docIdToOffsets.keySet().size()); + int defaultInnerHitSize = 3; + for (int i = 1; i <= numberOfDocuments; i++) { + assertEquals(defaultInnerHitSize, docIdToOffsets.get(String.valueOf(i)).size()); + } + } + + private Float[] createVector() { + int vectorSize = VectorDataType.BINARY.equals(dataType) ? dimension / 8 : dimension; + Float[] vector = new Float[vectorSize]; + for (int i = 0; i < vectorSize; i++) { + vector[i] = (float) (randomInt(255) - 128); + } + return vector; + } + + private void updateVectorWithTopLevelField( + final int docId, + final int numOfNestedFields, + final String fieldName, + final String fieldValue + ) throws IOException { + addRandomVectorsWithTopLevelField(docId, numOfNestedFields, fieldName, fieldValue); + } + + private void addRandomVectorsWithTopLevelField( + final int docId, + final int numOfNestedFields, + final String fieldName, + final String fieldValue + ) throws IOException { + + NestedKnnDocBuilder builder = NestedKnnDocBuilder.create(FIELD_NAME_NESTED); + for (int i = 0; i < numOfNestedFields; i++) { + builder.addVectors(FIELD_NAME_VECTOR, createVector()); + } + builder.addTopLevelField(fieldName, fieldValue); + String doc = builder.build(); + addKnnDoc(INDEX_NAME, String.valueOf(docId), doc); + refreshIndex(INDEX_NAME); + } + + private void addSingleRandomVectors(final int docId, final int numOfNestedFields) throws IOException { + NestedKnnDocBuilder builder = NestedKnnDocBuilder.create(FIELD_NAME_NESTED); + Object[] vector = createVector(); + for (int i = 0; i < numOfNestedFields; i++) { + builder.addVectors(FIELD_NAME_VECTOR, vector); + } + String doc = builder.build(); + addKnnDoc(INDEX_NAME, String.valueOf(docId), doc); + refreshIndex(INDEX_NAME); + } + + private void addRandomVectorsWithMetadata( + final int docId, + final int numOfNestedFields, + final String nestedFieldName, + final List nestedFieldValue + ) throws IOException { + assert numOfNestedFields == nestedFieldValue.size(); + + NestedKnnDocBuilder builder = NestedKnnDocBuilder.create(FIELD_NAME_NESTED); + for (int i = 0; i < numOfNestedFields; i++) { + builder.addVectorWithMetadata(FIELD_NAME_VECTOR, createVector(), nestedFieldName, nestedFieldValue.get(i)); + } + String doc = builder.build(); + addKnnDoc(INDEX_NAME, String.valueOf(docId), doc); + refreshIndex(INDEX_NAME); + } + + private void createKnnIndex(final KNNEngine engine, final Mode mode, final int dimension, final VectorDataType vectorDataType) + throws Exception { + createKnnIndex(engine, mode, dimension, vectorDataType, 1); + } + + /** + * { + * "dynamic": false, + * "properties": { + * "test_nested": { + * "type": "nested", + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": 3, + * "mode": "in_memory", + * "data_type: "float", + * "method": { + * "name": "hnsw", + * "engine": "lucene" + * } + * }, + * "storage": { + * "type": "boolean" + * } + * } + * }, + * "parking": { + * "type": "boolean" + * } + * } + * } + */ + private void createKnnIndex( + final KNNEngine engine, + final Mode mode, + final int dimension, + final VectorDataType vectorDataType, + final int numOfShards + ) throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field("dynamic", false) + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME_NESTED) + .field(TYPE, TYPE_NESTED) + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME_VECTOR) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, dimension) + .field(MODE_PARAMETER, Mode.NOT_CONFIGURED.equals(mode) ? null : mode.getName()) + .field(VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, engine.getName()) + .endObject() + .endObject() + .startObject(FIELD_NAME_STORAGE) + .field(TYPE, TYPE_BOOLEAN) + .endObject() + .endObject() + .endObject() + .startObject(FIELD_NAME_PARKING) + .field(TYPE, TYPE_BOOLEAN) + .endObject() + .endObject() + .endObject(); + + String mapping = builder.toString(); + Settings settings = Settings.builder() + .put("number_of_shards", numOfShards) + .put("number_of_replicas", 0) + .put("index.knn", true) + .build(); + createKnnIndex(INDEX_NAME, settings, mapping); + } + + private Response queryNestedFieldWithExpandNestedDocs(final String index, final Integer k, final Object[] vector) throws IOException { + return queryNestedFieldWithExpandNestedDocs(index, k, vector, null, null); + } + + /** + * { + * "query": { + * "nested": { + * "path": "test_nested", + * "query": { + * "knn": { + * "test_nested.test_vector" : { + * "vector: [1, 1, 2] + * "k": 3, + * "filter": { + * "term": { + * "nested_field.storage": true + * } + * } + * } + * } + * }, + * "inner_hits": {} + * } + * } + * } + */ + private Response queryNestedFieldWithExpandNestedDocs( + final String index, + final Integer k, + final Object[] vector, + final String filterName, + final String filterValue + ) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject(QUERY); + builder.startObject(TYPE_NESTED); + builder.field(PATH, FIELD_NAME_NESTED); + builder.startObject(QUERY).startObject(KNN).startObject(FIELD_NAME_NESTED + "." + FIELD_NAME_VECTOR); + builder.field(VECTOR, vector); + builder.field(K, k); + builder.field(EXPAND_NESTED, true); + if (filterName != null && filterValue != null) { + builder.startObject(FIELD_FILTER); + builder.startObject(FIELD_TERM); + builder.field(filterName, filterValue); + builder.endObject(); + builder.endObject(); + } + + builder.endObject().endObject().endObject(); + builder.field(INNER_HITS); + builder.startObject().endObject(); + builder.endObject().endObject().endObject(); + + Request request = new Request("POST", "/" + index + "/_search"); + request.setJsonEntity(builder.toString()); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + return response; + } +} diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 2afbd9639..632543e43 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -5,9 +5,12 @@ package org.opensearch.knn; +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.Multimap; import com.google.common.primitives.Bytes; import com.google.common.primitives.Floats; import com.google.common.primitives.Ints; +import com.jayway.jsonpath.JsonPath; import lombok.SneakyThrows; import lombok.extern.log4j.Log4j2; import org.apache.commons.lang.StringUtils; @@ -52,7 +55,6 @@ import javax.management.remote.JMXConnector; import javax.management.remote.JMXConnectorFactory; import javax.management.remote.JMXServiceURL; - import java.io.IOException; import java.io.InputStream; import java.nio.file.Files; @@ -928,6 +930,25 @@ protected int parseHits(String searchResponseBody) throws IOException { return ((List) responseMap.get("hits")).size(); } + /** + * Get mapping from parent doc Id to inner hits offsets + */ + protected Multimap parseInnerHits(String searchResponseBody, String fieldName) throws IOException { + List ids = JsonPath.read( + searchResponseBody, + String.format(Locale.ROOT, "$.hits.hits[*].inner_hits.%s.hits.hits[*]._id", fieldName) + ); + List offsets = JsonPath.read( + searchResponseBody, + String.format(Locale.ROOT, "$.hits.hits[*].inner_hits.%s.hits.hits[*]._nested.offset", fieldName) + ); + Multimap docIdToOffsets = ArrayListMultimap.create(); + for (int i = 0; i < ids.size(); i++) { + docIdToOffsets.put(ids.get(i), offsets.get(i)); + } + return docIdToOffsets; + } + protected List parseIds(String searchResponseBody) throws IOException { @SuppressWarnings("unchecked") List hits = (List) ((Map) createParser(