From d66e1c4c2d0b0cf467bbd0b2fd1dccb7962e5f04 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 2 Jan 2024 03:28:57 +0000 Subject: [PATCH] Base class for Integ test; add integ test for NeuralSparseSearchTool (#86) * add common components Signed-off-by: zhichao-aws * add common components Signed-off-by: zhichao-aws * add basic components Signed-off-by: zhichao-aws * rebase main Signed-off-by: zhichao-aws * add basic components for it, add it Signed-off-by: zhichao-aws * rebase main Signed-off-by: zhichao-aws * tidy Signed-off-by: zhichao-aws * change neural sparse model to pretrained tokenizer Signed-off-by: zhichao-aws * rm redundant line Signed-off-by: zhichao-aws * add comments Signed-off-by: zhichao-aws * tidy Signed-off-by: zhichao-aws * add register connector Signed-off-by: zhichao-aws --------- Signed-off-by: zhichao-aws (cherry picked from commit 34ae75f81ed47ecc2e4006d32343e12fa388678b) Signed-off-by: github-actions[bot] --- build.gradle | 22 ++ .../integTest/BaseAgentToolsIT.java | 241 ++++++++++++++++++ .../integTest/NeuralSparseSearchToolIT.java | 161 ++++++++++++ .../OpenSearchSecureRestTestCase.java | 185 ++++++++++++++ ...eural_sparse_search_tool_request_body.json | 17 ++ ...er_sparse_encoding_model_request_body.json | 5 + 6 files changed, 631 insertions(+) create mode 100644 src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java create mode 100644 src/test/java/org/opensearch/integTest/NeuralSparseSearchToolIT.java create mode 100644 src/test/java/org/opensearch/integTest/OpenSearchSecureRestTestCase.java create mode 100644 src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_neural_sparse_search_tool_request_body.json create mode 100644 src/test/resources/org/opensearch/agent/tools/register_sparse_encoding_model_request_body.json diff --git a/build.gradle b/build.gradle index 2a7ba02f..f72a66a8 100644 --- a/build.gradle +++ b/build.gradle @@ -125,6 +125,8 @@ dependencies { zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${version}" zipArchive "org.opensearch.plugin:opensearch-anomaly-detection:${version}" zipArchive group: 'org.opensearch.plugin', name:'opensearch-sql-plugin', version: "${version}" + zipArchive group: 'org.opensearch.plugin', name:'opensearch-knn', version: "${version}" + zipArchive group: 'org.opensearch.plugin', name:'neural-search', version: "${version}" // Test dependencies testImplementation "org.opensearch.test:framework:${opensearch_version}" @@ -348,6 +350,26 @@ testClusters.integTest { } } +// Remote Integration Tests +task integTestRemote(type: RestIntegTestTask) { + testClassesDirs = sourceSets.test.output.classesDirs + classpath = sourceSets.test.runtimeClasspath + + systemProperty "https", System.getProperty("https") + systemProperty "user", System.getProperty("user") + systemProperty "password", System.getProperty("password") + + systemProperty 'cluster.number_of_nodes', "${_numNodes}" + + systemProperty 'tests.security.manager', 'false' + // Run tests with remote cluster only if rest case is defined + if (System.getProperty("tests.rest.cluster") != null) { + filter { + includeTestsMatching "org.opensearch.integTest.*IT" + } + } +} + // Automatically sets up the integration test cluster locally run { useCluster testClusters.integTest diff --git a/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java b/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java new file mode 100644 index 00000000..1162064a --- /dev/null +++ b/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java @@ -0,0 +1,241 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.apache.commons.lang3.StringUtils; +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.Header; +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.io.entity.StringEntity; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Before; +import org.opensearch.client.*; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.input.execute.agent.AgentMLInput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; + +import com.google.common.collect.ImmutableList; +import com.google.gson.Gson; + +import lombok.SneakyThrows; + +public abstract class BaseAgentToolsIT extends OpenSearchSecureRestTestCase { + public static final Gson gson = new Gson(); + private static final int MAX_TASK_RESULT_QUERY_TIME_IN_SECOND = 60 * 5; + private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000; + + /** + * Update cluster settings to run ml models + */ + @Before + public void updateClusterSettings() { + updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false); + // default threshold for native circuit breaker is 90, it may be not enough on test runner machine + updateClusterSettings("plugins.ml_commons.native_memory_threshold", 100); + updateClusterSettings("plugins.ml_commons.allow_registering_model_via_url", true); + } + + @SneakyThrows + protected void updateClusterSettings(String settingKey, Object value) { + XContentBuilder builder = XContentFactory + .jsonBuilder() + .startObject() + .startObject("persistent") + .field(settingKey, value) + .endObject() + .endObject(); + Response response = makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + builder.toString(), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + @SneakyThrows + private Map parseResponseToMap(Response response) { + Map responseInMap = XContentHelper + .convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(response.getEntity()), false); + response.getEntity().toString(); + return responseInMap; + } + + @SneakyThrows + private Object parseFieldFromResponse(Response response, String field) { + assertNotNull(field); + Map map = parseResponseToMap(response); + Object result = map.get(field); + assertNotNull(result); + return result; + } + + protected String createConnector(String requestBody) { + Response response = makeRequest(client(), "POST", "/_plugins/_ml/connectors/_create", null, requestBody, null); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + return parseFieldFromResponse(response, MLModel.CONNECTOR_ID_FIELD).toString(); + } + + protected String registerModel(String requestBody) { + Response response = makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, requestBody, null); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + return parseFieldFromResponse(response, MLTask.TASK_ID_FIELD).toString(); + } + + protected String deployModel(String modelId) { + Response response = makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_deploy", null, (String) null, null); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + return parseFieldFromResponse(response, MLTask.TASK_ID_FIELD).toString(); + } + + @SneakyThrows + protected Response waitTaskComplete(String taskId) { + for (int i = 0; i < MAX_TASK_RESULT_QUERY_TIME_IN_SECOND; i++) { + Response response = makeRequest(client(), "GET", "/_plugins/_ml/tasks/" + taskId, null, (String) null, null); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + String state = parseFieldFromResponse(response, MLTask.STATE_FIELD).toString(); + if (state.equals(MLTaskState.COMPLETED.toString())) { + return response; + } + if (state.equals(MLTaskState.FAILED.toString()) + || state.equals(MLTaskState.CANCELLED.toString()) + || state.equals(MLTaskState.COMPLETED_WITH_ERROR.toString())) { + fail("The task failed with state " + state); + } + Thread.sleep(DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND); + } + fail("The task failed to complete after " + MAX_TASK_RESULT_QUERY_TIME_IN_SECOND + " seconds."); + return null; + } + + // Register the model then deploy it. Returns the model_id until the model is deployed + protected String registerModelThenDeploy(String requestBody) { + String registerModelTaskId = registerModel(requestBody); + Response registerTaskResponse = waitTaskComplete(registerModelTaskId); + String modelId = parseFieldFromResponse(registerTaskResponse, MLTask.MODEL_ID_FIELD).toString(); + String deployModelTaskId = deployModel(modelId); + waitTaskComplete(deployModelTaskId); + return modelId; + } + + protected void createIndexWithConfiguration(String indexName, String indexConfiguration) throws Exception { + Response response = makeRequest(client(), "PUT", indexName, null, indexConfiguration, null); + assertEquals("true", parseFieldFromResponse(response, "acknowledged").toString()); + assertEquals(indexName, parseFieldFromResponse(response, "index").toString()); + } + + @SneakyThrows + protected void addDocToIndex(String indexName, String docId, List fieldNames, List fieldContents) { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + for (int i = 0; i < fieldNames.size(); i++) { + builder.field(fieldNames.get(i), fieldContents.get(i)); + } + builder.endObject(); + Response response = makeRequest( + client(), + "POST", + "/" + indexName + "/_doc/" + docId + "?refresh=true", + null, + builder.toString(), + null + ); + assertEquals(RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + public String createAgent(String requestBody) { + Response response = makeRequest(client(), "POST", "/_plugins/_ml/agents/_register", null, requestBody, null); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + return parseFieldFromResponse(response, AgentMLInput.AGENT_ID_FIELD).toString(); + } + + private String parseStringResponseFromExecuteAgentResponse(Response response) { + Map responseInMap = parseResponseToMap(response); + Optional optionalResult = Optional + .ofNullable(responseInMap) + .map(m -> (List) m.get(ModelTensorOutput.INFERENCE_RESULT_FIELD)) + .map(l -> (Map) l.get(0)) + .map(m -> (List) m.get(ModelTensors.OUTPUT_FIELD)) + .map(l -> (Map) l.get(0)) + .map(m -> (String) (m.get(ModelTensor.RESULT_FIELD))); + return optionalResult.get(); + } + + // execute the agent, and return the String response from the json structure + // {"inference_results": [{"output": [{"name": "response","result": "the result to return."}]}]} + public String executeAgent(String agentId, String requestBody) { + Response response = makeRequest(client(), "POST", "/_plugins/_ml/agents/" + agentId + "/_execute", null, requestBody, null); + return parseStringResponseFromExecuteAgentResponse(response); + } + + public static Response makeRequest( + RestClient client, + String method, + String endpoint, + Map params, + String jsonEntity, + List
headers + ) { + HttpEntity httpEntity = StringUtils.isBlank(jsonEntity) ? null : new StringEntity(jsonEntity, ContentType.APPLICATION_JSON); + return makeRequest(client, method, endpoint, params, httpEntity, headers); + } + + public static Response makeRequest( + RestClient client, + String method, + String endpoint, + Map params, + HttpEntity entity, + List
headers + ) { + return makeRequest(client, method, endpoint, params, entity, headers, false); + } + + @SneakyThrows + public static Response makeRequest( + RestClient client, + String method, + String endpoint, + Map params, + HttpEntity entity, + List
headers, + boolean strictDeprecationMode + ) { + Request request = new Request(method, endpoint); + + RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder(); + if (headers != null) { + headers.forEach(header -> options.addHeader(header.getName(), header.getValue())); + } + options.setWarningsHandler(strictDeprecationMode ? WarningsHandler.STRICT : WarningsHandler.PERMISSIVE); + request.setOptions(options.build()); + + if (params != null) { + params.forEach(request::addParameter); + } + if (entity != null) { + request.setEntity(entity); + } + return client.performRequest(request); + } +} diff --git a/src/test/java/org/opensearch/integTest/NeuralSparseSearchToolIT.java b/src/test/java/org/opensearch/integTest/NeuralSparseSearchToolIT.java new file mode 100644 index 00000000..2dda2095 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/NeuralSparseSearchToolIT.java @@ -0,0 +1,161 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.client.ResponseException; + +import lombok.SneakyThrows; + +public class NeuralSparseSearchToolIT extends BaseAgentToolsIT { + public static String TEST_INDEX_NAME = "test_index"; + + private String modelId; + private String registerAgentRequestBody; + + @SneakyThrows + private void prepareModel() { + String requestBody = Files + .readString( + Path + .of( + this + .getClass() + .getClassLoader() + .getResource("org/opensearch/agent/tools/register_sparse_encoding_model_request_body.json") + .toURI() + ) + ); + modelId = registerModelThenDeploy(requestBody); + } + + @SneakyThrows + private void prepareIndex() { + createIndexWithConfiguration( + TEST_INDEX_NAME, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"text\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"embedding\": {\n" + + " \"type\": \"rank_features\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + addDocToIndex(TEST_INDEX_NAME, "0", List.of("text", "embedding"), List.of("text doc 1", Map.of("hello", 1, "world", 2))); + addDocToIndex(TEST_INDEX_NAME, "1", List.of("text", "embedding"), List.of("text doc 2", Map.of("a", 3, "b", 4))); + addDocToIndex(TEST_INDEX_NAME, "2", List.of("text", "embedding"), List.of("text doc 3", Map.of("test", 5, "a", 6))); + } + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + prepareModel(); + prepareIndex(); + registerAgentRequestBody = Files + .readString( + Path + .of( + this + .getClass() + .getClassLoader() + .getResource("org/opensearch/agent/tools/register_flow_agent_of_neural_sparse_search_tool_request_body.json") + .toURI() + ) + ); + registerAgentRequestBody = registerAgentRequestBody.replace("", modelId); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteExternalIndices(); + } + + public void testNeuralSparseSearchToolInFlowAgent() { + String agentId = createAgent(registerAgentRequestBody); + // successful case + String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"); + assertEquals( + "The agent execute response not equal with expected.", + "{\"_index\":\"test_index\",\"_source\":{\"text\":\"text doc 3\"},\"_id\":\"2\",\"_score\":2.4136734}\n" + + "{\"_index\":\"test_index\",\"_source\":{\"text\":\"text doc 2\"},\"_id\":\"1\",\"_score\":1.2068367}\n", + result + ); + + // use non-exist token to test the case the tool can not find match docs. + String result2 = executeAgent(agentId, "{\"parameters\": {\"question\": \"c\"}}"); + assertEquals("The agent execute response not equal with expected.", "Can not get any match from search result.", result2); + + // if blank input, call onFailure and get exception + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"\"}}")); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf(containsString("[input] is null or empty, can not process it."), containsString("illegal_argument_exception")) + ); + } + + public void testNeuralSparseSearchToolInFlowAgent_withIllegalSourceField_thenGetEmptySource() { + String agentId = createAgent(registerAgentRequestBody.replace("text", "text2")); + String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"); + assertEquals( + "The agent execute response not equal with expected.", + "{\"_index\":\"test_index\",\"_source\":{},\"_id\":\"2\",\"_score\":2.4136734}\n" + + "{\"_index\":\"test_index\",\"_source\":{},\"_id\":\"1\",\"_score\":1.2068367}\n", + result + ); + } + + public void testNeuralSparseSearchToolInFlowAgent_withIllegalEmbeddingField_thenThrowException() { + String agentId = createAgent(registerAgentRequestBody.replace("\"embedding\"", "\"embedding2\"")); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf( + containsString("failed to create query: [neural_sparse] query only works on [rank_features] fields"), + containsString("search_phase_execution_exception") + ) + ); + } + + public void testNeuralSparseSearchToolInFlowAgent_withIllegalIndexField_thenThrowException() { + String agentId = createAgent(registerAgentRequestBody.replace("test_index", "test_index2")); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf(containsString("no such index [test_index2]"), containsString("index_not_found_exception")) + ); + } + + public void testNeuralSparseSearchToolInFlowAgent_withIllegalModelIdField_thenThrowException() { + String agentId = createAgent(registerAgentRequestBody.replace(modelId, "test_model_id")); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + + org.hamcrest.MatcherAssert + .assertThat(exception.getMessage(), allOf(containsString("Failed to find model"), containsString("status_exception"))); + } +} diff --git a/src/test/java/org/opensearch/integTest/OpenSearchSecureRestTestCase.java b/src/test/java/org/opensearch/integTest/OpenSearchSecureRestTestCase.java new file mode 100644 index 00000000..511609a7 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/OpenSearchSecureRestTestCase.java @@ -0,0 +1,185 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import static org.opensearch.client.RestClientBuilder.DEFAULT_MAX_CONN_PER_ROUTE; +import static org.opensearch.client.RestClientBuilder.DEFAULT_MAX_CONN_TOTAL; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import org.apache.hc.client5.http.auth.AuthScope; +import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; +import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; +import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManager; +import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManagerBuilder; +import org.apache.hc.client5.http.ssl.ClientTlsStrategyBuilder; +import org.apache.hc.client5.http.ssl.NoopHostnameVerifier; +import org.apache.hc.core5.http.Header; +import org.apache.hc.core5.http.HttpHost; +import org.apache.hc.core5.http.message.BasicHeader; +import org.apache.hc.core5.http.nio.ssl.TlsStrategy; +import org.apache.hc.core5.ssl.SSLContextBuilder; +import org.apache.hc.core5.util.Timeout; +import org.junit.After; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.client.RestClient; +import org.opensearch.client.RestClientBuilder; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.MediaType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.test.rest.OpenSearchRestTestCase; + +/** + * Base class for running the integration tests on a secure cluster. The plugin IT test should either extend this + * class or create another base class by extending this class to make sure that their IT can be run on secure clusters. + */ +public abstract class OpenSearchSecureRestTestCase extends OpenSearchRestTestCase { + + private static final String PROTOCOL_HTTP = "http"; + private static final String PROTOCOL_HTTPS = "https"; + private static final String SYS_PROPERTY_KEY_HTTPS = "https"; + private static final String SYS_PROPERTY_KEY_CLUSTER_ENDPOINT = "tests.rest.cluster"; + private static final String SYS_PROPERTY_KEY_USER = "user"; + private static final String SYS_PROPERTY_KEY_PASSWORD = "password"; + private static final String DEFAULT_SOCKET_TIMEOUT = "60s"; + private static final String INTERNAL_INDICES_PREFIX = "."; + private static String protocol; + + @Override + protected String getProtocol() { + if (protocol == null) { + protocol = readProtocolFromSystemProperty(); + } + return protocol; + } + + private String readProtocolFromSystemProperty() { + final boolean isHttps = Optional.ofNullable(System.getProperty(SYS_PROPERTY_KEY_HTTPS)).map("true"::equalsIgnoreCase).orElse(false); + if (!isHttps) { + return PROTOCOL_HTTP; + } + + // currently only external cluster is supported for security enabled testing + if (Optional.ofNullable(System.getProperty(SYS_PROPERTY_KEY_CLUSTER_ENDPOINT)).isEmpty()) { + throw new RuntimeException("cluster url should be provided for security enabled testing"); + } + return PROTOCOL_HTTPS; + } + + @Override + protected RestClient buildClient(Settings settings, HttpHost[] hosts) throws IOException { + final RestClientBuilder builder = RestClient.builder(hosts); + if (PROTOCOL_HTTPS.equals(getProtocol())) { + configureHttpsClient(builder, settings); + } else { + configureClient(builder, settings); + } + + return builder.build(); + } + + private void configureHttpsClient(final RestClientBuilder builder, final Settings settings) { + final Map headers = ThreadContext.buildDefaultHeaders(settings); + final Header[] defaultHeaders = new Header[headers.size()]; + int i = 0; + for (Map.Entry entry : headers.entrySet()) { + defaultHeaders[i++] = new BasicHeader(entry.getKey(), entry.getValue()); + } + builder.setDefaultHeaders(defaultHeaders); + builder.setHttpClientConfigCallback(httpClientBuilder -> { + final String userName = Optional + .ofNullable(System.getProperty(SYS_PROPERTY_KEY_USER)) + .orElseThrow(() -> new RuntimeException("user name is missing")); + final String password = Optional + .ofNullable(System.getProperty(SYS_PROPERTY_KEY_PASSWORD)) + .orElseThrow(() -> new RuntimeException("password is missing")); + final BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider(); + final AuthScope anyScope = new AuthScope(null, -1); + credentialsProvider.setCredentials(anyScope, new UsernamePasswordCredentials(userName, password.toCharArray())); + try { + final TlsStrategy tlsStrategy = ClientTlsStrategyBuilder + .create() + .setHostnameVerifier(NoopHostnameVerifier.INSTANCE) + .setSslContext(SSLContextBuilder.create().loadTrustMaterial(null, (chains, authType) -> true).build()) + .build(); + final PoolingAsyncClientConnectionManager connectionManager = PoolingAsyncClientConnectionManagerBuilder + .create() + .setMaxConnPerRoute(DEFAULT_MAX_CONN_PER_ROUTE) + .setMaxConnTotal(DEFAULT_MAX_CONN_TOTAL) + .setTlsStrategy(tlsStrategy) + .build(); + return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider).setConnectionManager(connectionManager); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + + final String socketTimeoutString = settings.get(CLIENT_SOCKET_TIMEOUT); + final TimeValue socketTimeout = TimeValue + .parseTimeValue(socketTimeoutString == null ? DEFAULT_SOCKET_TIMEOUT : socketTimeoutString, CLIENT_SOCKET_TIMEOUT); + builder.setRequestConfigCallback(conf -> { + Timeout timeout = Timeout.ofMilliseconds(Math.toIntExact(socketTimeout.getMillis())); + conf.setConnectTimeout(timeout); + conf.setResponseTimeout(timeout); + return conf; + }); + if (settings.hasValue(CLIENT_PATH_PREFIX)) { + builder.setPathPrefix(settings.get(CLIENT_PATH_PREFIX)); + } + } + + /** + * wipeAllIndices won't work since it cannot delete security index. Use deleteExternalIndices instead. + */ + @Override + protected boolean preserveIndicesUponCompletion() { + return true; + } + + @After + public void deleteExternalIndices() throws IOException { + final Response response = client().performRequest(new Request("GET", "/_cat/indices?format=json" + "&expand_wildcards=all")); + final MediaType xContentType = MediaType.fromMediaType(response.getEntity().getContentType()); + try ( + final XContentParser parser = xContentType + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + response.getEntity().getContent() + ) + ) { + final XContentParser.Token token = parser.nextToken(); + final List> parserList; + if (token == XContentParser.Token.START_ARRAY) { + parserList = parser.listOrderedMap().stream().map(obj -> (Map) obj).collect(Collectors.toList()); + } else { + parserList = Collections.singletonList(parser.mapOrdered()); + } + + final List externalIndices = parserList + .stream() + .map(index -> (String) index.get("index")) + .filter(indexName -> indexName != null) + .filter(indexName -> !indexName.startsWith(INTERNAL_INDICES_PREFIX)) + .collect(Collectors.toList()); + + for (final String indexName : externalIndices) { + adminClient().performRequest(new Request("DELETE", "/" + indexName)); + } + } + } +} diff --git a/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_neural_sparse_search_tool_request_body.json b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_neural_sparse_search_tool_request_body.json new file mode 100644 index 00000000..ac2a2987 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_neural_sparse_search_tool_request_body.json @@ -0,0 +1,17 @@ +{ + "name": "Test_Neural_Sparse_Agent_For_RAG", + "type": "flow", + "tools": [ + { + "type": "NeuralSparseSearchTool", + "parameters": { + "description":"user this tool to search data from the test index", + "model_id": "", + "index": "test_index", + "embedding_field": "embedding", + "source_field": ["text"], + "input": "${parameters.question}" + } + } + ] +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/register_sparse_encoding_model_request_body.json b/src/test/resources/org/opensearch/agent/tools/register_sparse_encoding_model_request_body.json new file mode 100644 index 00000000..8eb7901c --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/register_sparse_encoding_model_request_body.json @@ -0,0 +1,5 @@ +{ + "name":"amazon/neural-sparse/opensearch-neural-sparse-tokenizer-v1", + "version":"1.0.1", + "model_format": "TORCH_SCRIPT" +} \ No newline at end of file