diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e6016ff9..01d802e16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] - Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290] +- Optimizes lucene query execution to prevent unnecessary rewrites (#2305)[https://github.com/opensearch-project/k-NN/pull/2305] ### Bug Fixes * Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282] ### Infrastructure diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index bac6e95b5..d01a9aff6 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -15,6 +15,7 @@ import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.query.common.QueryUtils; import org.opensearch.knn.index.query.lucenelib.NestedKnnVectorQueryFactory; +import org.opensearch.knn.index.query.lucene.LuceneEngineKnnVectorQuery; import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery; import org.opensearch.knn.index.query.rescore.RescoreContext; @@ -128,9 +129,13 @@ public static Query create(CreateQueryRequest createQueryRequest) { log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); switch (vectorDataType) { case BYTE: - return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter, expandNested); + return new LuceneEngineKnnVectorQuery( + getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter, expandNested) + ); case FLOAT: - return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter, expandNested); + return new LuceneEngineKnnVectorQuery( + getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter, expandNested) + ); default: throw new IllegalArgumentException( String.format( diff --git a/src/main/java/org/opensearch/knn/index/query/lucene/LuceneEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/lucene/LuceneEngineKnnVectorQuery.java new file mode 100644 index 000000000..40af1f510 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/lucene/LuceneEngineKnnVectorQuery.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.lucene; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Weight; + +import java.io.IOException; + +/** + * LuceneEngineKnnVectorQuery is a wrapper around a vector queries for the Lucene engine. + * This enables us to defer rewrites until weight creation to optimize repeated execution + * of Lucene based k-NN queries. + */ +@AllArgsConstructor +@Log4j2 +public class LuceneEngineKnnVectorQuery extends Query { + @Getter + private final Query luceneQuery; + + /* + Prevents repeated rewrites of the query for the Lucene engine. + */ + @Override + public Query rewrite(IndexSearcher indexSearcher) { + return this; + } + + /* + Rewrites the query just before weight creation. + */ + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + Query rewrittenQuery = luceneQuery.rewrite(searcher); + return rewrittenQuery.createWeight(searcher, scoreMode, boost); + } + + @Override + public String toString(String s) { + return luceneQuery.toString(); + } + + @Override + public void visit(QueryVisitor queryVisitor) { + queryVisitor.visitLeaf(this); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + LuceneEngineKnnVectorQuery otherQuery = (LuceneEngineKnnVectorQuery) o; + return luceneQuery.equals(otherQuery.luceneQuery); + } + + @Override + public int hashCode() { + return luceneQuery.hashCode(); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index b28b790d1..b609bb0df 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -8,7 +8,6 @@ import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; import org.apache.lucene.search.FloatVectorSimilarityQuery; -import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.junit.Before; @@ -33,6 +32,7 @@ import org.opensearch.knn.index.mapper.KNNMappingConfig; import org.opensearch.knn.index.mapper.KNNVectorFieldType; import org.opensearch.knn.index.mapper.Mode; +import org.opensearch.knn.index.query.lucene.LuceneEngineKnnVectorQuery; import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.knn.index.util.KNNClusterUtil; import org.opensearch.knn.index.engine.KNNMethodContext; @@ -512,7 +512,7 @@ public void testDoToQuery_KnnQueryWithFilter_Lucene() throws Exception { // Then assertNotNull(query); - assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class)); + assertTrue(query.getClass().isAssignableFrom(LuceneEngineKnnVectorQuery.class)); } @SneakyThrows diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index 1836ddb7e..eff2ca895 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -11,8 +11,6 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.join.BitSetProducer; -import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; -import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.apache.lucene.search.join.ToChildBlockJoinQuery; import org.junit.Before; import org.mockito.Mock; @@ -30,6 +28,7 @@ import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.query.lucenelib.ExpandNestedDocsQuery; +import org.opensearch.knn.index.query.lucene.LuceneEngineKnnVectorQuery; import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery; import org.opensearch.knn.index.query.rescore.RescoreContext; @@ -120,7 +119,7 @@ public void testCreateLuceneDefaultQuery() { .vectorDataType(DEFAULT_VECTOR_DATA_TYPE_FIELD) .build() ); - assertEquals(KnnFloatVectorQuery.class, query.getClass()); + assertEquals(LuceneEngineKnnVectorQuery.class, query.getClass()); } } @@ -138,7 +137,7 @@ public void testLuceneFloatVectorQuery() { ); // efsearch > k - Query expectedQuery1 = new KnnFloatVectorQuery(testFieldName, testQueryVector, 100, null); + Query expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnFloatVectorQuery(testFieldName, testQueryVector, 100, null)); assertEquals(expectedQuery1, actualQuery1); // efsearch < k @@ -153,7 +152,7 @@ public void testLuceneFloatVectorQuery() { .vectorDataType(VectorDataType.FLOAT) .build() ); - expectedQuery1 = new KnnFloatVectorQuery(testFieldName, testQueryVector, testK, null); + expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnFloatVectorQuery(testFieldName, testQueryVector, testK, null)); assertEquals(expectedQuery1, actualQuery1); actualQuery1 = KNNQueryFactory.create( @@ -166,7 +165,7 @@ public void testLuceneFloatVectorQuery() { .vectorDataType(VectorDataType.FLOAT) .build() ); - expectedQuery1 = new KnnFloatVectorQuery(testFieldName, testQueryVector, testK, null); + expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnFloatVectorQuery(testFieldName, testQueryVector, testK, null)); assertEquals(expectedQuery1, actualQuery1); } @@ -184,7 +183,7 @@ public void testLuceneByteVectorQuery() { ); // efsearch > k - Query expectedQuery1 = new KnnByteVectorQuery(testFieldName, testByteQueryVector, 100, null); + Query expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnByteVectorQuery(testFieldName, testByteQueryVector, 100, null)); assertEquals(expectedQuery1, actualQuery1); // efsearch < k @@ -199,7 +198,7 @@ public void testLuceneByteVectorQuery() { .vectorDataType(VectorDataType.BYTE) .build() ); - expectedQuery1 = new KnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null); + expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null)); assertEquals(expectedQuery1, actualQuery1); actualQuery1 = KNNQueryFactory.create( @@ -212,7 +211,7 @@ public void testLuceneByteVectorQuery() { .vectorDataType(VectorDataType.BYTE) .build() ); - expectedQuery1 = new KnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null); + expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null)); assertEquals(expectedQuery1, actualQuery1); } @@ -235,7 +234,7 @@ public void testCreateLuceneQueryWithFilter() { .filter(FILTER_QUERY_BUILDER) .build(); Query query = KNNQueryFactory.create(createQueryRequest); - assertEquals(KnnFloatVectorQuery.class, query.getClass()); + assertEquals(LuceneEngineKnnVectorQuery.class, query.getClass()); } } @@ -311,8 +310,8 @@ public void testCreateFaissQueryWithFilter_withValidValues_nullEfSearch_thenSucc } public void testCreate_whenLuceneWithParentFilter_thenReturnDiversifyingQuery() { - validateDiversifyingQueryWithParentFilter(VectorDataType.BYTE, DiversifyingChildrenByteKnnVectorQuery.class); - validateDiversifyingQueryWithParentFilter(VectorDataType.FLOAT, DiversifyingChildrenFloatKnnVectorQuery.class); + validateDiversifyingQueryWithParentFilter(VectorDataType.BYTE, LuceneEngineKnnVectorQuery.class); + validateDiversifyingQueryWithParentFilter(VectorDataType.FLOAT, LuceneEngineKnnVectorQuery.class); } public void testCreate_whenNestedVectorFiledAndNonNestedFilterField_thenReturnToChildBlockJoinQueryForFilters() { @@ -515,7 +514,11 @@ private void testExpandNestedDocsQuery(KNNEngine knnEngine, Class klass, VectorD .build(); Query query = KNNQueryFactory.create(createQueryRequest); - // Then - assertEquals(klass, query.getClass()); + if (knnEngine == KNNEngine.LUCENE) { + assertEquals(klass, ((LuceneEngineKnnVectorQuery) query).getLuceneQuery().getClass()); + } else { + // Then + assertEquals(klass, query.getClass()); + } } } diff --git a/src/test/java/org/opensearch/knn/index/query/lucene/LuceneEngineKnnVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/lucene/LuceneEngineKnnVectorQueryTests.java new file mode 100644 index 000000000..10be94890 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/lucene/LuceneEngineKnnVectorQueryTests.java @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.lucene; + +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Weight; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.Spy; +import org.opensearch.test.OpenSearchTestCase; + +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.MockitoAnnotations.openMocks; + +public class LuceneEngineKnnVectorQueryTests extends OpenSearchTestCase { + + @Mock + IndexSearcher indexSearcher; + + @Mock + Query luceneQuery; + + @Mock + Weight weight; + + @Mock + QueryVisitor queryVisitor; + + @Spy + @InjectMocks + LuceneEngineKnnVectorQuery objectUnderTest; + + @Override + public void setUp() throws Exception { + super.setUp(); + openMocks(this); + when(luceneQuery.rewrite(any(IndexSearcher.class))).thenReturn(luceneQuery); + when(luceneQuery.createWeight(any(IndexSearcher.class), any(ScoreMode.class), anyFloat())).thenReturn(weight); + } + + public void testRewrite() { + objectUnderTest.rewrite(indexSearcher); + objectUnderTest.rewrite(indexSearcher); + objectUnderTest.rewrite(indexSearcher); + verifyNoInteractions(luceneQuery); + verify(objectUnderTest, times(3)).rewrite(indexSearcher); + } + + public void testCreateWeight() throws Exception { + objectUnderTest.rewrite(indexSearcher); + objectUnderTest.rewrite(indexSearcher); + objectUnderTest.rewrite(indexSearcher); + verifyNoInteractions(luceneQuery); + Weight actualWeight = objectUnderTest.createWeight(indexSearcher, ScoreMode.TOP_DOCS, 1.0f); + verify(luceneQuery, times(1)).rewrite(indexSearcher); + verify(objectUnderTest, times(3)).rewrite(indexSearcher); + assertEquals(weight, actualWeight); + } + + public void testVisit() { + objectUnderTest.visit(queryVisitor); + verify(queryVisitor).visitLeaf(objectUnderTest); + } + + public void testEquals() { + LuceneEngineKnnVectorQuery mainQuery = new LuceneEngineKnnVectorQuery(luceneQuery); + LuceneEngineKnnVectorQuery otherQuery = new LuceneEngineKnnVectorQuery(luceneQuery); + assertEquals(mainQuery, otherQuery); + assertEquals(mainQuery, mainQuery); + assertNotEquals(mainQuery, null); + assertNotEquals(mainQuery, new Object()); + LuceneEngineKnnVectorQuery otherQuery2 = new LuceneEngineKnnVectorQuery(null); + assertNotEquals(mainQuery, otherQuery2); + } + + public void testHashCode() { + LuceneEngineKnnVectorQuery mainQuery = new LuceneEngineKnnVectorQuery(luceneQuery); + assertEquals(mainQuery.hashCode(), luceneQuery.hashCode()); + } + + public void testToString() { + LuceneEngineKnnVectorQuery mainQuery = new LuceneEngineKnnVectorQuery(luceneQuery); + assertEquals(mainQuery.toString(), luceneQuery.toString()); + } +}