Skip to content

Commit

Permalink
Mock http server for LLM; Integration test for visualization tool (op…
Browse files Browse the repository at this point in the history
…ensearch-project#92) (opensearch-project#102)

* mock server and integTest for visualization

* update rest status

* add refresh

* merge from main

* rename variable name

---------

(cherry picked from commit 4c76e4c)

Signed-off-by: Hailong Cui <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
2 people authored and yuye-aws committed Apr 26, 2024
1 parent 84f7f72 commit 0b12066
Show file tree
Hide file tree
Showing 6 changed files with 441 additions and 4 deletions.
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ dependencies {
// Test dependencies
testImplementation "org.opensearch.test:framework:${opensearch_version}"
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
testImplementation group: 'org.json', name: 'json', version: '20231013'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.8.0'
testImplementation group: 'org.mockito', name: 'mockito-inline', version: '5.2.0'
testImplementation("net.bytebuddy:byte-buddy:1.14.7")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.client.Client;
import org.opensearch.client.Requests;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
Expand All @@ -24,9 +25,6 @@
import org.opensearch.search.SearchHits;
import org.opensearch.search.builder.SearchSourceBuilder;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;

import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
Expand Down Expand Up @@ -113,7 +111,6 @@ public void onFailure(Exception e) {
});
}

