Skip to content

Commit

Permalink
add reprovision sync execution
Browse files Browse the repository at this point in the history
Signed-off-by: Junwei Dai <[email protected]>

# Conflicts:
#	src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java
  • Loading branch information
Junwei Dai committed Jan 15, 2025
1 parent 2584329 commit 15e052b
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
*/
package org.opensearch.flowframework.common;

import org.opensearch.common.unit.TimeValue;

/**
* Representation of common values that are used across project
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContent;
Expand Down Expand Up @@ -43,6 +44,7 @@
import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS;
import static org.opensearch.flowframework.common.CommonValue.USE_CASE;
import static org.opensearch.flowframework.common.CommonValue.VALIDATION;
import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
Expand Down Expand Up @@ -88,6 +90,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
boolean reprovision = request.paramAsBoolean(REPROVISION_WORKFLOW, false);
boolean updateFields = request.paramAsBoolean(UPDATE_WORKFLOW_FIELDS, false);
String useCase = request.param(USE_CASE);
TimeValue waitForCompletionTimeout = request.paramAsTime(WAIT_FOR_COMPLETION_TIMEOUT, TimeValue.MINUS_ONE);

// If provisioning, consume all other params and pass to provision transport action
Map<String, String> params = provision
Expand Down Expand Up @@ -145,6 +148,17 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
);
return processError(ffe, params, request);
}
// Ensure wait_for_completion is not set unless reprovision or provision is true
if (waitForCompletionTimeout != TimeValue.MINUS_ONE && !(reprovision || provision)) {
FlowFrameworkException ffe = new FlowFrameworkException(
"Request parameters "
+ request.consumedParams()
+ " are not allowed unless the 'provision' or 'reprovision' parameter is set to true.",
RestStatus.BAD_REQUEST
);
return processError(ffe, params, request);
}

try {
Template template;
Map<String, String> useCaseDefaultsMap = Collections.emptyMap();
Expand Down Expand Up @@ -219,7 +233,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
if (updateFields) {
params = Map.of(UPDATE_WORKFLOW_FIELDS, "true");
}

if (waitForCompletionTimeout != TimeValue.MINUS_ONE) {
params = Map.of(WAIT_FOR_COMPLETION_TIMEOUT, waitForCompletionTimeout.toString());
}
WorkflowRequest workflowRequest = new WorkflowRequest(
workflowId,
template,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,14 @@ private void createExecute(WorkflowRequest request, User user, ActionListener<Wo
ReprovisionWorkflowAction.INSTANCE,
reprovisionRequest,
ActionListener.wrap(reprovisionResponse -> {
listener.onResponse(new WorkflowResponse(reprovisionResponse.getWorkflowId()));
listener.onResponse(
reprovisionRequest.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE
? new WorkflowResponse(reprovisionResponse.getWorkflowId())
: new WorkflowResponse(
reprovisionResponse.getWorkflowId(),
reprovisionResponse.getWorkflowState()
)
);
}, exception -> {
String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage(
"Reprovisioning failed for workflow {}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
Expand All @@ -34,6 +35,7 @@
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.model.WorkflowState;
import org.opensearch.flowframework.util.EncryptorUtils;
import org.opensearch.flowframework.util.WorkflowTimeoutUtility;
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
Expand All @@ -48,6 +50,8 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;

import static org.opensearch.flowframework.common.CommonValue.ERROR_FIELD;
Expand Down Expand Up @@ -243,9 +247,23 @@ private void executeReprovisionRequest(
Template updatedTemplateWithProvisionedTime = Template.builder(updatedTemplate)
.lastProvisionedTime(Instant.now())
.build();
executeWorkflowAsync(workflowId, updatedTemplateWithProvisionedTime, reprovisionProcessSequence, listener);

listener.onResponse(new WorkflowResponse(workflowId));
if (request.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE) {
executeWorkflowAsync(workflowId, updatedTemplateWithProvisionedTime, reprovisionProcessSequence, listener);
} else {
executeWorkflowSync(
workflowId,
updatedTemplate,
reprovisionProcessSequence,
listener,
request.getWaitForCompletionTimeout().getMillis()
);
}

if (request.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE) {
listener.onResponse(new WorkflowResponse(workflowId));
} else {
logger.info("Waiting for workflow completion");
}

}, exception -> {
String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage("Failed to update workflow state: {}", workflowId)
Expand Down Expand Up @@ -284,13 +302,42 @@ private void executeWorkflowAsync(
try {
threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL).execute(() -> {
updateTemplate(template, workflowId);
executeWorkflow(template, workflowSequence, workflowId);
executeWorkflow(template, workflowSequence, workflowId, listener, false);
});
} catch (Exception exception) {
listener.onFailure(new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(exception)));
}
}

private void executeWorkflowSync(
String workflowId,
Template template,
List<ProcessNode> workflowSequence,
ActionListener<WorkflowResponse> listener,
long timeout
) {
AtomicBoolean isResponseSent = new AtomicBoolean(false);
CompletableFuture.runAsync(() -> {
try {
updateTemplate(template, workflowId);
executeWorkflow(template, workflowSequence, workflowId, new ActionListener<>() {
@Override
public void onResponse(WorkflowResponse workflowResponse) {
WorkflowTimeoutUtility.handleResponse(workflowId, workflowResponse, isResponseSent, listener);
}

@Override
public void onFailure(Exception e) {
WorkflowTimeoutUtility.handleFailure(workflowId, e, isResponseSent, listener);
}
}, true);
} catch (Exception ex) {
WorkflowTimeoutUtility.handleFailure(workflowId, ex, isResponseSent, listener);
}
}, threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL));
WorkflowTimeoutUtility.scheduleTimeoutHandler(client, threadPool, workflowId, listener, timeout, isResponseSent);
}

/**
* Replace template document
* @param template The template to store after reprovisioning completes successfully
Expand All @@ -310,7 +357,13 @@ private void updateTemplate(Template template, String workflowId) {
* @param workflowSequence The topologically sorted workflow to execute
* @param workflowId The workflowId associated with the workflow that is executing
*/
private void executeWorkflow(Template template, List<ProcessNode> workflowSequence, String workflowId) {
private void executeWorkflow(
Template template,
List<ProcessNode> workflowSequence,
String workflowId,
ActionListener<WorkflowResponse> listener,
boolean isSyncExecution
) {
String currentStepId = "";
try {
Map<String, PlainActionFuture<?>> workflowFutureMap = new LinkedHashMap<>();
Expand Down Expand Up @@ -349,7 +402,23 @@ private void executeWorkflow(Template template, List<ProcessNode> workflowSequen
ActionListener.wrap(updateResponse -> {

logger.info("updated workflow {} state to {}", workflowId, State.COMPLETED);

if (isSyncExecution) {
client.execute(
GetWorkflowStateAction.INSTANCE,
new GetWorkflowStateRequest(workflowId, false),
ActionListener.wrap(response -> {
listener.onResponse(new WorkflowResponse(workflowId, response.getWorkflowState()));
}, exception -> {
String errorMessage = "Failed to get workflow state.";
logger.error(errorMessage, exception);
if (exception instanceof FlowFrameworkException) {
listener.onFailure(exception);
} else {
listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)));
}
})
);
}
}, exception -> { logger.error("Failed to update workflow state for workflow {}", workflowId, exception); })
);
} catch (Exception ex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import static org.opensearch.flowframework.common.CommonValue.REPROVISION_WORKFLOW;
import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS;
import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT;

/**
* Transport Request to create, provision, and deprovision a workflow
Expand Down Expand Up @@ -154,7 +155,8 @@ public WorkflowRequest(
this.validation = validation;
this.provision = provisionOrUpdate && !params.containsKey(UPDATE_WORKFLOW_FIELDS);
this.updateFields = !provision && Boolean.parseBoolean(params.get(UPDATE_WORKFLOW_FIELDS));
if (!this.provision && params.keySet().stream().anyMatch(k -> !UPDATE_WORKFLOW_FIELDS.equals(k))) {
if (!this.provision
&& params.keySet().stream().anyMatch(k -> !UPDATE_WORKFLOW_FIELDS.equals(k) && !WAIT_FOR_COMPLETION_TIMEOUT.equals(k))) {
throw new IllegalArgumentException("Params may only be included when provisioning.");
}
this.params = this.updateFields ? Collections.emptyMap() : params;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.opensearch.flowframework.transport;

import org.opensearch.Version;
import org.opensearch.common.Nullable;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -67,6 +68,7 @@ public String getWorkflowId() {
* Gets the workflowState of this repsonse
* @return the workflowState
*/
@Nullable
public WorkflowState getWorkflowState() {
return this.workflowState;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import java.util.concurrent.atomic.AtomicBoolean;

import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL;

/**
* Utility class for managing timeout tasks in workflow execution.
* This class provides methods to schedule timeout handlers, wrap listeners with timeout cancellation logic,
Expand Down Expand Up @@ -57,7 +59,7 @@ public static ActionListener<WorkflowResponse> scheduleTimeoutHandler(
Scheduler.ScheduledCancellable scheduledCancellable = threadPool.schedule(
new WorkflowTimeoutListener(client, workflowId, listener, isResponseSent),
TimeValue.timeValueMillis(adjustedTimeout),
ThreadPool.Names.GENERIC
PROVISION_WORKFLOW_THREAD_POOL
);

return wrapWithTimeoutCancellationListener(listener, scheduledCancellable, isResponseSent);
Expand Down Expand Up @@ -181,6 +183,7 @@ public static void fetchWorkflowStateAfterTimeout(
final String workflowId,
final ActionListener<WorkflowResponse> listener
) {
logger.info("Fetching workflow state after timeout");
client.execute(
GetWorkflowStateAction.INSTANCE,
new GetWorkflowStateRequest(workflowId, false),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import static org.opensearch.flowframework.common.CommonValue.REPROVISION_WORKFLOW;
import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS;
import static org.opensearch.flowframework.common.CommonValue.USE_CASE;
import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
Expand Down Expand Up @@ -128,7 +129,7 @@ public void testCreateWorkflowRequestWithParamsAndProvision() throws Exception {
assertTrue(channel.capturedResponse().content().utf8ToString().contains("id-123"));
}

public void testRestCreateWorkflow_withWaitForCompletionTimeout() throws Exception {
public void testRestCreateWorkflowWithWaitForCompletionTimeout() throws Exception {
RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
.withParams(Map.ofEntries(Map.entry(PROVISION_WORKFLOW, "true"), Map.entry("wait_for_completion_timeout", "5s")))
.withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON)
Expand Down Expand Up @@ -162,6 +163,23 @@ public void testCreateWorkflowRequestWithParamsButNoProvision() throws Exception
);
}

public void testCreateWorkflowRequestWithWaitForTimeCompletionTimeoutButNoProvision() throws Exception {
RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
.withPath(this.createWorkflowPath)
.withParams(Map.of(WAIT_FOR_COMPLETION_TIMEOUT, "1s"))
.withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON)
.build();
FakeRestChannel channel = new FakeRestChannel(request, false, 1);
createWorkflowRestAction.handleRequest(request, channel, nodeClient);
assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status());
assertTrue(
channel.capturedResponse()
.content()
.utf8ToString()
.contains("are not allowed unless the 'provision' or 'reprovision' parameter is set to true.")
);
}

public void testCreateWorkflowRequestWithUpdateAndProvision() throws Exception {
RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
.withPath(this.createWorkflowPath)
Expand Down

0 comments on commit 15e052b

Please sign in to comment.