forked from opensearch-project/skills
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Mock http server for LLM; Integration test for visualization tool (op…
…ensearch-project#92) (opensearch-project#102) * mock server and integTest for visualization * update rest status * add refresh * merge from main * rename variable name --------- (cherry picked from commit 4c76e4c) Signed-off-by: Hailong Cui <[email protected]> Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Signed-off-by: yuye-aws <[email protected]>
- Loading branch information
Showing
6 changed files
with
441 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
52 changes: 52 additions & 0 deletions
52
src/test/java/org/opensearch/integTest/MockHttpServer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.integTest; | ||
|
||
import java.io.IOException; | ||
import java.io.InputStream; | ||
import java.net.InetAddress; | ||
import java.net.InetSocketAddress; | ||
import java.nio.charset.StandardCharsets; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import com.google.gson.Gson; | ||
import com.sun.net.httpserver.HttpServer; | ||
|
||
import lombok.extern.log4j.Log4j2; | ||
|
||
@Log4j2 | ||
public class MockHttpServer { | ||
|
||
private static Gson gson = new Gson(); | ||
|
||
public static HttpServer setupMockLLM(List<PromptHandler> promptHandlers) throws IOException { | ||
HttpServer server = HttpServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0); | ||
|
||
server.createContext("/invoke", exchange -> { | ||
InputStream ins = exchange.getRequestBody(); | ||
String req = new String(ins.readAllBytes(), StandardCharsets.UTF_8); | ||
Map<String, String> map = gson.fromJson(req, Map.class); | ||
String prompt = map.get("prompt"); | ||
log.debug("prompt received: {}", prompt); | ||
|
||
String llmRes = ""; | ||
for (PromptHandler promptHandler : promptHandlers) { | ||
if (promptHandler.apply(prompt)) { | ||
PromptHandler.LLMResponse llmResponse = new PromptHandler.LLMResponse(); | ||
llmResponse.setCompletion(promptHandler.response(prompt)); | ||
llmRes = gson.toJson(llmResponse); | ||
break; | ||
} | ||
} | ||
byte[] llmResBytes = llmRes.getBytes(StandardCharsets.UTF_8); | ||
exchange.sendResponseHeaders(200, llmResBytes.length); | ||
exchange.getResponseBody().write(llmResBytes); | ||
exchange.close(); | ||
}); | ||
return server; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.integTest; | ||
|
||
import com.google.gson.annotations.SerializedName; | ||
|
||
import lombok.AllArgsConstructor; | ||
import lombok.Builder; | ||
import lombok.Data; | ||
import lombok.NoArgsConstructor; | ||
|
||
public class PromptHandler { | ||
|
||
boolean apply(String prompt) { | ||
return prompt.contains(llmThought().getQuestion()); | ||
} | ||
|
||
LLMThought llmThought() { | ||
return new LLMThought(); | ||
} | ||
|
||
String response(String prompt) { | ||
if (prompt.contains("TOOL RESPONSE: ")) { | ||
return "```json{\n" | ||
+ " \"thought\": \"Thought: Now I know the final answer\",\n" | ||
+ " \"final_answer\": \"final answer\"\n" | ||
+ "}```"; | ||
} else { | ||
return "```json{\n" | ||
+ " \"thought\": \"Thought: Let me use tool to figure out\",\n" | ||
+ " \"action\": \"" | ||
+ this.llmThought().getAction() | ||
+ "\",\n" | ||
+ " \"action_input\": \"" | ||
+ this.llmThought().getActionInput() | ||
+ "\"\n" | ||
+ "}```"; | ||
} | ||
} | ||
|
||
@Builder | ||
@NoArgsConstructor | ||
@AllArgsConstructor | ||
@Data | ||
static class LLMThought { | ||
String question; | ||
String action; | ||
String actionInput; | ||
} | ||
|
||
@Data | ||
static class LLMResponse { | ||
String completion; | ||
@SerializedName("stop_reason") | ||
String stopReason = "stop_sequence"; | ||
String stop = "\\n\\nHuman:"; | ||
} | ||
} |
219 changes: 219 additions & 0 deletions
219
src/test/java/org/opensearch/integTest/ToolIntegrationTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.integTest; | ||
|
||
import java.io.IOException; | ||
import java.io.InputStream; | ||
import java.util.List; | ||
import java.util.Locale; | ||
import java.util.UUID; | ||
import java.util.concurrent.TimeUnit; | ||
|
||
import org.junit.After; | ||
import org.junit.Before; | ||
import org.opensearch.client.Request; | ||
import org.opensearch.client.RequestOptions; | ||
import org.opensearch.client.Response; | ||
|
||
import com.google.gson.Gson; | ||
import com.google.gson.JsonParser; | ||
import com.sun.net.httpserver.HttpServer; | ||
|
||
import lombok.extern.log4j.Log4j2; | ||
|
||
@Log4j2 | ||
public abstract class ToolIntegrationTest extends BaseAgentToolsIT { | ||
protected HttpServer server; | ||
protected String modelId; | ||
protected String agentId; | ||
protected String modelGroupId; | ||
protected String connectorId; | ||
|
||
private final Gson gson = new Gson(); | ||
|
||
abstract List<PromptHandler> promptHandlers(); | ||
|
||
abstract String toolType(); | ||
|
||
@Before | ||
public void setupTestAgent() throws IOException, InterruptedException { | ||
server = MockHttpServer.setupMockLLM(promptHandlers()); | ||
server.start(); | ||
clusterSettings(false); | ||
try { | ||
connectorId = setUpConnector(); | ||
} catch (Exception e) { | ||
// Wait for ML encryption master key has been initialized | ||
TimeUnit.SECONDS.sleep(10); | ||
connectorId = setUpConnector(); | ||
} | ||
modelGroupId = setupModelGroup(); | ||
modelId = setupLLMModel(connectorId, modelGroupId); | ||
// wait for model to get deployed | ||
TimeUnit.SECONDS.sleep(1); | ||
agentId = setupConversationalAgent(modelId); | ||
log.info("model_id: {}, agent_id: {}", modelId, agentId); | ||
} | ||
|
||
@After | ||
public void cleanUpClusterSetting() throws IOException { | ||
clusterSettings(true); | ||
} | ||
|
||
@After | ||
public void stopMockLLM() { | ||
server.stop(1); | ||
} | ||
|
||
private String setUpConnector() { | ||
String url = String.format(Locale.ROOT, "http://127.0.0.1:%d/invoke", server.getAddress().getPort()); | ||
return createConnector( | ||
"{\n" | ||
+ " \"name\": \"BedRock test claude Connector\",\n" | ||
+ " \"description\": \"The connector to BedRock service for claude model\",\n" | ||
+ " \"version\": 1,\n" | ||
+ " \"protocol\": \"aws_sigv4\",\n" | ||
+ " \"parameters\": {\n" | ||
+ " \"region\": \"us-east-1\",\n" | ||
+ " \"service_name\": \"bedrock\",\n" | ||
+ " \"anthropic_version\": \"bedrock-2023-05-31\",\n" | ||
+ " \"endpoint\": \"bedrock.us-east-1.amazonaws.com\",\n" | ||
+ " \"auth\": \"Sig_V4\",\n" | ||
+ " \"content_type\": \"application/json\",\n" | ||
+ " \"max_tokens_to_sample\": 8000,\n" | ||
+ " \"temperature\": 0.0001,\n" | ||
+ " \"response_filter\": \"$.completion\"\n" | ||
+ " },\n" | ||
+ " \"credential\": {\n" | ||
+ " \"access_key\": \"<key>\",\n" | ||
+ " \"secret_key\": \"<secret>\"\n" | ||
+ " },\n" | ||
+ " \"actions\": [\n" | ||
+ " {\n" | ||
+ " \"action_type\": \"predict\",\n" | ||
+ " \"method\": \"POST\",\n" | ||
+ " \"url\": \"" | ||
+ url | ||
+ "\",\n" | ||
+ " \"headers\": {\n" | ||
+ " \"content-type\": \"application/json\",\n" | ||
+ " \"x-amz-content-sha256\": \"required\"\n" | ||
+ " },\n" | ||
+ " \"request_body\": \"{\\\"prompt\\\":\\\"${parameters.prompt}\\\", \\\"max_tokens_to_sample\\\":${parameters.max_tokens_to_sample}, \\\"temperature\\\":${parameters.temperature}, \\\"anthropic_version\\\":\\\"${parameters.anthropic_version}\\\" }\"\n" | ||
+ " }\n" | ||
+ " ]\n" | ||
+ "}" | ||
); | ||
} | ||
|
||
private void clusterSettings(boolean clean) throws IOException { | ||
if (!clean) { | ||
updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false); | ||
updateClusterSettings("plugins.ml_commons.memory_feature_enabled", true); | ||
updateClusterSettings("plugins.ml_commons.trusted_connector_endpoints_regex", List.of("^.*$")); | ||
} else { | ||
updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", null); | ||
updateClusterSettings("plugins.ml_commons.memory_feature_enabled", null); | ||
updateClusterSettings("plugins.ml_commons.trusted_connector_endpoints_regex", null); | ||
} | ||
} | ||
|
||
private String setupModelGroup() throws IOException { | ||
Request request = new Request("POST", "/_plugins/_ml/model_groups/_register"); | ||
request | ||
.setJsonEntity( | ||
"{\n" | ||
+ " \"name\": \"test_model_group_bedrock-" | ||
+ UUID.randomUUID() | ||
+ "\",\n" | ||
+ " \"description\": \"This is a public model group\"\n" | ||
+ "}" | ||
); | ||
Response response = executeRequest(request); | ||
|
||
String resp = readResponse(response); | ||
|
||
return JsonParser.parseString(resp).getAsJsonObject().get("model_group_id").getAsString(); | ||
} | ||
|
||
private String setupLLMModel(String connectorId, String modelGroupId) throws IOException { | ||
Request request = new Request("POST", "/_plugins/_ml/models/_register?deploy=true"); | ||
request | ||
.setJsonEntity( | ||
"{\n" | ||
+ " \"name\": \"Bedrock Claude V2 model\",\n" | ||
+ " \"function_name\": \"remote\",\n" | ||
+ " \"model_group_id\": \"" | ||
+ modelGroupId | ||
+ "\",\n" | ||
+ " \"description\": \"test model\",\n" | ||
+ " \"connector_id\": \"" | ||
+ connectorId | ||
+ "\"\n" | ||
+ "}" | ||
); | ||
Response response = executeRequest(request); | ||
|
||
String resp = readResponse(response); | ||
|
||
return JsonParser.parseString(resp).getAsJsonObject().get("model_id").getAsString(); | ||
} | ||
|
||
private String setupConversationalAgent(String modelId) throws IOException { | ||
Request request = new Request("POST", "/_plugins/_ml/agents/_register"); | ||
request | ||
.setJsonEntity( | ||
"{\n" | ||
+ " \"name\": \"integTest-agent\",\n" | ||
+ " \"type\": \"conversational\",\n" | ||
+ " \"description\": \"this is a test agent\",\n" | ||
+ " \"llm\": {\n" | ||
+ " \"model_id\": \"" | ||
+ modelId | ||
+ "\",\n" | ||
+ " \"parameters\": {\n" | ||
+ " \"max_iteration\": \"5\",\n" | ||
+ " \"stop_when_no_tool_found\": \"true\",\n" | ||
+ " \"response_filter\": \"$.completion\"\n" | ||
+ " }\n" | ||
+ " },\n" | ||
+ " \"tools\": [\n" | ||
+ " {\n" | ||
+ " \"type\": \"" | ||
+ toolType() | ||
+ "\",\n" | ||
+ " \"name\": \"" | ||
+ toolType() | ||
+ "\",\n" | ||
+ " \"include_output_in_agent_response\": true,\n" | ||
+ " \"description\": \"tool description\"\n" | ||
+ " }\n" | ||
+ " ],\n" | ||
+ " \"memory\": {\n" | ||
+ " \"type\": \"conversation_index\"\n" | ||
+ " }\n" | ||
+ "}" | ||
); | ||
Response response = executeRequest(request); | ||
|
||
String resp = readResponse(response); | ||
|
||
return JsonParser.parseString(resp).getAsJsonObject().get("agent_id").getAsString(); | ||
} | ||
|
||
public static Response executeRequest(Request request) throws IOException { | ||
RequestOptions.Builder builder = RequestOptions.DEFAULT.toBuilder(); | ||
builder.addHeader("Content-Type", "application/json"); | ||
request.setOptions(builder); | ||
return client().performRequest(request); | ||
} | ||
|
||
public static String readResponse(Response response) throws IOException { | ||
try (InputStream ins = response.getEntity().getContent()) { | ||
return String.join("", org.opensearch.common.io.Streams.readAllLines(ins)); | ||
} | ||
} | ||
} |
Oops, something went wrong.