From a897b8089046e4b486cf217ec1eb5708bef17eec Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Wed, 6 Nov 2024 18:59:47 -0600 Subject: [PATCH] Add support for Lucene SQ 4 bits Signed-off-by: Naveen Tatikonda --- CHANGELOG.md | 1 + .../engine/lucene/LuceneMethodResolver.java | 31 ++++++++++ .../index/engine/lucene/LuceneSQEncoder.java | 22 ++++++- .../opensearch/knn/index/LuceneEngineIT.java | 60 +++++++++++-------- .../engine/lucene/LuceneSQEncoderTests.java | 13 +++- 5 files changed, 100 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c66eb2184..6c365d42b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.18...2.x) ### Features +* Add support for Lucene int4 SQ [2253](https://github.com/opensearch-project/k-NN/pull/2253) ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] ### Bug Fixes diff --git a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneMethodResolver.java b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneMethodResolver.java index 6546d9f93..849856c4c 100644 --- a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneMethodResolver.java +++ b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneMethodResolver.java @@ -18,9 +18,13 @@ import org.opensearch.knn.index.mapper.Mode; import java.util.HashMap; +import java.util.Locale; import java.util.Map; import java.util.Set; +import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; +import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; +import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.index.engine.lucene.LuceneHNSWMethod.HNSW_METHOD_COMPONENT; @@ -60,6 +64,7 @@ public ResolvedMethodContext resolveMethod( protected void resolveEncoder(KNNMethodContext resolvedKNNMethodContext, KNNMethodConfigContext knnMethodConfigContext) { if (shouldEncoderBeResolved(resolvedKNNMethodContext, knnMethodConfigContext) == false) { + validateEncoderDimension(resolvedKNNMethodContext, knnMethodConfigContext); return; } @@ -94,6 +99,32 @@ private void validateConfig(KNNMethodConfigContext knnMethodConfigContext, boole } } + private void validateEncoderDimension(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) { + String encoderName = getEncoderName(knnMethodContext); + if (!ENCODER_SQ.equals(encoderName)) { + return; + } + + MethodComponentContext encoderMethodComponentContext = getEncoderComponentContext(knnMethodContext); + // This check is coming from Lucene. Code Ref: + // https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java#L200-L206 + if (encoderMethodComponentContext.getParameters().containsKey(LUCENE_SQ_BITS) + && encoderMethodComponentContext.getParameters().get(LUCENE_SQ_BITS).equals(4) + && knnMethodConfigContext.getDimension() % 2 != 0) { + ValidationException validationException = new ValidationException(); + validationException.addValidationError( + String.format( + Locale.ROOT, + "Odd vector dimension is not supported when [%s] is set to [4] for [%s] engine with [%s] encoder", + LUCENE_SQ_BITS, + LUCENE_NAME, + ENCODER_SQ + ) + ); + throw validationException; + } + } + private CompressionLevel getDefaultCompressionLevel(KNNMethodConfigContext knnMethodConfigContext) { if (CompressionLevel.isConfigured(knnMethodConfigContext.getCompressionLevel())) { return knnMethodConfigContext.getCompressionLevel(); diff --git a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoder.java b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoder.java index 6bd16ebee..e655a47f4 100644 --- a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoder.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.engine.lucene; import com.google.common.collect.ImmutableSet; +import org.opensearch.common.ValidationException; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.Encoder; import org.opensearch.knn.index.engine.KNNMethodConfigContext; @@ -31,7 +32,7 @@ public class LuceneSQEncoder implements Encoder { private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT); - private final static List LUCENE_SQ_BITS_SUPPORTED = List.of(7); + private final static List LUCENE_SQ_BITS_SUPPORTED = List.of(4, 7); private final static MethodComponent METHOD_COMPONENT = MethodComponent.Builder.builder(ENCODER_SQ) .addSupportedDataTypes(SUPPORTED_DATA_TYPES) .addParameter( @@ -58,7 +59,24 @@ public CompressionLevel calculateCompressionLevel( MethodComponentContext methodComponentContext, KNNMethodConfigContext knnMethodConfigContext ) { - // Hard coding to 4x for now, given thats all that is supported. + if (methodComponentContext.getParameters().containsKey(LUCENE_SQ_BITS) == false) { + return CompressionLevel.x4; + } + + // Map the number of bits passed in, back to the compression level + Object value = methodComponentContext.getParameters().get(LUCENE_SQ_BITS); + ValidationException validationException = METHOD_COMPONENT.getParameters() + .get(LUCENE_SQ_BITS) + .validate(value, knnMethodConfigContext); + if (validationException != null) { + throw validationException; + } + + Integer bitCount = (Integer) value; + if (bitCount == 4) { + return CompressionLevel.NOT_CONFIGURED; + } + // Return 4x compression for 7 bits return CompressionLevel.x4; } } diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index 688d22e74..d2eee9843 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -90,6 +90,8 @@ public class LuceneEngineIT extends KNNRestTestCase { private static final String INTEGER_FIELD_NAME = "int_field"; private static final String FILED_TYPE_INTEGER = "integer"; private static final String NON_EXISTENT_INTEGER_FIELD_NAME = "nonexistent_int_field"; + private final static List LUCENE_SQ_BITS_SUPPORTED = ImmutableList.of(4, 7); + private static final int DIMENSION_SQ = 2; @After public final void cleanUp() throws IOException { @@ -592,16 +594,30 @@ public void testSQ_withInvalidParams_thenThrowException() { ); } + @SneakyThrows + public void testSQ_4bits_withOddDimension_thenThrowException() { + expectThrows( + ResponseException.class, + () -> createKnnIndexMappingWithLuceneEngineAndSQEncoder( + DIMENSION, + SpaceType.L2, + VectorDataType.FLOAT, + 4, + MINIMUM_CONFIDENCE_INTERVAL + ) + ); + } + @SneakyThrows public void testAddDocWithSQEncoder() { createKnnIndexMappingWithLuceneEngineAndSQEncoder( - DIMENSION, + DIMENSION_SQ, SpaceType.L2, VectorDataType.FLOAT, - LUCENE_SQ_DEFAULT_BITS, + LUCENE_SQ_BITS_SUPPORTED.get(random().nextInt(LUCENE_SQ_BITS_SUPPORTED.size())), MAXIMUM_CONFIDENCE_INTERVAL ); - Float[] vector = new Float[] { 2.0f, 4.5f, 6.5f }; + Float[] vector = new Float[] { 2.0f, 4.5f }; addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); refreshIndex(INDEX_NAME); @@ -611,16 +627,16 @@ public void testAddDocWithSQEncoder() { @SneakyThrows public void testUpdateDocWithSQEncoder() { createKnnIndexMappingWithLuceneEngineAndSQEncoder( - DIMENSION, + DIMENSION_SQ, SpaceType.INNER_PRODUCT, VectorDataType.FLOAT, - LUCENE_SQ_DEFAULT_BITS, + LUCENE_SQ_BITS_SUPPORTED.get(random().nextInt(LUCENE_SQ_BITS_SUPPORTED.size())), MAXIMUM_CONFIDENCE_INTERVAL ); - Float[] vector = { 6.0f, 6.0f, 7.0f }; + Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); - Float[] updatedVector = { 8.0f, 8.0f, 8.0f }; + Float[] updatedVector = { 8.0f, 8.0f }; updateKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, updatedVector); refreshIndex(INDEX_NAME); @@ -630,13 +646,13 @@ public void testUpdateDocWithSQEncoder() { @SneakyThrows public void testDeleteDocWithSQEncoder() { createKnnIndexMappingWithLuceneEngineAndSQEncoder( - DIMENSION, + DIMENSION_SQ, SpaceType.INNER_PRODUCT, VectorDataType.FLOAT, - LUCENE_SQ_DEFAULT_BITS, + LUCENE_SQ_BITS_SUPPORTED.get(random().nextInt(LUCENE_SQ_BITS_SUPPORTED.size())), MAXIMUM_CONFIDENCE_INTERVAL ); - Float[] vector = { 6.0f, 6.0f, 7.0f }; + Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); deleteKnnDoc(INDEX_NAME, DOC_ID); @@ -648,16 +664,16 @@ public void testDeleteDocWithSQEncoder() { @SneakyThrows public void testIndexingAndQueryingWithSQEncoder() { createKnnIndexMappingWithLuceneEngineAndSQEncoder( - DIMENSION, + DIMENSION_SQ, SpaceType.INNER_PRODUCT, VectorDataType.FLOAT, - LUCENE_SQ_DEFAULT_BITS, + LUCENE_SQ_BITS_SUPPORTED.get(random().nextInt(LUCENE_SQ_BITS_SUPPORTED.size())), MAXIMUM_CONFIDENCE_INTERVAL ); int numDocs = 10; for (int i = 0; i < numDocs; i++) { - float[] indexVector = new float[DIMENSION]; + float[] indexVector = new float[DIMENSION_SQ]; Arrays.fill(indexVector, (float) i); addKnnDocWithAttributes(INDEX_NAME, Integer.toString(i), FIELD_NAME, indexVector, ImmutableMap.of("rating", String.valueOf(i))); } @@ -666,7 +682,7 @@ public void testIndexingAndQueryingWithSQEncoder() { refreshAllNonSystemIndices(); assertEquals(numDocs, getDocCount(INDEX_NAME)); - float[] queryVector = new float[DIMENSION]; + float[] queryVector = new float[DIMENSION_SQ]; Arrays.fill(queryVector, (float) numDocs); int k = 10; @@ -680,24 +696,20 @@ public void testIndexingAndQueryingWithSQEncoder() { public void testQueryWithFilterUsingSQEncoder() throws Exception { createKnnIndexMappingWithLuceneEngineAndSQEncoder( - DIMENSION, + DIMENSION_SQ, SpaceType.INNER_PRODUCT, VectorDataType.FLOAT, - LUCENE_SQ_DEFAULT_BITS, + LUCENE_SQ_BITS_SUPPORTED.get(random().nextInt(LUCENE_SQ_BITS_SUPPORTED.size())), MAXIMUM_CONFIDENCE_INTERVAL ); - addKnnDocWithAttributes( - DOC_ID, - new float[] { 6.0f, 7.9f, 3.1f }, - ImmutableMap.of(COLOR_FIELD_NAME, "red", TASTE_FIELD_NAME, "sweet") - ); - addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); - addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + addKnnDocWithAttributes(DOC_ID, new float[] { 6.0f, 7.9f }, ImmutableMap.of(COLOR_FIELD_NAME, "red", TASTE_FIELD_NAME, "sweet")); + addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); + addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); refreshIndex(INDEX_NAME); - final float[] searchVector = { 6.0f, 6.0f, 4.1f }; + final float[] searchVector = { 6.0f, 6.0f }; List expectedDocIdsKGreaterThanFilterResult = Arrays.asList(DOC_ID, DOC_ID_3); List expectedDocIdsKLimitsFilterResult = Arrays.asList(DOC_ID); validateQueryResultsWithFilters(searchVector, 5, 1, expectedDocIdsKGreaterThanFilterResult, expectedDocIdsKLimitsFilterResult); diff --git a/src/test/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoderTests.java b/src/test/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoderTests.java index 139f96e8b..86e7c0759 100644 --- a/src/test/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoderTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/lucene/LuceneSQEncoderTests.java @@ -6,11 +6,22 @@ package org.opensearch.knn.index.engine.lucene; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.mapper.CompressionLevel; +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; +import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS; + public class LuceneSQEncoderTests extends KNNTestCase { public void testCalculateCompressionLevel() { LuceneSQEncoder encoder = new LuceneSQEncoder(); - assertEquals(CompressionLevel.x4, encoder.calculateCompressionLevel(null, null)); + assertEquals(CompressionLevel.NOT_CONFIGURED, encoder.calculateCompressionLevel(generateMethodComponentContext(4), null)); + assertEquals(CompressionLevel.x4, encoder.calculateCompressionLevel(generateMethodComponentContext(7), null)); + } + + private MethodComponentContext generateMethodComponentContext(int bitCount) { + return new MethodComponentContext(ENCODER_SQ, Map.of(LUCENE_SQ_BITS, bitCount)); } }