Skip to content

Commit

Permalink
Merge branch 'opensearch-project:main' into time-bug-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
brianf-aws authored Sep 17, 2024
2 parents fc18739 + 0d26931 commit c606ffa
Show file tree
Hide file tree
Showing 101 changed files with 7,629 additions and 292 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLConfig;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.ToolMetadata;
Expand Down Expand Up @@ -428,4 +429,20 @@ default ActionFuture<ToolMetadata> getTool(String toolName) {
*/
void getTool(String toolName, ActionListener<ToolMetadata> listener);

/**
* Get config
* @param configId ML config id
*/
default ActionFuture<MLConfig> getConfig(String configId) {
PlainActionFuture<MLConfig> actionFuture = PlainActionFuture.newFuture();
getConfig(configId, actionFuture);
return actionFuture;
}

/**
* Get config
* @param configId ML config id
* @param listener a listener to be notified of the result
*/
void getConfig(String configId, ActionListener<MLConfig> listener);
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLConfig;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.ToolMetadata;
Expand All @@ -39,6 +40,9 @@
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.config.MLConfigGetAction;
import org.opensearch.ml.common.transport.config.MLConfigGetRequest;
import org.opensearch.ml.common.transport.config.MLConfigGetResponse;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
Expand Down Expand Up @@ -309,6 +313,13 @@ public void getTool(String toolName, ActionListener<ToolMetadata> listener) {
client.execute(MLGetToolAction.INSTANCE, mlToolGetRequest, getMlGetToolResponseActionListener(listener));
}

@Override
public void getConfig(String configId, ActionListener<MLConfig> listener) {
MLConfigGetRequest mlConfigGetRequest = MLConfigGetRequest.builder().configId(configId).build();

client.execute(MLConfigGetAction.INSTANCE, mlConfigGetRequest, getMlGetConfigResponseActionListener(listener));
}

private ActionListener<MLToolsListResponse> getMlListToolsResponseActionListener(ActionListener<List<ToolMetadata>> listener) {
ActionListener<MLToolsListResponse> internalListener = ActionListener.wrap(mlModelListResponse -> {
listener.onResponse(mlModelListResponse.getToolMetadataList());
Expand All @@ -331,6 +342,17 @@ private ActionListener<MLToolGetResponse> getMlGetToolResponseActionListener(Act
return actionListener;
}

private ActionListener<MLConfigGetResponse> getMlGetConfigResponseActionListener(ActionListener<MLConfig> listener) {
ActionListener<MLConfigGetResponse> internalListener = ActionListener.wrap(mlConfigGetResponse -> {
listener.onResponse(mlConfigGetResponse.getMlConfig());
}, listener::onFailure);
ActionListener<MLConfigGetResponse> actionListener = wrapActionListener(internalListener, res -> {
MLConfigGetResponse getResponse = MLConfigGetResponse.fromActionResponse(res);
return getResponse;
});
return actionListener;
}

private ActionListener<MLRegisterAgentResponse> getMLRegisterAgentResponseActionListener(
ActionListener<MLRegisterAgentResponse> listener
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import static org.opensearch.ml.common.input.Constants.KMEANS;
import static org.opensearch.ml.common.input.Constants.TRAIN;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
Expand All @@ -28,8 +29,10 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.Configuration;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLAgentType;
import org.opensearch.ml.common.MLConfig;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.ToolMetadata;
Expand All @@ -46,6 +49,7 @@
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.config.MLConfigGetResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
Expand Down Expand Up @@ -99,9 +103,13 @@ public class MachineLearningClientTest {
@Mock
MLRegisterAgentResponse registerAgentResponse;

@Mock
MLConfigGetResponse configGetResponse;

private String modekId = "test_model_id";
private MLModel mlModel;
private MLTask mlTask;
private MLConfig mlConfig;
private ToolMetadata toolMetadata;
private List<ToolMetadata> toolsList = new ArrayList<>();

Expand All @@ -124,6 +132,14 @@ public void setUp() {
.build();
toolsList.add(toolMetadata);

mlConfig = MLConfig
.builder()
.type("dummyType")
.configuration(Configuration.builder().agentId("agentId").build())
.createTime(Instant.now())
.lastUpdateTime(Instant.now())
.build();

machineLearningClient = new MachineLearningClient() {
@Override
public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
Expand Down Expand Up @@ -231,6 +247,11 @@ public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentRespons
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void getConfig(String configId, ActionListener<MLConfig> listener) {
listener.onResponse(mlConfig);
}
};
}

Expand Down Expand Up @@ -503,4 +524,9 @@ public void getTool() {
public void listTools() {
assertEquals(toolMetadata, machineLearningClient.listTools().actionGet().get(0));
}

@Test
public void getConfig() {
assertEquals(mlConfig, machineLearningClient.getConfig("configId").actionGet());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
import static org.opensearch.ml.common.input.Constants.ACTION;
import static org.opensearch.ml.common.input.Constants.ALGORITHM;
import static org.opensearch.ml.common.input.Constants.KMEANS;
Expand All @@ -40,6 +41,7 @@
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
Expand All @@ -51,12 +53,15 @@
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.Configuration;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLAgentType;
import org.opensearch.ml.common.MLConfig;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
Expand Down Expand Up @@ -84,6 +89,9 @@
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.config.MLConfigGetAction;
import org.opensearch.ml.common.transport.config.MLConfigGetRequest;
import org.opensearch.ml.common.transport.config.MLConfigGetResponse;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
Expand Down Expand Up @@ -206,6 +214,9 @@ public class MachineLearningNodeClientTest {
@Mock
ActionListener<ToolMetadata> getToolActionListener;

@Mock
ActionListener<MLConfig> getMlConfigListener;

@InjectMocks
MachineLearningNodeClient machineLearningNodeClient;

Expand Down Expand Up @@ -951,6 +962,43 @@ public void listTools() {
assertEquals("Use this tool to search general knowledge on wikipedia.", argumentCaptor.getValue().get(0).getDescription());
}

@Test
public void getConfig() {
MLConfig mlConfig = MLConfig.builder().type("type").configuration(Configuration.builder().agentId("agentId").build()).build();

doAnswer(invocation -> {
ActionListener<MLConfigGetResponse> actionListener = invocation.getArgument(2);
MLConfigGetResponse output = MLConfigGetResponse.builder().mlConfig(mlConfig).build();
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLConfigGetAction.INSTANCE), any(), any());

ArgumentCaptor<MLConfig> argumentCaptor = ArgumentCaptor.forClass(MLConfig.class);
machineLearningNodeClient.getConfig("agentId", getMlConfigListener);

verify(client).execute(eq(MLConfigGetAction.INSTANCE), isA(MLConfigGetRequest.class), any());
verify(getMlConfigListener).onResponse(argumentCaptor.capture());
assertEquals("agentId", argumentCaptor.getValue().getConfiguration().getAgentId());
assertEquals("type", argumentCaptor.getValue().getType());
}

@Test
public void getConfigRejectedMasterKey() {
doAnswer(invocation -> {
ActionListener<MLConfigGetResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new OpenSearchStatusException("You are not allowed to access this config doc", RestStatus.FORBIDDEN));
return null;
}).when(client).execute(eq(MLConfigGetAction.INSTANCE), any(), any());

ArgumentCaptor<OpenSearchStatusException> argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
machineLearningNodeClient.getConfig(MASTER_KEY, getMlConfigListener);

verify(client).execute(eq(MLConfigGetAction.INSTANCE), isA(MLConfigGetRequest.class), any());
verify(getMlConfigListener).onFailure(argumentCaptor.capture());
assertEquals(RestStatus.FORBIDDEN, argumentCaptor.getValue().status());
assertEquals("You are not allowed to access this config doc", argumentCaptor.getValue().getLocalizedMessage());
}

private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public class CommonValue {
public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 11;
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 3;
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 3;
public static final String ML_CONFIG_INDEX = ".plugins-ml-config";
public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 3;
Expand Down Expand Up @@ -391,6 +391,9 @@ public class CommonValue {
+ " \""
+ MLTask.IS_ASYNC_TASK_FIELD
+ "\" : {\"type\" : \"boolean\"}, \n"
+ " \""
+ MLTask.REMOTE_JOB_FIELD
+ "\" : {\"type\": \"flat_object\"}, \n"
+ USER_FIELD_MAPPING
+ " }\n"
+ "}";
Expand Down Expand Up @@ -575,6 +578,7 @@ public class CommonValue {
public static final Version VERSION_2_12_0 = Version.fromString("2.12.0");
public static final Version VERSION_2_13_0 = Version.fromString("2.13.0");
public static final Version VERSION_2_14_0 = Version.fromString("2.14.0");
public static final Version VERSION_2_15_0 = Version.fromString("2.15.0");
public static final Version VERSION_2_16_0 = Version.fromString("2.16.0");
public static final Version VERSION_2_17_0 = Version.fromString("2.17.0");
}
Loading

0 comments on commit c606ffa

Please sign in to comment.