Skip to content

Commit

Permalink
support nested query in neural sparse
Browse files Browse the repository at this point in the history
Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws committed Jul 15, 2024
1 parent 62ac87f commit d55b10d
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ public class NeuralSparseSearchTool extends AbstractRetrieverTool {
public static final String TYPE = "NeuralSparseSearchTool";
public static final String MODEL_ID_FIELD = "model_id";
public static final String EMBEDDING_FIELD = "embedding_field";
public static final String NESTED_PATH_FIELD = "nested_path";

private String name = TYPE;
private String modelId;
private String embeddingField;
private String nestedPath;

@Builder
public NeuralSparseSearchTool(
Expand All @@ -46,11 +48,13 @@ public NeuralSparseSearchTool(
String embeddingField,
String[] sourceFields,
Integer docSize,
String modelId
String modelId,
String nestedPath
) {
super(client, xContentRegistry, index, sourceFields, docSize);
this.modelId = modelId;
this.embeddingField = embeddingField;
this.nestedPath = nestedPath;
}

@Override
Expand All @@ -61,8 +65,29 @@ protected String getQueryBody(String queryText) {
);
}

Map<String, Object> queryBody = Map
.of("query", Map.of("neural_sparse", Map.of(embeddingField, Map.of("query_text", queryText, "model_id", modelId))));
Map<String, Object> queryBody;
if (StringUtils.isBlank(nestedPath)) {
queryBody = Map
.of("query", Map.of("neural_sparse", Map.of(embeddingField, Map.of("query_text", queryText, "model_id", modelId))));
} else {
queryBody = Map
.of(
"query",
Map
.of(
"nested",
Map
.of(
"path",
nestedPath,
"score_mode",
"max",
"query",
Map.of("neural_sparse", Map.of(embeddingField, Map.of("query_text", queryText, "model_id", modelId)))
)
)
);
}

try {
return AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(queryBody));
Expand Down Expand Up @@ -99,6 +124,7 @@ public NeuralSparseSearchTool create(Map<String, Object> params) {
String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class);
String modelId = (String) params.get(MODEL_ID_FIELD);
Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : DEFAULT_DOC_SIZE;
String nestedPath = (String) params.get(NESTED_PATH_FIELD);
return NeuralSparseSearchTool
.builder()
.client(client)
Expand All @@ -108,6 +134,7 @@ public NeuralSparseSearchTool create(Map<String, Object> params) {
.sourceFields(sourceFields)
.modelId(modelId)
.docSize(docSize)
.nestedPath(nestedPath)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public class NeuralSparseSearchToolTests {
public static final String TEST_QUERY_TEXT = "123fsd23134sdfouh";
public static final String TEST_EMBEDDING_FIELD = "test embedding";
public static final String TEST_MODEL_ID = "123fsd23134";
public static final String TEST_NESTED_PATH = "nested_path";
private Map<String, Object> params = new HashMap<>();

@Before
Expand Down Expand Up @@ -60,6 +61,23 @@ public void testGetQueryBody() {
assertEquals("123fsd23134", queryBody.get("query").get("neural_sparse").get("test embedding").get("model_id"));
}

@Test
@SneakyThrows
public void testGetQueryBodyWithNestedPath() {
Map nestedParams = new HashMap<>(params);
nestedParams.put(NeuralSparseSearchTool.NESTED_PATH_FIELD, TEST_NESTED_PATH);
NeuralSparseSearchTool tool = NeuralSparseSearchTool.Factory.getInstance().create(nestedParams);
Map<String, Map<String, Map<String, Object>>> nestedQueryBody = gson.fromJson(tool.getQueryBody(TEST_QUERY_TEXT), Map.class);
assertEquals("nested_path", nestedQueryBody.get("query").get("nested").get("path"));
assertEquals("max", nestedQueryBody.get("query").get("nested").get("score_mode"));
Map<String, Map<String, Map<String, String>>> queryBody = (Map<String, Map<String, Map<String, String>>>) nestedQueryBody
.get("query")
.get("nested")
.get("query");
assertEquals("123fsd23134sdfouh", queryBody.get("neural_sparse").get("test embedding").get("query_text"));
assertEquals("123fsd23134", queryBody.get("neural_sparse").get("test embedding").get("model_id"));
}

@Test
@SneakyThrows
public void testGetQueryBodyWithJsonObjectString() {
Expand Down Expand Up @@ -110,6 +128,11 @@ public void testCreateToolsParseParams() {
() -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(NeuralSparseSearchTool.MODEL_ID_FIELD, 123))
);

assertThrows(
ClassCastException.class,
() -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(NeuralSparseSearchTool.NESTED_PATH_FIELD, 123))
);

assertThrows(
JsonSyntaxException.class,
() -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(NeuralSparseSearchTool.SOURCE_FIELD, "123"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.nio.file.Files;
import java.nio.file.Path;
Expand All @@ -22,6 +21,7 @@

public class NeuralSparseSearchToolIT extends BaseAgentToolsIT {
public static String TEST_INDEX_NAME = "test_index";
public static String TEST_NESTED_INDEX_NAME = "test_index_nested";

private String modelId;
private String registerAgentRequestBody;
Expand Down Expand Up @@ -64,12 +64,55 @@ private void prepareIndex() {
addDocToIndex(TEST_INDEX_NAME, "2", List.of("text", "embedding"), List.of("text doc 3", Map.of("test", 5, "a", 6)));
}

@SneakyThrows
private void prepareNestedIndex() {
createIndexWithConfiguration(
TEST_NESTED_INDEX_NAME,
"{\n"
+ " \"mappings\": {\n"
+ " \"properties\": {\n"
+ " \"text\": {\n"
+ " \"type\": \"text\"\n"
+ " },\n"
+ " \"embedding\": {\n"
+ " \"type\": \"nested\",\n"
+ " \"properties\":{\n"
+ " \"sparse\":{\n"
+ " \"type\":\"rank_features\"\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}"
);
addDocToIndex(
TEST_NESTED_INDEX_NAME,
"0",
List.of("text", "embedding"),
List.of("text doc 1", Map.of("sparse", List.of(Map.of("hello", 1, "world", 2))))
);
addDocToIndex(
TEST_NESTED_INDEX_NAME,
"1",
List.of("text", "embedding"),
List.of("text doc 2", Map.of("sparse", List.of(Map.of("a", 3, "b", 4))))
);
addDocToIndex(
TEST_NESTED_INDEX_NAME,
"2",
List.of("text", "embedding"),
List.of("text doc 3", Map.of("sparse", List.of(Map.of("test", 5, "a", 6))))
);
}

@Before
@SneakyThrows
public void setUp() {
super.setUp();
prepareModel();
prepareIndex();
prepareNestedIndex();
registerAgentRequestBody = Files
.readString(
Path
Expand Down Expand Up @@ -127,6 +170,23 @@ public void testNeuralSparseSearchToolInFlowAgent() {
);
}

public void testNeuralSparseSearchToolInFlowAgent_withNestedIndex() {
String registerAgentRequestBodyNested = registerAgentRequestBody;
registerAgentRequestBodyNested = registerAgentRequestBodyNested.replace("\"nested_path\": \"\"", "\"nested_path\": \"embedding\"");
registerAgentRequestBodyNested = registerAgentRequestBodyNested
.replace("\"embedding_field\": \"embedding\"", "\"embedding_field\": \"embedding.sparse\"");
registerAgentRequestBodyNested = registerAgentRequestBodyNested
.replace("\"index\": \"test_index\"", "\"index\": \"test_index_nested\"");
String agentId = createAgent(registerAgentRequestBodyNested);
String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}");
assertEquals(
"The agent execute response not equal with expected.",
"{\"_index\":\"test_index_nested\",\"_source\":{\"text\":\"text doc 3\"},\"_id\":\"2\",\"_score\":2.4136734}\n"
+ "{\"_index\":\"test_index_nested\",\"_source\":{\"text\":\"text doc 2\"},\"_id\":\"1\",\"_score\":1.2068367}\n",
result
);
}

public void testNeuralSparseSearchToolInFlowAgent_withIllegalSourceField_thenGetEmptySource() {
String agentId = createAgent(registerAgentRequestBody.replace("text", "text2"));
String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
"index": "test_index",
"embedding_field": "embedding",
"source_field": ["text"],
"input": "${parameters.question}"
"input": "${parameters.question}",
"nested_path": ""
}
}
]
Expand Down

0 comments on commit d55b10d

Please sign in to comment.