diff --git a/CHANGELOG.md b/CHANGELOG.md index 8cbd0ef2f..7a8d987c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.18...2.x) ### Features - Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283] +- Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292] ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] - Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290] diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorSimilarityFunction.java b/src/main/java/org/opensearch/knn/index/KNNVectorSimilarityFunction.java index 7eca6287c..0c54cb370 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorSimilarityFunction.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorSimilarityFunction.java @@ -29,7 +29,8 @@ public float compare(byte[] v1, byte[] v2) { @Override public VectorSimilarityFunction getVectorSimilarityFunction() { - throw new IllegalStateException("VectorSimilarityFunction is not available for Hamming space"); + // This is not used in binary case + return VectorSimilarityFunction.EUCLIDEAN; } }; diff --git a/src/main/java/org/opensearch/knn/index/VectorDataType.java b/src/main/java/org/opensearch/knn/index/VectorDataType.java index 4827a4582..2b5022722 100644 --- a/src/main/java/org/opensearch/knn/index/VectorDataType.java +++ b/src/main/java/org/opensearch/knn/index/VectorDataType.java @@ -40,7 +40,7 @@ public enum VectorDataType { @Override public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) { - throw new IllegalStateException("Unsupported method"); + return KnnByteVectorField.createFieldType(dimension / Byte.SIZE, vectorSimilarityFunction); } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index 72187516f..f3a125838 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -114,7 +114,12 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { } } - KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth); + KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams( + params, + defaultMaxConnections, + defaultBeamWidth, + knnMethodContext.getSpaceType() + ); log.debug( "Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"", field, diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990BinaryVectorScorer.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990BinaryVectorScorer.java new file mode 100644 index 000000000..6383252e0 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990BinaryVectorScorer.java @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; + +import java.io.IOException; + +/** + * A FlatVectorsScorer to be used for scoring binary vectors. Meant to be used with {@link KNN990BinaryVectorScorer} + */ +public class KNN990BinaryVectorScorer implements FlatVectorsScorer { + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction vectorSimilarityFunction, + RandomAccessVectorValues randomAccessVectorValues + ) throws IOException { + if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes) { + return new BinaryRandomVectorScorerSupplier((RandomAccessVectorValues.Bytes) randomAccessVectorValues); + } + throw new IllegalArgumentException("vectorValues must be an instance of RandomAccessVectorValues.Bytes"); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction vectorSimilarityFunction, + RandomAccessVectorValues randomAccessVectorValues, + float[] queryVector + ) throws IOException { + throw new IllegalArgumentException("binary vectors do not support float[] targets"); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction vectorSimilarityFunction, + RandomAccessVectorValues randomAccessVectorValues, + byte[] queryVector + ) throws IOException { + if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes) { + return new BinaryRandomVectorScorer((RandomAccessVectorValues.Bytes) randomAccessVectorValues, queryVector); + } + throw new IllegalArgumentException("vectorValues must be an instance of RandomAccessVectorValues.Bytes"); + } + + static class BinaryRandomVectorScorer implements RandomVectorScorer { + private final RandomAccessVectorValues.Bytes vectorValues; + private final byte[] queryVector; + + BinaryRandomVectorScorer(RandomAccessVectorValues.Bytes vectorValues, byte[] query) { + this.queryVector = query; + this.vectorValues = vectorValues; + } + + @Override + public float score(int node) throws IOException { + return 1 / (float) (1 + VectorUtil.xorBitCount(queryVector, vectorValues.vectorValue(node))); + } + + @Override + public int maxOrd() { + return vectorValues.size(); + } + + @Override + public int ordToDoc(int ord) { + return vectorValues.ordToDoc(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return vectorValues.getAcceptOrds(acceptDocs); + } + } + + static class BinaryRandomVectorScorerSupplier implements RandomVectorScorerSupplier { + protected final RandomAccessVectorValues.Bytes vectorValues; + protected final RandomAccessVectorValues.Bytes vectorValues1; + protected final RandomAccessVectorValues.Bytes vectorValues2; + + public BinaryRandomVectorScorerSupplier(RandomAccessVectorValues.Bytes vectorValues) throws IOException { + this.vectorValues = vectorValues; + this.vectorValues1 = vectorValues.copy(); + this.vectorValues2 = vectorValues.copy(); + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + byte[] queryVector = vectorValues1.vectorValue(ord); + return new BinaryRandomVectorScorer(vectorValues2, queryVector); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new BinaryRandomVectorScorerSupplier(vectorValues.copy()); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990HnswBinaryVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990HnswBinaryVectorsFormat.java new file mode 100644 index 000000000..8d6ac5865 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990HnswBinaryVectorsFormat.java @@ -0,0 +1,119 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.TaskExecutor; +import org.opensearch.knn.index.engine.KNNEngine; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN; +import static org.opensearch.knn.index.engine.KNNEngine.getMaxDimensionByEngine; + +/** + * Custom KnnVectorsFormat implementation to support binary vectors. This class is mostly identical to + * {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat}, however we use the custom {@link KNN990BinaryVectorScorer} + * to perform hamming bit scoring. + */ +public final class KNN990HnswBinaryVectorsFormat extends KnnVectorsFormat { + + private final int maxConn; + private final int beamWidth; + private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat(new KNN990BinaryVectorScorer()); + private final int numMergeWorkers; + private final TaskExecutor mergeExec; + + private static final String NAME = "KNN990HnswBinaryVectorsFormat"; + + /** + * Constructor logic is identical to {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat#Lucene99HnswVectorsFormat()} + */ + public KNN990HnswBinaryVectorsFormat() { + this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null); + } + + /** + * Constructor logic is identical to {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat#Lucene99HnswVectorsFormat(int, int)} + */ + public KNN990HnswBinaryVectorsFormat(int maxConn, int beamWidth) { + this(maxConn, beamWidth, 1, null); + } + + /** + * Constructor logic is identical to {@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat#Lucene99HnswVectorsFormat(int, int, int, java.util.concurrent.ExecutorService)} + */ + public KNN990HnswBinaryVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { + super(NAME); + if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { + throw new IllegalArgumentException( + "maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn + ); + } + if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) { + throw new IllegalArgumentException( + "beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth + ); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + if (numMergeWorkers == 1 && mergeExec != null) { + throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge"); + } + this.numMergeWorkers = numMergeWorkers; + if (mergeExec != null) { + this.mergeExec = new TaskExecutor(mergeExec); + } else { + this.mergeExec = null; + } + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter( + state, + this.maxConn, + this.beamWidth, + flatVectorsFormat.fieldsWriter(state), + this.numMergeWorkers, + this.mergeExec + ); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return getMaxDimensionByEngine(KNNEngine.LUCENE); + } + + @Override + public String toString() { + return "KNN990HnswBinaryVectorsFormat(name=KNN990HnswBinaryVectorsFormat, maxConn=" + + this.maxConn + + ", beamWidth=" + + this.beamWidth + + ", flatVectorFormat=" + + flatVectorsFormat + + ")"; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java index f565dfe5b..69f11f4a6 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java @@ -8,6 +8,7 @@ import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat; import org.opensearch.knn.index.engine.KNNEngine; @@ -24,11 +25,19 @@ public KNN990PerFieldKnnVectorsFormat(final Optional mapperServic mapperService, Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, - () -> new Lucene99HnswVectorsFormat(), - knnVectorsFormatParams -> new Lucene99HnswVectorsFormat( - knnVectorsFormatParams.getMaxConnections(), - knnVectorsFormatParams.getBeamWidth() - ), + Lucene99HnswVectorsFormat::new, + knnVectorsFormatParams -> { + // There is an assumption here that hamming space will only be used for binary vectors. This will need to be fixed if that + // changes in the future. + if (knnVectorsFormatParams.getSpaceType() == SpaceType.HAMMING) { + return new KNN990HnswBinaryVectorsFormat( + knnVectorsFormatParams.getMaxConnections(), + knnVectorsFormatParams.getBeamWidth() + ); + } else { + return new Lucene99HnswVectorsFormat(knnVectorsFormatParams.getMaxConnections(), knnVectorsFormatParams.getBeamWidth()); + } + }, knnScalarQuantizedVectorsFormatParams -> new Lucene99HnswScalarQuantizedVectorsFormat( knnScalarQuantizedVectorsFormatParams.getMaxConnections(), knnScalarQuantizedVectorsFormatParams.getBeamWidth(), diff --git a/src/main/java/org/opensearch/knn/index/codec/params/KNNVectorsFormatParams.java b/src/main/java/org/opensearch/knn/index/codec/params/KNNVectorsFormatParams.java index 52134bc7e..ebf985fbb 100644 --- a/src/main/java/org/opensearch/knn/index/codec/params/KNNVectorsFormatParams.java +++ b/src/main/java/org/opensearch/knn/index/codec/params/KNNVectorsFormatParams.java @@ -7,6 +7,7 @@ import lombok.Getter; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.SpaceType; import java.util.Map; @@ -17,10 +18,16 @@ public class KNNVectorsFormatParams { private int maxConnections; private int beamWidth; + private final SpaceType spaceType; public KNNVectorsFormatParams(final Map params, int defaultMaxConnections, int defaultBeamWidth) { + this(params, defaultMaxConnections, defaultBeamWidth, SpaceType.UNDEFINED); + } + + public KNNVectorsFormatParams(final Map params, int defaultMaxConnections, int defaultBeamWidth, SpaceType spaceType) { initMaxConnections(params, defaultMaxConnections); initBeamWidth(params, defaultBeamWidth); + this.spaceType = spaceType; } public boolean validate(final Map params) { diff --git a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java index 57cc016a6..701f79768 100644 --- a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java @@ -30,13 +30,18 @@ */ public class LuceneHNSWMethod extends AbstractKNNMethod { - private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT, VectorDataType.BYTE); + private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of( + VectorDataType.FLOAT, + VectorDataType.BYTE, + VectorDataType.BINARY + ); public final static List SUPPORTED_SPACES = Arrays.asList( SpaceType.UNDEFINED, SpaceType.L2, SpaceType.COSINESIMIL, - SpaceType.INNER_PRODUCT + SpaceType.INNER_PRODUCT, + SpaceType.HAMMING ); final static Encoder SQ_ENCODER = new LuceneSQEncoder(); diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index d01a9aff6..74b864f98 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -129,6 +129,7 @@ public static Query create(CreateQueryRequest createQueryRequest) { log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); switch (vectorDataType) { case BYTE: + case BINARY: return new LuceneEngineKnnVectorQuery( getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter, expandNested) ); diff --git a/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index d799c3869..fbdb77887 100644 --- a/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -10,3 +10,4 @@ # org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsFormat +org.opensearch.knn.index.codec.KNN990Codec.KNN990HnswBinaryVectorsFormat diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java index f760a6e88..73af608c1 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java @@ -13,7 +13,6 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.util.BytesRef; @@ -109,14 +108,6 @@ private void createKNNByteVectorDocument(Directory directory) throws IOException writer.close(); } - public void testCreateKnnVectorFieldType_whenBinary_thenException() { - Exception ex = expectThrows( - IllegalStateException.class, - () -> VectorDataType.BINARY.createKnnVectorFieldType(1, VectorSimilarityFunction.EUCLIDEAN) - ); - assertTrue(ex.getMessage().contains("Unsupported method")); - } - public void testGetVectorFromBytesRef_whenBinary_thenException() { byte[] vector = { 1, 2, 3 }; float[] expected = { 1, 2, 3 }; diff --git a/src/test/java/org/opensearch/knn/index/engine/KNNMethodContextTests.java b/src/test/java/org/opensearch/knn/index/engine/KNNMethodContextTests.java index c5979e576..bc72503a2 100644 --- a/src/test/java/org/opensearch/knn/index/engine/KNNMethodContextTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/KNNMethodContextTests.java @@ -283,13 +283,6 @@ public void testValidateVectorDataType_whenBinaryFaissHNSW_thenValid() { } public void testValidateVectorDataType_whenBinaryNonFaiss_thenException() { - validateValidateVectorDataType( - KNNEngine.LUCENE, - KNNConstants.METHOD_HNSW, - VectorDataType.BINARY, - SpaceType.HAMMING, - "UnsupportedMethod" - ); validateValidateVectorDataType( KNNEngine.NMSLIB, KNNConstants.METHOD_HNSW, diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index 714723a8e..9e637be9b 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -1528,8 +1528,7 @@ public void testTypeParser_whenBinaryFaissHNSWWithInvalidSpaceType_thenException } } - public void testTypeParser_whenBinaryNonFaiss_thenException() throws IOException { - testTypeParserWithBinaryDataType(KNNEngine.LUCENE, SpaceType.HAMMING, METHOD_HNSW, 8, "is not supported for vector data type"); + public void testTypeParser_whenBinaryNmslib_thenException() throws IOException { testTypeParserWithBinaryDataType(KNNEngine.NMSLIB, SpaceType.HAMMING, METHOD_HNSW, 8, "is not supported for vector data type"); } diff --git a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java index 4fb267eb5..6f243ff3a 100644 --- a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java +++ b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java @@ -5,6 +5,7 @@ package org.opensearch.knn.integ; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Floats; import lombok.SneakyThrows; @@ -27,6 +28,7 @@ import java.io.IOException; import java.net.URL; import java.util.Arrays; +import java.util.Collection; import java.util.List; import java.util.Set; import java.util.stream.Collectors; @@ -34,13 +36,23 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; /** - * This class contains integration tests for binary index with HNSW in Faiss + * This class contains integration tests for binary index with HNSW in Faiss and Lucene */ @Log4j2 public class BinaryIndexIT extends KNNRestTestCase { private static TestUtils.TestData testData; private static final int NEVER_BUILD_GRAPH = -1; private static final int ALWAYS_BUILD_GRAPH = 0; + private final KNNEngine engine; + + public BinaryIndexIT(KNNEngine engine) { + this.engine = engine; + } + + @ParametersFactory + public static Collection parameters() { + return Arrays.asList(new Object[] { KNNEngine.LUCENE }, new Object[] { KNNEngine.FAISS }); + } @BeforeClass public static void setUpClass() throws IOException { @@ -66,9 +78,9 @@ public void cleanUp() { } @SneakyThrows - public void testFaissHnswBinary_whenSmallDataSet_thenCreateIngestQueryWorks() { + public void testHnswBinary_whenSmallDataSet_thenCreateIngestQueryWorks() { // Create Index - createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 16); + createKnnHnswBinaryIndex(engine, INDEX_NAME, FIELD_NAME, 16); // Ingest Byte[] vector1 = { 0b00000001, 0b00000001 }; @@ -93,9 +105,9 @@ public void testFaissHnswBinary_whenSmallDataSet_thenCreateIngestQueryWorks() { } @SneakyThrows - public void testFaissHnswBinary_when1000Data_thenRecallIsAboveNinePointZero() { + public void testHnswBinary_when1000Data_thenRecallIsAboveNinePointZero() { // Create Index - createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128); + createKnnHnswBinaryIndex(engine, INDEX_NAME, FIELD_NAME, 128); ingestTestData(INDEX_NAME, FIELD_NAME); int k = 100; @@ -110,9 +122,9 @@ public void testFaissHnswBinary_when1000Data_thenRecallIsAboveNinePointZero() { } @SneakyThrows - public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsNegativeEndToEnd_thenBuildGraphBasedOnSetting() { + public void testHnswBinary_whenBuildVectorGraphThresholdIsNegativeEndToEnd_thenBuildGraphBasedOnSetting() { // Create Index - createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128, NEVER_BUILD_GRAPH); + createKnnHnswBinaryIndex(engine, INDEX_NAME, FIELD_NAME, 128, NEVER_BUILD_GRAPH); ingestTestData(INDEX_NAME, FIELD_NAME); assertEquals(1, runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[0], 1).size()); @@ -133,9 +145,9 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsNegativeEndToEnd_ } @SneakyThrows - public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSetting() { + public void testHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSetting() { // Create Index - createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128, testData.indexData.docs.length); + createKnnHnswBinaryIndex(engine, INDEX_NAME, FIELD_NAME, 128, testData.indexData.docs.length); ingestTestData(INDEX_NAME, FIELD_NAME, false); assertEquals(1, runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[0], 1).size()); @@ -156,9 +168,9 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_ } @SneakyThrows - public void testFaissHnswBinary_whenRadialSearch_thenThrowException() { + public void testHnswBinary_whenRadialSearch_thenThrowException() { // Create Index - createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 16); + createKnnHnswBinaryIndex(engine, INDEX_NAME, FIELD_NAME, 16); // Query float[] queryVector = { (byte) 0b10001111, (byte) 0b10000000 }; diff --git a/src/test/java/org/opensearch/knn/integ/BinaryIndexInvalidMappingIT.java b/src/test/java/org/opensearch/knn/integ/BinaryIndexInvalidMappingIT.java index 29e710ec1..a706dd0cd 100644 --- a/src/test/java/org/opensearch/knn/integ/BinaryIndexInvalidMappingIT.java +++ b/src/test/java/org/opensearch/knn/integ/BinaryIndexInvalidMappingIT.java @@ -46,11 +46,6 @@ public void cleanUp() { public static Collection parameters() throws IOException { return Arrays.asList( $$( - $( - "Creation of binary index with lucene engine should fail", - createKnnHnswBinaryIndexMapping(KNNEngine.LUCENE, FIELD_NAME, 16, null), - "Validation Failed" - ), $( "Creation of binary index with nmslib engine should fail", createKnnHnswBinaryIndexMapping(KNNEngine.NMSLIB, FIELD_NAME, 16, null),