From ffc81e787c0a46a992cfeefa09573bc957fe2577 Mon Sep 17 00:00:00 2001 From: sahil <61558528+buddharajusahil@users.noreply.github.com> Date: Tue, 10 Dec 2024 13:34:53 -0800 Subject: [PATCH] Allow method parameter override for training based indices (solves issue #2246) (#2290) * Allow method parameter override for training based indices Signed-off-by: Sahil Buddharaju * Fixed code squashing imports Signed-off-by: Sahil Buddharaju * Changed changelog Signed-off-by: Sahil Buddharaju * spotlessApply styling Signed-off-by: Sahil Buddharaju --------- Signed-off-by: Sahil Buddharaju Co-authored-by: Sahil Buddharaju (cherry picked from commit 19f045d763cd7f302c2b96c8731b725484b14bec) --- CHANGELOG.md | 1 + .../plugin/rest/RestTrainModelHandler.java | 2 - .../knn/integ/ModeAndCompressionIT.java | 24 +++++- .../action/RestTrainModelHandlerIT.java | 80 +++++++++++++++++++ 4 files changed, 103 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a39546843..ba83529d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] +- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290] ### Bug Fixes * Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282] ### Infrastructure diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java index 4380310c3..837bb3f43 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java @@ -135,8 +135,6 @@ && ensureSpaceTypeNotSet(topLevelSpaceType)) { } ensureAtleastOneSet(KNN_METHOD, knnMethodContext, MODE_PARAMETER, mode, COMPRESSION_LEVEL_PARAMETER, compressionLevel); - ensureMutualExclusion(KNN_METHOD, knnMethodContext, MODE_PARAMETER, mode); - ensureMutualExclusion(KNN_METHOD, knnMethodContext, COMPRESSION_LEVEL_PARAMETER, compressionLevel); ensureSet(DIMENSION, dimension); ensureSet(TRAIN_INDEX_PARAMETER, trainingIndex); diff --git a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java index 91ee89aeb..262221a4c 100644 --- a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java +++ b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java @@ -26,17 +26,23 @@ import java.util.List; import static org.opensearch.knn.common.KNNConstants.COMPRESSION_LEVEL_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST_DEFAULT; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION; import static org.opensearch.knn.common.KNNConstants.MODE_PARAMETER; import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER; import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; @@ -260,17 +266,31 @@ public void testCompressionIndexWithNonVectorFieldsSegment_whenValid_ThenSucceed public void testTraining_whenInvalid_thenFail() { setupTrainingIndex(); String modelId = "test"; + XContentBuilder builder1 = XContentFactory.jsonBuilder() .startObject() .field(TRAIN_INDEX_PARAMETER, TRAINING_INDEX_NAME) .field(TRAIN_FIELD_PARAMETER, TRAINING_FIELD_NAME) .field(KNNConstants.DIMENSION, DIMENSION) + .field(VECTOR_DATA_TYPE_FIELD, "float") + .field(MODEL_DESCRIPTION, "") + .field(MODE_PARAMETER, Mode.ON_DISK) + .field(COMPRESSION_LEVEL_PARAMETER, "16x") .startObject(KNN_METHOD) .field(NAME, METHOD_IVF) .field(KNN_ENGINE, FAISS_NAME) + .field(METHOD_PARAMETER_SPACE_TYPE, "l2") + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, 1) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, "pq") + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2) + .field(ENCODER_PARAMETER_PQ_M, 8) + .endObject() + .endObject() + .endObject() .endObject() - .field(MODEL_DESCRIPTION, "") - .field(MODE_PARAMETER, Mode.ON_DISK) .endObject(); expectThrows(ResponseException.class, () -> trainModel(modelId, builder1)); diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java index 1ba6eae9b..89c8113ee 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java @@ -11,7 +11,9 @@ package org.opensearch.knn.plugin.action; + import org.apache.http.util.EntityUtils; +import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; import org.opensearch.core.xcontent.XContentBuilder; @@ -22,15 +24,22 @@ import java.util.Map; +import static org.opensearch.knn.common.KNNConstants.COMPRESSION_LEVEL_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNConstants.MODE_PARAMETER; import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER; public class RestTrainModelHandlerIT extends KNNRestTestCase { @@ -472,4 +481,75 @@ public void testTrainModel_success_nestedField() throws Exception { assertTrainingSucceeds(modelId, 30, 1000); } + + // Test to checks when user tries to train a model compression/mode and method + public void testTrainModel_success_methodOverrideWithCompressionMode() throws Exception { + String modelId = "test-model-id"; + String trainingIndexName = "train-index"; + String nestedFieldPath = "a.b.train-field"; + int dimension = 8; + + // Create a training index and randomly ingest data into it + String mapping = createKnnIndexNestedMapping(dimension, nestedFieldPath); + createKnnIndex(trainingIndexName, mapping); + int trainingDataCount = 200; + bulkIngestRandomVectorsWithNestedField(trainingIndexName, nestedFieldPath, trainingDataCount, dimension); + + // Call the train API with this definition: + + /* + POST /_plugins/_knn/models/test-model/_train + { + "training_index": "train_index", + "training_field": "train_field", + "dimension": 8, + "description": "model", + "space_type": "innerproduct", + "mode": "on_disk", + "method": { + "name": "ivf", + "params": { + "nlist": 16 + } + } + } + + */ + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, "ivf") + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, 16) + .endObject() + .endObject(); + Map method = xContentBuilderToMap(builder); + + XContentBuilder outerParams = XContentFactory.jsonBuilder() + .startObject() + .field(TRAIN_INDEX_PARAMETER, trainingIndexName) + .field(TRAIN_FIELD_PARAMETER, nestedFieldPath) + .field(DIMENSION, dimension) + .field(COMPRESSION_LEVEL_PARAMETER, "16x") + .field(MODE_PARAMETER, "on_disk") + .field(KNN_METHOD, method) + .field(MODEL_DESCRIPTION, "dummy description") + .endObject(); + + Request request = new Request("POST", "/_plugins/_knn/models/" + modelId + "/_train"); + request.setJsonEntity(outerParams.toString()); + + Response trainResponse = client().performRequest(request); + + assertEquals(RestStatus.OK, RestStatus.fromCode(trainResponse.getStatusLine().getStatusCode())); + + Response getResponse = getModel(modelId, null); + String responseBody = EntityUtils.toString(getResponse.getEntity()); + assertNotNull(responseBody); + + Map responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map(); + + assertEquals(modelId, responseMap.get(MODEL_ID)); + + assertTrainingSucceeds(modelId, 30, 1000); + } }