Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable script score to work with model based indices #1649

Merged
merged 7 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Serialize all models into cluster metadata [#1499](https://github.com/opensearch-project/k-NN/pull/1499)
### Bug Fixes
* Add stored fields for knn_vector type [#1630](https://github.com/opensearch-project/k-NN/pull/1630)
* Enable script score to work with model based indices [#1649](https://github.com/opensearch-project/k-NN/pull/1649)
### Infrastructure
* Add micro-benchmark module in k-NN plugin for benchmark streaming vectors to JNI layer functionality. [#1583](https://github.com/opensearch-project/k-NN/pull/1583)
* Add arm64 check when SIMD is disabled [#1618](https://github.com/opensearch-project/k-NN/pull/1618)
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.opensearch.knn.plugin.rest.RestTrainModelHandler;
import org.opensearch.knn.plugin.rest.RestClearCacheHandler;
import org.opensearch.knn.plugin.script.KNNScoringScriptEngine;
import org.opensearch.knn.plugin.script.KNNScoringSpaceUtil;
import org.opensearch.knn.plugin.stats.KNNStats;
import org.opensearch.knn.plugin.transport.DeleteModelAction;
import org.opensearch.knn.plugin.transport.DeleteModelTransportAction;
Expand Down Expand Up @@ -204,6 +205,7 @@ public Collection<Object> createComponents(
TrainingJobClusterStateListener.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client);
KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
KNNScoringSpaceUtil.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
TrainingModelRequest.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToLong;

public interface KNNScoringSpace {

/**
* Return the correct scoring script for a given query. The scoring script
*
Expand Down Expand Up @@ -60,7 +61,7 @@ public L2(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v));
Expand Down Expand Up @@ -96,7 +97,7 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
SpaceType.COSINESIMIL.validateVector(processedQuery);
Expand Down Expand Up @@ -191,7 +192,7 @@ public L1(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v));
Expand Down Expand Up @@ -226,7 +227,7 @@ public LInf(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v));
Expand Down Expand Up @@ -263,7 +264,7 @@ public InnerProd(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
this.scoringMethod = (float[] q, float[] v) -> KNNWeight.normalizeScore(-KNNScoringUtil.innerProduct(q, v));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import java.util.List;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.opensearch.index.mapper.BinaryFieldMapper;
import org.opensearch.index.mapper.MappedFieldType;
Expand All @@ -21,6 +24,12 @@

public class KNNScoringSpaceUtil {
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved

private static ModelDao modelDao;

public static void initialize(ModelDao modelDao) {
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
KNNScoringSpaceUtil.modelDao = modelDao;
}

/**
* Check if the passed in fieldType is of type NumberFieldType with numericType being Long
*
Expand Down Expand Up @@ -137,4 +146,43 @@ public static float getVectorMagnitudeSquared(float[] inputVector) {
}
return normInputVector;
}

/**
* Get the expected dimensions from a specified knn vector field type.
*
* If the field is model-based, get dimensions from model metadata.
* @param knnVectorFieldType knn vector field type
* @return expected dimensions
*/
public static int getExpectedDimensions(KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType) {
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
int expectedDimensions = knnVectorFieldType.getDimension();
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
// Value will be -1 when a model-based index is used. In this case, retrieve expected dimensions from model metadata.
if (expectedDimensions == -1) {
ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType);
expectedDimensions = modelMetadata.getDimension();
}
return expectedDimensions;
}

/**
* Returns the model metadata for a specified knn vector field
*
* @param knnVectorField knn vector field
* @return the model metadata from knnVectorField
*/
private static ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) {
String modelId = knnVectorField.getModelId();

if (modelId == null) {
throw new IllegalArgumentException(
String.format("Field '%s' does not have model.", knnVectorField.getKnnMethodContext().getMethodComponentContext().getName())
);
}

ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (!ModelUtil.isModelCreated(modelMetadata)) {
throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId));
}
return modelMetadata;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
package org.opensearch.knn.plugin.script;

import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.index.mapper.BinaryFieldMapper;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;

import java.math.BigInteger;
import java.util.ArrayList;
Expand Down Expand Up @@ -75,4 +80,44 @@ public void testParseKNNVectorQuery() {
String invalidObject = "invalidObject";
expectThrows(ClassCastException.class, () -> KNNScoringSpaceUtil.parseToFloatArray(invalidObject, 3, VectorDataType.FLOAT));
}

public void testGetExpectedDimensions() {
KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(knnVectorFieldType.getDimension()).thenReturn(3);

KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1);
String modelId = "test-model";
when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId);

ModelDao modelDao = mock(ModelDao.class);
ModelMetadata modelMetadata = mock(ModelMetadata.class);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
when(modelMetadata.getDimension()).thenReturn(4);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);

KNNScoringSpaceUtil.initialize(modelDao);

assertEquals(3, KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldType));
assertEquals(4, KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldTypeModelBased));

when(modelMetadata.getState()).thenReturn(ModelState.TRAINING);

IllegalArgumentException e = expectThrows(
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
IllegalArgumentException.class,
() -> KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldTypeModelBased)
);
assertEquals(String.format("Model ID '%s' is not created.", modelId), e.getMessage());

when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(null);
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
MethodComponentContext methodComponentContext = mock(MethodComponentContext.class);
String fieldName = "test-field";
when(methodComponentContext.getName()).thenReturn(fieldName);
when(knnMethodContext.getMethodComponentContext()).thenReturn(methodComponentContext);
when(knnVectorFieldTypeModelBased.getKnnMethodContext()).thenReturn(knnMethodContext);

e = expectThrows(IllegalArgumentException.class, () -> KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldTypeModelBased));
assertEquals(String.format("Field '%s' does not have model.", fieldName), e.getMessage());
}
}
Loading