From 4734d88b8b12ff3edd676e3a1f337fc814c8932f Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Fri, 15 Mar 2024 11:24:00 -0700 Subject: [PATCH] Persist model definition in model metadata (#1527) * Add MethodComponentContext to ModelMetadata Signed-off-by: Ryan Bogan * Add changelog Signed-off-by: Ryan Bogan * Address PR Comments Signed-off-by: Ryan Bogan * Address PR Comments Signed-off-by: Ryan Bogan * Change fromString Signed-off-by: Ryan Bogan * Address PR Comments Signed-off-by: Ryan Bogan * Address PR Comments Signed-off-by: Ryan Bogan * Address PR Comments Signed-off-by: Ryan Bogan * Fix spotless Signed-off-by: Ryan Bogan --------- Signed-off-by: Ryan Bogan --- CHANGELOG.md | 1 + .../opensearch/knn/common/KNNConstants.java | 1 + .../org/opensearch/knn/index/IndexUtil.java | 3 + .../knn/index/MethodComponentContext.java | 202 +++++++++++ .../org/opensearch/knn/indices/ModelDao.java | 11 + .../opensearch/knn/indices/ModelMetadata.java | 140 ++++++-- .../opensearch/knn/training/TrainingJob.java | 3 +- src/main/resources/mappings/model-index.json | 6 + .../index/KNNCreateIndexFromModelTests.java | 3 +- .../index/MethodComponentContextTests.java | 69 ++++ .../KNN80DocValuesConsumerTests.java | 12 +- .../knn/index/codec/KNNCodecTestCase.java | 3 +- .../mapper/KNNVectorFieldMapperTests.java | 6 +- .../knn/indices/ModelCacheTests.java | 38 ++- .../opensearch/knn/indices/ModelDaoTests.java | 44 ++- .../knn/indices/ModelMetadataTests.java | 316 ++++++++++++++++-- .../opensearch/knn/indices/ModelTests.java | 43 ++- .../transport/GetModelResponseTests.java | 17 +- ...oveModelFromCacheTransportActionTests.java | 13 +- .../transport/TrainingModelRequestTests.java | 4 +- .../UpdateModelMetadataRequestTests.java | 10 +- ...dateModelMetadataTransportActionTests.java | 4 +- .../knn/training/TrainingJobTests.java | 6 +- 23 files changed, 834 insertions(+), 121 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c345fb1ce..a669bfb2f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Optize Faiss Query With Filters: Reduce iteration and memory for id filter [#1402](https://github.com/opensearch-project/k-NN/pull/1402) * Detect AVX2 Dynamically on the System [#1502](https://github.com/opensearch-project/k-NN/pull/1502) * Validate zero vector when using cosine metric [#1501](https://github.com/opensearch-project/k-NN/pull/1501) +* Persist model definition in model metadata [#1527] (https://github.com/opensearch-project/k-NN/pull/1527) ### Bug Fixes * Disable sdc table for HNSWPQ read-only indices [#1518](https://github.com/opensearch-project/k-NN/pull/1518) * Switch SpaceType.INNERPRODUCT's vector similarity function to MAXIMUM_INNER_PRODUCT [#1532](https://github.com/opensearch-project/k-NN/pull/1532) diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index d6c6f3450..269f774b5 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -48,6 +48,7 @@ public class KNNConstants { public static final String MODEL_DESCRIPTION = "description"; public static final String MODEL_ERROR = "error"; public static final String MODEL_NODE_ASSIGNMENT = "training_node_assignment"; + public static final String MODEL_METHOD_COMPONENT_CONTEXT = "model_definition"; public static final String PARAM_SIZE = "size"; public static final Integer SEARCH_MODEL_MIN_SIZE = 1; public static final Integer SEARCH_MODEL_MAX_SIZE = 1000; diff --git a/src/main/java/org/opensearch/knn/index/IndexUtil.java b/src/main/java/org/opensearch/knn/index/IndexUtil.java index e98c00197..1b385319a 100644 --- a/src/main/java/org/opensearch/knn/index/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/IndexUtil.java @@ -37,13 +37,16 @@ public class IndexUtil { public static final String MODEL_NODE_ASSIGNMENT_KEY = KNNConstants.MODEL_NODE_ASSIGNMENT; + public static final String MODEL_METHOD_COMPONENT_CONTEXT_KEY = KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED = Version.V_2_11_0; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT = Version.V_2_12_0; + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT = Version.V_2_13_0; private static final Map minimalRequiredVersionMap = new HashMap() { { put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED); put(MODEL_NODE_ASSIGNMENT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT); + put(MODEL_METHOD_COMPONENT_CONTEXT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT); } }; diff --git a/src/main/java/org/opensearch/knn/index/MethodComponentContext.java b/src/main/java/org/opensearch/knn/index/MethodComponentContext.java index 66952f448..b9fd56b72 100644 --- a/src/main/java/org/opensearch/knn/index/MethodComponentContext.java +++ b/src/main/java/org/opensearch/knn/index/MethodComponentContext.java @@ -11,24 +11,29 @@ package org.opensearch.knn.index; +import lombok.AllArgsConstructor; import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.Setter; +import org.apache.commons.lang.math.NumberUtils; import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.mapper.MapperParsingException; import java.io.IOException; import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.Objects; import java.util.stream.Collectors; import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; +import org.opensearch.knn.indices.ModelMetadata; import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; @@ -41,6 +46,13 @@ @RequiredArgsConstructor public class MethodComponentContext implements ToXContentFragment, Writeable { + // EMPTY method component context can only occur if a model originated on a cluster before 2.13.0 and the cluster is then upgraded to + // 2.13.0 + public static final MethodComponentContext EMPTY = new MethodComponentContext("", Collections.emptyMap()); + + private static final String DELIMITER = ";"; + private static final String DELIMITER_PLACEHOLDER = "$%$"; + @Getter private final String name; private final Map parameters; @@ -161,6 +173,15 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } + public static MethodComponentContext fromXContent(XContentParser xContentParser) throws IOException { + // If it is a fresh parser, move to the first token + if (xContentParser.currentToken() == null) { + xContentParser.nextToken(); + } + Map parsedMap = xContentParser.map(); + return MethodComponentContext.parse(parsedMap); + } + @Override public boolean equals(Object obj) { if (this == obj) return true; @@ -193,6 +214,187 @@ public Map getParameters() { return parameters; } + /** + * + * Provides a String representation of MethodComponentContext + * Sample return: + * {name=ivf;parameters=[nlist=4;type=fp16;encoder={name=sq;parameters=[nprobes=2;clip=false;]};]} + * + * @return string representation + */ + public String toClusterStateString() { + StringBuilder stringBuilder = new StringBuilder(); + stringBuilder.append("{name=").append(name).append(DELIMITER); + stringBuilder.append("parameters=["); + if (Objects.nonNull(parameters)) { + for (Map.Entry entry : parameters.entrySet()) { + stringBuilder.append(entry.getKey()).append("="); + Object objectValue = entry.getValue(); + String value; + if (objectValue instanceof MethodComponentContext) { + value = ((MethodComponentContext) objectValue).toClusterStateString(); + } else { + value = entry.getValue().toString(); + } + // Model Metadata uses a delimiter to split the input string in its fromString method + // https://github.com/opensearch-project/k-NN/blob/2.12/src/main/java/org/opensearch/knn/indices/ModelMetadata.java#L265 + // If any of the values in the method component context contain this delimiter, + // then the method will not work correctly. Therefore, we replace the delimiter with an uncommon + // sequence that is very unlikely to appear in the value itself. + // https://github.com/opensearch-project/k-NN/issues/1337 + value = value.replace(ModelMetadata.DELIMITER, DELIMITER_PLACEHOLDER); + stringBuilder.append(value).append(DELIMITER); + } + } + stringBuilder.append("]}"); + return stringBuilder.toString(); + } + + /** + * This method converts a string created by the toClusterStateString() method of MethodComponentContext + * to a MethodComponentContext object. + * + * @param in a string representation of MethodComponentContext + * @return a MethodComponentContext object + */ + public static MethodComponentContext fromClusterStateString(String in) { + String stringToParse = unwrapString(in, '{', '}'); + + // Parse name from string + String[] nameAndParameters = stringToParse.split(DELIMITER, 2); + checkExpectedArrayLength(nameAndParameters, 2); + String name = parseName(nameAndParameters[0]); + String parametersString = nameAndParameters[1]; + Map parameters = parseParameters(parametersString); + return new MethodComponentContext(name, parameters); + } + + private static String parseName(String candidateNameString) { + // Expecting candidateNameString to look like "name=ivf" + checkStringNotEmpty(candidateNameString); + String[] nameKeyAndValue = candidateNameString.split("="); + checkStringMatches(nameKeyAndValue[0], "name"); + if (nameKeyAndValue.length == 1) { + return ""; + } + checkExpectedArrayLength(nameKeyAndValue, 2); + return nameKeyAndValue[1]; + } + + private static Map parseParameters(String candidateParameterString) { + checkStringNotEmpty(candidateParameterString); + String[] parametersKeyAndValue = candidateParameterString.split("=", 2); + checkStringMatches(parametersKeyAndValue[0], "parameters"); + if (parametersKeyAndValue.length == 1) { + return Collections.emptyMap(); + } + checkExpectedArrayLength(parametersKeyAndValue, 2); + return parseParametersValue(parametersKeyAndValue[1]); + } + + private static Map parseParametersValue(String candidateParameterValueString) { + // Expected input is [nlist=4;type=fp16;encoder={name=sq;parameters=[nprobes=2;clip=false;]};] + checkStringNotEmpty(candidateParameterValueString); + candidateParameterValueString = unwrapString(candidateParameterValueString, '[', ']'); + Map parameters = new HashMap<>(); + while (!candidateParameterValueString.isEmpty()) { + String[] keyAndValueToParse = candidateParameterValueString.split("=", 2); + if (keyAndValueToParse.length == 1 && keyAndValueToParse[0].charAt(0) == ';') { + break; + } + String key = keyAndValueToParse[0]; + ValueAndRestToParse parsed = parseParameterValueAndRestToParse(keyAndValueToParse[1]); + parameters.put(key, parsed.getValue()); + candidateParameterValueString = parsed.getRestToParse(); + } + + return parameters; + } + + private static ValueAndRestToParse parseParameterValueAndRestToParse(String candidateParameterValueAndRestToParse) { + if (candidateParameterValueAndRestToParse.charAt(0) == '{') { + int endOfNestedMap = findClosingPosition(candidateParameterValueAndRestToParse, '{', '}'); + String nestedMethodContext = candidateParameterValueAndRestToParse.substring(0, endOfNestedMap + 1); + Object nestedParse = fromClusterStateString(nestedMethodContext); + String restToParse = candidateParameterValueAndRestToParse.substring(endOfNestedMap + 1); + return new ValueAndRestToParse(nestedParse, restToParse); + } + + String[] stringValueAndRestToParse = candidateParameterValueAndRestToParse.split(DELIMITER, 2); + String stringValue = stringValueAndRestToParse[0]; + Object value; + if (NumberUtils.isNumber(stringValue)) { + value = Integer.parseInt(stringValue); + } else if (stringValue.equals("true") || stringValue.equals("false")) { + value = Boolean.parseBoolean(stringValue); + } else { + stringValue = stringValue.replace(DELIMITER_PLACEHOLDER, ModelMetadata.DELIMITER); + value = stringValue; + } + + return new ValueAndRestToParse(value, stringValueAndRestToParse[1]); + } + + private static String unwrapString(String in, char expectedStart, char expectedEnd) { + if (in.length() < 2) { + throw new IllegalArgumentException("Invalid string."); + } + + if (in.charAt(0) != expectedStart || in.charAt(in.length() - 1) != expectedEnd) { + throw new IllegalArgumentException("Invalid string." + in); + } + return in.substring(1, in.length() - 1); + } + + private static int findClosingPosition(String in, char expectedStart, char expectedEnd) { + int nestedLevel = 0; + for (int i = 0; i < in.length(); i++) { + if (in.charAt(i) == expectedStart) { + nestedLevel++; + continue; + } + + if (in.charAt(i) == expectedEnd) { + nestedLevel--; + } + + if (nestedLevel == 0) { + return i; + } + } + + throw new IllegalArgumentException("Invalid string. No end to the nesting"); + } + + private static void checkStringNotEmpty(String string) { + if (string.isEmpty()) { + throw new IllegalArgumentException("Unable to parse MethodComponentContext"); + } + } + + private static void checkStringMatches(String string, String expected) { + if (!Objects.equals(string, expected)) { + throw new IllegalArgumentException("Unexpected key in MethodComponentContext. Expected 'name' or 'parameters'"); + } + } + + private static void checkExpectedArrayLength(String[] array, int expectedLength) { + if (null == array) { + throw new IllegalArgumentException("Error parsing MethodComponentContext. Array is null."); + } + + if (array.length != expectedLength) { + throw new IllegalArgumentException("Error parsing MethodComponentContext. Array is not expected length."); + } + } + + @AllArgsConstructor + @Getter + private static class ValueAndRestToParse { + private final Object value; + private final String restToParse; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(this.name); diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index e1c30cc86..0c0f08545 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -43,10 +43,14 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexNotFoundException; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.common.exception.DeleteModelWhenInTrainStateException; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.plugin.transport.DeleteModelResponse; import org.opensearch.knn.plugin.transport.GetModelResponse; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction; @@ -288,6 +292,13 @@ private void putInternal(Model model, ActionListener listener, Do put(KNNConstants.MODEL_DESCRIPTION, modelMetadata.getDescription()); put(KNNConstants.MODEL_ERROR, modelMetadata.getError()); put(KNNConstants.MODEL_NODE_ASSIGNMENT, modelMetadata.getNodeAssignment()); + + MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext(); + if (!methodComponentContext.getName().isEmpty()) { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder = methodComponentContext.toXContent(builder, ToXContent.EMPTY_PARAMS).endObject(); + put(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT, builder.toString()); + } } }; diff --git a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java index 04836f184..fa88c8416 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java +++ b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java @@ -14,13 +14,17 @@ import lombok.extern.log4j.Log4j2; import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; +import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.IndexUtil; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; @@ -29,19 +33,12 @@ import java.util.Objects; import java.util.concurrent.atomic.AtomicReference; -import static org.opensearch.knn.common.KNNConstants.DIMENSION; -import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; -import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION; -import static org.opensearch.knn.common.KNNConstants.MODEL_ERROR; -import static org.opensearch.knn.common.KNNConstants.MODEL_STATE; -import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP; -import static org.opensearch.knn.common.KNNConstants.MODEL_NODE_ASSIGNMENT; +import static org.opensearch.core.xcontent.DeprecationHandler.IGNORE_DEPRECATIONS; @Log4j2 public class ModelMetadata implements Writeable, ToXContentObject { - private static final String DELIMITER = ","; + public static final String DELIMITER = ","; final private KNNEngine knnEngine; final private SpaceType spaceType; @@ -51,6 +48,7 @@ public class ModelMetadata implements Writeable, ToXContentObject { final private String timestamp; final private String description; final private String trainingNodeAssignment; + private MethodComponentContext methodComponentContext; private String error; /** @@ -76,6 +74,12 @@ public ModelMetadata(StreamInput in) throws IOException { } else { this.trainingNodeAssignment = ""; } + + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), IndexUtil.MODEL_METHOD_COMPONENT_CONTEXT_KEY)) { + this.methodComponentContext = new MethodComponentContext(in); + } else { + this.methodComponentContext = MethodComponentContext.EMPTY; + } } /** @@ -88,6 +92,8 @@ public ModelMetadata(StreamInput in) throws IOException { * @param timestamp timevalue when model was created * @param description information about the model * @param error error message associated with model + * @param trainingNodeAssignment node assignment for the model + * @param methodComponentContext method component context associated with model */ public ModelMetadata( KNNEngine knnEngine, @@ -97,7 +103,8 @@ public ModelMetadata( String timestamp, String description, String error, - String trainingNodeAssignment + String trainingNodeAssignment, + MethodComponentContext methodComponentContext ) { this.knnEngine = Objects.requireNonNull(knnEngine, "knnEngine must not be null"); this.spaceType = Objects.requireNonNull(spaceType, "spaceType must not be null"); @@ -118,6 +125,7 @@ public ModelMetadata( this.description = Objects.requireNonNull(description, "description must not be null"); this.error = Objects.requireNonNull(error, "error must not be null"); this.trainingNodeAssignment = Objects.requireNonNull(trainingNodeAssignment, "node assignment must not be null"); + this.methodComponentContext = Objects.requireNonNull(methodComponentContext, "method context must not be null"); } /** @@ -192,6 +200,15 @@ public String getNodeAssignment() { return trainingNodeAssignment; } + /** + * getter for model's method context + * + * @return knnMethodContext + */ + public MethodComponentContext getMethodComponentContext() { + return methodComponentContext; + } + /** * setter for model's state * @@ -221,7 +238,8 @@ public String toString() { timestamp, description, error, - trainingNodeAssignment + trainingNodeAssignment, + methodComponentContext.toClusterStateString() ); } @@ -252,6 +270,7 @@ public int hashCode() { .append(getTimestamp()) .append(getDescription()) .append(getError()) + .append(getMethodComponentContext()) .toHashCode(); } @@ -268,7 +287,9 @@ public static ModelMetadata fromString(String modelMetadataString) { // Because models can be created on older versions and the cluster can be upgraded after, // we need to accept model metadata arrays both with and without the training node assignment. if (modelMetadataArray.length == 7) { - log.debug("Model metadata array does not contain training node assignment. Assuming empty string."); + log.debug( + "Model metadata array does not contain training node assignment or method component context. Assuming empty string node assignment and empty method component context." + ); KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]); SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]); int dimension = Integer.parseInt(modelMetadataArray[2]); @@ -276,9 +297,19 @@ public static ModelMetadata fromString(String modelMetadataString) { String timestamp = modelMetadataArray[4]; String description = modelMetadataArray[5]; String error = modelMetadataArray[6]; - return new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error, ""); + return new ModelMetadata( + knnEngine, + spaceType, + dimension, + modelState, + timestamp, + description, + error, + "", + MethodComponentContext.EMPTY + ); } else if (modelMetadataArray.length == 8) { - log.debug("Model metadata contains training node assignment"); + log.debug("Model metadata contains training node assignment. Assuming empty method component context."); KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]); SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]); int dimension = Integer.parseInt(modelMetadataArray[2]); @@ -287,11 +318,43 @@ public static ModelMetadata fromString(String modelMetadataString) { String description = modelMetadataArray[5]; String error = modelMetadataArray[6]; String trainingNodeAssignment = modelMetadataArray[7]; - return new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error, trainingNodeAssignment); + return new ModelMetadata( + knnEngine, + spaceType, + dimension, + modelState, + timestamp, + description, + error, + trainingNodeAssignment, + MethodComponentContext.EMPTY + ); + } else if (modelMetadataArray.length == 9) { + log.debug("Model metadata contains training node assignment and method context"); + KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]); + SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]); + int dimension = Integer.parseInt(modelMetadataArray[2]); + ModelState modelState = ModelState.getModelState(modelMetadataArray[3]); + String timestamp = modelMetadataArray[4]; + String description = modelMetadataArray[5]; + String error = modelMetadataArray[6]; + String trainingNodeAssignment = modelMetadataArray[7]; + MethodComponentContext methodComponentContext = MethodComponentContext.fromClusterStateString(modelMetadataArray[8]); + return new ModelMetadata( + knnEngine, + spaceType, + dimension, + modelState, + timestamp, + description, + error, + trainingNodeAssignment, + methodComponentContext + ); } else { throw new IllegalArgumentException( "Illegal format for model metadata. Must be of the form " - + "\",,,,,,\" or \",,,,,,,\"." + + "\",,,,,,\" or \",,,,,,,\" or \",,,,,,,,\"." ); } } @@ -321,11 +384,27 @@ public static ModelMetadata getMetadataFromSourceMap(final Map m Object description = modelSourceMap.get(KNNConstants.MODEL_DESCRIPTION); Object error = modelSourceMap.get(KNNConstants.MODEL_ERROR); Object trainingNodeAssignment = modelSourceMap.get(KNNConstants.MODEL_NODE_ASSIGNMENT); + Object methodComponentContext = modelSourceMap.get(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT); if (trainingNodeAssignment == null) { trainingNodeAssignment = ""; } + if (Objects.nonNull(methodComponentContext)) { + try { + XContentParser xContentParser = JsonXContent.jsonXContent.createParser( + NamedXContentRegistry.EMPTY, + IGNORE_DEPRECATIONS, + objectToString(methodComponentContext) + ); + methodComponentContext = MethodComponentContext.fromXContent(xContentParser); + } catch (IOException e) { + throw new IllegalArgumentException("Error parsing method component context"); + } + } else { + methodComponentContext = MethodComponentContext.EMPTY; + } + ModelMetadata modelMetadata = new ModelMetadata( KNNEngine.getEngine(objectToString(engine)), SpaceType.getSpace(objectToString(space)), @@ -334,7 +413,8 @@ public static ModelMetadata getMetadataFromSourceMap(final Map m objectToString(timestamp), objectToString(description), objectToString(error), - objectToString(trainingNodeAssignment) + objectToString(trainingNodeAssignment), + (MethodComponentContext) methodComponentContext ); return modelMetadata; } @@ -351,20 +431,28 @@ public void writeTo(StreamOutput out) throws IOException { if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), IndexUtil.MODEL_NODE_ASSIGNMENT_KEY)) { out.writeString(getNodeAssignment()); } + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), IndexUtil.MODEL_METHOD_COMPONENT_CONTEXT_KEY)) { + getMethodComponentContext().writeTo(out); + } } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.field(MODEL_STATE, getState().getName()); - builder.field(MODEL_TIMESTAMP, getTimestamp()); - builder.field(MODEL_DESCRIPTION, getDescription()); - builder.field(MODEL_ERROR, getError()); - - builder.field(METHOD_PARAMETER_SPACE_TYPE, getSpaceType().getValue()); - builder.field(DIMENSION, getDimension()); - builder.field(KNN_ENGINE, getKnnEngine().getName()); + builder.field(KNNConstants.MODEL_STATE, getState().getName()); + builder.field(KNNConstants.MODEL_TIMESTAMP, getTimestamp()); + builder.field(KNNConstants.MODEL_DESCRIPTION, getDescription()); + builder.field(KNNConstants.MODEL_ERROR, getError()); + + builder.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, getSpaceType().getValue()); + builder.field(KNNConstants.DIMENSION, getDimension()); + builder.field(KNNConstants.KNN_ENGINE, getKnnEngine().getName()); if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(IndexUtil.MODEL_NODE_ASSIGNMENT_KEY)) { - builder.field(MODEL_NODE_ASSIGNMENT, getNodeAssignment()); + builder.field(KNNConstants.MODEL_NODE_ASSIGNMENT, getNodeAssignment()); + } + if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(IndexUtil.MODEL_METHOD_COMPONENT_CONTEXT_KEY)) { + builder.field(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT).startObject(); + getMethodComponentContext().toXContent(builder, params); + builder.endObject(); } return builder; } diff --git a/src/main/java/org/opensearch/knn/training/TrainingJob.java b/src/main/java/org/opensearch/knn/training/TrainingJob.java index 7b5404f6c..2c86082bb 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJob.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJob.java @@ -83,7 +83,8 @@ public TrainingJob( ZonedDateTime.now(ZoneOffset.UTC).toString(), description, "", - nodeAssignment + nodeAssignment, + knnMethodContext.getMethodComponentContext() ), null, this.modelId diff --git a/src/main/resources/mappings/model-index.json b/src/main/resources/mappings/model-index.json index a8c7d6528..cd2a50839 100644 --- a/src/main/resources/mappings/model-index.json +++ b/src/main/resources/mappings/model-index.json @@ -26,6 +26,12 @@ }, "model_blob": { "type": "binary" + }, + "node_assignment": { + "type": "keyword" + }, + "method_component_context": { + "type": "keyword" } } } diff --git a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java index 8fdc55766..710978928 100644 --- a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java @@ -62,7 +62,8 @@ public void testCreateIndexFromModel() throws IOException, InterruptedException ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "test-node" + "test-node", + MethodComponentContext.EMPTY ); Model model = new Model(modelMetadata, modelBlob, modelId); diff --git a/src/test/java/org/opensearch/knn/index/MethodComponentContextTests.java b/src/test/java/org/opensearch/knn/index/MethodComponentContextTests.java index cbbb872cf..5ce1a76ac 100644 --- a/src/test/java/org/opensearch/knn/index/MethodComponentContextTests.java +++ b/src/test/java/org/opensearch/knn/index/MethodComponentContextTests.java @@ -281,4 +281,73 @@ public void testHashCode() { assertNotEquals(methodContext1.hashCode(), methodContext4.hashCode()); assertEquals(methodContext4.hashCode(), methodContext5.hashCode()); } + + public void testToStringFromString() { + HashMap parameters3 = new HashMap() { + { + put("nlist", 4); + put("nprobes", 2); + } + }; + + HashMap parameters4 = new HashMap() { + { + put("nlist", 4); + put("type", "fp16"); + } + }; + + HashMap nestedParameters = new HashMap() { + { + put("nprobes", 2); + put("clip", false); + } + }; + HashMap parameters5 = new HashMap() { + { + put("nlist", 4); + put("type", "fp16"); + put("encoder", new MethodComponentContext("sq", nestedParameters)); + } + }; + + HashMap parameters6 = new HashMap() { + { + put("nlist", 4); + put("encoder", new MethodComponentContext("sq", nestedParameters)); + put("type", "fp16"); + } + }; + + MethodComponentContext methodComponentContext1 = MethodComponentContext.EMPTY; + MethodComponentContext methodComponentContext2 = new MethodComponentContext("ivf", null); + MethodComponentContext methodComponentContext3 = new MethodComponentContext("ivf", parameters3); + MethodComponentContext methodComponentContext4 = new MethodComponentContext("ivf", parameters4); + MethodComponentContext methodComponentContext5 = new MethodComponentContext("ivf", parameters5); + MethodComponentContext methodComponentContext6 = new MethodComponentContext("ivf", parameters6); + + String contextString1 = methodComponentContext1.toClusterStateString(); + String contextString2 = methodComponentContext2.toClusterStateString(); + String contextString3 = methodComponentContext3.toClusterStateString(); + String contextString4 = methodComponentContext4.toClusterStateString(); + String contextString5 = methodComponentContext5.toClusterStateString(); + String contextString6 = methodComponentContext6.toClusterStateString(); + + assertEquals("{name=;parameters=[]}", contextString1); + assertEquals("{name=ivf;parameters=[]}", contextString2); + + MethodComponentContext methodComponentContextFromString1 = MethodComponentContext.fromClusterStateString(contextString1); + MethodComponentContext methodComponentContextFromString2 = MethodComponentContext.fromClusterStateString(contextString2); + MethodComponentContext methodComponentContextFromString3 = MethodComponentContext.fromClusterStateString(contextString3); + MethodComponentContext methodComponentContextFromString4 = MethodComponentContext.fromClusterStateString(contextString4); + MethodComponentContext methodComponentContextFromString5 = MethodComponentContext.fromClusterStateString(contextString5); + MethodComponentContext methodComponentContextFromString6 = MethodComponentContext.fromClusterStateString(contextString6); + + assertEquals(methodComponentContext1, methodComponentContextFromString1); + assertEquals(new MethodComponentContext("ivf", Collections.emptyMap()), methodComponentContextFromString2); + assertEquals(methodComponentContext3, methodComponentContextFromString3); + assertEquals(methodComponentContext4, methodComponentContextFromString4); + assertEquals(methodComponentContext5, methodComponentContextFromString5); + assertEquals(methodComponentContext6, methodComponentContextFromString6); + } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index eeca1e5ed..7736652ce 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -343,7 +343,17 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio byte[] modelBytes = JNIService.trainIndex(parameters, dimension, trainingPtr, knnEngine.getName()); Model model = new Model( - new ModelMetadata(knnEngine, spaceType, dimension, ModelState.CREATED, "timestamp", "Empty description", "", ""), + new ModelMetadata( + knnEngine, + spaceType, + dimension, + ModelState.CREATED, + "timestamp", + "Empty description", + "", + "", + MethodComponentContext.EMPTY + ), modelBytes, modelId ); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index 42eb81759..8ab2641db 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -212,7 +212,8 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ); Model mockModel = new Model(modelMetadata1, modelBlob, modelId); diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index 278f90ba2..72dcac5c5 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -164,7 +164,8 @@ public void testBuilder_build_fromModel() { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ); builder.modelId.setValue(modelId); Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); @@ -691,7 +692,8 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ); when(mockModelDao.getMetadata(modelId)).thenReturn(mockModelMetadata); diff --git a/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java b/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java index 3146d898e..3a5255cd3 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java @@ -17,6 +17,7 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; @@ -43,7 +44,8 @@ public void testGet_normal() throws ExecutionException, InterruptedException { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), "hello".getBytes(), modelId @@ -79,7 +81,8 @@ public void testGet_modelDoesNotFitInCache() throws ExecutionException, Interrup ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), new byte[BYTES_PER_KILOBYTES + 1], modelId @@ -136,7 +139,8 @@ public void testGetTotalWeight() throws ExecutionException, InterruptedException ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), new byte[size1], modelId1 @@ -151,7 +155,8 @@ public void testGetTotalWeight() throws ExecutionException, InterruptedException ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), new byte[size2], modelId2 @@ -194,7 +199,8 @@ public void testRemove_normal() throws ExecutionException, InterruptedException ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), new byte[size1], modelId1 @@ -209,7 +215,9 @@ public void testRemove_normal() throws ExecutionException, InterruptedException ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY + ), new byte[size2], modelId2 @@ -257,7 +265,8 @@ public void testRebuild_normal() throws ExecutionException, InterruptedException ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), "hello".getBytes(), modelId @@ -302,7 +311,8 @@ public void testRebuild_afterSettingUpdate() throws ExecutionException, Interrup ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), new byte[modelSize], modelId @@ -370,7 +380,8 @@ public void testContains() throws ExecutionException, InterruptedException { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), new byte[modelSize1], modelId1 @@ -411,7 +422,8 @@ public void testRemoveAll() throws ExecutionException, InterruptedException { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), new byte[modelSize1], modelId1 @@ -428,7 +440,8 @@ public void testRemoveAll() throws ExecutionException, InterruptedException { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), new byte[modelSize2], modelId2 @@ -473,7 +486,8 @@ public void testModelCacheEvictionDueToSize() throws ExecutionException, Interru ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), new byte[BYTES_PER_KILOBYTES * 2], modelId diff --git a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java index 2af8df953..ee2c77d1a 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java @@ -36,6 +36,7 @@ import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.knn.KNNSingleNodeTestCase; import org.opensearch.knn.common.exception.DeleteModelWhenInTrainStateException; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.plugin.transport.DeleteModelResponse; @@ -54,6 +55,7 @@ import java.time.ZoneOffset; import java.time.ZonedDateTime; import java.util.Base64; +import java.util.Collections; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -151,7 +153,8 @@ public void testModelIndexHealth() throws InterruptedException, ExecutionExcepti ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), modelBlob, modelId @@ -170,7 +173,8 @@ public void testModelIndexHealth() throws InterruptedException, ExecutionExcepti ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), modelBlob, modelId @@ -197,7 +201,8 @@ public void testPut_withId() throws InterruptedException, IOException { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + new MethodComponentContext("test", Collections.emptyMap()) ), modelBlob, modelId @@ -257,7 +262,8 @@ public void testPut_withoutModel() throws InterruptedException, IOException { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), modelBlob, modelId @@ -318,7 +324,8 @@ public void testPut_invalid_badState() { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), modelBlob, "any-id" @@ -354,7 +361,8 @@ public void testUpdate() throws IOException, InterruptedException { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), null, modelId @@ -392,7 +400,8 @@ public void testUpdate() throws IOException, InterruptedException { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), modelBlob, modelId @@ -442,7 +451,8 @@ public void testGet() throws IOException, InterruptedException, ExecutionExcepti ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), modelBlob, modelId @@ -460,7 +470,8 @@ public void testGet() throws IOException, InterruptedException, ExecutionExcepti ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), null, modelId @@ -496,7 +507,8 @@ public void testGetMetadata() throws IOException, InterruptedException { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ); Model model = new Model(modelMetadata, modelBlob, modelId); @@ -572,7 +584,8 @@ public void testDelete() throws IOException, InterruptedException { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), modelBlob, modelId @@ -605,7 +618,8 @@ public void testDelete() throws IOException, InterruptedException { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), modelBlob, modelId1 @@ -672,7 +686,8 @@ public void testDeleteModelInTrainingWithStepListeners() throws IOException, Exe ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), modelBlob, modelId @@ -713,7 +728,8 @@ public void testDeleteWithStepListeners() throws IOException, InterruptedExcepti ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), modelBlob, modelId diff --git a/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java b/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java index 219710308..da56a8421 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java @@ -12,8 +12,12 @@ package org.opensearch.knn.indices; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; @@ -21,6 +25,7 @@ import java.time.ZoneId; import java.time.ZoneOffset; import java.time.ZonedDateTime; +import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -39,7 +44,8 @@ public void testStreams() throws IOException { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ); BytesStreamOutput streamOutput = new BytesStreamOutput(); @@ -60,7 +66,8 @@ public void testGetKnnEngine() { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ); assertEquals(knnEngine, modelMetadata.getKnnEngine()); @@ -76,7 +83,8 @@ public void testGetSpaceType() { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ); assertEquals(spaceType, modelMetadata.getSpaceType()); @@ -92,7 +100,8 @@ public void testGetDimension() { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ); assertEquals(dimension, modelMetadata.getDimension()); @@ -108,7 +117,8 @@ public void testGetState() { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ); assertEquals(modelState, modelMetadata.getState()); @@ -116,7 +126,17 @@ public void testGetState() { public void testGetTimestamp() { String timeValue = ZonedDateTime.now(ZoneOffset.UTC).toString(); - ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, ModelState.CREATED, timeValue, "", "", ""); + ModelMetadata modelMetadata = new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L2, + 12, + ModelState.CREATED, + timeValue, + "", + "", + "", + MethodComponentContext.EMPTY + ); assertEquals(timeValue, modelMetadata.getTimestamp()); } @@ -131,7 +151,8 @@ public void testDescription() { ZonedDateTime.now(ZoneOffset.UTC).toString(), description, "", - "" + "", + MethodComponentContext.EMPTY ); assertEquals(description, modelMetadata.getDescription()); @@ -147,7 +168,8 @@ public void testGetError() { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", error, - "" + "", + MethodComponentContext.EMPTY ); assertEquals(error, modelMetadata.getError()); @@ -163,7 +185,8 @@ public void testSetState() { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ); assertEquals(modelState, modelMetadata.getState()); @@ -183,7 +206,8 @@ public void testSetError() { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", error, - "" + "", + MethodComponentContext.EMPTY ); assertEquals(error, modelMetadata.getError()); @@ -202,6 +226,7 @@ public void testToString() { String description = "test-description"; String error = "test-error"; String nodeAssignment = ""; + MethodComponentContext methodComponentContext = MethodComponentContext.EMPTY; String expected = knnEngine.getName() + "," @@ -217,7 +242,9 @@ public void testToString() { + "," + error + "," - + nodeAssignment; + + nodeAssignment + + "," + + methodComponentContext.toClusterStateString(); ModelMetadata modelMetadata = new ModelMetadata( knnEngine, @@ -227,7 +254,8 @@ public void testToString() { timestamp, description, error, - nodeAssignment + nodeAssignment, + MethodComponentContext.EMPTY ); assertEquals(expected, modelMetadata.toString()); @@ -238,14 +266,84 @@ public void testEquals() { String time1 = ZonedDateTime.now(ZoneOffset.UTC).toString(); String time2 = ZonedDateTime.of(2021, 9, 30, 12, 20, 45, 1, ZoneId.systemDefault()).toString(); - ModelMetadata modelMetadata1 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time1, "", "", ""); - ModelMetadata modelMetadata2 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time1, "", "", ""); + ModelMetadata modelMetadata1 = new ModelMetadata( + KNNEngine.FAISS, + SpaceType.L2, + 128, + ModelState.CREATED, + time1, + "", + "", + "", + MethodComponentContext.EMPTY + ); + ModelMetadata modelMetadata2 = new ModelMetadata( + KNNEngine.FAISS, + SpaceType.L2, + 128, + ModelState.CREATED, + time1, + "", + "", + "", + MethodComponentContext.EMPTY + ); - ModelMetadata modelMetadata3 = new ModelMetadata(KNNEngine.NMSLIB, SpaceType.L2, 128, ModelState.CREATED, time1, "", "", ""); - ModelMetadata modelMetadata4 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L1, 128, ModelState.CREATED, time1, "", "", ""); - ModelMetadata modelMetadata5 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 129, ModelState.CREATED, time1, "", "", ""); - ModelMetadata modelMetadata6 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.TRAINING, time1, "", "", ""); - ModelMetadata modelMetadata7 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time2, "", "", ""); + ModelMetadata modelMetadata3 = new ModelMetadata( + KNNEngine.NMSLIB, + SpaceType.L2, + 128, + ModelState.CREATED, + time1, + "", + "", + "", + MethodComponentContext.EMPTY + ); + ModelMetadata modelMetadata4 = new ModelMetadata( + KNNEngine.FAISS, + SpaceType.L1, + 128, + ModelState.CREATED, + time1, + "", + "", + "", + MethodComponentContext.EMPTY + ); + ModelMetadata modelMetadata5 = new ModelMetadata( + KNNEngine.FAISS, + SpaceType.L2, + 129, + ModelState.CREATED, + time1, + "", + "", + "", + MethodComponentContext.EMPTY + ); + ModelMetadata modelMetadata6 = new ModelMetadata( + KNNEngine.FAISS, + SpaceType.L2, + 128, + ModelState.TRAINING, + time1, + "", + "", + "", + MethodComponentContext.EMPTY + ); + ModelMetadata modelMetadata7 = new ModelMetadata( + KNNEngine.FAISS, + SpaceType.L2, + 128, + ModelState.CREATED, + time2, + "", + "", + "", + MethodComponentContext.EMPTY + ); ModelMetadata modelMetadata8 = new ModelMetadata( KNNEngine.FAISS, SpaceType.L2, @@ -254,7 +352,8 @@ public void testEquals() { time1, "diff descript", "", - "" + "", + MethodComponentContext.EMPTY ); ModelMetadata modelMetadata9 = new ModelMetadata( KNNEngine.FAISS, @@ -264,7 +363,20 @@ public void testEquals() { time1, "", "diff error", - "" + "", + MethodComponentContext.EMPTY + ); + + ModelMetadata modelMetadata10 = new ModelMetadata( + KNNEngine.FAISS, + SpaceType.L2, + 128, + ModelState.CREATED, + time1, + "", + "", + "", + new MethodComponentContext("test", Collections.emptyMap()) ); assertEquals(modelMetadata1, modelMetadata1); @@ -285,14 +397,84 @@ public void testHashCode() { String time1 = ZonedDateTime.now(ZoneOffset.UTC).toString(); String time2 = ZonedDateTime.of(2021, 9, 30, 12, 20, 45, 1, ZoneId.systemDefault()).toString(); - ModelMetadata modelMetadata1 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time1, "", "", ""); - ModelMetadata modelMetadata2 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time1, "", "", ""); + ModelMetadata modelMetadata1 = new ModelMetadata( + KNNEngine.FAISS, + SpaceType.L2, + 128, + ModelState.CREATED, + time1, + "", + "", + "", + MethodComponentContext.EMPTY + ); + ModelMetadata modelMetadata2 = new ModelMetadata( + KNNEngine.FAISS, + SpaceType.L2, + 128, + ModelState.CREATED, + time1, + "", + "", + "", + MethodComponentContext.EMPTY + ); - ModelMetadata modelMetadata3 = new ModelMetadata(KNNEngine.NMSLIB, SpaceType.L2, 128, ModelState.CREATED, time1, "", "", ""); - ModelMetadata modelMetadata4 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L1, 128, ModelState.CREATED, time1, "", "", ""); - ModelMetadata modelMetadata5 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 129, ModelState.CREATED, time1, "", "", ""); - ModelMetadata modelMetadata6 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.TRAINING, time1, "", "", ""); - ModelMetadata modelMetadata7 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time2, "", "", ""); + ModelMetadata modelMetadata3 = new ModelMetadata( + KNNEngine.NMSLIB, + SpaceType.L2, + 128, + ModelState.CREATED, + time1, + "", + "", + "", + MethodComponentContext.EMPTY + ); + ModelMetadata modelMetadata4 = new ModelMetadata( + KNNEngine.FAISS, + SpaceType.L1, + 128, + ModelState.CREATED, + time1, + "", + "", + "", + MethodComponentContext.EMPTY + ); + ModelMetadata modelMetadata5 = new ModelMetadata( + KNNEngine.FAISS, + SpaceType.L2, + 129, + ModelState.CREATED, + time1, + "", + "", + "", + MethodComponentContext.EMPTY + ); + ModelMetadata modelMetadata6 = new ModelMetadata( + KNNEngine.FAISS, + SpaceType.L2, + 128, + ModelState.TRAINING, + time1, + "", + "", + "", + MethodComponentContext.EMPTY + ); + ModelMetadata modelMetadata7 = new ModelMetadata( + KNNEngine.FAISS, + SpaceType.L2, + 128, + ModelState.CREATED, + time2, + "", + "", + "", + MethodComponentContext.EMPTY + ); ModelMetadata modelMetadata8 = new ModelMetadata( KNNEngine.FAISS, SpaceType.L2, @@ -301,7 +483,8 @@ public void testHashCode() { time1, "diff descript", "", - "" + "", + MethodComponentContext.EMPTY ); ModelMetadata modelMetadata9 = new ModelMetadata( KNNEngine.FAISS, @@ -311,20 +494,33 @@ public void testHashCode() { time1, "", "diff error", - "" + "", + MethodComponentContext.EMPTY + ); + + ModelMetadata modelMetadata10 = new ModelMetadata( + KNNEngine.FAISS, + SpaceType.L2, + 128, + ModelState.CREATED, + time1, + "", + "", + "", + new MethodComponentContext("test", Collections.emptyMap()) ); assertEquals(modelMetadata1.hashCode(), modelMetadata1.hashCode()); assertEquals(modelMetadata1.hashCode(), modelMetadata2.hashCode()); assertNotEquals(modelMetadata1.hashCode(), modelMetadata3.hashCode()); - assertNotEquals(modelMetadata1.hashCode(), modelMetadata3.hashCode()); assertNotEquals(modelMetadata1.hashCode(), modelMetadata4.hashCode()); assertNotEquals(modelMetadata1.hashCode(), modelMetadata5.hashCode()); assertNotEquals(modelMetadata1.hashCode(), modelMetadata6.hashCode()); assertNotEquals(modelMetadata1.hashCode(), modelMetadata7.hashCode()); assertNotEquals(modelMetadata1.hashCode(), modelMetadata8.hashCode()); assertNotEquals(modelMetadata1.hashCode(), modelMetadata9.hashCode()); + assertNotEquals(modelMetadata1.hashCode(), modelMetadata10.hashCode()); } public void testFromString() { @@ -336,6 +532,7 @@ public void testFromString() { String description = "test-description"; String error = "test-error"; String nodeAssignment = "test-node"; + MethodComponentContext methodComponentContext = MethodComponentContext.EMPTY; String stringRep1 = knnEngine.getName() + "," @@ -351,7 +548,9 @@ public void testFromString() { + "," + error + "," - + nodeAssignment; + + nodeAssignment + + "," + + methodComponentContext.toClusterStateString(); String stringRep2 = knnEngine.getName() + "," @@ -375,10 +574,21 @@ public void testFromString() { timestamp, description, error, - nodeAssignment + nodeAssignment, + MethodComponentContext.EMPTY ); - ModelMetadata expected2 = new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error, ""); + ModelMetadata expected2 = new ModelMetadata( + knnEngine, + spaceType, + dimension, + modelState, + timestamp, + description, + error, + "", + MethodComponentContext.EMPTY + ); ModelMetadata fromString1 = ModelMetadata.fromString(stringRep1); ModelMetadata fromString2 = ModelMetadata.fromString(stringRep2); @@ -389,7 +599,7 @@ public void testFromString() { expectThrows(IllegalArgumentException.class, () -> ModelMetadata.fromString("invalid")); } - public void testFromResponseMap() { + public void testFromResponseMap() throws IOException { KNNEngine knnEngine = KNNEngine.DEFAULT; SpaceType spaceType = SpaceType.L2; int dimension = 128; @@ -398,6 +608,21 @@ public void testFromResponseMap() { String description = "test-description"; String error = "test-error"; String nodeAssignment = "test-node"; + Map nestedParameters = new HashMap() { + { + put("testNestedKey1", "testNestedString"); + put("testNestedKey2", 1); + } + }; + Map parameters = new HashMap<>() { + { + put("testKey1", "testString"); + put("testKey2", 0); + put("testKey3", new MethodComponentContext("ivf", nestedParameters)); + } + }; + MethodComponentContext methodComponentContext = new MethodComponentContext("hnsw", parameters); + MethodComponentContext emptyMethodComponentContext = MethodComponentContext.EMPTY; ModelMetadata expected = new ModelMetadata( knnEngine, @@ -407,9 +632,21 @@ public void testFromResponseMap() { timestamp, description, error, - nodeAssignment + nodeAssignment, + methodComponentContext + + ); + ModelMetadata expected2 = new ModelMetadata( + knnEngine, + spaceType, + dimension, + modelState, + timestamp, + description, + error, + "", + emptyMethodComponentContext ); - ModelMetadata expected2 = new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error, ""); Map metadataAsMap = new HashMap<>(); metadataAsMap.put(KNNConstants.KNN_ENGINE, knnEngine.getName()); metadataAsMap.put(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()); @@ -420,10 +657,15 @@ public void testFromResponseMap() { metadataAsMap.put(KNNConstants.MODEL_ERROR, error); metadataAsMap.put(KNNConstants.MODEL_NODE_ASSIGNMENT, nodeAssignment); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder = methodComponentContext.toXContent(builder, ToXContent.EMPTY_PARAMS).endObject(); + metadataAsMap.put(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT, builder.toString()); + ModelMetadata fromMap = ModelMetadata.getMetadataFromSourceMap(metadataAsMap); assertEquals(expected, fromMap); metadataAsMap.put(KNNConstants.MODEL_NODE_ASSIGNMENT, null); + metadataAsMap.put(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT, null); assertEquals(expected2, fromMap); } diff --git a/src/test/java/org/opensearch/knn/indices/ModelTests.java b/src/test/java/org/opensearch/knn/indices/ModelTests.java index c015e8d62..13579acad 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelTests.java @@ -13,6 +13,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; @@ -39,7 +40,8 @@ public void testInvalidConstructor() { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), null, "test-model" @@ -59,7 +61,8 @@ public void testInvalidDimension() { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), new byte[16], "test-model" @@ -76,7 +79,8 @@ public void testInvalidDimension() { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), new byte[16], "test-model" @@ -93,7 +97,8 @@ public void testInvalidDimension() { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), new byte[16], "test-model" @@ -111,7 +116,8 @@ public void testGetModelMetadata() { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ); Model model = new Model(modelMetadata, new byte[16], "test-model"); assertEquals(modelMetadata, model.getModelMetadata()); @@ -128,7 +134,8 @@ public void testGetModelBlob() { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), modelBlob, "test-model" @@ -147,7 +154,8 @@ public void testGetLength() { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), new byte[size], "test-model" @@ -163,7 +171,8 @@ public void testGetLength() { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), null, "test-model" @@ -182,7 +191,8 @@ public void testSetModelBlob() { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ), blob1, "test-model" @@ -199,17 +209,17 @@ public void testEquals() { String time = ZonedDateTime.now(ZoneOffset.UTC).toString(); Model model1 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", ""), + new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), new byte[16], "test-model-1" ); Model model2 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", ""), + new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), new byte[16], "test-model-1" ); Model model3 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 2, ModelState.CREATED, time, "", "", ""), + new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), new byte[16], "test-model-2" ); @@ -224,17 +234,17 @@ public void testHashCode() { String time = ZonedDateTime.now(ZoneOffset.UTC).toString(); Model model1 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", ""), + new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), new byte[16], "test-model-1" ); Model model2 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", ""), + new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), new byte[16], "test-model-1" ); Model model3 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", ""), + new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), new byte[16], "test-model-2" ); @@ -263,7 +273,8 @@ public void testModelFromSourceMap() { timestamp, description, error, - nodeAssignment + nodeAssignment, + MethodComponentContext.EMPTY ); Map modelAsMap = new HashMap<>(); modelAsMap.put(KNNConstants.MODEL_ID, modelID); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java b/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java index 04c94de7e..a6985e72a 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java @@ -18,6 +18,7 @@ import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNClusterUtil; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.Model; @@ -33,7 +34,17 @@ public class GetModelResponseTests extends KNNTestCase { private ModelMetadata getModelMetadata(ModelState state) { - return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, state, "2021-03-27 10:15:30 AM +05:30", "test model", "", ""); + return new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + 4, + state, + "2021-03-27 10:15:30 AM +05:30", + "test model", + "", + "", + MethodComponentContext.EMPTY + ); } public void testStreams() throws IOException { @@ -57,7 +68,7 @@ public void testXContent() throws IOException { Model model = new Model(getModelMetadata(ModelState.CREATED), testModelBlob, modelId); GetModelResponse getModelResponse = new GetModelResponse(model); String expectedResponseString = - "{\"model_id\":\"test-model\",\"model_blob\":\"aGVsbG8=\",\"state\":\"created\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\"}"; + "{\"model_id\":\"test-model\",\"model_blob\":\"aGVsbG8=\",\"state\":\"created\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\",\"model_definition\":{\"name\":\"\",\"parameters\":{}}}"; XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); getModelResponse.toXContent(xContentBuilder, null); assertEquals(expectedResponseString, xContentBuilder.toString()); @@ -73,7 +84,7 @@ public void testXContentWithNoModelBlob() throws IOException { Model model = new Model(getModelMetadata(ModelState.FAILED), null, modelId); GetModelResponse getModelResponse = new GetModelResponse(model); String expectedResponseString = - "{\"model_id\":\"test-model\",\"model_blob\":\"\",\"state\":\"failed\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\"}"; + "{\"model_id\":\"test-model\",\"model_blob\":\"\",\"state\":\"failed\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\",\"model_definition\":{\"name\":\"\",\"parameters\":{}}}"; XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); getModelResponse.toXContent(xContentBuilder, null); assertEquals(expectedResponseString, xContentBuilder.toString()); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java index 5d30f54bb..a2da83dad 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java @@ -17,6 +17,7 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.knn.KNNSingleNodeTestCase; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.Model; @@ -68,7 +69,17 @@ public void testNodeOperation_modelInCache() throws ExecutionException, Interrup ModelDao modelDao = mock(ModelDao.class); String modelId = "test-model-id"; Model model = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 16, ModelState.CREATED, "timestamp", "description", "", ""), + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L2, + 16, + ModelState.CREATED, + "timestamp", + "description", + "", + "", + MethodComponentContext.EMPTY + ), new byte[128], modelId ); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index 7465ccc58..bdae54cad 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -24,6 +24,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; @@ -168,7 +169,8 @@ public void testValidation_invalid_modelIdAlreadyExists() { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java index a41ca900a..3719d124a 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java @@ -13,6 +13,7 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelMetadata; @@ -40,7 +41,8 @@ public void testStreams() throws IOException { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ); UpdateModelMetadataRequest updateModelMetadataRequest = new UpdateModelMetadataRequest(modelId, isRemoveRequest, modelMetadata); @@ -64,7 +66,8 @@ public void testValidate() { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ); UpdateModelMetadataRequest updateModelMetadataRequest1 = new UpdateModelMetadataRequest("test", true, null); @@ -103,7 +106,8 @@ public void testGetModelMetadata() { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ); UpdateModelMetadataRequest updateModelMetadataRequest = new UpdateModelMetadataRequest("test", true, modelMetadata); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java index 11961f6f5..ab0e4f506 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java @@ -17,6 +17,7 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.knn.KNNSingleNodeTestCase; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelMetadata; @@ -66,7 +67,8 @@ public void testClusterManagerOperation() throws InterruptedException { ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "", - "" + "", + MethodComponentContext.EMPTY ); // Get update transport action diff --git a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java index 6b07b2dd2..9dc461b97 100644 --- a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java +++ b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java @@ -56,6 +56,7 @@ public void testGetModelId() { KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); when(knnMethodContext.getKnnEngine()).thenReturn(KNNEngine.DEFAULT); when(knnMethodContext.getSpaceType()).thenReturn(SpaceType.DEFAULT); + when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY); TrainingJob trainingJob = new TrainingJob( modelId, @@ -78,10 +79,12 @@ public void testGetModel() { String description = "test description"; String error = ""; String nodeAssignment = "test-node"; + MethodComponentContext methodComponentContext = MethodComponentContext.EMPTY; KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); when(knnMethodContext.getKnnEngine()).thenReturn(knnEngine); when(knnMethodContext.getSpaceType()).thenReturn(spaceType); + when(knnMethodContext.getMethodComponentContext()).thenReturn(methodComponentContext); String modelID = "test-model-id"; TrainingJob trainingJob = new TrainingJob( @@ -104,7 +107,8 @@ public void testGetModel() { trainingJob.getModel().getModelMetadata().getTimestamp(), description, error, - nodeAssignment + nodeAssignment, + MethodComponentContext.EMPTY ), null, modelID