Skip to content

Commit

Permalink
Allowed using knn field path when train model
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Nov 29, 2023
1 parent 5e2f899 commit e3bded4
Show file tree
Hide file tree
Showing 9 changed files with 369 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Enhancements
### Bug Fixes
* Fix use-after-free case on nmslib search path [#1305](https://github.com/opensearch-project/k-NN/pull/1305)
* Fixed field value from nested mapping when train model [#1318](https://github.com/opensearch-project/k-NN/pull/1318)
### Infrastructure
* Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289)
### Documentation
Expand Down
52 changes: 51 additions & 1 deletion src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
import java.io.File;
import java.util.Collections;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;

import static org.opensearch.index.mapper.MapperService.INDEX_MAPPING_NESTED_FIELDS_LIMIT_SETTING;
import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES;
import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
Expand Down Expand Up @@ -59,6 +61,38 @@ public static int getFileSizeInKB(String filePath) {
return Math.toIntExact((file.length() / BYTES_PER_KILOBYTES) + 1L); // Add one so that integer division rounds up
}

/**
* This method retrieves a specified field mapping from a set of mapping properties.
*
* @param properties A map representing properties, where each key is a property name and
* the value is either a sub-map of properties or the property value itself.
* @param fieldPaths The field path list that make up the path to the field mapping.
* @return The value of the field if found, or null if the field is not present in the map.
*/
public static Object getFieldMapping(final Map<String, Object> properties, final String[] fieldPaths) {
Object currentFieldMapping = properties;

for (String path : fieldPaths) {
if (!(currentFieldMapping instanceof Map<?, ?>)) {
return null;
}

currentFieldMapping = ((Map<String, Object>) currentFieldMapping).get(path);
if (currentFieldMapping == null) {
return null;
}

if (currentFieldMapping instanceof Map<?, ?>) {
Object possibleProperties = ((Map<String, Object>) currentFieldMapping).get("properties");
if (possibleProperties instanceof Map<?, ?>) {
currentFieldMapping = possibleProperties;
}
}
}

return currentFieldMapping;
}

/**
* Validate that a field is a k-NN vector field and has the expected dimension
*
Expand Down Expand Up @@ -100,7 +134,23 @@ public static ValidationException validateKnnField(
return exception;
}

Object fieldMapping = properties.get(field);
// Check field path is valid
if (field.isEmpty()) {
exception.addValidationError("Field path is empty");
return exception;
}

String[] fieldPaths = field.split("\\.");

Long nestedFieldMaxLimit = INDEX_MAPPING_NESTED_FIELDS_LIMIT_SETTING.get(indexMetadata.getSettings());

// Check filed path length is valid
if (fieldPaths.length == 0 || fieldPaths.length > nestedFieldMaxLimit) {
exception.addValidationError(String.format(Locale.ROOT, "Field path length is invalid. Max length is %d", nestedFieldMaxLimit));
return exception;
}

Object fieldMapping = getFieldMapping(properties, fieldPaths);

// Check field existence
if (fieldMapping == null) {
Expand Down
36 changes: 32 additions & 4 deletions src/main/java/org/opensearch/knn/training/VectorReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
import org.opensearch.search.sort.SortOrder;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;

public class VectorReader {
Expand Down Expand Up @@ -182,10 +184,21 @@ public void onResponse(SearchResponse searchResponse) {
int vectorsToAdd = Integer.min(maxVectorCount - collectedVectorCount, hits.length);
List<Float[]> trainingData = new ArrayList<>();

for (int i = 0; i < vectorsToAdd; i++) {
trainingData.add(
((List<Number>) hits[i].getSourceAsMap().get(fieldName)).stream().map(Number::floatValue).toArray(Float[]::new)
);
for (int vector = 0; vector < vectorsToAdd; vector++) {
Map<String, Object> originalSourceMap = hits[vector].getSourceAsMap();
Map<String, Object> sourceMap = deepCopyMap(originalSourceMap);
// The field name may be a nested field, so we need to split it and traverse the map.
// Example fieldName: "my_field" or "my_field.nested_field.nested_nested_field"
String[] fieldPath = fieldName.split("\\.");
Map<String, Object> currentMap = sourceMap;

for (int pathPart = 0; pathPart < fieldPath.length - 1; pathPart++) {
currentMap = (Map<String, Object>) currentMap.get(fieldPath[pathPart]);
}

List<Number> fieldList = (List<Number>) currentMap.get(fieldPath[fieldPath.length - 1]);

trainingData.add(fieldList.stream().map(Number::floatValue).toArray(Float[]::new));
}

this.collectedVectorCount += trainingData.size();
Expand Down Expand Up @@ -225,5 +238,20 @@ public void onFailure(Exception e) {
listener.onFailure(e);
}
}

private Map<String, Object> deepCopyMap(Map<String, Object> original) {
Map<String, Object> copy = new HashMap<>();
for (Map.Entry<String, Object> entry : original.entrySet()) {
Object value = entry.getValue();
if (value instanceof Map<?, ?>) {
copy.put(entry.getKey(), deepCopyMap((Map<String, Object>) value));
} else if (value instanceof List<?>) {
copy.put(entry.getKey(), new ArrayList<>((List<?>) value));
} else {
copy.put(entry.getKey(), value);
}
}
return copy;
}
}
}
54 changes: 54 additions & 0 deletions src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,35 @@ protected void createKnnIndexMapping(String indexName, String fieldName, Integer
OpenSearchAssertions.assertAcked(client().admin().indices().putMapping(request).actionGet());
}

/**
* Create simple k-NN mapping which can be nested.
* e.g. fieldPath = "a.b.c" will create mapping for "c" as knn_vector
*/
protected void createKnnNestedIndexMapping(String indexName, String fieldPath, Integer dimensions) throws IOException {
PutMappingRequest request = new PutMappingRequest(indexName);
String[] path = fieldPath.split("\\.");
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject("properties");
for (int i = 0; i < path.length; i++) {
xContentBuilder.startObject(path[i]);
if (i == path.length - 1) {
xContentBuilder.field("type", "knn_vector").field("dimension", dimensions.toString());
} else {
xContentBuilder.startObject("properties");
}
}
for (int i = path.length - 1; i >= 0; i--) {
if (i != path.length - 1) {
xContentBuilder.endObject();
}
xContentBuilder.endObject();
}
xContentBuilder.endObject().endObject();

request.source(xContentBuilder);

OpenSearchAssertions.assertAcked(client().admin().indices().putMapping(request).actionGet());
}

/**
* Get default k-NN settings for test cases
*/
Expand All @@ -112,6 +141,31 @@ protected void addKnnDoc(String index, String docId, String fieldName, Object[]
assertEquals(response.status(), RestStatus.CREATED);
}

/**
* Add a k-NN doc to an index with nested knn_vector field
*/
protected void addKnnNestedDoc(String index, String docId, String fieldPath, Object[] vector) throws IOException, InterruptedException,
ExecutionException {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
String[] fieldParts = fieldPath.split("\\.");

for (int i = 0; i < fieldParts.length - 1; i++) {
builder.startObject(fieldParts[i]);
}
builder.field(fieldParts[fieldParts.length - 1], vector);
for (int i = fieldParts.length - 2; i >= 0; i--) {
builder.endObject();
}
builder.endObject();
IndexRequest indexRequest = new IndexRequest().index(index)
.id(docId)
.source(builder)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

IndexResponse response = client().index(indexRequest).get();
assertEquals(response.status(), RestStatus.CREATED);
}

/**
* Add any document to index
*/
Expand Down
31 changes: 31 additions & 0 deletions src/test/java/org/opensearch/knn/index/IndexUtilTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.util.KNNEngine;

import java.util.Collections;
import java.util.Map;

import static org.mockito.ArgumentMatchers.anyString;
Expand Down Expand Up @@ -67,4 +68,34 @@ public void testGetLoadParameters() {
assertEquals(spaceType2.getValue(), loadParameters.get(SPACE_TYPE));
assertEquals(efSearchValue, loadParameters.get(HNSW_ALGO_EF_SEARCH));
}

public void testGetFieldMappingNonNestedField() {
Map<String, Object> fieldValues = Map.of("type", "knn_vector", "dimension", 8);
Map<String, Object> properties = Map.of("top_level_field", fieldValues);
String[] field = { "top_level_field" };

Object result = IndexUtil.getFieldMapping(properties, field);
assertEquals(fieldValues, result);
}

public void testGetFieldMappingNestedField() {
Map<String, Object> deepFieldValues = Map.of("type", "knn_vector", "dimension", 8);
Map<String, Object> deepField = Map.of("train-field", deepFieldValues);
Map<String, Object> deepFieldProperties = Map.of("properties", deepField);
Map<String, Object> nest_b = Map.of("b", deepFieldProperties);
Map<String, Object> nest_b_properties = Map.of("properties", nest_b);
Map<String, Object> nest_a = Map.of("a", nest_b_properties);
String[] field = { "a", "b", "train-field" };

Object deepResult = IndexUtil.getFieldMapping(nest_a, field);
assertEquals(deepFieldValues, deepResult);
}

public void testGetFieldMappingEmptyProperties() {
Map<String, Object> properties = Collections.emptyMap();
String[] field = { "top_level_field" };

Object result = IndexUtil.getFieldMapping(properties, field);
assertNull(result);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -341,4 +341,74 @@ public void testTrainModel_success_noId() throws Exception {

assertTrainingSucceeds(modelId, 30, 1000);
}

// Test to checks when user tries to train a model with nested fields
public void testTrainModel_success_nestedField() throws Exception {
String modelId = "test-model-id";
String trainingIndexName = "train-index";
String nestedFieldPath = "a.b.train-field";
int dimension = 8;

// Create a training index and randomly ingest data into it
String mapping = createKnnIndexNestedMapping(dimension, nestedFieldPath);
createKnnIndex(trainingIndexName, mapping);
int trainingDataCount = 200;
bulkIngestRandomVectorsWithNestedField(trainingIndexName, nestedFieldPath, trainingDataCount, dimension);

// Call the train API with this definition:
/*
{
"training_index": "train_index",
"training_field": "a.b.train_field",
"dimension": 8,
"description": "this should be allowed to be null",
"method": {
"name":"ivf",
"engine":"faiss",
"space_type": "l2",
"parameters":{
"nlist":1,
"encoder":{
"name":"pq",
"parameters":{
"code_size":2,
"m": 2
}
}
}
}
}
*/
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.field(NAME, "ivf")
.field(KNN_ENGINE, "faiss")
.field(METHOD_PARAMETER_SPACE_TYPE, "l2")
.startObject(PARAMETERS)
.field(METHOD_PARAMETER_NLIST, 1)
.startObject(METHOD_ENCODER_PARAMETER)
.field(NAME, "pq")
.startObject(PARAMETERS)
.field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2)
.field(ENCODER_PARAMETER_PQ_M, 2)
.endObject()
.endObject()
.endObject()
.endObject();
Map<String, Object> method = xContentBuilderToMap(builder);

Response trainResponse = trainModel(modelId, trainingIndexName, nestedFieldPath, dimension, method, "dummy description");

assertEquals(RestStatus.OK, RestStatus.fromCode(trainResponse.getStatusLine().getStatusCode()));

Response getResponse = getModel(modelId, null);
String responseBody = EntityUtils.toString(getResponse.getEntity());
assertNotNull(responseBody);

Map<String, Object> responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();

assertEquals(modelId, responseMap.get(MODEL_ID));

assertTrainingSucceeds(modelId, 30, 1000);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.ValidationException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.KNNMethodContext;
Expand Down Expand Up @@ -409,6 +410,7 @@ public void testValidation_invalid_trainingFieldNotKnnVector() {
MappingMetadata mappingMetadata = mock(MappingMetadata.class);
when(mappingMetadata.getSourceAsMap()).thenReturn(mappingMap);
IndexMetadata indexMetadata = mock(IndexMetadata.class);
when(indexMetadata.getSettings()).thenReturn(Settings.builder().put("index.mapping.nested_fields.limit", 5).build());
when(indexMetadata.mapping()).thenReturn(mappingMetadata);
Metadata metadata = mock(Metadata.class);
when(metadata.index(trainingIndex)).thenReturn(indexMetadata);
Expand Down Expand Up @@ -468,6 +470,7 @@ public void testValidation_invalid_dimensionDoesNotMatch() {
when(mappingMetadata.getSourceAsMap()).thenReturn(mappingMap);
IndexMetadata indexMetadata = mock(IndexMetadata.class);
when(indexMetadata.mapping()).thenReturn(mappingMetadata);
when(indexMetadata.getSettings()).thenReturn(Settings.builder().put("index.mapping.nested_fields.limit", 5).build());
Metadata metadata = mock(Metadata.class);
when(metadata.index(trainingIndex)).thenReturn(indexMetadata);
ClusterState clusterState = mock(ClusterState.class);
Expand Down Expand Up @@ -523,6 +526,7 @@ public void testValidation_invalid_preferredNodeDoesNotExist() {
when(mappingMetadata.getSourceAsMap()).thenReturn(mappingMap);
IndexMetadata indexMetadata = mock(IndexMetadata.class);
when(indexMetadata.mapping()).thenReturn(mappingMetadata);
when(indexMetadata.getSettings()).thenReturn(Settings.builder().put("index.mapping.nested_fields.limit", 5).build());
Metadata metadata = mock(Metadata.class);
when(metadata.index(trainingIndex)).thenReturn(indexMetadata);

Expand Down Expand Up @@ -675,6 +679,7 @@ public void testValidation_valid_trainingIndexBuiltFromModel() {
when(mappingMetadata.getSourceAsMap()).thenReturn(mappingMap);
IndexMetadata indexMetadata = mock(IndexMetadata.class);
when(indexMetadata.mapping()).thenReturn(mappingMetadata);
when(indexMetadata.getSettings()).thenReturn(Settings.builder().put("index.mapping.nested_fields.limit", 5).build());
Metadata metadata = mock(Metadata.class);
when(metadata.index(trainingIndex)).thenReturn(indexMetadata);
DiscoveryNodes discoveryNodes = mock(DiscoveryNodes.class);
Expand Down Expand Up @@ -712,6 +717,7 @@ private ClusterService getClusterServiceForValidReturns(String trainingIndex, St
when(mappingMetadata.getSourceAsMap()).thenReturn(mappingMap);
IndexMetadata indexMetadata = mock(IndexMetadata.class);
when(indexMetadata.mapping()).thenReturn(mappingMetadata);
when(indexMetadata.getSettings()).thenReturn(Settings.builder().put("index.mapping.nested_fields.limit", 5).build());
Metadata metadata = mock(Metadata.class);
when(metadata.index(trainingIndex)).thenReturn(indexMetadata);
ClusterState clusterState = mock(ClusterState.class);
Expand Down
Loading

0 comments on commit e3bded4

Please sign in to comment.