forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add search and singular APIs to conversation memory (opensearch-proje…
…ct#1504) * add searchConversation Signed-off-by: HenryL27 <[email protected]> * add searchinteractions Signed-off-by: HenryL27 <[email protected]> * add searchConversationsITTests Signed-off-by: HenryL27 <[email protected]> * add searchInteractionsITTests Signed-off-by: HenryL27 <[email protected]> * add unit tests for storage-layer search Signed-off-by: HenryL27 <[email protected]> * add Search transport actions and tests Signed-off-by: HenryL27 <[email protected]> * add rest search actions Signed-off-by: HenryL27 <[email protected]> * add search rest actions Signed-off-by: HenryL27 <[email protected]> * Add singular get actions at storage layer Signed-off-by: HenryL27 <[email protected]> * Add OpenSearhMemoryHandler unit tests for singular get Signed-off-by: HenryL27 <[email protected]> * Add singular get transport layer Signed-off-by: HenryL27 <[email protected]> * add singular get rest actions Signed-off-by: HenryL27 <[email protected]> * fix async return value problem Signed-off-by: HenryL27 <[email protected]> * address esay PR comments Signed-off-by: HenryL27 <[email protected]> --------- Signed-off-by: HenryL27 <[email protected]>
- Loading branch information
Showing
48 changed files
with
3,875 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
34 changes: 34 additions & 0 deletions
34
memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
/* | ||
* Copyright 2023 Aryn | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import org.opensearch.action.ActionType; | ||
|
||
/** | ||
* Action for retrieving a top-level conversation object by id | ||
*/ | ||
public class GetConversationAction extends ActionType<GetConversationResponse> { | ||
/** Instance of this */ | ||
public static final GetConversationAction INSTANCE = new GetConversationAction(); | ||
/** Name of this action */ | ||
public static final String NAME = "cluster:admin/opensearch/ml/memory/conversation/get"; | ||
|
||
private GetConversationAction() { | ||
super(NAME, GetConversationResponse::new); | ||
} | ||
} |
77 changes: 77 additions & 0 deletions
77
...ry/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
/* | ||
* Copyright 2023 Aryn | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import static org.opensearch.action.ValidateActions.addValidationError; | ||
|
||
import java.io.IOException; | ||
|
||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.ActionRequestValidationException; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.ml.common.conversation.ActionConstants; | ||
import org.opensearch.rest.RestRequest; | ||
|
||
import lombok.AllArgsConstructor; | ||
import lombok.Getter; | ||
|
||
/** | ||
* Action Request object for GetConversation (singular) | ||
*/ | ||
@AllArgsConstructor | ||
public class GetConversationRequest extends ActionRequest { | ||
@Getter | ||
private String conversationId; | ||
|
||
/** | ||
* Stream Constructor | ||
* @param in input stream to read this from | ||
* @throws IOException if something goes wrong reading from stream | ||
*/ | ||
public GetConversationRequest(StreamInput in) throws IOException { | ||
super(in); | ||
this.conversationId = in.readString(); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
out.writeString(this.conversationId); | ||
} | ||
|
||
@Override | ||
public ActionRequestValidationException validate() { | ||
ActionRequestValidationException exception = null; | ||
if (this.conversationId == null) { | ||
exception = addValidationError("GetConversation Request must have a conversation id", exception); | ||
} | ||
return exception; | ||
} | ||
|
||
/** | ||
* Creates a GetConversationRequest from a rest request | ||
* @param request Rest Request representing a GetConversationRequest | ||
* @return the new GetConversationRequest | ||
* @throws IOException if something goes wrong in translation | ||
*/ | ||
public static GetConversationRequest fromRestRequest(RestRequest request) throws IOException { | ||
String conversationId = request.param(ActionConstants.CONVERSATION_ID_FIELD); | ||
return new GetConversationRequest(conversationId); | ||
} | ||
} |
60 changes: 60 additions & 0 deletions
60
...y/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationResponse.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
/* | ||
* Copyright 2023 Aryn | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import java.io.IOException; | ||
|
||
import org.opensearch.core.action.ActionResponse; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.xcontent.ToXContentObject; | ||
import org.opensearch.core.xcontent.XContentBuilder; | ||
import org.opensearch.ml.common.conversation.ConversationMeta; | ||
|
||
import lombok.AllArgsConstructor; | ||
import lombok.Getter; | ||
|
||
/** | ||
* ActionResponse object for GetConversation (singular) | ||
*/ | ||
@AllArgsConstructor | ||
public class GetConversationResponse extends ActionResponse implements ToXContentObject { | ||
|
||
@Getter | ||
private ConversationMeta conversation; | ||
|
||
/** | ||
* Stream Constructor | ||
* @param in input stream to read this from | ||
* @throws IOException if soething goes wrong in reading | ||
*/ | ||
public GetConversationResponse(StreamInput in) throws IOException { | ||
super(in); | ||
this.conversation = ConversationMeta.fromStream(in); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
this.conversation.writeTo(out); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
return this.conversation.toXContent(builder, params); | ||
} | ||
} |
100 changes: 100 additions & 0 deletions
100
...ain/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
/* | ||
* Copyright 2023 Aryn | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import org.opensearch.OpenSearchException; | ||
import org.opensearch.action.support.ActionFilters; | ||
import org.opensearch.action.support.HandledTransportAction; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.cluster.service.ClusterService; | ||
import org.opensearch.common.inject.Inject; | ||
import org.opensearch.common.util.concurrent.ThreadContext; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.ml.common.conversation.ConversationMeta; | ||
import org.opensearch.ml.common.conversation.ConversationalIndexConstants; | ||
import org.opensearch.ml.memory.ConversationalMemoryHandler; | ||
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; | ||
import org.opensearch.tasks.Task; | ||
import org.opensearch.transport.TransportService; | ||
|
||
import lombok.extern.log4j.Log4j2; | ||
|
||
/** | ||
* Transport Action for GetConversation | ||
*/ | ||
@Log4j2 | ||
public class GetConversationTransportAction extends HandledTransportAction<GetConversationRequest, GetConversationResponse> { | ||
private Client client; | ||
private ConversationalMemoryHandler cmHandler; | ||
|
||
private volatile boolean featureIsEnabled; | ||
|
||
/** | ||
* Constructor | ||
* @param transportService for inter-node communications | ||
* @param actionFilters for filtering actions | ||
* @param cmHandler Handler for conversational memory operations | ||
* @param client OS Client for dealing with OS | ||
* @param clusterService for some cluster ops | ||
*/ | ||
@Inject | ||
public GetConversationTransportAction( | ||
TransportService transportService, | ||
ActionFilters actionFilters, | ||
OpenSearchConversationalMemoryHandler cmHandler, | ||
Client client, | ||
ClusterService clusterService | ||
) { | ||
super(GetConversationAction.NAME, transportService, actionFilters, GetConversationRequest::new); | ||
this.client = client; | ||
this.cmHandler = cmHandler; | ||
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); | ||
clusterService | ||
.getClusterSettings() | ||
.addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); | ||
} | ||
|
||
@Override | ||
public void doExecute(Task task, GetConversationRequest request, ActionListener<GetConversationResponse> actionListener) { | ||
if (!featureIsEnabled) { | ||
actionListener | ||
.onFailure( | ||
new OpenSearchException( | ||
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting " | ||
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() | ||
) | ||
); | ||
return; | ||
} else { | ||
String conversationId = request.getConversationId(); | ||
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { | ||
ActionListener<GetConversationResponse> internalListener = ActionListener | ||
.runBefore(actionListener, () -> context.restore()); | ||
ActionListener<ConversationMeta> al = ActionListener.wrap(conversationMeta -> { | ||
internalListener.onResponse(new GetConversationResponse(conversationMeta)); | ||
}, e -> { internalListener.onFailure(e); }); | ||
cmHandler.getConversation(conversationId, al); | ||
} catch (Exception e) { | ||
log.error("Failed to get Conversation " + conversationId, e); | ||
actionListener.onFailure(e); | ||
} | ||
|
||
} | ||
|
||
} | ||
} |
Oops, something went wrong.