From 37490dc14206a819419b66b90c608059d8dd1425 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Mon, 2 Oct 2023 13:43:22 -0700 Subject: [PATCH] Add memory client to machine learning client Signed-off-by: HenryL27 --- client/build.gradle | 1 + .../ml/client/MachineLearningClient.java | 6 + .../ml/client/MachineLearningNodeClient.java | 13 +- .../opensearch/ml/client/MemoryClient.java | 37 +++ .../ml/client/MemoryNodeClient.java | 82 +++++ .../ml/client/MachineLearningClientTest.java | 13 + .../client/MachineLearningNodeClientTest.java | 17 +- .../ml/client/MemoryNodeClientTest.java | 294 ++++++++++++++++++ 8 files changed, 460 insertions(+), 3 deletions(-) create mode 100644 client/src/main/java/org/opensearch/ml/client/MemoryClient.java create mode 100644 client/src/main/java/org/opensearch/ml/client/MemoryNodeClient.java create mode 100644 client/src/test/java/org/opensearch/ml/client/MemoryNodeClientTest.java diff --git a/client/build.gradle b/client/build.gradle index cc4f904083..d57e739099 100644 --- a/client/build.gradle +++ b/client/build.gradle @@ -14,6 +14,7 @@ plugins { dependencies { implementation project(':opensearch-ml-common') + implementation project(':opensearch-ml-memory') compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" testImplementation group: 'junit', name: 'junit', version: '4.13.2' testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0' diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java index 8452dcc3c2..bc85d73d48 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java @@ -226,4 +226,10 @@ default ActionFuture searchTask(SearchRequest searchRequest) { * @param listener action listener */ void searchTask(SearchRequest searchRequest, ActionListener listener); + + /** + * Get the memory client + * @return A Memory client for accessing conversation memory + */ + MemoryClient memory(); } diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index d7328b9766..200ad27212 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -53,10 +53,16 @@ import static org.opensearch.ml.common.input.InputHelper.getFunctionName; @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) -@RequiredArgsConstructor public class MachineLearningNodeClient implements MachineLearningClient { Client client; + MemoryClient memoryClient; + + public MachineLearningNodeClient(Client client) { + this.client = client; + this.memoryClient = new MemoryNodeClient(client); + } + @Override public void predict(String modelId, MLInput mlInput, ActionListener listener) { @@ -218,4 +224,9 @@ private void validateMLInput(MLInput mlInput, boolean requireInput) { throw new IllegalArgumentException("input data set can't be null"); } } + + @Override + public MemoryClient memory() { + return memoryClient; + } } diff --git a/client/src/main/java/org/opensearch/ml/client/MemoryClient.java b/client/src/main/java/org/opensearch/ml/client/MemoryClient.java new file mode 100644 index 0000000000..f01eaced5b --- /dev/null +++ b/client/src/main/java/org/opensearch/ml/client/MemoryClient.java @@ -0,0 +1,37 @@ +package org.opensearch.ml.client; + +import org.opensearch.common.action.ActionFuture; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; +import org.opensearch.ml.memory.action.conversation.CreateInteractionRequest; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.ml.memory.action.conversation.DeleteConversationRequest; +import org.opensearch.ml.memory.action.conversation.DeleteConversationResponse; +import org.opensearch.ml.memory.action.conversation.GetConversationsRequest; +import org.opensearch.ml.memory.action.conversation.GetConversationsResponse; +import org.opensearch.ml.memory.action.conversation.GetInteractionsRequest; +import org.opensearch.ml.memory.action.conversation.GetInteractionsResponse; + +public interface MemoryClient { + + void createConversation(CreateConversationRequest request, ActionListener listener); + + ActionFuture createConversation(CreateConversationRequest request); + + void createInteraction(CreateInteractionRequest request, ActionListener listener); + + ActionFuture createInteraction(CreateInteractionRequest request); + + void getConversations(GetConversationsRequest request, ActionListener listener); + + ActionFuture getConversations(GetConversationsRequest request); + + void getInteractions(GetInteractionsRequest request, ActionListener listener); + + ActionFuture getInteractions(GetInteractionsRequest request); + + void deleteConversation(DeleteConversationRequest request, ActionListener listener); + + ActionFuture deleteConversation(DeleteConversationRequest request); +} diff --git a/client/src/main/java/org/opensearch/ml/client/MemoryNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MemoryNodeClient.java new file mode 100644 index 0000000000..1d75e3d0ca --- /dev/null +++ b/client/src/main/java/org/opensearch/ml/client/MemoryNodeClient.java @@ -0,0 +1,82 @@ +package org.opensearch.ml.client; + +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.client.Client; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.memory.action.conversation.CreateConversationAction; +import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; +import org.opensearch.ml.memory.action.conversation.CreateInteractionAction; +import org.opensearch.ml.memory.action.conversation.CreateInteractionRequest; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.ml.memory.action.conversation.DeleteConversationAction; +import org.opensearch.ml.memory.action.conversation.DeleteConversationRequest; +import org.opensearch.ml.memory.action.conversation.DeleteConversationResponse; +import org.opensearch.ml.memory.action.conversation.GetConversationsAction; +import org.opensearch.ml.memory.action.conversation.GetConversationsRequest; +import org.opensearch.ml.memory.action.conversation.GetConversationsResponse; +import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionsRequest; +import org.opensearch.ml.memory.action.conversation.GetInteractionsResponse; + +import lombok.AllArgsConstructor; +import lombok.RequiredArgsConstructor; + +@AllArgsConstructor +public class MemoryNodeClient implements MemoryClient { + + Client client; + + public void createConversation(CreateConversationRequest request, ActionListener listener) { + client.execute(CreateConversationAction.INSTANCE, request, listener); + } + + + public ActionFuture createConversation(CreateConversationRequest request) { + PlainActionFuture fut = PlainActionFuture.newFuture(); + createConversation(request, fut); + return fut; + } + + + public void createInteraction(CreateInteractionRequest request, ActionListener listener) { + client.execute(CreateInteractionAction.INSTANCE, request, listener); + } + + public ActionFuture createInteraction(CreateInteractionRequest request) { + PlainActionFuture fut = PlainActionFuture.newFuture(); + createInteraction(request, fut); + return fut; + } + + public void getConversations(GetConversationsRequest request, ActionListener listener) { + client.execute(GetConversationsAction.INSTANCE, request, listener); + } + + public ActionFuture getConversations(GetConversationsRequest request) { + PlainActionFuture fut = PlainActionFuture.newFuture(); + getConversations(request, fut); + return fut; + } + + public void getInteractions(GetInteractionsRequest request, ActionListener listener) { + client.execute(GetInteractionsAction.INSTANCE, request, listener); + } + + public ActionFuture getInteractions(GetInteractionsRequest request) { + PlainActionFuture fut = PlainActionFuture.newFuture(); + getInteractions(request, fut); + return fut; + } + + public void deleteConversation(DeleteConversationRequest request, ActionListener listener) { + client.execute(DeleteConversationAction.INSTANCE, request, listener); + } + + public ActionFuture deleteConversation(DeleteConversationRequest request) { + PlainActionFuture fut = PlainActionFuture.newFuture(); + deleteConversation(request, fut); + return fut; + } +} diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index 2a98812090..efa710c0e7 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -58,6 +58,9 @@ public class MachineLearningClientTest { @Mock SearchResponse searchResponse; + @Mock + MemoryClient memoryClient; + private String modekId = "test_model_id"; private MLModel mlModel; private MLTask mlTask; @@ -132,6 +135,11 @@ public void deleteTask(String taskId, ActionListener listener) { public void searchTask(SearchRequest searchRequest, ActionListener listener) { listener.onResponse(searchResponse); } + + @Override + public MemoryClient memory() { + return memoryClient; + } }; } @@ -251,4 +259,9 @@ public void deleteTask() { public void searchTask() { assertEquals(searchResponse, machineLearningClient.searchTask(new SearchRequest()).actionGet()); } + + @Test + public void memory() { + assertEquals(memoryClient, machineLearningClient.memory()); + } } diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index 2f6f11998d..c778593996 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -36,8 +36,6 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; -import org.opensearch.ml.common.MLTask; -import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.output.MLTrainingOutput; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.model.MLModelDeleteAction; @@ -57,6 +55,9 @@ import org.opensearch.ml.common.transport.training.MLTrainingTaskAction; import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest; import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction; +import org.opensearch.ml.memory.action.conversation.CreateConversationAction; +import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; @@ -64,17 +65,22 @@ import org.opensearch.search.profile.SearchProfileShardResults; import org.opensearch.search.suggest.Suggest; +import lombok.extern.log4j.Log4j2; + import java.io.IOException; import java.util.Collections; import java.util.HashMap; import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.mockito.Answers.RETURNS_DEEP_STUBS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.opensearch.ml.common.input.Constants.ACTION; import static org.opensearch.ml.common.input.Constants.ALGORITHM; @@ -85,6 +91,7 @@ import static org.opensearch.ml.common.input.Constants.TRAIN; import static org.opensearch.ml.common.input.Constants.TRAINANDPREDICT; +@Log4j2 public class MachineLearningNodeClientTest { @Mock(answer = RETURNS_DEEP_STUBS) @@ -590,4 +597,10 @@ private SearchResponse createSearchResponse(ToXContentObject o) throws IOExcepti SearchResponse.Clusters.EMPTY ); } + + @Test + public void memory() { + MemoryClient memoryClient = machineLearningNodeClient.memory(); + assertNotNull(memoryClient); + } } diff --git a/client/src/test/java/org/opensearch/ml/client/MemoryNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MemoryNodeClientTest.java new file mode 100644 index 0000000000..be1cd6e1d2 --- /dev/null +++ b/client/src/test/java/org/opensearch/ml/client/MemoryNodeClientTest.java @@ -0,0 +1,294 @@ +package org.opensearch.ml.client; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.List; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.memory.action.conversation.CreateConversationAction; +import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; +import org.opensearch.ml.memory.action.conversation.CreateInteractionAction; +import org.opensearch.ml.memory.action.conversation.CreateInteractionRequest; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.ml.memory.action.conversation.DeleteConversationAction; +import org.opensearch.ml.memory.action.conversation.DeleteConversationRequest; +import org.opensearch.ml.memory.action.conversation.DeleteConversationResponse; +import org.opensearch.ml.memory.action.conversation.GetConversationsAction; +import org.opensearch.ml.memory.action.conversation.GetConversationsRequest; +import org.opensearch.ml.memory.action.conversation.GetConversationsResponse; +import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionsRequest; +import org.opensearch.ml.memory.action.conversation.GetInteractionsResponse; + +public class MemoryNodeClientTest { + + @Mock + Client client; + + @Mock + ActionListener createConversationListener; + + @Mock + ActionListener createInteractionListener; + + @Mock + ActionListener getConversationsListener; + + @Mock + ActionListener getInteractionsListener; + + @Mock + ActionListener deleteConversationListener; + + @InjectMocks + MemoryNodeClient memoryClient; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + } + + @Test + public void createConversation_Success() { + CreateConversationResponse response = new CreateConversationResponse("Test id"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(response); + return null; + }).when(client).execute(eq(CreateConversationAction.INSTANCE), any(), any()); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(CreateConversationResponse.class); + CreateConversationRequest request = new CreateConversationRequest(); + memoryClient.createConversation(request, createConversationListener); + + verify(createConversationListener, times(1)).onResponse(argCaptor.capture()); + assertEquals(response, argCaptor.getValue()); + } + + @Test + public void createConversation_Fails() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new Exception("CC Fail")); + return null; + }).when(client).execute(eq(CreateConversationAction.INSTANCE), any(), any()); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + CreateConversationRequest request = new CreateConversationRequest(); + memoryClient.createConversation(request, createConversationListener); + + verify(createConversationListener, times(1)).onFailure(argCaptor.capture()); + assertEquals("CC Fail", argCaptor.getValue().getMessage()); + } + + @Test + public void createConversation_Future() { + CreateConversationResponse response = new CreateConversationResponse("Test id"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(response); + return null; + }).when(client).execute(eq(CreateConversationAction.INSTANCE), any(), any()); + + CreateConversationRequest request = new CreateConversationRequest(); + assertEquals(memoryClient.createConversation(request).actionGet(), response); + } + + @Test + public void createInteraction_Success() { + CreateInteractionResponse response = new CreateInteractionResponse("Test IID"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(response); + return null; + }).when(client).execute(eq(CreateInteractionAction.INSTANCE), any(), any()); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(CreateInteractionResponse.class); + CreateInteractionRequest request = new CreateInteractionRequest("cid", "inp", "pt", "rsp", "ogn", "add"); + memoryClient.createInteraction(request, createInteractionListener); + + verify(createInteractionListener, times(1)).onResponse(argCaptor.capture()); + assertEquals(response, argCaptor.getValue()); + } + + @Test + public void createInteraction_Fails() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new Exception("CI Fail")); + return null; + }).when(client).execute(eq(CreateInteractionAction.INSTANCE), any(), any()); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + CreateInteractionRequest request = new CreateInteractionRequest("cid", "inp", "pt", "rsp", "ogn", "add"); + memoryClient.createInteraction(request, createInteractionListener); + + verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); + assertEquals("CI Fail", argCaptor.getValue().getMessage()); + } + + @Test + public void createInteraction_Future() { + CreateInteractionResponse response = new CreateInteractionResponse("Test IID"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(response); + return null; + }).when(client).execute(eq(CreateInteractionAction.INSTANCE), any(), any()); + + CreateInteractionRequest request = new CreateInteractionRequest("cid", "inp", "pt", "rsp", "ogn", "add"); + assertEquals(memoryClient.createInteraction(request).actionGet(), response); + } + + @Test + public void getConversations_Success() { + GetConversationsResponse response = new GetConversationsResponse(List.of(), 4, false); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(response); + return null; + }).when(client).execute(eq(GetConversationsAction.INSTANCE), any(), any()); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetConversationsResponse.class); + GetConversationsRequest request = new GetConversationsRequest(); + memoryClient.getConversations(request, getConversationsListener); + + verify(getConversationsListener, times(1)).onResponse(argCaptor.capture()); + assertEquals(response, argCaptor.getValue()); + } + + @Test + public void getConversations_Fails() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new Exception("GC Fail")); + return null; + }).when(client).execute(eq(GetConversationsAction.INSTANCE), any(), any()); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + GetConversationsRequest request = new GetConversationsRequest(); + memoryClient.getConversations(request, getConversationsListener); + + verify(getConversationsListener, times(1)).onFailure(argCaptor.capture()); + assertEquals("GC Fail", argCaptor.getValue().getMessage()); + } + + @Test + public void getConversations_Future() { + GetConversationsResponse response = new GetConversationsResponse(List.of(), 4, false); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(response); + return null; + }).when(client).execute(eq(GetConversationsAction.INSTANCE), any(), any()); + + GetConversationsRequest request = new GetConversationsRequest(); + assertEquals(memoryClient.getConversations(request).actionGet(), response); + } + + @Test + public void getInteractions_Success() { + GetInteractionsResponse response = new GetInteractionsResponse(List.of(), 4, false); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(response); + return null; + }).when(client).execute(eq(GetInteractionsAction.INSTANCE), any(), any()); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetInteractionsResponse.class); + GetInteractionsRequest request = new GetInteractionsRequest("Test CID"); + memoryClient.getInteractions(request, getInteractionsListener); + + verify(getInteractionsListener, times(1)).onResponse(argCaptor.capture()); + assertEquals(response, argCaptor.getValue()); + } + + @Test + public void getInteractions_Fails() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new Exception("GI Fail")); + return null; + }).when(client).execute(eq(GetInteractionsAction.INSTANCE), any(), any()); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + GetInteractionsRequest request = new GetInteractionsRequest("Test CID"); + memoryClient.getInteractions(request, getInteractionsListener); + + verify(getInteractionsListener, times(1)).onFailure(argCaptor.capture()); + assertEquals("GI Fail", argCaptor.getValue().getMessage()); + } + + @Test + public void getInteractions_Future() { + GetInteractionsResponse response = new GetInteractionsResponse(List.of(), 4, false); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(response); + return null; + }).when(client).execute(eq(GetInteractionsAction.INSTANCE), any(), any()); + + GetInteractionsRequest request = new GetInteractionsRequest("Test CID"); + assertEquals(memoryClient.getInteractions(request).actionGet(), response); + } + + @Test + public void deleteConversation_Success() { + DeleteConversationResponse response = new DeleteConversationResponse(true); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(response); + return null; + }).when(client).execute(eq(DeleteConversationAction.INSTANCE), any(), any()); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(DeleteConversationResponse.class); + DeleteConversationRequest request = new DeleteConversationRequest("Test CID"); + memoryClient.deleteConversation(request, deleteConversationListener); + + verify(deleteConversationListener, times(1)).onResponse(argCaptor.capture()); + assertEquals(response, argCaptor.getValue()); + } + + @Test + public void deleteConversation_Fails() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new Exception("DC Fail")); + return null; + }).when(client).execute(eq(DeleteConversationAction.INSTANCE), any(), any()); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + DeleteConversationRequest request = new DeleteConversationRequest("Test CID"); + memoryClient.deleteConversation(request, deleteConversationListener); + + verify(deleteConversationListener, times(1)).onFailure(argCaptor.capture()); + assertEquals("DC Fail", argCaptor.getValue().getMessage()); + } + + @Test + public void deleteConversation_Future() { + DeleteConversationResponse response = new DeleteConversationResponse(true); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(response); + return null; + }).when(client).execute(eq(DeleteConversationAction.INSTANCE), any(), any()); + + DeleteConversationRequest request = new DeleteConversationRequest("Test CID"); + assertEquals(memoryClient.deleteConversation(request).actionGet(), response); + } + +}