Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable script score to work with model based indices #1649

Merged
merged 7 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Serialize all models into cluster metadata [#1499](https://github.com/opensearch-project/k-NN/pull/1499)
### Bug Fixes
* Add stored fields for knn_vector type [#1630](https://github.com/opensearch-project/k-NN/pull/1630)
* Enable script score to work with model based indices [#1649](https://github.com/opensearch-project/k-NN/pull/1649)
### Infrastructure
* Add micro-benchmark module in k-NN plugin for benchmark streaming vectors to JNI layer functionality. [#1583](https://github.com/opensearch-project/k-NN/pull/1583)
* Add arm64 check when SIMD is disabled [#1618](https://github.com/opensearch-project/k-NN/pull/1618)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;

import java.util.Arrays;
import java.util.Locale;
Expand All @@ -34,9 +37,22 @@
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue;

/**
* Utility class for KNNVectorFieldMapper
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class KNNVectorFieldMapperUtil {

private static ModelDao modelDao;

/**
* Initializes static instance variables
* @param modelDao ModelDao object
*/
public static void initialize(final ModelDao modelDao) {
KNNVectorFieldMapperUtil.modelDao = modelDao;
}

/**
* Validate the float vector value and throw exception if it is not a number or not in the finite range
* or is not within the FP16 range of [-65504 to 65504].
Expand Down Expand Up @@ -171,4 +187,46 @@ public static Object deserializeStoredVector(BytesRef storedVector, VectorDataTy

return vectorDataType.getVectorFromBytesRef(storedVector);
}

/**
* Get the expected dimensions from a specified knn vector field type.
*
* If the field is model-based, get dimensions from model metadata.
* @param knnVectorFieldType knn vector field type
* @return expected dimensions
*/
public static int getExpectedDimensions(final KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType) {
int expectedDimensions = knnVectorFieldType.getDimension();
if (isModelBasedIndex(expectedDimensions)) {
ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType);
expectedDimensions = modelMetadata.getDimension();
}
return expectedDimensions;
}

private static boolean isModelBasedIndex(int expectedDimensions) {
return expectedDimensions == -1;
}

/**
* Returns the model metadata for a specified knn vector field
*
* @param knnVectorField knn vector field
* @return the model metadata from knnVectorField
*/
private static ModelMetadata getModelMetadataForField(final KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) {
String modelId = knnVectorField.getModelId();

if (modelId == null) {
throw new IllegalArgumentException(
String.format("Field '%s' does not have model.", knnVectorField.getKnnMethodContext().getMethodComponentContext().getName())
);
}

ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (!ModelUtil.isModelCreated(modelMetadata)) {
throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId));
}
return modelMetadata;
}
}
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.indices.SystemIndexDescriptor;
import org.opensearch.knn.index.KNNCircuitBreaker;
import org.opensearch.knn.index.KNNClusterUtil;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
Expand Down Expand Up @@ -204,6 +205,7 @@ public Collection<Object> createComponents(
TrainingJobClusterStateListener.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client);
KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
KNNVectorFieldMapperUtil.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
TrainingModelRequest.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil;
import org.opensearch.knn.index.query.KNNWeight;
import org.apache.lucene.index.LeafReaderContext;
import org.opensearch.index.mapper.MappedFieldType;
Expand All @@ -28,6 +29,7 @@
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToLong;

public interface KNNScoringSpace {

/**
* Return the correct scoring script for a given query. The scoring script
*
Expand Down Expand Up @@ -60,7 +62,7 @@ public L2(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v));
Expand Down Expand Up @@ -96,7 +98,7 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
SpaceType.COSINESIMIL.validateVector(processedQuery);
Expand Down Expand Up @@ -191,7 +193,7 @@ public L1(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v));
Expand Down Expand Up @@ -226,7 +228,7 @@ public LInf(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v));
Expand Down Expand Up @@ -263,7 +265,7 @@ public InnerProd(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
KNNVectorFieldMapperUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
this.scoringMethod = (float[] q, float[] v) -> KNNWeight.normalizeScore(-KNNScoringUtil.innerProduct(q, v));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import static org.opensearch.index.mapper.NumberFieldMapper.NumberType.LONG;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;

/**
* Utility class for KNNScoringSpace
*/
public class KNNScoringSpaceUtil {
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@

import org.apache.lucene.document.StoredField;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;

import java.io.ByteArrayInputStream;
import java.util.Arrays;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class KNNVectorFieldMapperUtilTests extends KNNTestCase {

private static final String TEST_FIELD_NAME = "test_field_name";
Expand Down Expand Up @@ -51,4 +59,59 @@ public void testStoredFields_whenVectorIsFloatType_thenSucceed() {
assertTrue(vector instanceof float[]);
assertArrayEquals(TEST_FLOAT_VECTOR, (float[]) vector, 0.001f);
}

public void testGetExpectedDimensionsSuccess() {
KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(knnVectorFieldType.getDimension()).thenReturn(3);

KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1);
String modelId = "test-model";
when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId);

ModelDao modelDao = mock(ModelDao.class);
ModelMetadata modelMetadata = mock(ModelMetadata.class);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
when(modelMetadata.getDimension()).thenReturn(4);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);

KNNVectorFieldMapperUtil.initialize(modelDao);

assertEquals(3, KNNVectorFieldMapperUtil.getExpectedDimensions(knnVectorFieldType));
assertEquals(4, KNNVectorFieldMapperUtil.getExpectedDimensions(knnVectorFieldTypeModelBased));
}

public void testGetExpectedDimensionsFailure() {
KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1);
String modelId = "test-model";
when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId);

ModelDao modelDao = mock(ModelDao.class);
ModelMetadata modelMetadata = mock(ModelMetadata.class);
when(modelMetadata.getState()).thenReturn(ModelState.TRAINING);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);

KNNVectorFieldMapperUtil.initialize(modelDao);

IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> KNNVectorFieldMapperUtil.getExpectedDimensions(knnVectorFieldTypeModelBased)
);
assertEquals(String.format("Model ID '%s' is not created.", modelId), e.getMessage());

when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(null);
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
MethodComponentContext methodComponentContext = mock(MethodComponentContext.class);
String fieldName = "test-field";
when(methodComponentContext.getName()).thenReturn(fieldName);
when(knnMethodContext.getMethodComponentContext()).thenReturn(methodComponentContext);
when(knnVectorFieldTypeModelBased.getKnnMethodContext()).thenReturn(knnMethodContext);

e = expectThrows(
IllegalArgumentException.class,
() -> KNNVectorFieldMapperUtil.getExpectedDimensions(knnVectorFieldTypeModelBased)
);
assertEquals(String.format("Field '%s' does not have model.", fieldName), e.getMessage());
}
}
Loading