From 96fa384e1b60fd9cdb86216c69f0ba06693a395f Mon Sep 17 00:00:00 2001 From: Varun Jain <varunudr@amazon.com> Date: Tue, 14 Jan 2025 12:44:35 -0800 Subject: [PATCH] [Backport 2.x] Pagination in hybrid query (#1099) * Pagination in Hybrid query (#1048) * Pagination in Hybrid query Signed-off-by: Varun Jain <varunudr@amazon.com> * Remove unwanted code Signed-off-by: Varun Jain <varunudr@amazon.com> * Adding hybrid query context dto Signed-off-by: Varun Jain <varunudr@amazon.com> * Adding javadoc in hybridquerycontext and addressing few comments from review Signed-off-by: Varun Jain <varunudr@amazon.com> * rename hybrid query extraction method Signed-off-by: Varun Jain <varunudr@amazon.com> * Refactoring to optimize extractHybridQuery method calls Signed-off-by: Varun Jain <varunudr@amazon.com> * Changes in tests to adapt with builder pattern in querybuilder Signed-off-by: Varun Jain <varunudr@amazon.com> * Add mapper service mock in tests Signed-off-by: Varun Jain <varunudr@amazon.com> * Fix error message of index.max_result_window setting Signed-off-by: Varun Jain <varunudr@amazon.com> * Fix error message of index.max_result_window setting Signed-off-by: Varun Jain <varunudr@amazon.com> * Fixing validation condition for lower bound Signed-off-by: Varun Jain <varunudr@amazon.com> * fix tests Signed-off-by: Varun Jain <varunudr@amazon.com> * Removing version check from doEquals and doHashCode method Signed-off-by: Varun Jain <varunudr@amazon.com> --------- Signed-off-by: Varun Jain <varunudr@amazon.com> * Update pagination_depth datatype from int to Integer (#1094) * Update pagination_depth datatype from int to Integer Signed-off-by: Varun Jain <varunudr@amazon.com> --------- Signed-off-by: Varun Jain <varunudr@amazon.com> --- CHANGELOG.md | 1 + .../common/MinClusterVersionUtil.java | 5 + .../processor/NormalizationProcessor.java | 1 + .../NormalizationProcessorWorkflow.java | 71 ++++++-- ...zationProcessorWorkflowExecuteRequest.java | 2 + .../combination/CombineScoresDto.java | 1 + .../processor/combination/ScoreCombiner.java | 10 +- .../neuralsearch/query/HybridQuery.java | 18 +- .../query/HybridQueryBuilder.java | 57 +++++- .../query/HybridQueryContext.java | 17 ++ .../search/query/HybridCollectorManager.java | 52 +++++- .../query/HybridQueryPhaseSearcher.java | 18 +- .../neuralsearch/util/HybridQueryUtil.java | 16 +- .../NormalizationProcessorTests.java | 5 +- .../NormalizationProcessorWorkflowTests.java | 128 +++++++++++-- .../query/HybridQueryBuilderTests.java | 122 +++++++++++++ .../neuralsearch/query/HybridQueryIT.java | 168 ++++++++++++++++-- .../neuralsearch/query/HybridQueryTests.java | 50 ++++-- .../query/HybridQueryWeightTests.java | 9 +- .../HybridAggregationProcessorTests.java | 12 +- .../query/HybridCollectorManagerTests.java | 154 ++++++++++++++-- .../query/HybridQueryPhaseSearcherTests.java | 28 +++ .../util/HybridQueryUtilTests.java | 40 ++++- .../neuralsearch/BaseNeuralSearchIT.java | 2 - 24 files changed, 884 insertions(+), 103 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/query/HybridQueryContext.java diff --git a/CHANGELOG.md b/CHANGELOG.md index cf648c641..12248ccad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x) ### Features +- Pagination in Hybrid query ([#1048](https://github.com/opensearch-project/neural-search/pull/1048)) ### Enhancements - Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970)) - Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988)) diff --git a/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java b/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java index a17e138e2..13410d1c7 100644 --- a/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java @@ -24,6 +24,7 @@ public final class MinClusterVersionUtil { private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0; private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0; private static final Version MINIMAL_SUPPORTED_VERSION_QUERY_IMAGE_FIX = Version.V_2_19_0; + private static final Version MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY = Version.V_2_19_0; // Note this minimal version will act as a override private static final Map<String, Version> MINIMAL_VERSION_NEURAL = ImmutableMap.<String, Version>builder() @@ -41,6 +42,10 @@ public static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() { return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH); } + public static boolean isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery() { + return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY); + } + public static boolean isClusterOnOrAfterMinReqVersion(String key) { Version version; if (MINIMAL_VERSION_NEURAL.containsKey(key)) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index d2008ae97..d2fa03fde 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -93,6 +93,7 @@ private <Result extends SearchPhaseResult> void prepareAndExecuteNormalizationWo .combinationTechnique(combinationTechnique) .explain(explain) .pipelineProcessingContext(requestContextOptional.orElse(null)) + .searchPhaseContext(searchPhaseContext) .build(); normalizationWorkflow.execute(request); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index f2699d967..db3747a13 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -19,6 +19,7 @@ import org.apache.lucene.search.Sort; import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.search.FieldDoc; +import org.opensearch.action.search.SearchPhaseContext; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.neuralsearch.processor.combination.CombineScoresDto; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; @@ -64,7 +65,8 @@ public void execute( final List<QuerySearchResult> querySearchResults, final Optional<FetchSearchResult> fetchSearchResultOptional, final ScoreNormalizationTechnique normalizationTechnique, - final ScoreCombinationTechnique combinationTechnique + final ScoreCombinationTechnique combinationTechnique, + final SearchPhaseContext searchPhaseContext ) { NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest.builder() .querySearchResults(querySearchResults) @@ -72,17 +74,21 @@ public void execute( .normalizationTechnique(normalizationTechnique) .combinationTechnique(combinationTechnique) .explain(false) + .searchPhaseContext(searchPhaseContext) .build(); execute(request); } public void execute(final NormalizationProcessorWorkflowExecuteRequest request) { + List<QuerySearchResult> querySearchResults = request.getQuerySearchResults(); + Optional<FetchSearchResult> fetchSearchResultOptional = request.getFetchSearchResultOptional(); + // save original state - List<Integer> unprocessedDocIds = unprocessedDocIds(request.getQuerySearchResults()); + List<Integer> unprocessedDocIds = unprocessedDocIds(querySearchResults); // pre-process data log.debug("Pre-process query results"); - List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(request.getQuerySearchResults()); + List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(querySearchResults); explain(request, queryTopDocs); @@ -93,8 +99,9 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request) CombineScoresDto combineScoresDTO = CombineScoresDto.builder() .queryTopDocs(queryTopDocs) .scoreCombinationTechnique(request.getCombinationTechnique()) - .querySearchResults(request.getQuerySearchResults()) - .sort(evaluateSortCriteria(request.getQuerySearchResults(), queryTopDocs)) + .querySearchResults(querySearchResults) + .sort(evaluateSortCriteria(querySearchResults, queryTopDocs)) + .fromValueForSingleShard(getFromValueIfSingleShard(request)) .build(); // combine @@ -103,8 +110,26 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request) // post-process data log.debug("Post-process query results after score normalization and combination"); - updateOriginalQueryResults(combineScoresDTO); - updateOriginalFetchResults(request.getQuerySearchResults(), request.getFetchSearchResultOptional(), unprocessedDocIds); + updateOriginalQueryResults(combineScoresDTO, fetchSearchResultOptional.isPresent()); + updateOriginalFetchResults( + querySearchResults, + fetchSearchResultOptional, + unprocessedDocIds, + combineScoresDTO.getFromValueForSingleShard() + ); + } + + /** + * Get value of from parameter when there is a single shard + * and fetch phase is already executed + * Ref https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/SearchService.java#L715 + */ + private int getFromValueIfSingleShard(final NormalizationProcessorWorkflowExecuteRequest request) { + final SearchPhaseContext searchPhaseContext = request.getSearchPhaseContext(); + if (searchPhaseContext.getNumShards() > 1 || request.fetchSearchResultOptional.isEmpty()) { + return -1; + } + return searchPhaseContext.getRequest().source().from(); } /** @@ -173,19 +198,33 @@ private List<CompoundTopDocs> getQueryTopDocs(final List<QuerySearchResult> quer return queryTopDocs; } - private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO) { + private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO, final boolean isFetchPhaseExecuted) { final List<QuerySearchResult> querySearchResults = combineScoresDTO.getQuerySearchResults(); final List<CompoundTopDocs> queryTopDocs = getCompoundTopDocs(combineScoresDTO, querySearchResults); final Sort sort = combineScoresDTO.getSort(); + int totalScoreDocsCount = 0; for (int index = 0; index < querySearchResults.size(); index++) { QuerySearchResult querySearchResult = querySearchResults.get(index); CompoundTopDocs updatedTopDocs = queryTopDocs.get(index); + totalScoreDocsCount += updatedTopDocs.getScoreDocs().size(); TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore( buildTopDocs(updatedTopDocs, sort), maxScoreForShard(updatedTopDocs, sort != null) ); + // Fetch Phase had ran before the normalization phase, therefore update the from value in result of each shard. + // This will ensure the trimming of the search results. + if (isFetchPhaseExecuted) { + querySearchResult.from(combineScoresDTO.getFromValueForSingleShard()); + } querySearchResult.topDocs(updatedTopDocsAndMaxScore, querySearchResult.sortValueFormats()); } + + final int from = querySearchResults.get(0).from(); + if (from > totalScoreDocsCount) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Reached end of search result, increase pagination_depth value to see more results") + ); + } } private List<CompoundTopDocs> getCompoundTopDocs(CombineScoresDto combineScoresDTO, List<QuerySearchResult> querySearchResults) { @@ -244,7 +283,8 @@ private TopDocs buildTopDocs(CompoundTopDocs updatedTopDocs, Sort sort) { private void updateOriginalFetchResults( final List<QuerySearchResult> querySearchResults, final Optional<FetchSearchResult> fetchSearchResultOptional, - final List<Integer> docIds + final List<Integer> docIds, + final int fromValueForSingleShard ) { if (fetchSearchResultOptional.isEmpty()) { return; @@ -276,14 +316,21 @@ private void updateOriginalFetchResults( QuerySearchResult querySearchResult = querySearchResults.get(0); TopDocs topDocs = querySearchResult.topDocs().topDocs; + // Scenario to handle when calculating the trimmed length of updated search hits + // When normalization process runs after fetch phase, then search hits already fetched. Therefore, use the from value sent in the + // search request to calculate the effective length of updated search hits array. + int trimmedLengthOfSearchHits = topDocs.scoreDocs.length - fromValueForSingleShard; // iterate over the normalized/combined scores, that solves (1) and (3) - SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> { + SearchHit[] updatedSearchHitArray = new SearchHit[trimmedLengthOfSearchHits]; + for (int i = 0; i < trimmedLengthOfSearchHits; i++) { + // Read topDocs after the desired from length + ScoreDoc scoreDoc = topDocs.scoreDocs[i + fromValueForSingleShard]; // get fetched hit content by doc_id SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc); // update score to normalized/combined value (3) searchHit.score(scoreDoc.score); - return searchHit; - }).toArray(SearchHit[]::new); + updatedSearchHitArray[i] = searchHit; + } SearchHits updatedSearchHits = new SearchHits( updatedSearchHitArray, querySearchResult.getTotalHits(), diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java index ea0b54b9c..e818c1b31 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java @@ -7,6 +7,7 @@ import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Getter; +import org.opensearch.action.search.SearchPhaseContext; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.search.fetch.FetchSearchResult; @@ -29,4 +30,5 @@ public class NormalizationProcessorWorkflowExecuteRequest { final ScoreCombinationTechnique combinationTechnique; boolean explain; final PipelineProcessingContext pipelineProcessingContext; + final SearchPhaseContext searchPhaseContext; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java index c4783969b..fecf5ca09 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java @@ -29,4 +29,5 @@ public class CombineScoresDto { private List<QuerySearchResult> querySearchResults; @Nullable private Sort sort; + private int fromValueForSingleShard; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index 1779f20f7..40625adfb 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -70,14 +70,10 @@ public class ScoreCombiner { public void combineScores(final CombineScoresDto combineScoresDTO) { // iterate over results from each shard. Every CompoundTopDocs object has results from // multiple sub queries, doc ids may repeat for each sub query results + ScoreCombinationTechnique scoreCombinationTechnique = combineScoresDTO.getScoreCombinationTechnique(); + Sort sort = combineScoresDTO.getSort(); combineScoresDTO.getQueryTopDocs() - .forEach( - compoundQueryTopDocs -> combineShardScores( - combineScoresDTO.getScoreCombinationTechnique(), - compoundQueryTopDocs, - combineScoresDTO.getSort() - ) - ); + .forEach(compoundQueryTopDocs -> combineShardScores(scoreCombinationTechnique, compoundQueryTopDocs, sort)); } private void combineShardScores( diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java index 60d5870da..d1e339bd5 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java @@ -34,17 +34,22 @@ public final class HybridQuery extends Query implements Iterable<Query> { private final List<Query> subQueries; + private final HybridQueryContext queryContext; /** * Create new instance of hybrid query object based on collection of sub queries and filter query * @param subQueries collection of queries that are executed individually and contribute to a final list of combined scores * @param filterQueries list of filters that will be applied to each sub query. Each filter from the list is added as bool "filter" clause. If this is null sub queries will be executed as is */ - public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQueries) { + public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQueries, final HybridQueryContext hybridQueryContext) { Objects.requireNonNull(subQueries, "collection of queries must not be null"); if (subQueries.isEmpty()) { throw new IllegalArgumentException("collection of queries must not be empty"); } + Integer paginationDepth = hybridQueryContext.getPaginationDepth(); + if (Objects.nonNull(paginationDepth) && paginationDepth == 0) { + throw new IllegalArgumentException("pagination_depth must not be zero"); + } if (Objects.isNull(filterQueries) || filterQueries.isEmpty()) { this.subQueries = new ArrayList<>(subQueries); } else { @@ -57,10 +62,11 @@ public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQ } this.subQueries = modifiedSubQueries; } + this.queryContext = hybridQueryContext; } - public HybridQuery(final Collection<Query> subQueries) { - this(subQueries, List.of()); + public HybridQuery(final Collection<Query> subQueries, final HybridQueryContext hybridQueryContext) { + this(subQueries, List.of(), hybridQueryContext); } /** @@ -128,7 +134,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { return super.rewrite(indexSearcher); } final List<Query> rewrittenSubQueries = manager.getQueriesAfterRewrite(collectors); - return new HybridQuery(rewrittenSubQueries); + return new HybridQuery(rewrittenSubQueries, queryContext); } private Void rewriteQuery(Query query, HybridQueryExecutorCollector<IndexSearcher, Map.Entry<Query, Boolean>> collector) { @@ -190,6 +196,10 @@ public Collection<Query> getSubQueries() { return Collections.unmodifiableCollection(subQueries); } + public HybridQueryContext getQueryContext() { + return queryContext; + } + /** * Create the Weight used to score this query * diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index 338758802..bea94e603 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -22,6 +22,7 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexSettings; import org.opensearch.index.query.AbstractQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; @@ -35,6 +36,8 @@ import lombok.experimental.Accessors; import lombok.extern.log4j.Log4j2; +import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery; + /** * Class abstract creation of a Query type "hybrid". Hybrid query will allow execution of multiple sub-queries and * collects score for each of those sub-query. @@ -48,14 +51,22 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder<HybridQueryBu public static final String NAME = "hybrid"; private static final ParseField QUERIES_FIELD = new ParseField("queries"); + private static final ParseField PAGINATION_DEPTH_FIELD = new ParseField("pagination_depth"); private final List<QueryBuilder> queries = new ArrayList<>(); + private Integer paginationDepth; + static final int MAX_NUMBER_OF_SUB_QUERIES = 5; + private final static int DEFAULT_PAGINATION_DEPTH = 10; + private static final int LOWER_BOUND_OF_PAGINATION_DEPTH = 0; public HybridQueryBuilder(StreamInput in) throws IOException { super(in); queries.addAll(readQueries(in)); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + paginationDepth = in.readOptionalInt(); + } } /** @@ -66,6 +77,9 @@ public HybridQueryBuilder(StreamInput in) throws IOException { @Override protected void doWriteTo(StreamOutput out) throws IOException { writeQueries(out, queries); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + out.writeOptionalInt(paginationDepth); + } } /** @@ -95,6 +109,10 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep queryBuilder.toXContent(builder, params); } builder.endArray(); + // TODO https://github.com/opensearch-project/neural-search/issues/1097 + if (Objects.nonNull(paginationDepth)) { + builder.field(PAGINATION_DEPTH_FIELD.getPreferredName(), paginationDepth); + } printBoostAndQueryName(builder); builder.endObject(); } @@ -111,7 +129,9 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio if (queryCollection.isEmpty()) { return Queries.newMatchNoDocsQuery(String.format(Locale.ROOT, "no clauses for %s query", NAME)); } - return new HybridQuery(queryCollection); + validatePaginationDepth(paginationDepth, queryShardContext); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(paginationDepth).build(); + return new HybridQuery(queryCollection, hybridQueryContext); } /** @@ -147,6 +167,7 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOException { float boost = AbstractQueryBuilder.DEFAULT_BOOST; + int paginationDepth = DEFAULT_PAGINATION_DEPTH; final List<QueryBuilder> queries = new ArrayList<>(); String queryName = null; @@ -194,6 +215,8 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx } } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { queryName = parser.text(); + } else if (PAGINATION_DEPTH_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + paginationDepth = parser.intValue(); } else { log.error(String.format(Locale.ROOT, "[%s] query does not support [%s]", NAME, currentFieldName)); throw new ParsingException( @@ -214,6 +237,9 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx HybridQueryBuilder compoundQueryBuilder = new HybridQueryBuilder(); compoundQueryBuilder.queryName(queryName); compoundQueryBuilder.boost(boost); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + compoundQueryBuilder.paginationDepth(paginationDepth); + } for (QueryBuilder query : queries) { compoundQueryBuilder.add(query); } @@ -233,6 +259,9 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryShardContext) throws I if (changed) { newBuilder.queryName(queryName); newBuilder.boost(boost); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + newBuilder.paginationDepth(paginationDepth); + } return newBuilder; } else { return this; @@ -254,6 +283,7 @@ protected boolean doEquals(HybridQueryBuilder obj) { } EqualsBuilder equalsBuilder = new EqualsBuilder(); equalsBuilder.append(queries, obj.queries); + equalsBuilder.append(paginationDepth, obj.paginationDepth); return equalsBuilder.isEquals(); } @@ -263,7 +293,7 @@ protected boolean doEquals(HybridQueryBuilder obj) { */ @Override protected int doHashCode() { - return Objects.hash(queries); + return Objects.hash(queries, paginationDepth); } /** @@ -294,6 +324,29 @@ private Collection<Query> toQueries(Collection<QueryBuilder> queryBuilders, Quer return queries; } + private static void validatePaginationDepth(final int paginationDepth, final QueryShardContext queryShardContext) { + if (Objects.isNull(paginationDepth)) { + return; + } + if (paginationDepth < LOWER_BOUND_OF_PAGINATION_DEPTH) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "pagination_depth should be greater than %s", LOWER_BOUND_OF_PAGINATION_DEPTH) + ); + } + // compare pagination depth with OpenSearch setting index.max_result_window + // see https://opensearch.org/docs/latest/install-and-configure/configuring-opensearch/index-settings/ + int maxResultWindowIndexSetting = queryShardContext.getIndexSettings().getMaxResultWindow(); + if (paginationDepth > maxResultWindowIndexSetting) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "pagination_depth should be less than or equal to %s setting", + IndexSettings.MAX_RESULT_WINDOW_SETTING.getKey() + ) + ); + } + } + /** * visit method to parse the HybridQueryBuilder by a visitor */ diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryContext.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryContext.java new file mode 100644 index 000000000..34706e6e7 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryContext.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query; + +import lombok.Builder; +import lombok.Getter; + +/** + * Class that holds the low level information of hybrid query in the form of context + */ +@Builder +@Getter +public class HybridQueryContext { + private Integer paginationDepth; +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index f9457f6ca..3c6a7271f 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -8,6 +8,7 @@ import lombok.RequiredArgsConstructor; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Collector; import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.Weight; @@ -22,6 +23,7 @@ import org.opensearch.common.Nullable; import org.opensearch.common.lucene.search.FilteredCollector; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.neuralsearch.query.HybridQuery; import org.opensearch.neuralsearch.search.HitsThresholdChecker; import org.opensearch.neuralsearch.search.collector.HybridSearchCollector; import org.opensearch.neuralsearch.search.collector.HybridTopFieldDocSortCollector; @@ -52,6 +54,7 @@ import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createFieldDocDelimiterElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createSortFieldsForDelimiterResults; +import static org.opensearch.neuralsearch.util.HybridQueryUtil.isHybridQueryWrappedInBooleanQuery; /** * Collector manager based on HybridTopScoreDocCollector that allows users to parallelize counting the number of hits. @@ -80,14 +83,28 @@ public abstract class HybridCollectorManager implements CollectorManager<Collect * @throws IOException */ public static CollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException { + if (searchContext.scrollContext() != null) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Scroll operation is not supported in hybrid query")); + } final IndexReader reader = searchContext.searcher().getIndexReader(); final int totalNumDocs = Math.max(0, reader.numDocs()); - int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); + int numDocs = Math.min(getSubqueryResultsRetrievalSize(searchContext), totalNumDocs); int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo(); if (searchContext.sort() != null) { validateSortCriteria(searchContext, searchContext.trackScores()); } + boolean isSingleShard = searchContext.numberOfShards() == 1; + // In case of single shard, it can happen that fetch phase might execute before normalization phase. Moreover, The pagination logic + // lies in the fetch phase. + // If the fetch phase gets executed before the normalization phase, then the result will be not paginated as per normalized score. + // Therefore, to avoid it we will update from value in search context to 0. This will stop fetch phase to trim results prematurely. + // Later in the normalization phase we will update QuerySearchResult object with the right from value, to handle the effective + // trimming of results. + if (isSingleShard && searchContext.from() > 0) { + searchContext.from(0); + } + Weight filteringWeight = null; // Check for post filter to create weight for filter query and later use that weight in the search workflow if (Objects.nonNull(searchContext.parsedPostFilter()) && Objects.nonNull(searchContext.parsedPostFilter().query())) { @@ -461,6 +478,39 @@ private ReduceableSearchResult reduceSearchResults(final List<ReduceableSearchRe }; } + /** + * Get maximum subquery results count to be collected from each shard. + * @param searchContext search context that contains pagination depth + * @return results size to collected + */ + private static int getSubqueryResultsRetrievalSize(final SearchContext searchContext) { + HybridQuery hybridQuery = unwrapHybridQuery(searchContext); + int paginationDepth = hybridQuery.getQueryContext().getPaginationDepth(); + + // Switch to from+size retrieval size during standard hybrid query execution. + if (searchContext.from() == 0) { + return searchContext.size(); + } + log.info("pagination_depth is {}", paginationDepth); + return paginationDepth; + } + + /** + * Unwraps a HybridQuery from either a direct query or a nested BooleanQuery + */ + private static HybridQuery unwrapHybridQuery(final SearchContext searchContext) { + HybridQuery hybridQuery; + Query query = searchContext.query(); + // In case of nested fields and alias filter, hybrid query is wrapped under bool query and lies in the first clause. + if (isHybridQueryWrappedInBooleanQuery(searchContext, searchContext.query())) { + BooleanQuery booleanQuery = (BooleanQuery) query; + hybridQuery = (HybridQuery) booleanQuery.clauses().get(0).getQuery(); + } else { + hybridQuery = (HybridQuery) query; + } + return hybridQuery; + } + /** * Implementation of the HybridCollector that reuses instance of collector on each even call. This allows caller to * use saved state of collector diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index 411127507..aca93b77b 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -29,9 +29,8 @@ import lombok.extern.log4j.Log4j2; -import static org.opensearch.neuralsearch.util.HybridQueryUtil.hasAliasFilter; -import static org.opensearch.neuralsearch.util.HybridQueryUtil.hasNestedFieldOrNestedDocs; import static org.opensearch.neuralsearch.util.HybridQueryUtil.isHybridQuery; +import static org.opensearch.neuralsearch.util.HybridQueryUtil.isHybridQueryWrappedInBooleanQuery; /** * Custom search implementation to be used at {@link QueryPhase} for Hybrid Query search. For queries other than Hybrid the @@ -60,10 +59,6 @@ public boolean searchWith( validateQuery(searchContext, query); return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } else { - // TODO remove this check after following issue https://github.com/opensearch-project/neural-search/issues/280 gets resolved. - if (searchContext.from() != 0) { - throw new IllegalArgumentException("In the current OpenSearch version pagination is not supported with hybrid query"); - } Query hybridQuery = extractHybridQuery(searchContext, query); QueryPhaseSearcher queryPhaseSearcher = getQueryPhaseSearcher(searchContext); queryPhaseSearcher.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); @@ -78,16 +73,9 @@ private QueryPhaseSearcher getQueryPhaseSearcher(final SearchContext searchConte : defaultQueryPhaseSearcherWithEmptyCollectorContext; } - private static boolean isWrappedHybridQuery(final Query query) { - return query instanceof BooleanQuery - && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery); - } - @VisibleForTesting protected Query extractHybridQuery(final SearchContext searchContext, final Query query) { - if ((hasAliasFilter(query, searchContext) || hasNestedFieldOrNestedDocs(query, searchContext)) - && isWrappedHybridQuery(query) - && !((BooleanQuery) query).clauses().isEmpty()) { + if (isHybridQueryWrappedInBooleanQuery(searchContext, query)) { List<BooleanClause> booleanClauses = ((BooleanQuery) query).clauses(); if (!(booleanClauses.get(0).getQuery() instanceof HybridQuery)) { throw new IllegalStateException("cannot process hybrid query due to incorrect structure of top level query"); @@ -97,7 +85,7 @@ && isWrappedHybridQuery(query) .filter(clause -> BooleanClause.Occur.FILTER == clause.getOccur()) .map(BooleanClause::getQuery) .collect(Collectors.toList()); - HybridQuery hybridQueryWithFilter = new HybridQuery(hybridQuery.getSubQueries(), filterQueries); + HybridQuery hybridQueryWithFilter = new HybridQuery(hybridQuery.getSubQueries(), filterQueries, hybridQuery.getQueryContext()); return hybridQueryWithFilter; } return query; diff --git a/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java b/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java index d19985f5c..e8794131f 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java @@ -20,6 +20,9 @@ @NoArgsConstructor(access = AccessLevel.PRIVATE) public class HybridQueryUtil { + /** + * This method validates whether the query object is an instance of hybrid query + */ public static boolean isHybridQuery(final Query query, final SearchContext searchContext) { if (query instanceof HybridQuery) { return true; @@ -52,7 +55,7 @@ public static boolean isHybridQuery(final Query query, final SearchContext searc return false; } - public static boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { + private static boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { return searchContext.mapperService().hasNested() && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query); } @@ -61,7 +64,16 @@ private static boolean isWrappedHybridQuery(final Query query) { && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery); } - public static boolean hasAliasFilter(final Query query, final SearchContext searchContext) { + private static boolean hasAliasFilter(final Query query, final SearchContext searchContext) { return Objects.nonNull(searchContext.aliasFilter()); } + + /** + * This method checks whether hybrid query is wrapped under boolean query object + */ + public static boolean isHybridQueryWrappedInBooleanQuery(final SearchContext searchContext, final Query query) { + return ((hasAliasFilter(query, searchContext) || hasNestedFieldOrNestedDocs(query, searchContext)) + && isWrappedHybridQuery(query) + && !((BooleanQuery) query).clauses().isEmpty()); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index 5f45b14fe..87dac8674 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -274,7 +274,7 @@ public void testEmptySearchResults_whenEmptySearchResults_thenDoNotExecuteWorkfl SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); normalizationProcessor.process(null, searchPhaseContext); - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); + verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any(), any()); } public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResult_thenDoNotExecuteWorkflow() { @@ -330,7 +330,7 @@ public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResul when(searchPhaseContext.getNumShards()).thenReturn(numberOfShards); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); + verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any(), any()); } public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormalization() { @@ -346,6 +346,7 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz ); SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); searchRequest.setBatchedReduceSize(4); AtomicReference<Exception> onPartialMergeFailure = new AtomicReference<>(); QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index 59fb51563..9969081a6 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -12,6 +12,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Optional; @@ -19,6 +20,8 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchRequest; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.core.index.shard.ShardId; import org.opensearch.neuralsearch.util.TestUtils; @@ -29,6 +32,7 @@ import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.FetchSearchResult; import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.search.query.QuerySearchResult; @@ -71,12 +75,18 @@ public void testSearchResultTypes_whenResultsOfHybridSearch_thenDoNormalizationC querySearchResult.setShardIndex(shardId); querySearchResults.add(querySearchResult); } - + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + when(searchRequest.source()).thenReturn(searchSourceBuilder); normalizationProcessorWorkflow.execute( querySearchResults, Optional.empty(), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + searchPhaseContext ); TestUtils.assertQueryResultScores(querySearchResults); @@ -113,12 +123,18 @@ public void testSearchResultTypes_whenNoMatches_thenReturnZeroResults() { querySearchResult.setShardIndex(shardId); querySearchResults.add(querySearchResult); } - + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + when(searchRequest.source()).thenReturn(searchSourceBuilder); normalizationProcessorWorkflow.execute( querySearchResults, Optional.empty(), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + searchPhaseContext ); TestUtils.assertQueryResultScoresWithNoMatches(querySearchResults); @@ -172,12 +188,18 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo new SearchHit(0, "10", Map.of(), Map.of()), }; SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); fetchSearchResult.hits(searchHits); - + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + when(searchRequest.source()).thenReturn(searchSourceBuilder); normalizationProcessorWorkflow.execute( querySearchResults, Optional.of(fetchSearchResult), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + searchPhaseContext ); TestUtils.assertQueryResultScores(querySearchResults); @@ -232,12 +254,18 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom new SearchHit(-1, "10", Map.of(), Map.of()), }; SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); fetchSearchResult.hits(searchHits); - + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + when(searchRequest.source()).thenReturn(searchSourceBuilder); normalizationProcessorWorkflow.execute( querySearchResults, Optional.of(fetchSearchResult), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + searchPhaseContext ); TestUtils.assertQueryResultScores(querySearchResults); @@ -284,14 +312,20 @@ public void testFetchResultsAndNoCache_whenOneShardAndMultipleNodesAndMismatchRe querySearchResults.add(querySearchResult); SearchHits searchHits = getSearchHits(); fetchSearchResult.hits(searchHits); - + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + when(searchRequest.source()).thenReturn(searchSourceBuilder); expectThrows( IllegalStateException.class, () -> normalizationProcessorWorkflow.execute( querySearchResults, Optional.of(fetchSearchResult), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + searchPhaseContext ) ); } @@ -336,18 +370,88 @@ public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResu querySearchResults.add(querySearchResult); SearchHits searchHits = getSearchHits(); fetchSearchResult.hits(searchHits); - + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + when(searchRequest.source()).thenReturn(searchSourceBuilder); normalizationProcessorWorkflow.execute( querySearchResults, Optional.of(fetchSearchResult), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + searchPhaseContext ); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); } + public void testNormalization_whenFromIsGreaterThanResultsSize_thenFail() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + + List<QuerySearchResult> querySearchResults = new ArrayList<>(); + for (int shardId = 0; shardId < 4; shardId++) { + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + querySearchResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + new ScoreDoc(4, 0.25f), + new ScoreDoc(10, 0.2f), + createStartStopElementForHybridSearchResults(0) } + ), + 0.5f + ), + null + ); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + // requested page is out of bound for the total number of results + querySearchResult.from(17); + querySearchResults.add(querySearchResult); + } + + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + when(searchPhaseContext.getNumShards()).thenReturn(4); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(17); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + + NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDto = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.empty()) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> normalizationProcessorWorkflow.execute(normalizationExecuteDto) + ); + + assertEquals( + String.format(Locale.ROOT, "Reached end of search result, increase pagination_depth value to see more results"), + illegalArgumentException.getMessage() + ); + } + private static SearchHits getSearchHits() { SearchHit[] searchHitArray = new SearchHit[] { new SearchHit(-1, "10", Map.of(), Map.of()), diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index 1640d8e02..a6cf4d29e 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -11,6 +11,7 @@ import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.DEFAULT_BOOST; +import static org.opensearch.index.remote.RemoteStoreEnums.PathType.HASHED_PREFIX; import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; import static org.opensearch.neuralsearch.util.TestUtils.xContentBuilderToMap; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.K_FIELD; @@ -33,7 +34,9 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.Version; +import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.UUIDs; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; @@ -50,6 +53,7 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexSettings; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.MatchAllQueryBuilder; @@ -57,6 +61,7 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.remote.RemoteStoreEnums; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; @@ -119,6 +124,7 @@ public void testDoToQuery_whenNoSubqueries_thenBuildSuccessfully() { @SneakyThrows public void testDoToQuery_whenOneSubquery_thenBuildSuccessfully() { HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + queryBuilder.paginationDepth(10); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); @@ -130,6 +136,10 @@ public void testDoToQuery_whenOneSubquery_thenBuildSuccessfully() { when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(3)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder() .fieldName(VECTOR_FIELD_NAME) @@ -154,6 +164,7 @@ public void testDoToQuery_whenOneSubquery_thenBuildSuccessfully() { @SneakyThrows public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() { HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + queryBuilder.paginationDepth(10); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); @@ -165,6 +176,10 @@ public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() { when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(3)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder() .fieldName(VECTOR_FIELD_NAME) @@ -201,6 +216,81 @@ public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() { assertEquals(TERM_QUERY_TEXT, termQuery.getTerm().text()); } + @SneakyThrows + public void testDoToQuery_whenPaginationDepthIsGreaterThan10000_thenBuildSuccessfully() { + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + queryBuilder.paginationDepth(10001); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); + KNNMappingConfig mockKNNMappingConfig = mock(KNNMappingConfig.class); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, MethodComponentContext.EMPTY); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(mockKNNMappingConfig); + when(mockKNNMappingConfig.getKnnMethodContext()).thenReturn(Optional.of(knnMethodContext)); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(3)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + + NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder() + .fieldName(VECTOR_FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER) + .build(); + + queryBuilder.add(neuralQueryBuilder); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> queryBuilder.doToQuery(mockQueryShardContext) + ); + assertThat( + exception.getMessage(), + containsString("pagination_depth should be less than or equal to index.max_result_window setting") + ); + } + + @SneakyThrows + public void testDoToQuery_whenPaginationDepthIsLessThanZero_thenBuildSuccessfully() { + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + queryBuilder.paginationDepth(-1); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); + KNNMappingConfig mockKNNMappingConfig = mock(KNNMappingConfig.class); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, MethodComponentContext.EMPTY); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(mockKNNMappingConfig); + when(mockKNNMappingConfig.getKnnMethodContext()).thenReturn(Optional.of(knnMethodContext)); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(3)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + + NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder() + .fieldName(VECTOR_FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER) + .build(); + + queryBuilder.add(neuralQueryBuilder); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> queryBuilder.doToQuery(mockQueryShardContext) + ); + assertThat(exception.getMessage(), containsString("pagination_depth should be greater than 0")); + } + @SneakyThrows public void testDoToQuery_whenTooManySubqueries_thenFail() { // create query with 6 sub-queries, which is more than current max allowed @@ -336,6 +426,7 @@ public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() { assertEquals(2, queryTwoSubQueries.queries().size()); assertTrue(queryTwoSubQueries.queries().get(0) instanceof NeuralQueryBuilder); assertTrue(queryTwoSubQueries.queries().get(1) instanceof TermQueryBuilder); + assertEquals(10, queryTwoSubQueries.paginationDepth().intValue()); // verify knn vector query NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryTwoSubQueries.queries().get(0); assertEquals(VECTOR_FIELD_NAME, neuralQueryBuilder.fieldName()); @@ -409,6 +500,7 @@ public void testFromXContent_whenIncorrectFormat_thenFail() { @SneakyThrows public void testToXContent_whenIncomingJsonIsCorrect_thenSuccessful() { + setUpClusterService(); HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); @@ -537,6 +629,7 @@ public void testHashAndEquals_whenSameOrIdenticalObject_thenReturnEqual() { } public void testHashAndEquals_whenSubQueriesDifferent_thenReturnNotEqual() { + setUpClusterService(); String modelId = "testModelId"; String fieldName = "fieldTwo"; String queryText = "query text"; @@ -637,6 +730,7 @@ public void testHashAndEquals_whenSubQueriesDifferent_thenReturnNotEqual() { @SneakyThrows public void testRewrite_whenMultipleSubQueries_thenReturnBuilderForEachSubQuery() { + setUpClusterService(); HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder() .fieldName(VECTOR_FIELD_NAME) @@ -744,6 +838,7 @@ public void testBoost_whenNonDefaultBoostSet_thenFail() { @SneakyThrows public void testBoost_whenDefaultBoostSet_thenBuildSuccessfully() { + setUpClusterService(); // create query with 6 sub-queries, which is more than current max allowed XContentBuilder xContentBuilderWithNonDefaultBoost = XContentFactory.jsonBuilder() .startObject() @@ -794,6 +889,10 @@ public void testBuild_whenValidParameters_thenCreateQuery() { MappedFieldType fieldType = mock(MappedFieldType.class); when(context.fieldMapper(fieldName)).thenReturn(fieldType); when(fieldType.typeName()).thenReturn("rank_features"); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(3)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(context.getIndexSettings()).thenReturn(indexSettings); // Create HybridQueryBuilder instance (no spy since it's final) NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder(); @@ -802,6 +901,7 @@ public void testBuild_whenValidParameters_thenCreateQuery() { .modelId(modelId) .queryTokensSupplier(() -> Map.of("token1", 1.0f, "token2", 0.5f)); HybridQueryBuilder builder = new HybridQueryBuilder().add(neuralSparseQueryBuilder); + builder.paginationDepth(10); // Build query Query query = builder.toQuery(context); @@ -813,6 +913,7 @@ public void testBuild_whenValidParameters_thenCreateQuery() { @SneakyThrows public void testDoEquals_whenSameParameters_thenEqual() { + setUpClusterService(); // Create neural queries NeuralQueryBuilder neuralQueryBuilder1 = NeuralQueryBuilder.builder() .fieldName("test") @@ -894,4 +995,25 @@ private void initKNNSettings() { when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, defaultClusterSettings)); KNNSettings.state().setClusterService(clusterService); } + + private static IndexMetadata getIndexMetadata() { + Map<String, String> remoteCustomData = Map.of( + RemoteStoreEnums.PathType.NAME, + HASHED_PREFIX.name(), + RemoteStoreEnums.PathHashAlgorithm.NAME, + RemoteStoreEnums.PathHashAlgorithm.FNV_1A_BASE64.name(), + IndexMetadata.TRANSLOG_METADATA_KEY, + "false" + ); + Settings idxSettings = Settings.builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .put(IndexMetadata.SETTING_INDEX_UUID, UUIDs.randomBase64UUID()) + .build(); + IndexMetadata indexMetadata = new IndexMetadata.Builder("test").settings(idxSettings) + .numberOfShards(1) + .numberOfReplicas(0) + .putCustom(IndexMetadata.REMOTE_STORE_CUSTOM_KEY, remoteCustomData) + .build(); + return indexMetadata; + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 610e08dd0..c3087a1e4 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -28,6 +28,7 @@ import org.junit.Before; import org.opensearch.client.ResponseException; import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.NestedQueryBuilder; import org.opensearch.index.query.QueryBuilder; @@ -793,21 +794,130 @@ public void testConcurrentSearchWithMultipleSlices_whenMultipleShardsIndex_thenS } } - // TODO remove this test after following issue https://github.com/opensearch-project/neural-search/issues/280 gets resolved. @SneakyThrows - public void testHybridQuery_whenFromIsSetInSearchRequest_thenFail() { + public void testPaginationOnSingleShard_whenConcurrentSearchEnabled_thenSuccessful() { try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); - MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); - HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder(); - hybridQueryBuilderOnlyTerm.add(matchQueryBuilder); + testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); + } + } - ResponseException exceptionNoNestedTypes = expectThrows( + @SneakyThrows + public void testPaginationOnSingleShard_whenConcurrentSearchDisabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testPaginationOnMultipleShard_whenConcurrentSearchEnabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testPaginationOnMultipleShard_whenConcurrentSearchDisabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(String indexName) { + HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); + hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); + hybridQueryBuilderOnlyMatchAll.paginationDepth(10); + + Map<String, Object> searchResponseAsMap = search( + indexName, + hybridQueryBuilderOnlyMatchAll, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + null, + false, + null, + 2 + ); + + assertEquals(2, getHitCount(searchResponseAsMap)); + Map<String, Object> total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(4, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } + + @SneakyThrows + public void testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); + hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); + + Map<String, Object> searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, + hybridQueryBuilderOnlyMatchAll, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + null, + false, + null, + 2 + ); + + assertEquals(2, getHitCount(searchResponseAsMap)); + Map<String, Object> total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(4, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); + hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); + + ResponseException responseException = assertThrows( ResponseException.class, () -> search( TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, - hybridQueryBuilderOnlyTerm, + hybridQueryBuilderOnlyMatchAll, null, 10, Map.of("search_pipeline", SEARCH_PIPELINE), @@ -816,18 +926,50 @@ public void testHybridQuery_whenFromIsSetInSearchRequest_thenFail() { null, false, null, - 10 + 5 ) - ); org.hamcrest.MatcherAssert.assertThat( - exceptionNoNestedTypes.getMessage(), - allOf( - containsString("In the current OpenSearch version pagination is not supported with hybrid query"), - containsString("illegal_argument_exception") + responseException.getMessage(), + allOf(containsString("Reached end of search result, increase pagination_depth value to see more results")) + ); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); + hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); + hybridQueryBuilderOnlyMatchAll.paginationDepth(100001); + + ResponseException responseException = assertThrows( + ResponseException.class, + () -> search( + TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, + hybridQueryBuilderOnlyMatchAll, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + null, + false, + null, + 0 ) ); + + org.hamcrest.MatcherAssert.assertThat( + responseException.getMessage(), + allOf(containsString("pagination_depth should be less than or equal to index.max_result_window setting")) + ); } finally { wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java index 15f0621e8..26babdbce 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java @@ -72,16 +72,19 @@ public void testQueryBasics_whenMultipleDifferentQueries_thenSuccessful() { when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); HybridQuery query1 = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + new HybridQueryContext(10) ); HybridQuery query2 = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + new HybridQueryContext(10) ); HybridQuery query3 = new HybridQuery( List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT).toQuery(mockQueryShardContext) - ) + ), + new HybridQueryContext(10) ); QueryUtils.check(query1); QueryUtils.checkEqual(query1, query2); @@ -96,6 +99,7 @@ public void testQueryBasics_whenMultipleDifferentQueries_thenSuccessful() { countOfQueries++; } assertEquals(2, countOfQueries); + assertEquals(10, query3.getQueryContext().getPaginationDepth().intValue()); } @SneakyThrows @@ -103,6 +107,7 @@ public void testRewrite_whenRewriteQuery_thenSuccessful() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + String field1Value = "text1"; Directory directory = newDirectory(); final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); @@ -120,14 +125,18 @@ public void testRewrite_whenRewriteQuery_thenSuccessful() { // Test with TermQuery HybridQuery hybridQueryWithTerm = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + new HybridQueryContext(10) ); Query rewritten = hybridQueryWithTerm.rewrite(reader); // term query is the same after we rewrite it assertSame(hybridQueryWithTerm, rewritten); // Test empty query list - IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of())); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> new HybridQuery(List.of(), new HybridQueryContext(10)) + ); assertThat(exception.getMessage(), containsString("collection of queries must not be empty")); w.close(); @@ -160,7 +169,8 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithMatch_thenRetu IndexSearcher searcher = newSearcher(reader); HybridQuery query = new HybridQuery( - List.of(new TermQuery(new Term(TEXT_FIELD_NAME, field1Value)), new TermQuery(new Term(TEXT_FIELD_NAME, field2Value))) + List.of(new TermQuery(new Term(TEXT_FIELD_NAME, field1Value)), new TermQuery(new Term(TEXT_FIELD_NAME, field2Value))), + new HybridQueryContext(10) ); // executing search query, getting up to 3 docs in result TopDocs hybridQueryResult = searcher.search(query, 3); @@ -206,7 +216,7 @@ public void testWithRandomDocuments_whenOneTermSubQueryWithoutMatch_thenReturnSu DirectoryReader reader = DirectoryReader.open(w); IndexSearcher searcher = newSearcher(reader); - HybridQuery query = new HybridQuery(List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)))); + HybridQuery query = new HybridQuery(List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), new HybridQueryContext(10)); // executing search query, getting up to 3 docs in result TopDocs hybridQueryResult = searcher.search(query, 3); @@ -242,7 +252,8 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithoutMatch_thenR IndexSearcher searcher = newSearcher(reader); HybridQuery query = new HybridQuery( - List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)), new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))) + List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)), new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), + new HybridQueryContext(10) ); // executing search query, getting up to 3 docs in result TopDocs hybridQueryResult = searcher.search(query, 3); @@ -256,10 +267,25 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithoutMatch_thenR @SneakyThrows public void testWithRandomDocuments_whenNoSubQueries_thenFail() { - IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of())); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> new HybridQuery(List.of(), new HybridQueryContext(10)) + ); assertThat(exception.getMessage(), containsString("collection of queries must not be empty")); } + @SneakyThrows + public void testWithRandomDocuments_whenPaginationDepthIsZero_thenFail() { + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> new HybridQuery( + List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)), new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), + new HybridQueryContext(0) + ) + ); + assertThat(exception.getMessage(), containsString("pagination_depth must not be zero")); + } + @SneakyThrows public void testToString_whenCallQueryToString_thenSuccessful() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); @@ -273,7 +299,8 @@ public void testToString_whenCallQueryToString_thenSuccessful() { new BoolQueryBuilder().should(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT)) .should(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT)) .toQuery(mockQueryShardContext) - ) + ), + new HybridQueryContext(10) ); String queryString = query.toString(TEXT_FIELD_NAME); @@ -293,7 +320,8 @@ public void testFilter_whenSubQueriesWithFilterPassed_thenSuccessful() { QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT).toQuery(mockQueryShardContext) ), - List.of(filter) + List.of(filter), + new HybridQueryContext(10) ); QueryUtils.check(hybridQuery); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java index 0e32b5e78..024c5e6e3 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java @@ -61,7 +61,8 @@ public void testScorerIterator_whenExecuteQuery_thenScorerIteratorSuccessful() { IndexReader reader = DirectoryReader.open(w); HybridQuery hybridQueryWithTerm = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + new HybridQueryContext(10) ); IndexSearcher searcher = newSearcher(reader); Weight weight = hybridQueryWithTerm.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f); @@ -117,7 +118,8 @@ public void testSubQueries_whenMultipleEqualSubQueries_thenSuccessful() { .rewrite(mockQueryShardContext) .toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext) - ) + ), + new HybridQueryContext(10) ); IndexSearcher searcher = newSearcher(reader); Weight weight = hybridQueryWithTerm.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f); @@ -164,7 +166,8 @@ public void testExplain_whenCallExplain_thenSuccessful() { IndexReader reader = DirectoryReader.open(w); HybridQuery hybridQueryWithTerm = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + new HybridQueryContext(10) ); IndexSearcher searcher = newSearcher(reader); Weight weight = searcher.createWeight(hybridQueryWithTerm, ScoreMode.COMPLETE, 1.0f); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java index f44e762f0..acbc2148c 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java @@ -15,11 +15,13 @@ import org.opensearch.action.OriginalIndices; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.neuralsearch.query.HybridQuery; +import org.opensearch.neuralsearch.query.HybridQueryContext; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchShardTarget; @@ -69,9 +71,12 @@ public void testCollectorManager_whenHybridQueryAndNotConcurrentSearch_thenSucce TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); when(searchContext.query()).thenReturn(hybridQuery); + MapperService mapperService = mock(MapperService.class); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexSearcher.getIndexReader()).thenReturn(indexReader); @@ -129,9 +134,12 @@ public void testCollectorManager_whenHybridQueryAndConcurrentSearch_thenSuccessf TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); when(searchContext.query()).thenReturn(hybridQuery); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexSearcher.getIndexReader()).thenReturn(indexReader); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index 24ebebe5b..e0d95f24e 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -36,6 +36,7 @@ import org.apache.lucene.tests.analysis.MockAnalyzer; import org.opensearch.common.lucene.search.FilteredCollector; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.BoostingQueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -44,6 +45,7 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.ParsedQuery; import org.opensearch.neuralsearch.query.HybridQuery; +import org.opensearch.neuralsearch.query.HybridQueryContext; import org.opensearch.neuralsearch.query.HybridQueryWeight; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import org.opensearch.neuralsearch.search.collector.HybridTopScoreDocCollector; @@ -52,12 +54,14 @@ import org.opensearch.neuralsearch.search.query.exception.HybridSearchRescoreQueryException; import org.opensearch.search.DocValueFormat; import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.ScrollContext; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.ReduceableSearchResult; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import static org.mockito.ArgumentMatchers.any; @@ -88,11 +92,14 @@ public class HybridCollectorManagerTests extends OpenSearchQueryTestCase { public void testNewCollector_whenNotConcurrentSearch_thenSuccessful() { SearchContext searchContext = mock(SearchContext.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MapperService mapperService = createMapperService(); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); + when(searchContext.mapperService()).thenReturn(mapperService); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); @@ -122,8 +129,11 @@ public void testNewCollector_whenConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); @@ -153,8 +163,11 @@ public void testPostFilter_whenNotConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, "world"); ParsedQuery parsedQuery = new ParsedQuery(postFilterQuery.toQuery(mockQueryShardContext)); searchContext.parsedQuery(parsedQuery); @@ -197,7 +210,11 @@ public void testPostFilter_whenConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); + + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, "world"); Query pfQuery = postFilterQuery.toQuery(mockQueryShardContext); @@ -240,9 +257,14 @@ public void testReduce_whenMatchedDocs_thenSuccessful() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); HybridQuery hybridQueryWithTerm = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext)), + hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -343,9 +365,13 @@ public void testNewCollector_whenNotConcurrentSearchAndSortingIsApplied_thenSucc TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); when(searchContext.query()).thenReturn(hybridQuery); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexSearcher.getIndexReader()).thenReturn(indexReader); @@ -380,9 +406,13 @@ public void testNewCollector_whenNotConcurrentSearchAndSortingAndSearchAfterAreA TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); when(searchContext.query()).thenReturn(hybridQuery); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexSearcher.getIndexReader()).thenReturn(indexReader); @@ -410,8 +440,12 @@ public void testReduce_whenMatchedDocsAndSortingIsApplied_thenSuccessful() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); - HybridQuery hybridQueryWithMatchAll = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext))); + HybridQuery hybridQueryWithMatchAll = new HybridQuery( + List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), + hybridQueryContext + ); when(searchContext.query()).thenReturn(hybridQueryWithMatchAll); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); @@ -420,6 +454,9 @@ public void testReduce_whenMatchedDocsAndSortingIsApplied_thenSuccessful() { when(searchContext.searcher()).thenReturn(indexSearcher); when(searchContext.size()).thenReturn(1); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); + Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> classCollectorManagerMap = new HashMap<>(); when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); @@ -503,14 +540,18 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery hybridQueryWithTerm = new HybridQuery( List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) - ) + ), + hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexReader.numDocs()).thenReturn(3); @@ -593,9 +634,15 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); - HybridQuery hybridQueryWithTerm = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext))); + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), + hybridQueryContext + ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexReader.numDocs()).thenReturn(2); @@ -718,14 +765,18 @@ public void testReduceAndRescore_whenMatchedDocsAndRescoreContextPresent_thenSuc QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery hybridQueryWithTerm = new HybridQuery( List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) - ) + ), + hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexReader.numDocs()).thenReturn(3); @@ -835,15 +886,19 @@ public void testRescoreWithConcurrentSegmentSearch_whenMatchedDocsAndRescore_the QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery hybridQueryWithTerm = new HybridQuery( List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY3).toQuery(mockQueryShardContext) - ) + ), + hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexReader.numDocs()).thenReturn(2); @@ -979,14 +1034,18 @@ public void testReduceAndRescore_whenRescorerThrowsException_thenFail() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery hybridQueryWithTerm = new HybridQuery( List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) - ) + ), + hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexReader.numDocs()).thenReturn(3); @@ -1042,4 +1101,73 @@ public void testReduceAndRescore_whenRescorerThrowsException_thenFail() { reader.close(); directory.close(); } + + @SneakyThrows + public void testCreateCollectorManager_whenFromAreEqualToZeroAndPaginationDepthInRange_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + + // pagination_depth=10 + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); + + when(searchContext.query()).thenReturn(hybridQuery); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + assertNotNull(hybridCollectorManager); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorNonConcurrentManager); + + Collector collector = hybridCollectorManager.newCollector(); + assertNotNull(collector); + assertTrue(collector instanceof HybridTopScoreDocCollector); + + Collector secondCollector = hybridCollectorManager.newCollector(); + assertSame(collector, secondCollector); + } + + @SneakyThrows + public void testScrollWithHybridQuery_thenFail() { + SearchContext searchContext = mock(SearchContext.class); + ScrollContext scrollContext = new ScrollContext(); + when(searchContext.scrollContext()).thenReturn(scrollContext); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> HybridCollectorManager.createHybridCollectorManager(searchContext) + ); + assertEquals( + String.format(Locale.ROOT, "Scroll operation is not supported in hybrid query"), + illegalArgumentException.getMessage() + ); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index a8cad5ec7..2aafa2ece 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -138,6 +138,10 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); when(searchContext.mapperService()).thenReturn(mapperService); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); LinkedList<QueryCollectorContext> collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -150,6 +154,7 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { TermQueryBuilder termSubQuery2 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2); queryBuilder.add(termSubQuery1); queryBuilder.add(termSubQuery2); + queryBuilder.paginationDepth(10); Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); @@ -287,6 +292,10 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { when(searchContext.queryResult()).thenReturn(querySearchResult); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); when(searchContext.mapperService()).thenReturn(mapperService); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); LinkedList<QueryCollectorContext> collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -296,6 +305,7 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); queryBuilder.add(termSubQuery); + queryBuilder.paginationDepth(10); Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); @@ -372,6 +382,10 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); when(searchContext.mapperService()).thenReturn(mapperService); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); LinkedList<QueryCollectorContext> collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -382,6 +396,7 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); queryBuilder.add(QueryBuilders.matchAllQuery()); + queryBuilder.paginationDepth(10); Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); @@ -473,6 +488,7 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBool_thenFail() { queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + queryBuilder.paginationDepth(10); TermQueryBuilder termQuery3 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery().should(queryBuilder).should(termQuery3); @@ -578,6 +594,7 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolAndIncorrectStructur queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + queryBuilder.paginationDepth(10); BooleanQuery.Builder builder = new BooleanQuery.Builder(); builder.add(Queries.newNonNestedFilter(), BooleanClause.Occur.FILTER) @@ -694,6 +711,7 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_then queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + queryBuilder.paginationDepth(10); BooleanQuery.Builder builder = new BooleanQuery.Builder(); builder.add(queryBuilder.toQuery(mockQueryShardContext), BooleanClause.Occur.MUST) @@ -868,6 +886,10 @@ public void testAggregations_whenMetricAggregation_thenSuccessful() { when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); when(searchContext.mapperService()).thenReturn(mapperService); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); LinkedList<QueryCollectorContext> collectors = new LinkedList<>(); @@ -881,6 +903,7 @@ public void testAggregations_whenMetricAggregation_thenSuccessful() { TermQueryBuilder termSubQuery2 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2); queryBuilder.add(termSubQuery1); queryBuilder.add(termSubQuery2); + queryBuilder.paginationDepth(10); Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); @@ -965,6 +988,10 @@ public void testAliasWithFilter_whenHybridWrappedIntoBoolBecauseOfIndexAlias_the when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); when(searchContext.mapperService()).thenReturn(mapperService); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); LinkedList<QueryCollectorContext> collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -974,6 +1001,7 @@ public void testAliasWithFilter_whenHybridWrappedIntoBoolBecauseOfIndexAlias_the queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + queryBuilder.paginationDepth(10); Query termFilter = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1).toQuery(mockQueryShardContext); BooleanQuery.Builder builder = new BooleanQuery.Builder(); diff --git a/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java b/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java index be9dbc2cc..ab882b388 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java @@ -6,20 +6,29 @@ import lombok.SneakyThrows; import org.apache.lucene.search.Query; +import org.opensearch.Version; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.common.UUIDs; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.IndexSettings; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.remote.RemoteStoreEnums; import org.opensearch.neuralsearch.query.HybridQuery; import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.HybridQueryContext; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import org.opensearch.search.internal.SearchContext; import java.util.List; +import java.util.Map; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.index.remote.RemoteStoreEnums.PathType.HASHED_PREFIX; public class HybridQueryUtilTests extends OpenSearchQueryTestCase { @@ -34,6 +43,7 @@ public void testIsHybridQueryCheck_whenQueryIsHybridQueryInstance_thenSuccess() QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery query = new HybridQuery( List.of( @@ -45,7 +55,8 @@ public void testIsHybridQueryCheck_whenQueryIsHybridQueryInstance_thenSuccess() .rewrite(mockQueryShardContext) .toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext) - ) + ), + hybridQueryContext ); SearchContext searchContext = mock(SearchContext.class); @@ -58,13 +69,17 @@ public void testIsHybridQueryCheck_whenHybridWrappedIntoBoolAndNoNested_thenSucc MapperService mapperService = createMapperService(); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); hybridQueryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT)); hybridQueryBuilder.add( QueryBuilders.rangeQuery(RANGE_FIELD).from(FROM_TEXT).to(TO_TEXT).rewrite(mockQueryShardContext).rewrite(mockQueryShardContext) ); - + hybridQueryBuilder.paginationDepth(10); Query booleanQuery = QueryBuilders.boolQuery() .should(hybridQueryBuilder) .should(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT)) @@ -97,4 +112,25 @@ public void testIsHybridQueryCheck_whenNoHybridQuery_thenSuccess() { assertFalse(HybridQueryUtil.isHybridQuery(booleanQuery, searchContext)); } + + private static IndexMetadata getIndexMetadata() { + Map<String, String> remoteCustomData = Map.of( + RemoteStoreEnums.PathType.NAME, + HASHED_PREFIX.name(), + RemoteStoreEnums.PathHashAlgorithm.NAME, + RemoteStoreEnums.PathHashAlgorithm.FNV_1A_BASE64.name(), + IndexMetadata.TRANSLOG_METADATA_KEY, + "false" + ); + Settings idxSettings = Settings.builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .put(IndexMetadata.SETTING_INDEX_UUID, UUIDs.randomBase64UUID()) + .build(); + IndexMetadata indexMetadata = new IndexMetadata.Builder("test").settings(idxSettings) + .numberOfShards(1) + .numberOfReplicas(0) + .putCustom(IndexMetadata.REMOTE_STORE_CUSTOM_KEY, remoteCustomData) + .build(); + return indexMetadata; + } } diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index f4d4a3c40..509527aeb 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -600,13 +600,11 @@ protected Map<String, Object> search( if (requestParams != null && !requestParams.isEmpty()) { requestParams.forEach(request::addParameter); } - logger.info("Sorting request " + builder.toString()); request.setJsonEntity(builder.toString()); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); String responseBody = EntityUtils.toString(response.getEntity()); - logger.info("Response " + responseBody); return XContentHelper.convertToMap(XContentType.JSON.xContent(), responseBody, false); }