Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add integration and unit tests for missing RRF coverage #997

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.Optional;

import lombok.Getter;
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.fetch.FetchSearchResult;
Expand Down Expand Up @@ -98,7 +99,8 @@ public boolean isIgnoreFailure() {
return false;
}

private <Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPhaseResults<Result> searchPhaseResult) {
@VisibleForTesting
<Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPhaseResults<Result> searchPhaseResult) {
if (Objects.isNull(searchPhaseResult) || !(searchPhaseResult instanceof QueryPhaseResultConsumer queryPhaseResultConsumer)) {
return true;
}
Expand All @@ -111,7 +113,8 @@ private <Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPha
* @param searchPhaseResult
* @return true if results are from hybrid query
*/
private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) {
@VisibleForTesting
boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) {
// check for delimiter at the end of the score docs.
return Objects.nonNull(searchPhaseResult.queryResult())
&& Objects.nonNull(searchPhaseResult.queryResult().topDocs())
Expand All @@ -120,17 +123,16 @@ private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) {
&& isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]);
}

private <Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhaseSearchResults(
final SearchPhaseResults<Result> results
) {
<Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhaseSearchResults(final SearchPhaseResults<Result> results) {
return results.getAtomicArray()
.asList()
.stream()
.map(result -> result == null ? null : result.queryResult())
.collect(Collectors.toList());
}

private <Result extends SearchPhaseResult> Optional<FetchSearchResult> getFetchSearchResults(
@VisibleForTesting
<Result extends SearchPhaseResult> Optional<FetchSearchResult> getFetchSearchResults(
final SearchPhaseResults<Result> searchPhaseResults
) {
Optional<Result> optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import lombok.SneakyThrows;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.neuralsearch.BaseNeuralSearchIT;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION;
import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE;

public class RRFProcessorIT extends BaseNeuralSearchIT {

private int currentDoc = 1;
private static final String RRF_INDEX_NAME = "rrf-index";
private static final String RRF_SEARCH_PIPELINE = "rrf-search-pipeline";
private static final String RRF_INGEST_PIPELINE = "rrf-ingest-pipeline";

private static final int RRF_DIMENSION = 5;

@SneakyThrows
public void testRRF_whenValidInput_thenSucceed() {
try {
createPipelineProcessor(null, RRF_INGEST_PIPELINE, ProcessorType.TEXT_EMBEDDING);
prepareKnnIndex(
RRF_INDEX_NAME,
Collections.singletonList(new KNNFieldConfig("passage_embedding", RRF_DIMENSION, TEST_SPACE_TYPE))
);
addDocuments();
createDefaultRRFSearchPipeline();

HybridQueryBuilder hybridQueryBuilder = getHybridQueryBuilder();

Map<String, Object> results = search(
RRF_INDEX_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", RRF_SEARCH_PIPELINE)
);
Map<String, Object> hits = (Map<String, Object>) results.get("hits");
ArrayList<HashMap<String, Object>> hitsList = (ArrayList<HashMap<String, Object>>) hits.get("hits");
assertEquals(3, hitsList.size());
assertEquals(0.016393442, (Double) hitsList.getFirst().get("_score"), DELTA_FOR_SCORE_ASSERTION);
assertEquals(0.016129032, (Double) hitsList.get(1).get("_score"), DELTA_FOR_SCORE_ASSERTION);
assertEquals(0.015873017, (Double) hitsList.getLast().get("_score"), DELTA_FOR_SCORE_ASSERTION);
} finally {
wipeOfTestResources(RRF_INDEX_NAME, RRF_INGEST_PIPELINE, null, RRF_SEARCH_PIPELINE);
}
}

private HybridQueryBuilder getHybridQueryBuilder() {
MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", "cowboy rodeo bronco");
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder.Builder().fieldName("passage_embedding")
.k(5)
.vector(new float[] { 0.1f, 1.2f, 2.3f, 3.4f, 4.5f })
.build();

HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
hybridQueryBuilder.add(matchQueryBuilder);
hybridQueryBuilder.add(knnQueryBuilder);
return hybridQueryBuilder;
}

@SneakyThrows
private void addDocuments() {
addDocument(
"A West Virginia university women 's basketball team , officials , and a small gathering of fans are in a West Virginia arena .",
"4319130149.jpg"
);
addDocument("A wild animal races across an uncut field with a minimal amount of trees .", "1775029934.jpg");
addDocument(
"People line the stands which advertise Freemont 's orthopedics , a cowboy rides a light brown bucking bronco .",
"2664027527.jpg"
);
addDocument("A man who is riding a wild horse in the rodeo is very near to falling off .", "4427058951.jpg");
addDocument("A rodeo cowboy , wearing a cowboy hat , is being thrown off of a wild white horse .", "2691147709.jpg");
}

@SneakyThrows
private void addDocument(String description, String imageText) {
addDocument(RRF_INDEX_NAME, String.valueOf(currentDoc++), "text", description, "image_text", imageText);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import lombok.SneakyThrows;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.junit.Before;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.OriginalIndices;
import org.opensearch.action.search.QueryPhaseResultConsumer;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchPhaseName;
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.support.IndicesOptions;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.common.util.concurrent.AtomicArray;
import org.opensearch.core.common.Strings;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.internal.AliasFilter;
import org.opensearch.search.internal.ShardSearchContextId;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.test.OpenSearchTestCase;

import java.util.List;
import java.util.Optional;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class RRFProcessorTests extends OpenSearchTestCase {

@Mock
private ScoreNormalizationTechnique mockNormalizationTechnique;
@Mock
private ScoreCombinationTechnique mockCombinationTechnique;
@Mock
private NormalizationProcessorWorkflow mockNormalizationWorkflow;
@Mock
private SearchPhaseResults<SearchPhaseResult> mockSearchPhaseResults;
@Mock
private SearchPhaseContext mockSearchPhaseContext;
@Mock
private QueryPhaseResultConsumer mockQueryPhaseResultConsumer;

private RRFProcessor rrfProcessor;
private static final String TAG = "tag";
private static final String DESCRIPTION = "description";

@Before
@SneakyThrows
public void setUp() {
super.setUp();
MockitoAnnotations.openMocks(this);
rrfProcessor = new RRFProcessor(TAG, DESCRIPTION, mockNormalizationTechnique, mockCombinationTechnique, mockNormalizationWorkflow);
}

@SneakyThrows
public void testGetType() {
assertEquals(RRFProcessor.TYPE, rrfProcessor.getType());
}

@SneakyThrows
public void testGetBeforePhase() {
assertEquals(SearchPhaseName.QUERY, rrfProcessor.getBeforePhase());
}

@SneakyThrows
public void testGetAfterPhase() {
assertEquals(SearchPhaseName.FETCH, rrfProcessor.getAfterPhase());
}

@SneakyThrows
public void testIsIgnoreFailure() {
assertFalse(rrfProcessor.isIgnoreFailure());
}

@SneakyThrows
public void testProcess_whenNullSearchPhaseResult_thenSkipWorkflow() {
rrfProcessor.process(null, mockSearchPhaseContext);
verify(mockNormalizationWorkflow, never()).execute(any());
}

@SneakyThrows
public void testProcess_whenNonQueryPhaseResultConsumer_thenSkipWorkflow() {
rrfProcessor.process(mockSearchPhaseResults, mockSearchPhaseContext);
verify(mockNormalizationWorkflow, never()).execute(any());
}

@SneakyThrows
public void testProcess_whenValidHybridInput_thenSucceed() {
QuerySearchResult result = createQuerySearchResult(true);
AtomicArray<SearchPhaseResult> atomicArray = new AtomicArray<>(1);
atomicArray.set(0, result);

when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext);

verify(mockNormalizationWorkflow).execute(any(NormalizationExecuteDTO.class));
}

@SneakyThrows
public void testProcess_whenValidNonHybridInput_thenSucceed() {
QuerySearchResult result = createQuerySearchResult(false);
AtomicArray<SearchPhaseResult> atomicArray = new AtomicArray<>(1);
atomicArray.set(0, result);

when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext);

verify(mockNormalizationWorkflow, never()).execute(any(NormalizationExecuteDTO.class));
}

@SneakyThrows
public void testGetTag() {
assertEquals(TAG, rrfProcessor.getTag());
}

@SneakyThrows
public void testGetDescription() {
assertEquals(DESCRIPTION, rrfProcessor.getDescription());
}

@SneakyThrows
public void testShouldSkipProcessor() {
assertTrue(rrfProcessor.shouldSkipProcessor(null));
assertTrue(rrfProcessor.shouldSkipProcessor(mockSearchPhaseResults));

AtomicArray<SearchPhaseResult> atomicArray = new AtomicArray<>(1);
atomicArray.set(0, createQuerySearchResult(false));
when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

assertTrue(rrfProcessor.shouldSkipProcessor(mockQueryPhaseResultConsumer));

atomicArray.set(0, createQuerySearchResult(true));
assertFalse(rrfProcessor.shouldSkipProcessor(mockQueryPhaseResultConsumer));
}

@SneakyThrows
public void testGetQueryPhaseSearchResults() {
AtomicArray<SearchPhaseResult> atomicArray = new AtomicArray<>(2);
atomicArray.set(0, createQuerySearchResult(true));
atomicArray.set(1, createQuerySearchResult(false));
when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

List<QuerySearchResult> results = rrfProcessor.getQueryPhaseSearchResults(mockQueryPhaseResultConsumer);
assertEquals(2, results.size());
assertNotNull(results.get(0));
assertNotNull(results.get(1));
}

@SneakyThrows
public void testGetFetchSearchResults() {
AtomicArray<SearchPhaseResult> atomicArray = new AtomicArray<>(1);
atomicArray.set(0, createQuerySearchResult(true));
when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

Optional<FetchSearchResult> result = rrfProcessor.getFetchSearchResults(mockQueryPhaseResultConsumer);
assertFalse(result.isPresent());
}

private QuerySearchResult createQuerySearchResult(boolean isHybrid) {
ShardId shardId = new ShardId("index", "uuid", 0);
OriginalIndices originalIndices = new OriginalIndices(new String[] { "index" }, IndicesOptions.strictExpandOpenAndForbidClosed());
SearchRequest searchRequest = new SearchRequest("index");
searchRequest.source(new SearchSourceBuilder());
searchRequest.allowPartialSearchResults(true);

int numberOfShards = 1;
AliasFilter aliasFilter = new AliasFilter(null, Strings.EMPTY_ARRAY);
float indexBoost = 1.0f;
long nowInMillis = System.currentTimeMillis();
String clusterAlias = null;
String[] indexRoutings = Strings.EMPTY_ARRAY;

ShardSearchRequest shardSearchRequest = new ShardSearchRequest(
originalIndices,
searchRequest,
shardId,
numberOfShards,
aliasFilter,
indexBoost,
nowInMillis,
clusterAlias,
indexRoutings
);

QuerySearchResult result = new QuerySearchResult(
new ShardSearchContextId("test", 1),
new SearchShardTarget("node1", shardId, clusterAlias, originalIndices),
shardSearchRequest
);
result.from(0).size(10);

ScoreDoc[] scoreDocs;
if (isHybrid) {
scoreDocs = new ScoreDoc[] { HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults(0) };
} else {
scoreDocs = new ScoreDoc[] { new ScoreDoc(0, 1.0f) };
}

TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), scoreDocs);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(topDocs, 1.0f);
result.topDocs(topDocsAndMaxScore, new DocValueFormat[0]);

return result;
}
}
Loading
Loading