@VisibleForTesting
String trimIdPrefix(String id) {
id = Optional.ofNullable(id).orElse("");
if (id.startsWith(SAVED_OBJECT_TYPE)) {
Expand Down
52 changes: 52 additions & 0 deletions src/test/java/org/opensearch/integTest/MockHttpServer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.integTest;

import java.io.IOException;
import java.io.InputStream;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;

import com.google.gson.Gson;
import com.sun.net.httpserver.HttpServer;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class MockHttpServer {

private static Gson gson = new Gson();

public static HttpServer setupMockLLM(List<PromptHandler> promptHandlers) throws IOException {
HttpServer server = HttpServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0);

server.createContext("/invoke", exchange -> {
InputStream ins = exchange.getRequestBody();
String req = new String(ins.readAllBytes(), StandardCharsets.UTF_8);
Map<String, String> map = gson.fromJson(req, Map.class);
String prompt = map.get("prompt");
log.debug("prompt received: {}", prompt);

String llmRes = "";
for (PromptHandler promptHandler : promptHandlers) {
if (promptHandler.apply(prompt)) {
PromptHandler.LLMResponse llmResponse = new PromptHandler.LLMResponse();
llmResponse.setCompletion(promptHandler.response(prompt));
llmRes = gson.toJson(llmResponse);
break;
}
}
byte[] llmResBytes = llmRes.getBytes(StandardCharsets.UTF_8);
exchange.sendResponseHeaders(200, llmResBytes.length);
exchange.getResponseBody().write(llmResBytes);
exchange.close();
});
return server;
}
}
61 changes: 61 additions & 0 deletions src/test/java/org/opensearch/integTest/PromptHandler.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.integTest;

import com.google.gson.annotations.SerializedName;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

public class PromptHandler {

boolean apply(String prompt) {
return prompt.contains(llmThought().getQuestion());
}

LLMThought llmThought() {
return new LLMThought();
}

String response(String prompt) {
if (prompt.contains("TOOL RESPONSE: ")) {
return "```json{\n"
+ " \"thought\": \"Thought: Now I know the final answer\",\n"
+ " \"final_answer\": \"final answer\"\n"
+ "}```";
} else {
return "```json{\n"
+ " \"thought\": \"Thought: Let me use tool to figure out\",\n"
+ " \"action\": \""
+ this.llmThought().getAction()
+ "\",\n"
+ " \"action_input\": \""
+ this.llmThought().getActionInput()
+ "\"\n"
+ "}```";
}
}

@Builder
@NoArgsConstructor
@AllArgsConstructor
@Data
static class LLMThought {
String question;
String action;
String actionInput;
}

@Data
static class LLMResponse {
String completion;
@SerializedName("stop_reason")
String stopReason = "stop_sequence";
String stop = "\\n\\nHuman:";
}
}
219 changes: 219 additions & 0 deletions src/test/java/org/opensearch/integTest/ToolIntegrationTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.integTest;

import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.Locale;
import java.util.UUID;
import java.util.concurrent.TimeUnit;

import org.junit.After;
import org.junit.Before;
import org.opensearch.client.Request;
import org.opensearch.client.RequestOptions;
import org.opensearch.client.Response;

import com.google.gson.Gson;
import com.google.gson.JsonParser;
import com.sun.net.httpserver.HttpServer;

import lombok.extern.log4j.Log4j2;

@Log4j2
public abstract class ToolIntegrationTest extends BaseAgentToolsIT {
protected HttpServer server;
protected String modelId;
protected String agentId;
protected String modelGroupId;
protected String connectorId;

private final Gson gson = new Gson();

abstract List<PromptHandler> promptHandlers();

abstract String toolType();

@Before
public void setupTestAgent() throws IOException, InterruptedException {
server = MockHttpServer.setupMockLLM(promptHandlers());
server.start();
clusterSettings(false);
try {
connectorId = setUpConnector();
} catch (Exception e) {
// Wait for ML encryption master key has been initialized
TimeUnit.SECONDS.sleep(10);
connectorId = setUpConnector();
}
modelGroupId = setupModelGroup();
modelId = setupLLMModel(connectorId, modelGroupId);
// wait for model to get deployed
TimeUnit.SECONDS.sleep(1);
agentId = setupConversationalAgent(modelId);
log.info("model_id: {}, agent_id: {}", modelId, agentId);
}

@After
public void cleanUpClusterSetting() throws IOException {
clusterSettings(true);
}

@After
public void stopMockLLM() {
server.stop(1);
}

private String setUpConnector() {
String url = String.format(Locale.ROOT, "http://127.0.0.1:%d/invoke", server.getAddress().getPort());
return createConnector(
"{\n"
+ " \"name\": \"BedRock test claude Connector\",\n"
+ " \"description\": \"The connector to BedRock service for claude model\",\n"
+ " \"version\": 1,\n"
+ " \"protocol\": \"aws_sigv4\",\n"
+ " \"parameters\": {\n"
+ " \"region\": \"us-east-1\",\n"
+ " \"service_name\": \"bedrock\",\n"
+ " \"anthropic_version\": \"bedrock-2023-05-31\",\n"
+ " \"endpoint\": \"bedrock.us-east-1.amazonaws.com\",\n"
+ " \"auth\": \"Sig_V4\",\n"
+ " \"content_type\": \"application/json\",\n"
+ " \"max_tokens_to_sample\": 8000,\n"
+ " \"temperature\": 0.0001,\n"
+ " \"response_filter\": \"$.completion\"\n"
+ " },\n"
+ " \"credential\": {\n"
+ " \"access_key\": \"<key>\",\n"
+ " \"secret_key\": \"<secret>\"\n"
+ " },\n"
+ " \"actions\": [\n"
+ " {\n"
+ " \"action_type\": \"predict\",\n"
+ " \"method\": \"POST\",\n"
+ " \"url\": \""
+ url
+ "\",\n"
+ " \"headers\": {\n"
+ " \"content-type\": \"application/json\",\n"
+ " \"x-amz-content-sha256\": \"required\"\n"
+ " },\n"
+ " \"request_body\": \"{\\\"prompt\\\":\\\"${parameters.prompt}\\\", \\\"max_tokens_to_sample\\\":${parameters.max_tokens_to_sample}, \\\"temperature\\\":${parameters.temperature}, \\\"anthropic_version\\\":\\\"${parameters.anthropic_version}\\\" }\"\n"
+ " }\n"
+ " ]\n"
+ "}"
);
}

private void clusterSettings(boolean clean) throws IOException {
if (!clean) {
updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false);
updateClusterSettings("plugins.ml_commons.memory_feature_enabled", true);
updateClusterSettings("plugins.ml_commons.trusted_connector_endpoints_regex", List.of("^.*$"));
} else {
updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", null);
updateClusterSettings("plugins.ml_commons.memory_feature_enabled", null);
updateClusterSettings("plugins.ml_commons.trusted_connector_endpoints_regex", null);
}
}

private String setupModelGroup() throws IOException {
Request request = new Request("POST", "/_plugins/_ml/model_groups/_register");
request
.setJsonEntity(
"{\n"
+ " \"name\": \"test_model_group_bedrock-"
+ UUID.randomUUID()
+ "\",\n"
+ " \"description\": \"This is a public model group\"\n"
+ "}"
);
Response response = executeRequest(request);

String resp = readResponse(response);

return JsonParser.parseString(resp).getAsJsonObject().get("model_group_id").getAsString();
}

private String setupLLMModel(String connectorId, String modelGroupId) throws IOException {
Request request = new Request("POST", "/_plugins/_ml/models/_register?deploy=true");
request
.setJsonEntity(
"{\n"
+ " \"name\": \"Bedrock Claude V2 model\",\n"
+ " \"function_name\": \"remote\",\n"
+ " \"model_group_id\": \""
+ modelGroupId
+ "\",\n"
+ " \"description\": \"test model\",\n"
+ " \"connector_id\": \""
+ connectorId
+ "\"\n"
+ "}"
);
Response response = executeRequest(request);

String resp = readResponse(response);

return JsonParser.parseString(resp).getAsJsonObject().get("model_id").getAsString();
}

private String setupConversationalAgent(String modelId) throws IOException {
Request request = new Request("POST", "/_plugins/_ml/agents/_register");
request
.setJsonEntity(
"{\n"
+ " \"name\": \"integTest-agent\",\n"
+ " \"type\": \"conversational\",\n"
+ " \"description\": \"this is a test agent\",\n"
+ " \"llm\": {\n"
+ " \"model_id\": \""
+ modelId
+ "\",\n"
+ " \"parameters\": {\n"
+ " \"max_iteration\": \"5\",\n"
+ " \"stop_when_no_tool_found\": \"true\",\n"
+ " \"response_filter\": \"$.completion\"\n"
+ " }\n"
+ " },\n"
+ " \"tools\": [\n"
+ " {\n"
+ " \"type\": \""
+ toolType()
+ "\",\n"
+ " \"name\": \""
+ toolType()
+ "\",\n"
+ " \"include_output_in_agent_response\": true,\n"
+ " \"description\": \"tool description\"\n"
+ " }\n"
+ " ],\n"
+ " \"memory\": {\n"
+ " \"type\": \"conversation_index\"\n"
+ " }\n"
+ "}"
);
Response response = executeRequest(request);

String resp = readResponse(response);

return JsonParser.parseString(resp).getAsJsonObject().get("agent_id").getAsString();
}

public static Response executeRequest(Request request) throws IOException {
RequestOptions.Builder builder = RequestOptions.DEFAULT.toBuilder();
builder.addHeader("Content-Type", "application/json");
request.setOptions(builder);
return client().performRequest(request);
}

public static String readResponse(Response response) throws IOException {
try (InputStream ins = response.getEntity().getContent()) {
return String.join("", org.opensearch.common.io.Streams.readAllLines(ins));
}
}
}
Loading

0 comments on commit 0b12066

Please sign in to comment.