diff --git a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStateAction.java b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStateAction.java index d2eb03ed..7af1ce31 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStateAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStateAction.java @@ -77,7 +77,7 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); } - GetWorkflowStateRequest getWorkflowRequest = new GetWorkflowStateRequest(workflowId, all); + GetWorkflowStateRequest getWorkflowRequest = new GetWorkflowStateRequest(workflowId, all, tenantId); return channel -> client.execute(GetWorkflowStateAction.INSTANCE, getWorkflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); diff --git a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java index d0119d65..87380b23 100644 --- a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java @@ -140,7 +140,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener executeDeprovisionRequest(request, listener, context), + () -> executeDeprovisionRequest(request, tenantId, listener, context), client, sdkClient, clusterService, @@ -156,12 +156,13 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener listener, ThreadContext.StoredContext context ) { String workflowId = request.getWorkflowId(); String allowDelete = request.getParams().get(ALLOW_DELETE); - GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true); + GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true, tenantId); logger.info("Querying state for workflow: {}", workflowId); client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> { context.restore(); diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateRequest.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateRequest.java index 7fd546c2..abb64e6c 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateRequest.java @@ -8,6 +8,7 @@ */ package org.opensearch.flowframework.transport; +import org.opensearch.Version; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.Nullable; @@ -32,14 +33,18 @@ public class GetWorkflowStateRequest extends ActionRequest { */ private boolean all; + private String tenantId; + /** * Instantiates a new GetWorkflowStateRequest * @param workflowId the documentId of the workflow * @param all whether the get request is looking for all fields in status + * @param tenantId the tenant id */ - public GetWorkflowStateRequest(@Nullable String workflowId, boolean all) { + public GetWorkflowStateRequest(@Nullable String workflowId, boolean all, String tenantId) { this.workflowId = workflowId; this.all = all; + this.tenantId = tenantId; } /** @@ -51,6 +56,10 @@ public GetWorkflowStateRequest(StreamInput in) throws IOException { super(in); this.workflowId = in.readString(); this.all = in.readBoolean(); + // TODO: After backport, change to next 2.x release + if (in.getVersion().onOrAfter(Version.CURRENT)) { + this.tenantId = in.readOptionalString(); + } } /** @@ -75,6 +84,10 @@ public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(workflowId); out.writeBoolean(all); + // TODO: After backport, change to next 2.x release + if (out.getVersion().onOrAfter(Version.CURRENT)) { + out.writeOptionalString(tenantId); + } } @Override diff --git a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java index a0825dd3..8b881402 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java @@ -150,7 +150,7 @@ protected void doExecute(Task task, ReprovisionWorkflowRequest request, ActionLi filterByEnabled, flowFrameworkSettings.isMultiTenancyEnabled(), listener, - () -> executeReprovisionRequest(request, listener, context), + () -> executeReprovisionRequest(request, tenantId, listener, context), client, sdkClient, clusterService, @@ -170,18 +170,20 @@ protected void doExecute(Task task, ReprovisionWorkflowRequest request, ActionLi /** * Execute the reprovision request * @param request the reprovision request + * @param tenantId * @param listener the action listener * @param context the thread context */ private void executeReprovisionRequest( ReprovisionWorkflowRequest request, + String tenantId, ActionListener listener, ThreadContext.StoredContext context ) { String workflowId = request.getWorkflowId(); logger.info("Querying state for workflow: {}", workflowId); // Retrieve state and resources created - GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true); + GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true, tenantId); client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> { context.restore(); diff --git a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportActionTests.java index ec7b390a..690b5eff 100644 --- a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportActionTests.java @@ -115,7 +115,7 @@ public void onFailure(Exception e) {} } public void testGetTransportAction() throws IOException { - GetWorkflowStateRequest getWorkflowRequest = new GetWorkflowStateRequest("1234", false); + GetWorkflowStateRequest getWorkflowRequest = new GetWorkflowStateRequest("1234", false, null); getWorkflowStateTransportAction.doExecute(task, getWorkflowRequest, response); } @@ -125,7 +125,7 @@ public void testGetAction() { } public void testGetWorkflowStateRequest() throws IOException { - GetWorkflowStateRequest request = new GetWorkflowStateRequest("1234", false); + GetWorkflowStateRequest request = new GetWorkflowStateRequest("1234", false, null); BytesStreamOutput out = new BytesStreamOutput(); request.writeTo(out); StreamInput input = out.bytes().streamInput(); @@ -164,7 +164,7 @@ public void testGetWorkflowStateResponse() throws IOException { public void testExecuteGetWorkflowStateRequestFailure() throws IOException { String workflowId = "test-workflow"; - GetWorkflowStateRequest request = new GetWorkflowStateRequest(workflowId, false); + GetWorkflowStateRequest request = new GetWorkflowStateRequest(workflowId, false, null); ActionListener listener = mock(ActionListener.class); // Stub client.get to force on failure @@ -185,7 +185,7 @@ public void testExecuteGetWorkflowStateRequestFailure() throws IOException { public void testExecuteGetWorkflowStateRequestIndexNotFound() throws IOException { String workflowId = "test-workflow"; - GetWorkflowStateRequest request = new GetWorkflowStateRequest(workflowId, false); + GetWorkflowStateRequest request = new GetWorkflowStateRequest(workflowId, false, null); ActionListener listener = mock(ActionListener.class); // Stub client.get to force on failure @@ -206,7 +206,7 @@ public void testExecuteGetWorkflowStateRequestIndexNotFound() throws IOException public void testExecuteGetWorkflowStateRequestParseFailure() throws IOException { String workflowId = "test-workflow"; - GetWorkflowStateRequest request = new GetWorkflowStateRequest(workflowId, false); + GetWorkflowStateRequest request = new GetWorkflowStateRequest(workflowId, false, null); ActionListener listener = mock(ActionListener.class); // Stub client.get to force on response