Skip to content

Commit

Permalink
Improving Defaults Hyper Parameter for Binary Quantization Indexes (#…
Browse files Browse the repository at this point in the history
…2087) (#2095)

Signed-off-by: VIKASH TIWARI <[email protected]>
(cherry picked from commit e55e49b)

Co-authored-by: Vikasht34 <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and Vikasht34 authored Sep 11, 2024
1 parent d40b5da commit 272324c
Show file tree
Hide file tree
Showing 10 changed files with 199 additions and 13 deletions.
36 changes: 31 additions & 5 deletions src/main/java/org/opensearch/knn/index/engine/MethodComponent.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import org.opensearch.common.ValidationException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.mapper.Mode;
import org.opensearch.knn.index.util.IndexHyperParametersUtil;

import java.util.HashMap;
Expand Down Expand Up @@ -328,19 +330,43 @@ public static Map<String, Object> getParameterMapWithDefaultsAdded(
Map<String, Object> parametersWithDefaultsMap = new HashMap<>();
Map<String, Object> userProvidedParametersMap = methodComponentContext.getParameters();
Version indexCreationVersion = knnMethodConfigContext.getVersionCreated();
Mode mode = knnMethodConfigContext.getMode();
CompressionLevel compressionLevel = knnMethodConfigContext.getCompressionLevel();

// Check if the mode is ON_DISK and the compression level is one of the binary quantization levels (x32, x16, or x8).
// This determines whether to use binary quantization-specific values for parameters like ef_search and ef_construction.
boolean isOnDiskWithBinaryQuantization = (compressionLevel == CompressionLevel.x32
|| compressionLevel == CompressionLevel.x16
|| compressionLevel == CompressionLevel.x8);

for (Parameter<?> parameter : methodComponent.getParameters().values()) {
if (methodComponentContext.getParameters().containsKey(parameter.getName())) {
parametersWithDefaultsMap.put(parameter.getName(), userProvidedParametersMap.get(parameter.getName()));
} else {
// Picking the right values for the parameters whose values are different based on different index
// created version.
if (parameter.getName().equals(KNNConstants.METHOD_PARAMETER_EF_SEARCH)) {
parametersWithDefaultsMap.put(parameter.getName(), IndexHyperParametersUtil.getHNSWEFSearchValue(indexCreationVersion));
if (isOnDiskWithBinaryQuantization) {
parametersWithDefaultsMap.put(parameter.getName(), IndexHyperParametersUtil.getBinaryQuantizationEFSearchValue());
} else {
parametersWithDefaultsMap.put(
parameter.getName(),
IndexHyperParametersUtil.getHNSWEFSearchValue(indexCreationVersion)
);
}
} else if (parameter.getName().equals(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) {
parametersWithDefaultsMap.put(
parameter.getName(),
IndexHyperParametersUtil.getHNSWEFConstructionValue(indexCreationVersion)
);
if (isOnDiskWithBinaryQuantization) {
parametersWithDefaultsMap.put(
parameter.getName(),
IndexHyperParametersUtil.getBinaryQuantizationEFConstructionValue()
);
} else {
parametersWithDefaultsMap.put(
parameter.getName(),
IndexHyperParametersUtil.getHNSWEFConstructionValue(indexCreationVersion)
);
}

} else {
Object value = parameter.getDefaultValue();
if (value != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ public ResolvedMethodContext resolveMethod(

// Fill in parameters for the encoder and then the method.
resolveEncoder(resolvedKNNMethodContext, knnMethodConfigContext, encoderMap);
resolveMethodParams(resolvedKNNMethodContext.getMethodComponentContext(), knnMethodConfigContext, method);

// From the resolved method context, get the compression level and validate it against the passed in
// configuration
CompressionLevel resolvedCompressionLevel = resolveCompressionLevelFromMethodContext(
Expand All @@ -77,6 +75,9 @@ public ResolvedMethodContext resolveMethod(

// Validate that resolved compression doesnt have any conflicts
validateCompressionConflicts(knnMethodConfigContext.getCompressionLevel(), resolvedCompressionLevel);
knnMethodConfigContext.setCompressionLevel(resolvedCompressionLevel);
resolveMethodParams(resolvedKNNMethodContext.getMethodComponentContext(), knnMethodConfigContext, method);

return ResolvedMethodContext.builder()
.knnMethodContext(resolvedKNNMethodContext)
.compressionLevel(resolvedCompressionLevel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ public enum CompressionLevel {
x1(1, "1x", null, Collections.emptySet()),
x2(2, "2x", null, Collections.emptySet()),
x4(4, "4x", null, Collections.emptySet()),
x8(8, "8x", new RescoreContext(1.5f), Set.of(Mode.ON_DISK)),
x16(16, "16x", new RescoreContext(2.0f), Set.of(Mode.ON_DISK)),
x32(32, "32x", new RescoreContext(2.0f), Set.of(Mode.ON_DISK));
x8(8, "8x", new RescoreContext(2.0f), Set.of(Mode.ON_DISK)),
x16(16, "16x", new RescoreContext(3.0f), Set.of(Mode.ON_DISK)),
x32(32, "32x", new RescoreContext(3.0f), Set.of(Mode.ON_DISK));

// Internally, an empty string is easier to deal with them null. However, from the mapping,
// we do not want users to pass in the empty string and instead want null. So we make the conversion herex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ public final class RescoreContext {

public static final int MAX_FIRST_PASS_RESULTS = 10000;

// Todo:- We will improve this in upcoming releases
public static final int MIN_FIRST_PASS_RESULTS = 100;

@Builder.Default
private float oversampleFactor = DEFAULT_OVERSAMPLE_FACTOR;

Expand All @@ -40,6 +43,6 @@ public static RescoreContext getDefault() {
* @return The number of results to return for the first pass of rescoring
*/
public int getFirstPassK(int finalK) {
return Math.min(MAX_FIRST_PASS_RESULTS, (int) Math.ceil(finalK * oversampleFactor));
return Math.min(MAX_FIRST_PASS_RESULTS, Math.max(MIN_FIRST_PASS_RESULTS, (int) Math.ceil(finalK * oversampleFactor)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ public class IndexHyperParametersUtil {

private static final int INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION_OLD_VALUE = 512;
private static final int INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH_OLD_VALUE = 512;
private static final int INDEX_BINARY_QUANTIZATION_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION = 256;
private static final int INDEX_BINARY_QUANTIZATION_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH = 256;

/**
* Returns the default value of EF Construction that should be used for the input index version. After version 2.12.0
Expand Down Expand Up @@ -76,4 +78,22 @@ public static int getHNSWEFSearchValue(@NonNull final Version indexVersion) {
);
return KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH;
}

/*
* Returns the default value of EF Construction that should be used with Binary Quantization.
*
* @return default value of EF Construction
*/
public static int getBinaryQuantizationEFConstructionValue() {
return INDEX_BINARY_QUANTIZATION_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION;
}

/*
* Returns the default value of EF Search that should be used with Binary Quantization.
*
* @return default value of EF Search
*/
public static int getBinaryQuantizationEFSearchValue() {
return INDEX_BINARY_QUANTIZATION_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.mapper.Mode;
import org.opensearch.knn.index.util.IndexHyperParametersUtil;

import java.io.IOException;
import java.util.Map;
Expand Down Expand Up @@ -214,4 +217,41 @@ public void testBuilder() {
.getLibraryParameters()
);
}

/**
* Test the new flow where EF_SEARCH and EF_CONSTRUCTION are set for ON_DISK mode
* with binary quantization compression levels.
*/
public void testGetParameterMapWithDefaultsAdded_forOnDiskWithBinaryQuantization() {
// Set up MethodComponent and context
String methodName = "test-method";
String parameterEFSearch = "ef_search";
String parameterEFConstruction = "ef_construction";

MethodComponent methodComponent = MethodComponent.Builder.builder(methodName)
.addParameter(parameterEFSearch, new Parameter.IntegerParameter(parameterEFSearch, 512, (v, context) -> v > 0))
.addParameter(parameterEFConstruction, new Parameter.IntegerParameter(parameterEFConstruction, 512, (v, context) -> v > 0))
.build();

// Simulate ON_DISK mode and binary quantization compression levels
KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder()
.versionCreated(Version.CURRENT)
.mode(Mode.ON_DISK) // ON_DISK mode
.compressionLevel(CompressionLevel.x32) // Binary quantization compression level
.build();

MethodComponentContext methodComponentContext = new MethodComponentContext(methodName, Map.of());

// Retrieve parameter map with defaults added
Map<String, Object> resultMap = MethodComponent.getParameterMapWithDefaultsAdded(
methodComponentContext,
methodComponent,
knnMethodConfigContext
);

// Check that binary quantization values are used
assertEquals(IndexHyperParametersUtil.getBinaryQuantizationEFSearchValue(), resultMap.get(parameterEFSearch));
assertEquals(IndexHyperParametersUtil.getBinaryQuantizationEFConstructionValue(), resultMap.get(parameterEFConstruction));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,33 @@ public void testResolveMethod_whenValid_thenResolve() {
SpaceType.L2
);
validateResolveMethodContext(resolvedMethodContext, CompressionLevel.x1, SpaceType.L2, ENCODER_FLAT, false);

KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder()
.vectorDataType(VectorDataType.FLOAT)
.versionCreated(Version.CURRENT)
.build();

resolvedMethodContext = TEST_RESOLVER.resolveMethod(
new KNNMethodContext(
KNNEngine.FAISS,
SpaceType.L2,
new MethodComponentContext(
METHOD_HNSW,
Map.of(
METHOD_ENCODER_PARAMETER,
new MethodComponentContext(
QFrameBitEncoder.NAME,
Map.of(QFrameBitEncoder.BITCOUNT_PARAM, CompressionLevel.x8.numBitsForFloat32())
)
)
)
),
knnMethodConfigContext,
false,
SpaceType.L2
);
assertEquals(knnMethodConfigContext.getCompressionLevel(), CompressionLevel.x8);
validateResolveMethodContext(resolvedMethodContext, CompressionLevel.x8, SpaceType.L2, QFrameBitEncoder.NAME, true);
}

private void validateResolveMethodContext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import org.opensearch.core.common.Strings;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.query.rescore.RescoreContext;

public class CompressionLevelTests extends KNNTestCase {

Expand Down Expand Up @@ -39,4 +40,30 @@ public void testIsConfigured() {
assertFalse(CompressionLevel.isConfigured(null));
assertTrue(CompressionLevel.isConfigured(CompressionLevel.x1));
}

public void testGetDefaultRescoreContext() {
// Test rescore context for ON_DISK mode
Mode mode = Mode.ON_DISK;

// x32 should have RescoreContext with an oversample factor of 3.0f
RescoreContext rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode);
assertNotNull(rescoreContext);
assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f);

// x16 should have RescoreContext with an oversample factor of 3.0f
rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode);
assertNotNull(rescoreContext);
assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f);

// x8 should have RescoreContext with an oversample factor of 2.0f
rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode);
assertNotNull(rescoreContext);
assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f);

// Other compression levels should not have a RescoreContext for ON_DISK mode
assertNull(CompressionLevel.x4.getDefaultRescoreContext(mode));
assertNull(CompressionLevel.x2.getDefaultRescoreContext(mode));
assertNull(CompressionLevel.x1.getDefaultRescoreContext(mode));
assertNull(CompressionLevel.NOT_CONFIGURED.getDefaultRescoreContext(mode));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.opensearch.knn.KNNTestCase;

import static org.opensearch.knn.index.query.rescore.RescoreContext.MAX_FIRST_PASS_RESULTS;
import static org.opensearch.knn.index.query.rescore.RescoreContext.MIN_FIRST_PASS_RESULTS;

public class RescoreContextTests extends KNNTestCase {

Expand All @@ -17,10 +18,43 @@ public void testGetFirstPassK() {
int finalK = 100;
assertEquals(260, rescoreContext.getFirstPassK(finalK));
finalK = 1;
assertEquals(3, rescoreContext.getFirstPassK(finalK));
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK));
finalK = 0;
assertEquals(0, rescoreContext.getFirstPassK(finalK));
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK));
finalK = MAX_FIRST_PASS_RESULTS;
assertEquals(MAX_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK));
}

public void testGetFirstPassKWithMinPassK() {
float oversample = 2.6f;
RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).build();

// Case 1: Test with a finalK that results in a value greater than MIN_FIRST_PASS_RESULTS
int finalK = 100;
assertEquals(260, rescoreContext.getFirstPassK(finalK));

// Case 2: Test with a very small finalK that should result in a value less than MIN_FIRST_PASS_RESULTS
finalK = 1;
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK));

// Case 3: Test with finalK = 0, should return 0
finalK = 0;
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK));

// Case 4: Test with finalK = MAX_FIRST_PASS_RESULTS, should cap at MAX_FIRST_PASS_RESULTS
finalK = MAX_FIRST_PASS_RESULTS;
assertEquals(MAX_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK));

// Case 5: Test where finalK * oversample is smaller than MIN_FIRST_PASS_RESULTS
finalK = 10;
oversample = 0.5f; // This will result in 5, which is less than MIN_FIRST_PASS_RESULTS
rescoreContext = RescoreContext.builder().oversampleFactor(oversample).build();
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK));

// Case 6: Test where finalK * oversample results in exactly MIN_FIRST_PASS_RESULTS
finalK = 100;
oversample = 1.0f; // This will result in exactly 100 (MIN_FIRST_PASS_RESULTS)
rescoreContext = RescoreContext.builder().oversampleFactor(oversample).build();
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,12 @@ public void testGetHNSWEFSearchValue_withDifferentValues_thenSuccess() {
IndexHyperParametersUtil.getHNSWEFConstructionValue(Version.CURRENT)
);
}

public void testGetBinaryQuantizationEFValues_thenSuccess() {
// Test for Binary Quantization EF Construction value
Assert.assertEquals(256, IndexHyperParametersUtil.getBinaryQuantizationEFConstructionValue());

// Test for Binary Quantization EF Search value
Assert.assertEquals(256, IndexHyperParametersUtil.getBinaryQuantizationEFSearchValue());
}
}

0 comments on commit 272324c

Please sign in to comment.