diff --git a/.github/workflows/CI-workflow.yml b/.github/workflows/CI-workflow.yml index 595cf086de..9a7aa47845 100644 --- a/.github/workflows/CI-workflow.yml +++ b/.github/workflows/CI-workflow.yml @@ -42,7 +42,7 @@ jobs: needs: [Get-Require-Approval, Get-CI-Image-Tag, spotless] strategy: matrix: - java: [21, 23] + java: [21, 24] name: Build and Test MLCommons Plugin on linux if: github.repository == 'opensearch-project/ml-commons' @@ -107,7 +107,7 @@ jobs: needs: [Get-Require-Approval, Build-ml-linux, spotless] strategy: matrix: - java: [21, 23] + java: [21, 24] name: Test MLCommons Plugin on linux docker if: github.repository == 'opensearch-project/ml-commons' @@ -203,7 +203,7 @@ jobs: Build-ml-windows: strategy: matrix: - java: [21, 23] + java: [21, 24] name: Build and Test MLCommons Plugin on Windows if: github.repository == 'opensearch-project/ml-commons' needs: [Get-Require-Approval, spotless] diff --git a/build.gradle b/build.gradle index fb68269d8b..d398793be8 100644 --- a/build.gradle +++ b/build.gradle @@ -11,7 +11,7 @@ buildscript { ext { opensearch_group = "org.opensearch" isSnapshot = "true" == System.getProperty("build.snapshot", "true") - opensearch_version = System.getProperty("opensearch.version", "3.1.0-SNAPSHOT") + opensearch_version = System.getProperty("opensearch.version", "3.2.0-SNAPSHOT") buildVersionQualifier = System.getProperty("build.version_qualifier", "") asm_version = "9.7" @@ -60,9 +60,9 @@ buildscript { } plugins { - id 'com.netflix.nebula.ospackage' version "11.5.0" + id 'com.netflix.nebula.ospackage' version "12.0.0" id 'java' - id "io.freefair.lombok" version "8.4" + id "io.freefair.lombok" version "8.14" id 'jacoco' } @@ -80,7 +80,7 @@ allprojects { } plugins.withId('jacoco') { - jacoco.toolVersion = '0.8.12' + jacoco.toolVersion = '0.8.13' } project.getExtensions().getExtraProperties().set("versions", VersionProperties.getVersions()); diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index 0e1394b58f..e442b0300c 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -214,7 +214,7 @@ public void registerModelGroup( @Override public void execute(FunctionName name, Input input, ActionListener listener) { MLExecuteTaskRequest mlExecuteTaskRequest = new MLExecuteTaskRequest(name, input); - client.execute(MLExecuteTaskAction.INSTANCE, mlExecuteTaskRequest, listener); + client.execute(MLExecuteTaskAction.INSTANCE, mlExecuteTaskRequest, getMLExecuteResponseActionListener(listener)); } @Override @@ -345,6 +345,10 @@ private ActionListener getMLRegisterAgentResponseAction return wrapActionListener(listener, MLRegisterAgentResponse::fromActionResponse); } + private ActionListener getMLExecuteResponseActionListener(ActionListener listener) { + return wrapActionListener(listener, MLExecuteTaskResponse::fromActionResponse); + } + private ActionListener getMLTaskResponseActionListener(ActionListener listener) { ActionListener internalListener = ActionListener .wrap(getResponse -> { listener.onResponse(getResponse.getMlTask()); }, listener::onFailure); diff --git a/common/build.gradle b/common/build.gradle index 4a141aad91..24cc63046f 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -47,7 +47,7 @@ dependencies { } lombok { - version = "1.18.30" + version = "1.18.38" } jacocoTestReport { diff --git a/common/src/main/java/org/opensearch/ml/common/AccessMode.java b/common/src/main/java/org/opensearch/ml/common/AccessMode.java index d4195206d5..3cf656fb88 100644 --- a/common/src/main/java/org/opensearch/ml/common/AccessMode.java +++ b/common/src/main/java/org/opensearch/ml/common/AccessMode.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common; diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index c969973dc7..2fe3d3b770 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -43,6 +43,7 @@ public class CommonValue { public static final String ML_AGENT_INDEX = ".plugins-ml-agent"; public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta"; public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message"; + public static final String ML_MEMORY_CONTAINER_INDEX = ".plugins-ml-memory-container"; public static final String ML_STOP_WORDS_INDEX = ".plugins-ml-stop-words"; // index used in 2.19 to track MlTaskBatchUpdate task public static final String TASK_POLLING_JOB_INDEX = ".ml_commons_task_polling_job"; @@ -63,6 +64,7 @@ public class CommonValue { public static final String ML_AGENT_INDEX_MAPPING_PATH = "index-mappings/ml_agent.json"; public static final String ML_MEMORY_META_INDEX_MAPPING_PATH = "index-mappings/ml_memory_meta.json"; public static final String ML_MEMORY_MESSAGE_INDEX_MAPPING_PATH = "index-mappings/ml_memory_message.json"; + public static final String ML_MEMORY_CONTAINER_INDEX_MAPPING_PATH = "index-mappings/ml_memory_container.json"; public static final String ML_MCP_SESSION_MANAGEMENT_INDEX_MAPPING_PATH = "index-mappings/ml_mcp_session_management.json"; public static final String ML_MCP_TOOLS_INDEX_MAPPING_PATH = "index-mappings/ml_mcp_tools.json"; public static final String ML_JOBS_INDEX_MAPPING_PATH = "index-mappings/ml_jobs.json"; diff --git a/common/src/main/java/org/opensearch/ml/common/FunctionName.java b/common/src/main/java/org/opensearch/ml/common/FunctionName.java index 96df6baa55..f1ee27efbc 100644 --- a/common/src/main/java/org/opensearch/ml/common/FunctionName.java +++ b/common/src/main/java/org/opensearch/ml/common/FunctionName.java @@ -31,7 +31,8 @@ public enum FunctionName { TEXT_SIMILARITY, QUESTION_ANSWERING, AGENT, - CONNECTOR; + CONNECTOR, + TOOL; public static FunctionName from(String value) { try { diff --git a/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java b/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java index 1709923c48..4fd60d35a1 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java +++ b/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java @@ -28,6 +28,8 @@ import org.opensearch.ml.common.output.MLOutputType; import org.reflections.Reflections; +import com.fasterxml.jackson.core.JsonParseException; + import lombok.extern.log4j.Log4j2; @Log4j2 @@ -255,16 +257,18 @@ public static boolean canInitMLInput(FunctionName functionName) { return mlInputClassMap.containsKey(functionName); } - public static S initConnector(String name, Object[] initArgs, Class... constructorParameterTypes) { + public static S initConnector(String name, Object[] initArgs, Class... constructorParameterTypes) throws JsonParseException { return init(connectorClassMap, name, initArgs, constructorParameterTypes); } @SuppressWarnings("unchecked") - public static , S> S initMLInput(T type, Object[] initArgs, Class... constructorParameterTypes) { + public static , S> S initMLInput(T type, Object[] initArgs, Class... constructorParameterTypes) + throws JsonParseException { return init(mlInputClassMap, type, initArgs, constructorParameterTypes); } - private static S init(Map> map, T type, Object[] initArgs, Class... constructorParameterTypes) { + private static S init(Map> map, T type, Object[] initArgs, Class... constructorParameterTypes) + throws JsonParseException { Class clazz = map.get(type); if (clazz == null) { throw new IllegalArgumentException("Can't find class for type " + type); @@ -278,6 +282,8 @@ private static S init(Map> map, T type, Object[] initArgs, Cl throw (MLException) cause; } else if (cause instanceof IllegalArgumentException) { throw (IllegalArgumentException) cause; + } else if (cause instanceof JsonParseException) { + throw (JsonParseException) cause; } else { log.error("Failed to init instance for type " + type, e); return null; diff --git a/common/src/main/java/org/opensearch/ml/common/MLIndex.java b/common/src/main/java/org/opensearch/ml/common/MLIndex.java index cbe2e43132..566d27db7e 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLIndex.java +++ b/common/src/main/java/org/opensearch/ml/common/MLIndex.java @@ -19,6 +19,8 @@ import static org.opensearch.ml.common.CommonValue.ML_JOBS_INDEX_MAPPING_PATH; import static org.opensearch.ml.common.CommonValue.ML_MCP_SESSION_MANAGEMENT_INDEX_MAPPING_PATH; import static org.opensearch.ml.common.CommonValue.ML_MCP_TOOLS_INDEX_MAPPING_PATH; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_CONTAINER_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_CONTAINER_INDEX_MAPPING_PATH; import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX_MAPPING_PATH; import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX; @@ -45,6 +47,7 @@ public enum MLIndex { AGENT(ML_AGENT_INDEX, false, ML_AGENT_INDEX_MAPPING_PATH), MEMORY_META(ML_MEMORY_META_INDEX, false, ML_MEMORY_META_INDEX_MAPPING_PATH), MEMORY_MESSAGE(ML_MEMORY_MESSAGE_INDEX, false, ML_MEMORY_MESSAGE_INDEX_MAPPING_PATH), + MEMORY_CONTAINER(ML_MEMORY_CONTAINER_INDEX, false, ML_MEMORY_CONTAINER_INDEX_MAPPING_PATH), MCP_SESSION_MANAGEMENT(MCP_SESSION_MANAGEMENT_INDEX, false, ML_MCP_SESSION_MANAGEMENT_INDEX_MAPPING_PATH), MCP_TOOLS(MCP_TOOLS_INDEX, false, ML_MCP_TOOLS_INDEX_MAPPING_PATH), JOBS(ML_JOBS_INDEX, false, ML_JOBS_INDEX_MAPPING_PATH); diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index e0eb7abf68..16d70f1ef9 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -965,7 +965,7 @@ Tags getRemoteModelTags(Connector connector) { * @param url The URL to analyze for service provider identification * @return The identified service provider name, or "unknown" if not found */ - String identifyServiceProvider(String url) { + static String identifyServiceProvider(String url) { for (String provider : MODEL_SERVICE_PROVIDER_KEYWORDS) { if (url.contains(provider)) { return provider; @@ -975,6 +975,10 @@ String identifyServiceProvider(String url) { return TAG_VALUE_UNKNOWN; } + public static String identifyServiceProviderFromUrl(String url) { + return identifyServiceProvider(url); + } + /** * Identifies the model name from the connector configuration using multiple strategies. * The method attempts to extract the model name in the following order: diff --git a/common/src/main/java/org/opensearch/ml/common/MLTask.java b/common/src/main/java/org/opensearch/ml/common/MLTask.java index 3b275f3e34..c7f22981c8 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLTask.java +++ b/common/src/main/java/org/opensearch/ml/common/MLTask.java @@ -73,6 +73,7 @@ public class MLTask implements ToXContentObject, Writeable { @Setter private String error; private User user; // TODO: support document level access control later + @Setter private boolean async; @Setter private Map remoteJob; diff --git a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java index d8306882a5..5ab6ba18e7 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java @@ -182,7 +182,7 @@ default void validateConnectorURL(List urlRegexes) { } } if (!hasMatchedUrl) { - throw new IllegalArgumentException("Connector URL is not matching the trusted connector endpoint regex, URL is: " + url); + throw new IllegalArgumentException("Connector URL is not matching the trusted connector endpoint regex"); } } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index 835c6a6c47..c82f489296 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -9,10 +9,14 @@ import java.io.IOException; import java.util.HashSet; +import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Set; +import org.apache.commons.text.StringSubstitutor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -35,6 +39,17 @@ public class ConnectorAction implements ToXContentObject, Writeable { public static final String REQUEST_BODY_FIELD = "request_body"; public static final String ACTION_PRE_PROCESS_FUNCTION = "pre_process_function"; public static final String ACTION_POST_PROCESS_FUNCTION = "post_process_function"; + public static final String OPENAI = "openai"; + public static final String COHERE = "cohere"; + public static final String BEDROCK = "bedrock"; + public static final String SAGEMAKER = "sagemaker"; + public static final String SAGEMAKER_PRE_POST_FUNC_TEXT = "default"; + public static final List SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES = List.of(SAGEMAKER, OPENAI, BEDROCK, COHERE); + + private static final String INBUILT_FUNC_PREFIX = "connector."; + private static final String PRE_PROCESS_FUNC = "PreProcessFunction"; + private static final String POST_PROCESS_FUNC = "PostProcessFunction"; + private static final Logger logger = LogManager.getLogger(ConnectorAction.class); private ActionType actionType; private String method; @@ -185,6 +200,81 @@ public static ConnectorAction parse(XContentParser parser) throws IOException { .build(); } + /** + * Checks the compatibility of pre and post-process functions with the selected LLM service. + * Each LLM service (eg: Bedrock, OpenAI, SageMaker) has recommended pre and post-process functions + * designed for optimal performance. While it's possible to use functions from other services, + * it's strongly advised to use the corresponding functions for the best results. + * This method logs a warning if non-corresponding functions are detected, but allows the + * configuration to proceed. Users should be aware that using mismatched functions may lead + * to unexpected behavior or reduced performance, though it won't necessarily cause failures. + * + * @param parameters - connector parameters + */ + public void validatePrePostProcessFunctions(Map parameters) { + StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); + String endPoint = substitutor.replace(url); + String remoteServer = getRemoteServerFromURL(endPoint); + if (!remoteServer.isEmpty()) { + validateProcessFunctions(remoteServer, preProcessFunction, PRE_PROCESS_FUNC); + validateProcessFunctions(remoteServer, postProcessFunction, POST_PROCESS_FUNC); + } + } + + /** + * To get the remote server name from url + * + * @param url - remote server url + * @return - returns the corresponding remote server name for url, if server is not in the pre-defined list, + * it returns null + */ + public static String getRemoteServerFromURL(String url) { + return SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES.stream().filter(url::contains).findFirst().orElse(""); + } + + private void validateProcessFunctions(String remoteServer, String processFunction, String funcNameForWarnText) { + if (isInBuiltProcessFunction(processFunction)) { + switch (remoteServer) { + case OPENAI: + if (!processFunction.contains(OPENAI)) { + logWarningForInvalidProcessFunc(OPENAI, funcNameForWarnText); + } + break; + case COHERE: + if (!processFunction.contains(COHERE)) { + logWarningForInvalidProcessFunc(COHERE, funcNameForWarnText); + } + break; + case BEDROCK: + if (!processFunction.contains(BEDROCK)) { + logWarningForInvalidProcessFunc(BEDROCK, funcNameForWarnText); + } + break; + case SAGEMAKER: + if (!processFunction.contains(SAGEMAKER_PRE_POST_FUNC_TEXT)) { + logWarningForInvalidProcessFunc(SAGEMAKER, funcNameForWarnText); + } + } + } + } + + private boolean isInBuiltProcessFunction(String processFunction) { + return (processFunction != null && processFunction.startsWith(INBUILT_FUNC_PREFIX)); + } + + private void logWarningForInvalidProcessFunc(String remoteServer, String funcNameForWarnText) { + logger + .warn( + "LLM service is " + + remoteServer + + ", so " + + funcNameForWarnText + + " should be corresponding to " + + remoteServer + + " for better results." + ); + } + public enum ActionType { PREDICT, EXECUTE, diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index 2e2f56c7b7..e00dcb1ffa 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -70,6 +70,11 @@ public HttpConnector( String tenantId ) { validateProtocol(protocol); + if (actions != null) { + for (ConnectorAction action : actions) { + action.validatePrePostProcessFunctions(parameters); + } + } this.name = name; this.description = description; this.version = version; diff --git a/common/src/main/java/org/opensearch/ml/common/connector/McpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/McpConnector.java index 5f9f72bf97..d0ce976e58 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/McpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/McpConnector.java @@ -438,7 +438,7 @@ public void validateConnectorURL(List urlRegexes) { } } if (!hasMatchedUrl) { - throw new IllegalArgumentException("Connector URL is not matching the trusted connector endpoint regex, URL is: " + url); + throw new IllegalArgumentException("Connector URL is not matching the trusted connector endpoint regex"); } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java index d7acc1a70b..b1056df9cf 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common.connector.functions.preprocess; diff --git a/common/src/main/java/org/opensearch/ml/common/input/execute/tool/ToolMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/execute/tool/ToolMLInput.java new file mode 100644 index 0000000000..5acf865ff9 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/input/execute/tool/ToolMLInput.java @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.input.execute.tool; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.core.ParseField; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.Input; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.utils.StringUtils; + +import lombok.Getter; +import lombok.Setter; + +@org.opensearch.ml.common.annotation.MLInput(functionNames = { FunctionName.TOOL }) +public class ToolMLInput extends MLInput { + public static final String TOOL_NAME_FIELD = "tool_name"; + public static final String PARAMETERS_FIELD = "parameters"; + + @Getter + @Setter + private String toolName; + + public ToolMLInput(StreamInput in) throws IOException { + super(in); + this.toolName = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(toolName); + } + + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( + Input.class, + new ParseField(FunctionName.TOOL.name()), + it -> parse(it) + ); + + public static ToolMLInput parse(XContentParser parser) throws IOException { + return new ToolMLInput(parser, FunctionName.TOOL); + } + + public ToolMLInput(XContentParser parser, FunctionName functionName) throws IOException { + this.algorithm = functionName; + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case TOOL_NAME_FIELD: + toolName = parser.text(); + break; + case PARAMETERS_FIELD: + Map parameters = StringUtils.getParameterMap(parser.map()); + inputDataset = new RemoteInferenceInputDataSet(parameters); + break; + default: + parser.skipChildren(); + break; + } + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java index f73b83e106..bbeceed464 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java @@ -9,7 +9,9 @@ import java.io.IOException; import java.util.Locale; +import java.util.Objects; +import org.opensearch.Version; import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -31,9 +33,11 @@ *

* Use this parameter only if the model is asymmetric and has been registered with the corresponding * `query_prefix` and `passage_prefix` configuration parameters. + *

+ * Also supports embedding format control for sparse encoding algorithms. */ @Data -@MLAlgoParameter(algorithms = { FunctionName.TEXT_EMBEDDING }) +@MLAlgoParameter(algorithms = { FunctionName.TEXT_EMBEDDING, FunctionName.SPARSE_ENCODING, FunctionName.SPARSE_TOKENIZE }) public class AsymmetricTextEmbeddingParameters implements MLAlgoParams { public enum EmbeddingContentType { @@ -47,18 +51,44 @@ public enum EmbeddingContentType { new ParseField(PARSE_FIELD_NAME), it -> parse(it) ); + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY_SPARSE_ENCODING = new NamedXContentRegistry.Entry( + MLAlgoParams.class, + new ParseField(FunctionName.SPARSE_ENCODING.name()), + it -> parse(it) + ); + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY_SPARSE_TOKENIZE = new NamedXContentRegistry.Entry( + MLAlgoParams.class, + new ParseField(FunctionName.SPARSE_TOKENIZE.name()), + it -> parse(it) + ); @Builder(toBuilder = true) + public AsymmetricTextEmbeddingParameters(EmbeddingContentType embeddingContentType, SparseEmbeddingFormat sparseEmbeddingFormat) { + this.embeddingContentType = embeddingContentType; + this.sparseEmbeddingFormat = sparseEmbeddingFormat != null ? sparseEmbeddingFormat : SparseEmbeddingFormat.WORD; + } + + // Constructor for backward compatibility public AsymmetricTextEmbeddingParameters(EmbeddingContentType embeddingContentType) { this.embeddingContentType = embeddingContentType; + this.sparseEmbeddingFormat = SparseEmbeddingFormat.WORD; } public AsymmetricTextEmbeddingParameters(StreamInput in) throws IOException { - this.embeddingContentType = EmbeddingContentType.valueOf(in.readOptionalString()); + Version streamInputVersion = in.getVersion(); + String contentType = in.readOptionalString(); + this.embeddingContentType = contentType != null ? EmbeddingContentType.valueOf(contentType) : null; + if (streamInputVersion.onOrAfter(Version.V_3_2_0)) { + String formatName = in.readOptionalString(); + this.sparseEmbeddingFormat = formatName != null ? SparseEmbeddingFormat.valueOf(formatName) : SparseEmbeddingFormat.WORD; + } else { + this.sparseEmbeddingFormat = SparseEmbeddingFormat.WORD; + } } public static MLAlgoParams parse(XContentParser parser) throws IOException { EmbeddingContentType embeddingContentType = null; + SparseEmbeddingFormat sparseEmbeddingFormat = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -70,19 +100,27 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException { String contentType = parser.text(); embeddingContentType = EmbeddingContentType.valueOf(contentType.toUpperCase(Locale.ROOT)); break; + case SPARSE_EMBEDDING_FORMAT_FIELD: + String formatType = parser.text(); + sparseEmbeddingFormat = SparseEmbeddingFormat.valueOf(formatType.toUpperCase(Locale.ROOT)); + break; default: parser.skipChildren(); break; } } - return new AsymmetricTextEmbeddingParameters(embeddingContentType); + return new AsymmetricTextEmbeddingParameters(embeddingContentType, sparseEmbeddingFormat); } public static final String EMBEDDING_CONTENT_TYPE_FIELD = "content_type"; + public static final String SPARSE_EMBEDDING_FORMAT_FIELD = "sparse_embedding_format"; // The type of the content to be embedded private EmbeddingContentType embeddingContentType; + // The format of the embedding output + private SparseEmbeddingFormat sparseEmbeddingFormat; + @Override public int getVersion() { return 1; @@ -95,7 +133,11 @@ public String getWriteableName() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalString(embeddingContentType.name()); + Version streamOutputVersion = out.getVersion(); + out.writeOptionalString(embeddingContentType != null ? embeddingContentType.name() : null); + if (streamOutputVersion.onOrAfter(Version.V_3_2_0)) { + out.writeOptionalString(sparseEmbeddingFormat != null ? sparseEmbeddingFormat.name() : null); + } } @Override @@ -104,6 +146,7 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params if (embeddingContentType != null) { xContentBuilder.field(EMBEDDING_CONTENT_TYPE_FIELD, embeddingContentType.name()); } + xContentBuilder.field(SPARSE_EMBEDDING_FORMAT_FIELD, sparseEmbeddingFormat.name()); xContentBuilder.endObject(); return xContentBuilder; } @@ -111,4 +154,21 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params public EmbeddingContentType getEmbeddingContentType() { return embeddingContentType; } + + public SparseEmbeddingFormat getSparseEmbeddingFormat() { + return sparseEmbeddingFormat; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + AsymmetricTextEmbeddingParameters other = (AsymmetricTextEmbeddingParameters) obj; + return Objects.equals(embeddingContentType, other.embeddingContentType) + && Objects.equals(sparseEmbeddingFormat, other.sparseEmbeddingFormat); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEmbeddingFormat.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEmbeddingFormat.java new file mode 100644 index 0000000000..1e66f825a8 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEmbeddingFormat.java @@ -0,0 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.input.parameter.textembedding; + +/** + * Enum defining the format of sparse embeddings. + */ +public enum SparseEmbeddingFormat { + WORD, + TOKEN_ID +} diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MLMemory.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MLMemory.java new file mode 100644 index 0000000000..b3b0cff6c2 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MLMemory.java @@ -0,0 +1,269 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.memorycontainer; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.AGENT_ID_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.CREATED_TIME_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.LAST_UPDATED_TIME_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_EMBEDDING_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_TYPE_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.ROLE_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.SESSION_ID_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.TAGS_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.USER_ID_FIELD; + +import java.io.IOException; +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +/** + * Represents a memory entry in a memory container + */ +@Getter +@Setter +@Builder +public class MLMemory implements ToXContentObject, Writeable { + + // Core fields + private String sessionId; + private String memory; + private MemoryType memoryType; + + // Optional fields + private String userId; + private String agentId; + private String role; + private Map tags; + + // System fields + private Instant createdTime; + private Instant lastUpdatedTime; + + // Vector/embedding field (optional, for semantic storage) + private Object memoryEmbedding; + + @Builder + public MLMemory( + String sessionId, + String memory, + MemoryType memoryType, + String userId, + String agentId, + String role, + Map tags, + Instant createdTime, + Instant lastUpdatedTime, + Object memoryEmbedding + ) { + this.sessionId = sessionId; + this.memory = memory; + this.memoryType = memoryType; + this.userId = userId; + this.agentId = agentId; + this.role = role; + this.tags = tags; + this.createdTime = createdTime; + this.lastUpdatedTime = lastUpdatedTime; + this.memoryEmbedding = memoryEmbedding; + } + + public MLMemory(StreamInput in) throws IOException { + this.sessionId = in.readString(); + this.memory = in.readString(); + this.memoryType = in.readEnum(MemoryType.class); + this.userId = in.readOptionalString(); + this.agentId = in.readOptionalString(); + this.role = in.readOptionalString(); + if (in.readBoolean()) { + this.tags = in.readMap(StreamInput::readString, StreamInput::readString); + } + this.createdTime = in.readInstant(); + this.lastUpdatedTime = in.readInstant(); + // Note: memoryEmbedding is not serialized in StreamInput/Output as it's typically handled separately + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(sessionId); + out.writeString(memory); + out.writeEnum(memoryType); + out.writeOptionalString(userId); + out.writeOptionalString(agentId); + out.writeOptionalString(role); + if (tags != null && !tags.isEmpty()) { + out.writeBoolean(true); + out.writeMap(tags, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } + out.writeInstant(createdTime); + out.writeInstant(lastUpdatedTime); + // Note: memoryEmbedding is not serialized in StreamInput/Output as it's typically handled separately + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(SESSION_ID_FIELD, sessionId); + builder.field(MEMORY_FIELD, memory); + builder.field(MEMORY_TYPE_FIELD, memoryType.getValue()); + + if (userId != null) { + builder.field(USER_ID_FIELD, userId); + } + if (agentId != null) { + builder.field(AGENT_ID_FIELD, agentId); + } + if (role != null) { + builder.field(ROLE_FIELD, role); + } + if (tags != null && !tags.isEmpty()) { + builder.field(TAGS_FIELD, tags); + } + + builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli()); + builder.field(LAST_UPDATED_TIME_FIELD, lastUpdatedTime.toEpochMilli()); + + if (memoryEmbedding != null) { + builder.field(MEMORY_EMBEDDING_FIELD, memoryEmbedding); + } + + builder.endObject(); + return builder; + } + + public static MLMemory parse(XContentParser parser) throws IOException { + String sessionId = null; + String memory = null; + MemoryType memoryType = null; + String userId = null; + String agentId = null; + String role = null; + Map tags = null; + Instant createdTime = null; + Instant lastUpdatedTime = null; + Object memoryEmbedding = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case SESSION_ID_FIELD: + sessionId = parser.text(); + break; + case MEMORY_FIELD: + memory = parser.text(); + break; + case MEMORY_TYPE_FIELD: + memoryType = MemoryType.fromString(parser.text()); + break; + case USER_ID_FIELD: + userId = parser.text(); + break; + case AGENT_ID_FIELD: + agentId = parser.text(); + break; + case ROLE_FIELD: + role = parser.text(); + break; + case TAGS_FIELD: + Map tagsMap = parser.map(); + if (tagsMap != null) { + tags = new HashMap<>(); + for (Map.Entry entry : tagsMap.entrySet()) { + if (entry.getValue() != null) { + tags.put(entry.getKey(), entry.getValue().toString()); + } + } + } + break; + case CREATED_TIME_FIELD: + createdTime = Instant.ofEpochMilli(parser.longValue()); + break; + case LAST_UPDATED_TIME_FIELD: + lastUpdatedTime = Instant.ofEpochMilli(parser.longValue()); + break; + case MEMORY_EMBEDDING_FIELD: + // Parse embedding as generic object (could be array or sparse map) + memoryEmbedding = parser.map(); + break; + default: + parser.skipChildren(); + break; + } + } + + return MLMemory + .builder() + .sessionId(sessionId) + .memory(memory) + .memoryType(memoryType) + .userId(userId) + .agentId(agentId) + .role(role) + .tags(tags) + .createdTime(createdTime) + .lastUpdatedTime(lastUpdatedTime) + .memoryEmbedding(memoryEmbedding) + .build(); + } + + /** + * Convert to a Map for indexing + */ + public Map toIndexMap() { + Map map = Map + .of( + SESSION_ID_FIELD, + sessionId, + MEMORY_FIELD, + memory, + MEMORY_TYPE_FIELD, + memoryType.getValue(), + CREATED_TIME_FIELD, + createdTime.toEpochMilli(), + LAST_UPDATED_TIME_FIELD, + lastUpdatedTime.toEpochMilli() + ); + + // Use mutable map for optional fields + Map result = new java.util.HashMap<>(map); + + if (userId != null) { + result.put(USER_ID_FIELD, userId); + } + if (agentId != null) { + result.put(AGENT_ID_FIELD, agentId); + } + if (role != null) { + result.put(ROLE_FIELD, role); + } + if (tags != null && !tags.isEmpty()) { + result.put(TAGS_FIELD, tags); + } + if (memoryEmbedding != null) { + result.put(MEMORY_EMBEDDING_FIELD, memoryEmbedding); + } + + return result; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MLMemoryContainer.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MLMemoryContainer.java new file mode 100644 index 0000000000..56ef786398 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MLMemoryContainer.java @@ -0,0 +1,185 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.memorycontainer; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.CREATED_TIME_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.DESCRIPTION_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.LAST_UPDATED_TIME_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_STORAGE_CONFIG_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.NAME_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.OWNER_FIELD; + +import java.io.IOException; +import java.time.Instant; + +import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; + +/** + * ML Memory Container data model that stores metadata about memory-related objects + */ +@Getter +@Setter +@Builder +@EqualsAndHashCode +public class MLMemoryContainer implements ToXContentObject, Writeable { + + private String name; + private String description; + private User owner; + private String tenantId; + private Instant createdTime; + private Instant lastUpdatedTime; + private MemoryStorageConfig memoryStorageConfig; + + public MLMemoryContainer( + String name, + String description, + User owner, + String tenantId, + Instant createdTime, + Instant lastUpdatedTime, + MemoryStorageConfig memoryStorageConfig + ) { + this.name = name; + this.description = description; + this.owner = owner; + this.tenantId = tenantId; + this.createdTime = createdTime; + this.lastUpdatedTime = lastUpdatedTime; + this.memoryStorageConfig = memoryStorageConfig; + } + + public MLMemoryContainer(StreamInput input) throws IOException { + this.name = input.readOptionalString(); + this.description = input.readOptionalString(); + if (input.readBoolean()) { + this.owner = new User(input); + } + this.tenantId = input.readOptionalString(); + this.createdTime = input.readOptionalInstant(); + this.lastUpdatedTime = input.readOptionalInstant(); + if (input.readBoolean()) { + this.memoryStorageConfig = new MemoryStorageConfig(input); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(name); + out.writeOptionalString(description); + if (owner != null) { + out.writeBoolean(true); + owner.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeOptionalString(tenantId); + out.writeOptionalInstant(createdTime); + out.writeOptionalInstant(lastUpdatedTime); + if (memoryStorageConfig != null) { + out.writeBoolean(true); + memoryStorageConfig.writeTo(out); + } else { + out.writeBoolean(false); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + if (name != null) { + builder.field(NAME_FIELD, name); + } + if (description != null) { + builder.field(DESCRIPTION_FIELD, description); + } + if (owner != null) { + builder.field(OWNER_FIELD, owner); + } + if (tenantId != null) { + builder.field(TENANT_ID_FIELD, tenantId); + } + if (createdTime != null) { + builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli()); + } + if (lastUpdatedTime != null) { + builder.field(LAST_UPDATED_TIME_FIELD, lastUpdatedTime.toEpochMilli()); + } + if (memoryStorageConfig != null) { + builder.field(MEMORY_STORAGE_CONFIG_FIELD, memoryStorageConfig); + } + builder.endObject(); + return builder; + } + + public static MLMemoryContainer parse(XContentParser parser) throws IOException { + String name = null; + String description = null; + User owner = null; + String tenantId = null; + Instant createdTime = null; + Instant lastUpdatedTime = null; + MemoryStorageConfig memoryStorageConfig = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case NAME_FIELD: + name = parser.text(); + break; + case DESCRIPTION_FIELD: + description = parser.text(); + break; + case OWNER_FIELD: + owner = User.parse(parser); + break; + case TENANT_ID_FIELD: + tenantId = parser.text(); + break; + case CREATED_TIME_FIELD: + createdTime = Instant.ofEpochMilli(parser.longValue()); + break; + case LAST_UPDATED_TIME_FIELD: + lastUpdatedTime = Instant.ofEpochMilli(parser.longValue()); + break; + case MEMORY_STORAGE_CONFIG_FIELD: + memoryStorageConfig = MemoryStorageConfig.parse(parser); + break; + default: + parser.skipChildren(); + break; + } + } + + return MLMemoryContainer + .builder() + .name(name) + .description(description) + .owner(owner) + .tenantId(tenantId) + .createdTime(createdTime) + .lastUpdatedTime(lastUpdatedTime) + .memoryStorageConfig(memoryStorageConfig) + .build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstants.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstants.java new file mode 100644 index 0000000000..7965f0e70a --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstants.java @@ -0,0 +1,117 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.memorycontainer; + +/** + * Constants for Memory Container feature + */ +public class MemoryContainerConstants { + + // Field names for MemoryContainer + public static final String MEMORY_CONTAINER_ID_FIELD = "memory_container_id"; + public static final String NAME_FIELD = "name"; + public static final String DESCRIPTION_FIELD = "description"; + public static final String OWNER_FIELD = "owner"; + public static final String CREATED_TIME_FIELD = "created_time"; + public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; + public static final String MEMORY_STORAGE_CONFIG_FIELD = "memory_storage_config"; + + // Field names for MemoryStorageConfig + public static final String MEMORY_INDEX_NAME_FIELD = "memory_index_name"; + public static final String SEMANTIC_STORAGE_ENABLED_FIELD = "semantic_storage_enabled"; + public static final String EMBEDDING_MODEL_TYPE_FIELD = "embedding_model_type"; + public static final String EMBEDDING_MODEL_ID_FIELD = "embedding_model_id"; + public static final String LLM_MODEL_ID_FIELD = "llm_model_id"; + public static final String DIMENSION_FIELD = "dimension"; + public static final String MAX_INFER_SIZE_FIELD = "max_infer_size"; + + // Default values + public static final int MAX_INFER_SIZE_DEFAULT_VALUE = 5; + + // Memory index type prefixes + public static final String STATIC_MEMORY_INDEX_PREFIX = "ml-static-memory-"; + public static final String KNN_MEMORY_INDEX_PREFIX = "ml-knn-memory-"; + public static final String SPARSE_MEMORY_INDEX_PREFIX = "ml-sparse-memory-"; + + // Memory data index field names + public static final String USER_ID_FIELD = "user_id"; + public static final String AGENT_ID_FIELD = "agent_id"; + public static final String SESSION_ID_FIELD = "session_id"; + public static final String MEMORY_FIELD = "memory"; + public static final String MEMORY_EMBEDDING_FIELD = "memory_embedding"; + public static final String TAGS_FIELD = "tags"; + public static final String MEMORY_ID_FIELD = "memory_id"; + public static final String MEMORY_TYPE_FIELD = "memory_type"; + public static final String ROLE_FIELD = "role"; + + // Request body field names (different from storage field names) + public static final String MESSAGE_FIELD = "message"; + public static final String MESSAGES_FIELD = "messages"; + public static final String CONTENT_FIELD = "content"; + public static final String INFER_FIELD = "infer"; + public static final String QUERY_FIELD = "query"; + public static final String TEXT_FIELD = "text"; + + // KNN index settings + public static final String KNN_ENGINE = "lucene"; + public static final String KNN_SPACE_TYPE = "cosinesimil"; + public static final String KNN_METHOD_NAME = "hnsw"; + public static final int KNN_EF_SEARCH = 100; + public static final int KNN_EF_CONSTRUCTION = 100; + public static final int KNN_M = 16; + + // REST API paths + public static final String BASE_MEMORY_CONTAINERS_PATH = "/_plugins/_ml/memory_containers"; + public static final String CREATE_MEMORY_CONTAINER_PATH = BASE_MEMORY_CONTAINERS_PATH + "/_create"; + public static final String PARAMETER_MEMORY_CONTAINER_ID = "memory_container_id"; + public static final String PARAMETER_MEMORY_ID = "memory_id"; + public static final String MEMORIES_PATH = BASE_MEMORY_CONTAINERS_PATH + "/{" + PARAMETER_MEMORY_CONTAINER_ID + "}/memories"; + public static final String SEARCH_MEMORIES_PATH = MEMORIES_PATH + "/_search"; + public static final String DELETE_MEMORY_PATH = MEMORIES_PATH + "/{" + PARAMETER_MEMORY_ID + "}"; + public static final String UPDATE_MEMORY_PATH = MEMORIES_PATH + "/{" + PARAMETER_MEMORY_ID + "}"; + public static final String GET_MEMORY_PATH = MEMORIES_PATH + "/{" + PARAMETER_MEMORY_ID + "}"; + + // Memory types are defined in MemoryType enum + + // Response fields + public static final String STATUS_FIELD = "status"; + + // Error messages + public static final String SEMANTIC_STORAGE_EMBEDDING_MODEL_TYPE_REQUIRED_ERROR = + "Embedding model type is required when embedding model ID is provided"; + public static final String SEMANTIC_STORAGE_EMBEDDING_MODEL_ID_REQUIRED_ERROR = + "Embedding model ID is required when embedding model type is provided"; + public static final String TEXT_EMBEDDING_DIMENSION_REQUIRED_ERROR = "Dimension is required for TEXT_EMBEDDING"; + public static final String SPARSE_ENCODING_DIMENSION_NOT_ALLOWED_ERROR = "Dimension is not allowed for SPARSE_ENCODING"; + public static final String INVALID_EMBEDDING_MODEL_TYPE_ERROR = "Embedding model type must be either TEXT_EMBEDDING or SPARSE_ENCODING"; + public static final String MAX_INFER_SIZE_LIMIT_ERROR = "Maximum infer size cannot exceed 10"; + public static final String FIELD_NOT_ALLOWED_SEMANTIC_DISABLED_ERROR = "Field %s is not allowed when semantic storage is disabled"; + + // Model validation error messages + public static final String LLM_MODEL_NOT_FOUND_ERROR = "LLM model with ID %s not found"; + public static final String LLM_MODEL_NOT_REMOTE_ERROR = "LLM model must be a REMOTE model, found: %s"; + public static final String EMBEDDING_MODEL_NOT_FOUND_ERROR = "Embedding model with ID %s not found"; + public static final String EMBEDDING_MODEL_TYPE_MISMATCH_ERROR = "Embedding model must be of type %s or REMOTE, found: %s"; // instead + public static final String INFER_REQUIRES_LLM_MODEL_ERROR = "infer=true requires llm_model_id to be configured in memory storage"; + + // Memory API limits + public static final int MAX_MESSAGES_PER_REQUEST = 10; + public static final String MAX_MESSAGES_EXCEEDED_ERROR = "Cannot process more than 10 messages in a single request"; + + // Memory decision fields + public static final String MEMORY_DECISION_FIELD = "memory_decision"; + public static final String OLD_MEMORY_FIELD = "old_memory"; + public static final String RETRIEVED_FACTS_FIELD = "retrieved_facts"; + public static final String EVENT_FIELD = "event"; + public static final String SCORE_FIELD = "score"; + + // LLM System Prompts + public static final String PERSONAL_INFORMATION_ORGANIZER_PROMPT = + "\nPersonal Information Organizer\nExtract and organize personal information shared within conversations.\n\nCarefully read the conversation.\nIdentify and extract any personal information shared by participants.\nFocus on details that help build a profile of the person, including but not limited to:\n\nNames and relationships\nProfessional information (job, company, role, responsibilities)\nPersonal interests and hobbies\nSkills and expertise\nPreferences and opinions\nGoals and aspirations\nChallenges or pain points\nBackground and experiences\nContact information (if shared)\nAvailability and schedule preferences\n\n\nOrganize each piece of information as a separate fact.\nEnsure facts are specific, clear, and preserve the original context.\nNever answer user's question or fulfill user's requirement. You are a personal information manager, not a helpful assistant.\nInclude the person who shared the information when relevant.\nDo not make assumptions or inferences beyond what is explicitly stated.\nIf no personal information is found, return an empty list.\n\n\nYou should always return and only return the extracted facts as a JSON object with a \"facts\" array.\n\n{\n \"facts\": [\n \"User's name is John Smith\",\n \"John works as a software engineer at TechCorp\",\n \"John enjoys hiking on weekends\",\n \"John is looking to improve his Python skills\"\n ]\n}\n\n\n"; + + public static final String DEFAULT_UPDATE_MEMORY_PROMPT = + "You are a smart memory manager which controls the memory of a system.You will receive: 1. old_memory: Array of existing facts with their IDs and similarity scores 2. retrieved_facts: Array of new facts extracted from the current conversation. Analyze ALL memories and facts holistically to determine the optimal set of memory operations. Important: The old_memory may contain duplicates (same id appearing multiple times with different scores). Consider the highest score for each unique ID. You should only respond and always respond with a JSON object containing a \"memory_decision\" array that covers: - Every unique existing memory ID (with appropriate event: NONE, UPDATE, or DELETE) - New entries for facts that should be added (with event: ADD){\"memory_decision\": [{\"id\": \"existing_id_or_new_id\",\"text\": \"the fact text\",\"event\": \"ADD|UPDATE|DELETE|NONE\",\"old_memory\": \"original text (only for UPDATE events)\"}]}1. **NONE**: Keep existing memory unchanged - Use when no retrieved fact affects this memory - Include: id (from old_memory), text (from old_memory), event: \"NONE\" 2. **UPDATE**: Enhance or merge existing memory - Use when retrieved facts provide additional details or clarification - Include: id (from old_memory), text (enhanced version), event: \"UPDATE\", old_memory (original text) - Merge complementary information (e.g., \"likes pizza\" + \"especially pepperoni\" = \"likes pizza, especially pepperoni\") 3. **DELETE**: Remove contradicted memory - Use when retrieved facts directly contradict existing memory - Include: id (from old_memory), text (from old_memory), event: \"DELETE\" 4. **ADD**: Create new memory - Use for retrieved facts that represent genuinely new information - Include: id (generate new), text (the new fact), event: \"ADD\" - Only add if the fact is not already covered by existing or updated memories- Integrity: Never answer user's question or fulfill user's requirement. You are a smart memory manager, not a helpful assistant. - Process holistically: Consider all facts and memories together before making decisions - Avoid redundancy: Don't ADD a fact if it's already covered by an UPDATE - Merge related facts: If multiple retrieved facts relate to the same topic, consider combining them - Respect similarity scores: Higher scores indicate stronger matches - be more careful about updating high-score memories - Maintain consistency: Ensure your decisions don't create contradictions in the memory set - One decision per unique memory ID: If an ID appears multiple times in old_memory, make only one decision for it{\"old_memory\": [{\"id\": \"fact_001\", \"text\": \"Enjoys Italian food\", \"score\": 0.85},{\"id\": \"fact_002\", \"text\": \"Works at Google\", \"score\": 0.92},{\"id\": \"fact_001\", \"text\": \"Enjoys Italian food\", \"score\": 0.75},{\"id\": \"fact_003\", \"text\": \"Has a dog\", \"score\": 0.65}],\"retrieved_facts\": [\"Loves pasta and pizza\",\"Recently joined Amazon\",\"Has two dogs named Max and Bella\"]}{\"memory_decision\": [{\"id\": \"fact_001\",\"text\": \"Loves Italian food, especially pasta and pizza\",\"event\": \"UPDATE\",\"old_memory\": \"Enjoys Italian food\"},{\"id\": \"fact_002\",\"text\": \"Works at Google\",\"event\": \"DELETE\"},{\"id\": \"fact_003\",\"text\": \"Has two dogs named Max and Bella\",\"event\": \"UPDATE\",\"old_memory\": \"Has a dog\"},{\"id\": \"fact_004\",\"text\": \"Recently joined Amazon\",\"event\": \"ADD\"}]}"; +} diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryDecision.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryDecision.java new file mode 100644 index 0000000000..4a8e5dd62b --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryDecision.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.memorycontainer; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.*; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.transport.memorycontainer.memory.MemoryEvent; + +import lombok.Builder; +import lombok.Data; + +/** + * Represents a memory decision made by the LLM + */ +@Data +@Builder +public class MemoryDecision implements ToXContentObject, Writeable { + + private String id; + private String text; + private MemoryEvent event; + private String oldMemory; // Only for UPDATE events + + public MemoryDecision(String id, String text, MemoryEvent event, String oldMemory) { + this.id = id; + this.text = text; + this.event = event; + this.oldMemory = oldMemory; + } + + public MemoryDecision(StreamInput in) throws IOException { + this.id = in.readString(); + this.text = in.readString(); + this.event = MemoryEvent.valueOf(in.readString()); + this.oldMemory = in.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + out.writeString(text); + out.writeString(event.toString()); + out.writeOptionalString(oldMemory); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field(MEMORY_ID_FIELD, id); + builder.field(TEXT_FIELD, text); + builder.field(EVENT_FIELD, event.toString()); + if (oldMemory != null) { + builder.field(OLD_MEMORY_FIELD, oldMemory); + } + builder.endObject(); + return builder; + } + + public static MemoryDecision parse(XContentParser parser) throws IOException { + String id = null; + String text = null; + MemoryEvent event = null; + String oldMemory = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case MEMORY_ID_FIELD: + case "id": // Support both formats + id = parser.text(); + break; + case TEXT_FIELD: + text = parser.text(); + break; + case EVENT_FIELD: + event = MemoryEvent.fromString(parser.text()); + break; + case OLD_MEMORY_FIELD: + oldMemory = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + + return MemoryDecision.builder().id(id).text(text).event(event).oldMemory(oldMemory).build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryDecisionRequest.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryDecisionRequest.java new file mode 100644 index 0000000000..8650159d53 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryDecisionRequest.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.memorycontainer; + +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.*; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Builder; +import lombok.Data; + +/** + * Request structure for memory decision making + */ +@Data +@Builder +public class MemoryDecisionRequest implements ToXContentObject { + + // List of existing memories with scores + private List oldMemory; + + // List of newly extracted facts + private List retrievedFacts; + + @Data + @Builder + public static class OldMemory implements ToXContentObject { + private String id; + private String text; + private float score; + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field("id", id); + builder.field("text", text); + builder.field("score", score); + builder.endObject(); + return builder; + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + + // Build old_memory array + builder.startArray(OLD_MEMORY_FIELD); + if (oldMemory != null) { + for (OldMemory memory : oldMemory) { + memory.toXContent(builder, params); + } + } + builder.endArray(); + + // Build retrieved_facts array + builder.startArray(RETRIEVED_FACTS_FIELD); + if (retrievedFacts != null) { + for (String fact : retrievedFacts) { + builder.value(fact); + } + } + builder.endArray(); + + builder.endObject(); + return builder; + } + + /** + * Convert to string for LLM request + */ + public String toJsonString() { + try { + XContentBuilder builder = XContentBuilder.builder(org.opensearch.common.xcontent.json.JsonXContent.jsonXContent); + toXContent(builder, ToXContent.EMPTY_PARAMS); + return builder.toString(); + } catch (IOException e) { + throw new RuntimeException("Failed to serialize MemoryDecisionRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryStorageConfig.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryStorageConfig.java new file mode 100644 index 0000000000..e8e518aa87 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryStorageConfig.java @@ -0,0 +1,254 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.memorycontainer; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.DIMENSION_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.EMBEDDING_MODEL_ID_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.EMBEDDING_MODEL_TYPE_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.INVALID_EMBEDDING_MODEL_TYPE_ERROR; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.LLM_MODEL_ID_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MAX_INFER_SIZE_DEFAULT_VALUE; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MAX_INFER_SIZE_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MAX_INFER_SIZE_LIMIT_ERROR; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_INDEX_NAME_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.SEMANTIC_STORAGE_EMBEDDING_MODEL_ID_REQUIRED_ERROR; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.SEMANTIC_STORAGE_EMBEDDING_MODEL_TYPE_REQUIRED_ERROR; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.SEMANTIC_STORAGE_ENABLED_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.SPARSE_ENCODING_DIMENSION_NOT_ALLOWED_ERROR; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.TEXT_EMBEDDING_DIMENSION_REQUIRED_ERROR; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; + +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; + +/** + * Configuration for memory storage in memory containers + */ +@Getter +@Setter +@Builder +@EqualsAndHashCode +public class MemoryStorageConfig implements ToXContentObject, Writeable { + + private String memoryIndexName; + private boolean semanticStorageEnabled; + private FunctionName embeddingModelType; + private String embeddingModelId; + private String llmModelId; + private Integer dimension; + @Builder.Default + private Integer maxInferSize = MAX_INFER_SIZE_DEFAULT_VALUE; + + public MemoryStorageConfig( + String memoryIndexName, + boolean semanticStorageEnabled, + FunctionName embeddingModelType, + String embeddingModelId, + String llmModelId, + Integer dimension, + Integer maxInferSize + ) { + // Validate first + validateInputs(embeddingModelType, embeddingModelId, dimension, maxInferSize); + + // Auto-determine semantic storage based on embedding configuration + boolean determinedSemanticStorage = (embeddingModelId != null && embeddingModelType != null); + + // Assign values after validation + this.memoryIndexName = memoryIndexName; + this.semanticStorageEnabled = determinedSemanticStorage; + this.embeddingModelType = embeddingModelType; + this.embeddingModelId = embeddingModelId; + this.llmModelId = llmModelId; + this.dimension = dimension; + this.maxInferSize = (llmModelId != null) ? (maxInferSize != null ? maxInferSize : MAX_INFER_SIZE_DEFAULT_VALUE) : null; + } + + public MemoryStorageConfig(StreamInput input) throws IOException { + this.memoryIndexName = input.readOptionalString(); + this.semanticStorageEnabled = input.readBoolean(); + String embeddingModelTypeStr = input.readOptionalString(); + this.embeddingModelType = embeddingModelTypeStr != null ? FunctionName.from(embeddingModelTypeStr) : null; + this.embeddingModelId = input.readOptionalString(); + this.llmModelId = input.readOptionalString(); + this.dimension = input.readOptionalInt(); + this.maxInferSize = input.readOptionalInt(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(memoryIndexName); + out.writeBoolean(semanticStorageEnabled); + out.writeOptionalString(embeddingModelType != null ? embeddingModelType.name() : null); + out.writeOptionalString(embeddingModelId); + out.writeOptionalString(llmModelId); + out.writeOptionalInt(dimension); + out.writeOptionalInt(maxInferSize); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + + // Always output these fields + if (memoryIndexName != null) { + builder.field(MEMORY_INDEX_NAME_FIELD, memoryIndexName); + } + builder.field(SEMANTIC_STORAGE_ENABLED_FIELD, semanticStorageEnabled); + + // Always output LLM model if present (decoupled from semantic storage) + if (llmModelId != null) { + builder.field(LLM_MODEL_ID_FIELD, llmModelId); + } + + // When semantic storage is enabled, output embedding-related fields + if (semanticStorageEnabled) { + if (embeddingModelType != null) { + builder.field(EMBEDDING_MODEL_TYPE_FIELD, embeddingModelType.name()); + } + if (embeddingModelId != null) { + builder.field(EMBEDDING_MODEL_ID_FIELD, embeddingModelId); + } + if (dimension != null) { + builder.field(DIMENSION_FIELD, dimension); + } + } + + // Output maxInferSize when LLM model is configured + if (llmModelId != null && maxInferSize != null) { + builder.field(MAX_INFER_SIZE_FIELD, maxInferSize); + } + + builder.endObject(); + return builder; + } + + public static MemoryStorageConfig parse(XContentParser parser) throws IOException { + String memoryIndexName = null; + FunctionName embeddingModelType = null; + String embeddingModelId = null; + String llmModelId = null; + Integer dimension = null; + Integer maxInferSize = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case MEMORY_INDEX_NAME_FIELD: + memoryIndexName = parser.text(); + break; + case SEMANTIC_STORAGE_ENABLED_FIELD: + // Skip this field - it's now auto-determined + parser.skipChildren(); + break; + case EMBEDDING_MODEL_TYPE_FIELD: + embeddingModelType = FunctionName.from(parser.text()); + break; + case EMBEDDING_MODEL_ID_FIELD: + embeddingModelId = parser.text(); + break; + case LLM_MODEL_ID_FIELD: + llmModelId = parser.text(); + break; + case DIMENSION_FIELD: + dimension = parser.intValue(); + break; + case MAX_INFER_SIZE_FIELD: + maxInferSize = parser.intValue(); + break; + default: + parser.skipChildren(); + break; + } + } + + // Note: validation is already called in the constructor + return MemoryStorageConfig + .builder() + .memoryIndexName(memoryIndexName) + .embeddingModelType(embeddingModelType) + .embeddingModelId(embeddingModelId) + .llmModelId(llmModelId) + .dimension(dimension) + .maxInferSize(maxInferSize) + .build(); + } + + /** + * Validates input parameters before construction. + */ + private static void validateInputs(FunctionName embeddingModelType, String embeddingModelId, Integer dimension, Integer maxInferSize) { + validateEmbeddingConfiguration(embeddingModelType, embeddingModelId, dimension); + validateMaxInferSize(maxInferSize); + } + + /** + * Validates embedding configuration including model pairing and dimension requirements. + */ + private static void validateEmbeddingConfiguration(FunctionName embeddingModelType, String embeddingModelId, Integer dimension) { + // Check for partial embedding configuration + if (embeddingModelId != null && embeddingModelType == null) { + throw new IllegalArgumentException(SEMANTIC_STORAGE_EMBEDDING_MODEL_TYPE_REQUIRED_ERROR); + } + if (embeddingModelType != null && embeddingModelId == null) { + throw new IllegalArgumentException(SEMANTIC_STORAGE_EMBEDDING_MODEL_ID_REQUIRED_ERROR); + } + + // If embedding model type is provided, validate it + if (embeddingModelType != null) { + validateEmbeddingModelType(embeddingModelType); + validateDimensionRequirements(embeddingModelType, dimension); + } + } + + /** + * Validates max infer size limit. + */ + private static void validateMaxInferSize(Integer maxInferSize) { + if (maxInferSize != null && maxInferSize > 10) { + throw new IllegalArgumentException(MAX_INFER_SIZE_LIMIT_ERROR); + } + } + + /** + * Validates that the embedding model type is supported. + */ + private static void validateEmbeddingModelType(FunctionName embeddingModelType) { + if (embeddingModelType != FunctionName.TEXT_EMBEDDING && embeddingModelType != FunctionName.SPARSE_ENCODING) { + throw new IllegalArgumentException(INVALID_EMBEDDING_MODEL_TYPE_ERROR); + } + } + + /** + * Validates dimension requirements based on embedding type. + * TEXT_EMBEDDING requires dimension, SPARSE_ENCODING does not allow dimension. + */ + private static void validateDimensionRequirements(FunctionName embeddingModelType, Integer dimension) { + if (embeddingModelType == FunctionName.TEXT_EMBEDDING && dimension == null) { + throw new IllegalArgumentException(TEXT_EMBEDDING_DIMENSION_REQUIRED_ERROR); + } + + if (embeddingModelType == FunctionName.SPARSE_ENCODING && dimension != null) { + throw new IllegalArgumentException(SPARSE_ENCODING_DIMENSION_NOT_ALLOWED_ERROR); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryType.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryType.java new file mode 100644 index 0000000000..7e8053f9fb --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryType.java @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.memorycontainer; + +/** + * Enum representing the type of memory entry + */ +public enum MemoryType { + RAW_MESSAGE("RAW_MESSAGE"), + FACT("FACT"); + + private final String value; + + MemoryType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + /** + * Parse string value to MemoryType + * @param value string representation of memory type + * @return corresponding MemoryType enum + * @throws IllegalArgumentException if value is invalid + */ + public static MemoryType fromString(String value) { + if (value == null) { + return null; + } + + for (MemoryType type : MemoryType.values()) { + if (type.value.equalsIgnoreCase(value)) { + return type; + } + } + + throw new IllegalArgumentException("Invalid memory type: " + value + ". Must be either RAW_MESSAGE or FACT"); + } + + @Override + public String toString() { + return value; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/model/BaseModelConfig.java b/common/src/main/java/org/opensearch/ml/common/model/BaseModelConfig.java index 056d217181..55693a129c 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/BaseModelConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/model/BaseModelConfig.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.common.CommonValue.VERSION_3_1_0; import java.io.IOException; +import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -41,13 +42,46 @@ public class BaseModelConfig extends MLModelConfig { it -> parse(it) ); + public static final String EMBEDDING_DIMENSION_FIELD = "embedding_dimension"; + public static final String FRAMEWORK_TYPE_FIELD = "framework_type"; + public static final String POOLING_MODE_FIELD = "pooling_mode"; + public static final String NORMALIZE_RESULT_FIELD = "normalize_result"; + public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length"; + public static final String QUERY_PREFIX = "query_prefix"; + public static final String PASSAGE_PREFIX = "passage_prefix"; public static final String ADDITIONAL_CONFIG_FIELD = "additional_config"; + + protected Integer embeddingDimension; + protected FrameworkType frameworkType; + protected PoolingMode poolingMode; + protected boolean normalizeResult; + protected Integer modelMaxLength; + protected String queryPrefix; + protected String passagePrefix; protected Map additionalConfig; @Builder(builderMethodName = "baseModelConfigBuilder") - public BaseModelConfig(String modelType, String allConfig, Map additionalConfig) { + public BaseModelConfig( + String modelType, + String allConfig, + Map additionalConfig, + Integer embeddingDimension, + FrameworkType frameworkType, + PoolingMode poolingMode, + boolean normalizeResult, + Integer modelMaxLength, + String queryPrefix, + String passagePrefix + ) { super(modelType, allConfig); this.additionalConfig = additionalConfig; + this.embeddingDimension = embeddingDimension; + this.frameworkType = frameworkType; + this.poolingMode = poolingMode; + this.normalizeResult = normalizeResult; + this.modelMaxLength = modelMaxLength; + this.queryPrefix = queryPrefix; + this.passagePrefix = passagePrefix; validateNoDuplicateKeys(allConfig, additionalConfig); } @@ -55,6 +89,13 @@ public static BaseModelConfig parse(XContentParser parser) throws IOException { String modelType = null; String allConfig = null; Map additionalConfig = null; + Integer embeddingDimension = null; + FrameworkType frameworkType = null; + PoolingMode poolingMode = null; + boolean normalizeResult = false; + Integer modelMaxLength = null; + String queryPrefix = null; + String passagePrefix = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -71,12 +112,44 @@ public static BaseModelConfig parse(XContentParser parser) throws IOException { case ADDITIONAL_CONFIG_FIELD: additionalConfig = parser.map(); break; + case EMBEDDING_DIMENSION_FIELD: + embeddingDimension = parser.intValue(); + break; + case FRAMEWORK_TYPE_FIELD: + frameworkType = FrameworkType.from(parser.text().toUpperCase(Locale.ROOT)); + break; + case POOLING_MODE_FIELD: + poolingMode = PoolingMode.from(parser.text().toUpperCase(Locale.ROOT)); + break; + case NORMALIZE_RESULT_FIELD: + normalizeResult = parser.booleanValue(); + break; + case MODEL_MAX_LENGTH_FIELD: + modelMaxLength = parser.intValue(); + break; + case QUERY_PREFIX: + queryPrefix = parser.text(); + break; + case PASSAGE_PREFIX: + passagePrefix = parser.text(); + break; default: parser.skipChildren(); break; } } - return new BaseModelConfig(modelType, allConfig, additionalConfig); + return new BaseModelConfig( + modelType, + allConfig, + additionalConfig, + embeddingDimension, + frameworkType, + poolingMode, + normalizeResult, + modelMaxLength, + queryPrefix, + passagePrefix + ); } @Override @@ -89,6 +162,21 @@ public BaseModelConfig(StreamInput in) throws IOException { if (in.getVersion().onOrAfter(VERSION_3_1_0)) { this.additionalConfig = in.readMap(); } + embeddingDimension = in.readOptionalInt(); + if (in.readBoolean()) { + frameworkType = in.readEnum(FrameworkType.class); + } else { + frameworkType = null; + } + if (in.readBoolean()) { + poolingMode = in.readEnum(PoolingMode.class); + } else { + poolingMode = null; + } + normalizeResult = in.readBoolean(); + modelMaxLength = in.readOptionalInt(); + queryPrefix = in.readOptionalString(); + passagePrefix = in.readOptionalString(); } @Override @@ -97,6 +185,23 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getVersion().onOrAfter(VERSION_3_1_0)) { out.writeMap(additionalConfig); } + out.writeOptionalInt(embeddingDimension); + if (frameworkType != null) { + out.writeBoolean(true); + out.writeEnum(frameworkType); + } else { + out.writeBoolean(false); + } + if (poolingMode != null) { + out.writeBoolean(true); + out.writeEnum(poolingMode); + } else { + out.writeBoolean(false); + } + out.writeBoolean(normalizeResult); + out.writeOptionalInt(modelMaxLength); + out.writeOptionalString(queryPrefix); + out.writeOptionalString(passagePrefix); } @Override @@ -111,10 +216,72 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (additionalConfig != null) { builder.field(ADDITIONAL_CONFIG_FIELD, additionalConfig); } + if (embeddingDimension != null) { + builder.field(EMBEDDING_DIMENSION_FIELD, embeddingDimension); + } + if (frameworkType != null) { + builder.field(FRAMEWORK_TYPE_FIELD, frameworkType); + } + if (modelMaxLength != null) { + builder.field(MODEL_MAX_LENGTH_FIELD, modelMaxLength); + } + if (poolingMode != null) { + builder.field(POOLING_MODE_FIELD, poolingMode); + } + if (normalizeResult) { + builder.field(NORMALIZE_RESULT_FIELD, normalizeResult); + } + if (queryPrefix != null) { + builder.field(QUERY_PREFIX, queryPrefix); + } + if (passagePrefix != null) { + builder.field(PASSAGE_PREFIX, passagePrefix); + } builder.endObject(); return builder; } + public enum PoolingMode { + MEAN("mean"), + MEAN_SQRT_LEN("mean_sqrt_len"), + MAX("max"), + WEIGHTED_MEAN("weightedmean"), + CLS("cls"), + LAST_TOKEN("lasttoken"); + + private String name; + + public String getName() { + return name; + } + + PoolingMode(String name) { + this.name = name; + } + + public static PoolingMode from(String value) { + try { + return PoolingMode.valueOf(value.toUpperCase(Locale.ROOT)); + } catch (Exception e) { + throw new IllegalArgumentException("Wrong pooling method"); + } + } + } + + public enum FrameworkType { + HUGGINGFACE_TRANSFORMERS, + SENTENCE_TRANSFORMERS, + HUGGINGFACE_TRANSFORMERS_NEURON; + + public static FrameworkType from(String value) { + try { + return FrameworkType.valueOf(value.toUpperCase(Locale.ROOT)); + } catch (Exception e) { + throw new IllegalArgumentException("Wrong framework type"); + } + } + } + protected void validateNoDuplicateKeys(String allConfig, Map additionalConfig) { if (allConfig == null || additionalConfig == null || additionalConfig.isEmpty()) { return; diff --git a/common/src/main/java/org/opensearch/ml/common/model/RemoteModelConfig.java b/common/src/main/java/org/opensearch/ml/common/model/RemoteModelConfig.java index dd6daa64a9..fe27e3e2ed 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/RemoteModelConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/model/RemoteModelConfig.java @@ -37,36 +37,29 @@ public class RemoteModelConfig extends BaseModelConfig { it -> parse(it) ); - public static final String EMBEDDING_DIMENSION_FIELD = "embedding_dimension"; - public static final String FRAMEWORK_TYPE_FIELD = "framework_type"; - public static final String POOLING_MODE_FIELD = "pooling_mode"; - public static final String NORMALIZE_RESULT_FIELD = "normalize_result"; - public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length"; - - private final Integer embeddingDimension; - private final FrameworkType frameworkType; - private final PoolingMode poolingMode; - private final boolean normalizeResult; - private final Integer modelMaxLength; - @Builder(toBuilder = true) public RemoteModelConfig( String modelType, Integer embeddingDimension, - FrameworkType frameworkType, + BaseModelConfig.FrameworkType frameworkType, String allConfig, - PoolingMode poolingMode, + BaseModelConfig.PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength, Map additionalConfig ) { - super(modelType, allConfig, additionalConfig); - this.embeddingDimension = embeddingDimension; - this.frameworkType = frameworkType; - this.poolingMode = poolingMode; - this.normalizeResult = normalizeResult; - this.modelMaxLength = modelMaxLength; - + super( + modelType, + allConfig, + additionalConfig, + embeddingDimension, + frameworkType, + poolingMode, + normalizeResult, + modelMaxLength, + null, + null + ); validateNoDuplicateKeys(allConfig, additionalConfig); validateTextEmbeddingConfig(); } @@ -74,9 +67,9 @@ public RemoteModelConfig( public static RemoteModelConfig parse(XContentParser parser) throws IOException { String modelType = null; Integer embeddingDimension = null; - FrameworkType frameworkType = null; + BaseModelConfig.FrameworkType frameworkType = null; String allConfig = null; - PoolingMode poolingMode = null; + BaseModelConfig.PoolingMode poolingMode = null; boolean normalizeResult = false; Integer modelMaxLength = null; Map additionalConfig = null; @@ -94,13 +87,13 @@ public static RemoteModelConfig parse(XContentParser parser) throws IOException embeddingDimension = parser.intValue(); break; case FRAMEWORK_TYPE_FIELD: - frameworkType = FrameworkType.from(parser.text().toUpperCase(Locale.ROOT)); + frameworkType = BaseModelConfig.FrameworkType.from(parser.text().toUpperCase(Locale.ROOT)); break; case ALL_CONFIG_FIELD: allConfig = parser.text(); break; case POOLING_MODE_FIELD: - poolingMode = PoolingMode.from(parser.text().toUpperCase(Locale.ROOT)); + poolingMode = BaseModelConfig.PoolingMode.from(parser.text().toUpperCase(Locale.ROOT)); break; case NORMALIZE_RESULT_FIELD: normalizeResult = parser.booleanValue(); @@ -135,39 +128,11 @@ public String getWriteableName() { public RemoteModelConfig(StreamInput in) throws IOException { super(in); - embeddingDimension = in.readOptionalInt(); - if (in.readBoolean()) { - frameworkType = in.readEnum(FrameworkType.class); - } else { - frameworkType = null; - } - if (in.readBoolean()) { - poolingMode = in.readEnum(PoolingMode.class); - } else { - poolingMode = null; - } - normalizeResult = in.readBoolean(); - modelMaxLength = in.readOptionalInt(); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeOptionalInt(embeddingDimension); - if (frameworkType != null) { - out.writeBoolean(true); - out.writeEnum(frameworkType); - } else { - out.writeBoolean(false); - } - if (poolingMode != null) { - out.writeBoolean(true); - out.writeEnum(poolingMode); - } else { - out.writeBoolean(false); - } - out.writeBoolean(normalizeResult); - out.writeOptionalInt(modelMaxLength); } @Override @@ -201,47 +166,6 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par return builder; } - public enum PoolingMode { - MEAN("mean"), - MEAN_SQRT_LEN("mean_sqrt_len"), - MAX("max"), - WEIGHTED_MEAN("weightedmean"), - CLS("cls"), - LAST_TOKEN("lasttoken"); - - private String name; - - public String getName() { - return name; - } - - PoolingMode(String name) { - this.name = name; - } - - public static PoolingMode from(String value) { - try { - return PoolingMode.valueOf(value.toUpperCase(Locale.ROOT)); - } catch (Exception e) { - throw new IllegalArgumentException("Wrong pooling method"); - } - } - } - - public enum FrameworkType { - HUGGINGFACE_TRANSFORMERS, - SENTENCE_TRANSFORMERS, - HUGGINGFACE_TRANSFORMERS_NEURON; - - public static FrameworkType from(String value) { - try { - return FrameworkType.valueOf(value.toUpperCase(Locale.ROOT)); - } catch (Exception e) { - throw new IllegalArgumentException("Wrong framework type"); - } - } - } - private void validateTextEmbeddingConfig() { if (modelType != null && modelType.equalsIgnoreCase("text_embedding")) { if (embeddingDimension == null) { diff --git a/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java b/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java index d7f136ec74..1e66157b4d 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java @@ -34,22 +34,6 @@ public class TextEmbeddingModelConfig extends BaseModelConfig { it -> parse(it) ); - public static final String EMBEDDING_DIMENSION_FIELD = "embedding_dimension"; - public static final String FRAMEWORK_TYPE_FIELD = "framework_type"; - public static final String POOLING_MODE_FIELD = "pooling_mode"; - public static final String NORMALIZE_RESULT_FIELD = "normalize_result"; - public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length"; - public static final String QUERY_PREFIX = "query_prefix"; - public static final String PASSAGE_PREFIX = "passage_prefix"; - - private final Integer embeddingDimension; - private final FrameworkType frameworkType; - private final PoolingMode poolingMode; - private final boolean normalizeResult; - private final Integer modelMaxLength; - private final String queryPrefix; - private final String passagePrefix; - public TextEmbeddingModelConfig( String modelType, Integer embeddingDimension, @@ -78,16 +62,27 @@ public TextEmbeddingModelConfig( public TextEmbeddingModelConfig( String modelType, Integer embeddingDimension, - FrameworkType frameworkType, + BaseModelConfig.FrameworkType frameworkType, String allConfig, Map additionalConfig, - PoolingMode poolingMode, + BaseModelConfig.PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength, String queryPrefix, String passagePrefix ) { - super(modelType, allConfig, additionalConfig); + super( + modelType, + allConfig, + additionalConfig, + embeddingDimension, + frameworkType, + poolingMode, + normalizeResult, + modelMaxLength, + queryPrefix, + passagePrefix + ); if (embeddingDimension == null) { throw new IllegalArgumentException("embedding dimension is null"); } @@ -95,22 +90,15 @@ public TextEmbeddingModelConfig( throw new IllegalArgumentException("framework type is null"); } validateNoDuplicateKeys(allConfig, additionalConfig); - this.embeddingDimension = embeddingDimension; - this.frameworkType = frameworkType; - this.poolingMode = poolingMode; - this.normalizeResult = normalizeResult; - this.modelMaxLength = modelMaxLength; - this.queryPrefix = queryPrefix; - this.passagePrefix = passagePrefix; } public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOException { String modelType = null; Integer embeddingDimension = null; - FrameworkType frameworkType = null; + BaseModelConfig.FrameworkType frameworkType = null; String allConfig = null; Map additionalConfig = null; - PoolingMode poolingMode = null; + BaseModelConfig.PoolingMode poolingMode = null; boolean normalizeResult = false; Integer modelMaxLength = null; String queryPrefix = null; @@ -129,7 +117,7 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc embeddingDimension = parser.intValue(); break; case FRAMEWORK_TYPE_FIELD: - frameworkType = FrameworkType.from(parser.text().toUpperCase(Locale.ROOT)); + frameworkType = BaseModelConfig.FrameworkType.from(parser.text().toUpperCase(Locale.ROOT)); break; case ALL_CONFIG_FIELD: allConfig = parser.text(); @@ -138,7 +126,7 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc additionalConfig = parser.map(); break; case POOLING_MODE_FIELD: - poolingMode = PoolingMode.from(parser.text().toUpperCase(Locale.ROOT)); + poolingMode = BaseModelConfig.PoolingMode.from(parser.text().toUpperCase(Locale.ROOT)); break; case NORMALIZE_RESULT_FIELD: normalizeResult = parser.booleanValue(); @@ -178,34 +166,11 @@ public String getWriteableName() { public TextEmbeddingModelConfig(StreamInput in) throws IOException { super(in); - embeddingDimension = in.readInt(); - frameworkType = in.readEnum(FrameworkType.class); - if (in.readBoolean()) { - poolingMode = in.readEnum(PoolingMode.class); - } else { - poolingMode = null; - } - normalizeResult = in.readBoolean(); - modelMaxLength = in.readOptionalInt(); - queryPrefix = in.readOptionalString(); - passagePrefix = in.readOptionalString(); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeInt(embeddingDimension); - out.writeEnum(frameworkType); - if (poolingMode != null) { - out.writeBoolean(true); - out.writeEnum(poolingMode); - } else { - out.writeBoolean(false); - } - out.writeBoolean(normalizeResult); - out.writeOptionalInt(modelMaxLength); - out.writeOptionalString(queryPrefix); - out.writeOptionalString(passagePrefix); } @Override @@ -244,45 +209,4 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par builder.endObject(); return builder; } - - public enum PoolingMode { - MEAN("mean"), - MEAN_SQRT_LEN("mean_sqrt_len"), - MAX("max"), - WEIGHTED_MEAN("weightedmean"), - CLS("cls"), - LAST_TOKEN("lasttoken"); - - private String name; - - public String getName() { - return name; - } - - PoolingMode(String name) { - this.name = name; - } - - public static PoolingMode from(String value) { - try { - return PoolingMode.valueOf(value.toUpperCase(Locale.ROOT)); - } catch (Exception e) { - throw new IllegalArgumentException("Wrong pooling method"); - } - } - } - - public enum FrameworkType { - HUGGINGFACE_TRANSFORMERS, - SENTENCE_TRANSFORMERS, - HUGGINGFACE_TRANSFORMERS_NEURON; - - public static FrameworkType from(String value) { - try { - return FrameworkType.valueOf(value.toUpperCase(Locale.ROOT)); - } catch (Exception e) { - throw new IllegalArgumentException("Wrong framework type"); - } - } - } } diff --git a/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java b/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java index 7d631ac2f6..021be0c941 100644 --- a/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java +++ b/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java @@ -216,6 +216,12 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_MEMORY_FEATURE_ENABLED = Setting .boolSetting("plugins.ml_commons.memory_feature_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting ML_COMMONS_AGENTIC_SEARCH_ENABLED = Setting + .boolSetting("plugins.ml_commons.agentic_search_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final String ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE = + "The QueryPlanningTool tool for Agentic Search is not enabled. To enable, please update the setting " + + ML_COMMONS_AGENTIC_SEARCH_ENABLED.getKey(); + public static final Setting ML_COMMONS_MCP_CONNECTOR_ENABLED = Setting .boolSetting("plugins.ml_commons.mcp_connector_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); public static final String ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE = @@ -237,6 +243,12 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED = Setting .boolSetting("plugins.ml_commons.connector.private_ip_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); + // Feature flag for execute tool API + public static final Setting ML_COMMONS_EXECUTE_TOOL_ENABLED = Setting + .boolSetting("plugins.ml_commons.execute_tools_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final String ML_COMMONS_EXECUTE_TOOL_DISABLED_MESSAGE = + "The Execute Tool API is not enabled. To enable, please update the setting " + ML_COMMONS_EXECUTE_TOOL_ENABLED.getKey(); + public static final Setting> ML_COMMONS_REMOTE_JOB_STATUS_FIELD = Setting .listSetting( "plugins.ml_commons.remote_job.status_field", @@ -350,4 +362,38 @@ private MLCommonsSettings() {} // Feature flag for enabling telemetry static metric collection job -- MLStatsJobProcessor public static final Setting ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED = Setting .boolSetting("plugins.ml_commons.metrics_static_collection_enabled", false, Setting.Property.NodeScope, Setting.Property.Final); + + // Feature flag for Agentic memory APIs + public static final Setting ML_COMMONS_AGENTIC_MEMORY_ENABLED = Setting + .boolSetting("plugins.ml_commons.agentic_memory_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final String ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE = + "The Agentic Memory APIs are not enabled. To enable, please update the setting " + ML_COMMONS_AGENTIC_MEMORY_ENABLED.getKey(); + + // Feature flag for enabling telemetry tracer + // This setting is Final because it controls the core tracing infrastructure initialization. + // Once the tracer is initialized, changing this setting would require a node restart + // to properly reinitialize the tracing components. + public static final Setting ML_COMMONS_TRACING_ENABLED = Setting + .boolSetting("plugins.ml_commons.tracing_enabled", false, Setting.Property.NodeScope, Setting.Property.Final); + + // Feature flag for enabling telemetry agent tracing + // This setting is Dynamic because agent tracing can be enabled/disabled at runtime + // without requiring a node restart. The MLAgentTracer singleton can be updated + // to switch between real tracer and NoopTracer based on this setting. + public static final Setting ML_COMMONS_AGENT_TRACING_ENABLED = Setting + .boolSetting("plugins.ml_commons.agent_tracing_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // Feature flag for enabling telemetry connector tracing + // This setting is Dynamic because connector tracing can be enabled/disabled at runtime + // without requiring a node restart. The MLConnectorTracer singleton can be updated + // to switch between real tracer and NoopTracer based on this setting. + public static final Setting ML_COMMONS_CONNECTOR_TRACING_ENABLED = Setting + .boolSetting("plugins.ml_commons.connector_tracing_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // Feature flag for enabling telemetry model tracing + // This setting is Dynamic because model tracing can be enabled/disabled at runtime + // without requiring a node restart. The MLModelTracer singleton can be updated + // to switch between real tracer and NoopTracer based on this setting. + public static final Setting ML_COMMONS_MODEL_TRACING_ENABLED = Setting + .boolSetting("plugins.ml_commons.model_tracing_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); } diff --git a/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java b/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java index 786af9e29c..799f8c4e47 100644 --- a/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java +++ b/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java @@ -1,24 +1,30 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common.settings; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_ENABLED; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENT_TRACING_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_TRACING_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_EXECUTE_TOOL_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_SERVER_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_METRIC_COLLECTION_ENABLED; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MODEL_TRACING_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_TRACING_ENABLED; import java.util.ArrayList; import java.util.List; @@ -51,6 +57,19 @@ public class MLFeatureEnabledSetting { private volatile Boolean isMetricCollectionEnabled; private volatile Boolean isStaticMetricCollectionEnabled; + private volatile Boolean isExecuteToolEnabled; + + private volatile Boolean isAgenticSearchEnabled; + + private volatile Boolean isMcpConnectorEnabled; + + private volatile Boolean isAgenticMemoryEnabled; + + private volatile Boolean isTracingEnabled; + private volatile Boolean isAgentTracingEnabled; + private volatile Boolean isConnectorTracingEnabled; + private volatile Boolean isModelTracingEnabled; + private final List listeners = new ArrayList<>(); public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) { @@ -66,6 +85,14 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) isRagSearchPipelineEnabled = ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED.get(settings); isMetricCollectionEnabled = ML_COMMONS_METRIC_COLLECTION_ENABLED.get(settings); isStaticMetricCollectionEnabled = ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED.get(settings); + isExecuteToolEnabled = ML_COMMONS_EXECUTE_TOOL_ENABLED.get(settings); + isAgenticSearchEnabled = ML_COMMONS_AGENTIC_SEARCH_ENABLED.get(settings); + isMcpConnectorEnabled = ML_COMMONS_MCP_CONNECTOR_ENABLED.get(settings); + isAgenticMemoryEnabled = ML_COMMONS_AGENTIC_MEMORY_ENABLED.get(settings); + isTracingEnabled = ML_COMMONS_TRACING_ENABLED.get(settings); + isAgentTracingEnabled = ML_COMMONS_AGENT_TRACING_ENABLED.get(settings); + isConnectorTracingEnabled = ML_COMMONS_CONNECTOR_TRACING_ENABLED.get(settings); + isModelTracingEnabled = ML_COMMONS_MODEL_TRACING_ENABLED.get(settings); clusterService .getClusterSettings() @@ -88,6 +115,19 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) clusterService .getClusterSettings() .addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, it -> isRagSearchPipelineEnabled = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_EXECUTE_TOOL_ENABLED, it -> isExecuteToolEnabled = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_AGENTIC_SEARCH_ENABLED, it -> isAgenticSearchEnabled = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MCP_CONNECTOR_ENABLED, it -> isMcpConnectorEnabled = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_AGENTIC_MEMORY_ENABLED, it -> isAgenticMemoryEnabled = it); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_AGENT_TRACING_ENABLED, it -> isAgentTracingEnabled = it); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_CONNECTOR_TRACING_ENABLED, it -> isConnectorTracingEnabled = it); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_MODEL_TRACING_ENABLED, it -> isModelTracingEnabled = it); } /** @@ -178,10 +218,50 @@ public boolean isStaticMetricCollectionEnabled() { return isStaticMetricCollectionEnabled; } + /** + * Whether the execute tool API is enabled. If disabled, execute tool API in ml-commons will be blocked + * @return whether the execute tool API is enabled. + */ + public boolean isToolExecuteEnabled() { + return isExecuteToolEnabled; + } + + /** + * Whether the Agentic memory APIs are enabled. If disabled, Agentic memory APIs in ml-commons will be blocked + * @return whether the agentic memory feature is enabled. + */ + public boolean isAgenticMemoryEnabled() { + return isAgenticMemoryEnabled; + } + + public boolean isTracingEnabled() { + return isTracingEnabled; + } + + public boolean isAgentTracingEnabled() { + return isAgentTracingEnabled; + } + + public boolean isConnectorTracingEnabled() { + return isConnectorTracingEnabled; + } + + public boolean isModelTracingEnabled() { + return isModelTracingEnabled; + } + @VisibleForTesting public void notifyMultiTenancyListeners(boolean isEnabled) { for (SettingsChangeListener listener : listeners) { listener.onMultiTenancyEnabledChanged(isEnabled); } } + + public boolean isAgenticSearchEnabled() { + return isAgenticSearchEnabled; + } + + public boolean isMcpConnectorEnabled() { + return isMcpConnectorEnabled; + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java index 3d1ff1b163..9a0d6002fd 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java @@ -12,6 +12,7 @@ import java.io.IOException; import java.time.Instant; import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -41,9 +42,14 @@ public class MLAgentUpdateInput implements ToXContentObject, Writeable { public static final String AGENT_NAME_FIELD = "name"; public static final String DESCRIPTION_FIELD = "description"; public static final String LLM_FIELD = "llm"; + public static final String LLM_MODEL_ID_FIELD = "model_id"; + public static final String LLM_PARAMETERS_FIELD = "parameters"; public static final String TOOLS_FIELD = "tools"; public static final String PARAMETERS_FIELD = "parameters"; public static final String MEMORY_FIELD = "memory"; + public static final String MEMORY_TYPE_FIELD = "type"; + public static final String MEMORY_SESSION_ID_FIELD = "session_id"; + public static final String MEMORY_WINDOW_SIZE_FIELD = "window_size"; public static final String APP_TYPE_FIELD = "app_type"; public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; @@ -51,10 +57,13 @@ public class MLAgentUpdateInput implements ToXContentObject, Writeable { private String agentId; private String name; private String description; - private LLMSpec llm; + private String llmModelId; + private Map llmParameters; private List tools; private Map parameters; - private MLMemorySpec memory; + private String memoryType; + private String memorySessionId; + private Integer memoryWindowSize; private String appType; private Instant lastUpdateTime; private String tenantId; @@ -64,10 +73,13 @@ public MLAgentUpdateInput( String agentId, String name, String description, - LLMSpec llm, + String llmModelId, + Map llmParameters, List tools, Map parameters, - MLMemorySpec memory, + String memoryType, + String memorySessionId, + Integer memoryWindowSize, String appType, Instant lastUpdateTime, String tenantId @@ -75,10 +87,13 @@ public MLAgentUpdateInput( this.agentId = agentId; this.name = name; this.description = description; - this.llm = llm; + this.llmModelId = llmModelId; + this.llmParameters = llmParameters; this.tools = tools; this.parameters = parameters; - this.memory = memory; + this.memoryType = memoryType; + this.memorySessionId = memorySessionId; + this.memoryWindowSize = memoryWindowSize; this.appType = appType; this.lastUpdateTime = lastUpdateTime; this.tenantId = tenantId; @@ -90,8 +105,9 @@ public MLAgentUpdateInput(StreamInput in) throws IOException { agentId = in.readString(); name = in.readOptionalString(); description = in.readOptionalString(); + llmModelId = in.readOptionalString(); if (in.readBoolean()) { - llm = new LLMSpec(in); + llmParameters = in.readMap(StreamInput::readString, StreamInput::readOptionalString); } if (in.readBoolean()) { tools = new ArrayList<>(); @@ -103,9 +119,9 @@ public MLAgentUpdateInput(StreamInput in) throws IOException { if (in.readBoolean()) { parameters = in.readMap(StreamInput::readString, StreamInput::readOptionalString); } - if (in.readBoolean()) { - memory = new MLMemorySpec(in); - } + memoryType = in.readOptionalString(); + memorySessionId = in.readOptionalString(); + memoryWindowSize = in.readOptionalInt(); lastUpdateTime = in.readOptionalInstant(); appType = in.readOptionalString(); tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null; @@ -121,8 +137,15 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (description != null) { builder.field(DESCRIPTION_FIELD, description); } - if (llm != null) { - builder.field(LLM_FIELD, llm); + if (llmModelId != null || (llmParameters != null && !llmParameters.isEmpty())) { + builder.startObject(LLM_FIELD); + if (llmModelId != null) { + builder.field(LLM_MODEL_ID_FIELD, llmModelId); + } + if (llmParameters != null && !llmParameters.isEmpty()) { + builder.field(LLM_PARAMETERS_FIELD, llmParameters); + } + builder.endObject(); } if (tools != null && !tools.isEmpty()) { builder.field(TOOLS_FIELD, tools); @@ -130,8 +153,18 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (parameters != null && !parameters.isEmpty()) { builder.field(PARAMETERS_FIELD, parameters); } - if (memory != null) { - builder.field(MEMORY_FIELD, memory); + if (memoryType != null || memorySessionId != null || memoryWindowSize != null) { + builder.startObject(MEMORY_FIELD); + if (memoryType != null) { + builder.field(MEMORY_TYPE_FIELD, memoryType); + } + if (memorySessionId != null) { + builder.field(MEMORY_SESSION_ID_FIELD, memorySessionId); + } + if (memoryWindowSize != null) { + builder.field(MEMORY_WINDOW_SIZE_FIELD, memoryWindowSize); + } + builder.endObject(); } if (appType != null) { builder.field(APP_TYPE_FIELD, appType); @@ -152,9 +185,10 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(agentId); out.writeOptionalString(name); out.writeOptionalString(description); - if (llm != null) { + out.writeOptionalString(llmModelId); + if (llmParameters != null && !llmParameters.isEmpty()) { out.writeBoolean(true); - llm.writeTo(out); + out.writeMap(llmParameters, StreamOutput::writeString, StreamOutput::writeOptionalString); } else { out.writeBoolean(false); } @@ -173,12 +207,9 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } - if (memory != null) { - out.writeBoolean(true); - memory.writeTo(out); - } else { - out.writeBoolean(false); - } + out.writeOptionalString(memoryType); + out.writeOptionalString(memorySessionId); + out.writeOptionalInt(memoryWindowSize); out.writeOptionalInstant(lastUpdateTime); out.writeOptionalString(appType); if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { @@ -190,10 +221,13 @@ public static MLAgentUpdateInput parse(XContentParser parser) throws IOException String agentId = null; String name = null; String description = null; - LLMSpec llm = null; + String llmModelId = null; + Map llmParameters = null; List tools = null; Map parameters = null; - MLMemorySpec memory = null; + String memoryType = null; + String memorySessionId = null; + Integer memoryWindowSize = null; String appType = null; Instant lastUpdateTime = null; String tenantId = null; @@ -213,7 +247,22 @@ public static MLAgentUpdateInput parse(XContentParser parser) throws IOException description = parser.text(); break; case LLM_FIELD: - llm = LLMSpec.parse(parser); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String llmFieldName = parser.currentName(); + parser.nextToken(); + switch (llmFieldName) { + case LLM_MODEL_ID_FIELD: + llmModelId = parser.text(); + break; + case LLM_PARAMETERS_FIELD: + llmParameters = parser.mapStrings(); + break; + default: + parser.skipChildren(); + break; + } + } break; case TOOLS_FIELD: tools = new ArrayList<>(); @@ -226,7 +275,25 @@ public static MLAgentUpdateInput parse(XContentParser parser) throws IOException parameters = parser.mapStrings(); break; case MEMORY_FIELD: - memory = MLMemorySpec.parse(parser); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String memoryFieldName = parser.currentName(); + parser.nextToken(); + switch (memoryFieldName) { + case MEMORY_TYPE_FIELD: + memoryType = parser.text(); + break; + case MEMORY_SESSION_ID_FIELD: + memorySessionId = parser.text(); + break; + case MEMORY_WINDOW_SIZE_FIELD: + memoryWindowSize = parser.intValue(); + break; + default: + parser.skipChildren(); + break; + } + } break; case APP_TYPE_FIELD: appType = parser.text(); @@ -243,10 +310,56 @@ public static MLAgentUpdateInput parse(XContentParser parser) throws IOException } } - return new MLAgentUpdateInput(agentId, name, description, llm, tools, parameters, memory, appType, lastUpdateTime, tenantId); + return new MLAgentUpdateInput( + agentId, + name, + description, + llmModelId, + llmParameters, + tools, + parameters, + memoryType, + memorySessionId, + memoryWindowSize, + appType, + lastUpdateTime, + tenantId + ); } public MLAgent toMLAgent(MLAgent originalAgent) { + LLMSpec finalLlm; + if (llmModelId == null && (llmParameters == null || llmParameters.isEmpty())) { + finalLlm = originalAgent.getLlm(); + } else { + LLMSpec originalLlm = originalAgent.getLlm(); + + String finalModelId = llmModelId != null ? llmModelId : originalLlm.getModelId(); + + Map finalParameters = new HashMap<>(); + if (originalLlm != null && originalLlm.getParameters() != null) { + finalParameters.putAll(originalLlm.getParameters()); + } + if (llmParameters != null) { + finalParameters.putAll(llmParameters); + } + + finalLlm = LLMSpec.builder().modelId(finalModelId).parameters(finalParameters).build(); + } + + MLMemorySpec finalMemory; + if (memoryType == null && memorySessionId == null && memoryWindowSize == null) { + finalMemory = originalAgent.getMemory(); + } else { + MLMemorySpec originalMemory = originalAgent.getMemory(); + + String finalMemoryType = memoryType != null ? memoryType : originalMemory.getType(); + String finalSessionId = memorySessionId != null ? memorySessionId : originalMemory.getSessionId(); + Integer finalWindowSize = memoryWindowSize != null ? memoryWindowSize : originalMemory.getWindowSize(); + + finalMemory = MLMemorySpec.builder().type(finalMemoryType).sessionId(finalSessionId).windowSize(finalWindowSize).build(); + } + return MLAgent .builder() .type(originalAgent.getType()) @@ -254,10 +367,10 @@ public MLAgent toMLAgent(MLAgent originalAgent) { .isHidden(originalAgent.getIsHidden()) .name(name == null ? originalAgent.getName() : name) .description(description == null ? originalAgent.getDescription() : description) - .llm(llm == null ? originalAgent.getLlm() : llm) + .llm(finalLlm) .tools(tools == null ? originalAgent.getTools() : tools) .parameters(parameters == null ? originalAgent.getParameters() : parameters) - .memory(memory == null ? originalAgent.getMemory() : memory) + .memory(finalMemory) .lastUpdateTime(lastUpdateTime) .appType(appType) .tenantId(tenantId) @@ -270,8 +383,8 @@ private void validate() { String.format("Agent name cannot be empty or exceed max length of %d characters", MLAgent.AGENT_NAME_MAX_LENGTH) ); } - if (memory != null && !memory.getType().equals("conversation_index")) { - throw new IllegalArgumentException(String.format("Invalid memory type: %s", memory.getType())); + if (memoryType != null && !memoryType.equals("conversation_index")) { + throw new IllegalArgumentException(String.format("Invalid memory type: %s", memoryType)); } if (tools != null) { Set toolNames = new HashSet<>(); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java index ccf6894b4c..8e47bfb1e6 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java @@ -109,6 +109,11 @@ public MLCreateConnectorInput( if (credential == null || credential.isEmpty()) { throw new IllegalArgumentException("Connector credential is null or empty list"); } + if (actions != null) { + for (ConnectorAction action : actions) { + action.validatePrePostProcessFunctions(parameters); + } + } } this.name = name; this.description = description; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/McpToolBaseInput.java b/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/McpToolBaseInput.java index 7f89aeefd8..a4dc6c7b70 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/McpToolBaseInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/McpToolBaseInput.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common.transport.mcpserver.requests; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/list/MLMcpToolsListRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/list/MLMcpToolsListRequest.java new file mode 100644 index 0000000000..8dade66492 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/list/MLMcpToolsListRequest.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.mcpserver.requests.list; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; + +public class MLMcpToolsListRequest extends ActionRequest { + + public MLMcpToolsListRequest(StreamInput input) throws IOException { + super(input); + } + + public MLMcpToolsListRequest() { + super(); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/register/MLMcpToolsRegisterNodesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/register/MLMcpToolsRegisterNodesRequest.java index 444bd6505f..5e9cb827f5 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/register/MLMcpToolsRegisterNodesRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/register/MLMcpToolsRegisterNodesRequest.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common.transport.mcpserver.requests.register; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/register/McpToolRegisterInput.java b/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/register/McpToolRegisterInput.java index 3b2be93617..f2d819dc8b 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/register/McpToolRegisterInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/register/McpToolRegisterInput.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common.transport.mcpserver.requests.register; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/remove/MLMcpToolsRemoveNodesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/remove/MLMcpToolsRemoveNodesRequest.java index f6df217b94..f622ad39bb 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/remove/MLMcpToolsRemoveNodesRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/remove/MLMcpToolsRemoveNodesRequest.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common.transport.mcpserver.requests.remove; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/update/MLMcpToolsUpdateNodesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/update/MLMcpToolsUpdateNodesRequest.java index d131dc6c91..c1da582ebb 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/update/MLMcpToolsUpdateNodesRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/update/MLMcpToolsUpdateNodesRequest.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common.transport.mcpserver.requests.update; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLCreateMemoryContainerAction.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLCreateMemoryContainerAction.java new file mode 100644 index 0000000000..283b1733c9 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLCreateMemoryContainerAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer; + +import org.opensearch.action.ActionType; + +public class MLCreateMemoryContainerAction extends ActionType { + public static final MLCreateMemoryContainerAction INSTANCE = new MLCreateMemoryContainerAction(); + public static final String NAME = "cluster:admin/opensearch/ml/memory_containers/create"; + + private MLCreateMemoryContainerAction() { + super(NAME, MLCreateMemoryContainerResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLCreateMemoryContainerInput.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLCreateMemoryContainerInput.java new file mode 100644 index 0000000000..52becb0d8f --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLCreateMemoryContainerInput.java @@ -0,0 +1,126 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.memorycontainer.MemoryStorageConfig; + +import lombok.Builder; +import lombok.Data; + +@Data +public class MLCreateMemoryContainerInput implements ToXContentObject, Writeable { + + public static final String NAME_FIELD = "name"; + public static final String DESCRIPTION_FIELD = "description"; + public static final String MEMORY_STORAGE_CONFIG_FIELD = "memory_storage_config"; + + private String name; + private String description; + private MemoryStorageConfig memoryStorageConfig; + private String tenantId; + + @Builder(toBuilder = true) + public MLCreateMemoryContainerInput(String name, String description, MemoryStorageConfig memoryStorageConfig, String tenantId) { + if (name == null) { + throw new IllegalArgumentException("name is null"); + } + this.name = name; + this.description = description; + this.memoryStorageConfig = memoryStorageConfig; + this.tenantId = tenantId; + } + + public MLCreateMemoryContainerInput(StreamInput in) throws IOException { + this.name = in.readString(); + this.description = in.readOptionalString(); + if (in.readBoolean()) { + this.memoryStorageConfig = new MemoryStorageConfig(in); + } else { + this.memoryStorageConfig = null; + } + this.tenantId = in.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(name); + out.writeOptionalString(description); + if (memoryStorageConfig != null) { + out.writeBoolean(true); + memoryStorageConfig.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeOptionalString(tenantId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(NAME_FIELD, name); + if (description != null) { + builder.field(DESCRIPTION_FIELD, description); + } + if (memoryStorageConfig != null) { + builder.field(MEMORY_STORAGE_CONFIG_FIELD, memoryStorageConfig); + } + if (tenantId != null) { + builder.field(TENANT_ID_FIELD, tenantId); + } + builder.endObject(); + return builder; + } + + public static MLCreateMemoryContainerInput parse(XContentParser parser) throws IOException { + String name = null; + String description = null; + MemoryStorageConfig memoryStorageConfig = null; + String tenantId = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case NAME_FIELD: + name = parser.text(); + break; + case DESCRIPTION_FIELD: + description = parser.text(); + break; + case MEMORY_STORAGE_CONFIG_FIELD: + memoryStorageConfig = MemoryStorageConfig.parse(parser); + break; + case TENANT_ID_FIELD: + tenantId = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + + return MLCreateMemoryContainerInput + .builder() + .name(name) + .description(description) + .memoryStorageConfig(memoryStorageConfig) + .tenantId(tenantId) + .build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLCreateMemoryContainerRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLCreateMemoryContainerRequest.java new file mode 100644 index 0000000000..a1410093f0 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLCreateMemoryContainerRequest.java @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.Builder; +import lombok.Getter; + +@Getter +public class MLCreateMemoryContainerRequest extends ActionRequest { + + private final MLCreateMemoryContainerInput mlCreateMemoryContainerInput; + + @Builder + public MLCreateMemoryContainerRequest(MLCreateMemoryContainerInput mlCreateMemoryContainerInput) { + this.mlCreateMemoryContainerInput = mlCreateMemoryContainerInput; + } + + public MLCreateMemoryContainerRequest(StreamInput in) throws IOException { + super(in); + this.mlCreateMemoryContainerInput = new MLCreateMemoryContainerInput(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + this.mlCreateMemoryContainerInput.writeTo(out); + } + + @Override + public ActionRequestValidationException validate() { + if (mlCreateMemoryContainerInput == null) { + return addValidationError("Memory container input can't be null", null); + } + + // All MemoryStorageConfig validation is handled by MemoryStorageConfig itself + return null; + } + + public static MLCreateMemoryContainerRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLCreateMemoryContainerRequest) { + return (MLCreateMemoryContainerRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLCreateMemoryContainerRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionRequest into MLCreateMemoryContainerRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLCreateMemoryContainerResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLCreateMemoryContainerResponse.java new file mode 100644 index 0000000000..83e238c43c --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLCreateMemoryContainerResponse.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer; + +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_CONTAINER_ID_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.STATUS_FIELD; + +import java.io.IOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Getter; + +@Getter +public class MLCreateMemoryContainerResponse extends ActionResponse implements ToXContentObject { + + private String memoryContainerId; + private String status; + + public MLCreateMemoryContainerResponse(String memoryContainerId, String status) { + this.memoryContainerId = memoryContainerId; + this.status = status; + } + + public MLCreateMemoryContainerResponse(StreamInput in) throws IOException { + super(in); + this.memoryContainerId = in.readString(); + this.status = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(memoryContainerId); + out.writeString(status); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MEMORY_CONTAINER_ID_FIELD, memoryContainerId); + builder.field(STATUS_FIELD, status); + builder.endObject(); + return builder; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerDeleteAction.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerDeleteAction.java new file mode 100644 index 0000000000..37a04e5dd1 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerDeleteAction.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer; + +import org.opensearch.action.ActionType; +import org.opensearch.action.delete.DeleteResponse; + +public class MLMemoryContainerDeleteAction extends ActionType { + public static final MLMemoryContainerDeleteAction INSTANCE = new MLMemoryContainerDeleteAction(); + public static final String NAME = "cluster:admin/opensearch/ml/memory_containers/delete"; + + private MLMemoryContainerDeleteAction() { + super(NAME, DeleteResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerDeleteRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerDeleteRequest.java new file mode 100644 index 0000000000..bdcec9193d --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerDeleteRequest.java @@ -0,0 +1,76 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.Builder; +import lombok.Getter; + +public class MLMemoryContainerDeleteRequest extends ActionRequest { + @Getter + String memoryContainerId; + + @Getter + String tenantId; + + @Builder + public MLMemoryContainerDeleteRequest(String memoryContainerId, String tenantId) { + this.memoryContainerId = memoryContainerId; + this.tenantId = tenantId; + } + + public MLMemoryContainerDeleteRequest(StreamInput input) throws IOException { + super(input); + this.memoryContainerId = input.readString(); + this.tenantId = input.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput output) throws IOException { + super.writeTo(output); + output.writeString(memoryContainerId); + output.writeOptionalString(tenantId); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + + if (this.memoryContainerId == null) { + exception = addValidationError("ML memory container id can't be null", exception); + } + + return exception; + } + + public static MLMemoryContainerDeleteRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLMemoryContainerDeleteRequest) { + return (MLMemoryContainerDeleteRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLMemoryContainerDeleteRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLMemoryContainerDeleteRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerGetAction.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerGetAction.java new file mode 100644 index 0000000000..d31dc0fb43 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerGetAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer; + +import org.opensearch.action.ActionType; + +public class MLMemoryContainerGetAction extends ActionType { + public static final MLMemoryContainerGetAction INSTANCE = new MLMemoryContainerGetAction(); + public static final String NAME = "cluster:admin/opensearch/ml/memory_containers/get"; + + private MLMemoryContainerGetAction() { + super(NAME, MLMemoryContainerGetResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerGetRequest.java new file mode 100644 index 0000000000..f6cf720bdb --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerGetRequest.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLMemoryContainerGetRequest extends ActionRequest { + + String memoryContainerId; + String tenantId; + + @Builder + public MLMemoryContainerGetRequest(String memoryContainerId, String tenantId) { + this.memoryContainerId = memoryContainerId; + this.tenantId = tenantId; + } + + public MLMemoryContainerGetRequest(StreamInput in) throws IOException { + super(in); + this.memoryContainerId = in.readString(); + this.tenantId = in.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.memoryContainerId); + out.writeOptionalString(tenantId); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + + if (this.memoryContainerId == null) { + exception = addValidationError("Memory container id can't be null", exception); + } + + return exception; + } + + public static MLMemoryContainerGetRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLMemoryContainerGetRequest) { + return (MLMemoryContainerGetRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLMemoryContainerGetRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLMemoryContainerGetRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerGetResponse.java new file mode 100644 index 0000000000..bdb0feed9d --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerGetResponse.java @@ -0,0 +1,66 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.memorycontainer.MLMemoryContainer; + +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; + +@Getter +@ToString +public class MLMemoryContainerGetResponse extends ActionResponse implements ToXContentObject { + + MLMemoryContainer mlMemoryContainer; + + @Builder + public MLMemoryContainerGetResponse(MLMemoryContainer mlMemoryContainer) { + this.mlMemoryContainer = mlMemoryContainer; + } + + public MLMemoryContainerGetResponse(StreamInput in) throws IOException { + super(in); + mlMemoryContainer = new MLMemoryContainer(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + mlMemoryContainer.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + return mlMemoryContainer.toXContent(xContentBuilder, params); + } + + public static MLMemoryContainerGetResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLMemoryContainerGetResponse) { + return (MLMemoryContainerGetResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLMemoryContainerGetResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLMemoryContainerGetResponse", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesAction.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesAction.java new file mode 100644 index 0000000000..6fa2cbb72e --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import org.opensearch.action.ActionType; + +public class MLAddMemoriesAction extends ActionType { + public static final MLAddMemoriesAction INSTANCE = new MLAddMemoriesAction(); + public static final String NAME = "cluster:admin/opensearch/ml/memory_containers/memories/add"; + + private MLAddMemoriesAction() { + super(NAME, MLAddMemoriesResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesInput.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesInput.java new file mode 100644 index 0000000000..4c250cc46f --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesInput.java @@ -0,0 +1,193 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.AGENT_ID_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.INFER_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_CONTAINER_ID_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MESSAGES_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.SESSION_ID_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.TAGS_FIELD; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +/** + * Input data for adding memory to a memory container + */ +@Getter +@Setter +@Builder +public class MLAddMemoriesInput implements ToXContentObject, Writeable { + + // Required fields + private String memoryContainerId; + private List messages; + + // Optional fields + private String sessionId; + private String agentId; + private Boolean infer; + private Map tags; + + public MLAddMemoriesInput( + String memoryContainerId, + List messages, + String sessionId, + String agentId, + Boolean infer, + Map tags + ) { + // Note: memoryContainerId validation is removed here since it may come from URL path + if (messages == null || messages.isEmpty()) { + throw new IllegalArgumentException("Messages list cannot be empty"); + } + // MAX_MESSAGES_PER_REQUEST limit removed for performance testing + + this.memoryContainerId = memoryContainerId; + this.messages = messages; + this.sessionId = sessionId; + this.agentId = agentId; + this.infer = infer; + this.tags = tags; + } + + public MLAddMemoriesInput(StreamInput in) throws IOException { + this.memoryContainerId = in.readOptionalString(); + int messagesSize = in.readVInt(); + this.messages = new ArrayList<>(messagesSize); + for (int i = 0; i < messagesSize; i++) { + this.messages.add(new MessageInput(in)); + } + this.sessionId = in.readOptionalString(); + this.agentId = in.readOptionalString(); + this.infer = in.readOptionalBoolean(); + if (in.readBoolean()) { + this.tags = in.readMap(StreamInput::readString, StreamInput::readString); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(memoryContainerId); + out.writeVInt(messages.size()); + for (MessageInput message : messages) { + message.writeTo(out); + } + out.writeOptionalString(sessionId); + out.writeOptionalString(agentId); + out.writeOptionalBoolean(infer); + if (tags != null && !tags.isEmpty()) { + out.writeBoolean(true); + out.writeMap(tags, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + if (memoryContainerId != null) { + builder.field(MEMORY_CONTAINER_ID_FIELD, memoryContainerId); + } + builder.startArray(MESSAGES_FIELD); + for (MessageInput message : messages) { + message.toXContent(builder, params); + } + builder.endArray(); + if (sessionId != null) { + builder.field(SESSION_ID_FIELD, sessionId); + } + if (agentId != null) { + builder.field(AGENT_ID_FIELD, agentId); + } + if (infer != null) { + builder.field(INFER_FIELD, infer); + } + if (tags != null && !tags.isEmpty()) { + builder.field(TAGS_FIELD, tags); + } + builder.endObject(); + return builder; + } + + public static MLAddMemoriesInput parse(XContentParser parser) throws IOException { + String memoryContainerId = null; + List messages = null; + String sessionId = null; + String agentId = null; + Boolean infer = null; + Map tags = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case MEMORY_CONTAINER_ID_FIELD: + memoryContainerId = parser.text(); + break; + case MESSAGES_FIELD: + messages = new ArrayList<>(); + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + messages.add(MessageInput.parse(parser)); + } + break; + case SESSION_ID_FIELD: + sessionId = parser.text(); + break; + case AGENT_ID_FIELD: + agentId = parser.text(); + break; + case INFER_FIELD: + infer = parser.booleanValue(); + break; + case TAGS_FIELD: + tags = new HashMap<>(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String tagKey = parser.currentName(); + parser.nextToken(); + String tagValue = parser.text(); + tags.put(tagKey, tagValue); + } + break; + default: + parser.skipChildren(); + break; + } + } + + return MLAddMemoriesInput + .builder() + .memoryContainerId(memoryContainerId) + .messages(messages) + .sessionId(sessionId) + .agentId(agentId) + .infer(infer) + .tags(tags) + .build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesRequest.java new file mode 100644 index 0000000000..66e04d7b86 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesRequest.java @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; + +@Getter +@FieldDefaults(level = AccessLevel.PRIVATE) +@ToString +public class MLAddMemoriesRequest extends ActionRequest { + + MLAddMemoriesInput mlAddMemoryInput; + + @Builder + public MLAddMemoriesRequest(MLAddMemoriesInput mlAddMemoryInput) { + this.mlAddMemoryInput = mlAddMemoryInput; + } + + public MLAddMemoriesRequest(StreamInput in) throws IOException { + super(in); + this.mlAddMemoryInput = new MLAddMemoriesInput(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + this.mlAddMemoryInput.writeTo(out); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (mlAddMemoryInput == null) { + exception = addValidationError("ML add memory input can't be null", exception); + } + return exception; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesResponse.java new file mode 100644 index 0000000000..b60f99f06a --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesResponse.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.*; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; + +@Getter +@ToString +public class MLAddMemoriesResponse extends ActionResponse implements ToXContentObject { + + private List results; + private String sessionId; + + @Builder + public MLAddMemoriesResponse(List results, String sessionId) { + this.results = results != null ? results : new ArrayList<>(); + this.sessionId = sessionId; + } + + public MLAddMemoriesResponse(StreamInput in) throws IOException { + super(in); + int size = in.readVInt(); + this.results = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + this.results.add(new MemoryResult(in)); + } + this.sessionId = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(results.size()); + for (MemoryResult result : results) { + result.writeTo(out); + } + out.writeString(sessionId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startArray("results"); + for (MemoryResult result : results) { + result.toXContent(builder, params); + } + builder.endArray(); + builder.field(SESSION_ID_FIELD, sessionId); + builder.endObject(); + return builder; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLDeleteMemoryAction.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLDeleteMemoryAction.java new file mode 100644 index 0000000000..1a9d216fad --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLDeleteMemoryAction.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import org.opensearch.action.ActionType; +import org.opensearch.action.delete.DeleteResponse; + +public class MLDeleteMemoryAction extends ActionType { + public static final MLDeleteMemoryAction INSTANCE = new MLDeleteMemoryAction(); + public static final String NAME = "cluster:admin/opensearch/ml/memory_containers/memory/delete"; + + private MLDeleteMemoryAction() { + super(NAME, DeleteResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLDeleteMemoryRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLDeleteMemoryRequest.java new file mode 100644 index 0000000000..35a8880bb1 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLDeleteMemoryRequest.java @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.Builder; +import lombok.Getter; + +@Getter +public class MLDeleteMemoryRequest extends ActionRequest { + private final String memoryContainerId; + private final String memoryId; + + @Builder + public MLDeleteMemoryRequest(String memoryContainerId, String memoryId) { + this.memoryContainerId = memoryContainerId; + this.memoryId = memoryId; + } + + public MLDeleteMemoryRequest(StreamInput input) throws IOException { + super(input); + this.memoryContainerId = input.readString(); + this.memoryId = input.readString(); + } + + @Override + public void writeTo(StreamOutput output) throws IOException { + super.writeTo(output); + output.writeString(memoryContainerId); + output.writeString(memoryId); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + + if (this.memoryContainerId == null) { + exception = addValidationError("Memory container id can't be null", exception); + } + + if (this.memoryId == null) { + exception = addValidationError("Memory id can't be null", exception); + } + + return exception; + } + + public static MLDeleteMemoryRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLDeleteMemoryRequest) { + return (MLDeleteMemoryRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLDeleteMemoryRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLDeleteMemoryRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryAction.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryAction.java new file mode 100644 index 0000000000..cdbc8d62bf --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import org.opensearch.action.ActionType; + +public class MLGetMemoryAction extends ActionType { + public static final MLGetMemoryAction INSTANCE = new MLGetMemoryAction(); + public static final String NAME = "cluster:admin/opensearch/ml/memory_containers/memory/get"; + + private MLGetMemoryAction() { + super(NAME, MLGetMemoryResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryRequest.java new file mode 100644 index 0000000000..b85609d93c --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryRequest.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLGetMemoryRequest extends ActionRequest { + + String memoryContainerId; + String memoryId; + + @Builder + public MLGetMemoryRequest(String memoryContainerId, String memoryId) { + this.memoryContainerId = memoryContainerId; + this.memoryId = memoryId; + } + + public MLGetMemoryRequest(StreamInput in) throws IOException { + super(in); + this.memoryContainerId = in.readString(); + this.memoryId = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.memoryContainerId); + out.writeString(this.memoryId); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + + if (this.memoryContainerId == null || this.memoryId == null) { + exception = addValidationError("memoryContainerId and memoryId id can not be null", exception); + } + + return exception; + } + + public static MLGetMemoryRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLGetMemoryRequest) { + return (MLGetMemoryRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLGetMemoryRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLMemoryGetRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryResponse.java new file mode 100644 index 0000000000..e66525a3af --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryResponse.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.memorycontainer.MLMemory; + +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; + +@Getter +@ToString +public class MLGetMemoryResponse extends ActionResponse implements ToXContentObject { + MLMemory mlMemory; + + @Builder + public MLGetMemoryResponse(MLMemory mlMemory) { + this.mlMemory = mlMemory; + } + + public MLGetMemoryResponse(StreamInput in) throws IOException { + super(in); + mlMemory = new MLMemory(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + mlMemory.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + return mlMemory.toXContent(xContentBuilder, params); + } + + public static MLGetMemoryResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLGetMemoryResponse) { + return (MLGetMemoryResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLGetMemoryResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLMemoryGetResponse", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLSearchMemoriesAction.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLSearchMemoriesAction.java new file mode 100644 index 0000000000..986a7446fa --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLSearchMemoriesAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import org.opensearch.action.ActionType; + +public class MLSearchMemoriesAction extends ActionType { + public static final MLSearchMemoriesAction INSTANCE = new MLSearchMemoriesAction(); + public static final String NAME = "cluster:admin/opensearch/ml/memory_containers/memories/search"; + + private MLSearchMemoriesAction() { + super(NAME, MLSearchMemoriesResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLSearchMemoriesInput.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLSearchMemoriesInput.java new file mode 100644 index 0000000000..d743c2001e --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLSearchMemoriesInput.java @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.*; + +import java.io.IOException; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +/** + * Input data for searching memories in a memory container + */ +@Getter +@Setter +@Builder +public class MLSearchMemoriesInput implements ToXContentObject, Writeable { + + // Required fields + private String memoryContainerId; + private String query; + + public MLSearchMemoriesInput(String memoryContainerId, String query) { + if (StringUtils.isBlank(query)) { + throw new IllegalArgumentException("Query cannot be null or empty"); + } + this.memoryContainerId = memoryContainerId; + this.query = query.trim(); + } + + public MLSearchMemoriesInput(StreamInput in) throws IOException { + this.memoryContainerId = in.readOptionalString(); + this.query = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(memoryContainerId); + out.writeString(query); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + if (memoryContainerId != null) { + builder.field(MEMORY_CONTAINER_ID_FIELD, memoryContainerId); + } + builder.field(QUERY_FIELD, query); + builder.endObject(); + return builder; + } + + public static MLSearchMemoriesInput parse(XContentParser parser) throws IOException { + String memoryContainerId = null; + String query = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case MEMORY_CONTAINER_ID_FIELD: + memoryContainerId = parser.text(); + break; + case QUERY_FIELD: + query = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + + return MLSearchMemoriesInput.builder().memoryContainerId(memoryContainerId).query(query).build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLSearchMemoriesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLSearchMemoriesRequest.java new file mode 100644 index 0000000000..dac5a642ae --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLSearchMemoriesRequest.java @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +/** + * ML search memories request + */ +@Getter +@Setter +@Builder +public class MLSearchMemoriesRequest extends ActionRequest { + + private MLSearchMemoriesInput mlSearchMemoriesInput; + private String tenantId; + + public MLSearchMemoriesRequest(MLSearchMemoriesInput mlSearchMemoriesInput, String tenantId) { + this.mlSearchMemoriesInput = mlSearchMemoriesInput; + this.tenantId = tenantId; + } + + public MLSearchMemoriesRequest(StreamInput in) throws IOException { + super(in); + this.mlSearchMemoriesInput = new MLSearchMemoriesInput(in); + this.tenantId = in.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + mlSearchMemoriesInput.writeTo(out); + out.writeOptionalString(tenantId); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (mlSearchMemoriesInput == null) { + exception = addValidationError("Search memories input can't be null", exception); + } + return exception; + } + + public static MLSearchMemoriesRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLSearchMemoriesRequest) { + return (MLSearchMemoriesRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLSearchMemoriesRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLSearchMemoriesRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLSearchMemoriesResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLSearchMemoriesResponse.java new file mode 100644 index 0000000000..7692c8a102 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLSearchMemoriesResponse.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Builder; +import lombok.Getter; + +/** + * ML search memories response + */ +@Getter +@Builder +public class MLSearchMemoriesResponse extends ActionResponse implements ToXContentObject { + + private List hits; + private long totalHits; + private float maxScore; + private boolean timedOut; + + public MLSearchMemoriesResponse(List hits, long totalHits, float maxScore, boolean timedOut) { + this.hits = hits != null ? hits : new ArrayList<>(); + this.totalHits = totalHits; + this.maxScore = maxScore; + this.timedOut = timedOut; + } + + public MLSearchMemoriesResponse(StreamInput in) throws IOException { + super(in); + int hitCount = in.readVInt(); + this.hits = new ArrayList<>(hitCount); + for (int i = 0; i < hitCount; i++) { + this.hits.add(new MemorySearchResult(in)); + } + this.totalHits = in.readVLong(); + this.maxScore = in.readFloat(); + this.timedOut = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(hits.size()); + for (MemorySearchResult hit : hits) { + hit.writeTo(out); + } + out.writeVLong(totalHits); + out.writeFloat(maxScore); + out.writeBoolean(timedOut); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("timed_out", timedOut); + + builder.startObject("hits"); + builder.field("total", totalHits); + builder.field("max_score", maxScore); + + builder.startArray("hits"); + for (MemorySearchResult hit : hits) { + hit.toXContent(builder, params); + } + builder.endArray(); + + builder.endObject(); // end hits object + builder.endObject(); // end root object + return builder; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLUpdateMemoryAction.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLUpdateMemoryAction.java new file mode 100644 index 0000000000..7a00012ef6 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLUpdateMemoryAction.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import org.opensearch.action.ActionType; +import org.opensearch.action.update.UpdateResponse; + +public class MLUpdateMemoryAction extends ActionType { + public static final MLUpdateMemoryAction INSTANCE = new MLUpdateMemoryAction(); + public static final String NAME = "cluster:admin/opensearch/ml/memory_containers/memory/update"; + + private MLUpdateMemoryAction() { + super(NAME, UpdateResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLUpdateMemoryInput.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLUpdateMemoryInput.java new file mode 100644 index 0000000000..d8c55e1f14 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLUpdateMemoryInput.java @@ -0,0 +1,76 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.TEXT_FIELD; + +import java.io.IOException; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +/** + * Input data for updating a memory + */ +@Getter +@Setter +public class MLUpdateMemoryInput implements ToXContentObject, Writeable { + + private String text; + + @Builder + public MLUpdateMemoryInput(String text) { + if (StringUtils.isBlank(text)) { + throw new IllegalArgumentException("Text cannot be null or empty"); + } + this.text = text.trim(); + } + + public MLUpdateMemoryInput(StreamInput in) throws IOException { + this.text = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(text); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TEXT_FIELD, text); + builder.endObject(); + return builder; + } + + public static MLUpdateMemoryInput parse(XContentParser parser) throws IOException { + String text = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + if (TEXT_FIELD.equals(fieldName)) { + text = parser.text(); + } else { + parser.skipChildren(); + } + } + + return MLUpdateMemoryInput.builder().text(text).build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLUpdateMemoryRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLUpdateMemoryRequest.java new file mode 100644 index 0000000000..2d8885bede --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLUpdateMemoryRequest.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +@Getter +public class MLUpdateMemoryRequest extends ActionRequest { + @Setter + private MLUpdateMemoryInput mlUpdateMemoryInput; + private String memoryContainerId; + private String memoryId; + + @Builder + public MLUpdateMemoryRequest(String memoryContainerId, String memoryId, MLUpdateMemoryInput mlUpdateMemoryInput) { + this.memoryContainerId = memoryContainerId; + this.memoryId = memoryId; + this.mlUpdateMemoryInput = mlUpdateMemoryInput; + } + + public MLUpdateMemoryRequest(StreamInput in) throws IOException { + super(in); + this.memoryContainerId = in.readString(); + this.memoryId = in.readString(); + this.mlUpdateMemoryInput = new MLUpdateMemoryInput(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(memoryContainerId); + out.writeString(memoryId); + mlUpdateMemoryInput.writeTo(out); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (mlUpdateMemoryInput == null) { + exception = addValidationError("Update memory input can't be null", exception); + } + if (memoryContainerId == null) { + exception = addValidationError("Memory container id can't be null", exception); + } + if (memoryId == null) { + exception = addValidationError("Memory id can't be null", exception); + } + return exception; + } + + public static MLUpdateMemoryRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLUpdateMemoryRequest) { + return (MLUpdateMemoryRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLUpdateMemoryRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLUpdateMemoryRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MemoryEvent.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MemoryEvent.java new file mode 100644 index 0000000000..5f66094036 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MemoryEvent.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +/** + * Enum representing memory operation events + */ +public enum MemoryEvent { + ADD("ADD"), + UPDATE("UPDATE"), + DELETE("DELETE"), + NONE("NONE"); + + private final String value; + + MemoryEvent(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + public static MemoryEvent fromString(String value) { + if (value == null) { + throw new IllegalArgumentException("Memory event value cannot be null"); + } + + for (MemoryEvent event : MemoryEvent.values()) { + if (event.value.equalsIgnoreCase(value)) { + return event; + } + } + throw new IllegalArgumentException("Unknown memory event: " + value); + } + + @Override + public String toString() { + return value; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MemoryResult.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MemoryResult.java new file mode 100644 index 0000000000..d0180f3533 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MemoryResult.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.*; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; + +/** + * Represents a single memory result in the MLAddMemoryResponse + */ +@Getter +@ToString +@Builder +public class MemoryResult implements ToXContentObject, Writeable { + + private final String memoryId; + private final String memory; + private final MemoryEvent event; + private final String oldMemory; + + public MemoryResult(String memoryId, String memory, MemoryEvent event, String oldMemory) { + this.memoryId = memoryId; + this.memory = memory; + this.event = event; + this.oldMemory = oldMemory; + } + + public MemoryResult(StreamInput in) throws IOException { + this.memoryId = in.readString(); + this.memory = in.readString(); + this.event = MemoryEvent.fromString(in.readString()); + this.oldMemory = in.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(memoryId); + out.writeString(memory); + out.writeString(event.getValue()); + out.writeOptionalString(oldMemory); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("id", memoryId); + builder.field("text", memory); + builder.field("event", event.getValue()); + if (oldMemory != null) { + builder.field("old_memory", oldMemory); + } + builder.endObject(); + return builder; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MemorySearchResult.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MemorySearchResult.java new file mode 100644 index 0000000000..b68757b569 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MemorySearchResult.java @@ -0,0 +1,143 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.*; + +import java.io.IOException; +import java.time.Instant; +import java.util.Map; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.memorycontainer.MemoryType; + +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; + +/** + * Represents a single memory search result with relevance score + */ +@Getter +@ToString +@Builder +public class MemorySearchResult implements ToXContentObject, Writeable { + + private final String memoryId; + private final String memory; + private final float score; + private final String sessionId; + private final String agentId; + private final String userId; + private final MemoryType memoryType; + private final String role; + private final Map tags; + private final Instant createdTime; + private final Instant lastUpdatedTime; + + public MemorySearchResult( + String memoryId, + String memory, + float score, + String sessionId, + String agentId, + String userId, + MemoryType memoryType, + String role, + Map tags, + Instant createdTime, + Instant lastUpdatedTime + ) { + this.memoryId = memoryId; + this.memory = memory; + this.score = score; + this.sessionId = sessionId; + this.agentId = agentId; + this.userId = userId; + this.memoryType = memoryType; + this.role = role; + this.tags = tags; + this.createdTime = createdTime; + this.lastUpdatedTime = lastUpdatedTime; + } + + public MemorySearchResult(StreamInput in) throws IOException { + this.memoryId = in.readString(); + this.memory = in.readString(); + this.score = in.readFloat(); + this.sessionId = in.readOptionalString(); + this.agentId = in.readOptionalString(); + this.userId = in.readOptionalString(); + String memoryTypeStr = in.readOptionalString(); + this.memoryType = memoryTypeStr != null ? MemoryType.fromString(memoryTypeStr) : null; + this.role = in.readOptionalString(); + if (in.readBoolean()) { + this.tags = in.readMap(StreamInput::readString, StreamInput::readString); + } else { + this.tags = null; + } + this.createdTime = in.readOptionalInstant(); + this.lastUpdatedTime = in.readOptionalInstant(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(memoryId); + out.writeString(memory); + out.writeFloat(score); + out.writeOptionalString(sessionId); + out.writeOptionalString(agentId); + out.writeOptionalString(userId); + out.writeOptionalString(memoryType != null ? memoryType.toString() : null); + out.writeOptionalString(role); + if (tags != null && !tags.isEmpty()) { + out.writeBoolean(true); + out.writeMap(tags, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } + out.writeOptionalInstant(createdTime); + out.writeOptionalInstant(lastUpdatedTime); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MEMORY_ID_FIELD, memoryId); + builder.field(MEMORY_FIELD, memory); + builder.field("_score", score); + if (sessionId != null) { + builder.field(SESSION_ID_FIELD, sessionId); + } + if (agentId != null) { + builder.field(AGENT_ID_FIELD, agentId); + } + if (userId != null) { + builder.field(USER_ID_FIELD, userId); + } + if (memoryType != null) { + builder.field(MEMORY_TYPE_FIELD, memoryType.toString()); + } + if (role != null) { + builder.field(ROLE_FIELD, role); + } + if (tags != null && !tags.isEmpty()) { + builder.field(TAGS_FIELD, tags); + } + if (createdTime != null) { + builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli()); + } + if (lastUpdatedTime != null) { + builder.field(LAST_UPDATED_TIME_FIELD, lastUpdatedTime.toEpochMilli()); + } + builder.endObject(); + return builder; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MessageInput.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MessageInput.java new file mode 100644 index 0000000000..1888645b6d --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MessageInput.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.CONTENT_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.ROLE_FIELD; + +import java.io.IOException; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +/** + * Represents a single message with role and content + */ +@Getter +@Setter +@Builder +public class MessageInput implements ToXContentObject, Writeable { + + private String role; // Optional when infer=true + private String content; // Required + + public MessageInput(String role, String content) { + if (StringUtils.isBlank(content)) { + throw new IllegalArgumentException("Content is required"); + } + this.role = role; + this.content = content; + } + + public MessageInput(StreamInput in) throws IOException { + this.role = in.readOptionalString(); + this.content = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(role); + out.writeString(content); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (role != null) { + builder.field(ROLE_FIELD, role); + } + builder.field(CONTENT_FIELD, content); + builder.endObject(); + return builder; + } + + public static MessageInput parse(XContentParser parser) throws IOException { + String role = null; + String content = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case ROLE_FIELD: + role = parser.text(); + break; + case CONTENT_FIELD: + content = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + + return new MessageInput(role, content); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java index 4dc54bb23c..ee5a5068c0 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java @@ -25,10 +25,10 @@ import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.controller.MLRateLimiter; +import org.opensearch.ml.common.model.BaseModelConfig; import org.opensearch.ml.common.model.Guardrails; import org.opensearch.ml.common.model.MLDeploySetting; import org.opensearch.ml.common.model.MLModelConfig; -import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; @@ -125,7 +125,7 @@ public MLUpdateModelInput(StreamInput in) throws IOException { rateLimiter = new MLRateLimiter(in); } if (in.readBoolean()) { - modelConfig = new TextEmbeddingModelConfig(in); + modelConfig = new BaseModelConfig(in); } if (in.readBoolean()) { updatedConnector = Connector.fromStream(in); @@ -307,7 +307,7 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException rateLimiter = MLRateLimiter.parse(parser); break; case MODEL_CONFIG_FIELD: - modelConfig = TextEmbeddingModelConfig.parse(parser); + modelConfig = BaseModelConfig.parse(parser); break; case DEPLOY_SETTING_FIELD: deploySetting = MLDeploySetting.parse(parser); diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index a11f5db440..9d4f5ad0c9 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -5,6 +5,7 @@ package org.opensearch.ml.common.utils; +import static org.apache.commons.text.StringEscapeUtils.escapeJson; import static org.opensearch.action.ValidateActions.addValidationError; import java.nio.ByteBuffer; @@ -16,6 +17,7 @@ import java.security.PrivilegedExceptionAction; import java.util.ArrayList; import java.util.Base64; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -40,6 +42,7 @@ import com.google.gson.JsonObject; import com.google.gson.JsonParser; import com.google.gson.JsonSyntaxException; +import com.google.gson.reflect.TypeToken; import com.jayway.jsonpath.JsonPath; import com.jayway.jsonpath.PathNotFoundException; import com.networknt.schema.JsonSchema; @@ -111,6 +114,30 @@ public static boolean isJson(String json) { } } + /** + * Ensures that a string is properly JSON escaped. + * + *

This method examines the input string and determines whether it already represents + * valid JSON content. If the input is valid JSON, it is returned unchanged. Otherwise, + * the input is treated as a plain string and escaped according to JSON string literal + * rules.

+ * + *

Examples:

+ *
+     *   prepareJsonValue("hello")        → "\"hello\""
+     *   prepareJsonValue("\"hello\"")        → "\\\"hello\\\""
+     *   prepareJsonValue("{\"key\":123}") → {\"key\":123} (valid JSON object, unchanged)
+     * 
+ * @param input + * @return + */ + public static String prepareJsonValue(String input) { + if (isJson(input)) { + return input; + } + return escapeJson(input); + } + public static String toUTF8(String rawString) { ByteBuffer buffer = StandardCharsets.UTF_8.encode(rawString); @@ -552,4 +579,22 @@ public static boolean matchesSafePattern(String value) { return SAFE_INPUT_PATTERN.matcher(value).matches(); } + /** + * Parses a JSON array string into a List of Strings. + * + * @param jsonArrayString JSON array string to parse (e.g., "[\"item1\", \"item2\"]") + * @return List of strings parsed from the JSON array, or an empty list if the input is + * null, empty, or invalid JSON + */ + public static List parseStringArrayToList(String jsonArrayString) { + if (jsonArrayString == null || jsonArrayString.trim().isEmpty()) { + return Collections.emptyList(); + } + try { + return gson.fromJson(jsonArrayString, TypeToken.getParameterized(List.class, String.class).getType()); + } catch (JsonSyntaxException e) { + log.error("Failed to parse JSON array string: {}", jsonArrayString, e); + return Collections.emptyList(); + } + } } diff --git a/common/src/main/resources/index-mappings/ml_memory_container.json b/common/src/main/resources/index-mappings/ml_memory_container.json new file mode 100644 index 0000000000..8f34de75e6 --- /dev/null +++ b/common/src/main/resources/index-mappings/ml_memory_container.json @@ -0,0 +1,60 @@ +{ + "_meta": { + "schema_version": 1 + }, + "properties": { + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "description": { + "type": "text" + }, + "created_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "last_updated_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "memory_storage_config": { + "type": "object", + "properties": { + "memory_index_name": { + "type": "keyword" + }, + "semantic_storage_enabled": { + "type": "boolean" + }, + "embedding_model_type": { + "type": "keyword" + }, + "embedding_model_id": { + "type": "keyword" + }, + "llm_model_id": { + "type": "keyword" + }, + "dimension": { + "type": "integer" + }, + "max_recent_messages": { + "type": "integer" + }, + "max_infer_size": { + "type": "integer" + } + } + }, + "owner": USER_MAPPING_PLACEHOLDER, + "tenant_id": { + "type": "keyword" + } + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/MLCommonsClassLoaderTests.java b/common/src/test/java/org/opensearch/ml/common/MLCommonsClassLoaderTests.java index 3a3cee9acf..08f368d1b4 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLCommonsClassLoaderTests.java +++ b/common/src/test/java/org/opensearch/ml/common/MLCommonsClassLoaderTests.java @@ -40,6 +40,8 @@ import org.opensearch.ml.common.output.execute.samplecalculator.LocalSampleCalculatorOutput; import org.opensearch.search.SearchModule; +import com.fasterxml.jackson.core.JsonParseException; + public class MLCommonsClassLoaderTests { private SampleAlgoParams params; @@ -183,7 +185,7 @@ public void testClassLoader_MLInput() throws IOException { } @Test(expected = IllegalArgumentException.class) - public void testConnectorInitializationException() { + public void testConnectorInitializationException() throws JsonParseException { // Example initialization parameters for connectors String initParam1 = "parameter1"; @@ -191,6 +193,22 @@ public void testConnectorInitializationException() { MLCommonsClassLoader.initConnector("Connector", new Object[] { initParam1 }, String.class); } + @Test(expected = JsonParseException.class) + public void testInitMLInput_JsonParseException() throws IOException { + String invalidJsonStr = "invalid-json"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + invalidJsonStr + ); + parser.nextToken(); + + MLCommonsClassLoader + .initMLInput(FunctionName.AGENT, new Object[] { parser, FunctionName.AGENT }, XContentParser.class, FunctionName.class); + } + public enum TestEnum { TEST } diff --git a/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java b/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java index cdad6f21ca..c24d3d943b 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java +++ b/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java @@ -63,7 +63,7 @@ public void toXContent() throws IOException { .assertEquals( "{\"name\":\"test\",\"latest_version\":1,\"description\":\"this is test group\"," + "\"backend_roles\":[\"role1\",\"role2\"]," - + "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + + "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null,\"user_requested_tenant_access\":null}," + "\"access\":\"PUBLIC\"}", content ); @@ -166,7 +166,7 @@ public void toXContent_WithTenantId() throws IOException { .assertEquals( "{\"name\":\"test\",\"latest_version\":1,\"description\":\"this is test group\"," + "\"backend_roles\":[\"role1\",\"role2\"]," - + "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + + "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null,\"user_requested_tenant_access\":null}," + "\"access\":\"PUBLIC\",\"tenant_id\":\"test_tenant\"}", content ); @@ -176,7 +176,7 @@ public void toXContent_WithTenantId() throws IOException { public void parse_WithTenantId() throws IOException { String jsonStr = "{\"name\":\"test\",\"latest_version\":1,\"description\":\"this is test group\"," + "\"backend_roles\":[\"role1\",\"role2\"]," - + "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + + "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null,\"user_requested_tenant_access\":null}," + "\"access\":\"PUBLIC\",\"tenant_id\":\"test_tenant\"}"; XContentParser parser = XContentType.JSON diff --git a/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java index 05a94459a3..e05d8d04d2 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java @@ -6,14 +6,46 @@ package org.opensearch.ml.common.connector; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.isValidActionInModelPrediction; +import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_BATCH_JOB_ARN; +import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_EMBEDDING; +import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_RERANK; +import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_EMBEDDING; +import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_RERANK; +import static org.opensearch.ml.common.connector.MLPostProcessFunction.DEFAULT_EMBEDDING; +import static org.opensearch.ml.common.connector.MLPostProcessFunction.DEFAULT_RERANK; +import static org.opensearch.ml.common.connector.MLPostProcessFunction.OPENAI_EMBEDDING; +import static org.opensearch.ml.common.connector.MLPreProcessFunction.IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT; +import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT; +import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT; +import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT; +import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; +import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT; +import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT; +import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT; +import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_SIMILARITY_TO_DEFAULT_INPUT; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; +import org.apache.logging.log4j.Level; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.core.LogEvent; +import org.apache.logging.log4j.core.LoggerContext; +import org.apache.logging.log4j.core.appender.AbstractAppender; +import org.apache.logging.log4j.core.config.LoggerConfig; +import org.apache.logging.log4j.core.layout.PatternLayout; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; @@ -31,8 +63,37 @@ public class ConnectorActionTest { private static final ConnectorAction.ActionType TEST_ACTION_TYPE = ConnectorAction.ActionType.PREDICT; private static final String TEST_METHOD_POST = "post"; private static final String TEST_METHOD_HTTP = "http"; + private static final String LOG_APPENDER_NAME = "TestAppender"; private static final String TEST_REQUEST_BODY = "{\"input\": \"${parameters.input}\"}"; private static final String URL = "https://test.com"; + private static final String OPENAI_URL = "https://api.openai.com/v1/chat/completions"; + private static final String COHERE_URL = "https://api.cohere.ai/v1/embed"; + private static final String BEDROCK_URL = "https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-embed-text-v1/invoke"; + private static final String SAGEMAKER_URL = + "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/lmi-model-2023-06-24-01-35-32-275/invocations"; + private static final Logger logger = LogManager.getLogger(ConnectorActionTest.class); + private static TestLogAppender testAppender; + private static LoggerConfig loggerConfig; + + @BeforeClass + public static void setUpClass() { + testAppender = new TestLogAppender(LOG_APPENDER_NAME); + LoggerContext context = (LoggerContext) LogManager.getContext(false); + loggerConfig = context.getConfiguration().getLoggerConfig(logger.getName()); + loggerConfig.addAppender(testAppender, Level.WARN, null); + context.updateLoggers(); + } + + @After + public void tearDown() { + testAppender.clear(); + } + + @AfterClass + public static void tearDownClass() { + loggerConfig.removeAppender(LOG_APPENDER_NAME); + testAppender.stop(); + } @Test public void constructor_NullActionType() { @@ -62,6 +123,369 @@ public void constructor_NullMethod() { assertEquals("method can't be null", exception.getMessage()); } + @Test + public void testValidatePrePostProcessFunctionsWithNullPreProcessFunctionSuccess() { + ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, OPENAI_URL, null, TEST_REQUEST_BODY, null, null); + action.validatePrePostProcessFunctions(Map.of()); + assertFalse(testAppender.getLogEvents().stream().anyMatch(event -> event.getLevel() == Level.WARN)); + } + + @Test + public void testValidatePrePostProcessFunctionsWithExternalServers() { + ConnectorAction action = new ConnectorAction( + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + URL, + null, + TEST_REQUEST_BODY, + TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, + OPENAI_EMBEDDING + ); + action.validatePrePostProcessFunctions(Map.of()); + assertFalse(testAppender.getLogEvents().stream().anyMatch(event -> event.getLevel() == Level.WARN)); + } + + @Test + public void testValidatePrePostProcessFunctionsWithCustomPainlessScriptPreProcessFunctionSuccess() { + String preProcessFunction = + "\"\\n StringBuilder builder = new StringBuilder();\\n builder.append(\\\"\\\\\\\"\\\");\\n String first = params.text_docs[0];\\n builder.append(first);\\n builder.append(\\\"\\\\\\\"\\\");\\n def parameters = \\\"{\\\" +\\\"\\\\\\\"text_inputs\\\\\\\":\\\" + builder + \\\"}\\\";\\n return \\\"{\\\" +\\\"\\\\\\\"parameters\\\\\\\":\\\" + parameters + \\\"}\\\";\""; + ConnectorAction action = new ConnectorAction( + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + OPENAI_URL, + null, + TEST_REQUEST_BODY, + preProcessFunction, + null + ); + action.validatePrePostProcessFunctions(null); + assertFalse(testAppender.getLogEvents().stream().anyMatch(event -> event.getLevel() == Level.WARN)); + } + + @Test + public void testValidatePrePostProcessFunctionsWithOpenAIConnectorCorrectInBuiltPrePostProcessFunctionSuccess() { + ConnectorAction action = new ConnectorAction( + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + "https://${parameters.endpoint}/v1/chat/completions", + null, + TEST_REQUEST_BODY, + TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, + OPENAI_EMBEDDING + ); + action.validatePrePostProcessFunctions(Map.of("endpoint", "api.openai.com")); + assertFalse(testAppender.getLogEvents().stream().anyMatch(event -> event.getLevel() == Level.WARN)); + } + + @Test + public void testValidatePrePostProcessFunctionsWithOpenAIConnectorWrongInBuiltPreProcessFunction() { + ConnectorAction action = new ConnectorAction( + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + OPENAI_URL, + null, + TEST_REQUEST_BODY, + TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, + OPENAI_EMBEDDING + ); + action.validatePrePostProcessFunctions(Map.of()); + boolean isWarningLogged = testAppender + .getLogEvents() + .stream() + .anyMatch( + event -> event.getLevel() == Level.WARN + && event + .getMessage() + .getFormattedMessage() + .contains("LLM service is openai, so PreProcessFunction should be corresponding to openai for better results.") + ); + assertTrue(isWarningLogged); + } + + @Test + public void testValidatePrePostProcessFunctionsWithOpenAIConnectorWrongInBuiltPostProcessFunction() { + ConnectorAction action = new ConnectorAction( + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + OPENAI_URL, + null, + TEST_REQUEST_BODY, + TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, + COHERE_EMBEDDING + ); + action.validatePrePostProcessFunctions(Map.of()); + boolean isWarningLogged = testAppender + .getLogEvents() + .stream() + .anyMatch( + event -> event.getLevel() == Level.WARN + && event + .getMessage() + .getFormattedMessage() + .contains("LLM service is openai, so PostProcessFunction should be corresponding to openai for better results.") + ); + assertTrue(isWarningLogged); + } + + @Test + public void testValidatePrePostProcessFunctionsWithCohereConnectorCorrectInBuiltPrePostProcessFunctionSuccess() { + ConnectorAction action = new ConnectorAction( + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + COHERE_URL, + null, + TEST_REQUEST_BODY, + TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, + COHERE_EMBEDDING + ); + action.validatePrePostProcessFunctions(Map.of()); + assertFalse(testAppender.getLogEvents().stream().anyMatch(event -> event.getLevel() == Level.WARN)); + + action = new ConnectorAction( + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + COHERE_URL, + null, + TEST_REQUEST_BODY, + IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT, + COHERE_EMBEDDING + ); + action.validatePrePostProcessFunctions(Map.of()); + assertFalse(testAppender.getLogEvents().stream().anyMatch(event -> event.getLevel() == Level.WARN)); + + action = new ConnectorAction( + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + COHERE_URL, + null, + TEST_REQUEST_BODY, + TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT, + COHERE_RERANK + ); + action.validatePrePostProcessFunctions(Map.of()); + assertFalse(testAppender.getLogEvents().stream().anyMatch(event -> event.getLevel() == Level.WARN)); + } + + @Test + public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPreProcessFunction() { + ConnectorAction action = new ConnectorAction( + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + COHERE_URL, + null, + TEST_REQUEST_BODY, + TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, + COHERE_EMBEDDING + ); + action.validatePrePostProcessFunctions(Map.of()); + boolean isWarningLogged = testAppender + .getLogEvents() + .stream() + .anyMatch( + event -> event.getLevel() == Level.WARN + && event + .getMessage() + .getFormattedMessage() + .contains("LLM service is cohere, so PreProcessFunction should be corresponding to cohere for better results.") + ); + assertTrue(isWarningLogged); + } + + @Test + public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPostProcessFunction() { + ConnectorAction action = new ConnectorAction( + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + COHERE_URL, + null, + TEST_REQUEST_BODY, + TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, + OPENAI_EMBEDDING + ); + action.validatePrePostProcessFunctions(Map.of()); + boolean isWarningLogged = testAppender + .getLogEvents() + .stream() + .anyMatch( + event -> event.getLevel() == Level.WARN + && event + .getMessage() + .getFormattedMessage() + .contains("LLM service is cohere, so PostProcessFunction should be corresponding to cohere for better results.") + ); + assertTrue(isWarningLogged); + } + + @Test + public void testValidatePrePostProcessFunctionsWithBedrockConnectorCorrectInBuiltPrePostProcessFunctionSuccess() { + ConnectorAction action = new ConnectorAction( + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + BEDROCK_URL, + null, + TEST_REQUEST_BODY, + TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, + BEDROCK_EMBEDDING + ); + action.validatePrePostProcessFunctions(Map.of()); + assertFalse(testAppender.getLogEvents().stream().anyMatch(event -> event.getLevel() == Level.WARN)); + + action = new ConnectorAction( + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + BEDROCK_URL, + null, + TEST_REQUEST_BODY, + TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, + BEDROCK_BATCH_JOB_ARN + ); + action.validatePrePostProcessFunctions(Map.of()); + assertFalse(testAppender.getLogEvents().stream().anyMatch(event -> event.getLevel() == Level.WARN)); + + action = new ConnectorAction( + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + BEDROCK_URL, + null, + TEST_REQUEST_BODY, + TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT, + BEDROCK_RERANK + ); + action.validatePrePostProcessFunctions(Map.of()); + assertFalse(testAppender.getLogEvents().stream().anyMatch(event -> event.getLevel() == Level.WARN)); + } + + @Test + public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltPreProcessFunction() { + ConnectorAction action = new ConnectorAction( + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + BEDROCK_URL, + null, + TEST_REQUEST_BODY, + TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, + BEDROCK_EMBEDDING + ); + action.validatePrePostProcessFunctions(Map.of()); + boolean isWarningLogged = testAppender + .getLogEvents() + .stream() + .anyMatch( + event -> event.getLevel() == Level.WARN + && event + .getMessage() + .getFormattedMessage() + .contains("LLM service is bedrock, so PreProcessFunction should be corresponding to bedrock for better results.") + ); + assertTrue(isWarningLogged); + } + + @Test + public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltPostProcessFunction() { + ConnectorAction action = new ConnectorAction( + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + BEDROCK_URL, + null, + TEST_REQUEST_BODY, + TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, + COHERE_EMBEDDING + ); + action.validatePrePostProcessFunctions(Map.of()); + boolean isWarningLogged = testAppender + .getLogEvents() + .stream() + .anyMatch( + event -> event.getLevel() == Level.WARN + && event + .getMessage() + .getFormattedMessage() + .contains("LLM service is bedrock, so PostProcessFunction should be corresponding to bedrock for better results.") + ); + assertTrue(isWarningLogged); + } + + @Test + public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWithCorrectInBuiltPrePostProcessFunctionSuccess() { + ConnectorAction action = new ConnectorAction( + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + SAGEMAKER_URL, + null, + TEST_REQUEST_BODY, + TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, + DEFAULT_EMBEDDING + ); + action.validatePrePostProcessFunctions(Map.of()); + assertFalse(testAppender.getLogEvents().stream().anyMatch(event -> event.getLevel() == Level.WARN)); + + action = new ConnectorAction( + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + SAGEMAKER_URL, + null, + TEST_REQUEST_BODY, + TEXT_SIMILARITY_TO_DEFAULT_INPUT, + DEFAULT_RERANK + ); + action.validatePrePostProcessFunctions(Map.of()); + assertFalse(testAppender.getLogEvents().stream().anyMatch(event -> event.getLevel() == Level.WARN)); + } + + @Test + public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWrongInBuiltPreProcessFunction() { + ConnectorAction action = new ConnectorAction( + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + SAGEMAKER_URL, + null, + TEST_REQUEST_BODY, + TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, + DEFAULT_EMBEDDING + ); + action.validatePrePostProcessFunctions(Map.of()); + boolean isWarningLogged = testAppender + .getLogEvents() + .stream() + .anyMatch( + event -> event.getLevel() == Level.WARN + && event + .getMessage() + .getFormattedMessage() + .contains( + "LLM service is sagemaker, so PreProcessFunction should be corresponding to sagemaker for better results." + ) + ); + assertTrue(isWarningLogged); + } + + @Test + public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWrongInBuiltPostProcessFunction() { + ConnectorAction action = new ConnectorAction( + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + SAGEMAKER_URL, + null, + TEST_REQUEST_BODY, + TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, + BEDROCK_EMBEDDING + ); + action.validatePrePostProcessFunctions(Map.of()); + boolean isWarningLogged = testAppender + .getLogEvents() + .stream() + .anyMatch( + event -> event.getLevel() == Level.WARN + && event + .getMessage() + .getFormattedMessage() + .contains( + "LLM service is sagemaker, so PostProcessFunction should be corresponding to sagemaker for better results." + ) + ); + assertTrue(isWarningLogged); + } + @Test public void writeTo_NullValue() throws IOException { ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null); @@ -170,4 +594,30 @@ public void test_invalidActionInModelPrediction() { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.from("execute"); assertEquals(isValidActionInModelPrediction(actionType), false); } + + /** + * Log appender class to check the logs printed or not + */ + static class TestLogAppender extends AbstractAppender { + + private final List logEvents = new ArrayList<>(); + + public TestLogAppender(String name) { + super(name, null, PatternLayout.createDefaultLayout(), false); + start(); + } + + @Override + public void append(LogEvent event) { + logEvents.add(event.toImmutable()); + } + + public List getLogEvents() { + return logEvents; + } + + public void clear() { + logEvents.clear(); + } + } } diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java index b949208472..dc494030ca 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java @@ -1,6 +1,7 @@ package org.opensearch.ml.common.dataset; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import static org.opensearch.ml.common.TestHelper.contentObjectToString; import static org.opensearch.ml.common.TestHelper.testParseFromString; @@ -11,12 +12,14 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.opensearch.Version; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; +import org.opensearch.ml.common.input.parameter.textembedding.SparseEmbeddingFormat; public class AsymmetricTextEmbeddingParametersTest { @@ -74,6 +77,127 @@ public void readInputStream_Success_EmptyParams() throws IOException { readInputStream(AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build()); } + @Test + public void parse_AsymmetricTextEmbeddingParameters_WithSparseEmbeddingFormat_LEXICAL() throws IOException { + AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters + .builder() + .embeddingContentType(EmbeddingContentType.QUERY) + .sparseEmbeddingFormat(SparseEmbeddingFormat.WORD) + .build(); + TestHelper.testParse(params, function); + } + + @Test + public void parse_AsymmetricTextEmbeddingParameters_WithSparseEmbeddingFormat_TOKEN_ID() throws IOException { + AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters + .builder() + .embeddingContentType(EmbeddingContentType.PASSAGE) + .sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID) + .build(); + TestHelper.testParse(params, function); + } + + @Test + public void parse_AsymmetricTextEmbeddingParameters_OnlySparseEmbeddingFormat() throws IOException { + AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters + .builder() + .sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID) + .build(); + TestHelper.testParse(params, function); + } + + @Test + public void parse_AsymmetricTextEmbeddingParameters_SparseEmbeddingFormat_Invalid() throws IOException { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule + .expectMessage("No enum constant org.opensearch.ml.common.input.parameter.textembedding.SparseEmbeddingFormat.INVALID"); + String jsonWithInvalidFormat = "{\"content_type\": \"QUERY\", \"sparse_embedding_format\": \"INVALID\"}"; + testParseFromString(params, jsonWithInvalidFormat, function); + } + + @Test + public void constructor_BackwardCompatibility() { + AsymmetricTextEmbeddingParameters params = new AsymmetricTextEmbeddingParameters(EmbeddingContentType.QUERY); + assertEquals(EmbeddingContentType.QUERY, params.getEmbeddingContentType()); + assertEquals(SparseEmbeddingFormat.WORD, params.getSparseEmbeddingFormat()); + } + + @Test + public void constructor_WithSparseEmbeddingFormat() { + AsymmetricTextEmbeddingParameters params = new AsymmetricTextEmbeddingParameters( + EmbeddingContentType.PASSAGE, + SparseEmbeddingFormat.TOKEN_ID + ); + assertEquals(EmbeddingContentType.PASSAGE, params.getEmbeddingContentType()); + assertEquals(SparseEmbeddingFormat.TOKEN_ID, params.getSparseEmbeddingFormat()); + } + + @Test + public void constructor_WithNullSparseEmbeddingFormat_DefaultsToLexical() { + AsymmetricTextEmbeddingParameters params = new AsymmetricTextEmbeddingParameters(EmbeddingContentType.QUERY, null); + assertEquals(EmbeddingContentType.QUERY, params.getEmbeddingContentType()); + assertEquals(SparseEmbeddingFormat.WORD, params.getSparseEmbeddingFormat()); + } + + @Test + public void constructor_NullContentType_WithSparseEmbeddingFormat() { + AsymmetricTextEmbeddingParameters params = new AsymmetricTextEmbeddingParameters(null, SparseEmbeddingFormat.TOKEN_ID); + assertNull(params.getEmbeddingContentType()); + assertEquals(SparseEmbeddingFormat.TOKEN_ID, params.getSparseEmbeddingFormat()); + } + + @Test + public void readInputStream_WithSparseEmbeddingFormat() throws IOException { + AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters + .builder() + .embeddingContentType(EmbeddingContentType.PASSAGE) + .sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID) + .build(); + readInputStream(params); + } + + @Test + public void readInputStream_OnlySparseEmbeddingFormat() throws IOException { + AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters + .builder() + .sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID) + .build(); + readInputStream(params); + } + + @Test + public void readInputStream_VersionCompatibility_Pre_V_3_2_0() throws IOException { + AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters + .builder() + .embeddingContentType(EmbeddingContentType.QUERY) + .sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID) + .build(); + + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + bytesStreamOutput.setVersion(Version.V_3_1_0); + params.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + streamInput.setVersion(Version.V_3_1_0); + AsymmetricTextEmbeddingParameters parsedParams = new AsymmetricTextEmbeddingParameters(streamInput); + + assertEquals(EmbeddingContentType.QUERY, parsedParams.getEmbeddingContentType()); + assertEquals(SparseEmbeddingFormat.WORD, parsedParams.getSparseEmbeddingFormat()); + } + + @Test + public void toXContent_IncludesSparseEmbeddingFormat() throws IOException { + AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters + .builder() + .embeddingContentType(EmbeddingContentType.QUERY) + .sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID) + .build(); + + String jsonStr = contentObjectToString(params); + assert (jsonStr.contains("\"content_type\":\"QUERY\"")); + assert (jsonStr.contains("\"sparse_embedding_format\":\"TOKEN_ID\"")); + } + private void readInputStream(AsymmetricTextEmbeddingParameters params) throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); params.writeTo(bytesStreamOutput); diff --git a/common/src/test/java/org/opensearch/ml/common/input/execute/tool/ToolMLInputTests.java b/common/src/test/java/org/opensearch/ml/common/input/execute/tool/ToolMLInputTests.java new file mode 100644 index 0000000000..5f21ac2a41 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/input/execute/tool/ToolMLInputTests.java @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.input.execute.tool; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Consumer; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.search.SearchModule; + +public class ToolMLInputTests { + + private ToolMLInput toolMLInput; + private Map parameters; + private final String json = "{\"tool_name\":\"TestTool\",\"parameters\":{\"question\":\"test question\",\"model_id\":\"test_model\"}}"; + + @Before + public void setUp() throws IOException { + parameters = new HashMap<>(); + parameters.put("question", "test question"); + parameters.put("model_id", "test_model"); + + XContentParser parser = createParser(json); + toolMLInput = new ToolMLInput(parser, FunctionName.TOOL); + } + + @Test + public void readInputStreamSuccess() throws IOException { + readInputStream(toolMLInput, parsedInput -> { + assertEquals("TestTool", parsedInput.getToolName()); + assertEquals(FunctionName.TOOL, parsedInput.getAlgorithm()); + assertNotNull(parsedInput.getInputDataset()); + }); + } + + @Test + public void testXContentParsing() throws IOException { + XContentParser parser = createParser(json); + ToolMLInput parsed = new ToolMLInput(parser, FunctionName.TOOL); + + assertEquals("TestTool", parsed.getToolName()); + assertEquals(FunctionName.TOOL, parsed.getAlgorithm()); + assertNotNull(parsed.getInputDataset()); + assertTrue(parsed.getInputDataset() instanceof RemoteInferenceInputDataSet); + } + + @Test(expected = IOException.class) + public void testParseInvalidJson() throws IOException { + String invalidJson = "{\"tool_name\":\"TestTool\",\"parameters\":{\"question\":\"test\""; // Missing closing braces + XContentParser parser = createParser(invalidJson); + new ToolMLInput(parser, FunctionName.TOOL); + } + + @Test + public void testParseMissingToolName() throws IOException { + String jsonWithoutToolName = "{\"parameters\":{\"question\":\"test\",\"model_id\":\"123\"}}"; + XContentParser parser = createParser(jsonWithoutToolName); + ToolMLInput parsed = new ToolMLInput(parser, FunctionName.TOOL); + + assertEquals(null, parsed.getToolName()); + assertEquals(FunctionName.TOOL, parsed.getAlgorithm()); + } + + @Test + public void testParseMissingParameters() throws IOException { + String jsonWithoutParams = "{\"tool_name\":\"TestTool\"}"; + XContentParser parser = createParser(jsonWithoutParams); + ToolMLInput parsed = new ToolMLInput(parser, FunctionName.TOOL); + + assertEquals("TestTool", parsed.getToolName()); + assertEquals(null, parsed.getInputDataset()); + } + + private XContentParser createParser(String jsonString) throws IOException { + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + LoggingDeprecationHandler.INSTANCE, + jsonString + ); + parser.nextToken(); + return parser; + } + + private void readInputStream(ToolMLInput input, Consumer verify) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + input.writeTo(bytesStreamOutput); + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + ToolMLInput parsedInput = new ToolMLInput(streamInput); + verify.accept(parsedInput); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/memorycontainer/MLMemoryContainerTests.java b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MLMemoryContainerTests.java new file mode 100644 index 0000000000..86ac98e978 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MLMemoryContainerTests.java @@ -0,0 +1,426 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.memorycontainer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.time.Instant; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.TestHelper; + +public class MLMemoryContainerTests { + + private MLMemoryContainer mlMemoryContainer; + private User testUser; + private MemoryStorageConfig testMemoryStorageConfig; + private Instant testCreatedTime; + private Instant testLastUpdatedTime; + + @Before + public void setUp() { + testUser = new User(); // Use empty User constructor like in MLModelTests + // Use millisecond precision to avoid precision loss in JSON serialization + testCreatedTime = Instant.ofEpochMilli(System.currentTimeMillis()); + testLastUpdatedTime = Instant.ofEpochMilli(System.currentTimeMillis() + 3600000); + + testMemoryStorageConfig = MemoryStorageConfig + .builder() + .memoryIndexName("test-memory-index") + .semanticStorageEnabled(true) + .embeddingModelType(FunctionName.TEXT_EMBEDDING) + .embeddingModelId("test-embedding-model") + .dimension(768) + .maxInferSize(10) // Max allowed value is 10 + .build(); + + mlMemoryContainer = MLMemoryContainer + .builder() + .name("test-memory-container") + .description("Test memory container description") + .owner(testUser) + .tenantId("test-tenant") + .createdTime(testCreatedTime) + .lastUpdatedTime(testLastUpdatedTime) + .memoryStorageConfig(testMemoryStorageConfig) + .build(); + } + + @Test + public void testConstructorWithBuilder() { + assertNotNull(mlMemoryContainer); + assertEquals("test-memory-container", mlMemoryContainer.getName()); + assertEquals("Test memory container description", mlMemoryContainer.getDescription()); + assertEquals(testUser, mlMemoryContainer.getOwner()); + assertEquals("test-tenant", mlMemoryContainer.getTenantId()); + assertEquals(testCreatedTime, mlMemoryContainer.getCreatedTime()); + assertEquals(testLastUpdatedTime, mlMemoryContainer.getLastUpdatedTime()); + assertEquals(testMemoryStorageConfig, mlMemoryContainer.getMemoryStorageConfig()); + } + + @Test + public void testConstructorWithAllParameters() { + MLMemoryContainer container = new MLMemoryContainer( + "test-name", + "test-description", + testUser, + "test-tenant", + testCreatedTime, + testLastUpdatedTime, + testMemoryStorageConfig + ); + + assertEquals("test-name", container.getName()); + assertEquals("test-description", container.getDescription()); + assertEquals(testUser, container.getOwner()); + assertEquals("test-tenant", container.getTenantId()); + assertEquals(testCreatedTime, container.getCreatedTime()); + assertEquals(testLastUpdatedTime, container.getLastUpdatedTime()); + assertEquals(testMemoryStorageConfig, container.getMemoryStorageConfig()); + } + + @Test + public void testConstructorWithNullValues() { + MLMemoryContainer container = MLMemoryContainer + .builder() + .name(null) + .description(null) + .owner(null) + .tenantId(null) + .createdTime(null) + .lastUpdatedTime(null) + .memoryStorageConfig(null) + .build(); + + assertNull(container.getName()); + assertNull(container.getDescription()); + assertNull(container.getOwner()); + assertNull(container.getTenantId()); + assertNull(container.getCreatedTime()); + assertNull(container.getLastUpdatedTime()); + assertNull(container.getMemoryStorageConfig()); + } + + @Test + public void testStreamInputOutput() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + mlMemoryContainer.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLMemoryContainer parsedContainer = new MLMemoryContainer(streamInput); + + assertEquals(mlMemoryContainer.getName(), parsedContainer.getName()); + assertEquals(mlMemoryContainer.getDescription(), parsedContainer.getDescription()); + assertEquals(mlMemoryContainer.getOwner(), parsedContainer.getOwner()); + assertEquals(mlMemoryContainer.getTenantId(), parsedContainer.getTenantId()); + assertEquals(mlMemoryContainer.getCreatedTime(), parsedContainer.getCreatedTime()); + assertEquals(mlMemoryContainer.getLastUpdatedTime(), parsedContainer.getLastUpdatedTime()); + assertEquals(mlMemoryContainer.getMemoryStorageConfig(), parsedContainer.getMemoryStorageConfig()); + } + + @Test + public void testStreamInputOutputWithNullValues() throws IOException { + MLMemoryContainer containerWithNulls = MLMemoryContainer + .builder() + .name("test-name") + .description("test-description") + .owner(null) + .tenantId("test-tenant") + .createdTime(testCreatedTime) + .lastUpdatedTime(testLastUpdatedTime) + .memoryStorageConfig(null) + .build(); + + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + containerWithNulls.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLMemoryContainer parsedContainer = new MLMemoryContainer(streamInput); + + assertEquals(containerWithNulls.getName(), parsedContainer.getName()); + assertEquals(containerWithNulls.getDescription(), parsedContainer.getDescription()); + assertNull(parsedContainer.getOwner()); + assertEquals(containerWithNulls.getTenantId(), parsedContainer.getTenantId()); + assertEquals(containerWithNulls.getCreatedTime(), parsedContainer.getCreatedTime()); + assertEquals(containerWithNulls.getLastUpdatedTime(), parsedContainer.getLastUpdatedTime()); + assertNull(parsedContainer.getMemoryStorageConfig()); + } + + @Test + public void testToXContent() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + mlMemoryContainer.toXContent(builder, EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + + assertNotNull(jsonStr); + // Verify that all fields are present in the JSON + assert (jsonStr.contains("\"name\":\"test-memory-container\"")); + assert (jsonStr.contains("\"description\":\"Test memory container description\"")); + assert (jsonStr.contains("\"tenant_id\":\"test-tenant\"")); + assert (jsonStr.contains("\"created_time\":" + testCreatedTime.toEpochMilli())); + assert (jsonStr.contains("\"last_updated_time\":" + testLastUpdatedTime.toEpochMilli())); + assert (jsonStr.contains("\"memory_storage_config\"")); + } + + @Test + public void testToXContentWithNullValues() throws IOException { + MLMemoryContainer containerWithNulls = MLMemoryContainer + .builder() + .name("test-name") + .description(null) + .owner(null) + .tenantId(null) + .createdTime(null) + .lastUpdatedTime(null) + .memoryStorageConfig(null) + .build(); + + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + containerWithNulls.toXContent(builder, EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + + assertNotNull(jsonStr); + assert (jsonStr.contains("\"name\":\"test-name\"")); + // Verify that null fields are not included in JSON + assert (!jsonStr.contains("\"description\"")); + assert (!jsonStr.contains("\"owner\"")); + assert (!jsonStr.contains("\"tenant_id\"")); + assert (!jsonStr.contains("\"created_time\"")); + assert (!jsonStr.contains("\"last_updated_time\"")); + assert (!jsonStr.contains("\"memory_storage_config\"")); + } + + @Test + public void testParseFromXContentWithPartialFields() throws IOException { + String jsonStr = "{\"name\":\"partial-container\",\"description\":\"partial description\"}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + MLMemoryContainer parsedContainer = MLMemoryContainer.parse(parser); + + assertEquals("partial-container", parsedContainer.getName()); + assertEquals("partial description", parsedContainer.getDescription()); + assertNull(parsedContainer.getOwner()); + assertNull(parsedContainer.getTenantId()); + assertNull(parsedContainer.getCreatedTime()); + assertNull(parsedContainer.getLastUpdatedTime()); + assertNull(parsedContainer.getMemoryStorageConfig()); + } + + @Test + public void testParseFromXContentWithUnknownFields() throws IOException { + String jsonStr = "{\"name\":\"test-container\",\"unknown_field\":\"unknown_value\",\"description\":\"test description\"}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + MLMemoryContainer parsedContainer = MLMemoryContainer.parse(parser); + + assertEquals("test-container", parsedContainer.getName()); + assertEquals("test description", parsedContainer.getDescription()); + // Unknown fields should be ignored + assertNull(parsedContainer.getOwner()); + assertNull(parsedContainer.getTenantId()); + } + + @Test + public void testParseFromXContentWithTimeFields() throws IOException { + long createdTimeMillis = testCreatedTime.toEpochMilli(); + long lastUpdatedTimeMillis = testLastUpdatedTime.toEpochMilli(); + + String jsonStr = String + .format( + "{\"name\":\"time-test-container\",\"created_time\":%d,\"last_updated_time\":%d}", + createdTimeMillis, + lastUpdatedTimeMillis + ); + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + MLMemoryContainer parsedContainer = MLMemoryContainer.parse(parser); + + assertEquals("time-test-container", parsedContainer.getName()); + assertEquals(testCreatedTime, parsedContainer.getCreatedTime()); + assertEquals(testLastUpdatedTime, parsedContainer.getLastUpdatedTime()); + assertNull(parsedContainer.getOwner()); + assertNull(parsedContainer.getMemoryStorageConfig()); + } + + @Test + public void testParseFromXContentWithMemoryStorageConfig() throws IOException { + // Create a JSON string with memory storage config + String jsonStr = "{\"name\":\"config-test-container\"," + + "\"memory_storage_config\":{" + + "\"memory_index_name\":\"test-index\"," + + "\"semantic_storage_enabled\":true," + + "\"embedding_model_type\":\"TEXT_EMBEDDING\"," + + "\"embedding_model_id\":\"test-model\"," + + "\"dimension\":512," + + "\"max_infer_size\":5" + + "}}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + MLMemoryContainer parsedContainer = MLMemoryContainer.parse(parser); + + assertEquals("config-test-container", parsedContainer.getName()); + assertNotNull(parsedContainer.getMemoryStorageConfig()); + assertEquals("test-index", parsedContainer.getMemoryStorageConfig().getMemoryIndexName()); + assertEquals(true, parsedContainer.getMemoryStorageConfig().isSemanticStorageEnabled()); + assertEquals(FunctionName.TEXT_EMBEDDING, parsedContainer.getMemoryStorageConfig().getEmbeddingModelType()); + assertEquals("test-model", parsedContainer.getMemoryStorageConfig().getEmbeddingModelId()); + assertEquals(Integer.valueOf(512), parsedContainer.getMemoryStorageConfig().getDimension()); + assertNull(parsedContainer.getMemoryStorageConfig().getMaxInferSize()); // No llmModelId, so maxInferSize is null + } + + @Test + public void testParseFromXContentCompleteRoundTrip() throws IOException { + // Test complete round trip: object -> JSON -> parse -> compare + // Use container without User to avoid parsing issues, but test all other fields + MLMemoryContainer originalContainer = MLMemoryContainer + .builder() + .name("roundtrip-container") + .description("roundtrip description") + .owner(null) // Skip User for now due to parsing complexity + .tenantId("roundtrip-tenant") + .createdTime(testCreatedTime) + .lastUpdatedTime(testLastUpdatedTime) + .memoryStorageConfig(testMemoryStorageConfig) + .build(); + + // Convert to JSON + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + originalContainer.toXContent(builder, EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + + // Parse back from JSON + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + MLMemoryContainer parsedContainer = MLMemoryContainer.parse(parser); + + // Verify all fields match + assertEquals(originalContainer.getName(), parsedContainer.getName()); + assertEquals(originalContainer.getDescription(), parsedContainer.getDescription()); + assertEquals(originalContainer.getTenantId(), parsedContainer.getTenantId()); + assertEquals(originalContainer.getCreatedTime(), parsedContainer.getCreatedTime()); + assertEquals(originalContainer.getLastUpdatedTime(), parsedContainer.getLastUpdatedTime()); + assertEquals(originalContainer.getMemoryStorageConfig(), parsedContainer.getMemoryStorageConfig()); + } + + @Test + public void testParseFromXContentWithUserField() throws IOException { + // Test parsing with User field using the same approach as other OpenSearch tests + // Create a container with User and serialize it to see the expected JSON format + MLMemoryContainer containerWithUser = MLMemoryContainer + .builder() + .name("user-test-container") + .description("test with user") + .owner(testUser) + .build(); + + // Convert to JSON to see the format + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + containerWithUser.toXContent(builder, EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + + // Parse back from JSON - this tests the User.parse() call in line 154 + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + MLMemoryContainer parsedContainer = MLMemoryContainer.parse(parser); + + // Verify the container was parsed correctly + assertEquals("user-test-container", parsedContainer.getName()); + assertEquals("test with user", parsedContainer.getDescription()); + assertEquals(testUser, parsedContainer.getOwner()); + } + + @Test + public void testEqualsAndHashCode() { + MLMemoryContainer container1 = MLMemoryContainer + .builder() + .name("test-container") + .description("test description") + .owner(testUser) + .tenantId("test-tenant") + .createdTime(testCreatedTime) + .lastUpdatedTime(testLastUpdatedTime) + .memoryStorageConfig(testMemoryStorageConfig) + .build(); + + MLMemoryContainer container2 = MLMemoryContainer + .builder() + .name("test-container") + .description("test description") + .owner(testUser) + .tenantId("test-tenant") + .createdTime(testCreatedTime) + .lastUpdatedTime(testLastUpdatedTime) + .memoryStorageConfig(testMemoryStorageConfig) + .build(); + + MLMemoryContainer container3 = MLMemoryContainer + .builder() + .name("different-container") + .description("test description") + .owner(testUser) + .tenantId("test-tenant") + .createdTime(testCreatedTime) + .lastUpdatedTime(testLastUpdatedTime) + .memoryStorageConfig(testMemoryStorageConfig) + .build(); + + assertEquals(container1, container2); + assertEquals(container1.hashCode(), container2.hashCode()); + assert (!container1.equals(container3)); + assert (container1.hashCode() != container3.hashCode()); + } + + @Test + public void testSettersAndGetters() { + MLMemoryContainer container = new MLMemoryContainer(null, null, null, null, null, null, null); + + container.setName("new-name"); + container.setDescription("new-description"); + container.setOwner(testUser); + container.setTenantId("new-tenant"); + container.setCreatedTime(testCreatedTime); + container.setLastUpdatedTime(testLastUpdatedTime); + container.setMemoryStorageConfig(testMemoryStorageConfig); + + assertEquals("new-name", container.getName()); + assertEquals("new-description", container.getDescription()); + assertEquals(testUser, container.getOwner()); + assertEquals("new-tenant", container.getTenantId()); + assertEquals(testCreatedTime, container.getCreatedTime()); + assertEquals(testLastUpdatedTime, container.getLastUpdatedTime()); + assertEquals(testMemoryStorageConfig, container.getMemoryStorageConfig()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/memorycontainer/MLMemoryTest.java b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MLMemoryTest.java new file mode 100644 index 0000000000..4aa1d5a137 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MLMemoryTest.java @@ -0,0 +1,445 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.memorycontainer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; + +public class MLMemoryTest { + + private MLMemory memoryWithAllFields; + private MLMemory memoryMinimal; + private Map testTags; + private Instant testCreatedTime; + private Instant testUpdatedTime; + private float[] testEmbedding; + private Map sparseEmbedding; + + @Before + public void setUp() { + testCreatedTime = Instant.now(); + testUpdatedTime = Instant.now().plusSeconds(60); + + testTags = new HashMap<>(); + testTags.put("topic", "machine learning"); + testTags.put("priority", "high"); + + testEmbedding = new float[] { 0.1f, 0.2f, 0.3f }; + + sparseEmbedding = new HashMap<>(); + sparseEmbedding.put("token1", 0.5f); + sparseEmbedding.put("token2", 0.8f); + + // Memory with all fields + memoryWithAllFields = MLMemory + .builder() + .sessionId("session-123") + .memory("This is a test memory content") + .memoryType(MemoryType.RAW_MESSAGE) + .userId("user-456") + .agentId("agent-789") + .role("user") + .tags(testTags) + .createdTime(testCreatedTime) + .lastUpdatedTime(testUpdatedTime) + .memoryEmbedding(testEmbedding) + .build(); + + // Minimal memory (only required fields) + memoryMinimal = MLMemory + .builder() + .sessionId("session-minimal") + .memory("Minimal memory") + .memoryType(MemoryType.FACT) + .createdTime(testCreatedTime) + .lastUpdatedTime(testUpdatedTime) + .build(); + } + + @Test + public void testBuilderWithAllFields() { + assertNotNull(memoryWithAllFields); + assertEquals("session-123", memoryWithAllFields.getSessionId()); + assertEquals("This is a test memory content", memoryWithAllFields.getMemory()); + assertEquals(MemoryType.RAW_MESSAGE, memoryWithAllFields.getMemoryType()); + assertEquals("user-456", memoryWithAllFields.getUserId()); + assertEquals("agent-789", memoryWithAllFields.getAgentId()); + assertEquals("user", memoryWithAllFields.getRole()); + assertEquals(testTags, memoryWithAllFields.getTags()); + assertEquals(testCreatedTime, memoryWithAllFields.getCreatedTime()); + assertEquals(testUpdatedTime, memoryWithAllFields.getLastUpdatedTime()); + assertEquals(testEmbedding, memoryWithAllFields.getMemoryEmbedding()); + } + + @Test + public void testBuilderMinimal() { + assertNotNull(memoryMinimal); + assertEquals("session-minimal", memoryMinimal.getSessionId()); + assertEquals("Minimal memory", memoryMinimal.getMemory()); + assertEquals(MemoryType.FACT, memoryMinimal.getMemoryType()); + assertNull(memoryMinimal.getUserId()); + assertNull(memoryMinimal.getAgentId()); + assertNull(memoryMinimal.getRole()); + assertNull(memoryMinimal.getTags()); + assertEquals(testCreatedTime, memoryMinimal.getCreatedTime()); + assertEquals(testUpdatedTime, memoryMinimal.getLastUpdatedTime()); + assertNull(memoryMinimal.getMemoryEmbedding()); + } + + @Test + public void testStreamInputOutput() throws IOException { + // Test with all fields + BytesStreamOutput out = new BytesStreamOutput(); + memoryWithAllFields.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLMemory deserialized = new MLMemory(in); + + assertEquals(memoryWithAllFields.getSessionId(), deserialized.getSessionId()); + assertEquals(memoryWithAllFields.getMemory(), deserialized.getMemory()); + assertEquals(memoryWithAllFields.getMemoryType(), deserialized.getMemoryType()); + assertEquals(memoryWithAllFields.getUserId(), deserialized.getUserId()); + assertEquals(memoryWithAllFields.getAgentId(), deserialized.getAgentId()); + assertEquals(memoryWithAllFields.getRole(), deserialized.getRole()); + assertEquals(memoryWithAllFields.getTags(), deserialized.getTags()); + assertEquals(memoryWithAllFields.getCreatedTime(), deserialized.getCreatedTime()); + assertEquals(memoryWithAllFields.getLastUpdatedTime(), deserialized.getLastUpdatedTime()); + // Note: memoryEmbedding is not serialized in StreamInput/Output + assertNull(deserialized.getMemoryEmbedding()); + } + + @Test + public void testStreamInputOutputMinimal() throws IOException { + // Test with minimal fields + BytesStreamOutput out = new BytesStreamOutput(); + memoryMinimal.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLMemory deserialized = new MLMemory(in); + + assertEquals(memoryMinimal.getSessionId(), deserialized.getSessionId()); + assertEquals(memoryMinimal.getMemory(), deserialized.getMemory()); + assertEquals(memoryMinimal.getMemoryType(), deserialized.getMemoryType()); + assertNull(deserialized.getUserId()); + assertNull(deserialized.getAgentId()); + assertNull(deserialized.getRole()); + assertNull(deserialized.getTags()); + assertEquals(memoryMinimal.getCreatedTime(), deserialized.getCreatedTime()); + assertEquals(memoryMinimal.getLastUpdatedTime(), deserialized.getLastUpdatedTime()); + assertNull(deserialized.getMemoryEmbedding()); + } + + @Test + public void testStreamInputOutputEmptyTags() throws IOException { + // Test with empty tags + MLMemory memoryEmptyTags = MLMemory + .builder() + .sessionId("session-empty-tags") + .memory("Memory with empty tags") + .memoryType(MemoryType.FACT) + .tags(new HashMap<>()) + .createdTime(testCreatedTime) + .lastUpdatedTime(testUpdatedTime) + .build(); + + BytesStreamOutput out = new BytesStreamOutput(); + memoryEmptyTags.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLMemory deserialized = new MLMemory(in); + + assertNull(deserialized.getTags()); + } + + @Test + public void testToXContent() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + memoryWithAllFields.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("\"session_id\":\"session-123\"")); + assertTrue(jsonString.contains("\"memory\":\"This is a test memory content\"")); + assertTrue(jsonString.contains("\"memory_type\":\"RAW_MESSAGE\"")); + assertTrue(jsonString.contains("\"user_id\":\"user-456\"")); + assertTrue(jsonString.contains("\"agent_id\":\"agent-789\"")); + assertTrue(jsonString.contains("\"role\":\"user\"")); + assertTrue(jsonString.contains("\"topic\":\"machine learning\"")); + assertTrue(jsonString.contains("\"priority\":\"high\"")); + assertTrue(jsonString.contains("\"created_time\":" + testCreatedTime.toEpochMilli())); + assertTrue(jsonString.contains("\"last_updated_time\":" + testUpdatedTime.toEpochMilli())); + } + + @Test + public void testToXContentMinimal() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + memoryMinimal.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("\"session_id\":\"session-minimal\"")); + assertTrue(jsonString.contains("\"memory\":\"Minimal memory\"")); + assertTrue(jsonString.contains("\"memory_type\":\"FACT\"")); + // Optional fields should not be present + assertTrue(!jsonString.contains("\"user_id\"")); + assertTrue(!jsonString.contains("\"agent_id\"")); + assertTrue(!jsonString.contains("\"role\"")); + assertTrue(!jsonString.contains("\"tags\"")); + assertTrue(!jsonString.contains("\"memory_embedding\"")); + } + + @Test + public void testParse() throws IOException { + String jsonString = "{" + + "\"session_id\":\"session-123\"," + + "\"memory\":\"This is a test memory content\"," + + "\"memory_type\":\"RAW_MESSAGE\"," + + "\"user_id\":\"user-456\"," + + "\"agent_id\":\"agent-789\"," + + "\"role\":\"user\"," + + "\"tags\":{\"topic\":\"machine learning\",\"priority\":\"high\"}," + + "\"created_time\":" + + testCreatedTime.toEpochMilli() + + "," + + "\"last_updated_time\":" + + testUpdatedTime.toEpochMilli() + + "," + + "\"memory_embedding\":{\"values\":[0.1,0.2,0.3]}" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLMemory parsed = MLMemory.parse(parser); + + assertEquals("session-123", parsed.getSessionId()); + assertEquals("This is a test memory content", parsed.getMemory()); + assertEquals(MemoryType.RAW_MESSAGE, parsed.getMemoryType()); + assertEquals("user-456", parsed.getUserId()); + assertEquals("agent-789", parsed.getAgentId()); + assertEquals("user", parsed.getRole()); + assertEquals(2, parsed.getTags().size()); + assertEquals("machine learning", parsed.getTags().get("topic")); + assertEquals("high", parsed.getTags().get("priority")); + assertEquals(testCreatedTime.toEpochMilli(), parsed.getCreatedTime().toEpochMilli()); + assertEquals(testUpdatedTime.toEpochMilli(), parsed.getLastUpdatedTime().toEpochMilli()); + assertNotNull(parsed.getMemoryEmbedding()); + } + + @Test + public void testParseMinimal() throws IOException { + String jsonString = "{" + + "\"session_id\":\"session-minimal\"," + + "\"memory\":\"Minimal memory\"," + + "\"memory_type\":\"FACT\"," + + "\"created_time\":" + + testCreatedTime.toEpochMilli() + + "," + + "\"last_updated_time\":" + + testUpdatedTime.toEpochMilli() + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLMemory parsed = MLMemory.parse(parser); + + assertEquals("session-minimal", parsed.getSessionId()); + assertEquals("Minimal memory", parsed.getMemory()); + assertEquals(MemoryType.FACT, parsed.getMemoryType()); + assertNull(parsed.getUserId()); + assertNull(parsed.getAgentId()); + assertNull(parsed.getRole()); + assertNull(parsed.getTags()); + assertEquals(testCreatedTime.toEpochMilli(), parsed.getCreatedTime().toEpochMilli()); + assertEquals(testUpdatedTime.toEpochMilli(), parsed.getLastUpdatedTime().toEpochMilli()); + assertNull(parsed.getMemoryEmbedding()); + } + + @Test + public void testParseWithUnknownFields() throws IOException { + String jsonString = "{" + + "\"session_id\":\"session-123\"," + + "\"memory\":\"Test memory\"," + + "\"memory_type\":\"FACT\"," + + "\"unknown_field\":\"should be ignored\"," + + "\"created_time\":" + + testCreatedTime.toEpochMilli() + + "," + + "\"last_updated_time\":" + + testUpdatedTime.toEpochMilli() + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLMemory parsed = MLMemory.parse(parser); + + assertEquals("session-123", parsed.getSessionId()); + assertEquals("Test memory", parsed.getMemory()); + assertEquals(MemoryType.FACT, parsed.getMemoryType()); + } + + @Test + public void testToIndexMap() { + Map indexMap = memoryWithAllFields.toIndexMap(); + + assertEquals("session-123", indexMap.get("session_id")); + assertEquals("This is a test memory content", indexMap.get("memory")); + assertEquals("RAW_MESSAGE", indexMap.get("memory_type")); + assertEquals("user-456", indexMap.get("user_id")); + assertEquals("agent-789", indexMap.get("agent_id")); + assertEquals("user", indexMap.get("role")); + assertEquals(testTags, indexMap.get("tags")); + assertEquals(testCreatedTime.toEpochMilli(), indexMap.get("created_time")); + assertEquals(testUpdatedTime.toEpochMilli(), indexMap.get("last_updated_time")); + assertEquals(testEmbedding, indexMap.get("memory_embedding")); + } + + @Test + public void testToIndexMapMinimal() { + Map indexMap = memoryMinimal.toIndexMap(); + + assertEquals("session-minimal", indexMap.get("session_id")); + assertEquals("Minimal memory", indexMap.get("memory")); + assertEquals("FACT", indexMap.get("memory_type")); + assertEquals(testCreatedTime.toEpochMilli(), indexMap.get("created_time")); + assertEquals(testUpdatedTime.toEpochMilli(), indexMap.get("last_updated_time")); + + // Optional fields should not be in the map + assertTrue(!indexMap.containsKey("user_id")); + assertTrue(!indexMap.containsKey("agent_id")); + assertTrue(!indexMap.containsKey("role")); + assertTrue(!indexMap.containsKey("tags")); + assertTrue(!indexMap.containsKey("memory_embedding")); + } + + @Test + public void testSettersWork() { + MLMemory memory = MLMemory + .builder() + .sessionId("initial-session") + .memory("initial memory") + .memoryType(MemoryType.FACT) + .createdTime(testCreatedTime) + .lastUpdatedTime(testUpdatedTime) + .build(); + + // Test setters + memory.setSessionId("new-session"); + memory.setMemory("new memory"); + memory.setMemoryType(MemoryType.RAW_MESSAGE); + memory.setUserId("new-user"); + memory.setAgentId("new-agent"); + memory.setRole("assistant"); + memory.setTags(testTags); + memory.setMemoryEmbedding(sparseEmbedding); + + assertEquals("new-session", memory.getSessionId()); + assertEquals("new memory", memory.getMemory()); + assertEquals(MemoryType.RAW_MESSAGE, memory.getMemoryType()); + assertEquals("new-user", memory.getUserId()); + assertEquals("new-agent", memory.getAgentId()); + assertEquals("assistant", memory.getRole()); + assertEquals(testTags, memory.getTags()); + assertEquals(sparseEmbedding, memory.getMemoryEmbedding()); + } + + @Test + public void testXContentRoundTrip() throws IOException { + // Use memory without embedding for round trip test since embedding parsing is complex + MLMemory memoryNoEmbedding = MLMemory + .builder() + .sessionId("session-123") + .memory("This is a test memory content") + .memoryType(MemoryType.RAW_MESSAGE) + .userId("user-456") + .agentId("agent-789") + .role("user") + .tags(testTags) + .createdTime(testCreatedTime) + .lastUpdatedTime(testUpdatedTime) + .build(); + + // Convert to XContent + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + memoryNoEmbedding.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + // Parse back + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + MLMemory parsed = MLMemory.parse(parser); + + // Verify all fields match + assertEquals(memoryNoEmbedding.getSessionId(), parsed.getSessionId()); + assertEquals(memoryNoEmbedding.getMemory(), parsed.getMemory()); + assertEquals(memoryNoEmbedding.getMemoryType(), parsed.getMemoryType()); + assertEquals(memoryNoEmbedding.getUserId(), parsed.getUserId()); + assertEquals(memoryNoEmbedding.getAgentId(), parsed.getAgentId()); + assertEquals(memoryNoEmbedding.getRole(), parsed.getRole()); + assertEquals(memoryNoEmbedding.getTags(), parsed.getTags()); + assertEquals(memoryNoEmbedding.getCreatedTime().toEpochMilli(), parsed.getCreatedTime().toEpochMilli()); + assertEquals(memoryNoEmbedding.getLastUpdatedTime().toEpochMilli(), parsed.getLastUpdatedTime().toEpochMilli()); + } + + @Test + public void testSpecialCharactersInFields() throws IOException { + Map specialTags = new HashMap<>(); + specialTags.put("key with spaces", "value with\nnewlines"); + specialTags.put("unicode_key_🔥", "unicode_value_✨"); + + MLMemory specialMemory = MLMemory + .builder() + .sessionId("session-with-special-chars-🚀") + .memory("Memory with\n\ttabs and\nnewlines and \"quotes\"") + .memoryType(MemoryType.FACT) + .role("user/assistant") + .tags(specialTags) + .createdTime(testCreatedTime) + .lastUpdatedTime(testUpdatedTime) + .build(); + + // Test XContent round trip + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + specialMemory.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + MLMemory parsed = MLMemory.parse(parser); + + assertEquals(specialMemory.getSessionId(), parsed.getSessionId()); + assertEquals(specialMemory.getMemory(), parsed.getMemory()); + assertEquals(specialMemory.getRole(), parsed.getRole()); + assertEquals(specialMemory.getTags(), parsed.getTags()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstantsTest.java b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstantsTest.java new file mode 100644 index 0000000000..0989fe512a --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstantsTest.java @@ -0,0 +1,202 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.memorycontainer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import org.junit.Test; + +/** + * Tests for MemoryContainerConstants to ensure constants are properly defined + */ +public class MemoryContainerConstantsTest { + + @Test + public void testMemoryContainerFieldConstants() { + assertEquals("memory_container_id", MemoryContainerConstants.MEMORY_CONTAINER_ID_FIELD); + assertEquals("name", MemoryContainerConstants.NAME_FIELD); + assertEquals("description", MemoryContainerConstants.DESCRIPTION_FIELD); + assertEquals("owner", MemoryContainerConstants.OWNER_FIELD); + assertEquals("created_time", MemoryContainerConstants.CREATED_TIME_FIELD); + assertEquals("last_updated_time", MemoryContainerConstants.LAST_UPDATED_TIME_FIELD); + assertEquals("memory_storage_config", MemoryContainerConstants.MEMORY_STORAGE_CONFIG_FIELD); + } + + @Test + public void testMemoryStorageConfigFieldConstants() { + assertEquals("memory_index_name", MemoryContainerConstants.MEMORY_INDEX_NAME_FIELD); + assertEquals("semantic_storage_enabled", MemoryContainerConstants.SEMANTIC_STORAGE_ENABLED_FIELD); + assertEquals("embedding_model_type", MemoryContainerConstants.EMBEDDING_MODEL_TYPE_FIELD); + assertEquals("embedding_model_id", MemoryContainerConstants.EMBEDDING_MODEL_ID_FIELD); + assertEquals("llm_model_id", MemoryContainerConstants.LLM_MODEL_ID_FIELD); + assertEquals("dimension", MemoryContainerConstants.DIMENSION_FIELD); + assertEquals("max_infer_size", MemoryContainerConstants.MAX_INFER_SIZE_FIELD); + } + + @Test + public void testDefaultValues() { + assertEquals(5, MemoryContainerConstants.MAX_INFER_SIZE_DEFAULT_VALUE); + } + + @Test + public void testIndexPrefixes() { + assertEquals("ml-static-memory-", MemoryContainerConstants.STATIC_MEMORY_INDEX_PREFIX); + assertEquals("ml-knn-memory-", MemoryContainerConstants.KNN_MEMORY_INDEX_PREFIX); + assertEquals("ml-sparse-memory-", MemoryContainerConstants.SPARSE_MEMORY_INDEX_PREFIX); + } + + @Test + public void testMemoryDataFieldConstants() { + assertEquals("user_id", MemoryContainerConstants.USER_ID_FIELD); + assertEquals("agent_id", MemoryContainerConstants.AGENT_ID_FIELD); + assertEquals("session_id", MemoryContainerConstants.SESSION_ID_FIELD); + assertEquals("memory", MemoryContainerConstants.MEMORY_FIELD); + assertEquals("memory_embedding", MemoryContainerConstants.MEMORY_EMBEDDING_FIELD); + assertEquals("tags", MemoryContainerConstants.TAGS_FIELD); + assertEquals("memory_id", MemoryContainerConstants.MEMORY_ID_FIELD); + assertEquals("memory_type", MemoryContainerConstants.MEMORY_TYPE_FIELD); + assertEquals("role", MemoryContainerConstants.ROLE_FIELD); + } + + @Test + public void testRequestFieldConstants() { + assertEquals("message", MemoryContainerConstants.MESSAGE_FIELD); + assertEquals("messages", MemoryContainerConstants.MESSAGES_FIELD); + assertEquals("content", MemoryContainerConstants.CONTENT_FIELD); + assertEquals("infer", MemoryContainerConstants.INFER_FIELD); + assertEquals("query", MemoryContainerConstants.QUERY_FIELD); + assertEquals("text", MemoryContainerConstants.TEXT_FIELD); + } + + @Test + public void testKnnIndexSettings() { + assertEquals("lucene", MemoryContainerConstants.KNN_ENGINE); + assertEquals("cosinesimil", MemoryContainerConstants.KNN_SPACE_TYPE); + assertEquals("hnsw", MemoryContainerConstants.KNN_METHOD_NAME); + assertEquals(100, MemoryContainerConstants.KNN_EF_SEARCH); + assertEquals(100, MemoryContainerConstants.KNN_EF_CONSTRUCTION); + assertEquals(16, MemoryContainerConstants.KNN_M); + } + + @Test + public void testRestApiPaths() { + assertEquals("/_plugins/_ml/memory_containers", MemoryContainerConstants.BASE_MEMORY_CONTAINERS_PATH); + assertEquals("/_plugins/_ml/memory_containers/_create", MemoryContainerConstants.CREATE_MEMORY_CONTAINER_PATH); + assertEquals("memory_container_id", MemoryContainerConstants.PARAMETER_MEMORY_CONTAINER_ID); + assertEquals("memory_id", MemoryContainerConstants.PARAMETER_MEMORY_ID); + + String expectedMemoriesPath = "/_plugins/_ml/memory_containers/{memory_container_id}/memories"; + assertEquals(expectedMemoriesPath, MemoryContainerConstants.MEMORIES_PATH); + + String expectedSearchPath = expectedMemoriesPath + "/_search"; + assertEquals(expectedSearchPath, MemoryContainerConstants.SEARCH_MEMORIES_PATH); + + String expectedDeletePath = expectedMemoriesPath + "/{memory_id}"; + assertEquals(expectedDeletePath, MemoryContainerConstants.DELETE_MEMORY_PATH); + assertEquals(expectedDeletePath, MemoryContainerConstants.UPDATE_MEMORY_PATH); + } + + @Test + public void testResponseFields() { + assertEquals("status", MemoryContainerConstants.STATUS_FIELD); + } + + @Test + public void testMemoryDecisionFields() { + assertEquals("memory_decision", MemoryContainerConstants.MEMORY_DECISION_FIELD); + assertEquals("old_memory", MemoryContainerConstants.OLD_MEMORY_FIELD); + assertEquals("retrieved_facts", MemoryContainerConstants.RETRIEVED_FACTS_FIELD); + assertEquals("event", MemoryContainerConstants.EVENT_FIELD); + assertEquals("score", MemoryContainerConstants.SCORE_FIELD); + } + + @Test + public void testApiLimits() { + assertEquals(10, MemoryContainerConstants.MAX_MESSAGES_PER_REQUEST); + assertEquals("Cannot process more than 10 messages in a single request", MemoryContainerConstants.MAX_MESSAGES_EXCEEDED_ERROR); + } + + @Test + public void testErrorMessages() { + // Test semantic storage error messages + assertEquals( + "Embedding model type is required when embedding model ID is provided", + MemoryContainerConstants.SEMANTIC_STORAGE_EMBEDDING_MODEL_TYPE_REQUIRED_ERROR + ); + assertEquals( + "Embedding model ID is required when embedding model type is provided", + MemoryContainerConstants.SEMANTIC_STORAGE_EMBEDDING_MODEL_ID_REQUIRED_ERROR + ); + assertEquals("Dimension is required for TEXT_EMBEDDING", MemoryContainerConstants.TEXT_EMBEDDING_DIMENSION_REQUIRED_ERROR); + assertEquals("Dimension is not allowed for SPARSE_ENCODING", MemoryContainerConstants.SPARSE_ENCODING_DIMENSION_NOT_ALLOWED_ERROR); + assertEquals( + "Embedding model type must be either TEXT_EMBEDDING or SPARSE_ENCODING", + MemoryContainerConstants.INVALID_EMBEDDING_MODEL_TYPE_ERROR + ); + assertEquals("Maximum infer size cannot exceed 10", MemoryContainerConstants.MAX_INFER_SIZE_LIMIT_ERROR); + assertTrue(MemoryContainerConstants.FIELD_NOT_ALLOWED_SEMANTIC_DISABLED_ERROR.contains("%s")); + + // Test model validation error messages + assertTrue(MemoryContainerConstants.LLM_MODEL_NOT_FOUND_ERROR.contains("%s")); + assertTrue(MemoryContainerConstants.LLM_MODEL_NOT_REMOTE_ERROR.contains("%s")); + assertTrue(MemoryContainerConstants.EMBEDDING_MODEL_NOT_FOUND_ERROR.contains("%s")); + assertTrue(MemoryContainerConstants.EMBEDDING_MODEL_TYPE_MISMATCH_ERROR.contains("%s")); + assertEquals( + "infer=true requires llm_model_id to be configured in memory storage", + MemoryContainerConstants.INFER_REQUIRES_LLM_MODEL_ERROR + ); + } + + @Test + public void testLlmPrompts() { + // Test Personal Information Organizer prompt + assertNotNull(MemoryContainerConstants.PERSONAL_INFORMATION_ORGANIZER_PROMPT); + assertTrue(MemoryContainerConstants.PERSONAL_INFORMATION_ORGANIZER_PROMPT.contains("Personal Information Organizer")); + assertTrue(MemoryContainerConstants.PERSONAL_INFORMATION_ORGANIZER_PROMPT.contains("Extract and organize personal information")); + assertTrue(MemoryContainerConstants.PERSONAL_INFORMATION_ORGANIZER_PROMPT.contains("\"facts\"")); + + // Test Default Update Memory prompt + assertNotNull(MemoryContainerConstants.DEFAULT_UPDATE_MEMORY_PROMPT); + assertTrue(MemoryContainerConstants.DEFAULT_UPDATE_MEMORY_PROMPT.contains("smart memory manager")); + assertTrue(MemoryContainerConstants.DEFAULT_UPDATE_MEMORY_PROMPT.contains("memory_decision")); + assertTrue(MemoryContainerConstants.DEFAULT_UPDATE_MEMORY_PROMPT.contains("ADD|UPDATE|DELETE|NONE")); + } + + @Test + public void testPromptStructure() { + // Verify Personal Information Organizer prompt has proper XML structure + String organizerPrompt = MemoryContainerConstants.PERSONAL_INFORMATION_ORGANIZER_PROMPT; + assertTrue(organizerPrompt.startsWith("")); + assertTrue(organizerPrompt.contains("")); + assertTrue(organizerPrompt.contains("")); + assertTrue(organizerPrompt.contains("")); + assertTrue(organizerPrompt.contains("")); + assertTrue(organizerPrompt.contains("")); + + // Verify Update Memory prompt has proper XML structure + String updatePrompt = MemoryContainerConstants.DEFAULT_UPDATE_MEMORY_PROMPT; + assertTrue(updatePrompt.startsWith("")); + assertTrue(updatePrompt.contains("")); + assertTrue(updatePrompt.contains("")); + assertTrue(updatePrompt.contains("")); + assertTrue(updatePrompt.contains("")); + assertTrue(updatePrompt.contains("")); + assertTrue(updatePrompt.contains("")); + assertTrue(updatePrompt.contains("")); + } + + @Test + public void testConstantsConsistency() { + // Verify that DELETE and UPDATE paths use the same pattern + assertEquals(MemoryContainerConstants.DELETE_MEMORY_PATH, MemoryContainerConstants.UPDATE_MEMORY_PATH); + + // Verify that parameter names are used in path construction + assertTrue(MemoryContainerConstants.MEMORIES_PATH.contains("{" + MemoryContainerConstants.PARAMETER_MEMORY_CONTAINER_ID + "}")); + assertTrue(MemoryContainerConstants.DELETE_MEMORY_PATH.contains("{" + MemoryContainerConstants.PARAMETER_MEMORY_ID + "}")); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryDecisionRequestTest.java b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryDecisionRequestTest.java new file mode 100644 index 0000000000..b1f539e0b1 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryDecisionRequestTest.java @@ -0,0 +1,288 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.memorycontainer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.TestHelper; + +public class MemoryDecisionRequestTest { + + private MemoryDecisionRequest requestWithAllFields; + private MemoryDecisionRequest requestMinimal; + private MemoryDecisionRequest requestEmpty; + private List testOldMemories; + private List testRetrievedFacts; + + @Before + public void setUp() { + // Create test old memories + testOldMemories = Arrays + .asList( + MemoryDecisionRequest.OldMemory.builder().id("mem-1").text("User's name is John").score(0.95f).build(), + MemoryDecisionRequest.OldMemory.builder().id("mem-2").text("Lives in Boston").score(0.87f).build(), + MemoryDecisionRequest.OldMemory.builder().id("mem-3").text("Works at TechCorp").score(0.76f).build() + ); + + // Create test retrieved facts + testRetrievedFacts = Arrays + .asList("User's name is John", "Lives in San Francisco", "Works at TechCorp", "Has 10 years of experience"); + + // Request with all fields + requestWithAllFields = MemoryDecisionRequest.builder().oldMemory(testOldMemories).retrievedFacts(testRetrievedFacts).build(); + + // Minimal request (only retrieved facts) + requestMinimal = MemoryDecisionRequest.builder().retrievedFacts(Arrays.asList("Single fact")).build(); + + // Empty request + requestEmpty = MemoryDecisionRequest.builder().oldMemory(new ArrayList<>()).retrievedFacts(new ArrayList<>()).build(); + } + + @Test + public void testBuilderWithAllFields() { + assertNotNull(requestWithAllFields); + assertEquals(testOldMemories, requestWithAllFields.getOldMemory()); + assertEquals(testRetrievedFacts, requestWithAllFields.getRetrievedFacts()); + assertEquals(3, requestWithAllFields.getOldMemory().size()); + assertEquals(4, requestWithAllFields.getRetrievedFacts().size()); + } + + @Test + public void testBuilderMinimal() { + assertNotNull(requestMinimal); + assertEquals(null, requestMinimal.getOldMemory()); + assertEquals(1, requestMinimal.getRetrievedFacts().size()); + assertEquals("Single fact", requestMinimal.getRetrievedFacts().get(0)); + } + + @Test + public void testBuilderEmpty() { + assertNotNull(requestEmpty); + assertEquals(0, requestEmpty.getOldMemory().size()); + assertEquals(0, requestEmpty.getRetrievedFacts().size()); + } + + @Test + public void testOldMemoryBuilder() { + MemoryDecisionRequest.OldMemory oldMemory = MemoryDecisionRequest.OldMemory + .builder() + .id("test-id") + .text("test memory text") + .score(0.89f) + .build(); + + assertEquals("test-id", oldMemory.getId()); + assertEquals("test memory text", oldMemory.getText()); + assertEquals(0.89f, oldMemory.getScore(), 0.001); + } + + @Test + public void testOldMemoryToXContent() throws IOException { + MemoryDecisionRequest.OldMemory oldMemory = testOldMemories.get(0); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + oldMemory.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("\"id\":\"mem-1\"")); + assertTrue(jsonString.contains("\"text\":\"User's name is John\"")); + assertTrue(jsonString.contains("\"score\":0.95")); + } + + @Test + public void testToXContentWithAllFields() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + requestWithAllFields.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + // Check structure + assertTrue(jsonString.contains("\"old_memory\":[")); + assertTrue(jsonString.contains("\"retrieved_facts\":[")); + + // Check old memories + assertTrue(jsonString.contains("\"id\":\"mem-1\"")); + assertTrue(jsonString.contains("\"text\":\"User's name is John\"")); + assertTrue(jsonString.contains("\"score\":0.95")); + assertTrue(jsonString.contains("\"id\":\"mem-2\"")); + assertTrue(jsonString.contains("\"text\":\"Lives in Boston\"")); + assertTrue(jsonString.contains("\"score\":0.87")); + + // Check retrieved facts + assertTrue(jsonString.contains("\"Lives in San Francisco\"")); + assertTrue(jsonString.contains("\"Has 10 years of experience\"")); + } + + @Test + public void testToXContentMinimal() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + requestMinimal.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + // Should have empty old_memory array + assertTrue(jsonString.contains("\"old_memory\":[]")); + // Should have retrieved_facts + assertTrue(jsonString.contains("\"retrieved_facts\":[\"Single fact\"]")); + } + + @Test + public void testToXContentEmpty() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + requestEmpty.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + // Both arrays should be empty + assertTrue(jsonString.contains("\"old_memory\":[]")); + assertTrue(jsonString.contains("\"retrieved_facts\":[]")); + } + + @Test + public void testToXContentWithNullFields() throws IOException { + MemoryDecisionRequest requestNulls = MemoryDecisionRequest.builder().oldMemory(null).retrievedFacts(null).build(); + + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + requestNulls.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + // Should have empty arrays + assertTrue(jsonString.contains("\"old_memory\":[]")); + assertTrue(jsonString.contains("\"retrieved_facts\":[]")); + } + + @Test + public void testToJsonString() { + String jsonString = requestWithAllFields.toJsonString(); + assertNotNull(jsonString); + + // Verify JSON structure + assertTrue(jsonString.contains("\"old_memory\":[")); + assertTrue(jsonString.contains("\"retrieved_facts\":[")); + assertTrue(jsonString.contains("\"id\":\"mem-1\"")); + assertTrue(jsonString.contains("\"Lives in San Francisco\"")); + } + + @Test + public void testToJsonStringMinimal() { + String jsonString = requestMinimal.toJsonString(); + assertNotNull(jsonString); + + assertTrue(jsonString.contains("\"old_memory\":[]")); + assertTrue(jsonString.contains("\"retrieved_facts\":[\"Single fact\"]")); + } + + @Test + public void testDataAnnotationMethods() { + // Test @Data generated methods + MemoryDecisionRequest request1 = MemoryDecisionRequest + .builder() + .oldMemory(testOldMemories) + .retrievedFacts(testRetrievedFacts) + .build(); + + MemoryDecisionRequest request2 = MemoryDecisionRequest + .builder() + .oldMemory(testOldMemories) + .retrievedFacts(testRetrievedFacts) + .build(); + + // Test equals + assertEquals(request1, request2); + assertEquals(request1.hashCode(), request2.hashCode()); + + // Test setters + List newFacts = Arrays.asList("New fact 1", "New fact 2"); + request1.setRetrievedFacts(newFacts); + assertEquals(newFacts, request1.getRetrievedFacts()); + + // Test toString + String str = request1.toString(); + assertTrue(str.contains("oldMemory")); + assertTrue(str.contains("retrievedFacts")); + } + + @Test + public void testOldMemoryDataAnnotations() { + MemoryDecisionRequest.OldMemory memory1 = MemoryDecisionRequest.OldMemory.builder().id("id-1").text("text-1").score(0.9f).build(); + + MemoryDecisionRequest.OldMemory memory2 = MemoryDecisionRequest.OldMemory.builder().id("id-1").text("text-1").score(0.9f).build(); + + // Test equals + assertEquals(memory1, memory2); + assertEquals(memory1.hashCode(), memory2.hashCode()); + + // Test setters + memory1.setId("new-id"); + memory1.setText("new-text"); + memory1.setScore(0.5f); + + assertEquals("new-id", memory1.getId()); + assertEquals("new-text", memory1.getText()); + assertEquals(0.5f, memory1.getScore(), 0.001); + } + + @Test + public void testSpecialCharactersInFields() throws IOException { + List specialMemories = Arrays + .asList( + MemoryDecisionRequest.OldMemory + .builder() + .id("id-with-special-🔥") + .text("Text with\n\ttabs and \"quotes\"") + .score(0.99f) + .build() + ); + + List specialFacts = Arrays.asList("Fact with 'quotes'", "Fact with\nnewlines", "Fact with unicode ✨"); + + MemoryDecisionRequest specialRequest = MemoryDecisionRequest + .builder() + .oldMemory(specialMemories) + .retrievedFacts(specialFacts) + .build(); + + String jsonString = specialRequest.toJsonString(); + assertNotNull(jsonString); + + // Verify special characters are properly handled - JSON may escape unicode + assertTrue(jsonString.contains("id-with-special-")); + // JSON escaping will handle newlines and tabs + assertTrue(jsonString.contains("Text with")); + assertTrue(jsonString.contains("tabs")); + assertTrue(jsonString.contains("quotes")); + assertTrue(jsonString.contains("Fact with unicode")); + } + + @Test + public void testLargeRequest() throws IOException { + // Test with many items + List manyMemories = new ArrayList<>(); + List manyFacts = new ArrayList<>(); + + for (int i = 0; i < 100; i++) { + manyMemories.add(MemoryDecisionRequest.OldMemory.builder().id("mem-" + i).text("Memory text " + i).score(i / 100.0f).build()); + manyFacts.add("Fact number " + i); + } + + MemoryDecisionRequest largeRequest = MemoryDecisionRequest.builder().oldMemory(manyMemories).retrievedFacts(manyFacts).build(); + + String jsonString = largeRequest.toJsonString(); + assertNotNull(jsonString); + + assertEquals(100, largeRequest.getOldMemory().size()); + assertEquals(100, largeRequest.getRetrievedFacts().size()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryDecisionTest.java b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryDecisionTest.java new file mode 100644 index 0000000000..ed5cc677e7 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryDecisionTest.java @@ -0,0 +1,319 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.memorycontainer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.ml.common.transport.memorycontainer.memory.MemoryEvent; + +public class MemoryDecisionTest { + + private MemoryDecision decisionWithAllFields; + private MemoryDecision decisionMinimal; + private MemoryDecision addDecision; + private MemoryDecision deleteDecision; + private MemoryDecision noneDecision; + + @Before + public void setUp() { + // UPDATE decision with all fields + decisionWithAllFields = MemoryDecision + .builder() + .id("memory-123") + .text("Updated memory text") + .event(MemoryEvent.UPDATE) + .oldMemory("Original memory text") + .build(); + + // Minimal decision (no oldMemory) + decisionMinimal = MemoryDecision.builder().id("memory-456").text("New memory text").event(MemoryEvent.ADD).build(); + + // Different event types + addDecision = MemoryDecision.builder().id("add-memory-789").text("Adding new memory").event(MemoryEvent.ADD).build(); + + deleteDecision = MemoryDecision.builder().id("delete-memory-101").text("Memory to delete").event(MemoryEvent.DELETE).build(); + + noneDecision = MemoryDecision.builder().id("none-memory-202").text("No change needed").event(MemoryEvent.NONE).build(); + } + + @Test + public void testBuilderWithAllFields() { + assertNotNull(decisionWithAllFields); + assertEquals("memory-123", decisionWithAllFields.getId()); + assertEquals("Updated memory text", decisionWithAllFields.getText()); + assertEquals(MemoryEvent.UPDATE, decisionWithAllFields.getEvent()); + assertEquals("Original memory text", decisionWithAllFields.getOldMemory()); + } + + @Test + public void testBuilderMinimal() { + assertNotNull(decisionMinimal); + assertEquals("memory-456", decisionMinimal.getId()); + assertEquals("New memory text", decisionMinimal.getText()); + assertEquals(MemoryEvent.ADD, decisionMinimal.getEvent()); + assertNull(decisionMinimal.getOldMemory()); + } + + @Test + public void testConstructorWithAllParameters() { + MemoryDecision decision = new MemoryDecision("id-1", "text-1", MemoryEvent.UPDATE, "old-text"); + assertEquals("id-1", decision.getId()); + assertEquals("text-1", decision.getText()); + assertEquals(MemoryEvent.UPDATE, decision.getEvent()); + assertEquals("old-text", decision.getOldMemory()); + } + + @Test + public void testConstructorWithNullOldMemory() { + MemoryDecision decision = new MemoryDecision("id-2", "text-2", MemoryEvent.ADD, null); + assertEquals("id-2", decision.getId()); + assertEquals("text-2", decision.getText()); + assertEquals(MemoryEvent.ADD, decision.getEvent()); + assertNull(decision.getOldMemory()); + } + + @Test + public void testStreamInputOutput() throws IOException { + // Test with all fields + BytesStreamOutput out = new BytesStreamOutput(); + decisionWithAllFields.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MemoryDecision deserialized = new MemoryDecision(in); + + assertEquals(decisionWithAllFields.getId(), deserialized.getId()); + assertEquals(decisionWithAllFields.getText(), deserialized.getText()); + assertEquals(decisionWithAllFields.getEvent(), deserialized.getEvent()); + assertEquals(decisionWithAllFields.getOldMemory(), deserialized.getOldMemory()); + } + + @Test + public void testStreamInputOutputMinimal() throws IOException { + // Test with minimal fields + BytesStreamOutput out = new BytesStreamOutput(); + decisionMinimal.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MemoryDecision deserialized = new MemoryDecision(in); + + assertEquals(decisionMinimal.getId(), deserialized.getId()); + assertEquals(decisionMinimal.getText(), deserialized.getText()); + assertEquals(decisionMinimal.getEvent(), deserialized.getEvent()); + assertNull(deserialized.getOldMemory()); + } + + @Test + public void testStreamInputOutputAllEventTypes() throws IOException { + // Test all event types + MemoryDecision[] decisions = { addDecision, deleteDecision, noneDecision, decisionWithAllFields }; + + for (MemoryDecision original : decisions) { + BytesStreamOutput out = new BytesStreamOutput(); + original.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MemoryDecision deserialized = new MemoryDecision(in); + + assertEquals(original.getId(), deserialized.getId()); + assertEquals(original.getText(), deserialized.getText()); + assertEquals(original.getEvent(), deserialized.getEvent()); + assertEquals(original.getOldMemory(), deserialized.getOldMemory()); + } + } + + @Test + public void testToXContent() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + decisionWithAllFields.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("\"memory_id\":\"memory-123\"")); + assertTrue(jsonString.contains("\"text\":\"Updated memory text\"")); + assertTrue(jsonString.contains("\"event\":\"UPDATE\"")); + assertTrue(jsonString.contains("\"old_memory\":\"Original memory text\"")); + } + + @Test + public void testToXContentMinimal() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + decisionMinimal.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("\"memory_id\":\"memory-456\"")); + assertTrue(jsonString.contains("\"text\":\"New memory text\"")); + assertTrue(jsonString.contains("\"event\":\"ADD\"")); + assertTrue(!jsonString.contains("\"old_memory\"")); + } + + @Test + public void testParse() throws IOException { + String jsonString = "{" + + "\"memory_id\":\"memory-123\"," + + "\"text\":\"Updated memory text\"," + + "\"event\":\"UPDATE\"," + + "\"old_memory\":\"Original memory text\"" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MemoryDecision parsed = MemoryDecision.parse(parser); + + assertEquals("memory-123", parsed.getId()); + assertEquals("Updated memory text", parsed.getText()); + assertEquals(MemoryEvent.UPDATE, parsed.getEvent()); + assertEquals("Original memory text", parsed.getOldMemory()); + } + + @Test + public void testParseMinimal() throws IOException { + String jsonString = "{" + "\"memory_id\":\"memory-456\"," + "\"text\":\"New memory text\"," + "\"event\":\"ADD\"" + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MemoryDecision parsed = MemoryDecision.parse(parser); + + assertEquals("memory-456", parsed.getId()); + assertEquals("New memory text", parsed.getText()); + assertEquals(MemoryEvent.ADD, parsed.getEvent()); + assertNull(parsed.getOldMemory()); + } + + @Test + public void testParseWithAlternativeIdField() throws IOException { + // Test parsing with "id" field instead of "memory_id" + String jsonString = "{" + "\"id\":\"memory-789\"," + "\"text\":\"Test memory\"," + "\"event\":\"DELETE\"" + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MemoryDecision parsed = MemoryDecision.parse(parser); + + assertEquals("memory-789", parsed.getId()); + assertEquals("Test memory", parsed.getText()); + assertEquals(MemoryEvent.DELETE, parsed.getEvent()); + } + + @Test + public void testParseWithUnknownFields() throws IOException { + String jsonString = "{" + + "\"memory_id\":\"memory-123\"," + + "\"text\":\"Test memory\"," + + "\"event\":\"NONE\"," + + "\"unknown_field\":\"should be ignored\"," + + "\"another_unknown\":123" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MemoryDecision parsed = MemoryDecision.parse(parser); + + assertEquals("memory-123", parsed.getId()); + assertEquals("Test memory", parsed.getText()); + assertEquals(MemoryEvent.NONE, parsed.getEvent()); + } + + @Test + public void testXContentRoundTrip() throws IOException { + // Convert to XContent + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + decisionWithAllFields.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + // Parse back + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + MemoryDecision parsed = MemoryDecision.parse(parser); + + // Verify all fields match + assertEquals(decisionWithAllFields.getId(), parsed.getId()); + assertEquals(decisionWithAllFields.getText(), parsed.getText()); + assertEquals(decisionWithAllFields.getEvent(), parsed.getEvent()); + assertEquals(decisionWithAllFields.getOldMemory(), parsed.getOldMemory()); + } + + @Test + public void testDataAnnotationMethods() { + // Test @Data generated methods + MemoryDecision decision1 = MemoryDecision.builder().id("id-1").text("text-1").event(MemoryEvent.ADD).build(); + + MemoryDecision decision2 = MemoryDecision.builder().id("id-1").text("text-1").event(MemoryEvent.ADD).build(); + + // Test equals + assertEquals(decision1, decision2); + assertEquals(decision1.hashCode(), decision2.hashCode()); + + // Test setters + decision1.setId("new-id"); + decision1.setText("new-text"); + decision1.setEvent(MemoryEvent.UPDATE); + decision1.setOldMemory("old-text"); + + assertEquals("new-id", decision1.getId()); + assertEquals("new-text", decision1.getText()); + assertEquals(MemoryEvent.UPDATE, decision1.getEvent()); + assertEquals("old-text", decision1.getOldMemory()); + + // Test toString + String str = decision1.toString(); + assertTrue(str.contains("new-id")); + assertTrue(str.contains("new-text")); + assertTrue(str.contains("UPDATE")); + assertTrue(str.contains("old-text")); + } + + @Test + public void testSpecialCharactersInFields() throws IOException { + MemoryDecision specialDecision = MemoryDecision + .builder() + .id("id-with-special-chars-🚀") + .text("Text with\n\ttabs and\nnewlines and \"quotes\"") + .event(MemoryEvent.UPDATE) + .oldMemory("Old text with 'single quotes' and \\backslashes\\") + .build(); + + // Test XContent round trip + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + specialDecision.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + MemoryDecision parsed = MemoryDecision.parse(parser); + + assertEquals(specialDecision.getId(), parsed.getId()); + assertEquals(specialDecision.getText(), parsed.getText()); + assertEquals(specialDecision.getOldMemory(), parsed.getOldMemory()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryStorageConfigTests.java b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryStorageConfigTests.java new file mode 100644 index 0000000000..dc0c6db987 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryStorageConfigTests.java @@ -0,0 +1,427 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.memorycontainer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.TestHelper; + +public class MemoryStorageConfigTests { + + private MemoryStorageConfig textEmbeddingConfig; + private MemoryStorageConfig sparseEncodingConfig; + private MemoryStorageConfig minimalConfig; + + @Before + public void setUp() { + // Text embedding configuration (semantic storage enabled) + textEmbeddingConfig = MemoryStorageConfig + .builder() + .memoryIndexName("test-text-embedding-index") + .embeddingModelType(FunctionName.TEXT_EMBEDDING) + .embeddingModelId("text-embedding-model") + .llmModelId("llm-model") + .dimension(768) + .maxInferSize(8) + .build(); + + // Sparse encoding configuration (semantic storage enabled) + sparseEncodingConfig = MemoryStorageConfig + .builder() + .memoryIndexName("test-sparse-encoding-index") + .embeddingModelType(FunctionName.SPARSE_ENCODING) + .embeddingModelId("sparse-encoding-model") + .llmModelId("llm-model") + .dimension(null) // Not allowed for sparse encoding + .maxInferSize(5) + .build(); + + // Minimal configuration (semantic storage disabled) + minimalConfig = MemoryStorageConfig.builder().memoryIndexName("test-minimal-index").llmModelId("llm-model-only").build(); + } + + @Test + public void testConstructorWithBuilderTextEmbedding() { + assertNotNull(textEmbeddingConfig); + assertEquals("test-text-embedding-index", textEmbeddingConfig.getMemoryIndexName()); + assertTrue(textEmbeddingConfig.isSemanticStorageEnabled()); // Auto-determined + assertEquals(FunctionName.TEXT_EMBEDDING, textEmbeddingConfig.getEmbeddingModelType()); + assertEquals("text-embedding-model", textEmbeddingConfig.getEmbeddingModelId()); + assertEquals("llm-model", textEmbeddingConfig.getLlmModelId()); + assertEquals(Integer.valueOf(768), textEmbeddingConfig.getDimension()); + assertEquals(Integer.valueOf(8), textEmbeddingConfig.getMaxInferSize()); + } + + @Test + public void testConstructorWithBuilderSparseEncoding() { + assertNotNull(sparseEncodingConfig); + assertEquals("test-sparse-encoding-index", sparseEncodingConfig.getMemoryIndexName()); + assertTrue(sparseEncodingConfig.isSemanticStorageEnabled()); // Auto-determined + assertEquals(FunctionName.SPARSE_ENCODING, sparseEncodingConfig.getEmbeddingModelType()); + assertEquals("sparse-encoding-model", sparseEncodingConfig.getEmbeddingModelId()); + assertEquals("llm-model", sparseEncodingConfig.getLlmModelId()); + assertNull(sparseEncodingConfig.getDimension()); // Not allowed for sparse encoding + assertEquals(Integer.valueOf(5), sparseEncodingConfig.getMaxInferSize()); + } + + @Test + public void testConstructorWithBuilderMinimal() { + assertNotNull(minimalConfig); + assertEquals("test-minimal-index", minimalConfig.getMemoryIndexName()); + assertFalse(minimalConfig.isSemanticStorageEnabled()); // Auto-determined as false + assertNull(minimalConfig.getEmbeddingModelType()); + assertNull(minimalConfig.getEmbeddingModelId()); + assertEquals("llm-model-only", minimalConfig.getLlmModelId()); + assertNull(minimalConfig.getDimension()); + assertEquals(Integer.valueOf(5), minimalConfig.getMaxInferSize()); // Default value when llmModelId is present + } + + @Test + public void testConstructorWithAllParameters() { + MemoryStorageConfig config = new MemoryStorageConfig( + "test-index", + false, // This will be overridden by auto-determination + FunctionName.TEXT_EMBEDDING, + "embedding-model", + "llm-model", + 512, + 7 + ); + + assertEquals("test-index", config.getMemoryIndexName()); + assertTrue(config.isSemanticStorageEnabled()); // Auto-determined as true + assertEquals(FunctionName.TEXT_EMBEDDING, config.getEmbeddingModelType()); + assertEquals("embedding-model", config.getEmbeddingModelId()); + assertEquals("llm-model", config.getLlmModelId()); + assertEquals(Integer.valueOf(512), config.getDimension()); + assertEquals(Integer.valueOf(7), config.getMaxInferSize()); + } + + @Test + public void testDefaultMaxInferSize() { + // Test with llmModelId present - should get default value + MemoryStorageConfig configWithLlm = MemoryStorageConfig + .builder() + .memoryIndexName("test-index") + .embeddingModelType(FunctionName.TEXT_EMBEDDING) + .embeddingModelId("embedding-model") + .llmModelId("llm-model") + .dimension(768) + // maxInferSize not set, should use default + .build(); + + assertEquals(Integer.valueOf(MemoryContainerConstants.MAX_INFER_SIZE_DEFAULT_VALUE), configWithLlm.getMaxInferSize()); + + // Test without llmModelId - should be null + MemoryStorageConfig configWithoutLlm = MemoryStorageConfig + .builder() + .memoryIndexName("test-index") + .embeddingModelType(FunctionName.TEXT_EMBEDDING) + .embeddingModelId("embedding-model") + .dimension(768) + .build(); + + assertNull(configWithoutLlm.getMaxInferSize()); + } + + @Test + public void testStreamInputOutput() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + textEmbeddingConfig.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MemoryStorageConfig parsedConfig = new MemoryStorageConfig(streamInput); + + assertEquals(textEmbeddingConfig.getMemoryIndexName(), parsedConfig.getMemoryIndexName()); + assertEquals(textEmbeddingConfig.isSemanticStorageEnabled(), parsedConfig.isSemanticStorageEnabled()); + assertEquals(textEmbeddingConfig.getEmbeddingModelType(), parsedConfig.getEmbeddingModelType()); + assertEquals(textEmbeddingConfig.getEmbeddingModelId(), parsedConfig.getEmbeddingModelId()); + assertEquals(textEmbeddingConfig.getLlmModelId(), parsedConfig.getLlmModelId()); + assertEquals(textEmbeddingConfig.getDimension(), parsedConfig.getDimension()); + assertEquals(textEmbeddingConfig.getMaxInferSize(), parsedConfig.getMaxInferSize()); + } + + @Test + public void testStreamInputOutputWithNullValues() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + minimalConfig.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MemoryStorageConfig parsedConfig = new MemoryStorageConfig(streamInput); + + assertEquals(minimalConfig.getMemoryIndexName(), parsedConfig.getMemoryIndexName()); + assertEquals(minimalConfig.isSemanticStorageEnabled(), parsedConfig.isSemanticStorageEnabled()); + assertNull(parsedConfig.getEmbeddingModelType()); + assertNull(parsedConfig.getEmbeddingModelId()); + assertEquals(minimalConfig.getLlmModelId(), parsedConfig.getLlmModelId()); + assertNull(parsedConfig.getDimension()); + assertEquals(Integer.valueOf(5), parsedConfig.getMaxInferSize()); // Default value when llmModelId is present + } + + @Test + public void testToXContentWithSemanticStorageEnabled() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + textEmbeddingConfig.toXContent(builder, EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + + assertNotNull(jsonStr); + // Verify semantic storage enabled fields are present + assertTrue(jsonStr.contains("\"memory_index_name\":\"test-text-embedding-index\"")); + assertTrue(jsonStr.contains("\"semantic_storage_enabled\":true")); + assertTrue(jsonStr.contains("\"embedding_model_type\":\"TEXT_EMBEDDING\"")); + assertTrue(jsonStr.contains("\"embedding_model_id\":\"text-embedding-model\"")); + assertTrue(jsonStr.contains("\"llm_model_id\":\"llm-model\"")); + assertTrue(jsonStr.contains("\"dimension\":768")); + assertTrue(jsonStr.contains("\"max_infer_size\":8")); + } + + @Test + public void testToXContentWithSemanticStorageDisabled() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + minimalConfig.toXContent(builder, EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + + assertNotNull(jsonStr); + // Verify only basic fields are present + assertTrue(jsonStr.contains("\"memory_index_name\":\"test-minimal-index\"")); + assertTrue(jsonStr.contains("\"semantic_storage_enabled\":false")); + assertTrue(jsonStr.contains("\"llm_model_id\":\"llm-model-only\"")); + // Verify semantic storage fields are NOT present + assertFalse(jsonStr.contains("\"embedding_model_type\"")); + assertFalse(jsonStr.contains("\"embedding_model_id\"")); + assertFalse(jsonStr.contains("\"dimension\"")); + // max_infer_size is present because llmModelId is set + assertTrue(jsonStr.contains("\"max_infer_size\":5")); + } + + @Test + public void testParseFromXContentWithAllFields() throws IOException { + String jsonStr = "{" + "\"memory_index_name\":\"parsed-index\"," + "\"semantic_storage_enabled\":true," + // This field is ignored + "\"embedding_model_type\":\"TEXT_EMBEDDING\"," + + "\"embedding_model_id\":\"parsed-embedding-model\"," + + "\"llm_model_id\":\"parsed-llm-model\"," + + "\"dimension\":1024," + + "\"max_infer_size\":9" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + MemoryStorageConfig parsedConfig = MemoryStorageConfig.parse(parser); + + assertEquals("parsed-index", parsedConfig.getMemoryIndexName()); + assertTrue(parsedConfig.isSemanticStorageEnabled()); // Auto-determined + assertEquals(FunctionName.TEXT_EMBEDDING, parsedConfig.getEmbeddingModelType()); + assertEquals("parsed-embedding-model", parsedConfig.getEmbeddingModelId()); + assertEquals("parsed-llm-model", parsedConfig.getLlmModelId()); + assertEquals(Integer.valueOf(1024), parsedConfig.getDimension()); + assertEquals(Integer.valueOf(9), parsedConfig.getMaxInferSize()); + } + + @Test + public void testParseFromXContentWithPartialFields() throws IOException { + String jsonStr = "{" + "\"memory_index_name\":\"partial-index\"," + "\"llm_model_id\":\"partial-llm-model\"" + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + MemoryStorageConfig parsedConfig = MemoryStorageConfig.parse(parser); + + assertEquals("partial-index", parsedConfig.getMemoryIndexName()); + assertFalse(parsedConfig.isSemanticStorageEnabled()); // Auto-determined as false + assertNull(parsedConfig.getEmbeddingModelType()); + assertNull(parsedConfig.getEmbeddingModelId()); + assertEquals("partial-llm-model", parsedConfig.getLlmModelId()); + assertNull(parsedConfig.getDimension()); + assertEquals(Integer.valueOf(5), parsedConfig.getMaxInferSize()); // Default value when llmModelId is present + } + + @Test + public void testParseFromXContentWithUnknownFields() throws IOException { + String jsonStr = "{" + + "\"memory_index_name\":\"unknown-test-index\"," + + "\"unknown_field\":\"unknown_value\"," + + "\"llm_model_id\":\"test-llm\"," + + "\"another_unknown\":123" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + MemoryStorageConfig parsedConfig = MemoryStorageConfig.parse(parser); + + assertEquals("unknown-test-index", parsedConfig.getMemoryIndexName()); + assertEquals("test-llm", parsedConfig.getLlmModelId()); + // Unknown fields should be ignored + assertFalse(parsedConfig.isSemanticStorageEnabled()); + } + + @Test + public void testCompleteRoundTrip() throws IOException { + // Test complete round trip: object -> JSON -> parse -> compare + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + textEmbeddingConfig.toXContent(builder, EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + MemoryStorageConfig parsedConfig = MemoryStorageConfig.parse(parser); + + assertEquals(textEmbeddingConfig.getMemoryIndexName(), parsedConfig.getMemoryIndexName()); + assertEquals(textEmbeddingConfig.isSemanticStorageEnabled(), parsedConfig.isSemanticStorageEnabled()); + assertEquals(textEmbeddingConfig.getEmbeddingModelType(), parsedConfig.getEmbeddingModelType()); + assertEquals(textEmbeddingConfig.getEmbeddingModelId(), parsedConfig.getEmbeddingModelId()); + assertEquals(textEmbeddingConfig.getLlmModelId(), parsedConfig.getLlmModelId()); + assertEquals(textEmbeddingConfig.getDimension(), parsedConfig.getDimension()); + assertEquals(textEmbeddingConfig.getMaxInferSize(), parsedConfig.getMaxInferSize()); + } + + @Test + public void testEqualsAndHashCode() { + MemoryStorageConfig config1 = MemoryStorageConfig + .builder() + .memoryIndexName("test-index") + .embeddingModelType(FunctionName.TEXT_EMBEDDING) + .embeddingModelId("embedding-model") + .llmModelId("llm-model") + .dimension(768) + .maxInferSize(5) + .build(); + + MemoryStorageConfig config2 = MemoryStorageConfig + .builder() + .memoryIndexName("test-index") + .embeddingModelType(FunctionName.TEXT_EMBEDDING) + .embeddingModelId("embedding-model") + .llmModelId("llm-model") + .dimension(768) + .maxInferSize(5) + .build(); + + MemoryStorageConfig config3 = MemoryStorageConfig + .builder() + .memoryIndexName("different-index") + .embeddingModelType(FunctionName.TEXT_EMBEDDING) + .embeddingModelId("embedding-model") + .llmModelId("llm-model") + .dimension(768) + .maxInferSize(5) + .build(); + + assertEquals(config1, config2); + assertEquals(config1.hashCode(), config2.hashCode()); + assertFalse(config1.equals(config3)); + assertTrue(config1.hashCode() != config3.hashCode()); + } + + @Test + public void testSettersAndGetters() { + MemoryStorageConfig config = new MemoryStorageConfig(null, false, null, null, null, null, null); + + config.setMemoryIndexName("new-index"); + config.setSemanticStorageEnabled(true); + config.setEmbeddingModelType(FunctionName.SPARSE_ENCODING); + config.setEmbeddingModelId("new-embedding-model"); + config.setLlmModelId("new-llm-model"); + config.setDimension(1024); + config.setMaxInferSize(10); + + assertEquals("new-index", config.getMemoryIndexName()); + assertTrue(config.isSemanticStorageEnabled()); + assertEquals(FunctionName.SPARSE_ENCODING, config.getEmbeddingModelType()); + assertEquals("new-embedding-model", config.getEmbeddingModelId()); + assertEquals("new-llm-model", config.getLlmModelId()); + assertEquals(Integer.valueOf(1024), config.getDimension()); + assertEquals(Integer.valueOf(10), config.getMaxInferSize()); + } + + // Validation Tests + + @Test(expected = IllegalArgumentException.class) + public void testValidationEmbeddingModelIdWithoutType() { + MemoryStorageConfig + .builder() + .memoryIndexName("test-index") + .embeddingModelId("embedding-model") // Missing embeddingModelType + .build(); + } + + @Test(expected = IllegalArgumentException.class) + public void testValidationEmbeddingModelTypeWithoutId() { + MemoryStorageConfig + .builder() + .memoryIndexName("test-index") + .embeddingModelType(FunctionName.TEXT_EMBEDDING) // Missing embeddingModelId + .build(); + } + + @Test(expected = IllegalArgumentException.class) + public void testValidationTextEmbeddingWithoutDimension() { + MemoryStorageConfig + .builder() + .memoryIndexName("test-index") + .embeddingModelType(FunctionName.TEXT_EMBEDDING) + .embeddingModelId("embedding-model") + // Missing dimension for TEXT_EMBEDDING + .build(); + } + + @Test(expected = IllegalArgumentException.class) + public void testValidationSparseEncodingWithDimension() { + MemoryStorageConfig + .builder() + .memoryIndexName("test-index") + .embeddingModelType(FunctionName.SPARSE_ENCODING) + .embeddingModelId("embedding-model") + .dimension(768) // Not allowed for SPARSE_ENCODING + .build(); + } + + @Test(expected = IllegalArgumentException.class) + public void testValidationMaxInferSizeExceedsLimit() { + MemoryStorageConfig + .builder() + .memoryIndexName("test-index") + .embeddingModelType(FunctionName.TEXT_EMBEDDING) + .embeddingModelId("embedding-model") + .dimension(768) + .maxInferSize(11) // Exceeds limit of 10 + .build(); + } + + @Test(expected = IllegalArgumentException.class) + public void testValidationInvalidEmbeddingModelType() { + MemoryStorageConfig + .builder() + .memoryIndexName("test-index") + .embeddingModelType(FunctionName.KMEANS) // Invalid embedding model type + .embeddingModelId("embedding-model") + .dimension(768) + .build(); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryTypeTest.java b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryTypeTest.java new file mode 100644 index 0000000000..35fd1c1865 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryTypeTest.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.memorycontainer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; + +import org.junit.Test; + +public class MemoryTypeTest { + + @Test + public void testEnumValues() { + // Test all enum values exist + assertEquals(2, MemoryType.values().length); + assertEquals(MemoryType.RAW_MESSAGE, MemoryType.valueOf("RAW_MESSAGE")); + assertEquals(MemoryType.FACT, MemoryType.valueOf("FACT")); + } + + @Test + public void testGetValue() { + assertEquals("RAW_MESSAGE", MemoryType.RAW_MESSAGE.getValue()); + assertEquals("FACT", MemoryType.FACT.getValue()); + } + + @Test + public void testToString() { + assertEquals("RAW_MESSAGE", MemoryType.RAW_MESSAGE.toString()); + assertEquals("FACT", MemoryType.FACT.toString()); + } + + @Test + public void testFromString_ValidValues() { + // Test exact match + assertEquals(MemoryType.RAW_MESSAGE, MemoryType.fromString("RAW_MESSAGE")); + assertEquals(MemoryType.FACT, MemoryType.fromString("FACT")); + + // Test case insensitive + assertEquals(MemoryType.RAW_MESSAGE, MemoryType.fromString("raw_message")); + assertEquals(MemoryType.FACT, MemoryType.fromString("FaCt")); + assertEquals(MemoryType.RAW_MESSAGE, MemoryType.fromString("Raw_Message")); + } + + @Test + public void testFromString_Null() { + assertNull(MemoryType.fromString(null)); + } + + @Test + public void testFromString_InvalidValue() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> MemoryType.fromString("INVALID_TYPE")); + assertEquals("Invalid memory type: INVALID_TYPE. Must be either RAW_MESSAGE or FACT", exception.getMessage()); + } + + @Test + public void testFromString_EmptyString() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> MemoryType.fromString("")); + assertEquals("Invalid memory type: . Must be either RAW_MESSAGE or FACT", exception.getMessage()); + } + + @Test + public void testFromString_Whitespace() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> MemoryType.fromString(" ")); + assertEquals("Invalid memory type: . Must be either RAW_MESSAGE or FACT", exception.getMessage()); + } + + @Test + public void testEnumConsistency() { + // Verify each enum's getValue() returns its name + for (MemoryType type : MemoryType.values()) { + assertNotNull(type.getValue()); + assertEquals(type.getValue(), type.toString()); + assertEquals(type, MemoryType.fromString(type.getValue())); + } + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/model/BaseModelConfigTests.java b/common/src/test/java/org/opensearch/ml/common/model/BaseModelConfigTests.java index b376f4fcdb..84148aaaf5 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/BaseModelConfigTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/BaseModelConfigTests.java @@ -44,6 +44,10 @@ public void setUp() { .modelType("testModelType") .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") .additionalConfig(additionalConfig) + .embeddingDimension(768) + .frameworkType(BaseModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .passagePrefix("passage: ") + .queryPrefix("query: ") .build(); function = parser -> { @@ -61,7 +65,8 @@ public void toXContent() throws IOException { config.toXContent(builder, EMPTY_PARAMS); String configContent = TestHelper.xContentBuilderToString(builder); assertEquals( - "{\"model_type\":\"testModelType\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"additional_config\":{\"space_type\":\"l2\"}}", + "{\"model_type\":\"testModelType\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"additional_config\":{\"space_type\":\"l2\"}," + + "\"embedding_dimension\":768,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"query_prefix\":\"query: \",\"passage_prefix\":\"passage: \"}", configContent ); } @@ -93,6 +98,11 @@ public void testStreamInputVersionAfter_3_1_0() throws IOException { assertEquals(config.getModelType(), parsedConfig.getModelType()); assertEquals(config.getAllConfig(), parsedConfig.getAllConfig()); assertEquals(config.getAdditionalConfig(), parsedConfig.getAdditionalConfig()); + assertEquals(config.getEmbeddingDimension(), parsedConfig.getEmbeddingDimension()); + assertEquals(config.getFrameworkType(), parsedConfig.getFrameworkType()); + assertEquals(config.getFrameworkType(), parsedConfig.getFrameworkType()); + assertEquals(config.getQueryPrefix(), parsedConfig.getQueryPrefix()); + assertEquals(config.getPassagePrefix(), parsedConfig.getPassagePrefix()); assertEquals(config.getWriteableName(), parsedConfig.getWriteableName()); } @@ -109,6 +119,11 @@ public void testStreamInputVersionBefore_3_1_0() throws IOException { assertEquals(config.getModelType(), parsedConfig.getModelType()); assertEquals(config.getAllConfig(), parsedConfig.getAllConfig()); assertNull(parsedConfig.getAdditionalConfig()); + assertEquals(config.getEmbeddingDimension(), parsedConfig.getEmbeddingDimension()); + assertEquals(config.getFrameworkType(), parsedConfig.getFrameworkType()); + assertEquals(config.getFrameworkType(), parsedConfig.getFrameworkType()); + assertEquals(config.getQueryPrefix(), parsedConfig.getQueryPrefix()); + assertEquals(config.getPassagePrefix(), parsedConfig.getPassagePrefix()); assertEquals(config.getWriteableName(), parsedConfig.getWriteableName()); } diff --git a/common/src/test/java/org/opensearch/ml/common/model/RemoteModelConfigTests.java b/common/src/test/java/org/opensearch/ml/common/model/RemoteModelConfigTests.java index e05dc0dd09..0890e5158a 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/RemoteModelConfigTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/RemoteModelConfigTests.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common.model; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import java.io.IOException; @@ -17,6 +18,7 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.opensearch.Version; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.io.stream.StreamInput; @@ -171,4 +173,48 @@ public void readInputStream(RemoteModelConfig config) throws IOException { assertEquals(config.getAdditionalConfig(), parsedConfig.getAdditionalConfig()); assertEquals(config.getWriteableName(), parsedConfig.getWriteableName()); } + + @Test + public void readInputStream_VersionCompatibility() throws IOException { + // Test with older version + BytesStreamOutput oldOut = new BytesStreamOutput(); + Version oldVersion = Version.V_3_0_0; + oldOut.setVersion(oldVersion); + config.writeTo(oldOut); + + StreamInput oldIn = oldOut.bytes().streamInput(); + oldIn.setVersion(oldVersion); + RemoteModelConfig oldConfig = new RemoteModelConfig(oldIn); + + // Verify essential fields with old version + assertEquals(config.getModelType(), oldConfig.getModelType()); + assertEquals(config.getAllConfig(), oldConfig.getAllConfig()); + assertEquals(config.getEmbeddingDimension(), oldConfig.getEmbeddingDimension()); + assertEquals(config.getFrameworkType(), oldConfig.getFrameworkType()); + assertEquals(config.getPoolingMode(), oldConfig.getPoolingMode()); + assertEquals(config.isNormalizeResult(), oldConfig.isNormalizeResult()); + assertEquals(config.getModelMaxLength(), oldConfig.getModelMaxLength()); + assertNull(oldConfig.getAdditionalConfig()); + assertEquals(config.getWriteableName(), oldConfig.getWriteableName()); + + // Test with newer version + BytesStreamOutput currentOut = new BytesStreamOutput(); + currentOut.setVersion(Version.V_3_1_0); + config.writeTo(currentOut); + + StreamInput currentIn = currentOut.bytes().streamInput(); + currentIn.setVersion(Version.V_3_1_0); + RemoteModelConfig newConfig = new RemoteModelConfig(currentIn); + + // Verify fields with current version + assertEquals(config.getModelType(), newConfig.getModelType()); + assertEquals(config.getAllConfig(), newConfig.getAllConfig()); + assertEquals(config.getEmbeddingDimension(), newConfig.getEmbeddingDimension()); + assertEquals(config.getFrameworkType(), newConfig.getFrameworkType()); + assertEquals(config.getPoolingMode(), newConfig.getPoolingMode()); + assertEquals(config.isNormalizeResult(), newConfig.isNormalizeResult()); + assertEquals(config.getModelMaxLength(), newConfig.getModelMaxLength()); + assertEquals(config.getAdditionalConfig(), newConfig.getAdditionalConfig()); + assertEquals(config.getWriteableName(), newConfig.getWriteableName()); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java b/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java index 482e9c0fe8..e6d873dd59 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common.model; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import java.io.IOException; @@ -17,6 +18,7 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.opensearch.Version; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.io.stream.StreamInput; @@ -118,4 +120,46 @@ public void readInputStream(TextEmbeddingModelConfig config) throws IOException assertEquals(config.getFrameworkType(), parsedConfig.getFrameworkType()); assertEquals(config.getWriteableName(), parsedConfig.getWriteableName()); } + + @Test + public void readInputStream_VersionCompatibility() throws IOException { + // Test with older version + BytesStreamOutput oldOut = new BytesStreamOutput(); + Version oldVersion = Version.V_3_0_0; + oldOut.setVersion(oldVersion); + config.writeTo(oldOut); + + StreamInput oldIn = oldOut.bytes().streamInput(); + oldIn.setVersion(oldVersion); + TextEmbeddingModelConfig oldConfig = new TextEmbeddingModelConfig(oldIn); + + // Verify essential fields with old version + assertEquals(config.getModelType(), oldConfig.getModelType()); + assertEquals(config.getAllConfig(), oldConfig.getAllConfig()); + assertNull(oldConfig.getAdditionalConfig()); + assertEquals(config.getEmbeddingDimension(), oldConfig.getEmbeddingDimension()); + assertEquals(config.getFrameworkType(), oldConfig.getFrameworkType()); + assertEquals(config.getWriteableName(), oldConfig.getWriteableName()); + assertEquals(config.getQueryPrefix(), oldConfig.getQueryPrefix()); + assertEquals(config.getPassagePrefix(), oldConfig.getPassagePrefix()); + + // Test with newer version + BytesStreamOutput currentOut = new BytesStreamOutput(); + currentOut.setVersion(Version.V_3_1_0); + config.writeTo(currentOut); + + StreamInput currentIn = currentOut.bytes().streamInput(); + currentIn.setVersion(Version.V_3_1_0); + TextEmbeddingModelConfig newConfig = new TextEmbeddingModelConfig(currentIn); + + // Verify fields with current version + assertEquals(config.getModelType(), newConfig.getModelType()); + assertEquals(config.getAllConfig(), newConfig.getAllConfig()); + assertEquals(config.getAdditionalConfig(), newConfig.getAdditionalConfig()); + assertEquals(config.getEmbeddingDimension(), newConfig.getEmbeddingDimension()); + assertEquals(config.getFrameworkType(), newConfig.getFrameworkType()); + assertEquals(config.getWriteableName(), newConfig.getWriteableName()); + assertEquals(config.getQueryPrefix(), newConfig.getQueryPrefix()); + assertEquals(config.getPassagePrefix(), newConfig.getPassagePrefix()); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/settings/MLCommonsSettingsTests.java b/common/src/test/java/org/opensearch/ml/common/settings/MLCommonsSettingsTests.java index cb0294ad20..33c41959b5 100644 --- a/common/src/test/java/org/opensearch/ml/common/settings/MLCommonsSettingsTests.java +++ b/common/src/test/java/org/opensearch/ml/common/settings/MLCommonsSettingsTests.java @@ -70,4 +70,36 @@ public void testRemoteInferenceEnabledByDefault() { public void testAllowModelUrlDisabledByDefault() { assertFalse(MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL.getDefault(null)); } + + @Test + public void testAgenticMemoryDisabledByDefault() { + assertFalse(MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_ENABLED.getDefault(null)); + } + + @Test + public void testAgenticMemorySettingProperties() { + // Test setting key + assertEquals("plugins.ml_commons.agentic_memory_enabled", MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_ENABLED.getKey()); + + // Test setting is dynamic + assertTrue( + MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_ENABLED + .getProperties() + .contains(org.opensearch.common.settings.Setting.Property.Dynamic) + ); + + // Test setting is node scope + assertTrue( + MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_ENABLED + .getProperties() + .contains(org.opensearch.common.settings.Setting.Property.NodeScope) + ); + } + + @Test + public void testAgenticMemoryDisabledMessage() { + String expectedMessage = + "The Agentic Memory APIs are not enabled. To enable, please update the setting plugins.ml_commons.agentic_memory_enabled"; + assertEquals(expectedMessage, MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java b/common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java index e1dc2b2030..568f1aa4a1 100644 --- a/common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java +++ b/common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java @@ -43,7 +43,15 @@ public void setUp() { MLCommonsSettings.ML_COMMONS_MCP_SERVER_ENABLED, MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, MLCommonsSettings.ML_COMMONS_METRIC_COLLECTION_ENABLED, - MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED + MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED, + MLCommonsSettings.ML_COMMONS_EXECUTE_TOOL_ENABLED, + MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED, + MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED, + MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_ENABLED, + MLCommonsSettings.ML_COMMONS_TRACING_ENABLED, + MLCommonsSettings.ML_COMMONS_AGENT_TRACING_ENABLED, + MLCommonsSettings.ML_COMMONS_CONNECTOR_TRACING_ENABLED, + MLCommonsSettings.ML_COMMONS_MODEL_TRACING_ENABLED ) ); when(mockClusterService.getClusterSettings()).thenReturn(mockClusterSettings); @@ -65,6 +73,13 @@ public void testDefaults_allFeaturesEnabled() { .put("plugins.ml_commons.rag_pipeline_feature_enabled", true) .put("plugins.ml_commons.metrics_collection_enabled", true) .put("plugins.ml_commons.metrics_static_collection_enabled", true) + .put("plugins.ml_commons.mcp_connector_enabled", true) + .put("plugins.ml_commons.agentic_search_enabled", true) + .put("plugins.ml_commons.agentic_memory_enabled", true) + .put("plugins.ml_commons.tracing_enabled", true) + .put("plugins.ml_commons.agent_tracing_enabled", true) + .put("plugins.ml_commons.connector_tracing_enabled", true) + .put("plugins.ml_commons.model_tracing_enabled", true) .build(); MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings); @@ -81,6 +96,13 @@ public void testDefaults_allFeaturesEnabled() { assertTrue(setting.isRagSearchPipelineEnabled()); assertTrue(setting.isMetricCollectionEnabled()); assertTrue(setting.isStaticMetricCollectionEnabled()); + assertTrue(setting.isMcpConnectorEnabled()); + assertTrue(setting.isAgenticSearchEnabled()); + assertTrue(setting.isAgenticMemoryEnabled()); + assertTrue(setting.isTracingEnabled()); + assertTrue(setting.isAgentTracingEnabled()); + assertTrue(setting.isConnectorTracingEnabled()); + assertTrue(setting.isModelTracingEnabled()); } @Test @@ -99,6 +121,13 @@ public void testDefaults_someFeaturesDisabled() { .put("plugins.ml_commons.rag_pipeline_feature_enabled", false) .put("plugins.ml_commons.metrics_collection_enabled", false) .put("plugins.ml_commons.metrics_static_collection_enabled", false) + .put("plugins.ml_commons.mcp_connector_enabled", false) + .put("plugins.ml_commons.agentic_search_enabled", false) + .put("plugins.ml_commons.agentic_memory_enabled", false) + .put("plugins.ml_commons.tracing_enabled", false) + .put("plugins.ml_commons.agent_tracing_enabled", false) + .put("plugins.ml_commons.connector_tracing_enabled", false) + .put("plugins.ml_commons.model_tracing_enabled", false) .build(); MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings); @@ -115,6 +144,13 @@ public void testDefaults_someFeaturesDisabled() { assertFalse(setting.isRagSearchPipelineEnabled()); assertFalse(setting.isMetricCollectionEnabled()); assertFalse(setting.isStaticMetricCollectionEnabled()); + assertFalse(setting.isMcpConnectorEnabled()); + assertFalse(setting.isAgenticSearchEnabled()); + assertFalse(setting.isAgenticMemoryEnabled()); + assertFalse(setting.isTracingEnabled()); + assertFalse(setting.isAgentTracingEnabled()); + assertFalse(setting.isConnectorTracingEnabled()); + assertFalse(setting.isModelTracingEnabled()); } @Test @@ -129,4 +165,29 @@ public void testMultiTenancyChangeNotifiesListeners() { setting.notifyMultiTenancyListeners(true); verify(mockListener).onMultiTenancyEnabledChanged(true); } + + @Test + public void testAgenticMemoryEnabledByDefault() { + Settings settings = Settings.EMPTY; + MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings); + + // Should be disabled by default + assertFalse(setting.isAgenticMemoryEnabled()); + } + + @Test + public void testAgenticMemoryCanBeEnabled() { + Settings settings = Settings.builder().put("plugins.ml_commons.agentic_memory_enabled", true).build(); + + MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings); + assertTrue(setting.isAgenticMemoryEnabled()); + } + + @Test + public void testAgenticMemoryCanBeDisabled() { + Settings settings = Settings.builder().put("plugins.ml_commons.agentic_memory_enabled", false).build(); + + MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings); + assertFalse(setting.isAgenticMemoryEnabled()); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java index afac271ccf..72eb035279 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java @@ -46,27 +46,31 @@ public class MLAgentUpdateInputTest { @Before public void setUp() throws Exception { - LLMSpec llmSpec = LLMSpec.builder().modelId("test-model-id").parameters(Map.of("max_iteration", "5")).build(); MLToolSpec toolSpec = MLToolSpec .builder() .name("test-tool") .type("MLModelTool") .parameters(Map.of("model_id", "test-model-id")) .build(); - MLMemorySpec memorySpec = MLMemorySpec.builder().type("conversation_index").build(); Map parameters = new HashMap<>(); parameters.put("_llm_interface", "test"); + Map llmParameters = new HashMap<>(); + llmParameters.put("max_iteration", "5"); + updateAgentInput = MLAgentUpdateInput .builder() .agentId("test-agent-id") .name("test-agent") .description("test description") - .llm(llmSpec) + .llmModelId("test-model-id") + .llmParameters(llmParameters) .tools(Collections.singletonList(toolSpec)) .parameters(parameters) - .memory(memorySpec) + .memoryType("conversation_index") + .memorySessionId("test-session") + .memoryWindowSize(10) .appType("rag") .lastUpdateTime(Instant.ofEpochMilli(1)) .build(); @@ -88,8 +92,7 @@ public void testToXContent() throws Exception { @Test public void testValidationWithInvalidMemoryType() { IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> { - MLMemorySpec invalidMemorySpec = MLMemorySpec.builder().type("invalid_type").build(); - MLAgentUpdateInput.builder().agentId("test-agent-id").name("test-agent").memory(invalidMemorySpec).build(); + MLAgentUpdateInput.builder().agentId("test-agent-id").name("test-agent").memoryType("invalid_type").build(); }); assertEquals("Invalid memory type: invalid_type", e.getMessage()); } @@ -122,14 +125,82 @@ public void testValidationWithDuplicateTools() { MLAgentUpdateInput.builder().agentId("test-agent-id").name("test-agent").tools(Arrays.asList(tool1, tool2)).build(); }); assertEquals("Duplicate tool defined: tool1", e.getMessage()); + } + + @Test + public void testValidationWithEmptyAgentName() { + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> { + MLAgentUpdateInput.builder().agentId("test-agent-id").name("").build(); + }); + assertTrue(e.getMessage().contains("Agent name cannot be empty")); + } + + @Test + public void testValidationWithBlankAgentName() { + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> { + MLAgentUpdateInput.builder().agentId("test-agent-id").name(" ").build(); + }); + assertTrue(e.getMessage().contains("Agent name cannot be empty")); + } + + @Test + public void testValidationWithTooLongAgentName() { + String longName = "a".repeat(MLAgent.AGENT_NAME_MAX_LENGTH + 1); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> { + MLAgentUpdateInput.builder().agentId("test-agent-id").name(longName).build(); + }); + assertTrue(e.getMessage().contains("exceed max length")); + } + + @Test + public void testValidationWithMaxLengthAgentName() { + String maxLengthName = "a".repeat(MLAgent.AGENT_NAME_MAX_LENGTH); + // Should not throw exception + MLAgentUpdateInput input = MLAgentUpdateInput.builder().agentId("test-agent-id").name(maxLengthName).build(); + assertEquals(maxLengthName, input.getName()); + } + + @Test + public void testValidationWithNullAgentName() { + // Should not throw exception - null names are allowed + MLAgentUpdateInput input = MLAgentUpdateInput.builder().agentId("test-agent-id").name(null).build(); + assertNull(input.getName()); + } + + @Test + public void testValidationWithDuplicateToolsByType() { + // Test duplicate tools identified by type when name is null + MLToolSpec tool1 = MLToolSpec.builder().type("duplicate_type").build(); + MLToolSpec tool2 = MLToolSpec.builder().type("duplicate_type").build(); + + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> { + MLAgentUpdateInput.builder().agentId("test-agent-id").tools(Arrays.asList(tool1, tool2)).build(); + }); + assertEquals("Duplicate tool defined: duplicate_type", e.getMessage()); + } + + @Test + public void testValidationWithValidUniqueTools() { + MLToolSpec tool1 = MLToolSpec.builder().name("tool1").type("type1").build(); + MLToolSpec tool2 = MLToolSpec.builder().name("tool2").type("type2").build(); + MLToolSpec tool3 = MLToolSpec.builder().type("type3").build(); // No name, uses type + + // Should not throw exception + MLAgentUpdateInput input = MLAgentUpdateInput.builder().agentId("test-agent-id").tools(Arrays.asList(tool1, tool2, tool3)).build(); + + assertEquals(3, input.getTools().size()); + } + @Test + public void testValidationWithDuplicateToolsByTypeOnly() { + // Test duplicate tools identified by type when name is null MLToolSpec tool3 = MLToolSpec.builder().type("type3").build(); MLToolSpec tool4 = MLToolSpec.builder().type("type3").build(); - e = assertThrows(IllegalArgumentException.class, () -> { + IllegalArgumentException e2 = assertThrows(IllegalArgumentException.class, () -> { MLAgentUpdateInput.builder().agentId("test-agent-id").name("test-agent").tools(Arrays.asList(tool3, tool4)).build(); }); - assertEquals("Duplicate tool defined: type3", e.getMessage()); + assertEquals("Duplicate tool defined: type3", e2.getMessage()); } @Test @@ -192,7 +263,9 @@ public void testParseSuccess() throws Exception { "chat_history": "test" }, "memory": { - "type": "conversation_index" + "type": "conversation_index", + "session_id": "test-session", + "window_size": 5 }, "app_type": "rag" } @@ -200,11 +273,14 @@ public void testParseSuccess() throws Exception { testParseFromJsonString(inputStr, parsedInput -> { assertEquals("test-agent", parsedInput.getName()); assertEquals("test description", parsedInput.getDescription()); - assertEquals("test-model-id", parsedInput.getLlm().getModelId()); + assertEquals("test-model-id", parsedInput.getLlmModelId()); + assertEquals("5", parsedInput.getLlmParameters().get("max_iteration")); assertEquals(1, parsedInput.getTools().size()); assertEquals("test-tool", parsedInput.getTools().getFirst().getName()); assertEquals("test", parsedInput.getParameters().get("chat_history")); - assertEquals("conversation_index", parsedInput.getMemory().getType()); + assertEquals("conversation_index", parsedInput.getMemoryType()); + assertEquals("test-session", parsedInput.getMemorySessionId()); + assertEquals(Integer.valueOf(5), parsedInput.getMemoryWindowSize()); assertEquals("rag", parsedInput.getAppType()); }); } @@ -249,26 +325,46 @@ public void testToMLAgent() { assertEquals(originalAgent.getIsHidden(), updatedAgent.getIsHidden()); assertEquals(updateAgentInput.getName(), updatedAgent.getName()); assertEquals(updateAgentInput.getDescription(), updatedAgent.getDescription()); - assertEquals(updateAgentInput.getLlm(), updatedAgent.getLlm()); + // Check LLM fields separately since we now use separate fields + assertEquals(updateAgentInput.getLlmModelId(), updatedAgent.getLlm().getModelId()); + assertEquals(updateAgentInput.getLlmParameters(), updatedAgent.getLlm().getParameters()); assertEquals(updateAgentInput.getTools(), updatedAgent.getTools()); assertEquals(updateAgentInput.getParameters(), updatedAgent.getParameters()); - assertEquals(updateAgentInput.getMemory(), updatedAgent.getMemory()); + // Check memory fields separately since we now use separate fields + assertEquals(updateAgentInput.getMemoryType(), updatedAgent.getMemory().getType()); + assertEquals(updateAgentInput.getMemorySessionId(), updatedAgent.getMemory().getSessionId()); + assertEquals(updateAgentInput.getMemoryWindowSize(), updatedAgent.getMemory().getWindowSize()); assertEquals(updateAgentInput.getLastUpdateTime(), updatedAgent.getLastUpdateTime()); assertEquals(updateAgentInput.getAppType(), updatedAgent.getAppType()); } @Test public void testReadInputStreamSuccessWithNullFields() throws IOException { - updateAgentInput.setLlm(null); - updateAgentInput.setTools(null); - updateAgentInput.setParameters(null); - updateAgentInput.setMemory(null); + // Create a new input with null LLM fields + MLAgentUpdateInput inputWithNulls = MLAgentUpdateInput + .builder() + .agentId("test-agent-id") + .name("test-agent") + .description("test description") + .llmModelId(null) + .llmParameters(null) + .tools(null) + .parameters(null) + .memoryType(null) + .memorySessionId(null) + .memoryWindowSize(null) + .appType("rag") + .lastUpdateTime(Instant.ofEpochMilli(1)) + .build(); - readInputStream(updateAgentInput, parsedInput -> { - assertNull(parsedInput.getLlm()); + readInputStream(inputWithNulls, parsedInput -> { + assertNull(parsedInput.getLlmModelId()); + assertNull(parsedInput.getLlmParameters()); assertNull(parsedInput.getTools()); assertNull(parsedInput.getParameters()); - assertNull(parsedInput.getMemory()); + assertNull(parsedInput.getMemoryType()); + assertNull(parsedInput.getMemorySessionId()); + assertNull(parsedInput.getMemoryWindowSize()); }); } @@ -278,10 +374,13 @@ public void testReadInputStreamSuccess() throws IOException { assertEquals(updateAgentInput.getAgentId(), parsedInput.getAgentId()); assertEquals(updateAgentInput.getName(), parsedInput.getName()); assertEquals(updateAgentInput.getDescription(), parsedInput.getDescription()); - assertEquals(updateAgentInput.getLlm().getModelId(), parsedInput.getLlm().getModelId()); + assertEquals(updateAgentInput.getLlmModelId(), parsedInput.getLlmModelId()); + assertEquals(updateAgentInput.getLlmParameters(), parsedInput.getLlmParameters()); assertEquals(updateAgentInput.getTools().size(), parsedInput.getTools().size()); assertEquals(updateAgentInput.getParameters().size(), parsedInput.getParameters().size()); - assertEquals(updateAgentInput.getMemory().getType(), parsedInput.getMemory().getType()); + assertEquals(updateAgentInput.getMemoryType(), parsedInput.getMemoryType()); + assertEquals(updateAgentInput.getMemorySessionId(), parsedInput.getMemorySessionId()); + assertEquals(updateAgentInput.getMemoryWindowSize(), parsedInput.getMemoryWindowSize()); assertEquals(updateAgentInput.getAppType(), parsedInput.getAppType()); assertEquals(updateAgentInput.getLastUpdateTime(), parsedInput.getLastUpdateTime()); }); @@ -313,7 +412,9 @@ public void testParseWithAllFields() throws Exception { "chat_history": "test" }, "memory": { - "type": "conversation_index" + "type": "conversation_index", + "session_id": "test-session", + "window_size": 5 }, "app_type": "rag", "last_updated_time": 1234567890, @@ -323,11 +424,13 @@ public void testParseWithAllFields() throws Exception { testParseFromJsonString(inputStr, parsedInput -> { assertEquals("test-agent", parsedInput.getName()); assertEquals("test description", parsedInput.getDescription()); - assertEquals("test-model-id", parsedInput.getLlm().getModelId()); + assertEquals("test-model-id", parsedInput.getLlmModelId()); assertEquals(1, parsedInput.getTools().size()); assertEquals("test-tool", parsedInput.getTools().getFirst().getName()); assertEquals("test", parsedInput.getParameters().get("chat_history")); - assertEquals("conversation_index", parsedInput.getMemory().getType()); + assertEquals("conversation_index", parsedInput.getMemoryType()); + assertEquals("test-session", parsedInput.getMemorySessionId()); + assertEquals(Integer.valueOf(5), parsedInput.getMemoryWindowSize()); assertEquals("rag", parsedInput.getAppType()); assertEquals(1234567890L, parsedInput.getLastUpdateTime().toEpochMilli()); assertEquals("test-tenant", parsedInput.getTenantId()); @@ -341,7 +444,8 @@ public void testToXContentWithAllFields() throws Exception { .agentId("test-agent-id") .name("test-agent") .description("test description") - .llm(LLMSpec.builder().modelId("test-model-id").parameters(Map.of("max_iteration", "5")).build()) + .llmModelId("test-model-id") + .llmParameters(Map.of("max_iteration", "5")) .tools( Collections .singletonList( @@ -349,7 +453,9 @@ public void testToXContentWithAllFields() throws Exception { ) ) .parameters(Map.of("chat_history", "test")) - .memory(MLMemorySpec.builder().type("conversation_index").build()) + .memoryType("conversation_index") + .memorySessionId("test-session") + .memoryWindowSize(5) .appType("rag") .lastUpdateTime(Instant.ofEpochMilli(1234567890)) .tenantId("test-tenant") @@ -375,7 +481,8 @@ public void testStreamInputOutput() throws IOException { .agentId("test-agent-id") .name("test-agent") .description("test description") - .llm(LLMSpec.builder().modelId("test-model-id").parameters(Map.of("max_iteration", "5")).build()) + .llmModelId("test-model-id") + .llmParameters(Map.of("max_iteration", "5")) .tools( Collections .singletonList( @@ -383,7 +490,9 @@ public void testStreamInputOutput() throws IOException { ) ) .parameters(Map.of("chat_history", "test")) - .memory(MLMemorySpec.builder().type("conversation_index").build()) + .memoryType("conversation_index") + .memorySessionId("test-session") + .memoryWindowSize(10) .appType("rag") .lastUpdateTime(Instant.ofEpochMilli(1234567890)) .tenantId("test-tenant") @@ -393,14 +502,16 @@ public void testStreamInputOutput() throws IOException { assertEquals(input.getAgentId(), parsedInput.getAgentId()); assertEquals(input.getName(), parsedInput.getName()); assertEquals(input.getDescription(), parsedInput.getDescription()); - assertEquals(input.getLlm().getModelId(), parsedInput.getLlm().getModelId()); - assertEquals(input.getLlm().getParameters(), parsedInput.getLlm().getParameters()); + assertEquals(input.getLlmModelId(), parsedInput.getLlmModelId()); + assertEquals(input.getLlmParameters(), parsedInput.getLlmParameters()); assertEquals(input.getTools().size(), parsedInput.getTools().size()); assertEquals(input.getTools().getFirst().getName(), parsedInput.getTools().getFirst().getName()); assertEquals(input.getTools().getFirst().getType(), parsedInput.getTools().getFirst().getType()); assertEquals(input.getTools().getFirst().getParameters(), parsedInput.getTools().getFirst().getParameters()); assertEquals(input.getParameters(), parsedInput.getParameters()); - assertEquals(input.getMemory().getType(), parsedInput.getMemory().getType()); + assertEquals(input.getMemoryType(), parsedInput.getMemoryType()); + assertEquals(input.getMemorySessionId(), parsedInput.getMemorySessionId()); + assertEquals(input.getMemoryWindowSize(), parsedInput.getMemoryWindowSize()); assertEquals(input.getAppType(), parsedInput.getAppType()); assertEquals(input.getLastUpdateTime(), parsedInput.getLastUpdateTime()); assertEquals(input.getTenantId(), parsedInput.getTenantId()); @@ -434,4 +545,454 @@ private String serializationWithToXContent(MLAgentUpdateInput input) throws IOEx assertNotNull(builder); return builder.toString(); } + + @Test + public void testLLMPartialUpdate() { + // Create original agent with LLM configuration + Map originalParams = new HashMap<>(); + originalParams.put("temperature", "0.7"); + originalParams.put("max_tokens", "1000"); + originalParams.put("top_p", "0.9"); + + LLMSpec originalLlm = LLMSpec.builder().modelId("original-model").parameters(originalParams).build(); + + MLAgent originalAgent = MLAgent + .builder() + .name("Test Agent") + .description("Test description") + .llm(originalLlm) + .type(MLAgentType.CONVERSATIONAL.name()) + .build(); + + // Test Case 1: Update only model ID + MLAgentUpdateInput updateInput1 = MLAgentUpdateInput + .builder() + .agentId("test-agent-id") + .llmModelId("new-model") + // No llmParameters specified - should keep existing ones + .build(); + + MLAgent updatedAgent1 = updateInput1.toMLAgent(originalAgent); + + assertEquals("new-model", updatedAgent1.getLlm().getModelId()); + assertEquals(originalParams, updatedAgent1.getLlm().getParameters()); + + // Test Case 2: Update only parameters + Map updateParams2 = new HashMap<>(); + updateParams2.put("temperature", "0.5"); // Override existing + updateParams2.put("frequency_penalty", "0.1"); // Add new + + MLAgentUpdateInput updateInput2 = MLAgentUpdateInput + .builder() + .agentId("test-agent-id") + .llmParameters(updateParams2) + // No llmModelId specified - should keep existing one + .build(); + + MLAgent updatedAgent2 = updateInput2.toMLAgent(originalAgent); + + assertEquals("original-model", updatedAgent2.getLlm().getModelId()); + + Map expectedParams = new HashMap<>(); + expectedParams.put("temperature", "0.5"); // Overridden + expectedParams.put("max_tokens", "1000"); // Kept from original + expectedParams.put("top_p", "0.9"); // Kept from original + expectedParams.put("frequency_penalty", "0.1"); // Added + + assertEquals(expectedParams, updatedAgent2.getLlm().getParameters()); + + // Test Case 3: Update both model ID and parameters + Map updateParams3 = new HashMap<>(); + updateParams3.put("top_p", "0.95"); // Override existing + updateParams3.put("presence_penalty", "0.2"); // Add new + + MLAgentUpdateInput updateInput3 = MLAgentUpdateInput + .builder() + .agentId("test-agent-id") + .llmModelId("another-model") + .llmParameters(updateParams3) + .build(); + + MLAgent updatedAgent3 = updateInput3.toMLAgent(originalAgent); + + assertEquals("another-model", updatedAgent3.getLlm().getModelId()); + + Map expectedParams3 = new HashMap<>(); + expectedParams3.put("temperature", "0.7"); // Kept from original + expectedParams3.put("max_tokens", "1000"); // Kept from original + expectedParams3.put("top_p", "0.95"); // Overridden + expectedParams3.put("presence_penalty", "0.2"); // Added + + assertEquals(expectedParams3, updatedAgent3.getLlm().getParameters()); + + // Test Case 4: No LLM update - should keep original + MLAgentUpdateInput updateInput4 = MLAgentUpdateInput + .builder() + .agentId("test-agent-id") + .name("Updated Name") + // No LLM fields specified + .build(); + + MLAgent updatedAgent4 = updateInput4.toMLAgent(originalAgent); + + assertEquals("Updated Name", updatedAgent4.getName()); + assertEquals(originalLlm.getModelId(), updatedAgent4.getLlm().getModelId()); + assertEquals(originalLlm.getParameters(), updatedAgent4.getLlm().getParameters()); + } + + @Test + public void testMemoryPartialUpdate() { + // Create original agent with memory + MLMemorySpec originalMemory = MLMemorySpec.builder().type("conversation_index").sessionId("original-session").windowSize(5).build(); + + MLAgent originalAgent = MLAgent + .builder() + .name("Test Agent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(LLMSpec.builder().modelId("test-model").build()) + .memory(originalMemory) + .build(); + + // Test Case 1: Update only window size + MLAgentUpdateInput updateInput1 = MLAgentUpdateInput + .builder() + .agentId("test-agent-id") + .memoryWindowSize(10) + // No memoryType or memorySessionId specified + .build(); + + MLAgent updatedAgent1 = updateInput1.toMLAgent(originalAgent); + + assertEquals("conversation_index", updatedAgent1.getMemory().getType()); // Preserved + assertEquals("original-session", updatedAgent1.getMemory().getSessionId()); // Preserved + assertEquals(Integer.valueOf(10), updatedAgent1.getMemory().getWindowSize()); // Updated + + // Test Case 2: Update only session ID + MLAgentUpdateInput updateInput2 = MLAgentUpdateInput + .builder() + .agentId("test-agent-id") + .memorySessionId("new-session-123") + // No memoryType or memoryWindowSize specified + .build(); + + MLAgent updatedAgent2 = updateInput2.toMLAgent(originalAgent); + + assertEquals("conversation_index", updatedAgent2.getMemory().getType()); // Preserved + assertEquals("new-session-123", updatedAgent2.getMemory().getSessionId()); // Updated + assertEquals(Integer.valueOf(5), updatedAgent2.getMemory().getWindowSize()); // Preserved + + // Test Case 3: Update multiple memory fields + MLAgentUpdateInput updateInput3 = MLAgentUpdateInput + .builder() + .agentId("test-agent-id") + .memoryType("conversation_index") + .memorySessionId("another-session") + .memoryWindowSize(15) + .build(); + + MLAgent updatedAgent3 = updateInput3.toMLAgent(originalAgent); + + assertEquals("conversation_index", updatedAgent3.getMemory().getType()); + assertEquals("another-session", updatedAgent3.getMemory().getSessionId()); + assertEquals(Integer.valueOf(15), updatedAgent3.getMemory().getWindowSize()); + + // Test Case 4: No memory update - should keep original + MLAgentUpdateInput updateInput4 = MLAgentUpdateInput + .builder() + .agentId("test-agent-id") + .name("Updated Name") + // No memory fields specified + .build(); + + MLAgent updatedAgent4 = updateInput4.toMLAgent(originalAgent); + + assertEquals("Updated Name", updatedAgent4.getName()); + assertEquals(originalMemory.getType(), updatedAgent4.getMemory().getType()); + assertEquals(originalMemory.getSessionId(), updatedAgent4.getMemory().getSessionId()); + assertEquals(originalMemory.getWindowSize(), updatedAgent4.getMemory().getWindowSize()); + } + + @Test + public void testParseWithEmptyMemoryObject() throws Exception { + String inputStr = """ + { + "agent_id": "test-agent-id", + "memory": {} + } + """; + + testParseFromJsonString(inputStr, parsedInput -> { + assertNull(parsedInput.getMemoryType()); + assertNull(parsedInput.getMemorySessionId()); + assertNull(parsedInput.getMemoryWindowSize()); + }); + } + + @Test + public void testParseWithEmptyLLMObject() throws Exception { + String inputStr = """ + { + "agent_id": "test-agent-id", + "llm": {} + } + """; + + testParseFromJsonString(inputStr, parsedInput -> { + assertNull(parsedInput.getLlmModelId()); + assertNull(parsedInput.getLlmParameters()); + }); + } + + @Test + public void testParseWithUnknownMemoryFields() throws Exception { + String inputStr = """ + { + "agent_id": "test-agent-id", + "memory": { + "type": "conversation_index", + "unknown_field": "should_be_ignored", + "another_unknown": 123 + } + } + """; + + testParseFromJsonString(inputStr, parsedInput -> { + assertEquals("conversation_index", parsedInput.getMemoryType()); + assertNull(parsedInput.getMemorySessionId()); + assertNull(parsedInput.getMemoryWindowSize()); + }); + } + + @Test + public void testParseWithUnknownLLMFields() throws Exception { + String inputStr = """ + { + "agent_id": "test-agent-id", + "llm": { + "model_id": "test-model", + "unknown_field": "should_be_ignored", + "another_unknown": 123 + } + } + """; + + testParseFromJsonString(inputStr, parsedInput -> { + assertEquals("test-model", parsedInput.getLlmModelId()); + assertNull(parsedInput.getLlmParameters()); + }); + } + + @Test + public void testToMLAgentWithConversationalAgentCanUpdateLLMParameters() { + // Create original CONVERSATIONAL agent with existing LLM + LLMSpec originalLlm = LLMSpec.builder().modelId("original-model-id").parameters(Map.of("existing_param", "value")).build(); + + MLAgent originalAgent = MLAgent.builder().name("Test Agent").type(MLAgentType.CONVERSATIONAL.name()).llm(originalLlm).build(); + + // Update LLM parameters without providing model ID (should use existing model ID) + MLAgentUpdateInput updateInput = MLAgentUpdateInput + .builder() + .agentId("test-agent-id") + .llmParameters(Map.of("temperature", "0.7")) + // No llmModelId provided, but CONVERSATIONAL agent has existing LLM + .build(); + + // This should work fine and use the existing model ID + MLAgent updatedAgent = updateInput.toMLAgent(originalAgent); + + assertNotNull(updatedAgent.getLlm()); + assertEquals("original-model-id", updatedAgent.getLlm().getModelId()); + assertEquals("0.7", updatedAgent.getLlm().getParameters().get("temperature")); + assertEquals("value", updatedAgent.getLlm().getParameters().get("existing_param")); + } + + @Test + public void testToMLAgentCanUpdateLLMParametersWhenModelIdProvided() { + // Create original agent without LLM (any type) + MLAgent originalAgent = MLAgent.builder().name("Test Agent").type(MLAgentType.FLOW.name()).build(); // No original LLM + + // Update LLM parameters WITH providing model ID (should work for any agent) + MLAgentUpdateInput updateInput = MLAgentUpdateInput + .builder() + .agentId("test-agent-id") + .llmModelId("new-model-id") + .llmParameters(Map.of("temperature", "0.7")) + .build(); + + // This should work fine since we provided the model ID + MLAgent updatedAgent = updateInput.toMLAgent(originalAgent); + + assertNotNull(updatedAgent.getLlm()); + assertEquals("new-model-id", updatedAgent.getLlm().getModelId()); + assertEquals("0.7", updatedAgent.getLlm().getParameters().get("temperature")); + } + + @Test + public void testToMLAgentPreservesAllOriginalFields() { + Instant createdTime = Instant.now(); + MLAgent originalAgent = MLAgent + .builder() + .name("Original Agent") + .type(MLAgentType.CONVERSATIONAL.name()) + .description("Original description") + .createdTime(createdTime) + .isHidden(true) + .llm(LLMSpec.builder().modelId("original-model").parameters(Map.of("temp", "0.5")).build()) + .tools(Collections.singletonList(MLToolSpec.builder().name("original-tool").type("original-type").build())) + .parameters(Map.of("original", "param")) + .memory(MLMemorySpec.builder().type("conversation_index").sessionId("original-session").windowSize(5).build()) + .appType("original-app") + .build(); + + // Update only the name + MLAgentUpdateInput updateInput = MLAgentUpdateInput.builder().agentId("test-agent-id").name("Updated Name").build(); + + MLAgent updatedAgent = updateInput.toMLAgent(originalAgent); + + // Check that only name was updated, everything else preserved + assertEquals("Updated Name", updatedAgent.getName()); + assertEquals(originalAgent.getType(), updatedAgent.getType()); + assertEquals(originalAgent.getDescription(), updatedAgent.getDescription()); + assertEquals(originalAgent.getCreatedTime(), updatedAgent.getCreatedTime()); + assertEquals(originalAgent.getIsHidden(), updatedAgent.getIsHidden()); + assertEquals(originalAgent.getLlm().getModelId(), updatedAgent.getLlm().getModelId()); + assertEquals(originalAgent.getLlm().getParameters(), updatedAgent.getLlm().getParameters()); + assertEquals(originalAgent.getTools(), updatedAgent.getTools()); + assertEquals(originalAgent.getParameters(), updatedAgent.getParameters()); + assertEquals(originalAgent.getMemory().getType(), updatedAgent.getMemory().getType()); + assertEquals(originalAgent.getMemory().getSessionId(), updatedAgent.getMemory().getSessionId()); + assertEquals(originalAgent.getMemory().getWindowSize(), updatedAgent.getMemory().getWindowSize()); + // appType is set to null because updateInput.appType is null + assertNull(updatedAgent.getAppType()); + } + + @Test + public void testToXContentWithNullValues() throws Exception { + MLAgentUpdateInput input = MLAgentUpdateInput + .builder() + .agentId("test-agent-id") + .name(null) + .description(null) + .llmModelId(null) + .llmParameters(null) + .tools(null) + .parameters(null) + .memoryType(null) + .memorySessionId(null) + .memoryWindowSize(null) + .appType(null) + .lastUpdateTime(null) + .tenantId(null) + .build(); + + String jsonStr = serializationWithToXContent(input); + + // Should only contain agent_id + assertTrue(jsonStr.contains("\"agent_id\":\"test-agent-id\"")); + assertFalse(jsonStr.contains("\"name\"")); + assertFalse(jsonStr.contains("\"description\"")); + assertFalse(jsonStr.contains("\"llm\"")); + assertFalse(jsonStr.contains("\"tools\"")); + assertFalse(jsonStr.contains("\"parameters\"")); + assertFalse(jsonStr.contains("\"memory\"")); + assertFalse(jsonStr.contains("\"app_type\"")); + assertFalse(jsonStr.contains("\"last_updated_time\"")); + assertFalse(jsonStr.contains("\"tenant_id\"")); + } + + @Test + public void testToXContentWithEmptyCollections() throws Exception { + MLAgentUpdateInput input = MLAgentUpdateInput + .builder() + .agentId("test-agent-id") + .llmParameters(Collections.emptyMap()) + .tools(Collections.emptyList()) + .parameters(Collections.emptyMap()) + .build(); + + String jsonStr = serializationWithToXContent(input); + + // Should only contain agent_id (empty collections should not be serialized) + assertTrue(jsonStr.contains("\"agent_id\":\"test-agent-id\"")); + assertFalse(jsonStr.contains("\"llm\"")); + assertFalse(jsonStr.contains("\"tools\"")); + assertFalse(jsonStr.contains("\"parameters\"")); + } + + @Test + public void testCombinedLLMAndMemoryPartialUpdates() { + // Create original agent with both LLM and memory + LLMSpec originalLlm = LLMSpec + .builder() + .modelId("original-model") + .parameters(Map.of("temperature", "0.5", "max_tokens", "100")) + .build(); + + MLMemorySpec originalMemory = MLMemorySpec.builder().type("conversation_index").sessionId("original-session").windowSize(5).build(); + + MLAgent originalAgent = MLAgent + .builder() + .name("Test Agent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(originalLlm) + .memory(originalMemory) + .build(); + + // Update some LLM and memory fields simultaneously + MLAgentUpdateInput updateInput = MLAgentUpdateInput + .builder() + .agentId("test-agent-id") + .llmParameters(Map.of("temperature", "0.8")) // Update LLM parameter + .memoryWindowSize(10) // Update memory window size + .build(); + + MLAgent updatedAgent = updateInput.toMLAgent(originalAgent); + + // Check LLM merging + assertEquals("original-model", updatedAgent.getLlm().getModelId()); // Preserved + assertEquals("0.8", updatedAgent.getLlm().getParameters().get("temperature")); // Updated + assertEquals("100", updatedAgent.getLlm().getParameters().get("max_tokens")); // Preserved + + // Check memory merging + assertEquals("conversation_index", updatedAgent.getMemory().getType()); // Preserved + assertEquals("original-session", updatedAgent.getMemory().getSessionId()); // Preserved + assertEquals(Integer.valueOf(10), updatedAgent.getMemory().getWindowSize()); // Updated + } + + @Test + public void testStreamInputOutputWithVersion() throws IOException { + MLAgentUpdateInput input = MLAgentUpdateInput + .builder() + .agentId("test-agent-id") + .name("test-agent") + .description("test description") + .llmModelId("test-model-id") + .llmParameters(Map.of("temperature", "0.7")) + .memoryType("conversation_index") + .memorySessionId("test-session") + .memoryWindowSize(10) + .appType("rag") + .lastUpdateTime(Instant.ofEpochMilli(1234567890)) + .tenantId("test-tenant") + .build(); + + // Test with different versions + BytesStreamOutput output = new BytesStreamOutput(); + input.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + MLAgentUpdateInput parsedInput = new MLAgentUpdateInput(streamInput); + + assertEquals(input.getAgentId(), parsedInput.getAgentId()); + assertEquals(input.getName(), parsedInput.getName()); + assertEquals(input.getDescription(), parsedInput.getDescription()); + assertEquals(input.getLlmModelId(), parsedInput.getLlmModelId()); + assertEquals(input.getLlmParameters(), parsedInput.getLlmParameters()); + assertEquals(input.getMemoryType(), parsedInput.getMemoryType()); + assertEquals(input.getMemorySessionId(), parsedInput.getMemorySessionId()); + assertEquals(input.getMemoryWindowSize(), parsedInput.getMemoryWindowSize()); + assertEquals(input.getAppType(), parsedInput.getAppType()); + assertEquals(input.getLastUpdateTime(), parsedInput.getLastUpdateTime()); + assertEquals(input.getTenantId(), parsedInput.getTenantId()); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/list/MLMcpToolsListRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/list/MLMcpToolsListRequestTests.java new file mode 100644 index 0000000000..ec2ac2d106 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/list/MLMcpToolsListRequestTests.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.mcpserver.requests.list; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.core.common.io.stream.StreamInput; + +public class MLMcpToolsListRequestTests { + + @Before + public void setUp() {} + + @Test + public void testMLMcpToolsListRequest_withStreamInput() throws IOException { + StreamInput streamInput = mock(StreamInput.class); + when(streamInput.readString()).thenReturn("mockNodeId"); + MLMcpToolsListRequest mlMcpToolsListRequest = new MLMcpToolsListRequest(streamInput); + assertNotNull(mlMcpToolsListRequest); + } + + @Test + public void testMLMcpToolsListRequest() { + MLMcpToolsListRequest mlMcpToolsListRequest = new MLMcpToolsListRequest(); + assertNotNull(mlMcpToolsListRequest); + } + + @Test + public void testValidate() { + MLMcpToolsListRequest mlMcpToolsListRequest = new MLMcpToolsListRequest(); + assertNull(mlMcpToolsListRequest.validate()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/message/MLMcpMessageRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/message/MLMcpMessageRequestTest.java index fadef0e930..2ee327ae8a 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/message/MLMcpMessageRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/message/MLMcpMessageRequestTest.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common.transport.mcpserver.requests.message; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/register/MLMcpToolsRegisterNodeRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/register/MLMcpToolsRegisterNodeRequestTest.java index 1373ea2ce3..78c31c9a8b 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/register/MLMcpToolsRegisterNodeRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/register/MLMcpToolsRegisterNodeRequestTest.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common.transport.mcpserver.requests.register; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/register/MLMcpToolsRegisterNodesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/register/MLMcpToolsRegisterNodesRequestTest.java index 18bb946c6c..edb0c8c4b6 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/register/MLMcpToolsRegisterNodesRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/register/MLMcpToolsRegisterNodesRequestTest.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common.transport.mcpserver.requests.register; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/register/RegisterMcpToolTest.java b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/register/RegisterMcpToolTest.java index a338d9b6ee..d83ea9f617 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/register/RegisterMcpToolTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/register/RegisterMcpToolTest.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common.transport.mcpserver.requests.register; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/remove/MLMcpToolsRemoveNodeRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/remove/MLMcpToolsRemoveNodeRequestTest.java index 5380b1e802..752e5962f6 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/remove/MLMcpToolsRemoveNodeRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/remove/MLMcpToolsRemoveNodeRequestTest.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common.transport.mcpserver.requests.remove; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/remove/MLMcpToolsRemoveNodesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/remove/MLMcpToolsRemoveNodesRequestTest.java index 019764d3d2..3252a8fb30 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/remove/MLMcpToolsRemoveNodesRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/remove/MLMcpToolsRemoveNodesRequestTest.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common.transport.mcpserver.requests.remove; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/update/MLMcpToolsUpdateNodeRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/update/MLMcpToolsUpdateNodeRequestTest.java index 1cc91055ea..a05b824d59 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/update/MLMcpToolsUpdateNodeRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/update/MLMcpToolsUpdateNodeRequestTest.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common.transport.mcpserver.requests.update; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/update/MLMcpToolsUpdateNodesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/update/MLMcpToolsUpdateNodesRequestTest.java index 382eabd3b0..da66c96702 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/update/MLMcpToolsUpdateNodesRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/update/MLMcpToolsUpdateNodesRequestTest.java @@ -1,10 +1,8 @@ package org.opensearch.ml.common.transport.mcpserver.requests.update; /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ import static org.junit.Assert.*; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/list/MLMcpListToolsResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/list/MLMcpListToolsResponseTest.java index 8a494ec3d9..a15a4b6470 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/list/MLMcpListToolsResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/list/MLMcpListToolsResponseTest.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common.transport.mcpserver.responses.list; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/register/MLMcpRegisterNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/register/MLMcpRegisterNodeResponseTest.java index 7ed9041f2f..2ece3d0cfc 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/register/MLMcpRegisterNodeResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/register/MLMcpRegisterNodeResponseTest.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common.transport.mcpserver.responses.register; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/register/MLMcpRegisterNodesResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/register/MLMcpRegisterNodesResponseTest.java index c765e47069..22207b77ef 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/register/MLMcpRegisterNodesResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/register/MLMcpRegisterNodesResponseTest.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common.transport.mcpserver.responses.register; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/remove/MLMcpRemoveNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/remove/MLMcpRemoveNodeResponseTest.java index 4e0b1598ed..2e6b7eaaf7 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/remove/MLMcpRemoveNodeResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/remove/MLMcpRemoveNodeResponseTest.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common.transport.mcpserver.responses.remove; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/remove/MLMcpRemoveNodesResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/remove/MLMcpRemoveNodesResponseTest.java index 0f3ece0053..a44f12b1a1 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/remove/MLMcpRemoveNodesResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/remove/MLMcpRemoveNodesResponseTest.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common.transport.mcpserver.responses.remove; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/update/MLMcpToolsUpdateNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/update/MLMcpToolsUpdateNodeResponseTest.java index 3f90a26519..4bd18018ab 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/update/MLMcpToolsUpdateNodeResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/update/MLMcpToolsUpdateNodeResponseTest.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common.transport.mcpserver.responses.update; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/update/MLMcpToolsUpdateNodesResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/update/MLMcpToolsUpdateNodesResponseTest.java index 7ec2294bce..99a6d43930 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/update/MLMcpToolsUpdateNodesResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/responses/update/MLMcpToolsUpdateNodesResponseTest.java @@ -1,8 +1,6 @@ /* - * - * * Copyright OpenSearch Contributors - * * SPDX-License-Identifier: Apache-2.0 - * + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.common.transport.mcpserver.responses.update; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/MLCreateMemoryContainerInputTests.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/MLCreateMemoryContainerInputTests.java new file mode 100644 index 0000000000..d8aa04d984 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/MLCreateMemoryContainerInputTests.java @@ -0,0 +1,389 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.ml.common.memorycontainer.MemoryStorageConfig; + +public class MLCreateMemoryContainerInputTests { + + private MLCreateMemoryContainerInput inputWithAllFields; + private MLCreateMemoryContainerInput inputMinimal; + private MemoryStorageConfig testMemoryStorageConfig; + + @Before + public void setUp() { + // Create test memory storage config + testMemoryStorageConfig = MemoryStorageConfig + .builder() + .memoryIndexName("test-memory-index") + .embeddingModelType(FunctionName.TEXT_EMBEDDING) + .embeddingModelId("test-embedding-model") + .llmModelId("test-llm-model") + .dimension(768) + .maxInferSize(8) + .build(); + + // Input with all fields + inputWithAllFields = MLCreateMemoryContainerInput + .builder() + .name("test-memory-container") + .description("Test memory container description") + .memoryStorageConfig(testMemoryStorageConfig) + .tenantId("test-tenant") + .build(); + + // Minimal input (only required fields) + inputMinimal = MLCreateMemoryContainerInput.builder().name("minimal-container").build(); + } + + @Test + public void testConstructorWithBuilder() { + assertNotNull(inputWithAllFields); + assertEquals("test-memory-container", inputWithAllFields.getName()); + assertEquals("Test memory container description", inputWithAllFields.getDescription()); + assertEquals(testMemoryStorageConfig, inputWithAllFields.getMemoryStorageConfig()); + assertEquals("test-tenant", inputWithAllFields.getTenantId()); + } + + @Test + public void testConstructorWithBuilderMinimal() { + assertNotNull(inputMinimal); + assertEquals("minimal-container", inputMinimal.getName()); + assertNull(inputMinimal.getDescription()); + assertNull(inputMinimal.getMemoryStorageConfig()); + assertNull(inputMinimal.getTenantId()); + } + + @Test + public void testConstructorWithAllParameters() { + MLCreateMemoryContainerInput input = new MLCreateMemoryContainerInput( + "param-container", + "param description", + testMemoryStorageConfig, + "param-tenant" + ); + + assertEquals("param-container", input.getName()); + assertEquals("param description", input.getDescription()); + assertEquals(testMemoryStorageConfig, input.getMemoryStorageConfig()); + assertEquals("param-tenant", input.getTenantId()); + } + + @Test + public void testConstructorWithNullOptionalFields() { + MLCreateMemoryContainerInput input = new MLCreateMemoryContainerInput("null-optional-container", null, null, null); + + assertEquals("null-optional-container", input.getName()); + assertNull(input.getDescription()); + assertNull(input.getMemoryStorageConfig()); + assertNull(input.getTenantId()); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorWithNullName() { + MLCreateMemoryContainerInput + .builder() + .name(null) // This should throw IllegalArgumentException + .description("test description") + .build(); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorWithNullNameDirectConstructor() { + new MLCreateMemoryContainerInput( + null, // This should throw IllegalArgumentException + "test description", + testMemoryStorageConfig, + "test-tenant" + ); + } + + @Test + public void testStreamInputOutput() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + inputWithAllFields.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLCreateMemoryContainerInput parsedInput = new MLCreateMemoryContainerInput(streamInput); + + assertEquals(inputWithAllFields.getName(), parsedInput.getName()); + assertEquals(inputWithAllFields.getDescription(), parsedInput.getDescription()); + assertEquals(inputWithAllFields.getMemoryStorageConfig(), parsedInput.getMemoryStorageConfig()); + assertEquals(inputWithAllFields.getTenantId(), parsedInput.getTenantId()); + } + + @Test + public void testStreamInputOutputWithNullValues() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + inputMinimal.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLCreateMemoryContainerInput parsedInput = new MLCreateMemoryContainerInput(streamInput); + + assertEquals(inputMinimal.getName(), parsedInput.getName()); + assertNull(parsedInput.getDescription()); + assertNull(parsedInput.getMemoryStorageConfig()); + assertNull(parsedInput.getTenantId()); + } + + @Test + public void testToXContentWithAllFields() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + inputWithAllFields.toXContent(builder, EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + + assertNotNull(jsonStr); + // Verify all fields are present in the JSON + assertTrue(jsonStr.contains("\"name\":\"test-memory-container\"")); + assertTrue(jsonStr.contains("\"description\":\"Test memory container description\"")); + assertTrue(jsonStr.contains("\"tenant_id\":\"test-tenant\"")); + assertTrue(jsonStr.contains("\"memory_storage_config\"")); + // Verify memory storage config fields are nested + assertTrue(jsonStr.contains("\"memory_index_name\":\"test-memory-index\"")); + assertTrue(jsonStr.contains("\"embedding_model_type\":\"TEXT_EMBEDDING\"")); + } + + @Test + public void testToXContentWithMinimalFields() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + inputMinimal.toXContent(builder, EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + + assertNotNull(jsonStr); + // Verify only required fields are present + assertTrue(jsonStr.contains("\"name\":\"minimal-container\"")); + // Verify optional fields are not present + assertFalse(jsonStr.contains("\"description\"")); + assertFalse(jsonStr.contains("\"memory_storage_config\"")); + assertFalse(jsonStr.contains("\"tenant_id\"")); + } + + @Test + public void testParseFromXContentWithAllFields() throws IOException { + String jsonStr = "{" + + "\"name\":\"parsed-container\"," + + "\"description\":\"parsed description\"," + + "\"tenant_id\":\"parsed-tenant\"," + + "\"memory_storage_config\":{" + + "\"memory_index_name\":\"parsed-index\"," + + "\"embedding_model_type\":\"TEXT_EMBEDDING\"," + + "\"embedding_model_id\":\"parsed-embedding-model\"," + + "\"llm_model_id\":\"parsed-llm-model\"," + + "\"dimension\":512," + + "\"max_infer_size\":7" + + "}" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + MLCreateMemoryContainerInput parsedInput = MLCreateMemoryContainerInput.parse(parser); + + assertEquals("parsed-container", parsedInput.getName()); + assertEquals("parsed description", parsedInput.getDescription()); + assertEquals("parsed-tenant", parsedInput.getTenantId()); + assertNotNull(parsedInput.getMemoryStorageConfig()); + assertEquals("parsed-index", parsedInput.getMemoryStorageConfig().getMemoryIndexName()); + assertEquals(FunctionName.TEXT_EMBEDDING, parsedInput.getMemoryStorageConfig().getEmbeddingModelType()); + assertEquals("parsed-embedding-model", parsedInput.getMemoryStorageConfig().getEmbeddingModelId()); + assertEquals("parsed-llm-model", parsedInput.getMemoryStorageConfig().getLlmModelId()); + assertEquals(Integer.valueOf(512), parsedInput.getMemoryStorageConfig().getDimension()); + assertEquals(Integer.valueOf(7), parsedInput.getMemoryStorageConfig().getMaxInferSize()); + } + + @Test + public void testParseFromXContentWithMinimalFields() throws IOException { + String jsonStr = "{\"name\":\"minimal-parsed-container\"}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + MLCreateMemoryContainerInput parsedInput = MLCreateMemoryContainerInput.parse(parser); + + assertEquals("minimal-parsed-container", parsedInput.getName()); + assertNull(parsedInput.getDescription()); + assertNull(parsedInput.getMemoryStorageConfig()); + assertNull(parsedInput.getTenantId()); + } + + @Test + public void testParseFromXContentWithUnknownFields() throws IOException { + String jsonStr = "{" + + "\"name\":\"unknown-fields-container\"," + + "\"unknown_field\":\"unknown_value\"," + + "\"description\":\"test description\"," + + "\"another_unknown\":123" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + MLCreateMemoryContainerInput parsedInput = MLCreateMemoryContainerInput.parse(parser); + + assertEquals("unknown-fields-container", parsedInput.getName()); + assertEquals("test description", parsedInput.getDescription()); + // Unknown fields should be ignored + assertNull(parsedInput.getMemoryStorageConfig()); + assertNull(parsedInput.getTenantId()); + } + + @Test + public void testCompleteRoundTrip() throws IOException { + // Test complete round trip: object -> JSON -> parse -> compare + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + inputWithAllFields.toXContent(builder, EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + MLCreateMemoryContainerInput parsedInput = MLCreateMemoryContainerInput.parse(parser); + + assertEquals(inputWithAllFields.getName(), parsedInput.getName()); + assertEquals(inputWithAllFields.getDescription(), parsedInput.getDescription()); + assertEquals(inputWithAllFields.getTenantId(), parsedInput.getTenantId()); + assertEquals(inputWithAllFields.getMemoryStorageConfig(), parsedInput.getMemoryStorageConfig()); + } + + @Test + public void testEqualsAndHashCode() { + MLCreateMemoryContainerInput input1 = MLCreateMemoryContainerInput + .builder() + .name("test-container") + .description("test description") + .memoryStorageConfig(testMemoryStorageConfig) + .tenantId("test-tenant") + .build(); + + MLCreateMemoryContainerInput input2 = MLCreateMemoryContainerInput + .builder() + .name("test-container") + .description("test description") + .memoryStorageConfig(testMemoryStorageConfig) + .tenantId("test-tenant") + .build(); + + MLCreateMemoryContainerInput input3 = MLCreateMemoryContainerInput + .builder() + .name("different-container") + .description("test description") + .memoryStorageConfig(testMemoryStorageConfig) + .tenantId("test-tenant") + .build(); + + assertEquals(input1, input2); + assertEquals(input1.hashCode(), input2.hashCode()); + assertFalse(input1.equals(input3)); + assertTrue(input1.hashCode() != input3.hashCode()); + } + + @Test + public void testSettersAndGetters() { + MLCreateMemoryContainerInput input = MLCreateMemoryContainerInput.builder().name("initial-name").build(); + + // Test setters + input.setName("new-name"); + input.setDescription("new-description"); + input.setMemoryStorageConfig(testMemoryStorageConfig); + input.setTenantId("new-tenant"); + + // Test getters + assertEquals("new-name", input.getName()); + assertEquals("new-description", input.getDescription()); + assertEquals(testMemoryStorageConfig, input.getMemoryStorageConfig()); + assertEquals("new-tenant", input.getTenantId()); + } + + @Test + public void testToBuilder() { + MLCreateMemoryContainerInput modifiedInput = inputWithAllFields + .toBuilder() + .name("modified-name") + .description("modified description") + .build(); + + assertEquals("modified-name", modifiedInput.getName()); + assertEquals("modified description", modifiedInput.getDescription()); + // Other fields should remain the same + assertEquals(inputWithAllFields.getMemoryStorageConfig(), modifiedInput.getMemoryStorageConfig()); + assertEquals(inputWithAllFields.getTenantId(), modifiedInput.getTenantId()); + } + + @Test + public void testFieldConstants() { + // Test that field constants are correctly defined + assertEquals("name", MLCreateMemoryContainerInput.NAME_FIELD); + assertEquals("description", MLCreateMemoryContainerInput.DESCRIPTION_FIELD); + assertEquals("memory_storage_config", MLCreateMemoryContainerInput.MEMORY_STORAGE_CONFIG_FIELD); + } + + @Test + public void testParseFromXContentWithPartialMemoryStorageConfig() throws IOException { + String jsonStr = "{" + + "\"name\":\"partial-config-container\"," + + "\"description\":\"test with partial config\"," + + "\"memory_storage_config\":{" + + "\"memory_index_name\":\"partial-index\"," + + "\"llm_model_id\":\"partial-llm-model\"" + + "}" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + MLCreateMemoryContainerInput parsedInput = MLCreateMemoryContainerInput.parse(parser); + + assertEquals("partial-config-container", parsedInput.getName()); + assertEquals("test with partial config", parsedInput.getDescription()); + assertNotNull(parsedInput.getMemoryStorageConfig()); + assertEquals("partial-index", parsedInput.getMemoryStorageConfig().getMemoryIndexName()); + assertEquals("partial-llm-model", parsedInput.getMemoryStorageConfig().getLlmModelId()); + // Semantic storage should be disabled due to missing embedding config + assertFalse(parsedInput.getMemoryStorageConfig().isSemanticStorageEnabled()); + } + + @Test + public void testDataAnnotationFunctionality() { + // Test that @Data annotation provides toString, equals, hashCode + String toString = inputWithAllFields.toString(); + assertNotNull(toString); + assertTrue(toString.contains("test-memory-container")); + assertTrue(toString.contains("Test memory container description")); + assertTrue(toString.contains("test-tenant")); + } + + // Helper method for assertions + private void assertTrue(boolean condition) { + org.junit.Assert.assertTrue(condition); + } + + private void assertFalse(boolean condition) { + org.junit.Assert.assertFalse(condition); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/MLCreateMemoryContainerRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/MLCreateMemoryContainerRequestTests.java new file mode 100644 index 0000000000..aa438acabd --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/MLCreateMemoryContainerRequestTests.java @@ -0,0 +1,304 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.memorycontainer.MemoryStorageConfig; + +public class MLCreateMemoryContainerRequestTests { + + private MLCreateMemoryContainerRequest requestWithAllFields; + private MLCreateMemoryContainerRequest requestMinimal; + private MLCreateMemoryContainerInput testInput; + private MLCreateMemoryContainerInput minimalInput; + private MemoryStorageConfig testMemoryStorageConfig; + + @Before + public void setUp() { + // Create test memory storage config + testMemoryStorageConfig = MemoryStorageConfig + .builder() + .memoryIndexName("test-memory-index") + .embeddingModelType(FunctionName.TEXT_EMBEDDING) + .embeddingModelId("test-embedding-model") + .llmModelId("test-llm-model") + .dimension(768) + .maxInferSize(8) + .build(); + + // Create test input with all fields + testInput = MLCreateMemoryContainerInput + .builder() + .name("test-memory-container") + .description("Test memory container description") + .memoryStorageConfig(testMemoryStorageConfig) + .tenantId("test-tenant") + .build(); + + // Create minimal input + minimalInput = MLCreateMemoryContainerInput.builder().name("minimal-container").build(); + + // Create requests + requestWithAllFields = MLCreateMemoryContainerRequest.builder().mlCreateMemoryContainerInput(testInput).build(); + + requestMinimal = MLCreateMemoryContainerRequest.builder().mlCreateMemoryContainerInput(minimalInput).build(); + } + + @Test + public void testConstructorWithBuilder() { + assertNotNull(requestWithAllFields); + assertEquals(testInput, requestWithAllFields.getMlCreateMemoryContainerInput()); + + assertNotNull(requestMinimal); + assertEquals(minimalInput, requestMinimal.getMlCreateMemoryContainerInput()); + } + + @Test + public void testConstructorWithInput() { + MLCreateMemoryContainerRequest request = new MLCreateMemoryContainerRequest(testInput); + + assertNotNull(request); + assertEquals(testInput, request.getMlCreateMemoryContainerInput()); + } + + @Test + public void testStreamInputOutput() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + requestWithAllFields.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLCreateMemoryContainerRequest parsedRequest = new MLCreateMemoryContainerRequest(streamInput); + + assertNotNull(parsedRequest); + assertNotNull(parsedRequest.getMlCreateMemoryContainerInput()); + + // Verify the input fields + MLCreateMemoryContainerInput originalInput = requestWithAllFields.getMlCreateMemoryContainerInput(); + MLCreateMemoryContainerInput parsedInput = parsedRequest.getMlCreateMemoryContainerInput(); + + assertEquals(originalInput.getName(), parsedInput.getName()); + assertEquals(originalInput.getDescription(), parsedInput.getDescription()); + assertEquals(originalInput.getTenantId(), parsedInput.getTenantId()); + assertEquals(originalInput.getMemoryStorageConfig(), parsedInput.getMemoryStorageConfig()); + } + + @Test + public void testStreamInputOutputWithMinimalFields() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + requestMinimal.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLCreateMemoryContainerRequest parsedRequest = new MLCreateMemoryContainerRequest(streamInput); + + assertNotNull(parsedRequest); + assertNotNull(parsedRequest.getMlCreateMemoryContainerInput()); + + MLCreateMemoryContainerInput parsedInput = parsedRequest.getMlCreateMemoryContainerInput(); + assertEquals("minimal-container", parsedInput.getName()); + assertNull(parsedInput.getDescription()); + assertNull(parsedInput.getTenantId()); + assertNull(parsedInput.getMemoryStorageConfig()); + } + + @Test + public void testValidateWithValidInput() { + ActionRequestValidationException validationException = requestWithAllFields.validate(); + assertNull(validationException); + } + + @Test + public void testValidateWithMinimalValidInput() { + ActionRequestValidationException validationException = requestMinimal.validate(); + assertNull(validationException); + } + + @Test + public void testValidateWithNullInput() { + MLCreateMemoryContainerRequest requestWithNullInput = MLCreateMemoryContainerRequest + .builder() + .mlCreateMemoryContainerInput(null) + .build(); + + ActionRequestValidationException validationException = requestWithNullInput.validate(); + + assertNotNull(validationException); + assertTrue(validationException.validationErrors().contains("Memory container input can't be null")); + } + + @Test + public void testFromActionRequestWithSameType() { + MLCreateMemoryContainerRequest result = MLCreateMemoryContainerRequest.fromActionRequest(requestWithAllFields); + + assertSame(requestWithAllFields, result); + } + + @Test + public void testFromActionRequestWithDifferentType() throws IOException { + // Create a properly serializable ActionRequest that writes data in the expected format + ActionRequest mockActionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + // Write data in the same format as MLCreateMemoryContainerRequest + super.writeTo(out); // Write ActionRequest base data + testInput.writeTo(out); // Write the MLCreateMemoryContainerInput data + } + }; + + MLCreateMemoryContainerRequest result = MLCreateMemoryContainerRequest.fromActionRequest(mockActionRequest); + + assertNotNull(result); + assertNotNull(result.getMlCreateMemoryContainerInput()); + assertEquals(testInput.getName(), result.getMlCreateMemoryContainerInput().getName()); + assertEquals(testInput.getDescription(), result.getMlCreateMemoryContainerInput().getDescription()); + assertEquals(testInput.getTenantId(), result.getMlCreateMemoryContainerInput().getTenantId()); + } + + @Test(expected = UncheckedIOException.class) + public void testFromActionRequestWithIOException() { + // Create a mock ActionRequest that throws IOException during serialization + ActionRequest mockActionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("Test IOException"); + } + }; + + MLCreateMemoryContainerRequest.fromActionRequest(mockActionRequest); + } + + @Test + public void testGetterFunctionality() { + assertEquals(testInput, requestWithAllFields.getMlCreateMemoryContainerInput()); + assertEquals(minimalInput, requestMinimal.getMlCreateMemoryContainerInput()); + } + + @Test + public void testBuilderFunctionality() { + MLCreateMemoryContainerRequest request = MLCreateMemoryContainerRequest.builder().mlCreateMemoryContainerInput(testInput).build(); + + assertNotNull(request); + assertEquals(testInput, request.getMlCreateMemoryContainerInput()); + } + + @Test + public void testInheritanceFromActionRequest() { + assertTrue(requestWithAllFields instanceof ActionRequest); + assertTrue(requestMinimal instanceof ActionRequest); + } + + @Test + public void testValidationDelegation() { + // Test that validation is properly delegated to the input object + // The request itself only validates that input is not null + // All other validation is handled by MLCreateMemoryContainerInput and MemoryStorageConfig + + // Valid input should pass validation + ActionRequestValidationException validationException = requestWithAllFields.validate(); + assertNull(validationException); + + // Null input should fail validation + MLCreateMemoryContainerRequest nullInputRequest = MLCreateMemoryContainerRequest + .builder() + .mlCreateMemoryContainerInput(null) + .build(); + + ActionRequestValidationException nullValidationException = nullInputRequest.validate(); + assertNotNull(nullValidationException); + assertEquals(1, nullValidationException.validationErrors().size()); + assertTrue(nullValidationException.validationErrors().get(0).contains("Memory container input can't be null")); + } + + @Test + public void testCompleteRoundTripSerialization() throws IOException { + // Test complete serialization round trip + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + requestWithAllFields.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLCreateMemoryContainerRequest deserializedRequest = new MLCreateMemoryContainerRequest(streamInput); + + // Verify all nested data is preserved + MLCreateMemoryContainerInput originalInput = requestWithAllFields.getMlCreateMemoryContainerInput(); + MLCreateMemoryContainerInput deserializedInput = deserializedRequest.getMlCreateMemoryContainerInput(); + + assertEquals(originalInput.getName(), deserializedInput.getName()); + assertEquals(originalInput.getDescription(), deserializedInput.getDescription()); + assertEquals(originalInput.getTenantId(), deserializedInput.getTenantId()); + + // Verify nested MemoryStorageConfig + MemoryStorageConfig originalConfig = originalInput.getMemoryStorageConfig(); + MemoryStorageConfig deserializedConfig = deserializedInput.getMemoryStorageConfig(); + + assertEquals(originalConfig.getMemoryIndexName(), deserializedConfig.getMemoryIndexName()); + assertEquals(originalConfig.isSemanticStorageEnabled(), deserializedConfig.isSemanticStorageEnabled()); + assertEquals(originalConfig.getEmbeddingModelType(), deserializedConfig.getEmbeddingModelType()); + assertEquals(originalConfig.getEmbeddingModelId(), deserializedConfig.getEmbeddingModelId()); + assertEquals(originalConfig.getLlmModelId(), deserializedConfig.getLlmModelId()); + assertEquals(originalConfig.getDimension(), deserializedConfig.getDimension()); + assertEquals(originalConfig.getMaxInferSize(), deserializedConfig.getMaxInferSize()); + } + + @Test + public void testFromActionRequestRoundTrip() throws IOException { + // Test that fromActionRequest can properly handle the same request type + MLCreateMemoryContainerRequest reconstructed = MLCreateMemoryContainerRequest.fromActionRequest(requestWithAllFields); + assertSame(requestWithAllFields, reconstructed); + + // Test with minimal request + MLCreateMemoryContainerRequest minimalReconstructed = MLCreateMemoryContainerRequest.fromActionRequest(requestMinimal); + assertSame(requestMinimal, minimalReconstructed); + } + + @Test + public void testNullInputHandling() { + MLCreateMemoryContainerRequest requestWithNull = MLCreateMemoryContainerRequest + .builder() + .mlCreateMemoryContainerInput(null) + .build(); + + assertNotNull(requestWithNull); + assertNull(requestWithNull.getMlCreateMemoryContainerInput()); + + // Validation should catch this + ActionRequestValidationException validationException = requestWithNull.validate(); + assertNotNull(validationException); + assertTrue(validationException.validationErrors().contains("Memory container input can't be null")); + } + + @Test + public void testBuilderWithNullInput() { + MLCreateMemoryContainerRequest request = MLCreateMemoryContainerRequest.builder().mlCreateMemoryContainerInput(null).build(); + + assertNotNull(request); + assertNull(request.getMlCreateMemoryContainerInput()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/MLCreateMemoryContainerResponseTests.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/MLCreateMemoryContainerResponseTests.java new file mode 100644 index 0000000000..cba573b12a --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/MLCreateMemoryContainerResponseTests.java @@ -0,0 +1,297 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.TestHelper; + +public class MLCreateMemoryContainerResponseTests { + + private MLCreateMemoryContainerResponse responseSuccess; + private MLCreateMemoryContainerResponse responseCreated; + private MLCreateMemoryContainerResponse responseWithLongId; + + @Before + public void setUp() { + // Response with success status + responseSuccess = new MLCreateMemoryContainerResponse("memory-container-123", "success"); + + // Response with created status + responseCreated = new MLCreateMemoryContainerResponse("memory-container-456", "created"); + + // Response with long ID to test edge cases + responseWithLongId = new MLCreateMemoryContainerResponse( + "memory-container-with-very-long-id-that-contains-multiple-segments-and-special-characters-789", + "success" + ); + } + + @Test + public void testConstructorWithParameters() { + assertNotNull(responseSuccess); + assertEquals("memory-container-123", responseSuccess.getMemoryContainerId()); + assertEquals("success", responseSuccess.getStatus()); + + assertNotNull(responseCreated); + assertEquals("memory-container-456", responseCreated.getMemoryContainerId()); + assertEquals("created", responseCreated.getStatus()); + } + + @Test + public void testConstructorWithLongId() { + assertNotNull(responseWithLongId); + assertEquals( + "memory-container-with-very-long-id-that-contains-multiple-segments-and-special-characters-789", + responseWithLongId.getMemoryContainerId() + ); + assertEquals("success", responseWithLongId.getStatus()); + } + + @Test + public void testStreamInputOutput() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + responseSuccess.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLCreateMemoryContainerResponse parsedResponse = new MLCreateMemoryContainerResponse(streamInput); + + assertNotNull(parsedResponse); + assertEquals(responseSuccess.getMemoryContainerId(), parsedResponse.getMemoryContainerId()); + assertEquals(responseSuccess.getStatus(), parsedResponse.getStatus()); + } + + @Test + public void testStreamInputOutputWithDifferentStatus() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + responseCreated.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLCreateMemoryContainerResponse parsedResponse = new MLCreateMemoryContainerResponse(streamInput); + + assertNotNull(parsedResponse); + assertEquals(responseCreated.getMemoryContainerId(), parsedResponse.getMemoryContainerId()); + assertEquals(responseCreated.getStatus(), parsedResponse.getStatus()); + } + + @Test + public void testStreamInputOutputWithLongId() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + responseWithLongId.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLCreateMemoryContainerResponse parsedResponse = new MLCreateMemoryContainerResponse(streamInput); + + assertNotNull(parsedResponse); + assertEquals(responseWithLongId.getMemoryContainerId(), parsedResponse.getMemoryContainerId()); + assertEquals(responseWithLongId.getStatus(), parsedResponse.getStatus()); + } + + @Test + public void testToXContentWithSuccessStatus() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + responseSuccess.toXContent(builder, EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + + assertNotNull(jsonStr); + assertTrue(jsonStr.contains("\"memory_container_id\":\"memory-container-123\"")); + assertTrue(jsonStr.contains("\"status\":\"success\"")); + + // Verify it's a proper JSON object + assertTrue(jsonStr.startsWith("{")); + assertTrue(jsonStr.endsWith("}")); + } + + @Test + public void testToXContentWithCreatedStatus() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + responseCreated.toXContent(builder, EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + + assertNotNull(jsonStr); + assertTrue(jsonStr.contains("\"memory_container_id\":\"memory-container-456\"")); + assertTrue(jsonStr.contains("\"status\":\"created\"")); + } + + @Test + public void testToXContentWithLongId() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + responseWithLongId.toXContent(builder, EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + + assertNotNull(jsonStr); + assertTrue( + jsonStr + .contains( + "\"memory_container_id\":\"memory-container-with-very-long-id-that-contains-multiple-segments-and-special-characters-789\"" + ) + ); + assertTrue(jsonStr.contains("\"status\":\"success\"")); + } + + @Test + public void testGetterMethods() { + assertEquals("memory-container-123", responseSuccess.getMemoryContainerId()); + assertEquals("success", responseSuccess.getStatus()); + + assertEquals("memory-container-456", responseCreated.getMemoryContainerId()); + assertEquals("created", responseCreated.getStatus()); + } + + @Test + public void testInheritanceFromActionResponse() { + assertTrue(responseSuccess instanceof ActionResponse); + assertTrue(responseCreated instanceof ActionResponse); + assertTrue(responseWithLongId instanceof ActionResponse); + } + + @Test + public void testToXContentObjectInterface() { + assertTrue(responseSuccess instanceof org.opensearch.core.xcontent.ToXContentObject); + assertTrue(responseCreated instanceof org.opensearch.core.xcontent.ToXContentObject); + } + + @Test + public void testCompleteRoundTripSerialization() throws IOException { + // Test complete serialization round trip + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + responseSuccess.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLCreateMemoryContainerResponse deserializedResponse = new MLCreateMemoryContainerResponse(streamInput); + + // Verify all data is preserved + assertEquals(responseSuccess.getMemoryContainerId(), deserializedResponse.getMemoryContainerId()); + assertEquals(responseSuccess.getStatus(), deserializedResponse.getStatus()); + + // Test that the deserialized response can be serialized again + BytesStreamOutput secondOutput = new BytesStreamOutput(); + deserializedResponse.writeTo(secondOutput); + + StreamInput secondInput = secondOutput.bytes().streamInput(); + MLCreateMemoryContainerResponse secondDeserialized = new MLCreateMemoryContainerResponse(secondInput); + + assertEquals(responseSuccess.getMemoryContainerId(), secondDeserialized.getMemoryContainerId()); + assertEquals(responseSuccess.getStatus(), secondDeserialized.getStatus()); + } + + @Test + public void testXContentRoundTrip() throws IOException { + // Test JSON serialization and verify structure + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + responseSuccess.toXContent(builder, EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + + // Verify JSON structure contains expected fields + assertTrue(jsonStr.contains("memory_container_id")); + assertTrue(jsonStr.contains("status")); + assertTrue(jsonStr.contains("memory-container-123")); + assertTrue(jsonStr.contains("success")); + } + + @Test + public void testWithEmptyStrings() throws IOException { + MLCreateMemoryContainerResponse responseWithEmptyStrings = new MLCreateMemoryContainerResponse("", ""); + + assertEquals("", responseWithEmptyStrings.getMemoryContainerId()); + assertEquals("", responseWithEmptyStrings.getStatus()); + + // Test serialization with empty strings + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + responseWithEmptyStrings.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLCreateMemoryContainerResponse parsedResponse = new MLCreateMemoryContainerResponse(streamInput); + + assertEquals("", parsedResponse.getMemoryContainerId()); + assertEquals("", parsedResponse.getStatus()); + + // Test JSON serialization with empty strings + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + responseWithEmptyStrings.toXContent(builder, EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonStr.contains("\"memory_container_id\":\"\"")); + assertTrue(jsonStr.contains("\"status\":\"\"")); + } + + @Test + public void testWithSpecialCharacters() throws IOException { + MLCreateMemoryContainerResponse responseWithSpecialChars = new MLCreateMemoryContainerResponse( + "memory-container-with-special-chars-!@#$%^&*()_+-=[]{}|;':\",./<>?", + "status-with-special-chars-!@#$%" + ); + + assertEquals("memory-container-with-special-chars-!@#$%^&*()_+-=[]{}|;':\",./<>?", responseWithSpecialChars.getMemoryContainerId()); + assertEquals("status-with-special-chars-!@#$%", responseWithSpecialChars.getStatus()); + + // Test serialization with special characters + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + responseWithSpecialChars.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLCreateMemoryContainerResponse parsedResponse = new MLCreateMemoryContainerResponse(streamInput); + + assertEquals(responseWithSpecialChars.getMemoryContainerId(), parsedResponse.getMemoryContainerId()); + assertEquals(responseWithSpecialChars.getStatus(), parsedResponse.getStatus()); + } + + @Test + public void testFieldConstants() throws IOException { + // Test that the response uses the correct field constants from MemoryContainerConstants + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + responseSuccess.toXContent(builder, EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + + // Verify that the correct field names are used + assertTrue(jsonStr.contains("memory_container_id")); // MEMORY_CONTAINER_ID_FIELD + assertTrue(jsonStr.contains("status")); // STATUS_FIELD + } + + @Test + public void testMultipleInstancesIndependence() { + // Test that multiple instances don't interfere with each other + MLCreateMemoryContainerResponse response1 = new MLCreateMemoryContainerResponse("id1", "status1"); + MLCreateMemoryContainerResponse response2 = new MLCreateMemoryContainerResponse("id2", "status2"); + + assertEquals("id1", response1.getMemoryContainerId()); + assertEquals("status1", response1.getStatus()); + assertEquals("id2", response2.getMemoryContainerId()); + assertEquals("status2", response2.getStatus()); + + // Verify they don't affect each other + assertNotEquals(response1.getMemoryContainerId(), response2.getMemoryContainerId()); + assertNotEquals(response1.getStatus(), response2.getStatus()); + } + + @Test + public void testLombokGetterAnnotation() { + // Test that @Getter annotation works correctly + assertNotNull(responseSuccess.getMemoryContainerId()); + assertNotNull(responseSuccess.getStatus()); + + // Test that getters return the correct values + assertEquals("memory-container-123", responseSuccess.getMemoryContainerId()); + assertEquals("success", responseSuccess.getStatus()); + } + + // Helper method for assertions + private void assertNotEquals(Object obj1, Object obj2) { + org.junit.Assert.assertNotEquals(obj1, obj2); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerDeleteRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerDeleteRequestTest.java new file mode 100644 index 0000000000..fbc126e64b --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerDeleteRequestTest.java @@ -0,0 +1,167 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class MLMemoryContainerDeleteRequestTest { + private String memoryContainerId; + + @Before + public void setUp() { + memoryContainerId = "test_id"; + } + + @Test + public void writeTo_Success() throws IOException { + MLMemoryContainerDeleteRequest mlMemoryContainerDeleteRequest = MLMemoryContainerDeleteRequest + .builder() + .memoryContainerId(memoryContainerId) + .build(); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + mlMemoryContainerDeleteRequest.writeTo(bytesStreamOutput); + MLMemoryContainerDeleteRequest parsedMemoryContainer = new MLMemoryContainerDeleteRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(parsedMemoryContainer.getMemoryContainerId(), memoryContainerId); + } + + @Test + public void validate_Success() { + MLMemoryContainerDeleteRequest mlMemoryContainerDeleteRequest = MLMemoryContainerDeleteRequest + .builder() + .memoryContainerId(memoryContainerId) + .build(); + ActionRequestValidationException actionRequestValidationException = mlMemoryContainerDeleteRequest.validate(); + assertNull(actionRequestValidationException); + } + + @Test + public void validate_Exception_NullMemoryContainerId() { + MLMemoryContainerDeleteRequest mlMemoryContainerDeleteRequest = MLMemoryContainerDeleteRequest.builder().build(); + + ActionRequestValidationException exception = mlMemoryContainerDeleteRequest.validate(); + assertEquals("Validation Failed: 1: ML memory container id can't be null;", exception.getMessage()); + } + + @Test + public void fromActionRequest_Success() { + MLMemoryContainerDeleteRequest mlMemoryContainerDeleteRequest = MLMemoryContainerDeleteRequest + .builder() + .memoryContainerId(memoryContainerId) + .build(); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + mlMemoryContainerDeleteRequest.writeTo(out); + } + }; + MLMemoryContainerDeleteRequest result = MLMemoryContainerDeleteRequest.fromActionRequest(actionRequest); + assertNotSame(result, mlMemoryContainerDeleteRequest); + assertEquals(result.getMemoryContainerId(), mlMemoryContainerDeleteRequest.getMemoryContainerId()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("test"); + } + }; + MLMemoryContainerDeleteRequest.fromActionRequest(actionRequest); + } + + @Test + public void fromActionRequestWithMemoryContainerDeleteRequest_Success() { + MLMemoryContainerDeleteRequest mlMemoryContainerDeleteRequest = MLMemoryContainerDeleteRequest + .builder() + .memoryContainerId(memoryContainerId) + .build(); + MLMemoryContainerDeleteRequest mlMemoryContainerDeleteRequestFromActionRequest = MLMemoryContainerDeleteRequest + .fromActionRequest(mlMemoryContainerDeleteRequest); + assertSame(mlMemoryContainerDeleteRequest, mlMemoryContainerDeleteRequestFromActionRequest); + assertEquals( + mlMemoryContainerDeleteRequest.getMemoryContainerId(), + mlMemoryContainerDeleteRequestFromActionRequest.getMemoryContainerId() + ); + } + + @Test + public void writeTo_withTenantId_Success() throws IOException { + String tenantId = "tenant-1"; + MLMemoryContainerDeleteRequest request = MLMemoryContainerDeleteRequest + .builder() + .memoryContainerId(memoryContainerId) + .tenantId(tenantId) + .build(); + + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + MLMemoryContainerDeleteRequest parsedRequest = new MLMemoryContainerDeleteRequest(out.bytes().streamInput()); + + assertEquals(memoryContainerId, parsedRequest.getMemoryContainerId()); + assertEquals(tenantId, parsedRequest.getTenantId()); + } + + @Test + public void writeTo_withoutTenantId_Success() throws IOException { + MLMemoryContainerDeleteRequest request = MLMemoryContainerDeleteRequest.builder().memoryContainerId(memoryContainerId).build(); + + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + MLMemoryContainerDeleteRequest parsedRequest = new MLMemoryContainerDeleteRequest(out.bytes().streamInput()); + + assertEquals(memoryContainerId, parsedRequest.getMemoryContainerId()); + assertNull(parsedRequest.getTenantId()); + } + + @Test + public void fromActionRequest_withTenantId_Success() { + MLMemoryContainerDeleteRequest originalRequest = MLMemoryContainerDeleteRequest + .builder() + .memoryContainerId(memoryContainerId) + .tenantId("tenant-1") + .build(); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + originalRequest.writeTo(out); + } + }; + + MLMemoryContainerDeleteRequest parsedRequest = MLMemoryContainerDeleteRequest.fromActionRequest(actionRequest); + assertNotSame(originalRequest, parsedRequest); + assertEquals(originalRequest.getMemoryContainerId(), parsedRequest.getMemoryContainerId()); + assertEquals(originalRequest.getTenantId(), parsedRequest.getTenantId()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerGetRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerGetRequestTests.java new file mode 100644 index 0000000000..4c6eef6bd0 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerGetRequestTests.java @@ -0,0 +1,412 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class MLMemoryContainerGetRequestTests { + + private MLMemoryContainerGetRequest requestWithTenant; + private MLMemoryContainerGetRequest requestWithoutTenant; + private MLMemoryContainerGetRequest requestWithLongId; + + @Before + public void setUp() { + // Request with tenant ID + requestWithTenant = MLMemoryContainerGetRequest.builder().memoryContainerId("memory-container-123").tenantId("test-tenant").build(); + + // Request without tenant ID + requestWithoutTenant = MLMemoryContainerGetRequest.builder().memoryContainerId("memory-container-456").tenantId(null).build(); + + // Request with long ID to test edge cases + requestWithLongId = MLMemoryContainerGetRequest + .builder() + .memoryContainerId("memory-container-with-very-long-id-that-contains-multiple-segments-and-special-characters-789") + .tenantId("tenant-with-long-name-and-special-characters-!@#$%") + .build(); + } + + @Test + public void testConstructorWithBuilder() { + assertNotNull(requestWithTenant); + assertEquals("memory-container-123", requestWithTenant.getMemoryContainerId()); + assertEquals("test-tenant", requestWithTenant.getTenantId()); + + assertNotNull(requestWithoutTenant); + assertEquals("memory-container-456", requestWithoutTenant.getMemoryContainerId()); + assertNull(requestWithoutTenant.getTenantId()); + } + + @Test + public void testConstructorWithBuilderLongValues() { + assertNotNull(requestWithLongId); + assertEquals( + "memory-container-with-very-long-id-that-contains-multiple-segments-and-special-characters-789", + requestWithLongId.getMemoryContainerId() + ); + assertEquals("tenant-with-long-name-and-special-characters-!@#$%", requestWithLongId.getTenantId()); + } + + @Test + public void testConstructorWithParameters() { + MLMemoryContainerGetRequest request = new MLMemoryContainerGetRequest("test-id", "test-tenant"); + + assertNotNull(request); + assertEquals("test-id", request.getMemoryContainerId()); + assertEquals("test-tenant", request.getTenantId()); + } + + @Test + public void testConstructorWithNullTenant() { + MLMemoryContainerGetRequest request = new MLMemoryContainerGetRequest("test-id", null); + + assertNotNull(request); + assertEquals("test-id", request.getMemoryContainerId()); + assertNull(request.getTenantId()); + } + + @Test + public void testStreamInputOutput() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + requestWithTenant.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLMemoryContainerGetRequest parsedRequest = new MLMemoryContainerGetRequest(streamInput); + + assertNotNull(parsedRequest); + assertEquals(requestWithTenant.getMemoryContainerId(), parsedRequest.getMemoryContainerId()); + assertEquals(requestWithTenant.getTenantId(), parsedRequest.getTenantId()); + } + + @Test + public void testStreamInputOutputWithoutTenant() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + requestWithoutTenant.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLMemoryContainerGetRequest parsedRequest = new MLMemoryContainerGetRequest(streamInput); + + assertNotNull(parsedRequest); + assertEquals(requestWithoutTenant.getMemoryContainerId(), parsedRequest.getMemoryContainerId()); + assertNull(parsedRequest.getTenantId()); + } + + @Test + public void testStreamInputOutputWithLongValues() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + requestWithLongId.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLMemoryContainerGetRequest parsedRequest = new MLMemoryContainerGetRequest(streamInput); + + assertNotNull(parsedRequest); + assertEquals(requestWithLongId.getMemoryContainerId(), parsedRequest.getMemoryContainerId()); + assertEquals(requestWithLongId.getTenantId(), parsedRequest.getTenantId()); + } + + @Test + public void testValidateWithValidRequest() { + ActionRequestValidationException validationException = requestWithTenant.validate(); + assertNull(validationException); + } + + @Test + public void testValidateWithValidRequestWithoutTenant() { + ActionRequestValidationException validationException = requestWithoutTenant.validate(); + assertNull(validationException); + } + + @Test + public void testValidateWithNullMemoryContainerId() { + MLMemoryContainerGetRequest requestWithNullId = MLMemoryContainerGetRequest + .builder() + .memoryContainerId(null) + .tenantId("test-tenant") + .build(); + + ActionRequestValidationException validationException = requestWithNullId.validate(); + + assertNotNull(validationException); + assertTrue(validationException.validationErrors().contains("Memory container id can't be null")); + } + + @Test + public void testValidateWithEmptyMemoryContainerId() { + // Empty string is considered valid (not null) + MLMemoryContainerGetRequest requestWithEmptyId = MLMemoryContainerGetRequest + .builder() + .memoryContainerId("") + .tenantId("test-tenant") + .build(); + + ActionRequestValidationException validationException = requestWithEmptyId.validate(); + assertNull(validationException); // Empty string should be valid + } + + @Test + public void testFromActionRequestWithSameType() { + MLMemoryContainerGetRequest result = MLMemoryContainerGetRequest.fromActionRequest(requestWithTenant); + + assertSame(requestWithTenant, result); + } + + @Test + public void testFromActionRequestWithDifferentType() throws IOException { + // Create a properly serializable ActionRequest that writes data in the expected format + ActionRequest mockActionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + // Write data in the same format as MLMemoryContainerGetRequest + super.writeTo(out); // Write ActionRequest base data + out.writeString("test-memory-container-id"); + out.writeOptionalString("test-tenant-id"); + } + }; + + MLMemoryContainerGetRequest result = MLMemoryContainerGetRequest.fromActionRequest(mockActionRequest); + + assertNotNull(result); + assertEquals("test-memory-container-id", result.getMemoryContainerId()); + assertEquals("test-tenant-id", result.getTenantId()); + } + + @Test(expected = UncheckedIOException.class) + public void testFromActionRequestWithIOException() { + // Create a mock ActionRequest that throws IOException during serialization + ActionRequest mockActionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("Test IOException"); + } + }; + + MLMemoryContainerGetRequest.fromActionRequest(mockActionRequest); + } + + @Test + public void testGetterMethods() { + assertEquals("memory-container-123", requestWithTenant.getMemoryContainerId()); + assertEquals("test-tenant", requestWithTenant.getTenantId()); + + assertEquals("memory-container-456", requestWithoutTenant.getMemoryContainerId()); + assertNull(requestWithoutTenant.getTenantId()); + } + + @Test + public void testBuilderFunctionality() { + MLMemoryContainerGetRequest request = MLMemoryContainerGetRequest + .builder() + .memoryContainerId("builder-test-id") + .tenantId("builder-test-tenant") + .build(); + + assertNotNull(request); + assertEquals("builder-test-id", request.getMemoryContainerId()); + assertEquals("builder-test-tenant", request.getTenantId()); + } + + @Test + public void testInheritanceFromActionRequest() { + assertTrue(requestWithTenant instanceof ActionRequest); + assertTrue(requestWithoutTenant instanceof ActionRequest); + assertTrue(requestWithLongId instanceof ActionRequest); + } + + @Test + public void testFieldsAreFinal() { + // Test that fields are final (immutable) - this is enforced by Lombok @FieldDefaults + // We can't directly test final fields, but we can test that there are no setters + try { + // Try to find setter methods - should not exist due to final fields + requestWithTenant.getClass().getMethod("setMemoryContainerId", String.class); + org.junit.Assert.fail("Setter method should not exist for final field"); + } catch (NoSuchMethodException e) { + // Expected - no setter should exist + } + + try { + requestWithTenant.getClass().getMethod("setTenantId", String.class); + org.junit.Assert.fail("Setter method should not exist for final field"); + } catch (NoSuchMethodException e) { + // Expected - no setter should exist + } + } + + @Test + public void testToStringFunctionality() { + String toString = requestWithTenant.toString(); + assertNotNull(toString); + assertTrue(toString.contains("memory-container-123")); + assertTrue(toString.contains("test-tenant")); + assertTrue(toString.contains("MLMemoryContainerGetRequest")); + } + + @Test + public void testCompleteRoundTripSerialization() throws IOException { + // Test complete serialization round trip + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + requestWithTenant.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLMemoryContainerGetRequest deserializedRequest = new MLMemoryContainerGetRequest(streamInput); + + // Verify all data is preserved + assertEquals(requestWithTenant.getMemoryContainerId(), deserializedRequest.getMemoryContainerId()); + assertEquals(requestWithTenant.getTenantId(), deserializedRequest.getTenantId()); + + // Test that the deserialized request can be serialized again + BytesStreamOutput secondOutput = new BytesStreamOutput(); + deserializedRequest.writeTo(secondOutput); + + StreamInput secondInput = secondOutput.bytes().streamInput(); + MLMemoryContainerGetRequest secondDeserialized = new MLMemoryContainerGetRequest(secondInput); + + assertEquals(requestWithTenant.getMemoryContainerId(), secondDeserialized.getMemoryContainerId()); + assertEquals(requestWithTenant.getTenantId(), secondDeserialized.getTenantId()); + } + + @Test + public void testValidationWithMultipleErrors() { + // Create a request that would have multiple validation errors if we had more validation rules + MLMemoryContainerGetRequest requestWithNullId = MLMemoryContainerGetRequest + .builder() + .memoryContainerId(null) + .tenantId("test-tenant") + .build(); + + ActionRequestValidationException validationException = requestWithNullId.validate(); + + assertNotNull(validationException); + assertEquals(1, validationException.validationErrors().size()); + assertTrue(validationException.validationErrors().get(0).contains("Memory container id can't be null")); + } + + @Test + public void testWithSpecialCharacters() throws IOException { + MLMemoryContainerGetRequest requestWithSpecialChars = MLMemoryContainerGetRequest + .builder() + .memoryContainerId("memory-container-with-special-chars-!@#$%^&*()_+-=[]{}|;':\",./<>?") + .tenantId("tenant-with-special-chars-!@#$%") + .build(); + + assertEquals("memory-container-with-special-chars-!@#$%^&*()_+-=[]{}|;':\",./<>?", requestWithSpecialChars.getMemoryContainerId()); + assertEquals("tenant-with-special-chars-!@#$%", requestWithSpecialChars.getTenantId()); + + // Test serialization with special characters + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + requestWithSpecialChars.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLMemoryContainerGetRequest parsedRequest = new MLMemoryContainerGetRequest(streamInput); + + assertEquals(requestWithSpecialChars.getMemoryContainerId(), parsedRequest.getMemoryContainerId()); + assertEquals(requestWithSpecialChars.getTenantId(), parsedRequest.getTenantId()); + } + + @Test + public void testWithEmptyStrings() throws IOException { + MLMemoryContainerGetRequest requestWithEmptyStrings = MLMemoryContainerGetRequest + .builder() + .memoryContainerId("") + .tenantId("") + .build(); + + assertEquals("", requestWithEmptyStrings.getMemoryContainerId()); + assertEquals("", requestWithEmptyStrings.getTenantId()); + + // Test serialization with empty strings + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + requestWithEmptyStrings.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLMemoryContainerGetRequest parsedRequest = new MLMemoryContainerGetRequest(streamInput); + + assertEquals("", parsedRequest.getMemoryContainerId()); + assertEquals("", parsedRequest.getTenantId()); + + // Validation should pass with empty string (not null) + ActionRequestValidationException validationException = requestWithEmptyStrings.validate(); + assertNull(validationException); + } + + @Test + public void testFromActionRequestRoundTrip() throws IOException { + // Test that fromActionRequest can properly handle the same request type + MLMemoryContainerGetRequest reconstructed = MLMemoryContainerGetRequest.fromActionRequest(requestWithTenant); + assertSame(requestWithTenant, reconstructed); + + // Test with request without tenant + MLMemoryContainerGetRequest minimalReconstructed = MLMemoryContainerGetRequest.fromActionRequest(requestWithoutTenant); + assertSame(requestWithoutTenant, minimalReconstructed); + } + + @Test + public void testLombokAnnotations() { + // Test @Getter annotation + assertNotNull(requestWithTenant.getMemoryContainerId()); + assertNotNull(requestWithTenant.getTenantId()); + + // Test @ToString annotation + String toString = requestWithTenant.toString(); + assertNotNull(toString); + assertTrue(toString.length() > 0); + + // Test @FieldDefaults (final fields) - no setters should exist + try { + requestWithTenant.getClass().getMethod("setMemoryContainerId", String.class); + org.junit.Assert.fail("Should not have setter for final field"); + } catch (NoSuchMethodException e) { + // Expected + } + } + + @Test + public void testMultipleInstancesIndependence() { + // Test that multiple instances don't interfere with each other + MLMemoryContainerGetRequest request1 = MLMemoryContainerGetRequest.builder().memoryContainerId("id1").tenantId("tenant1").build(); + + MLMemoryContainerGetRequest request2 = MLMemoryContainerGetRequest.builder().memoryContainerId("id2").tenantId("tenant2").build(); + + assertEquals("id1", request1.getMemoryContainerId()); + assertEquals("tenant1", request1.getTenantId()); + assertEquals("id2", request2.getMemoryContainerId()); + assertEquals("tenant2", request2.getTenantId()); + + // Verify they don't affect each other + assertNotEquals(request1.getMemoryContainerId(), request2.getMemoryContainerId()); + assertNotEquals(request1.getTenantId(), request2.getTenantId()); + } + + // Helper method for assertions + private void assertNotEquals(Object obj1, Object obj2) { + org.junit.Assert.assertNotEquals(obj1, obj2); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerGetResponseTests.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerGetResponseTests.java new file mode 100644 index 0000000000..ba997389cd --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/MLMemoryContainerGetResponseTests.java @@ -0,0 +1,444 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.time.Instant; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.ml.common.memorycontainer.MLMemoryContainer; +import org.opensearch.ml.common.memorycontainer.MemoryStorageConfig; + +public class MLMemoryContainerGetResponseTests { + + private MLMemoryContainerGetResponse responseWithAllFields; + private MLMemoryContainerGetResponse responseMinimal; + private MLMemoryContainer testMemoryContainer; + private MLMemoryContainer minimalMemoryContainer; + private MemoryStorageConfig testMemoryStorageConfig; + private User testUser; + private Instant testCreatedTime; + private Instant testLastUpdatedTime; + + @Before + public void setUp() { + testUser = new User(); // Use empty User constructor + // Use millisecond precision to avoid precision loss in JSON serialization + testCreatedTime = Instant.ofEpochMilli(System.currentTimeMillis()); + testLastUpdatedTime = Instant.ofEpochMilli(System.currentTimeMillis() + 3600000); + + // Create test memory storage config + testMemoryStorageConfig = MemoryStorageConfig + .builder() + .memoryIndexName("test-memory-index") + .embeddingModelType(FunctionName.TEXT_EMBEDDING) + .embeddingModelId("test-embedding-model") + .llmModelId("test-llm-model") + .dimension(768) + .maxInferSize(8) + .build(); + + // Create test memory container with all fields + testMemoryContainer = MLMemoryContainer + .builder() + .name("test-memory-container") + .description("Test memory container description") + .owner(testUser) + .tenantId("test-tenant") + .createdTime(testCreatedTime) + .lastUpdatedTime(testLastUpdatedTime) + .memoryStorageConfig(testMemoryStorageConfig) + .build(); + + // Create minimal memory container + minimalMemoryContainer = MLMemoryContainer.builder().name("minimal-container").build(); + + // Create responses + responseWithAllFields = MLMemoryContainerGetResponse.builder().mlMemoryContainer(testMemoryContainer).build(); + + responseMinimal = MLMemoryContainerGetResponse.builder().mlMemoryContainer(minimalMemoryContainer).build(); + } + + @Test + public void testConstructorWithBuilder() { + assertNotNull(responseWithAllFields); + assertEquals(testMemoryContainer, responseWithAllFields.getMlMemoryContainer()); + + assertNotNull(responseMinimal); + assertEquals(minimalMemoryContainer, responseMinimal.getMlMemoryContainer()); + } + + @Test + public void testConstructorWithMemoryContainer() { + MLMemoryContainerGetResponse response = new MLMemoryContainerGetResponse(testMemoryContainer); + + assertNotNull(response); + assertEquals(testMemoryContainer, response.getMlMemoryContainer()); + } + + @Test + public void testStreamInputOutput() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + responseWithAllFields.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLMemoryContainerGetResponse parsedResponse = new MLMemoryContainerGetResponse(streamInput); + + assertNotNull(parsedResponse); + assertNotNull(parsedResponse.getMlMemoryContainer()); + + // Verify the memory container fields + MLMemoryContainer originalContainer = responseWithAllFields.getMlMemoryContainer(); + MLMemoryContainer parsedContainer = parsedResponse.getMlMemoryContainer(); + + assertEquals(originalContainer.getName(), parsedContainer.getName()); + assertEquals(originalContainer.getDescription(), parsedContainer.getDescription()); + assertEquals(originalContainer.getTenantId(), parsedContainer.getTenantId()); + assertEquals(originalContainer.getCreatedTime(), parsedContainer.getCreatedTime()); + assertEquals(originalContainer.getLastUpdatedTime(), parsedContainer.getLastUpdatedTime()); + assertEquals(originalContainer.getMemoryStorageConfig(), parsedContainer.getMemoryStorageConfig()); + } + + @Test + public void testStreamInputOutputWithMinimalFields() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + responseMinimal.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLMemoryContainerGetResponse parsedResponse = new MLMemoryContainerGetResponse(streamInput); + + assertNotNull(parsedResponse); + assertNotNull(parsedResponse.getMlMemoryContainer()); + + MLMemoryContainer parsedContainer = parsedResponse.getMlMemoryContainer(); + assertEquals("minimal-container", parsedContainer.getName()); + assertNull(parsedContainer.getDescription()); + assertNull(parsedContainer.getTenantId()); + assertNull(parsedContainer.getCreatedTime()); + assertNull(parsedContainer.getLastUpdatedTime()); + assertNull(parsedContainer.getMemoryStorageConfig()); + } + + @Test + public void testToXContentWithAllFields() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + responseWithAllFields.toXContent(builder, EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + + assertNotNull(jsonStr); + // Verify that all memory container fields are present in the JSON + assertTrue(jsonStr.contains("\"name\":\"test-memory-container\"")); + assertTrue(jsonStr.contains("\"description\":\"Test memory container description\"")); + assertTrue(jsonStr.contains("\"tenant_id\":\"test-tenant\"")); + assertTrue(jsonStr.contains("\"created_time\":" + testCreatedTime.toEpochMilli())); + assertTrue(jsonStr.contains("\"last_updated_time\":" + testLastUpdatedTime.toEpochMilli())); + assertTrue(jsonStr.contains("\"memory_storage_config\"")); + + // Verify nested memory storage config fields + assertTrue(jsonStr.contains("\"memory_index_name\":\"test-memory-index\"")); + assertTrue(jsonStr.contains("\"embedding_model_type\":\"TEXT_EMBEDDING\"")); + assertTrue(jsonStr.contains("\"embedding_model_id\":\"test-embedding-model\"")); + } + + @Test + public void testToXContentWithMinimalFields() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + responseMinimal.toXContent(builder, EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + + assertNotNull(jsonStr); + // Verify only required fields are present + assertTrue(jsonStr.contains("\"name\":\"minimal-container\"")); + // Verify optional fields are not present + assertFalse(jsonStr.contains("\"description\"")); + assertFalse(jsonStr.contains("\"tenant_id\"")); + assertFalse(jsonStr.contains("\"created_time\"")); + assertFalse(jsonStr.contains("\"last_updated_time\"")); + assertFalse(jsonStr.contains("\"memory_storage_config\"")); + } + + @Test + public void testGetterMethod() { + assertEquals(testMemoryContainer, responseWithAllFields.getMlMemoryContainer()); + assertEquals(minimalMemoryContainer, responseMinimal.getMlMemoryContainer()); + } + + @Test + public void testInheritanceFromActionResponse() { + assertTrue(responseWithAllFields instanceof ActionResponse); + assertTrue(responseMinimal instanceof ActionResponse); + } + + @Test + public void testToXContentObjectInterface() { + assertTrue(responseWithAllFields instanceof org.opensearch.core.xcontent.ToXContentObject); + assertTrue(responseMinimal instanceof org.opensearch.core.xcontent.ToXContentObject); + } + + @Test + public void testFromActionResponseWithSameType() { + MLMemoryContainerGetResponse result = MLMemoryContainerGetResponse.fromActionResponse(responseWithAllFields); + + assertSame(responseWithAllFields, result); + } + + @Test + public void testFromActionResponseWithDifferentType() throws IOException { + // Create a properly serializable ActionResponse that writes data in the expected format + ActionResponse mockActionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + // Write data in the same format as MLMemoryContainerGetResponse + testMemoryContainer.writeTo(out); + } + }; + + MLMemoryContainerGetResponse result = MLMemoryContainerGetResponse.fromActionResponse(mockActionResponse); + + assertNotNull(result); + assertNotNull(result.getMlMemoryContainer()); + assertEquals(testMemoryContainer.getName(), result.getMlMemoryContainer().getName()); + assertEquals(testMemoryContainer.getDescription(), result.getMlMemoryContainer().getDescription()); + assertEquals(testMemoryContainer.getTenantId(), result.getMlMemoryContainer().getTenantId()); + } + + @Test(expected = UncheckedIOException.class) + public void testFromActionResponseWithIOException() { + // Create a mock ActionResponse that throws IOException during serialization + ActionResponse mockActionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("Test IOException"); + } + }; + + MLMemoryContainerGetResponse.fromActionResponse(mockActionResponse); + } + + @Test + public void testBuilderFunctionality() { + MLMemoryContainerGetResponse response = MLMemoryContainerGetResponse.builder().mlMemoryContainer(testMemoryContainer).build(); + + assertNotNull(response); + assertEquals(testMemoryContainer, response.getMlMemoryContainer()); + } + + @Test + public void testToStringFunctionality() { + String toString = responseWithAllFields.toString(); + assertNotNull(toString); + assertTrue(toString.length() > 0); + // The toString should contain the class name + assertTrue(toString.contains("MLMemoryContainerGetResponse")); + } + + @Test + public void testCompleteRoundTripSerialization() throws IOException { + // Test complete serialization round trip + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + responseWithAllFields.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLMemoryContainerGetResponse deserializedResponse = new MLMemoryContainerGetResponse(streamInput); + + // Verify all nested data is preserved + MLMemoryContainer originalContainer = responseWithAllFields.getMlMemoryContainer(); + MLMemoryContainer deserializedContainer = deserializedResponse.getMlMemoryContainer(); + + assertEquals(originalContainer.getName(), deserializedContainer.getName()); + assertEquals(originalContainer.getDescription(), deserializedContainer.getDescription()); + assertEquals(originalContainer.getTenantId(), deserializedContainer.getTenantId()); + assertEquals(originalContainer.getCreatedTime(), deserializedContainer.getCreatedTime()); + assertEquals(originalContainer.getLastUpdatedTime(), deserializedContainer.getLastUpdatedTime()); + + // Verify nested MemoryStorageConfig + MemoryStorageConfig originalConfig = originalContainer.getMemoryStorageConfig(); + MemoryStorageConfig deserializedConfig = deserializedContainer.getMemoryStorageConfig(); + + assertEquals(originalConfig.getMemoryIndexName(), deserializedConfig.getMemoryIndexName()); + assertEquals(originalConfig.isSemanticStorageEnabled(), deserializedConfig.isSemanticStorageEnabled()); + assertEquals(originalConfig.getEmbeddingModelType(), deserializedConfig.getEmbeddingModelType()); + assertEquals(originalConfig.getEmbeddingModelId(), deserializedConfig.getEmbeddingModelId()); + assertEquals(originalConfig.getLlmModelId(), deserializedConfig.getLlmModelId()); + assertEquals(originalConfig.getDimension(), deserializedConfig.getDimension()); + assertEquals(originalConfig.getMaxInferSize(), deserializedConfig.getMaxInferSize()); + } + + @Test + public void testXContentRoundTrip() throws IOException { + // Test JSON serialization and verify structure + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + responseWithAllFields.toXContent(builder, EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + + // Verify JSON structure contains expected fields + assertTrue(jsonStr.contains("name")); + assertTrue(jsonStr.contains("description")); + assertTrue(jsonStr.contains("tenant_id")); + assertTrue(jsonStr.contains("created_time")); + assertTrue(jsonStr.contains("last_updated_time")); + assertTrue(jsonStr.contains("memory_storage_config")); + assertTrue(jsonStr.contains("test-memory-container")); + assertTrue(jsonStr.contains("Test memory container description")); + } + + @Test + public void testWithNullMemoryContainer() throws IOException { + MLMemoryContainerGetResponse responseWithNull = MLMemoryContainerGetResponse.builder().mlMemoryContainer(null).build(); + + assertNotNull(responseWithNull); + assertNull(responseWithNull.getMlMemoryContainer()); + } + + @Test + public void testFromActionResponseRoundTrip() throws IOException { + // Test that fromActionResponse can properly handle the same response type + MLMemoryContainerGetResponse reconstructed = MLMemoryContainerGetResponse.fromActionResponse(responseWithAllFields); + assertSame(responseWithAllFields, reconstructed); + + // Test with minimal response + MLMemoryContainerGetResponse minimalReconstructed = MLMemoryContainerGetResponse.fromActionResponse(responseMinimal); + assertSame(responseMinimal, minimalReconstructed); + } + + @Test + public void testLombokAnnotations() { + // Test @Getter annotation + assertNotNull(responseWithAllFields.getMlMemoryContainer()); + + // Test @ToString annotation + String toString = responseWithAllFields.toString(); + assertNotNull(toString); + assertTrue(toString.length() > 0); + + // Test @Builder annotation + MLMemoryContainerGetResponse builderResponse = MLMemoryContainerGetResponse + .builder() + .mlMemoryContainer(testMemoryContainer) + .build(); + assertNotNull(builderResponse); + assertEquals(testMemoryContainer, builderResponse.getMlMemoryContainer()); + } + + @Test + public void testMultipleInstancesIndependence() { + // Test that multiple instances don't interfere with each other + MLMemoryContainer container1 = MLMemoryContainer.builder().name("container1").description("description1").build(); + + MLMemoryContainer container2 = MLMemoryContainer.builder().name("container2").description("description2").build(); + + MLMemoryContainerGetResponse response1 = MLMemoryContainerGetResponse.builder().mlMemoryContainer(container1).build(); + + MLMemoryContainerGetResponse response2 = MLMemoryContainerGetResponse.builder().mlMemoryContainer(container2).build(); + + assertEquals("container1", response1.getMlMemoryContainer().getName()); + assertEquals("description1", response1.getMlMemoryContainer().getDescription()); + assertEquals("container2", response2.getMlMemoryContainer().getName()); + assertEquals("description2", response2.getMlMemoryContainer().getDescription()); + + // Verify they don't affect each other + assertNotEquals(response1.getMlMemoryContainer().getName(), response2.getMlMemoryContainer().getName()); + assertNotEquals(response1.getMlMemoryContainer().getDescription(), response2.getMlMemoryContainer().getDescription()); + } + + @Test + public void testDelegationToMLMemoryContainer() throws IOException { + // Test that the response properly delegates to the wrapped MLMemoryContainer + + // Test toXContent delegation + XContentBuilder responseBuilder = XContentBuilder.builder(XContentType.JSON.xContent()); + responseWithAllFields.toXContent(responseBuilder, EMPTY_PARAMS); + String responseJson = TestHelper.xContentBuilderToString(responseBuilder); + + XContentBuilder containerBuilder = XContentBuilder.builder(XContentType.JSON.xContent()); + testMemoryContainer.toXContent(containerBuilder, EMPTY_PARAMS); + String containerJson = TestHelper.xContentBuilderToString(containerBuilder); + + // The JSON output should be identical since response delegates to container + assertEquals(containerJson, responseJson); + } + + @Test + public void testSerializationWithComplexMemoryContainer() throws IOException { + // Create a memory container with complex nested structure + MemoryStorageConfig complexConfig = MemoryStorageConfig + .builder() + .memoryIndexName("complex-memory-index-with-long-name") + .embeddingModelType(FunctionName.SPARSE_ENCODING) + .embeddingModelId("complex-sparse-encoding-model-id") + .llmModelId("complex-llm-model-id") + .maxInferSize(10) + .build(); + + MLMemoryContainer complexContainer = MLMemoryContainer + .builder() + .name("complex-memory-container-with-special-chars-!@#$%") + .description("Complex description with\nnewlines and\ttabs and special chars: !@#$%^&*()") + .owner(testUser) + .tenantId("complex-tenant-id-with-special-chars") + .createdTime(testCreatedTime) + .lastUpdatedTime(testLastUpdatedTime) + .memoryStorageConfig(complexConfig) + .build(); + + MLMemoryContainerGetResponse complexResponse = MLMemoryContainerGetResponse.builder().mlMemoryContainer(complexContainer).build(); + + // Test serialization + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + complexResponse.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLMemoryContainerGetResponse parsedResponse = new MLMemoryContainerGetResponse(streamInput); + + // Verify complex data is preserved + MLMemoryContainer parsedContainer = parsedResponse.getMlMemoryContainer(); + assertEquals(complexContainer.getName(), parsedContainer.getName()); + assertEquals(complexContainer.getDescription(), parsedContainer.getDescription()); + assertEquals(complexContainer.getTenantId(), parsedContainer.getTenantId()); + assertEquals(complexContainer.getMemoryStorageConfig(), parsedContainer.getMemoryStorageConfig()); + } + + @Test + public void testActionResponseIntegration() { + // Test that the response properly integrates with ActionResponse framework + assertTrue(responseWithAllFields instanceof ActionResponse); + + // Test that it can be treated as an ActionResponse + ActionResponse genericResponse = responseWithAllFields; + assertNotNull(genericResponse); + + // Test conversion back + MLMemoryContainerGetResponse convertedBack = MLMemoryContainerGetResponse.fromActionResponse(genericResponse); + assertSame(responseWithAllFields, convertedBack); + } + + // Helper method for assertions + private void assertNotEquals(Object obj1, Object obj2) { + org.junit.Assert.assertNotEquals(obj1, obj2); + } + + private void assertFalse(boolean condition) { + org.junit.Assert.assertFalse(condition); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/MemoryContainerActionClassesTest.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/MemoryContainerActionClassesTest.java new file mode 100644 index 0000000000..29c7c920d8 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/MemoryContainerActionClassesTest.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import org.junit.Test; + +/** + * Tests for memory container Action classes + */ +public class MemoryContainerActionClassesTest { + + @Test + public void testMLCreateMemoryContainerAction() { + assertNotNull(MLCreateMemoryContainerAction.INSTANCE); + assertEquals("cluster:admin/opensearch/ml/memory_containers/create", MLCreateMemoryContainerAction.NAME); + assertEquals(MLCreateMemoryContainerAction.NAME, MLCreateMemoryContainerAction.INSTANCE.name()); + } + + @Test + public void testMLMemoryContainerGetAction() { + assertNotNull(MLMemoryContainerGetAction.INSTANCE); + assertEquals("cluster:admin/opensearch/ml/memory_containers/get", MLMemoryContainerGetAction.NAME); + assertEquals(MLMemoryContainerGetAction.NAME, MLMemoryContainerGetAction.INSTANCE.name()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesInputTest.java new file mode 100644 index 0000000000..20c8593d67 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesInputTest.java @@ -0,0 +1,366 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; + +public class MLAddMemoriesInputTest { + + private MLAddMemoriesInput inputWithAllFields; + private MLAddMemoriesInput inputMinimal; + private MLAddMemoriesInput inputNoOptionals; + private List testMessages; + private Map testTags; + + @Before + public void setUp() { + testMessages = Arrays + .asList( + new MessageInput("user", "Hello, how are you?"), + new MessageInput("assistant", "I'm doing well, thank you!"), + new MessageInput("user", "What can you help me with?") + ); + + testTags = new HashMap<>(); + testTags.put("topic", "greeting"); + testTags.put("priority", "low"); + + // Input with all fields + inputWithAllFields = MLAddMemoriesInput + .builder() + .memoryContainerId("container-123") + .messages(testMessages) + .sessionId("session-456") + .agentId("agent-789") + .infer(true) + .tags(testTags) + .build(); + + // Minimal input (only required fields) + inputMinimal = MLAddMemoriesInput.builder().messages(Arrays.asList(new MessageInput(null, "Single message"))).build(); + + // Input without optional fields + inputNoOptionals = new MLAddMemoriesInput( + "container-999", + Arrays.asList(new MessageInput("user", "Test message")), + null, + null, + null, + null + ); + } + + @Test + public void testBuilderWithAllFields() { + assertNotNull(inputWithAllFields); + assertEquals("container-123", inputWithAllFields.getMemoryContainerId()); + assertEquals(testMessages, inputWithAllFields.getMessages()); + assertEquals(3, inputWithAllFields.getMessages().size()); + assertEquals("session-456", inputWithAllFields.getSessionId()); + assertEquals("agent-789", inputWithAllFields.getAgentId()); + assertEquals(Boolean.TRUE, inputWithAllFields.getInfer()); + assertEquals(testTags, inputWithAllFields.getTags()); + } + + @Test + public void testBuilderMinimal() { + assertNotNull(inputMinimal); + assertNull(inputMinimal.getMemoryContainerId()); + assertEquals(1, inputMinimal.getMessages().size()); + assertNull(inputMinimal.getSessionId()); + assertNull(inputMinimal.getAgentId()); + assertNull(inputMinimal.getInfer()); + assertNull(inputMinimal.getTags()); + } + + @Test + public void testConstructorValidation() { + // Test null messages + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> new MLAddMemoriesInput("container-1", null, null, null, null, null) + ); + assertEquals("Messages list cannot be empty", exception.getMessage()); + + // Test empty messages + exception = assertThrows( + IllegalArgumentException.class, + () -> new MLAddMemoriesInput("container-1", new ArrayList<>(), null, null, null, null) + ); + assertEquals("Messages list cannot be empty", exception.getMessage()); + + // Test that limit is removed - should be able to create with many messages + List manyMessages = new ArrayList<>(); + for (int i = 0; i < 100; i++) { // Test with 100 messages + manyMessages.add(new MessageInput("user", "Message " + i)); + } + // Should not throw exception anymore + MLAddMemoriesInput inputWithManyMessages = new MLAddMemoriesInput("container-1", manyMessages, null, null, null, null); + assertNotNull(inputWithManyMessages); + assertEquals(100, inputWithManyMessages.getMessages().size()); + } + + @Test + public void testStreamInputOutput() throws IOException { + // Test with all fields + BytesStreamOutput out = new BytesStreamOutput(); + inputWithAllFields.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLAddMemoriesInput deserialized = new MLAddMemoriesInput(in); + + assertEquals(inputWithAllFields.getMemoryContainerId(), deserialized.getMemoryContainerId()); + assertEquals(inputWithAllFields.getMessages().size(), deserialized.getMessages().size()); + for (int i = 0; i < inputWithAllFields.getMessages().size(); i++) { + MessageInput original = inputWithAllFields.getMessages().get(i); + MessageInput deser = deserialized.getMessages().get(i); + assertEquals(original.getRole(), deser.getRole()); + assertEquals(original.getContent(), deser.getContent()); + } + assertEquals(inputWithAllFields.getSessionId(), deserialized.getSessionId()); + assertEquals(inputWithAllFields.getAgentId(), deserialized.getAgentId()); + assertEquals(inputWithAllFields.getInfer(), deserialized.getInfer()); + assertEquals(inputWithAllFields.getTags(), deserialized.getTags()); + } + + @Test + public void testStreamInputOutputMinimal() throws IOException { + // Test with minimal fields + BytesStreamOutput out = new BytesStreamOutput(); + inputMinimal.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLAddMemoriesInput deserialized = new MLAddMemoriesInput(in); + + assertNull(deserialized.getMemoryContainerId()); + assertEquals(1, deserialized.getMessages().size()); + assertNull(deserialized.getSessionId()); + assertNull(deserialized.getAgentId()); + assertNull(deserialized.getInfer()); + assertNull(deserialized.getTags()); + } + + @Test + public void testStreamInputOutputEmptyTags() throws IOException { + // Test with empty tags + MLAddMemoriesInput inputEmptyTags = MLAddMemoriesInput + .builder() + .messages(Arrays.asList(new MessageInput("user", "Test"))) + .tags(new HashMap<>()) + .build(); + + BytesStreamOutput out = new BytesStreamOutput(); + inputEmptyTags.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLAddMemoriesInput deserialized = new MLAddMemoriesInput(in); + + assertNull(deserialized.getTags()); + } + + @Test + public void testToXContent() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + inputWithAllFields.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("\"memory_container_id\":\"container-123\"")); + assertTrue(jsonString.contains("\"messages\":[")); + assertTrue(jsonString.contains("\"role\":\"user\"")); + assertTrue(jsonString.contains("\"content\":\"Hello, how are you?\"")); + assertTrue(jsonString.contains("\"session_id\":\"session-456\"")); + assertTrue(jsonString.contains("\"agent_id\":\"agent-789\"")); + assertTrue(jsonString.contains("\"infer\":true")); + assertTrue(jsonString.contains("\"topic\":\"greeting\"")); + } + + @Test + public void testToXContentMinimal() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + inputMinimal.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(!jsonString.contains("\"memory_container_id\"")); + assertTrue(jsonString.contains("\"messages\":[")); + assertTrue(jsonString.contains("\"content\":\"Single message\"")); + assertTrue(!jsonString.contains("\"session_id\"")); + assertTrue(!jsonString.contains("\"agent_id\"")); + assertTrue(!jsonString.contains("\"infer\"")); + assertTrue(!jsonString.contains("\"tags\"")); + } + + @Test + public void testParse() throws IOException { + String jsonString = "{" + + "\"memory_container_id\":\"container-123\"," + + "\"messages\":[" + + "{\"role\":\"user\",\"content\":\"Test message 1\"}," + + "{\"role\":\"assistant\",\"content\":\"Test response\"}" + + "]," + + "\"session_id\":\"session-789\"," + + "\"agent_id\":\"agent-456\"," + + "\"infer\":false," + + "\"tags\":{\"key\":\"value\"}" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLAddMemoriesInput parsed = MLAddMemoriesInput.parse(parser); + + assertEquals("container-123", parsed.getMemoryContainerId()); + assertEquals(2, parsed.getMessages().size()); + assertEquals("user", parsed.getMessages().get(0).getRole()); + assertEquals("Test message 1", parsed.getMessages().get(0).getContent()); + assertEquals("session-789", parsed.getSessionId()); + assertEquals("agent-456", parsed.getAgentId()); + assertEquals(Boolean.FALSE, parsed.getInfer()); + assertEquals(1, parsed.getTags().size()); + assertEquals("value", parsed.getTags().get("key")); + } + + @Test + public void testParseMinimal() throws IOException { + String jsonString = "{" + "\"messages\":[" + "{\"content\":\"Minimal message\"}" + "]" + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLAddMemoriesInput parsed = MLAddMemoriesInput.parse(parser); + + assertNull(parsed.getMemoryContainerId()); + assertEquals(1, parsed.getMessages().size()); + assertEquals("Minimal message", parsed.getMessages().get(0).getContent()); + assertNull(parsed.getMessages().get(0).getRole()); + } + + @Test + public void testParseWithUnknownFields() throws IOException { + String jsonString = "{" + + "\"messages\":[{\"content\":\"Test\"}]," + + "\"unknown_field\":\"ignored\"," + + "\"another_unknown\":123" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLAddMemoriesInput parsed = MLAddMemoriesInput.parse(parser); + + assertEquals(1, parsed.getMessages().size()); + assertEquals("Test", parsed.getMessages().get(0).getContent()); + } + + @Test + public void testSetters() { + MLAddMemoriesInput input = MLAddMemoriesInput.builder().messages(Arrays.asList(new MessageInput("user", "Initial"))).build(); + + input.setMemoryContainerId("new-container"); + input.setSessionId("new-session"); + input.setAgentId("new-agent"); + input.setInfer(true); + input.setTags(testTags); + + assertEquals("new-container", input.getMemoryContainerId()); + assertEquals("new-session", input.getSessionId()); + assertEquals("new-agent", input.getAgentId()); + assertEquals(Boolean.TRUE, input.getInfer()); + assertEquals(testTags, input.getTags()); + } + + @Test + public void testLargeNumberOfMessages() { + // Test that we can handle a large number of messages now that limit is removed + List manyMessages = new ArrayList<>(); + for (int i = 0; i < 1000; i++) { // Test with 1000 messages + manyMessages.add(new MessageInput("user", "Message " + i)); + } + + // Should succeed with large number of messages + MLAddMemoriesInput input = new MLAddMemoriesInput("container-1", manyMessages, null, null, null, null); + assertEquals(1000, input.getMessages().size()); + } + + @Test + public void testSpecialCharactersInFields() throws IOException { + Map specialTags = new HashMap<>(); + specialTags.put("key with spaces", "value with\nnewlines"); + + List specialMessages = Arrays + .asList( + new MessageInput("user", "Message with\n\ttabs and \"quotes\""), + new MessageInput("assistant", "Response with unicode 🚀✨") + ); + + MLAddMemoriesInput specialInput = MLAddMemoriesInput + .builder() + .memoryContainerId("container-with-special-chars") + .messages(specialMessages) + .sessionId("session-🔥") + .tags(specialTags) + .build(); + + // Test serialization round trip + BytesStreamOutput out = new BytesStreamOutput(); + specialInput.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLAddMemoriesInput deserialized = new MLAddMemoriesInput(in); + + assertEquals(specialInput.getMemoryContainerId(), deserialized.getMemoryContainerId()); + assertEquals(specialInput.getSessionId(), deserialized.getSessionId()); + assertEquals(2, deserialized.getMessages().size()); + } + + @Test + public void testXContentRoundTrip() throws IOException { + // Convert to XContent + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + inputWithAllFields.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + // Parse back + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + MLAddMemoriesInput parsed = MLAddMemoriesInput.parse(parser); + + // Verify all fields match + assertEquals(inputWithAllFields.getMemoryContainerId(), parsed.getMemoryContainerId()); + assertEquals(inputWithAllFields.getMessages().size(), parsed.getMessages().size()); + assertEquals(inputWithAllFields.getSessionId(), parsed.getSessionId()); + assertEquals(inputWithAllFields.getAgentId(), parsed.getAgentId()); + assertEquals(inputWithAllFields.getInfer(), parsed.getInfer()); + assertEquals(inputWithAllFields.getTags(), parsed.getTags()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesRequestTest.java new file mode 100644 index 0000000000..a3b6f4e310 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesRequestTest.java @@ -0,0 +1,178 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; + +public class MLAddMemoriesRequestTest { + + private MLAddMemoriesInput testInput; + private MLAddMemoriesRequest request; + + @Before + public void setUp() { + MessageInput message = new MessageInput("user", "Test message content"); + testInput = MLAddMemoriesInput + .builder() + .messages(Arrays.asList(message)) + .memoryContainerId("container-123") + .sessionId("session-456") + .agentId("agent-789") + .infer(true) + .build(); + + request = MLAddMemoriesRequest.builder().mlAddMemoryInput(testInput).build(); + } + + @Test + public void testBuilder() { + assertNotNull(request); + assertNotNull(request.getMlAddMemoryInput()); + assertEquals(testInput, request.getMlAddMemoryInput()); + assertEquals("container-123", request.getMlAddMemoryInput().getMemoryContainerId()); + } + + @Test + public void testStreamInputOutput() throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + MLAddMemoriesRequest deserialized = new MLAddMemoriesRequest(in); + + assertNotNull(deserialized.getMlAddMemoryInput()); + assertEquals(request.getMlAddMemoryInput().getMemoryContainerId(), deserialized.getMlAddMemoryInput().getMemoryContainerId()); + assertEquals(request.getMlAddMemoryInput().getSessionId(), deserialized.getMlAddMemoryInput().getSessionId()); + assertEquals(request.getMlAddMemoryInput().getAgentId(), deserialized.getMlAddMemoryInput().getAgentId()); + assertEquals(request.getMlAddMemoryInput().getInfer(), deserialized.getMlAddMemoryInput().getInfer()); + assertEquals(request.getMlAddMemoryInput().getMessages().size(), deserialized.getMlAddMemoryInput().getMessages().size()); + } + + @Test + public void testValidateSuccess() { + ActionRequestValidationException exception = request.validate(); + assertNull(exception); + } + + @Test + public void testValidateWithNullInput() { + MLAddMemoriesRequest invalidRequest = MLAddMemoriesRequest.builder().mlAddMemoryInput(null).build(); + + ActionRequestValidationException exception = invalidRequest.validate(); + assertNotNull(exception); + assertEquals(1, exception.validationErrors().size()); + assertTrue(exception.validationErrors().get(0).contains("ML add memory input can't be null")); + } + + @Test + public void testToString() { + String toString = request.toString(); + assertNotNull(toString); + assertTrue(toString.contains("MLAddMemoriesRequest")); + assertTrue(toString.contains("mlAddMemoryInput")); + } + + @Test(expected = IllegalArgumentException.class) + public void testWithEmptyMessages() { + // Empty messages list is not allowed - should throw exception + MLAddMemoriesInput.builder().messages(Collections.emptyList()).memoryContainerId("container-empty").build(); + } + + @Test + public void testWithMultipleMessages() throws IOException { + MessageInput msg1 = new MessageInput("user", "First message"); + MessageInput msg2 = new MessageInput("assistant", "Second message"); + MessageInput msg3 = new MessageInput("user", "Third message"); + + MLAddMemoriesInput multiInput = MLAddMemoriesInput + .builder() + .messages(Arrays.asList(msg1, msg2, msg3)) + .memoryContainerId("container-multi") + .sessionId("session-multi") + .build(); + + MLAddMemoriesRequest multiRequest = MLAddMemoriesRequest.builder().mlAddMemoryInput(multiInput).build(); + + // Test serialization + BytesStreamOutput out = new BytesStreamOutput(); + multiRequest.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + MLAddMemoriesRequest deserialized = new MLAddMemoriesRequest(in); + + assertEquals(3, deserialized.getMlAddMemoryInput().getMessages().size()); + assertEquals("First message", deserialized.getMlAddMemoryInput().getMessages().get(0).getContent()); + assertEquals("Second message", deserialized.getMlAddMemoryInput().getMessages().get(1).getContent()); + assertEquals("Third message", deserialized.getMlAddMemoryInput().getMessages().get(2).getContent()); + } + + @Test + public void testWithMinimalInput() throws IOException { + MessageInput message = new MessageInput(null, "Minimal message"); + MLAddMemoriesInput minimalInput = MLAddMemoriesInput + .builder() + .messages(Arrays.asList(message)) + .memoryContainerId("container-minimal") + .build(); + + MLAddMemoriesRequest minimalRequest = MLAddMemoriesRequest.builder().mlAddMemoryInput(minimalInput).build(); + + assertNull(minimalRequest.validate()); + + // Test serialization + BytesStreamOutput out = new BytesStreamOutput(); + minimalRequest.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + MLAddMemoriesRequest deserialized = new MLAddMemoriesRequest(in); + + assertEquals("container-minimal", deserialized.getMlAddMemoryInput().getMemoryContainerId()); + assertNull(deserialized.getMlAddMemoryInput().getSessionId()); + assertNull(deserialized.getMlAddMemoryInput().getAgentId()); + assertNull(deserialized.getMlAddMemoryInput().getInfer()); // Null when not set + } + + @Test + public void testWithComplexTags() throws IOException { + MessageInput message = new MessageInput("user", "Tagged message"); + MLAddMemoriesInput taggedInput = MLAddMemoriesInput + .builder() + .messages(Arrays.asList(message)) + .memoryContainerId("container-tags") + .sessionId("session-tags") + .agentId("agent-tags") + .tags(java.util.Map.of("category", "technical", "priority", "high", "timestamp", "2024-01-01")) + .infer(false) + .build(); + + MLAddMemoriesRequest taggedRequest = MLAddMemoriesRequest.builder().mlAddMemoryInput(taggedInput).build(); + + // Test serialization + BytesStreamOutput out = new BytesStreamOutput(); + taggedRequest.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + MLAddMemoriesRequest deserialized = new MLAddMemoriesRequest(in); + + assertEquals(3, deserialized.getMlAddMemoryInput().getTags().size()); + assertEquals("technical", deserialized.getMlAddMemoryInput().getTags().get("category")); + assertEquals("high", deserialized.getMlAddMemoryInput().getTags().get("priority")); + assertEquals("2024-01-01", deserialized.getMlAddMemoryInput().getTags().get("timestamp")); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesResponseTest.java new file mode 100644 index 0000000000..228af28e01 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesResponseTest.java @@ -0,0 +1,258 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.TestHelper; + +public class MLAddMemoriesResponseTest { + + private MLAddMemoriesResponse responseWithResults; + private MLAddMemoriesResponse responseEmpty; + private MLAddMemoriesResponse responseMinimal; + private List testResults; + + @Before + public void setUp() { + testResults = Arrays + .asList( + MemoryResult.builder().memoryId("mem-1").memory("User's name is John").event(MemoryEvent.ADD).build(), + MemoryResult + .builder() + .memoryId("mem-2") + .memory("Lives in San Francisco") + .event(MemoryEvent.UPDATE) + .oldMemory("Lives in Boston") + .build(), + MemoryResult.builder().memoryId("mem-3").memory("Works at TechCorp").event(MemoryEvent.NONE).build() + ); + + // Response with results + responseWithResults = MLAddMemoriesResponse.builder().results(testResults).sessionId("session-123").build(); + + // Empty response + responseEmpty = MLAddMemoriesResponse.builder().results(new ArrayList<>()).sessionId("session-empty").build(); + + // Minimal response (null results defaults to empty list) + responseMinimal = MLAddMemoriesResponse.builder().sessionId("session-minimal").build(); + } + + @Test + public void testBuilderWithResults() { + assertNotNull(responseWithResults); + assertEquals(testResults, responseWithResults.getResults()); + assertEquals(3, responseWithResults.getResults().size()); + assertEquals("session-123", responseWithResults.getSessionId()); + } + + @Test + public void testBuilderEmpty() { + assertNotNull(responseEmpty); + assertEquals(0, responseEmpty.getResults().size()); + assertEquals("session-empty", responseEmpty.getSessionId()); + } + + @Test + public void testBuilderMinimal() { + assertNotNull(responseMinimal); + assertEquals(0, responseMinimal.getResults().size()); + assertEquals("session-minimal", responseMinimal.getSessionId()); + } + + @Test + public void testStreamInputOutput() throws IOException { + // Test with results + BytesStreamOutput out = new BytesStreamOutput(); + responseWithResults.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLAddMemoriesResponse deserialized = new MLAddMemoriesResponse(in); + + assertEquals(responseWithResults.getResults().size(), deserialized.getResults().size()); + for (int i = 0; i < responseWithResults.getResults().size(); i++) { + MemoryResult original = responseWithResults.getResults().get(i); + MemoryResult deser = deserialized.getResults().get(i); + assertEquals(original.getMemoryId(), deser.getMemoryId()); + assertEquals(original.getMemory(), deser.getMemory()); + assertEquals(original.getEvent(), deser.getEvent()); + assertEquals(original.getOldMemory(), deser.getOldMemory()); + } + assertEquals(responseWithResults.getSessionId(), deserialized.getSessionId()); + } + + @Test + public void testStreamInputOutputEmpty() throws IOException { + // Test with empty results + BytesStreamOutput out = new BytesStreamOutput(); + responseEmpty.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLAddMemoriesResponse deserialized = new MLAddMemoriesResponse(in); + + assertEquals(0, deserialized.getResults().size()); + assertEquals(responseEmpty.getSessionId(), deserialized.getSessionId()); + } + + @Test + public void testToXContent() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + responseWithResults.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("\"results\":[")); + assertTrue(jsonString.contains("\"id\":\"mem-1\"")); + assertTrue(jsonString.contains("\"text\":\"User's name is John\"")); + assertTrue(jsonString.contains("\"event\":\"ADD\"")); + assertTrue(jsonString.contains("\"id\":\"mem-2\"")); + assertTrue(jsonString.contains("\"text\":\"Lives in San Francisco\"")); + assertTrue(jsonString.contains("\"event\":\"UPDATE\"")); + assertTrue(jsonString.contains("\"old_memory\":\"Lives in Boston\"")); + assertTrue(jsonString.contains("\"session_id\":\"session-123\"")); + } + + @Test + public void testToXContentEmpty() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + responseEmpty.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("\"results\":[]")); + assertTrue(jsonString.contains("\"session_id\":\"session-empty\"")); + } + + @Test + public void testToString() { + String str = responseWithResults.toString(); + assertNotNull(str); + assertTrue(str.contains("session-123")); + assertTrue(str.contains("results")); + } + + @Test + public void testDifferentEventTypes() throws IOException { + List mixedResults = Arrays + .asList( + new MemoryResult("add-1", "New fact", MemoryEvent.ADD, null), + new MemoryResult("update-1", "Updated fact", MemoryEvent.UPDATE, "Old fact"), + new MemoryResult("delete-1", "Deleted fact", MemoryEvent.DELETE, null), + new MemoryResult("none-1", "Unchanged fact", MemoryEvent.NONE, null) + ); + + MLAddMemoriesResponse mixedResponse = MLAddMemoriesResponse.builder().results(mixedResults).sessionId("session-mixed").build(); + + // Test serialization + BytesStreamOutput out = new BytesStreamOutput(); + mixedResponse.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLAddMemoriesResponse deserialized = new MLAddMemoriesResponse(in); + + assertEquals(4, deserialized.getResults().size()); + assertEquals(MemoryEvent.ADD, deserialized.getResults().get(0).getEvent()); + assertEquals(MemoryEvent.UPDATE, deserialized.getResults().get(1).getEvent()); + assertEquals(MemoryEvent.DELETE, deserialized.getResults().get(2).getEvent()); + assertEquals(MemoryEvent.NONE, deserialized.getResults().get(3).getEvent()); + } + + @Test + public void testLargeResponse() throws IOException { + // Test with many results + List manyResults = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + manyResults + .add( + MemoryResult + .builder() + .memoryId("mem-" + i) + .memory("Memory content " + i) + .event(i % 2 == 0 ? MemoryEvent.ADD : MemoryEvent.UPDATE) + .oldMemory(i % 2 == 0 ? null : "Old memory " + i) + .build() + ); + } + + MLAddMemoriesResponse largeResponse = MLAddMemoriesResponse.builder().results(manyResults).sessionId("session-large").build(); + + // Test serialization + BytesStreamOutput out = new BytesStreamOutput(); + largeResponse.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLAddMemoriesResponse deserialized = new MLAddMemoriesResponse(in); + + assertEquals(100, deserialized.getResults().size()); + assertEquals("session-large", deserialized.getSessionId()); + } + + @Test + public void testSpecialCharactersInResponse() throws IOException { + List specialResults = Arrays + .asList( + MemoryResult + .builder() + .memoryId("mem-special-🚀") + .memory("Memory with\n\ttabs and \"quotes\"") + .event(MemoryEvent.ADD) + .build(), + MemoryResult + .builder() + .memoryId("mem-unicode-✨") + .memory("Memory with unicode characters") + .event(MemoryEvent.UPDATE) + .oldMemory("Old memory with 'single quotes'") + .build() + ); + + MLAddMemoriesResponse specialResponse = MLAddMemoriesResponse + .builder() + .results(specialResults) + .sessionId("session-special-chars") + .build(); + + // Test XContent + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + specialResponse.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("mem-special-")); + assertTrue(jsonString.contains("Memory with")); + assertTrue(jsonString.contains("tabs")); + assertTrue(jsonString.contains("quotes")); + } + + @Test + public void testConstructorWithNullResults() { + MLAddMemoriesResponse response = new MLAddMemoriesResponse(null, "session-null"); + assertNotNull(response.getResults()); + assertEquals(0, response.getResults().size()); + assertEquals("session-null", response.getSessionId()); + } + + @Test + public void testResultsOrder() throws IOException { + // Verify results maintain their order + BytesStreamOutput out = new BytesStreamOutput(); + responseWithResults.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLAddMemoriesResponse deserialized = new MLAddMemoriesResponse(in); + + for (int i = 0; i < testResults.size(); i++) { + assertEquals(testResults.get(i).getMemoryId(), deserialized.getResults().get(i).getMemoryId()); + } + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLDeleteMemoryRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLDeleteMemoryRequestTest.java new file mode 100644 index 0000000000..29191e3b2c --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLDeleteMemoryRequestTest.java @@ -0,0 +1,197 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class MLDeleteMemoryRequestTest { + + private MLDeleteMemoryRequest requestNormal; + private MLDeleteMemoryRequest requestEmpty; + + @Before + public void setUp() { + requestNormal = MLDeleteMemoryRequest.builder().memoryContainerId("container-123").memoryId("memory-456").build(); + + requestEmpty = MLDeleteMemoryRequest.builder().memoryContainerId(null).memoryId(null).build(); + } + + @Test + public void testBuilderNormal() { + assertNotNull(requestNormal); + assertEquals("container-123", requestNormal.getMemoryContainerId()); + assertEquals("memory-456", requestNormal.getMemoryId()); + } + + @Test + public void testBuilderWithNullValues() { + assertNotNull(requestEmpty); + assertNull(requestEmpty.getMemoryContainerId()); + assertNull(requestEmpty.getMemoryId()); + } + + @Test + public void testStreamInputOutput() throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + requestNormal.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLDeleteMemoryRequest deserialized = new MLDeleteMemoryRequest(in); + + assertEquals(requestNormal.getMemoryContainerId(), deserialized.getMemoryContainerId()); + assertEquals(requestNormal.getMemoryId(), deserialized.getMemoryId()); + } + + @Test + public void testValidateSuccess() { + ActionRequestValidationException exception = requestNormal.validate(); + assertNull(exception); + } + + @Test + public void testValidateWithNullContainerId() { + MLDeleteMemoryRequest request = MLDeleteMemoryRequest.builder().memoryContainerId(null).memoryId("memory-123").build(); + + ActionRequestValidationException exception = request.validate(); + assertNotNull(exception); + assertEquals(1, exception.validationErrors().size()); + assertTrue(exception.validationErrors().get(0).contains("Memory container id can't be null")); + } + + @Test + public void testValidateWithNullMemoryId() { + MLDeleteMemoryRequest request = MLDeleteMemoryRequest.builder().memoryContainerId("container-123").memoryId(null).build(); + + ActionRequestValidationException exception = request.validate(); + assertNotNull(exception); + assertEquals(1, exception.validationErrors().size()); + assertTrue(exception.validationErrors().get(0).contains("Memory id can't be null")); + } + + @Test + public void testValidateWithBothNull() { + ActionRequestValidationException exception = requestEmpty.validate(); + assertNotNull(exception); + assertEquals(2, exception.validationErrors().size()); + assertTrue(exception.validationErrors().get(0).contains("Memory container id can't be null")); + assertTrue(exception.validationErrors().get(1).contains("Memory id can't be null")); + } + + @Test + public void testFromActionRequestSameInstance() { + MLDeleteMemoryRequest result = MLDeleteMemoryRequest.fromActionRequest(requestNormal); + assertEquals(requestNormal, result); + } + + @Test + public void testFromActionRequestDifferentInstance() throws IOException { + // Create a mock ActionRequest that's not MLDeleteMemoryRequest + ActionRequest mockRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString("test-container"); + out.writeString("test-memory"); + } + }; + + MLDeleteMemoryRequest result = MLDeleteMemoryRequest.fromActionRequest(mockRequest); + assertNotNull(result); + assertEquals("test-container", result.getMemoryContainerId()); + assertEquals("test-memory", result.getMemoryId()); + } + + @Test(expected = UncheckedIOException.class) + public void testFromActionRequestIOException() { + // Create a mock ActionRequest that throws IOException + ActionRequest mockRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("Test exception"); + } + }; + + MLDeleteMemoryRequest.fromActionRequest(mockRequest); + } + + @Test + public void testSpecialCharacters() throws IOException { + MLDeleteMemoryRequest specialRequest = MLDeleteMemoryRequest + .builder() + .memoryContainerId("container-with-special-chars-🚀") + .memoryId("memory-with-\n\ttabs-and-\"quotes\"") + .build(); + + // Test serialization round trip + BytesStreamOutput out = new BytesStreamOutput(); + specialRequest.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLDeleteMemoryRequest deserialized = new MLDeleteMemoryRequest(in); + + assertEquals(specialRequest.getMemoryContainerId(), deserialized.getMemoryContainerId()); + assertEquals(specialRequest.getMemoryId(), deserialized.getMemoryId()); + } + + @Test + public void testEmptyStrings() { + MLDeleteMemoryRequest emptyStringRequest = MLDeleteMemoryRequest.builder().memoryContainerId("").memoryId("").build(); + + assertNotNull(emptyStringRequest); + assertEquals("", emptyStringRequest.getMemoryContainerId()); + assertEquals("", emptyStringRequest.getMemoryId()); + + // Empty strings should pass validation (only null check in validate method) + ActionRequestValidationException exception = emptyStringRequest.validate(); + assertNull(exception); + } + + @Test + public void testLongIds() throws IOException { + // Test with very long IDs + StringBuilder longId = new StringBuilder(); + for (int i = 0; i < 1000; i++) { + longId.append("a"); + } + + MLDeleteMemoryRequest longRequest = MLDeleteMemoryRequest + .builder() + .memoryContainerId(longId.toString()) + .memoryId(longId.toString() + "-memory") + .build(); + + // Test serialization + BytesStreamOutput out = new BytesStreamOutput(); + longRequest.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLDeleteMemoryRequest deserialized = new MLDeleteMemoryRequest(in); + + assertEquals(longRequest.getMemoryContainerId(), deserialized.getMemoryContainerId()); + assertEquals(longRequest.getMemoryId(), deserialized.getMemoryId()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryRequestTest.java new file mode 100644 index 0000000000..6cffd6b253 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryRequestTest.java @@ -0,0 +1,114 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class MLGetMemoryRequestTest { + + private MLGetMemoryRequest requestNormal; + private MLGetMemoryRequest requestWithNulls; + + @Before + public void setUp() { + requestNormal = MLGetMemoryRequest.builder().memoryContainerId("container-123").memoryId("memory-456").build(); + + requestWithNulls = MLGetMemoryRequest.builder().memoryContainerId(null).memoryId(null).build(); + } + + @Test + public void testBuilderNormal() { + assertNotNull(requestNormal); + assertEquals("container-123", requestNormal.getMemoryContainerId()); + assertEquals("memory-456", requestNormal.getMemoryId()); + } + + @Test + public void testStreamInputOutput() throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + requestNormal.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLGetMemoryRequest deserialized = new MLGetMemoryRequest(in); + + assertEquals(requestNormal.getMemoryContainerId(), deserialized.getMemoryContainerId()); + assertEquals(requestNormal.getMemoryId(), deserialized.getMemoryId()); + } + + @Test + public void testValidateSuccess() { + ActionRequestValidationException exception = requestNormal.validate(); + assertNull(exception); + } + + @Test + public void testValidateWithNullValues() { + ActionRequestValidationException exception = requestWithNulls.validate(); + assertNotNull(exception); + assertEquals(1, exception.validationErrors().size()); + assertTrue(exception.validationErrors().get(0).contains("memoryContainerId and memoryId id can not be null")); + } + + @Test + public void testFromActionRequestSameInstance() { + MLGetMemoryRequest result = MLGetMemoryRequest.fromActionRequest(requestNormal); + assertEquals(requestNormal, result); + } + + @Test + public void testFromActionRequestDifferentInstance() throws IOException { + // Create a mock ActionRequest that's not MLGetMemoryRequest + ActionRequest mockRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString("test-container"); + out.writeString("test-memory"); + } + }; + + MLGetMemoryRequest result = MLGetMemoryRequest.fromActionRequest(mockRequest); + assertNotNull(result); + assertEquals("test-container", result.getMemoryContainerId()); + assertEquals("test-memory", result.getMemoryId()); + } + + @Test(expected = UncheckedIOException.class) + public void testFromActionRequestIOException() { + // Create a mock ActionRequest that throws IOException + ActionRequest mockRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("Test exception"); + } + }; + + MLGetMemoryRequest.fromActionRequest(mockRequest); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryResponseTest.java new file mode 100644 index 0000000000..6352cce9a4 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryResponseTest.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.time.Instant; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.memorycontainer.MLMemory; +import org.opensearch.ml.common.memorycontainer.MemoryType; + +public class MLGetMemoryResponseTest { + + private MLGetMemoryResponse responseNormal; + private MLMemory testMemory; + + @Before + public void setUp() { + testMemory = MLMemory + .builder() + .sessionId("test-session") + .memory("Test memory content") + .memoryType(MemoryType.RAW_MESSAGE) + .userId("test-user") + .createdTime(Instant.now()) + .lastUpdatedTime(Instant.now()) + .build(); + + responseNormal = MLGetMemoryResponse.builder().mlMemory(testMemory).build(); + } + + @Test + public void testBuilderNormal() { + assertNotNull(responseNormal); + assertNotNull(responseNormal.getMlMemory()); + assertEquals("test-session", responseNormal.getMlMemory().getSessionId()); + assertEquals("Test memory content", responseNormal.getMlMemory().getMemory()); + assertEquals(MemoryType.RAW_MESSAGE, responseNormal.getMlMemory().getMemoryType()); + assertEquals("test-user", responseNormal.getMlMemory().getUserId()); + } + + @Test + public void testStreamInputOutput() throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + responseNormal.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLGetMemoryResponse deserialized = new MLGetMemoryResponse(in); + + assertNotNull(deserialized.getMlMemory()); + assertEquals(responseNormal.getMlMemory().getSessionId(), deserialized.getMlMemory().getSessionId()); + assertEquals(responseNormal.getMlMemory().getMemory(), deserialized.getMlMemory().getMemory()); + assertEquals(responseNormal.getMlMemory().getMemoryType(), deserialized.getMlMemory().getMemoryType()); + assertEquals(responseNormal.getMlMemory().getUserId(), deserialized.getMlMemory().getUserId()); + } + + @Test + public void testFromActionResponseSameInstance() { + MLGetMemoryResponse result = MLGetMemoryResponse.fromActionResponse(responseNormal); + assertEquals(responseNormal, result); + } + + @Test + public void testFromActionResponseDifferentInstance() throws IOException { + // Create a mock ActionResponse that's not MLGetMemoryResponse + ActionResponse mockResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + testMemory.writeTo(out); + } + }; + + MLGetMemoryResponse result = MLGetMemoryResponse.fromActionResponse(mockResponse); + assertNotNull(result); + assertNotNull(result.getMlMemory()); + assertEquals("test-session", result.getMlMemory().getSessionId()); + assertEquals("Test memory content", result.getMlMemory().getMemory()); + assertEquals(MemoryType.RAW_MESSAGE, result.getMlMemory().getMemoryType()); + assertEquals("test-user", result.getMlMemory().getUserId()); + } + + @Test(expected = UncheckedIOException.class) + public void testFromActionResponseIOException() { + // Create a mock ActionResponse that throws IOException + ActionResponse mockResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("Test exception"); + } + }; + + MLGetMemoryResponse.fromActionResponse(mockResponse); + } + + @Test + public void testToXContent() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + responseNormal.toXContent(builder, null); + String jsonString = builder.toString(); + + assertNotNull(jsonString); + assertTrue(jsonString.contains("test-session")); + assertTrue(jsonString.contains("Test memory content")); + assertTrue(jsonString.contains("RAW_MESSAGE")); + assertTrue(jsonString.contains("test-user")); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLSearchMemoriesInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLSearchMemoriesInputTest.java new file mode 100644 index 0000000000..ec454e61fb --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLSearchMemoriesInputTest.java @@ -0,0 +1,285 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; + +public class MLSearchMemoriesInputTest { + + private MLSearchMemoriesInput inputWithContainerId; + private MLSearchMemoriesInput inputWithoutContainerId; + + @Before + public void setUp() { + inputWithContainerId = MLSearchMemoriesInput + .builder() + .memoryContainerId("container-123") + .query("machine learning concepts") + .build(); + + inputWithoutContainerId = MLSearchMemoriesInput.builder().query("search without container id").build(); + } + + @Test + public void testBuilderWithContainerId() { + assertNotNull(inputWithContainerId); + assertEquals("container-123", inputWithContainerId.getMemoryContainerId()); + assertEquals("machine learning concepts", inputWithContainerId.getQuery()); + } + + @Test + public void testBuilderWithoutContainerId() { + assertNotNull(inputWithoutContainerId); + assertNull(inputWithoutContainerId.getMemoryContainerId()); + assertEquals("search without container id", inputWithoutContainerId.getQuery()); + } + + @Test + public void testConstructorWithContainerId() { + MLSearchMemoriesInput input = new MLSearchMemoriesInput("container-456", "test query"); + assertEquals("container-456", input.getMemoryContainerId()); + assertEquals("test query", input.getQuery()); + } + + @Test + public void testConstructorWithoutContainerId() { + MLSearchMemoriesInput input = new MLSearchMemoriesInput(null, "another query"); + assertNull(input.getMemoryContainerId()); + assertEquals("another query", input.getQuery()); + } + + @Test + public void testConstructorWithNullQuery() { + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> new MLSearchMemoriesInput("container-1", null) + ); + assertEquals("Query cannot be null or empty", exception.getMessage()); + } + + @Test + public void testConstructorWithEmptyQuery() { + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> new MLSearchMemoriesInput("container-1", "") + ); + assertEquals("Query cannot be null or empty", exception.getMessage()); + } + + @Test + public void testConstructorWithWhitespaceQuery() { + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> new MLSearchMemoriesInput("container-1", " ") + ); + assertEquals("Query cannot be null or empty", exception.getMessage()); + } + + @Test + public void testQueryTrimming() { + MLSearchMemoriesInput input = new MLSearchMemoriesInput("container-1", " query with spaces "); + assertEquals("query with spaces", input.getQuery()); + } + + @Test + public void testStreamInputOutput() throws IOException { + // Test with container ID + BytesStreamOutput out = new BytesStreamOutput(); + inputWithContainerId.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLSearchMemoriesInput deserialized = new MLSearchMemoriesInput(in); + + assertEquals(inputWithContainerId.getMemoryContainerId(), deserialized.getMemoryContainerId()); + assertEquals(inputWithContainerId.getQuery(), deserialized.getQuery()); + } + + @Test + public void testStreamInputOutputWithoutContainerId() throws IOException { + // Test without container ID + BytesStreamOutput out = new BytesStreamOutput(); + inputWithoutContainerId.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLSearchMemoriesInput deserialized = new MLSearchMemoriesInput(in); + + assertNull(deserialized.getMemoryContainerId()); + assertEquals(inputWithoutContainerId.getQuery(), deserialized.getQuery()); + } + + @Test + public void testToXContent() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + inputWithContainerId.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("\"memory_container_id\":\"container-123\"")); + assertTrue(jsonString.contains("\"query\":\"machine learning concepts\"")); + } + + @Test + public void testToXContentWithoutContainerId() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + inputWithoutContainerId.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(!jsonString.contains("\"memory_container_id\"")); + assertTrue(jsonString.contains("\"query\":\"search without container id\"")); + } + + @Test + public void testParse() throws IOException { + String jsonString = "{\"memory_container_id\":\"container-789\",\"query\":\"neural networks\"}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLSearchMemoriesInput parsed = MLSearchMemoriesInput.parse(parser); + + assertEquals("container-789", parsed.getMemoryContainerId()); + assertEquals("neural networks", parsed.getQuery()); + } + + @Test + public void testParseWithoutContainerId() throws IOException { + String jsonString = "{\"query\":\"deep learning\"}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLSearchMemoriesInput parsed = MLSearchMemoriesInput.parse(parser); + + assertNull(parsed.getMemoryContainerId()); + assertEquals("deep learning", parsed.getQuery()); + } + + @Test + public void testParseWithUnknownFields() throws IOException { + String jsonString = "{\"query\":\"test query\",\"unknown_field\":\"ignored\",\"another\":123}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLSearchMemoriesInput parsed = MLSearchMemoriesInput.parse(parser); + + assertEquals("test query", parsed.getQuery()); + } + + @Test + public void testSetters() { + MLSearchMemoriesInput input = new MLSearchMemoriesInput(null, "initial query"); + + input.setMemoryContainerId("new-container"); + input.setQuery("updated query"); + + assertEquals("new-container", input.getMemoryContainerId()); + assertEquals("updated query", input.getQuery()); + } + + @Test + public void testSpecialCharactersInQuery() throws IOException { + MLSearchMemoriesInput specialInput = new MLSearchMemoriesInput( + "container-special", + "Query with\n\ttabs,\nnewlines, \"quotes\", 'single quotes', and unicode 🚀✨" + ); + + // Test serialization round trip + BytesStreamOutput out = new BytesStreamOutput(); + specialInput.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLSearchMemoriesInput deserialized = new MLSearchMemoriesInput(in); + + assertEquals(specialInput.getMemoryContainerId(), deserialized.getMemoryContainerId()); + assertEquals(specialInput.getQuery(), deserialized.getQuery()); + + // Test XContent + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + specialInput.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("Query with")); + assertTrue(jsonString.contains("tabs")); + assertTrue(jsonString.contains("quotes")); + } + + @Test + public void testLongQuery() throws IOException { + // Test with a very long query + StringBuilder longQuery = new StringBuilder(); + for (int i = 0; i < 1000; i++) { + longQuery.append("word").append(i).append(" "); + } + + MLSearchMemoriesInput longInput = new MLSearchMemoriesInput("container-1", longQuery.toString().trim()); + + // Test serialization + BytesStreamOutput out = new BytesStreamOutput(); + longInput.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLSearchMemoriesInput deserialized = new MLSearchMemoriesInput(in); + + assertEquals(longInput.getQuery(), deserialized.getQuery()); + } + + @Test + public void testXContentRoundTrip() throws IOException { + // Convert to XContent + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + inputWithContainerId.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + // Parse back + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + MLSearchMemoriesInput parsed = MLSearchMemoriesInput.parse(parser); + + // Verify all fields match + assertEquals(inputWithContainerId.getMemoryContainerId(), parsed.getMemoryContainerId()); + assertEquals(inputWithContainerId.getQuery(), parsed.getQuery()); + } + + @Test + public void testComplexQueries() { + // Test various complex query patterns + String[] queries = { + "machine learning AND deep learning", + "\"exact phrase matching\"", + "wildcard* search?", + "field:value AND (nested OR query)", + "fuzzy~2 search", + "+required -excluded" }; + + for (String query : queries) { + MLSearchMemoriesInput input = new MLSearchMemoriesInput("container-1", query); + assertEquals(query, input.getQuery()); + } + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLSearchMemoriesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLSearchMemoriesRequestTest.java new file mode 100644 index 0000000000..cc62c95ff1 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLSearchMemoriesRequestTest.java @@ -0,0 +1,226 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class MLSearchMemoriesRequestTest { + + private MLSearchMemoriesInput testInput; + private MLSearchMemoriesRequest request; + + @Before + public void setUp() { + testInput = MLSearchMemoriesInput.builder().memoryContainerId("container-123").query("machine learning concepts").build(); + + request = MLSearchMemoriesRequest.builder().mlSearchMemoriesInput(testInput).tenantId("tenant-456").build(); + } + + @Test + public void testBuilder() { + assertNotNull(request); + assertNotNull(request.getMlSearchMemoriesInput()); + assertEquals(testInput, request.getMlSearchMemoriesInput()); + assertEquals("tenant-456", request.getTenantId()); + } + + @Test + public void testConstructor() { + MLSearchMemoriesRequest constructedRequest = new MLSearchMemoriesRequest(testInput, "tenant-789"); + assertNotNull(constructedRequest); + assertEquals(testInput, constructedRequest.getMlSearchMemoriesInput()); + assertEquals("tenant-789", constructedRequest.getTenantId()); + } + + @Test + public void testStreamInputOutput() throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + MLSearchMemoriesRequest deserialized = new MLSearchMemoriesRequest(in); + + assertNotNull(deserialized.getMlSearchMemoriesInput()); + assertEquals( + request.getMlSearchMemoriesInput().getMemoryContainerId(), + deserialized.getMlSearchMemoriesInput().getMemoryContainerId() + ); + assertEquals(request.getMlSearchMemoriesInput().getQuery(), deserialized.getMlSearchMemoriesInput().getQuery()); + assertEquals(request.getTenantId(), deserialized.getTenantId()); + } + + @Test + public void testStreamInputOutputWithNullTenant() throws IOException { + MLSearchMemoriesRequest requestNoTenant = MLSearchMemoriesRequest.builder().mlSearchMemoriesInput(testInput).tenantId(null).build(); + + BytesStreamOutput out = new BytesStreamOutput(); + requestNoTenant.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + MLSearchMemoriesRequest deserialized = new MLSearchMemoriesRequest(in); + + assertNotNull(deserialized.getMlSearchMemoriesInput()); + assertNull(deserialized.getTenantId()); + } + + @Test + public void testValidateSuccess() { + ActionRequestValidationException exception = request.validate(); + assertNull(exception); + } + + @Test + public void testValidateWithNullInput() { + MLSearchMemoriesRequest invalidRequest = MLSearchMemoriesRequest + .builder() + .mlSearchMemoriesInput(null) + .tenantId("tenant-123") + .build(); + + ActionRequestValidationException exception = invalidRequest.validate(); + assertNotNull(exception); + assertEquals(1, exception.validationErrors().size()); + assertTrue(exception.validationErrors().get(0).contains("Search memories input can't be null")); + } + + @Test + public void testFromActionRequestSameInstance() { + MLSearchMemoriesRequest result = MLSearchMemoriesRequest.fromActionRequest(request); + assertEquals(request, result); + } + + @Test + public void testFromActionRequestDifferentInstance() throws IOException { + // Create a mock ActionRequest that's not MLSearchMemoriesRequest + ActionRequest mockRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + testInput.writeTo(out); + out.writeOptionalString("mock-tenant"); + } + }; + + MLSearchMemoriesRequest result = MLSearchMemoriesRequest.fromActionRequest(mockRequest); + assertNotNull(result); + assertEquals("container-123", result.getMlSearchMemoriesInput().getMemoryContainerId()); + assertEquals("machine learning concepts", result.getMlSearchMemoriesInput().getQuery()); + assertEquals("mock-tenant", result.getTenantId()); + } + + @Test(expected = UncheckedIOException.class) + public void testFromActionRequestIOException() { + // Create a mock ActionRequest that throws IOException + ActionRequest mockRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("Test exception"); + } + }; + + MLSearchMemoriesRequest.fromActionRequest(mockRequest); + } + + @Test + public void testSetters() { + MLSearchMemoriesRequest mutableRequest = new MLSearchMemoriesRequest(testInput, "initial-tenant"); + + // Test setMlSearchMemoriesInput + MLSearchMemoriesInput newInput = MLSearchMemoriesInput.builder().memoryContainerId("new-container").query("new query").build(); + mutableRequest.setMlSearchMemoriesInput(newInput); + assertEquals(newInput, mutableRequest.getMlSearchMemoriesInput()); + + // Test setTenantId + mutableRequest.setTenantId("new-tenant"); + assertEquals("new-tenant", mutableRequest.getTenantId()); + } + + @Test(expected = IllegalArgumentException.class) + public void testWithEmptyQuery() { + // Empty query is not allowed - should throw exception + MLSearchMemoriesInput.builder().memoryContainerId("container-empty").query("").build(); + } + + @Test + public void testWithSpecialCharacters() throws IOException { + MLSearchMemoriesInput specialInput = MLSearchMemoriesInput + .builder() + .memoryContainerId("container-with-special-chars-🚀") + .query("Query with \"quotes\" and\n\ttabs and unicode 🔥") + .build(); + + MLSearchMemoriesRequest specialRequest = MLSearchMemoriesRequest + .builder() + .mlSearchMemoriesInput(specialInput) + .tenantId("tenant-特殊文字") + .build(); + + // Test serialization + BytesStreamOutput out = new BytesStreamOutput(); + specialRequest.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + MLSearchMemoriesRequest deserialized = new MLSearchMemoriesRequest(in); + + assertEquals("container-with-special-chars-🚀", deserialized.getMlSearchMemoriesInput().getMemoryContainerId()); + assertTrue(deserialized.getMlSearchMemoriesInput().getQuery().contains("quotes")); + assertTrue(deserialized.getMlSearchMemoriesInput().getQuery().contains("tabs")); + assertEquals("tenant-特殊文字", deserialized.getTenantId()); + } + + @Test + public void testWithLongQuery() throws IOException { + StringBuilder longQuery = new StringBuilder(); + for (int i = 0; i < 1000; i++) { + longQuery.append("word").append(i).append(" "); + } + + MLSearchMemoriesInput longInput = MLSearchMemoriesInput + .builder() + .memoryContainerId("container-long") + .query(longQuery.toString().trim()) + .build(); + + MLSearchMemoriesRequest longRequest = MLSearchMemoriesRequest + .builder() + .mlSearchMemoriesInput(longInput) + .tenantId("tenant-long") + .build(); + + // Test serialization + BytesStreamOutput out = new BytesStreamOutput(); + longRequest.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + MLSearchMemoriesRequest deserialized = new MLSearchMemoriesRequest(in); + + assertEquals(longQuery.toString().trim(), deserialized.getMlSearchMemoriesInput().getQuery()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLSearchMemoriesResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLSearchMemoriesResponseTest.java new file mode 100644 index 0000000000..9191fa6809 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLSearchMemoriesResponseTest.java @@ -0,0 +1,346 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.ml.common.memorycontainer.MemoryType; + +public class MLSearchMemoriesResponseTest { + + private MLSearchMemoriesResponse responseWithHits; + private MLSearchMemoriesResponse responseEmpty; + private MLSearchMemoriesResponse responseTimedOut; + private List testHits; + + @Before + public void setUp() { + Map tags = new HashMap<>(); + tags.put("topic", "ML"); + + testHits = Arrays + .asList( + MemorySearchResult + .builder() + .memoryId("mem-1") + .memory("Machine learning is a subset of AI") + .score(0.95f) + .sessionId("session-123") + .userId("user-456") + .memoryType(MemoryType.RAW_MESSAGE) + .role("assistant") + .tags(tags) + .createdTime(Instant.now()) + .lastUpdatedTime(Instant.now()) + .build(), + MemorySearchResult + .builder() + .memoryId("mem-2") + .memory("Deep learning uses neural networks") + .score(0.87f) + .sessionId("session-123") + .memoryType(MemoryType.FACT) + .build(), + MemorySearchResult.builder().memoryId("mem-3").memory("Neural networks have multiple layers").score(0.75f).build() + ); + + // Response with hits + responseWithHits = MLSearchMemoriesResponse.builder().hits(testHits).totalHits(25L).maxScore(0.95f).timedOut(false).build(); + + // Empty response + responseEmpty = MLSearchMemoriesResponse.builder().hits(new ArrayList<>()).totalHits(0L).maxScore(0.0f).timedOut(false).build(); + + // Timed out response + responseTimedOut = MLSearchMemoriesResponse + .builder() + .hits(Arrays.asList(testHits.get(0))) + .totalHits(1L) + .maxScore(0.95f) + .timedOut(true) + .build(); + } + + @Test + public void testBuilderWithHits() { + assertNotNull(responseWithHits); + assertEquals(testHits, responseWithHits.getHits()); + assertEquals(3, responseWithHits.getHits().size()); + assertEquals(25L, responseWithHits.getTotalHits()); + assertEquals(0.95f, responseWithHits.getMaxScore(), 0.001f); + assertFalse(responseWithHits.isTimedOut()); + } + + @Test + public void testBuilderEmpty() { + assertNotNull(responseEmpty); + assertEquals(0, responseEmpty.getHits().size()); + assertEquals(0L, responseEmpty.getTotalHits()); + assertEquals(0.0f, responseEmpty.getMaxScore(), 0.001f); + assertFalse(responseEmpty.isTimedOut()); + } + + @Test + public void testBuilderTimedOut() { + assertNotNull(responseTimedOut); + assertEquals(1, responseTimedOut.getHits().size()); + assertEquals(1L, responseTimedOut.getTotalHits()); + assertEquals(0.95f, responseTimedOut.getMaxScore(), 0.001f); + assertTrue(responseTimedOut.isTimedOut()); + } + + @Test + public void testConstructorWithNullHits() { + MLSearchMemoriesResponse response = new MLSearchMemoriesResponse(null, 0L, 0.0f, false); + assertNotNull(response.getHits()); + assertEquals(0, response.getHits().size()); + } + + @Test + public void testStreamInputOutput() throws IOException { + // Test with hits + BytesStreamOutput out = new BytesStreamOutput(); + responseWithHits.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLSearchMemoriesResponse deserialized = new MLSearchMemoriesResponse(in); + + assertEquals(responseWithHits.getHits().size(), deserialized.getHits().size()); + assertEquals(responseWithHits.getTotalHits(), deserialized.getTotalHits()); + assertEquals(responseWithHits.getMaxScore(), deserialized.getMaxScore(), 0.001f); + assertEquals(responseWithHits.isTimedOut(), deserialized.isTimedOut()); + + // Verify individual hits + for (int i = 0; i < responseWithHits.getHits().size(); i++) { + MemorySearchResult original = responseWithHits.getHits().get(i); + MemorySearchResult deser = deserialized.getHits().get(i); + assertEquals(original.getMemoryId(), deser.getMemoryId()); + assertEquals(original.getMemory(), deser.getMemory()); + assertEquals(original.getScore(), deser.getScore(), 0.001f); + } + } + + @Test + public void testStreamInputOutputEmpty() throws IOException { + // Test with empty hits + BytesStreamOutput out = new BytesStreamOutput(); + responseEmpty.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLSearchMemoriesResponse deserialized = new MLSearchMemoriesResponse(in); + + assertEquals(0, deserialized.getHits().size()); + assertEquals(0L, deserialized.getTotalHits()); + assertEquals(0.0f, deserialized.getMaxScore(), 0.001f); + assertFalse(deserialized.isTimedOut()); + } + + @Test + public void testToXContent() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + responseWithHits.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + // Check structure + assertTrue(jsonString.contains("\"timed_out\":false")); + assertTrue(jsonString.contains("\"hits\":{")); + assertTrue(jsonString.contains("\"total\":25")); + assertTrue(jsonString.contains("\"max_score\":0.95")); + assertTrue(jsonString.contains("\"hits\":[")); + + // Check individual hits + assertTrue(jsonString.contains("\"memory_id\":\"mem-1\"")); + assertTrue(jsonString.contains("\"memory\":\"Machine learning is a subset of AI\"")); + assertTrue(jsonString.contains("\"_score\":0.95")); + assertTrue(jsonString.contains("\"session_id\":\"session-123\"")); + } + + @Test + public void testToXContentEmpty() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + responseEmpty.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("\"timed_out\":false")); + assertTrue(jsonString.contains("\"total\":0")); + assertTrue(jsonString.contains("\"max_score\":0.0")); + assertTrue(jsonString.contains("\"hits\":[]")); + } + + @Test + public void testToXContentTimedOut() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + responseTimedOut.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("\"timed_out\":true")); + assertTrue(jsonString.contains("\"total\":1")); + assertTrue(jsonString.contains("\"max_score\":0.95")); + assertEquals(1, jsonString.split("\"memory_id\"").length - 1); // Only one hit + } + + @Test + public void testLargeResponse() throws IOException { + // Test with many hits + List manyHits = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + manyHits.add(MemorySearchResult.builder().memoryId("mem-" + i).memory("Memory content " + i).score(1.0f - (i * 0.01f)).build()); + } + + MLSearchMemoriesResponse largeResponse = MLSearchMemoriesResponse + .builder() + .hits(manyHits) + .totalHits(1000L) + .maxScore(1.0f) + .timedOut(false) + .build(); + + // Test serialization + BytesStreamOutput out = new BytesStreamOutput(); + largeResponse.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLSearchMemoriesResponse deserialized = new MLSearchMemoriesResponse(in); + + assertEquals(100, deserialized.getHits().size()); + assertEquals(1000L, deserialized.getTotalHits()); + assertEquals(1.0f, deserialized.getMaxScore(), 0.001f); + } + + @Test + public void testDifferentScoreValues() { + // Test various score configurations + MLSearchMemoriesResponse response1 = MLSearchMemoriesResponse + .builder() + .hits(new ArrayList<>()) + .totalHits(0L) + .maxScore(Float.NaN) + .timedOut(false) + .build(); + + MLSearchMemoriesResponse response2 = MLSearchMemoriesResponse + .builder() + .hits(testHits) + .totalHits(100L) + .maxScore(Float.POSITIVE_INFINITY) + .timedOut(false) + .build(); + + assertEquals(Float.NaN, response1.getMaxScore(), 0.001f); + assertEquals(Float.POSITIVE_INFINITY, response2.getMaxScore(), 0.001f); + } + + @Test + public void testHitsOrdering() throws IOException { + // Verify hits maintain their order + BytesStreamOutput out = new BytesStreamOutput(); + responseWithHits.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLSearchMemoriesResponse deserialized = new MLSearchMemoriesResponse(in); + + for (int i = 0; i < responseWithHits.getHits().size(); i++) { + assertEquals(responseWithHits.getHits().get(i).getMemoryId(), deserialized.getHits().get(i).getMemoryId()); + assertEquals(responseWithHits.getHits().get(i).getScore(), deserialized.getHits().get(i).getScore(), 0.001f); + } + } + + @Test + public void testResponseStructure() throws IOException { + // Test the nested JSON structure matches OpenSearch conventions + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + responseWithHits.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + // Verify structure: { "timed_out": ..., "hits": { "total": ..., "max_score": ..., "hits": [...] } } + assertTrue(jsonString.startsWith("{\"timed_out\":")); + assertTrue(jsonString.contains(",\"hits\":{\"total\":")); + assertTrue(jsonString.contains(",\"max_score\":")); + assertTrue(jsonString.contains(",\"hits\":[")); + assertTrue(jsonString.endsWith("}}")); + } + + @Test + public void testPartialResults() { + // Test response with partial results (timed out but has some hits) + MLSearchMemoriesResponse partialResponse = MLSearchMemoriesResponse + .builder() + .hits(Arrays.asList(testHits.get(0), testHits.get(1))) + .totalHits(50L) // More than returned hits + .maxScore(0.95f) + .timedOut(true) + .build(); + + assertEquals(2, partialResponse.getHits().size()); + assertEquals(50L, partialResponse.getTotalHits()); + assertTrue(partialResponse.isTimedOut()); + assertTrue(partialResponse.getTotalHits() > partialResponse.getHits().size()); + } + + @Test + public void testSpecialCharactersInHits() throws IOException { + List specialHits = Arrays + .asList( + MemorySearchResult + .builder() + .memoryId("mem-special-🚀") + .memory("Memory with\n\ttabs and \"quotes\"") + .score(0.9f) + .sessionId("session-✨") + .build() + ); + + MLSearchMemoriesResponse specialResponse = MLSearchMemoriesResponse + .builder() + .hits(specialHits) + .totalHits(1L) + .maxScore(0.9f) + .timedOut(false) + .build(); + + // Test XContent + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + specialResponse.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("mem-special-")); + assertTrue(jsonString.contains("Memory with")); + assertTrue(jsonString.contains("tabs")); + } + + @Test + public void testZeroMaxScore() throws IOException { + // Test when all hits have 0 score + List zeroScoreHits = Arrays + .asList(MemorySearchResult.builder().memoryId("mem-zero-1").memory("Memory with zero score").score(0.0f).build()); + + MLSearchMemoriesResponse zeroScoreResponse = MLSearchMemoriesResponse + .builder() + .hits(zeroScoreHits) + .totalHits(1L) + .maxScore(0.0f) + .timedOut(false) + .build(); + + assertEquals(0.0f, zeroScoreResponse.getMaxScore(), 0.001f); + assertEquals(0.0f, zeroScoreResponse.getHits().get(0).getScore(), 0.001f); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLUpdateMemoryInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLUpdateMemoryInputTest.java new file mode 100644 index 0000000000..c8459331f5 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLUpdateMemoryInputTest.java @@ -0,0 +1,258 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; + +public class MLUpdateMemoryInputTest { + + private MLUpdateMemoryInput inputNormal; + private MLUpdateMemoryInput inputWithWhitespace; + + @Before + public void setUp() { + inputNormal = MLUpdateMemoryInput.builder().text("Updated memory content").build(); + + inputWithWhitespace = MLUpdateMemoryInput.builder().text(" Text with surrounding spaces ").build(); + } + + @Test + public void testBuilderNormal() { + assertNotNull(inputNormal); + assertEquals("Updated memory content", inputNormal.getText()); + } + + @Test + public void testBuilderWithWhitespace() { + assertNotNull(inputWithWhitespace); + // Should be trimmed + assertEquals("Text with surrounding spaces", inputWithWhitespace.getText()); + } + + @Test + public void testConstructor() { + MLUpdateMemoryInput input = new MLUpdateMemoryInput("Test text"); + assertEquals("Test text", input.getText()); + } + + @Test + public void testConstructorWithTrimming() { + MLUpdateMemoryInput input = new MLUpdateMemoryInput(" Trimmed text "); + assertEquals("Trimmed text", input.getText()); + } + + @Test + public void testConstructorWithNullText() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> new MLUpdateMemoryInput((String) null)); + assertEquals("Text cannot be null or empty", exception.getMessage()); + } + + @Test + public void testConstructorWithEmptyText() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> new MLUpdateMemoryInput("")); + assertEquals("Text cannot be null or empty", exception.getMessage()); + } + + @Test + public void testConstructorWithWhitespaceOnlyText() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> new MLUpdateMemoryInput(" ")); + assertEquals("Text cannot be null or empty", exception.getMessage()); + } + + @Test + public void testStreamInputOutput() throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + inputNormal.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLUpdateMemoryInput deserialized = new MLUpdateMemoryInput(in); + + assertEquals(inputNormal.getText(), deserialized.getText()); + } + + @Test + public void testToXContent() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + inputNormal.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("\"text\":\"Updated memory content\"")); + } + + @Test + public void testParse() throws IOException { + String jsonString = "{\"text\":\"Parsed memory text\"}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLUpdateMemoryInput parsed = MLUpdateMemoryInput.parse(parser); + + assertEquals("Parsed memory text", parsed.getText()); + } + + @Test + public void testParseWithUnknownFields() throws IOException { + String jsonString = "{\"text\":\"Valid text\",\"unknown_field\":\"ignored\",\"another\":123}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLUpdateMemoryInput parsed = MLUpdateMemoryInput.parse(parser); + + assertEquals("Valid text", parsed.getText()); + } + + @Test + public void testSetter() { + MLUpdateMemoryInput input = new MLUpdateMemoryInput("Initial text"); + input.setText("Updated text"); + assertEquals("Updated text", input.getText()); + } + + @Test + public void testSpecialCharactersInText() throws IOException { + MLUpdateMemoryInput specialInput = new MLUpdateMemoryInput( + "Text with\n\ttabs,\nnewlines, \"quotes\", 'single quotes', and unicode 🚀✨" + ); + + // Test serialization round trip + BytesStreamOutput out = new BytesStreamOutput(); + specialInput.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLUpdateMemoryInput deserialized = new MLUpdateMemoryInput(in); + + assertEquals(specialInput.getText(), deserialized.getText()); + + // Test XContent + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + specialInput.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("Text with")); + assertTrue(jsonString.contains("tabs")); + assertTrue(jsonString.contains("quotes")); + } + + @Test + public void testLongText() throws IOException { + // Test with very long text + StringBuilder longText = new StringBuilder(); + for (int i = 0; i < 1000; i++) { + longText.append("This is sentence ").append(i).append(". "); + } + + MLUpdateMemoryInput longInput = new MLUpdateMemoryInput(longText.toString().trim()); + + // Test serialization + BytesStreamOutput out = new BytesStreamOutput(); + longInput.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLUpdateMemoryInput deserialized = new MLUpdateMemoryInput(in); + + assertEquals(longInput.getText(), deserialized.getText()); + } + + @Test + public void testXContentRoundTrip() throws IOException { + // Convert to XContent + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + inputNormal.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + // Parse back + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + MLUpdateMemoryInput parsed = MLUpdateMemoryInput.parse(parser); + + // Verify field matches + assertEquals(inputNormal.getText(), parsed.getText()); + } + + @Test + public void testMultilineText() throws IOException { + String multilineText = "Line 1\nLine 2\nLine 3\nWith multiple lines"; + MLUpdateMemoryInput multilineInput = new MLUpdateMemoryInput(multilineText); + + assertEquals(multilineText, multilineInput.getText()); + + // Test XContent handling + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + multilineInput.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + // Parse back + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + MLUpdateMemoryInput parsed = MLUpdateMemoryInput.parse(parser); + + assertEquals(multilineText, parsed.getText()); + } + + @Test + public void testSingleCharacterText() { + MLUpdateMemoryInput singleChar = new MLUpdateMemoryInput("A"); + assertEquals("A", singleChar.getText()); + } + + @Test + public void testTextWithOnlySpecialCharacters() { + String specialOnly = "!@#$%^&*()_+-=[]{}|;':\",./<>?"; + MLUpdateMemoryInput specialInput = new MLUpdateMemoryInput(specialOnly); + assertEquals(specialOnly, specialInput.getText()); + } + + @Test + public void testParseWithMissingTextField() throws IOException { + // Parse with missing text field should throw exception when building + String jsonString = "{\"other_field\":\"value\"}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + assertThrows(IllegalArgumentException.class, () -> MLUpdateMemoryInput.parse(parser)); + } + + @Test + public void testSimpleJsonStructure() throws IOException { + // Verify the JSON structure is simple: just {"text": "..."} + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + inputNormal.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + // Should be a simple object with just one field + assertTrue(jsonString.startsWith("{\"text\":")); + assertTrue(jsonString.endsWith("\"}")); + assertEquals(1, jsonString.split("\":").length - 1); // Only one field + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLUpdateMemoryRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLUpdateMemoryRequestTest.java new file mode 100644 index 0000000000..3cb7a27c9d --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLUpdateMemoryRequestTest.java @@ -0,0 +1,326 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class MLUpdateMemoryRequestTest { + + private MLUpdateMemoryRequest requestNormal; + private MLUpdateMemoryRequest requestWithNulls; + private MLUpdateMemoryInput testInput; + + @Before + public void setUp() { + testInput = MLUpdateMemoryInput.builder().text("Updated memory content").build(); + + requestNormal = MLUpdateMemoryRequest + .builder() + .memoryContainerId("container-123") + .memoryId("memory-456") + .mlUpdateMemoryInput(testInput) + .build(); + + requestWithNulls = MLUpdateMemoryRequest.builder().memoryContainerId(null).memoryId(null).mlUpdateMemoryInput(null).build(); + } + + @Test + public void testBuilderNormal() { + assertNotNull(requestNormal); + assertEquals("container-123", requestNormal.getMemoryContainerId()); + assertEquals("memory-456", requestNormal.getMemoryId()); + assertNotNull(requestNormal.getMlUpdateMemoryInput()); + assertEquals("Updated memory content", requestNormal.getMlUpdateMemoryInput().getText()); + } + + @Test + public void testBuilderWithNullValues() { + assertNotNull(requestWithNulls); + assertNull(requestWithNulls.getMemoryContainerId()); + assertNull(requestWithNulls.getMemoryId()); + assertNull(requestWithNulls.getMlUpdateMemoryInput()); + } + + @Test + public void testStreamInputOutput() throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + requestNormal.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLUpdateMemoryRequest deserialized = new MLUpdateMemoryRequest(in); + + assertEquals(requestNormal.getMemoryContainerId(), deserialized.getMemoryContainerId()); + assertEquals(requestNormal.getMemoryId(), deserialized.getMemoryId()); + assertNotNull(deserialized.getMlUpdateMemoryInput()); + assertEquals(requestNormal.getMlUpdateMemoryInput().getText(), deserialized.getMlUpdateMemoryInput().getText()); + } + + @Test + public void testValidateSuccess() { + ActionRequestValidationException exception = requestNormal.validate(); + assertNull(exception); + } + + @Test + public void testValidateWithNullInput() { + MLUpdateMemoryRequest request = MLUpdateMemoryRequest + .builder() + .memoryContainerId("container-123") + .memoryId("memory-456") + .mlUpdateMemoryInput(null) + .build(); + + ActionRequestValidationException exception = request.validate(); + assertNotNull(exception); + assertEquals(1, exception.validationErrors().size()); + assertTrue(exception.validationErrors().get(0).contains("Update memory input can't be null")); + } + + @Test + public void testValidateWithNullContainerId() { + MLUpdateMemoryRequest request = MLUpdateMemoryRequest + .builder() + .memoryContainerId(null) + .memoryId("memory-456") + .mlUpdateMemoryInput(testInput) + .build(); + + ActionRequestValidationException exception = request.validate(); + assertNotNull(exception); + assertEquals(1, exception.validationErrors().size()); + assertTrue(exception.validationErrors().get(0).contains("Memory container id can't be null")); + } + + @Test + public void testValidateWithNullMemoryId() { + MLUpdateMemoryRequest request = MLUpdateMemoryRequest + .builder() + .memoryContainerId("container-123") + .memoryId(null) + .mlUpdateMemoryInput(testInput) + .build(); + + ActionRequestValidationException exception = request.validate(); + assertNotNull(exception); + assertEquals(1, exception.validationErrors().size()); + assertTrue(exception.validationErrors().get(0).contains("Memory id can't be null")); + } + + @Test + public void testValidateWithAllNull() { + ActionRequestValidationException exception = requestWithNulls.validate(); + assertNotNull(exception); + assertEquals(3, exception.validationErrors().size()); + assertTrue(exception.validationErrors().get(0).contains("Update memory input can't be null")); + assertTrue(exception.validationErrors().get(1).contains("Memory container id can't be null")); + assertTrue(exception.validationErrors().get(2).contains("Memory id can't be null")); + } + + @Test + public void testValidateWithTwoNulls() { + MLUpdateMemoryRequest request = MLUpdateMemoryRequest + .builder() + .memoryContainerId(null) + .memoryId(null) + .mlUpdateMemoryInput(testInput) + .build(); + + ActionRequestValidationException exception = request.validate(); + assertNotNull(exception); + assertEquals(2, exception.validationErrors().size()); + assertTrue(exception.validationErrors().get(0).contains("Memory container id can't be null")); + assertTrue(exception.validationErrors().get(1).contains("Memory id can't be null")); + } + + @Test + public void testFromActionRequestSameInstance() { + MLUpdateMemoryRequest result = MLUpdateMemoryRequest.fromActionRequest(requestNormal); + assertEquals(requestNormal, result); + } + + @Test + public void testFromActionRequestDifferentInstance() throws IOException { + // Create a mock ActionRequest that's not MLUpdateMemoryRequest + ActionRequest mockRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString("test-container"); + out.writeString("test-memory"); + testInput.writeTo(out); + } + }; + + MLUpdateMemoryRequest result = MLUpdateMemoryRequest.fromActionRequest(mockRequest); + assertNotNull(result); + assertEquals("test-container", result.getMemoryContainerId()); + assertEquals("test-memory", result.getMemoryId()); + assertNotNull(result.getMlUpdateMemoryInput()); + assertEquals("Updated memory content", result.getMlUpdateMemoryInput().getText()); + } + + @Test(expected = UncheckedIOException.class) + public void testFromActionRequestIOException() { + // Create a mock ActionRequest that throws IOException + ActionRequest mockRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("Test exception"); + } + }; + + MLUpdateMemoryRequest.fromActionRequest(mockRequest); + } + + @Test + public void testSetMlUpdateMemoryInput() { + MLUpdateMemoryInput newInput = MLUpdateMemoryInput.builder().text("New updated text").build(); + + requestNormal.setMlUpdateMemoryInput(newInput); + assertEquals("New updated text", requestNormal.getMlUpdateMemoryInput().getText()); + } + + @Test + public void testSpecialCharacters() throws IOException { + MLUpdateMemoryInput specialInput = MLUpdateMemoryInput.builder().text("Text with\n\ttabs and \"quotes\" and unicode 🚀✨").build(); + + MLUpdateMemoryRequest specialRequest = MLUpdateMemoryRequest + .builder() + .memoryContainerId("container-with-special-chars-🌟") + .memoryId("memory-with-unicode-💫") + .mlUpdateMemoryInput(specialInput) + .build(); + + // Test serialization round trip + BytesStreamOutput out = new BytesStreamOutput(); + specialRequest.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLUpdateMemoryRequest deserialized = new MLUpdateMemoryRequest(in); + + assertEquals(specialRequest.getMemoryContainerId(), deserialized.getMemoryContainerId()); + assertEquals(specialRequest.getMemoryId(), deserialized.getMemoryId()); + assertEquals(specialRequest.getMlUpdateMemoryInput().getText(), deserialized.getMlUpdateMemoryInput().getText()); + } + + @Test + public void testEmptyStrings() { + MLUpdateMemoryInput emptyInput = MLUpdateMemoryInput + .builder() + .text("Valid text") // Text can't be empty as per MLUpdateMemoryInput validation + .build(); + + MLUpdateMemoryRequest emptyStringRequest = MLUpdateMemoryRequest + .builder() + .memoryContainerId("") + .memoryId("") + .mlUpdateMemoryInput(emptyInput) + .build(); + + assertNotNull(emptyStringRequest); + assertEquals("", emptyStringRequest.getMemoryContainerId()); + assertEquals("", emptyStringRequest.getMemoryId()); + + // Empty strings should pass validation (only null check in validate method) + ActionRequestValidationException exception = emptyStringRequest.validate(); + assertNull(exception); + } + + @Test + public void testLongIds() throws IOException { + // Test with very long IDs + StringBuilder longId = new StringBuilder(); + for (int i = 0; i < 1000; i++) { + longId.append("a"); + } + + StringBuilder longText = new StringBuilder(); + for (int i = 0; i < 1000; i++) { + longText.append("This is sentence ").append(i).append(". "); + } + + MLUpdateMemoryInput longInput = MLUpdateMemoryInput.builder().text(longText.toString().trim()).build(); + + MLUpdateMemoryRequest longRequest = MLUpdateMemoryRequest + .builder() + .memoryContainerId(longId.toString()) + .memoryId(longId.toString() + "-memory") + .mlUpdateMemoryInput(longInput) + .build(); + + // Test serialization + BytesStreamOutput out = new BytesStreamOutput(); + longRequest.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MLUpdateMemoryRequest deserialized = new MLUpdateMemoryRequest(in); + + assertEquals(longRequest.getMemoryContainerId(), deserialized.getMemoryContainerId()); + assertEquals(longRequest.getMemoryId(), deserialized.getMemoryId()); + assertEquals(longRequest.getMlUpdateMemoryInput().getText(), deserialized.getMlUpdateMemoryInput().getText()); + } + + @Test + public void testMultipleInputUpdates() { + MLUpdateMemoryInput input1 = MLUpdateMemoryInput.builder().text("First text").build(); + MLUpdateMemoryInput input2 = MLUpdateMemoryInput.builder().text("Second text").build(); + MLUpdateMemoryInput input3 = MLUpdateMemoryInput.builder().text("Third text").build(); + + MLUpdateMemoryRequest request = MLUpdateMemoryRequest + .builder() + .memoryContainerId("container-123") + .memoryId("memory-456") + .mlUpdateMemoryInput(input1) + .build(); + + assertEquals("First text", request.getMlUpdateMemoryInput().getText()); + + request.setMlUpdateMemoryInput(input2); + assertEquals("Second text", request.getMlUpdateMemoryInput().getText()); + + request.setMlUpdateMemoryInput(input3); + assertEquals("Third text", request.getMlUpdateMemoryInput().getText()); + } + + @Test + public void testValidationOrderWithMultipleNulls() { + // Test to ensure validation errors are added in correct order + MLUpdateMemoryRequest request = MLUpdateMemoryRequest + .builder() + .memoryContainerId("container-123") + .memoryId(null) + .mlUpdateMemoryInput(null) + .build(); + + ActionRequestValidationException exception = request.validate(); + assertNotNull(exception); + assertEquals(2, exception.validationErrors().size()); + // Input validation comes first + assertTrue(exception.validationErrors().get(0).contains("Update memory input can't be null")); + assertTrue(exception.validationErrors().get(1).contains("Memory id can't be null")); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MemoryActionClassesTest.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MemoryActionClassesTest.java new file mode 100644 index 0000000000..42e47e69e7 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MemoryActionClassesTest.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import org.junit.Test; + +/** + * Tests for memory-related Action classes + */ +public class MemoryActionClassesTest { + + @Test + public void testMLAddMemoriesAction() { + assertNotNull(MLAddMemoriesAction.INSTANCE); + assertEquals("cluster:admin/opensearch/ml/memory_containers/memories/add", MLAddMemoriesAction.NAME); + assertEquals(MLAddMemoriesAction.NAME, MLAddMemoriesAction.INSTANCE.name()); + } + + @Test + public void testMLSearchMemoriesAction() { + assertNotNull(MLSearchMemoriesAction.INSTANCE); + assertEquals("cluster:admin/opensearch/ml/memory_containers/memories/search", MLSearchMemoriesAction.NAME); + assertEquals(MLSearchMemoriesAction.NAME, MLSearchMemoriesAction.INSTANCE.name()); + } + + @Test + public void testMLDeleteMemoryAction() { + assertNotNull(MLDeleteMemoryAction.INSTANCE); + assertEquals("cluster:admin/opensearch/ml/memory_containers/memory/delete", MLDeleteMemoryAction.NAME); + assertEquals(MLDeleteMemoryAction.NAME, MLDeleteMemoryAction.INSTANCE.name()); + } + + @Test + public void testMLUpdateMemoryAction() { + assertNotNull(MLUpdateMemoryAction.INSTANCE); + assertEquals("cluster:admin/opensearch/ml/memory_containers/memory/update", MLUpdateMemoryAction.NAME); + assertEquals(MLUpdateMemoryAction.NAME, MLUpdateMemoryAction.INSTANCE.name()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MemoryEventTest.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MemoryEventTest.java new file mode 100644 index 0000000000..566c742254 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MemoryEventTest.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; + +import org.junit.Test; + +public class MemoryEventTest { + + @Test + public void testEnumValues() { + // Test all enum values exist + assertEquals(4, MemoryEvent.values().length); + assertEquals(MemoryEvent.ADD, MemoryEvent.valueOf("ADD")); + assertEquals(MemoryEvent.UPDATE, MemoryEvent.valueOf("UPDATE")); + assertEquals(MemoryEvent.DELETE, MemoryEvent.valueOf("DELETE")); + assertEquals(MemoryEvent.NONE, MemoryEvent.valueOf("NONE")); + } + + @Test + public void testGetValue() { + assertEquals("ADD", MemoryEvent.ADD.getValue()); + assertEquals("UPDATE", MemoryEvent.UPDATE.getValue()); + assertEquals("DELETE", MemoryEvent.DELETE.getValue()); + assertEquals("NONE", MemoryEvent.NONE.getValue()); + } + + @Test + public void testToString() { + assertEquals("ADD", MemoryEvent.ADD.toString()); + assertEquals("UPDATE", MemoryEvent.UPDATE.toString()); + assertEquals("DELETE", MemoryEvent.DELETE.toString()); + assertEquals("NONE", MemoryEvent.NONE.toString()); + } + + @Test + public void testFromString_ValidValues() { + // Test exact match + assertEquals(MemoryEvent.ADD, MemoryEvent.fromString("ADD")); + assertEquals(MemoryEvent.UPDATE, MemoryEvent.fromString("UPDATE")); + assertEquals(MemoryEvent.DELETE, MemoryEvent.fromString("DELETE")); + assertEquals(MemoryEvent.NONE, MemoryEvent.fromString("NONE")); + + // Test case insensitive + assertEquals(MemoryEvent.ADD, MemoryEvent.fromString("add")); + assertEquals(MemoryEvent.UPDATE, MemoryEvent.fromString("Update")); + assertEquals(MemoryEvent.DELETE, MemoryEvent.fromString("dElEtE")); + assertEquals(MemoryEvent.NONE, MemoryEvent.fromString("none")); + } + + @Test + public void testFromString_Null() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> MemoryEvent.fromString(null)); + assertEquals("Memory event value cannot be null", exception.getMessage()); + } + + @Test + public void testFromString_InvalidValue() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> MemoryEvent.fromString("INVALID_EVENT")); + assertEquals("Unknown memory event: INVALID_EVENT", exception.getMessage()); + } + + @Test + public void testFromString_EmptyString() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> MemoryEvent.fromString("")); + assertEquals("Unknown memory event: ", exception.getMessage()); + } + + @Test + public void testFromString_Whitespace() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> MemoryEvent.fromString(" ")); + assertEquals("Unknown memory event: ", exception.getMessage()); + } + + @Test + public void testEnumConsistency() { + // Verify each enum's getValue() returns its name + for (MemoryEvent event : MemoryEvent.values()) { + assertNotNull(event.getValue()); + assertEquals(event.getValue(), event.toString()); + assertEquals(event, MemoryEvent.fromString(event.getValue())); + } + } + + @Test + public void testAllEventsHandled() { + // Ensure all events are properly handled in fromString + String[] expectedEvents = { "ADD", "UPDATE", "DELETE", "NONE" }; + for (String eventStr : expectedEvents) { + MemoryEvent event = MemoryEvent.fromString(eventStr); + assertNotNull("Event should not be null for: " + eventStr, event); + assertEquals(eventStr, event.getValue()); + } + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MemoryResultTest.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MemoryResultTest.java new file mode 100644 index 0000000000..1786f2ce03 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MemoryResultTest.java @@ -0,0 +1,258 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.TestHelper; + +public class MemoryResultTest { + + private MemoryResult resultWithAllFields; + private MemoryResult resultMinimal; + private MemoryResult addResult; + private MemoryResult updateResult; + private MemoryResult deleteResult; + private MemoryResult noneResult; + + @Before + public void setUp() { + // UPDATE result with all fields including oldMemory + resultWithAllFields = MemoryResult + .builder() + .memoryId("memory-123") + .memory("Updated memory text") + .event(MemoryEvent.UPDATE) + .oldMemory("Original memory text") + .build(); + + // Minimal result (no oldMemory) + resultMinimal = MemoryResult.builder().memoryId("memory-456").memory("New memory text").event(MemoryEvent.ADD).build(); + + // Different event types + addResult = new MemoryResult("add-789", "Adding new memory", MemoryEvent.ADD, null); + updateResult = new MemoryResult("update-101", "Updating memory", MemoryEvent.UPDATE, "Old text"); + deleteResult = new MemoryResult("delete-202", "Deleting memory", MemoryEvent.DELETE, null); + noneResult = new MemoryResult("none-303", "No change", MemoryEvent.NONE, null); + } + + @Test + public void testBuilderWithAllFields() { + assertNotNull(resultWithAllFields); + assertEquals("memory-123", resultWithAllFields.getMemoryId()); + assertEquals("Updated memory text", resultWithAllFields.getMemory()); + assertEquals(MemoryEvent.UPDATE, resultWithAllFields.getEvent()); + assertEquals("Original memory text", resultWithAllFields.getOldMemory()); + } + + @Test + public void testBuilderMinimal() { + assertNotNull(resultMinimal); + assertEquals("memory-456", resultMinimal.getMemoryId()); + assertEquals("New memory text", resultMinimal.getMemory()); + assertEquals(MemoryEvent.ADD, resultMinimal.getEvent()); + assertNull(resultMinimal.getOldMemory()); + } + + @Test + public void testConstructorWithAllParameters() { + MemoryResult result = new MemoryResult("id-1", "text-1", MemoryEvent.UPDATE, "old-text"); + assertEquals("id-1", result.getMemoryId()); + assertEquals("text-1", result.getMemory()); + assertEquals(MemoryEvent.UPDATE, result.getEvent()); + assertEquals("old-text", result.getOldMemory()); + } + + @Test + public void testConstructorWithNullOldMemory() { + MemoryResult result = new MemoryResult("id-2", "text-2", MemoryEvent.ADD, null); + assertEquals("id-2", result.getMemoryId()); + assertEquals("text-2", result.getMemory()); + assertEquals(MemoryEvent.ADD, result.getEvent()); + assertNull(result.getOldMemory()); + } + + @Test + public void testStreamInputOutput() throws IOException { + // Test with all fields + BytesStreamOutput out = new BytesStreamOutput(); + resultWithAllFields.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MemoryResult deserialized = new MemoryResult(in); + + assertEquals(resultWithAllFields.getMemoryId(), deserialized.getMemoryId()); + assertEquals(resultWithAllFields.getMemory(), deserialized.getMemory()); + assertEquals(resultWithAllFields.getEvent(), deserialized.getEvent()); + assertEquals(resultWithAllFields.getOldMemory(), deserialized.getOldMemory()); + } + + @Test + public void testStreamInputOutputMinimal() throws IOException { + // Test with minimal fields + BytesStreamOutput out = new BytesStreamOutput(); + resultMinimal.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MemoryResult deserialized = new MemoryResult(in); + + assertEquals(resultMinimal.getMemoryId(), deserialized.getMemoryId()); + assertEquals(resultMinimal.getMemory(), deserialized.getMemory()); + assertEquals(resultMinimal.getEvent(), deserialized.getEvent()); + assertNull(deserialized.getOldMemory()); + } + + @Test + public void testStreamInputOutputAllEventTypes() throws IOException { + // Test all event types + MemoryResult[] results = { addResult, updateResult, deleteResult, noneResult }; + + for (MemoryResult original : results) { + BytesStreamOutput out = new BytesStreamOutput(); + original.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MemoryResult deserialized = new MemoryResult(in); + + assertEquals(original.getMemoryId(), deserialized.getMemoryId()); + assertEquals(original.getMemory(), deserialized.getMemory()); + assertEquals(original.getEvent(), deserialized.getEvent()); + assertEquals(original.getOldMemory(), deserialized.getOldMemory()); + } + } + + @Test + public void testToXContent() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + resultWithAllFields.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + // Uses "id" and "text" fields instead of "memory_id" and "memory" + assertTrue(jsonString.contains("\"id\":\"memory-123\"")); + assertTrue(jsonString.contains("\"text\":\"Updated memory text\"")); + assertTrue(jsonString.contains("\"event\":\"UPDATE\"")); + assertTrue(jsonString.contains("\"old_memory\":\"Original memory text\"")); + } + + @Test + public void testToXContentMinimal() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + resultMinimal.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("\"id\":\"memory-456\"")); + assertTrue(jsonString.contains("\"text\":\"New memory text\"")); + assertTrue(jsonString.contains("\"event\":\"ADD\"")); + // old_memory should not be present + assertTrue(!jsonString.contains("\"old_memory\"")); + } + + @Test + public void testToXContentDifferentEvents() throws IOException { + // Test ADD event + XContentBuilder addBuilder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + addResult.toXContent(addBuilder, EMPTY_PARAMS); + String addJson = TestHelper.xContentBuilderToString(addBuilder); + assertTrue(addJson.contains("\"event\":\"ADD\"")); + assertTrue(!addJson.contains("\"old_memory\"")); + + // Test UPDATE event + XContentBuilder updateBuilder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + updateResult.toXContent(updateBuilder, EMPTY_PARAMS); + String updateJson = TestHelper.xContentBuilderToString(updateBuilder); + assertTrue(updateJson.contains("\"event\":\"UPDATE\"")); + assertTrue(updateJson.contains("\"old_memory\":\"Old text\"")); + + // Test DELETE event + XContentBuilder deleteBuilder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + deleteResult.toXContent(deleteBuilder, EMPTY_PARAMS); + String deleteJson = TestHelper.xContentBuilderToString(deleteBuilder); + assertTrue(deleteJson.contains("\"event\":\"DELETE\"")); + + // Test NONE event + XContentBuilder noneBuilder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + noneResult.toXContent(noneBuilder, EMPTY_PARAMS); + String noneJson = TestHelper.xContentBuilderToString(noneBuilder); + assertTrue(noneJson.contains("\"event\":\"NONE\"")); + } + + @Test + public void testToString() { + String str = resultWithAllFields.toString(); + assertNotNull(str); + assertTrue(str.contains("memory-123")); + assertTrue(str.contains("Updated memory text")); + assertTrue(str.contains("UPDATE")); + assertTrue(str.contains("Original memory text")); + } + + @Test + public void testSpecialCharactersInFields() throws IOException { + MemoryResult specialResult = MemoryResult + .builder() + .memoryId("id-with-special-chars-🚀") + .memory("Text with\n\ttabs and\nnewlines and \"quotes\"") + .event(MemoryEvent.UPDATE) + .oldMemory("Old text with 'single quotes' and \\backslashes\\") + .build(); + + // Test serialization round trip + BytesStreamOutput out = new BytesStreamOutput(); + specialResult.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MemoryResult deserialized = new MemoryResult(in); + + assertEquals(specialResult.getMemoryId(), deserialized.getMemoryId()); + assertEquals(specialResult.getMemory(), deserialized.getMemory()); + assertEquals(specialResult.getOldMemory(), deserialized.getOldMemory()); + + // Test XContent + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + specialResult.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("id-with-special-chars-")); + assertTrue(jsonString.contains("Text with")); + assertTrue(jsonString.contains("tabs")); + } + + @Test + public void testEmptyStrings() throws IOException { + MemoryResult emptyResult = new MemoryResult("", "", MemoryEvent.NONE, ""); + + // Test serialization + BytesStreamOutput out = new BytesStreamOutput(); + emptyResult.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MemoryResult deserialized = new MemoryResult(in); + + assertEquals("", deserialized.getMemoryId()); + assertEquals("", deserialized.getMemory()); + assertEquals(MemoryEvent.NONE, deserialized.getEvent()); + assertEquals("", deserialized.getOldMemory()); + } + + @Test + public void testBuilderDefaults() { + // Test builder with only required fields + MemoryResult result = MemoryResult.builder().memoryId("test-id").memory("test memory").event(MemoryEvent.ADD).build(); + + assertEquals("test-id", result.getMemoryId()); + assertEquals("test memory", result.getMemory()); + assertEquals(MemoryEvent.ADD, result.getEvent()); + assertNull(result.getOldMemory()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MemorySearchResultTest.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MemorySearchResultTest.java new file mode 100644 index 0000000000..16a3d43c05 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MemorySearchResultTest.java @@ -0,0 +1,375 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.ml.common.memorycontainer.MemoryType; + +public class MemorySearchResultTest { + + private MemorySearchResult resultWithAllFields; + private MemorySearchResult resultMinimal; + private MemorySearchResult resultNoOptionals; + private Map testTags; + private Instant testCreatedTime; + private Instant testUpdatedTime; + + @Before + public void setUp() { + testCreatedTime = Instant.now(); + testUpdatedTime = Instant.now().plusSeconds(60); + + testTags = new HashMap<>(); + testTags.put("topic", "machine learning"); + testTags.put("priority", "high"); + + // Result with all fields + resultWithAllFields = MemorySearchResult + .builder() + .memoryId("memory-123") + .memory("This is a test memory content") + .score(0.95f) + .sessionId("session-456") + .agentId("agent-789") + .userId("user-101") + .memoryType(MemoryType.RAW_MESSAGE) + .role("user") + .tags(testTags) + .createdTime(testCreatedTime) + .lastUpdatedTime(testUpdatedTime) + .build(); + + // Minimal result (only required fields) + resultMinimal = MemorySearchResult.builder().memoryId("memory-minimal").memory("Minimal memory").score(0.5f).build(); + + // Result without optional fields + resultNoOptionals = new MemorySearchResult( + "memory-no-opt", + "Memory without optionals", + 0.75f, + null, + null, + null, + null, + null, + null, + null, + null + ); + } + + @Test + public void testBuilderWithAllFields() { + assertNotNull(resultWithAllFields); + assertEquals("memory-123", resultWithAllFields.getMemoryId()); + assertEquals("This is a test memory content", resultWithAllFields.getMemory()); + assertEquals(0.95f, resultWithAllFields.getScore(), 0.001f); + assertEquals("session-456", resultWithAllFields.getSessionId()); + assertEquals("agent-789", resultWithAllFields.getAgentId()); + assertEquals("user-101", resultWithAllFields.getUserId()); + assertEquals(MemoryType.RAW_MESSAGE, resultWithAllFields.getMemoryType()); + assertEquals("user", resultWithAllFields.getRole()); + assertEquals(testTags, resultWithAllFields.getTags()); + assertEquals(testCreatedTime, resultWithAllFields.getCreatedTime()); + assertEquals(testUpdatedTime, resultWithAllFields.getLastUpdatedTime()); + } + + @Test + public void testBuilderMinimal() { + assertNotNull(resultMinimal); + assertEquals("memory-minimal", resultMinimal.getMemoryId()); + assertEquals("Minimal memory", resultMinimal.getMemory()); + assertEquals(0.5f, resultMinimal.getScore(), 0.001f); + assertNull(resultMinimal.getSessionId()); + assertNull(resultMinimal.getAgentId()); + assertNull(resultMinimal.getUserId()); + assertNull(resultMinimal.getMemoryType()); + assertNull(resultMinimal.getRole()); + assertNull(resultMinimal.getTags()); + assertNull(resultMinimal.getCreatedTime()); + assertNull(resultMinimal.getLastUpdatedTime()); + } + + @Test + public void testConstructorWithAllParameters() { + Map tags = new HashMap<>(); + tags.put("key", "value"); + Instant now = Instant.now(); + + MemorySearchResult result = new MemorySearchResult( + "id-1", + "memory-1", + 0.85f, + "session-1", + "agent-1", + "user-1", + MemoryType.FACT, + "assistant", + tags, + now, + now.plusSeconds(10) + ); + + assertEquals("id-1", result.getMemoryId()); + assertEquals("memory-1", result.getMemory()); + assertEquals(0.85f, result.getScore(), 0.001f); + assertEquals("session-1", result.getSessionId()); + assertEquals("agent-1", result.getAgentId()); + assertEquals("user-1", result.getUserId()); + assertEquals(MemoryType.FACT, result.getMemoryType()); + assertEquals("assistant", result.getRole()); + assertEquals(tags, result.getTags()); + assertEquals(now, result.getCreatedTime()); + assertEquals(now.plusSeconds(10), result.getLastUpdatedTime()); + } + + @Test + public void testStreamInputOutput() throws IOException { + // Test with all fields + BytesStreamOutput out = new BytesStreamOutput(); + resultWithAllFields.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MemorySearchResult deserialized = new MemorySearchResult(in); + + assertEquals(resultWithAllFields.getMemoryId(), deserialized.getMemoryId()); + assertEquals(resultWithAllFields.getMemory(), deserialized.getMemory()); + assertEquals(resultWithAllFields.getScore(), deserialized.getScore(), 0.001f); + assertEquals(resultWithAllFields.getSessionId(), deserialized.getSessionId()); + assertEquals(resultWithAllFields.getAgentId(), deserialized.getAgentId()); + assertEquals(resultWithAllFields.getUserId(), deserialized.getUserId()); + assertEquals(resultWithAllFields.getMemoryType(), deserialized.getMemoryType()); + assertEquals(resultWithAllFields.getRole(), deserialized.getRole()); + assertEquals(resultWithAllFields.getTags(), deserialized.getTags()); + assertEquals(resultWithAllFields.getCreatedTime(), deserialized.getCreatedTime()); + assertEquals(resultWithAllFields.getLastUpdatedTime(), deserialized.getLastUpdatedTime()); + } + + @Test + public void testStreamInputOutputMinimal() throws IOException { + // Test with minimal fields + BytesStreamOutput out = new BytesStreamOutput(); + resultMinimal.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MemorySearchResult deserialized = new MemorySearchResult(in); + + assertEquals(resultMinimal.getMemoryId(), deserialized.getMemoryId()); + assertEquals(resultMinimal.getMemory(), deserialized.getMemory()); + assertEquals(resultMinimal.getScore(), deserialized.getScore(), 0.001f); + assertNull(deserialized.getSessionId()); + assertNull(deserialized.getAgentId()); + assertNull(deserialized.getUserId()); + assertNull(deserialized.getMemoryType()); + assertNull(deserialized.getRole()); + assertNull(deserialized.getTags()); + assertNull(deserialized.getCreatedTime()); + assertNull(deserialized.getLastUpdatedTime()); + } + + @Test + public void testStreamInputOutputEmptyTags() throws IOException { + // Test with empty tags + MemorySearchResult resultEmptyTags = MemorySearchResult + .builder() + .memoryId("memory-empty-tags") + .memory("Memory with empty tags") + .score(0.8f) + .tags(new HashMap<>()) + .build(); + + BytesStreamOutput out = new BytesStreamOutput(); + resultEmptyTags.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MemorySearchResult deserialized = new MemorySearchResult(in); + + assertNull(deserialized.getTags()); + } + + @Test + public void testToXContent() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + resultWithAllFields.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("\"memory_id\":\"memory-123\"")); + assertTrue(jsonString.contains("\"memory\":\"This is a test memory content\"")); + assertTrue(jsonString.contains("\"_score\":0.95")); + assertTrue(jsonString.contains("\"session_id\":\"session-456\"")); + assertTrue(jsonString.contains("\"agent_id\":\"agent-789\"")); + assertTrue(jsonString.contains("\"user_id\":\"user-101\"")); + assertTrue(jsonString.contains("\"memory_type\":\"RAW_MESSAGE\"")); + assertTrue(jsonString.contains("\"role\":\"user\"")); + assertTrue(jsonString.contains("\"topic\":\"machine learning\"")); + assertTrue(jsonString.contains("\"priority\":\"high\"")); + assertTrue(jsonString.contains("\"created_time\":" + testCreatedTime.toEpochMilli())); + assertTrue(jsonString.contains("\"last_updated_time\":" + testUpdatedTime.toEpochMilli())); + } + + @Test + public void testToXContentMinimal() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + resultMinimal.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("\"memory_id\":\"memory-minimal\"")); + assertTrue(jsonString.contains("\"memory\":\"Minimal memory\"")); + assertTrue(jsonString.contains("\"_score\":0.5")); + // Optional fields should not be present + assertTrue(!jsonString.contains("\"session_id\"")); + assertTrue(!jsonString.contains("\"agent_id\"")); + assertTrue(!jsonString.contains("\"user_id\"")); + assertTrue(!jsonString.contains("\"memory_type\"")); + assertTrue(!jsonString.contains("\"role\"")); + assertTrue(!jsonString.contains("\"tags\"")); + assertTrue(!jsonString.contains("\"created_time\"")); + assertTrue(!jsonString.contains("\"last_updated_time\"")); + } + + @Test + public void testToXContentNoOptionals() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + resultNoOptionals.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("\"memory_id\":\"memory-no-opt\"")); + assertTrue(jsonString.contains("\"memory\":\"Memory without optionals\"")); + assertTrue(jsonString.contains("\"_score\":0.75")); + // All optional fields should be absent + assertTrue(!jsonString.contains("\"session_id\"")); + assertTrue(!jsonString.contains("\"tags\"")); + assertTrue(!jsonString.contains("\"created_time\"")); + } + + @Test + public void testToString() { + String str = resultWithAllFields.toString(); + assertNotNull(str); + assertTrue(str.contains("memory-123")); + assertTrue(str.contains("This is a test memory content")); + assertTrue(str.contains("0.95")); + assertTrue(str.contains("session-456")); + } + + @Test + public void testDifferentMemoryTypes() throws IOException { + // Test with FACT type + MemorySearchResult factResult = MemorySearchResult + .builder() + .memoryId("fact-123") + .memory("User's name is John") + .score(0.9f) + .memoryType(MemoryType.FACT) + .build(); + + BytesStreamOutput out = new BytesStreamOutput(); + factResult.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MemorySearchResult deserialized = new MemorySearchResult(in); + + assertEquals(MemoryType.FACT, deserialized.getMemoryType()); + + // Test XContent + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + factResult.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + assertTrue(jsonString.contains("\"memory_type\":\"FACT\"")); + } + + @Test + public void testScoreValues() { + // Test various score values + float[] scores = { 0.0f, 0.5f, 0.999f, 1.0f, 100.0f }; + + for (float score : scores) { + MemorySearchResult result = MemorySearchResult.builder().memoryId("id-" + score).memory("memory-" + score).score(score).build(); + + assertEquals(score, result.getScore(), 0.001f); + } + } + + @Test + public void testSpecialCharactersInFields() throws IOException { + Map specialTags = new HashMap<>(); + specialTags.put("key with spaces", "value with\nnewlines"); + specialTags.put("unicode_key_🔥", "unicode_value_✨"); + + MemorySearchResult specialResult = MemorySearchResult + .builder() + .memoryId("id-with-special-chars-🚀") + .memory("Memory with\n\ttabs and\nnewlines and \"quotes\"") + .score(0.99f) + .sessionId("session-with-special-chars") + .role("user/assistant") + .tags(specialTags) + .createdTime(testCreatedTime) + .lastUpdatedTime(testUpdatedTime) + .build(); + + // Test serialization round trip + BytesStreamOutput out = new BytesStreamOutput(); + specialResult.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MemorySearchResult deserialized = new MemorySearchResult(in); + + assertEquals(specialResult.getMemoryId(), deserialized.getMemoryId()); + assertEquals(specialResult.getMemory(), deserialized.getMemory()); + assertEquals(specialResult.getRole(), deserialized.getRole()); + assertEquals(specialResult.getTags(), deserialized.getTags()); + } + + @Test + public void testNullHandling() throws IOException { + // Create result with explicit nulls + MemorySearchResult nullResult = new MemorySearchResult( + "id-null", + "memory-null", + 0.1f, + null, + null, + null, + null, + null, + null, + null, + null + ); + + // Test serialization + BytesStreamOutput out = new BytesStreamOutput(); + nullResult.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MemorySearchResult deserialized = new MemorySearchResult(in); + + assertEquals("id-null", deserialized.getMemoryId()); + assertEquals("memory-null", deserialized.getMemory()); + assertEquals(0.1f, deserialized.getScore(), 0.001f); + assertNull(deserialized.getSessionId()); + assertNull(deserialized.getAgentId()); + assertNull(deserialized.getUserId()); + assertNull(deserialized.getMemoryType()); + assertNull(deserialized.getRole()); + assertNull(deserialized.getTags()); + assertNull(deserialized.getCreatedTime()); + assertNull(deserialized.getLastUpdatedTime()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MessageInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MessageInputTest.java new file mode 100644 index 0000000000..b5bba155f1 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/memorycontainer/memory/MessageInputTest.java @@ -0,0 +1,242 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.memorycontainer.memory; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; + +public class MessageInputTest { + + private MessageInput messageWithRole; + private MessageInput messageWithoutRole; + + @Before + public void setUp() { + messageWithRole = MessageInput.builder().role("user").content("Hello, how are you?").build(); + + messageWithoutRole = MessageInput.builder().content("Just a message without role").build(); + } + + @Test + public void testBuilderWithRole() { + assertNotNull(messageWithRole); + assertEquals("user", messageWithRole.getRole()); + assertEquals("Hello, how are you?", messageWithRole.getContent()); + } + + @Test + public void testBuilderWithoutRole() { + assertNotNull(messageWithoutRole); + assertNull(messageWithoutRole.getRole()); + assertEquals("Just a message without role", messageWithoutRole.getContent()); + } + + @Test + public void testConstructorWithRole() { + MessageInput message = new MessageInput("assistant", "I'm doing well, thank you!"); + assertEquals("assistant", message.getRole()); + assertEquals("I'm doing well, thank you!", message.getContent()); + } + + @Test + public void testConstructorWithoutRole() { + MessageInput message = new MessageInput(null, "Message content"); + assertNull(message.getRole()); + assertEquals("Message content", message.getContent()); + } + + @Test + public void testConstructorWithNullContent() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> new MessageInput("user", null)); + assertEquals("Content is required", exception.getMessage()); + } + + @Test + public void testConstructorWithEmptyContent() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> new MessageInput("user", "")); + assertEquals("Content is required", exception.getMessage()); + } + + @Test + public void testStreamInputOutput() throws IOException { + // Test with role + BytesStreamOutput out = new BytesStreamOutput(); + messageWithRole.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MessageInput deserialized = new MessageInput(in); + + assertEquals(messageWithRole.getRole(), deserialized.getRole()); + assertEquals(messageWithRole.getContent(), deserialized.getContent()); + } + + @Test + public void testStreamInputOutputWithoutRole() throws IOException { + // Test without role + BytesStreamOutput out = new BytesStreamOutput(); + messageWithoutRole.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MessageInput deserialized = new MessageInput(in); + + assertNull(deserialized.getRole()); + assertEquals(messageWithoutRole.getContent(), deserialized.getContent()); + } + + @Test + public void testToXContent() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + messageWithRole.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("\"role\":\"user\"")); + assertTrue(jsonString.contains("\"content\":\"Hello, how are you?\"")); + } + + @Test + public void testToXContentWithoutRole() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + messageWithoutRole.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(!jsonString.contains("\"role\"")); + assertTrue(jsonString.contains("\"content\":\"Just a message without role\"")); + } + + @Test + public void testParse() throws IOException { + String jsonString = "{\"role\":\"user\",\"content\":\"Test message\"}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MessageInput parsed = MessageInput.parse(parser); + + assertEquals("user", parsed.getRole()); + assertEquals("Test message", parsed.getContent()); + } + + @Test + public void testParseWithoutRole() throws IOException { + String jsonString = "{\"content\":\"Message without role\"}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MessageInput parsed = MessageInput.parse(parser); + + assertNull(parsed.getRole()); + assertEquals("Message without role", parsed.getContent()); + } + + @Test + public void testParseWithUnknownFields() throws IOException { + String jsonString = "{\"role\":\"assistant\",\"content\":\"Test\",\"unknown\":\"field\"}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MessageInput parsed = MessageInput.parse(parser); + + assertEquals("assistant", parsed.getRole()); + assertEquals("Test", parsed.getContent()); + } + + @Test + public void testSetters() { + MessageInput message = new MessageInput(null, "Initial content"); + + message.setRole("system"); + message.setContent("Updated content"); + + assertEquals("system", message.getRole()); + assertEquals("Updated content", message.getContent()); + } + + @Test + public void testSpecialCharactersInContent() throws IOException { + MessageInput specialMessage = new MessageInput( + "user", + "Content with\n\ttabs,\nnewlines, \"quotes\", 'single quotes', and unicode 🚀✨" + ); + + // Test serialization round trip + BytesStreamOutput out = new BytesStreamOutput(); + specialMessage.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MessageInput deserialized = new MessageInput(in); + + assertEquals(specialMessage.getRole(), deserialized.getRole()); + assertEquals(specialMessage.getContent(), deserialized.getContent()); + + // Test XContent + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + specialMessage.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + assertTrue(jsonString.contains("Content with")); + assertTrue(jsonString.contains("tabs")); + assertTrue(jsonString.contains("quotes")); + } + + @Test + public void testRoleValues() throws IOException { + String[] roles = { "user", "assistant", "system", "human", "ai", null }; + + for (String role : roles) { + MessageInput message = new MessageInput(role, "Test content"); + assertEquals(role, message.getRole()); + + // Test round trip + BytesStreamOutput out = new BytesStreamOutput(); + message.writeTo(out); + StreamInput in = out.bytes().streamInput(); + MessageInput deserialized = new MessageInput(in); + assertEquals(role, deserialized.getRole()); + } + } + + @Test + public void testXContentRoundTrip() throws IOException { + // Convert to XContent + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + messageWithRole.toXContent(builder, EMPTY_PARAMS); + String jsonString = TestHelper.xContentBuilderToString(builder); + + // Parse back + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + MessageInput parsed = MessageInput.parse(parser); + + // Verify all fields match + assertEquals(messageWithRole.getRole(), parsed.getRole()); + assertEquals(messageWithRole.getContent(), parsed.getContent()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java index 2b191c14a2..66e7de8570 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java @@ -67,7 +67,7 @@ public void toXContentTest() throws IOException { + "\"algorithm\":\"KMEANS\"," + "\"model_version\":\"1.0.0\"," + "\"model_content\":\"content\"," - + "\"user\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null},\"model_state\":\"TRAINED\"}", + + "\"user\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null,\"user_requested_tenant_access\":null},\"model_state\":\"TRAINED\"}", jsonStr ); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java index 5591e8d273..8d4c952d72 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java @@ -38,6 +38,7 @@ import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.controller.MLRateLimiter; +import org.opensearch.ml.common.model.BaseModelConfig; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; @@ -62,8 +63,8 @@ public class MLUpdateModelInputTest { "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":" + "\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" + "{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"},\"model_config\":" - + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" - + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector\":" + + "{\"model_type\":\"testModelType\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"," + + "\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\"},\"connector\":" + "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":" + "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":" + "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":" @@ -77,13 +78,18 @@ public class MLUpdateModelInputTest { + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":" + "\"test-connector_id\"}"; + private final String expectedOutputStrSpaceType = + "{\"model_id\":\"test-model-id\",\"name\":\"name\",\"description\":\"description\",\"model_group_id\":" + + "\"modelGroupId\",\"model_config\":" + + "{\"model_type\":\"sparse_encoding\",\"additional_config\":{\"space_type\":\"l2\"}}}"; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @Before public void setUp() throws Exception { - MLModelConfig config = TextEmbeddingModelConfig - .builder() + MLModelConfig config = BaseModelConfig + .baseModelConfigBuilder() .modelType("testModelType") .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) @@ -214,6 +220,20 @@ public void parseWithIllegalFieldWithoutModel() throws Exception { }); } + @Test + public void parseWithSpaceType() throws Exception { + String expectedInputStrWithSpaceType = "{\"model_id\":\"test-model-id\",\"name\":\"name\",\"description\":\"description\"," + + "\"model_group_id\":\"modelGroupId\",\"model_config\":{\"model_type\":\"sparse_encoding\"," + + "\"additional_config\":{\"space_type\":\"l2\"}}}"; + testParseFromJsonString(expectedInputStrWithSpaceType, parsedInput -> { + try { + assertEquals(expectedOutputStrSpaceType, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + @Test public void serializationWithTenantId_Success() throws IOException { MLUpdateModelInput input = updateModelInput.toBuilder().tenantId("tenant-1").build(); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetResponseTest.java index 2d4db967c1..3bf346d1c3 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetResponseTest.java @@ -86,7 +86,7 @@ public void toXContentTest() throws IOException { + "\"create_time\":123," + "\"last_update_time\":123," + "\"error\":\"error\"," - + "\"user\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + + "\"user\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null,\"user_requested_tenant_access\":null}," + "\"is_async\":true}", jsonStr ); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java index 272854983a..d1082ac472 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java @@ -44,7 +44,18 @@ public void setup() { Map additionalConfig = new HashMap<>(); additionalConfig.put("test_key", "test_value"); - config = new BaseModelConfig("Model Type", "\"test_key1\":\"test_value1\"", additionalConfig); + config = new BaseModelConfig( + "Model Type", + "\"test_key1\":\"test_value1\"", + additionalConfig, + 768, + BaseModelConfig.FrameworkType.SENTENCE_TRANSFORMERS, + BaseModelConfig.PoolingMode.MEAN, + false, + null, + null, + null + ); mLRegisterModelMetaInput = new MLRegisterModelMetaInput( "Model Name", @@ -128,7 +139,8 @@ public void testToXContent() throws IOException { + "\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\"," + "\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\"," + "\"model_config\":{\"model_type\":\"Model Type\",\"all_config\":\"\\\"test_key1\\\":\\\"test_value1\\\"\"," - + "\"additional_config\":{\"test_key\":\"test_value\"}},\"total_chunks\":2," + + "\"additional_config\":{\"test_key\":\"test_value\"},\"embedding_dimension\":768," + + "\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"pooling_mode\":\"MEAN\"},\"total_chunks\":2," + "\"add_all_backend_roles\":false,\"does_version_create_model_group\":false,\"is_hidden\":false}"; assertEquals(expected, mlModelContent); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java index 7ca1b86fbe..da72ddebe6 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java @@ -34,7 +34,18 @@ public void setUp() { Map additionalConfig = new HashMap<>(); additionalConfig.put("test_key", "test_value"); - config = new BaseModelConfig("Model Type", "\"test_key1\":\"test_value1\"", additionalConfig); + config = new BaseModelConfig( + "Model Type", + "\"test_key1\":\"test_value1\"", + additionalConfig, + 768, + BaseModelConfig.FrameworkType.SENTENCE_TRANSFORMERS, + BaseModelConfig.PoolingMode.MEAN, + false, + null, + null, + null + ); mlRegisterModelMetaInput = new MLRegisterModelMetaInput( "Model Name", FunctionName.BATCH_RCF, diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java index 693caa7d7a..155f78d3d7 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -7,17 +7,13 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; -import static org.opensearch.ml.common.utils.StringUtils.TO_STRING_FUNCTION_NAME; -import static org.opensearch.ml.common.utils.StringUtils.collectToStringPrefixes; -import static org.opensearch.ml.common.utils.StringUtils.getJsonPath; -import static org.opensearch.ml.common.utils.StringUtils.isValidJSONPath; -import static org.opensearch.ml.common.utils.StringUtils.obtainFieldNameFromJsonPath; -import static org.opensearch.ml.common.utils.StringUtils.parseParameters; -import static org.opensearch.ml.common.utils.StringUtils.toJson; +import static org.opensearch.ml.common.utils.StringUtils.*; import java.io.IOException; import java.util.ArrayList; @@ -190,7 +186,7 @@ public void addDefaultMethod_NoEscape() { public void addDefaultMethod_Escape() { String input = "return escape(\"abc\n123\");"; String result = StringUtils.addDefaultMethod(input); - Assert.assertNotEquals(input, result); + assertNotEquals(input, result); assertTrue(result.startsWith(StringUtils.DEFAULT_ESCAPE_FUNCTION)); } @@ -858,4 +854,99 @@ public void testValidateFields_InvalidCharacterSet() { assertTrue(exception.getMessage().contains("Field1")); } + @Test + public void prepareJsonValue_returnsRawIfJson() { + String json = "{\"key\": 123}"; + String result = StringUtils.prepareJsonValue(json); + assertSame(json, result); // branch where isJson(input)==true + } + + @Test + public void prepareJsonValue_escapesBadCharsOtherwise() { + String input = "Tom & Jerry \"