Skip to content

Commit

Permalink
Add tenantId to GetWorkflowStateRequest
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Dec 22, 2024
1 parent d55e694 commit a91c20e
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request
throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST);
}

GetWorkflowStateRequest getWorkflowRequest = new GetWorkflowStateRequest(workflowId, all);
GetWorkflowStateRequest getWorkflowRequest = new GetWorkflowStateRequest(workflowId, all, tenantId);
return channel -> client.execute(GetWorkflowStateAction.INSTANCE, getWorkflowRequest, ActionListener.wrap(response -> {
XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS);
channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
filterByEnabled,
flowFrameworkSettings.isMultiTenancyEnabled(),
listener,
() -> executeDeprovisionRequest(request, listener, context),
() -> executeDeprovisionRequest(request, tenantId, listener, context),
client,
sdkClient,
clusterService,
Expand All @@ -156,12 +156,13 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work

private void executeDeprovisionRequest(
WorkflowRequest request,
String tenantId,
ActionListener<WorkflowResponse> listener,
ThreadContext.StoredContext context
) {
String workflowId = request.getWorkflowId();
String allowDelete = request.getParams().get(ALLOW_DELETE);
GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true);
GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true, tenantId);
logger.info("Querying state for workflow: {}", workflowId);
client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> {
context.restore();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/
package org.opensearch.flowframework.transport;

import org.opensearch.Version;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.Nullable;
Expand All @@ -32,14 +33,18 @@ public class GetWorkflowStateRequest extends ActionRequest {
*/
private boolean all;

private String tenantId;

/**
* Instantiates a new GetWorkflowStateRequest
* @param workflowId the documentId of the workflow
* @param all whether the get request is looking for all fields in status
* @param tenantId the tenant id
*/
public GetWorkflowStateRequest(@Nullable String workflowId, boolean all) {
public GetWorkflowStateRequest(@Nullable String workflowId, boolean all, String tenantId) {
this.workflowId = workflowId;
this.all = all;
this.tenantId = tenantId;
}

/**
Expand All @@ -51,6 +56,10 @@ public GetWorkflowStateRequest(StreamInput in) throws IOException {
super(in);
this.workflowId = in.readString();
this.all = in.readBoolean();
// TODO: After backport, change to next 2.x release
if (in.getVersion().onOrAfter(Version.CURRENT)) {
this.tenantId = in.readOptionalString();
}
}

/**
Expand All @@ -75,6 +84,10 @@ public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(workflowId);
out.writeBoolean(all);
// TODO: After backport, change to next 2.x release
if (out.getVersion().onOrAfter(Version.CURRENT)) {
out.writeOptionalString(tenantId);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ protected void doExecute(Task task, ReprovisionWorkflowRequest request, ActionLi
filterByEnabled,
flowFrameworkSettings.isMultiTenancyEnabled(),
listener,
() -> executeReprovisionRequest(request, listener, context),
() -> executeReprovisionRequest(request, tenantId, listener, context),
client,
sdkClient,
clusterService,
Expand All @@ -170,18 +170,20 @@ protected void doExecute(Task task, ReprovisionWorkflowRequest request, ActionLi
/**
* Execute the reprovision request
* @param request the reprovision request
* @param tenantId
* @param listener the action listener
* @param context the thread context
*/
private void executeReprovisionRequest(
ReprovisionWorkflowRequest request,
String tenantId,
ActionListener<WorkflowResponse> listener,
ThreadContext.StoredContext context
) {
String workflowId = request.getWorkflowId();
logger.info("Querying state for workflow: {}", workflowId);
// Retrieve state and resources created
GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true);
GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true, tenantId);
client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> {
context.restore();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public void onFailure(Exception e) {}
}

public void testGetTransportAction() throws IOException {
GetWorkflowStateRequest getWorkflowRequest = new GetWorkflowStateRequest("1234", false);
GetWorkflowStateRequest getWorkflowRequest = new GetWorkflowStateRequest("1234", false, null);
getWorkflowStateTransportAction.doExecute(task, getWorkflowRequest, response);
}

Expand All @@ -125,7 +125,7 @@ public void testGetAction() {
}

public void testGetWorkflowStateRequest() throws IOException {
GetWorkflowStateRequest request = new GetWorkflowStateRequest("1234", false);
GetWorkflowStateRequest request = new GetWorkflowStateRequest("1234", false, null);
BytesStreamOutput out = new BytesStreamOutput();
request.writeTo(out);
StreamInput input = out.bytes().streamInput();
Expand Down Expand Up @@ -164,7 +164,7 @@ public void testGetWorkflowStateResponse() throws IOException {

public void testExecuteGetWorkflowStateRequestFailure() throws IOException {
String workflowId = "test-workflow";
GetWorkflowStateRequest request = new GetWorkflowStateRequest(workflowId, false);
GetWorkflowStateRequest request = new GetWorkflowStateRequest(workflowId, false, null);
ActionListener<GetWorkflowStateResponse> listener = mock(ActionListener.class);

// Stub client.get to force on failure
Expand All @@ -185,7 +185,7 @@ public void testExecuteGetWorkflowStateRequestFailure() throws IOException {

public void testExecuteGetWorkflowStateRequestIndexNotFound() throws IOException {
String workflowId = "test-workflow";
GetWorkflowStateRequest request = new GetWorkflowStateRequest(workflowId, false);
GetWorkflowStateRequest request = new GetWorkflowStateRequest(workflowId, false, null);
ActionListener<GetWorkflowStateResponse> listener = mock(ActionListener.class);

// Stub client.get to force on failure
Expand All @@ -206,7 +206,7 @@ public void testExecuteGetWorkflowStateRequestIndexNotFound() throws IOException

public void testExecuteGetWorkflowStateRequestParseFailure() throws IOException {
String workflowId = "test-workflow";
GetWorkflowStateRequest request = new GetWorkflowStateRequest(workflowId, false);
GetWorkflowStateRequest request = new GetWorkflowStateRequest(workflowId, false, null);
ActionListener<GetWorkflowStateResponse> listener = mock(ActionListener.class);

// Stub client.get to force on response
Expand Down

0 comments on commit a91c20e

Please sign in to comment.