Skip to content

Commit

Permalink
Add memory client to machine learning client
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 committed Oct 2, 2023
1 parent 4f6d0d0 commit 37490dc
Show file tree
Hide file tree
Showing 8 changed files with 460 additions and 3 deletions.
1 change: 1 addition & 0 deletions client/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ plugins {

dependencies {
implementation project(':opensearch-ml-common')
implementation project(':opensearch-ml-memory')
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,4 +226,10 @@ default ActionFuture<SearchResponse> searchTask(SearchRequest searchRequest) {
* @param listener action listener
*/
void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> listener);

/**
* Get the memory client
* @return A Memory client for accessing conversation memory
*/
MemoryClient memory();
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,16 @@
import static org.opensearch.ml.common.input.InputHelper.getFunctionName;

@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@RequiredArgsConstructor
public class MachineLearningNodeClient implements MachineLearningClient {

Client client;
MemoryClient memoryClient;

public MachineLearningNodeClient(Client client) {
this.client = client;
this.memoryClient = new MemoryNodeClient(client);
}


@Override
public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
Expand Down Expand Up @@ -218,4 +224,9 @@ private void validateMLInput(MLInput mlInput, boolean requireInput) {
throw new IllegalArgumentException("input data set can't be null");
}
}

@Override
public MemoryClient memory() {
return memoryClient;
}
}
37 changes: 37 additions & 0 deletions client/src/main/java/org/opensearch/ml/client/MemoryClient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package org.opensearch.ml.client;

import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;
import org.opensearch.ml.memory.action.conversation.CreateInteractionRequest;
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse;
import org.opensearch.ml.memory.action.conversation.DeleteConversationRequest;
import org.opensearch.ml.memory.action.conversation.DeleteConversationResponse;
import org.opensearch.ml.memory.action.conversation.GetConversationsRequest;
import org.opensearch.ml.memory.action.conversation.GetConversationsResponse;
import org.opensearch.ml.memory.action.conversation.GetInteractionsRequest;
import org.opensearch.ml.memory.action.conversation.GetInteractionsResponse;

public interface MemoryClient {

void createConversation(CreateConversationRequest request, ActionListener<CreateConversationResponse> listener);

ActionFuture<CreateConversationResponse> createConversation(CreateConversationRequest request);

void createInteraction(CreateInteractionRequest request, ActionListener<CreateInteractionResponse> listener);

ActionFuture<CreateInteractionResponse> createInteraction(CreateInteractionRequest request);

void getConversations(GetConversationsRequest request, ActionListener<GetConversationsResponse> listener);

ActionFuture<GetConversationsResponse> getConversations(GetConversationsRequest request);

void getInteractions(GetInteractionsRequest request, ActionListener<GetInteractionsResponse> listener);

ActionFuture<GetInteractionsResponse> getInteractions(GetInteractionsRequest request);

void deleteConversation(DeleteConversationRequest request, ActionListener<DeleteConversationResponse> listener);

ActionFuture<DeleteConversationResponse> deleteConversation(DeleteConversationRequest request);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package org.opensearch.ml.client;

import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.client.Client;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;
import org.opensearch.ml.memory.action.conversation.CreateInteractionAction;
import org.opensearch.ml.memory.action.conversation.CreateInteractionRequest;
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse;
import org.opensearch.ml.memory.action.conversation.DeleteConversationAction;
import org.opensearch.ml.memory.action.conversation.DeleteConversationRequest;
import org.opensearch.ml.memory.action.conversation.DeleteConversationResponse;
import org.opensearch.ml.memory.action.conversation.GetConversationsAction;
import org.opensearch.ml.memory.action.conversation.GetConversationsRequest;
import org.opensearch.ml.memory.action.conversation.GetConversationsResponse;
import org.opensearch.ml.memory.action.conversation.GetInteractionsAction;
import org.opensearch.ml.memory.action.conversation.GetInteractionsRequest;
import org.opensearch.ml.memory.action.conversation.GetInteractionsResponse;

import lombok.AllArgsConstructor;
import lombok.RequiredArgsConstructor;

@AllArgsConstructor
public class MemoryNodeClient implements MemoryClient {

Client client;

public void createConversation(CreateConversationRequest request, ActionListener<CreateConversationResponse> listener) {
client.execute(CreateConversationAction.INSTANCE, request, listener);
}


public ActionFuture<CreateConversationResponse> createConversation(CreateConversationRequest request) {
PlainActionFuture<CreateConversationResponse> fut = PlainActionFuture.newFuture();
createConversation(request, fut);
return fut;
}


public void createInteraction(CreateInteractionRequest request, ActionListener<CreateInteractionResponse> listener) {
client.execute(CreateInteractionAction.INSTANCE, request, listener);
}

public ActionFuture<CreateInteractionResponse> createInteraction(CreateInteractionRequest request) {
PlainActionFuture<CreateInteractionResponse> fut = PlainActionFuture.newFuture();
createInteraction(request, fut);
return fut;
}

public void getConversations(GetConversationsRequest request, ActionListener<GetConversationsResponse> listener) {
client.execute(GetConversationsAction.INSTANCE, request, listener);
}

public ActionFuture<GetConversationsResponse> getConversations(GetConversationsRequest request) {
PlainActionFuture<GetConversationsResponse> fut = PlainActionFuture.newFuture();
getConversations(request, fut);
return fut;
}

public void getInteractions(GetInteractionsRequest request, ActionListener<GetInteractionsResponse> listener) {
client.execute(GetInteractionsAction.INSTANCE, request, listener);
}

public ActionFuture<GetInteractionsResponse> getInteractions(GetInteractionsRequest request) {
PlainActionFuture<GetInteractionsResponse> fut = PlainActionFuture.newFuture();
getInteractions(request, fut);
return fut;
}

public void deleteConversation(DeleteConversationRequest request, ActionListener<DeleteConversationResponse> listener) {
client.execute(DeleteConversationAction.INSTANCE, request, listener);
}

public ActionFuture<DeleteConversationResponse> deleteConversation(DeleteConversationRequest request) {
PlainActionFuture<DeleteConversationResponse> fut = PlainActionFuture.newFuture();
deleteConversation(request, fut);
return fut;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ public class MachineLearningClientTest {
@Mock
SearchResponse searchResponse;

@Mock
MemoryClient memoryClient;

private String modekId = "test_model_id";
private MLModel mlModel;
private MLTask mlTask;
Expand Down Expand Up @@ -132,6 +135,11 @@ public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
public void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
listener.onResponse(searchResponse);
}

@Override
public MemoryClient memory() {
return memoryClient;
}
};
}

Expand Down Expand Up @@ -251,4 +259,9 @@ public void deleteTask() {
public void searchTask() {
assertEquals(searchResponse, machineLearningClient.searchTask(new SearchRequest()).actionGet());
}

@Test
public void memory() {
assertEquals(memoryClient, machineLearningClient.memory());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
Expand All @@ -57,24 +55,32 @@
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.profile.SearchProfileShardResults;
import org.opensearch.search.suggest.Suggest;

import lombok.extern.log4j.Log4j2;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.mockito.Answers.RETURNS_DEEP_STUBS;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.opensearch.ml.common.input.Constants.ACTION;
import static org.opensearch.ml.common.input.Constants.ALGORITHM;
Expand All @@ -85,6 +91,7 @@
import static org.opensearch.ml.common.input.Constants.TRAIN;
import static org.opensearch.ml.common.input.Constants.TRAINANDPREDICT;

@Log4j2
public class MachineLearningNodeClientTest {

@Mock(answer = RETURNS_DEEP_STUBS)
Expand Down Expand Up @@ -590,4 +597,10 @@ private SearchResponse createSearchResponse(ToXContentObject o) throws IOExcepti
SearchResponse.Clusters.EMPTY
);
}

@Test
public void memory() {
MemoryClient memoryClient = machineLearningNodeClient.memory();
assertNotNull(memoryClient);
}
}
Loading

0 comments on commit 37490dc

Please sign in to comment.