diff --git a/sdk-workflows/src/main/java/io/dapr/workflows/DaprWorkflowContextImpl.java b/sdk-workflows/src/main/java/io/dapr/workflows/DaprWorkflowContextImpl.java index 75d904bcd..c6f474d70 100644 --- a/sdk-workflows/src/main/java/io/dapr/workflows/DaprWorkflowContextImpl.java +++ b/sdk-workflows/src/main/java/io/dapr/workflows/DaprWorkflowContextImpl.java @@ -18,6 +18,9 @@ import com.microsoft.durabletask.TaskCanceledException; import com.microsoft.durabletask.TaskOptions; import com.microsoft.durabletask.TaskOrchestrationContext; +import io.dapr.workflows.saga.DaprSagaContextImpl; +import io.dapr.workflows.saga.Saga; +import io.dapr.workflows.saga.SagaContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.slf4j.helpers.NOPLogger; @@ -32,6 +35,7 @@ public class DaprWorkflowContextImpl implements WorkflowContext { private final TaskOrchestrationContext innerContext; private final Logger logger; + private final Saga saga; /** * Constructor for DaprWorkflowContextImpl. @@ -51,6 +55,23 @@ public DaprWorkflowContextImpl(TaskOrchestrationContext context) throws IllegalA * @throws IllegalArgumentException if context or logger is null */ public DaprWorkflowContextImpl(TaskOrchestrationContext context, Logger logger) throws IllegalArgumentException { + this(context, logger, null); + } + + public DaprWorkflowContextImpl(TaskOrchestrationContext context, Saga saga) throws IllegalArgumentException { + this(context, LoggerFactory.getLogger(WorkflowContext.class), saga); + } + + /** + * Constructor for DaprWorkflowContextImpl. + * + * @param context TaskOrchestrationContext + * @param logger Logger + * @param saga saga object, if null, saga is disabled + * @throws IllegalArgumentException if context or logger is null + */ + public DaprWorkflowContextImpl(TaskOrchestrationContext context, Logger logger, Saga saga) + throws IllegalArgumentException { if (context == null) { throw new IllegalArgumentException("Context cannot be null"); } @@ -60,6 +81,7 @@ public DaprWorkflowContextImpl(TaskOrchestrationContext context, Logger logger) this.innerContext = context; this.logger = logger; + this.saga = saga; } /** @@ -110,15 +132,20 @@ public Task waitForExternalEvent(String name, Duration timeout, Class } /** - * Waits for an event to be raised named {@code name} and returns a {@link Task} that completes when the event is + * Waits for an event to be raised named {@code name} and returns a {@link Task} + * that completes when the event is * received or is canceled when {@code timeout} expires. * - *

See {@link #waitForExternalEvent(String, Duration, Class)} for a full description. + *

See {@link #waitForExternalEvent(String, Duration, Class)} for a full + * description. * * @param name the case-insensitive name of the event to wait for - * @param timeout the amount of time to wait before canceling the returned {@code Task} - * @return a new {@link Task} that completes when the external event is received or when {@code timeout} expires - * @throws TaskCanceledException if the specified {@code timeout} value expires before the event is received + * @param timeout the amount of time to wait before canceling the returned + * {@code Task} + * @return a new {@link Task} that completes when the external event is received + * or when {@code timeout} expires + * @throws TaskCanceledException if the specified {@code timeout} value expires + * before the event is received */ @Override public Task waitForExternalEvent(String name, Duration timeout) throws TaskCanceledException { @@ -126,10 +153,12 @@ public Task waitForExternalEvent(String name, Duration timeout) throws } /** - * Waits for an event to be raised named {@code name} and returns a {@link Task} that completes when the event is + * Waits for an event to be raised named {@code name} and returns a {@link Task} + * that completes when the event is * received. * - *

See {@link #waitForExternalEvent(String, Duration, Class)} for a full description. + *

See {@link #waitForExternalEvent(String, Duration, Class)} for a full + * description. * * @param name the case-insensitive name of the event to wait for * @return a new {@link Task} that completes when the external event is received @@ -172,7 +201,6 @@ public Task createTimer(Duration duration) { return this.innerContext.createTimer(duration); } - /** * {@inheritDoc} */ @@ -185,7 +213,7 @@ public T getInput(Class targetType) { */ @Override public Task callSubWorkflow(String name, @Nullable Object input, @Nullable String instanceID, - @Nullable TaskOptions options, Class returnType) { + @Nullable TaskOptions options, Class returnType) { return this.innerContext.callSubOrchestrator(name, input, instanceID, options, returnType); } @@ -213,4 +241,13 @@ public void continueAsNew(Object input, boolean preserveUnprocessedEvents) { public UUID newUuid() { return this.innerContext.newUUID(); } + + @Override + public SagaContext getSagaContext() { + if (this.saga == null) { + throw new UnsupportedOperationException("Saga is not enabled"); + } + + return new DaprSagaContextImpl(this.saga, this); + } } diff --git a/sdk-workflows/src/main/java/io/dapr/workflows/Workflow.java b/sdk-workflows/src/main/java/io/dapr/workflows/Workflow.java index 66b5c02d7..94bb4c828 100644 --- a/sdk-workflows/src/main/java/io/dapr/workflows/Workflow.java +++ b/sdk-workflows/src/main/java/io/dapr/workflows/Workflow.java @@ -13,11 +13,16 @@ package io.dapr.workflows; +import com.microsoft.durabletask.interruption.ContinueAsNewInterruption; +import com.microsoft.durabletask.interruption.OrchestratorBlockedException; +import io.dapr.workflows.saga.SagaCompensationException; +import io.dapr.workflows.saga.SagaOption; + /** * Common interface for workflow implementations. */ public abstract class Workflow { - public Workflow(){ + public Workflow() { } /** @@ -30,10 +35,50 @@ public Workflow(){ /** * Executes the workflow logic. * - * @param ctx provides access to methods for scheduling durable tasks and getting information about the current + * @param ctx provides access to methods for scheduling durable tasks and + * getting information about the current * workflow instance. */ public void run(WorkflowContext ctx) { - this.create().run(ctx); + WorkflowStub stub = this.create(); + + if (!this.isSagaEnabled()) { + // saga disabled + stub.run(ctx); + } else { + // saga enabled + try { + stub.run(ctx); + } catch (OrchestratorBlockedException | ContinueAsNewInterruption e) { + throw e; + } catch (SagaCompensationException e) { + // Saga compensation is triggered gracefully but failed in exception + // don't need to trigger compensation again + throw e; + } catch (Exception e) { + try { + ctx.getSagaContext().compensate(); + } catch (Exception se) { + se.addSuppressed(e); + throw se; + } + + throw e; + } + } + } + + public boolean isSagaEnabled() { + return this.getSagaOption() != null; + } + + /** + * get saga configuration. + * + * @return saga configuration + */ + public SagaOption getSagaOption() { + // by default, saga is disabled + return null; } } diff --git a/sdk-workflows/src/main/java/io/dapr/workflows/WorkflowContext.java b/sdk-workflows/src/main/java/io/dapr/workflows/WorkflowContext.java index 8338cd393..5315616ff 100644 --- a/sdk-workflows/src/main/java/io/dapr/workflows/WorkflowContext.java +++ b/sdk-workflows/src/main/java/io/dapr/workflows/WorkflowContext.java @@ -18,6 +18,7 @@ import com.microsoft.durabletask.TaskCanceledException; import com.microsoft.durabletask.TaskFailedException; import com.microsoft.durabletask.TaskOptions; +import io.dapr.workflows.saga.SagaContext; import org.slf4j.Logger; import javax.annotation.Nullable; @@ -530,4 +531,12 @@ default void continueAsNew(Object input) { default UUID newUuid() { throw new RuntimeException("No implementation found."); } + + /** + * get saga context. + * + * @return saga context + * @throws UnsupportedOperationException if saga is not enabled. + */ + SagaContext getSagaContext(); } diff --git a/sdk-workflows/src/main/java/io/dapr/workflows/WorkflowStub.java b/sdk-workflows/src/main/java/io/dapr/workflows/WorkflowStub.java index 561a6e1a7..6a109c626 100644 --- a/sdk-workflows/src/main/java/io/dapr/workflows/WorkflowStub.java +++ b/sdk-workflows/src/main/java/io/dapr/workflows/WorkflowStub.java @@ -13,8 +13,6 @@ package io.dapr.workflows; -import io.dapr.workflows.WorkflowContext; - @FunctionalInterface public interface WorkflowStub { void run(WorkflowContext ctx); diff --git a/sdk-workflows/src/main/java/io/dapr/workflows/runtime/OrchestratorWrapper.java b/sdk-workflows/src/main/java/io/dapr/workflows/runtime/OrchestratorWrapper.java index f28eed0de..d104c9c3e 100644 --- a/sdk-workflows/src/main/java/io/dapr/workflows/runtime/OrchestratorWrapper.java +++ b/sdk-workflows/src/main/java/io/dapr/workflows/runtime/OrchestratorWrapper.java @@ -17,6 +17,7 @@ import com.microsoft.durabletask.TaskOrchestrationFactory; import io.dapr.workflows.DaprWorkflowContextImpl; import io.dapr.workflows.Workflow; +import io.dapr.workflows.saga.Saga; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; @@ -55,7 +56,13 @@ public TaskOrchestration create() { String.format("Unable to instantiate instance of workflow class '%s'", this.name), e ); } - workflow.run(new DaprWorkflowContextImpl(ctx)); + + if (workflow.getSagaOption() != null) { + Saga saga = new Saga(workflow.getSagaOption()); + workflow.run(new DaprWorkflowContextImpl(ctx, saga)); + } else { + workflow.run(new DaprWorkflowContextImpl(ctx)); + } }; } diff --git a/sdk-workflows/src/main/java/io/dapr/workflows/saga/CompensatationInformation.java b/sdk-workflows/src/main/java/io/dapr/workflows/saga/CompensatationInformation.java new file mode 100644 index 000000000..cf0fe202c --- /dev/null +++ b/sdk-workflows/src/main/java/io/dapr/workflows/saga/CompensatationInformation.java @@ -0,0 +1,68 @@ +/* + * Copyright 2023 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.workflows.saga; + +import com.microsoft.durabletask.TaskOptions; + +/** + * Information for a compensation activity. + */ +class CompensatationInformation { + private final String compensatationActivityClassName; + private final Object compensatationActivityInput; + private final TaskOptions taskOptions; + + /** + * Constructor for a compensation information. + * + * @param compensatationActivityClassName Class name of the activity to do + * compensatation. + * @param compensatationActivityInput Input of the activity to do + * compensatation. + * @param taskOptions task options to set retry strategy + */ + public CompensatationInformation(String compensatationActivityClassName, + Object compensatationActivityInput, TaskOptions taskOptions) { + this.compensatationActivityClassName = compensatationActivityClassName; + this.compensatationActivityInput = compensatationActivityInput; + this.taskOptions = taskOptions; + } + + /** + * Gets the class name of the activity. + * + * @return the class name of the activity. + */ + public String getCompensatationActivityClassName() { + return compensatationActivityClassName; + } + + /** + * Gets the input of the activity. + * + * @return the input of the activity. + */ + public Object getCompensatationActivityInput() { + return compensatationActivityInput; + } + + /** + * get task options. + * + * @return task options, null if not set + */ + public TaskOptions getTaskOptions() { + return taskOptions; + } +} \ No newline at end of file diff --git a/sdk-workflows/src/main/java/io/dapr/workflows/saga/DaprSagaContextImpl.java b/sdk-workflows/src/main/java/io/dapr/workflows/saga/DaprSagaContextImpl.java new file mode 100644 index 000000000..5ede2af7f --- /dev/null +++ b/sdk-workflows/src/main/java/io/dapr/workflows/saga/DaprSagaContextImpl.java @@ -0,0 +1,54 @@ +/* + * Copyright 2023 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.workflows.saga; + +import io.dapr.workflows.WorkflowContext; + +/** + * Dapr Saga Context implementation. + */ +public class DaprSagaContextImpl implements SagaContext { + + private final Saga saga; + private final WorkflowContext workflowContext; + + /** + * Constructor to build up instance. + * + * @param saga Saga instance. + * @param workflowContext Workflow context. + * @throws IllegalArgumentException if saga or workflowContext is null. + */ + public DaprSagaContextImpl(Saga saga, WorkflowContext workflowContext) { + if (saga == null) { + throw new IllegalArgumentException("Saga should not be null"); + } + if (workflowContext == null) { + throw new IllegalArgumentException("workflowContext should not be null"); + } + + this.saga = saga; + this.workflowContext = workflowContext; + } + + @Override + public void registerCompensation(String activityClassName, Object activityInput) { + this.saga.registerCompensation(activityClassName, activityInput); + } + + @Override + public void compensate() { + this.saga.compensate(workflowContext); + } +} diff --git a/sdk-workflows/src/main/java/io/dapr/workflows/saga/Saga.java b/sdk-workflows/src/main/java/io/dapr/workflows/saga/Saga.java new file mode 100644 index 000000000..f2a151b9e --- /dev/null +++ b/sdk-workflows/src/main/java/io/dapr/workflows/saga/Saga.java @@ -0,0 +1,130 @@ +/* + * Copyright 2023 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.workflows.saga; + +import com.microsoft.durabletask.Task; +import com.microsoft.durabletask.TaskOptions; +import com.microsoft.durabletask.interruption.ContinueAsNewInterruption; +import com.microsoft.durabletask.interruption.OrchestratorBlockedException; +import io.dapr.workflows.WorkflowContext; + +import java.util.ArrayList; +import java.util.List; + +public final class Saga { + private final SagaOption option; + private final List compensationActivities = new ArrayList<>(); + + /** + * Build up a Saga with its options. + * + * @param option Saga option. + */ + public Saga(SagaOption option) { + if (option == null) { + throw new IllegalArgumentException("option is required and should not be null."); + } + this.option = option; + } + + /** + * Register a compensation activity. + * + * @param activityClassName name of the activity class + * @param activityInput input of the activity to be compensated + */ + public void registerCompensation(String activityClassName, Object activityInput) { + this.registerCompensation(activityClassName, activityInput, null); + } + + /** + * Register a compensation activity. + * + * @param activityClassName name of the activity class + * @param activityInput input of the activity to be compensated + * @param taskOptions task options to set retry strategy + */ + public void registerCompensation(String activityClassName, Object activityInput, TaskOptions taskOptions) { + if (activityClassName == null || activityClassName.isEmpty()) { + throw new IllegalArgumentException("activityClassName is required and should not be null or empty."); + } + this.compensationActivities.add(new CompensatationInformation(activityClassName, activityInput, taskOptions)); + } + + /** + * Compensate all registered activities. + * + * @param ctx Workflow context. + */ + public void compensate(WorkflowContext ctx) { + // Check if parallel compensation is enabled + // Specical case: when parallel compensation is enabled and there is only one + // compensation, we still + // compensate sequentially. + if (option.isParallelCompensation() && compensationActivities.size() > 1) { + compensateInParallel(ctx); + } else { + compensateSequentially(ctx); + } + } + + private void compensateInParallel(WorkflowContext ctx) { + List> tasks = new ArrayList<>(compensationActivities.size()); + for (CompensatationInformation compensationActivity : compensationActivities) { + Task task = executeCompensateActivity(ctx, compensationActivity); + tasks.add(task); + } + + try { + ctx.allOf(tasks).await(); + } catch (Exception e) { + throw new SagaCompensationException("Failed to compensate in parallel.", e); + } + } + + private void compensateSequentially(WorkflowContext ctx) { + SagaCompensationException sagaException = null; + for (int i = compensationActivities.size() - 1; i >= 0; i--) { + String activityClassName = compensationActivities.get(i).getCompensatationActivityClassName(); + try { + executeCompensateActivity(ctx, compensationActivities.get(i)).await(); + } catch (OrchestratorBlockedException | ContinueAsNewInterruption e) { + throw e; + } catch (Exception e) { + if (sagaException == null) { + sagaException = new SagaCompensationException( + "Exception in saga compensatation: activity=" + activityClassName, e); + ; + } else { + sagaException.addSuppressed(e); + } + + if (!option.isContinueWithError()) { + throw sagaException; + } + } + } + + if (sagaException != null) { + throw sagaException; + } + } + + private Task executeCompensateActivity(WorkflowContext ctx, CompensatationInformation info) + throws SagaCompensationException { + String activityClassName = info.getCompensatationActivityClassName(); + return ctx.callActivity(activityClassName, info.getCompensatationActivityInput(), + info.getTaskOptions()); + } +} diff --git a/sdk-workflows/src/main/java/io/dapr/workflows/saga/SagaCompensationException.java b/sdk-workflows/src/main/java/io/dapr/workflows/saga/SagaCompensationException.java new file mode 100644 index 000000000..07396d9b5 --- /dev/null +++ b/sdk-workflows/src/main/java/io/dapr/workflows/saga/SagaCompensationException.java @@ -0,0 +1,28 @@ +/* + * Copyright 2023 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.workflows.saga; + +/** + * saga compensation exception. + */ +public class SagaCompensationException extends RuntimeException { + /** + * build up a SagaCompensationException. + * @param message exception message + * @param cause exception cause + */ + public SagaCompensationException(String message, Exception cause) { + super(message, cause); + } +} \ No newline at end of file diff --git a/sdk-workflows/src/main/java/io/dapr/workflows/saga/SagaContext.java b/sdk-workflows/src/main/java/io/dapr/workflows/saga/SagaContext.java new file mode 100644 index 000000000..03470ff92 --- /dev/null +++ b/sdk-workflows/src/main/java/io/dapr/workflows/saga/SagaContext.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.workflows.saga; + +/** + * Saga context. + */ +public interface SagaContext { + /** + * Register a compensation activity. + * + * @param activityClassName name of the activity class + * @param activityInput input of the activity to be compensated + */ + void registerCompensation(String activityClassName, Object activityInput); + + /** + * Compensate all registered activities. + * + */ + void compensate(); + +} diff --git a/sdk-workflows/src/main/java/io/dapr/workflows/saga/SagaOption.java b/sdk-workflows/src/main/java/io/dapr/workflows/saga/SagaOption.java new file mode 100644 index 000000000..b13b2af77 --- /dev/null +++ b/sdk-workflows/src/main/java/io/dapr/workflows/saga/SagaOption.java @@ -0,0 +1,102 @@ +/* + * Copyright 2023 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.workflows.saga; + +/** + * Saga option. + */ +public final class SagaOption { + private final boolean parallelCompensation; + private final int maxParallelThread; + private final boolean continueWithError; + + private SagaOption(boolean parallelCompensation, int maxParallelThread, boolean continueWithError) { + this.parallelCompensation = parallelCompensation; + this.maxParallelThread = maxParallelThread; + this.continueWithError = continueWithError; + } + + public boolean isParallelCompensation() { + return parallelCompensation; + } + + public boolean isContinueWithError() { + return continueWithError; + } + + public int getMaxParallelThread() { + return maxParallelThread; + } + + public static Builder newBuilder() { + return new Builder(); + } + + public static final class Builder { + // by default compensation is sequential + private boolean parallelCompensation = false; + + // by default max parallel thread is 16, it's enough for most cases + private int maxParallelThread = 16; + + // by default set continueWithError to be true + // So if a compensation fails, we should continue with the next compensations + private boolean continueWithError = true; + + /** + * Set parallel compensation. + * @param parallelCompensation parallel compensation or not + * @return this builder itself + */ + public Builder setParallelCompensation(boolean parallelCompensation) { + this.parallelCompensation = parallelCompensation; + return this; + } + + /** + * set max parallel thread. + * + *

Only valid when parallelCompensation is true. + * @param maxParallelThread max parallel thread + * @return this builder itself + */ + public Builder setMaxParallelThread(int maxParallelThread) { + if (maxParallelThread <= 2) { + throw new IllegalArgumentException("maxParallelThread should be greater than 1."); + } + this.maxParallelThread = maxParallelThread; + return this; + } + + /** + * Set continue with error. + * + *

Only valid when parallelCompensation is false. + * @param continueWithError continue with error or not + * @return this builder itself + */ + public Builder setContinueWithError(boolean continueWithError) { + this.continueWithError = continueWithError; + return this; + } + + /** + * Build Saga optiion. + * @return Saga optiion + */ + public SagaOption build() { + return new SagaOption(this.parallelCompensation, this.maxParallelThread, this.continueWithError); + } + } +} diff --git a/sdk-workflows/src/test/java/io/dapr/workflows/DaprWorkflowContextImplTest.java b/sdk-workflows/src/test/java/io/dapr/workflows/DaprWorkflowContextImplTest.java index 8c0ce49d4..3ea03ddbb 100644 --- a/sdk-workflows/src/test/java/io/dapr/workflows/DaprWorkflowContextImplTest.java +++ b/sdk-workflows/src/test/java/io/dapr/workflows/DaprWorkflowContextImplTest.java @@ -20,6 +20,9 @@ import com.microsoft.durabletask.TaskOptions; import com.microsoft.durabletask.TaskOrchestrationContext; +import io.dapr.workflows.saga.Saga; +import io.dapr.workflows.saga.SagaContext; + import org.jetbrains.annotations.Nullable; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -31,6 +34,7 @@ import java.util.Arrays; import java.util.List; +import static org.junit.Assert.assertNotNull; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; @@ -130,6 +134,11 @@ public Task callSubWorkflow(String name, @Nullable Object input, @Nullabl public void continueAsNew(Object input, boolean preserveUnprocessedEvents) { } + + @Override + public SagaContext getSagaContext() { + return null; + } }; } @@ -181,13 +190,13 @@ public void callActivityTest() { @Test public void DaprWorkflowContextWithEmptyInnerContext() { assertThrows(IllegalArgumentException.class, () -> { - context = new DaprWorkflowContextImpl(mockInnerContext, null); + context = new DaprWorkflowContextImpl(mockInnerContext, (Logger)null); }); } @Test public void DaprWorkflowContextWithEmptyLogger() { assertThrows(IllegalArgumentException.class, () -> { - context = new DaprWorkflowContextImpl(null, null); + context = new DaprWorkflowContextImpl(null, (Logger)null); }); } @@ -309,4 +318,21 @@ public void newUuidTestNoImplementationExceptionTest() { String expectedMessage = "No implementation found."; assertEquals(expectedMessage, runtimeException.getMessage()); } + + @Test + public void getSagaContextTest_sagaEnabled() { + Saga saga = mock(Saga.class); + WorkflowContext context = new DaprWorkflowContextImpl(mockInnerContext, saga); + + SagaContext sagaContext = context.getSagaContext(); + assertNotNull("SagaContext should not be null", sagaContext); + } + + @Test + public void getSagaContextTest_sagaDisabled() { + WorkflowContext context = new DaprWorkflowContextImpl(mockInnerContext); + assertThrows(UnsupportedOperationException.class, () -> { + context.getSagaContext(); + }); + } } diff --git a/sdk-workflows/src/test/java/io/dapr/workflows/WorkflowTest.java b/sdk-workflows/src/test/java/io/dapr/workflows/WorkflowTest.java new file mode 100644 index 000000000..528af3191 --- /dev/null +++ b/sdk-workflows/src/test/java/io/dapr/workflows/WorkflowTest.java @@ -0,0 +1,197 @@ +package io.dapr.workflows; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import org.junit.Test; + +import com.microsoft.durabletask.interruption.ContinueAsNewInterruption; +import com.microsoft.durabletask.interruption.OrchestratorBlockedException; + +import io.dapr.workflows.saga.SagaCompensationException; +import io.dapr.workflows.saga.SagaContext; +import io.dapr.workflows.saga.SagaOption; + +public class WorkflowTest { + + @Test + public void testWorkflow_WithoutSaga() { + WorkflowStub stub = mock(WorkflowStub.class); + Workflow workflow = new WorkflowWithoutSaga(stub); + assertNull(workflow.getSagaOption()); + assertFalse(workflow.isSagaEnabled()); + + WorkflowContext ctx = mock(WorkflowContext.class); + doNothing().when(stub).run(ctx); + workflow.run(ctx); + + verify(stub, times(1)).run(eq(ctx)); + } + + @Test + public void testWorkflow_WithoutSaga_throwException() { + WorkflowStub stub = mock(WorkflowStub.class); + Workflow workflow = new WorkflowWithoutSaga(stub); + WorkflowContext ctx = mock(WorkflowContext.class); + Exception e = new RuntimeException(); + doThrow(e).when(stub).run(ctx); + + // should throw the exception, not catch + assertThrows(RuntimeException.class, () -> { + workflow.run(ctx); + }); + verify(stub, times(1)).run(eq(ctx)); + } + + @Test + public void testWorkflow_WithSaga() { + WorkflowStub stub = mock(WorkflowStub.class); + Workflow workflow = new WorkflowWithSaga(stub); + assertNotNull(workflow.getSagaOption()); + assertTrue(workflow.isSagaEnabled()); + + WorkflowContext ctx = mock(WorkflowContext.class); + doNothing().when(stub).run(ctx); + workflow.run(ctx); + + verify(stub, times(1)).run(eq(ctx)); + } + + @Test + public void testWorkflow_WithSaga_shouldNotCatch_OrchestratorBlockedException() { + WorkflowStub stub = mock(WorkflowStub.class); + Workflow workflow = new WorkflowWithSaga(stub); + + WorkflowContext ctx = mock(WorkflowContext.class); + Exception e = new OrchestratorBlockedException("test"); + doThrow(e).when(stub).run(ctx); + + // should not catch OrchestratorBlockedException + assertThrows(OrchestratorBlockedException.class, () -> { + workflow.run(ctx); + }); + verify(stub, times(1)).run(eq(ctx)); + } + + @Test + public void testWorkflow_WithSaga_shouldNotCatch_ContinueAsNewInterruption() { + WorkflowStub stub = mock(WorkflowStub.class); + Workflow workflow = new WorkflowWithSaga(stub); + + WorkflowContext ctx = mock(WorkflowContext.class); + Exception e = new ContinueAsNewInterruption("test"); + doThrow(e).when(stub).run(ctx); + + // should not catch ContinueAsNewInterruption + assertThrows(ContinueAsNewInterruption.class, () -> { + workflow.run(ctx); + }); + verify(stub, times(1)).run(eq(ctx)); + } + + @Test + public void testWorkflow_WithSaga_shouldNotCatch_SagaCompensationException() { + WorkflowStub stub = mock(WorkflowStub.class); + Workflow workflow = new WorkflowWithSaga(stub); + + WorkflowContext ctx = mock(WorkflowContext.class); + Exception e = new SagaCompensationException("test", null); + doThrow(e).when(stub).run(ctx); + + // should not catch SagaCompensationException + assertThrows(SagaCompensationException.class, () -> { + workflow.run(ctx); + }); + verify(stub, times(1)).run(eq(ctx)); + } + + @Test + public void testWorkflow_WithSaga_triggerCompensate() { + WorkflowStub stub = mock(WorkflowStub.class); + Workflow workflow = new WorkflowWithSaga(stub); + + WorkflowContext ctx = mock(WorkflowContext.class); + Exception e = new RuntimeException("test", null); + doThrow(e).when(stub).run(ctx); + SagaContext sagaContext = mock(SagaContext.class); + doReturn(sagaContext).when(ctx).getSagaContext(); + doNothing().when(sagaContext).compensate(); + + assertThrows(RuntimeException.class, () -> { + workflow.run(ctx); + }); + verify(stub, times(1)).run(eq(ctx)); + verify(sagaContext, times(1)).compensate(); + } + + @Test + public void testWorkflow_WithSaga_compensateFaile() { + WorkflowStub stub = mock(WorkflowStub.class); + Workflow workflow = new WorkflowWithSaga(stub); + + WorkflowContext ctx = mock(WorkflowContext.class); + Exception e = new RuntimeException("workflow fail", null); + doThrow(e).when(stub).run(ctx); + SagaContext sagaContext = mock(SagaContext.class); + doReturn(sagaContext).when(ctx).getSagaContext(); + Exception e2 = new RuntimeException("compensate fail", null); + doThrow(e2).when(sagaContext).compensate(); + + try { + workflow.run(ctx); + fail("sholdd throw exception"); + } catch (Exception ex) { + assertEquals(e2.getMessage(), ex.getMessage()); + assertEquals(1, ex.getSuppressed().length); + assertEquals(e.getMessage(), ex.getSuppressed()[0].getMessage()); + } + + verify(stub, times(1)).run(eq(ctx)); + verify(sagaContext, times(1)).compensate(); + } + + public static class WorkflowWithoutSaga extends Workflow { + private final WorkflowStub stub; + + public WorkflowWithoutSaga(WorkflowStub stub) { + this.stub = stub; + } + + @Override + public WorkflowStub create() { + return stub; + } + } + + public static class WorkflowWithSaga extends Workflow { + private final WorkflowStub stub; + + public WorkflowWithSaga(WorkflowStub stub) { + this.stub = stub; + } + + @Override + public WorkflowStub create() { + return stub; + } + + @Override + public SagaOption getSagaOption() { + return SagaOption.newBuilder() + .setParallelCompensation(false) + .build(); + } + } +} diff --git a/sdk-workflows/src/test/java/io/dapr/workflows/saga/DaprSagaContextImplTest.java b/sdk-workflows/src/test/java/io/dapr/workflows/saga/DaprSagaContextImplTest.java new file mode 100644 index 000000000..9c1918a41 --- /dev/null +++ b/sdk-workflows/src/test/java/io/dapr/workflows/saga/DaprSagaContextImplTest.java @@ -0,0 +1,54 @@ +package io.dapr.workflows.saga; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import org.junit.Test; + +import io.dapr.workflows.WorkflowContext; + +public class DaprSagaContextImplTest { + + @Test + public void testDaprSagaContextImpl_IllegalArgumentException() { + Saga saga = mock(Saga.class); + WorkflowContext workflowContext = mock(WorkflowContext.class); + + assertThrows(IllegalArgumentException.class, () -> { + new DaprSagaContextImpl(saga, null); + }); + + assertThrows(IllegalArgumentException.class, () -> { + new DaprSagaContextImpl(null, workflowContext); + }); + } + + @Test + public void test_registerCompensation() { + Saga saga = mock(Saga.class); + WorkflowContext workflowContext = mock(WorkflowContext.class); + DaprSagaContextImpl ctx = new DaprSagaContextImpl(saga, workflowContext); + + String activityClassName = "name1"; + Object activityInput = new Object(); + doNothing().when(saga).registerCompensation(activityClassName, activityInput); + + ctx.registerCompensation(activityClassName, activityInput); + verify(saga, times(1)).registerCompensation(activityClassName, activityInput); + } + + @Test + public void test_compensate() { + Saga saga = mock(Saga.class); + WorkflowContext workflowContext = mock(WorkflowContext.class); + DaprSagaContextImpl ctx = new DaprSagaContextImpl(saga, workflowContext); + + doNothing().when(saga).compensate(workflowContext); + + ctx.compensate(); + verify(saga, times(1)).compensate(workflowContext); + } +} diff --git a/sdk-workflows/src/test/java/io/dapr/workflows/saga/SagaIntegrationTest.java b/sdk-workflows/src/test/java/io/dapr/workflows/saga/SagaIntegrationTest.java new file mode 100644 index 000000000..0a39d64f2 --- /dev/null +++ b/sdk-workflows/src/test/java/io/dapr/workflows/saga/SagaIntegrationTest.java @@ -0,0 +1,324 @@ +package io.dapr.workflows.saga; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.Test; + +import com.microsoft.durabletask.TaskActivityContext; + +import io.dapr.workflows.runtime.WorkflowActivity; +import io.dapr.workflows.runtime.WorkflowActivityContext; + +public class SagaIntegrationTest { + + private static int count = 0; + private static Object countLock = new Object(); + + @Test + public void testSaga_CompensateSequentially() { + int runCount = 10; + int succeedCount = 0; + int compensateCount = 0; + + for (int i = 0; i < runCount; i++) { + boolean isSuccueed = doExecuteWorkflowWithSaga(false); + if (isSuccueed) { + succeedCount++; + } else { + compensateCount++; + } + } + + System.out.println("Run workflow with saga " + runCount + " times: succeed " + succeedCount + + " times, failed and compensated " + compensateCount + " times"); + } + + @Test + public void testSaga_compensateInParallel() { + int runCount = 100; + int succeedCount = 0; + int compensateCount = 0; + + for (int i = 0; i < runCount; i++) { + boolean isSuccueed = doExecuteWorkflowWithSaga(true); + if (isSuccueed) { + succeedCount++; + } else { + compensateCount++; + } + } + + System.out.println("Run workflow with saga " + runCount + " times: succeed " + succeedCount + + " times, failed and compensated " + compensateCount + " times"); + } + + private boolean doExecuteWorkflowWithSaga(boolean parallelCompensation) { + SagaOption config = SagaOption.newBuilder() + .setParallelCompensation(parallelCompensation) + .setContinueWithError(true).build(); + Saga saga = new Saga(config); + boolean workflowSuccess = false; + + // reset count to zero + synchronized (countLock) { + count = 0; + } + + Integer addInput = 100; + Integer subtractInput = 20; + Integer multiplyInput = 10; + Integer divideInput = 5; + + try { + // step1: add activity + String result = callActivity(AddActivity.class.getName(), addInput, String.class); + saga.registerCompensation(AddCompentationActivity.class.getName(), addInput); + // step2: subtract activity + result = callActivity(SubtractActivity.class.getName(), subtractInput, String.class); + saga.registerCompensation(SubtractCompentationActivity.class.getName(), subtractInput); + + if (parallelCompensation) { + // only add/subtract activities support parallel compensation + // so in step3 and step4 we repeat add/subtract activities + + // step3: add activity again + result = callActivity(AddActivity.class.getName(), addInput, String.class); + saga.registerCompensation(AddCompentationActivity.class.getName(), addInput); + + // step4: substract activity again + result = callActivity(SubtractActivity.class.getName(), subtractInput, String.class); + saga.registerCompensation(SubtractCompentationActivity.class.getName(), subtractInput); + } else { + // step3: multiply activity + result = callActivity(MultiplyActivity.class.getName(), multiplyInput, String.class); + saga.registerCompensation(MultiplyCompentationActivity.class.getName(), multiplyInput); + + // step4: divide activity + result = callActivity(DivideActivity.class.getName(), divideInput, String.class); + saga.registerCompensation(DivideCompentationActivity.class.getName(), divideInput); + } + + randomFail(); + + workflowSuccess = true; + } catch (Exception e) { + saga.compensate(SagaTest.createMockContext()); + } + + if (workflowSuccess) { + int expectResult = 0; + if (parallelCompensation) { + expectResult = 0 + addInput - subtractInput + addInput - subtractInput; + } else { + expectResult = (0 + addInput - subtractInput) * multiplyInput / divideInput; + } + assertEquals(expectResult, count); + } else { + assertEquals(0, count); + } + + return workflowSuccess; + } + + // mock to call activity in dapr workflow + private V callActivity(String activityClassName, Object input, Class returnType) { + try { + Class activityClass = Class.forName(activityClassName); + WorkflowActivity activity = (WorkflowActivity) activityClass.getDeclaredConstructor().newInstance(); + WorkflowActivityContext ctx = new WorkflowActivityContext(new TaskActivityContext() { + + @Override + public java.lang.String getName() { + return activityClassName; + } + + @Override + public T getInput(Class targetType) { + return (T) input; + } + }); + + randomFail(); + + return (V) activity.run(ctx); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static void randomFail() { + int randomInt = (int) (Math.random() * 100); + // if randomInt mod 10 is 0, then throw exception + if (randomInt % 10 == 0) { + throw new RuntimeException("random fail"); + } + } + + public static class AddActivity implements WorkflowActivity { + + @Override + public String run(WorkflowActivityContext ctx) { + Integer input = ctx.getInput(Integer.class); + + int originalCount = 0; + int updatedCount = 0; + synchronized (countLock) { + originalCount = count; + updatedCount = originalCount + input; + count = updatedCount; + } + + String resultString = "current count is updated from " + originalCount + " to " + updatedCount + + " after adding " + input; + // System.out.println(resultString); + return resultString; + } + } + + public static class AddCompentationActivity implements WorkflowActivity { + + @Override + public String run(WorkflowActivityContext ctx) { + Integer input = ctx.getInput(Integer.class); + + int originalCount = 0; + int updatedCount = 0; + synchronized (countLock) { + originalCount = count; + updatedCount = originalCount - input; + count = updatedCount; + } + + String resultString = "current count is compensated from " + originalCount + " to " + + updatedCount + " after compensate adding " + input; + // System.out.println(resultString); + return resultString; + } + } + + public static class SubtractActivity implements WorkflowActivity { + + @Override + public String run(WorkflowActivityContext ctx) { + Integer input = ctx.getInput(Integer.class); + + int originalCount = 0; + int updatedCount = 0; + synchronized (countLock) { + originalCount = count; + updatedCount = originalCount - input; + count = updatedCount; + } + + String resultString = "current count is updated from " + originalCount + " to " + updatedCount + + " after substracting " + input; + // System.out.println(resultString); + return resultString; + } + } + + public static class SubtractCompentationActivity implements WorkflowActivity { + + @Override + public String run(WorkflowActivityContext ctx) { + Integer input = ctx.getInput(Integer.class); + + int originalCount = 0; + int updatedCount = 0; + synchronized (countLock) { + originalCount = count; + updatedCount = originalCount + input; + count = updatedCount; + } + + String resultString = "current count is compensated from " + originalCount + " to " + updatedCount + + " after compensate substracting " + input; + // System.out.println(resultString); + return resultString; + } + } + + public static class MultiplyActivity implements WorkflowActivity { + + @Override + public String run(WorkflowActivityContext ctx) { + Integer input = ctx.getInput(Integer.class); + + int originalCount = 0; + int updatedCount = 0; + synchronized (countLock) { + originalCount = count; + updatedCount = originalCount * input; + count = updatedCount; + } + + String resultString = "current count is updated from " + originalCount + " to " + updatedCount + + " after multiplying " + input; + // System.out.println(resultString); + return resultString; + } + } + + public static class MultiplyCompentationActivity implements WorkflowActivity { + + @Override + public String run(WorkflowActivityContext ctx) { + Integer input = ctx.getInput(Integer.class); + + int originalCount = 0; + int updatedCount = 0; + synchronized (countLock) { + originalCount = count; + updatedCount = originalCount / input; + count = updatedCount; + } + + String resultString = "current count is compensated from " + originalCount + " to " + updatedCount + + " after compensate multiplying " + input; + // System.out.println(resultString); + return resultString; + } + } + + public static class DivideActivity implements WorkflowActivity { + + @Override + public String run(WorkflowActivityContext ctx) { + Integer input = ctx.getInput(Integer.class); + + int originalCount = 0; + int updatedCount = 0; + synchronized (countLock) { + originalCount = count; + updatedCount = originalCount / input; + count = updatedCount; + } + + String resultString = "current count is updated from " + originalCount + " to " + updatedCount + + " after dividing " + input; + // System.out.println(resultString); + return resultString; + } + } + + public static class DivideCompentationActivity implements WorkflowActivity { + + @Override + public String run(WorkflowActivityContext ctx) { + Integer input = ctx.getInput(Integer.class); + + int originalCount = 0; + int updatedCount = 0; + synchronized (countLock) { + originalCount = count; + updatedCount = originalCount * input; + count = updatedCount; + } + + String resultString = "current count is compensated from " + originalCount + " to " + updatedCount + + " after compensate dividing " + input; + // System.out.println(resultString); + return resultString; + } + } +} diff --git a/sdk-workflows/src/test/java/io/dapr/workflows/saga/SagaOptionTest.java b/sdk-workflows/src/test/java/io/dapr/workflows/saga/SagaOptionTest.java new file mode 100644 index 000000000..996f199dc --- /dev/null +++ b/sdk-workflows/src/test/java/io/dapr/workflows/saga/SagaOptionTest.java @@ -0,0 +1,50 @@ +package io.dapr.workflows.saga; + +import static org.junit.Assert.assertThrows; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.Test; + +public class SagaOptionTest { + + @Test + public void testBuild() { + SagaOption.Builder builder = SagaOption.newBuilder(); + builder.setParallelCompensation(true); + builder.setMaxParallelThread(32); + builder.setContinueWithError(false); + SagaOption option = builder.build(); + + assertEquals(true, option.isParallelCompensation()); + assertEquals(32, option.getMaxParallelThread()); + assertEquals(false, option.isContinueWithError()); + } + + @Test + public void testBuild_default() { + SagaOption.Builder builder = SagaOption.newBuilder(); + SagaOption option = builder.build(); + + assertEquals(false, option.isParallelCompensation()); + assertEquals(16, option.getMaxParallelThread()); + assertEquals(true, option.isContinueWithError()); + } + + @Test + public void testsetMaxParallelThread() { + SagaOption.Builder builder = SagaOption.newBuilder(); + + assertThrows(IllegalArgumentException.class, () -> { + builder.setMaxParallelThread(0); + }); + + assertThrows(IllegalArgumentException.class, () -> { + builder.setMaxParallelThread(1); + }); + + assertThrows(IllegalArgumentException.class, () -> { + builder.setMaxParallelThread(-1); + }); + } + +} diff --git a/sdk-workflows/src/test/java/io/dapr/workflows/saga/SagaTest.java b/sdk-workflows/src/test/java/io/dapr/workflows/saga/SagaTest.java new file mode 100644 index 000000000..314565509 --- /dev/null +++ b/sdk-workflows/src/test/java/io/dapr/workflows/saga/SagaTest.java @@ -0,0 +1,454 @@ +/* + * Copyright 2023 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ +package io.dapr.workflows.saga; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import org.junit.Test; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import com.microsoft.durabletask.Task; +import com.microsoft.durabletask.TaskOptions; + +import io.dapr.workflows.WorkflowContext; +import io.dapr.workflows.runtime.WorkflowActivity; +import io.dapr.workflows.runtime.WorkflowActivityContext; + +public class SagaTest { + + public static WorkflowContext createMockContext() { + WorkflowContext workflowContext = mock(WorkflowContext.class); + when(workflowContext.callActivity(anyString(), any(), eq((TaskOptions) null))).thenAnswer(new ActivityAnswer()); + when(workflowContext.allOf(anyList())).thenAnswer(new AllActivityAnswer()); + + return workflowContext; + } + + @Test + public void testSaga_IllegalArgument() { + assertThrows(IllegalArgumentException.class, () -> { + new Saga(null); + }); + } + + @Test + public void testregisterCompensation() { + SagaOption config = SagaOption.newBuilder() + .setParallelCompensation(false) + .setContinueWithError(true).build(); + Saga saga = new Saga(config); + + saga.registerCompensation(MockActivity.class.getName(), new MockActivityInput()); + } + + @Test + public void testregisterCompensation_IllegalArgument() { + SagaOption config = SagaOption.newBuilder() + .setParallelCompensation(false) + .setContinueWithError(true).build(); + Saga saga = new Saga(config); + + assertThrows(IllegalArgumentException.class, () -> { + saga.registerCompensation(null, "input"); + }); + assertThrows(IllegalArgumentException.class, () -> { + saga.registerCompensation("", "input"); + }); + } + + @Test + public void testCompensateInParallel() { + MockCompentationActivity.compensateOrder.clear(); + + SagaOption config = SagaOption.newBuilder() + .setParallelCompensation(true).build(); + Saga saga = new Saga(config); + MockActivityInput input1 = new MockActivityInput(); + input1.setOrder(1); + saga.registerCompensation(MockCompentationActivity.class.getName(), input1); + MockActivityInput input2 = new MockActivityInput(); + input2.setOrder(2); + saga.registerCompensation(MockCompentationActivity.class.getName(), input2); + MockActivityInput input3 = new MockActivityInput(); + input3.setOrder(3); + saga.registerCompensation(MockCompentationActivity.class.getName(), input3); + + saga.compensate(createMockContext()); + + assertEquals(3, MockCompentationActivity.compensateOrder.size()); + } + + @Test + public void testCompensateInParallel_exception_1failed() { + MockCompentationActivity.compensateOrder.clear(); + + SagaOption config = SagaOption.newBuilder() + .setParallelCompensation(true).build(); + Saga saga = new Saga(config); + MockActivityInput input1 = new MockActivityInput(); + input1.setOrder(1); + saga.registerCompensation(MockCompentationActivity.class.getName(), input1); + MockActivityInput input2 = new MockActivityInput(); + input2.setOrder(2); + input2.setThrowException(true); + saga.registerCompensation(MockCompentationActivity.class.getName(), input2); + MockActivityInput input3 = new MockActivityInput(); + input3.setOrder(3); + saga.registerCompensation(MockCompentationActivity.class.getName(), input3); + + SagaCompensationException exception = assertThrows(SagaCompensationException.class, () -> { + saga.compensate(createMockContext()); + }); + assertNotNull(exception.getCause()); + // 3 compentation activities, 2 succeed, 1 failed + assertEquals(0, exception.getSuppressed().length); + assertEquals(2, MockCompentationActivity.compensateOrder.size()); + } + + @Test + public void testCompensateInParallel_exception_2failed() { + MockCompentationActivity.compensateOrder.clear(); + + SagaOption config = SagaOption.newBuilder() + .setParallelCompensation(true).build(); + Saga saga = new Saga(config); + MockActivityInput input1 = new MockActivityInput(); + input1.setOrder(1); + saga.registerCompensation(MockCompentationActivity.class.getName(), input1); + MockActivityInput input2 = new MockActivityInput(); + input2.setOrder(2); + input2.setThrowException(true); + saga.registerCompensation(MockCompentationActivity.class.getName(), input2); + MockActivityInput input3 = new MockActivityInput(); + input3.setOrder(3); + input3.setThrowException(true); + saga.registerCompensation(MockCompentationActivity.class.getName(), input3); + + SagaCompensationException exception = assertThrows(SagaCompensationException.class, () -> { + saga.compensate(createMockContext()); + }); + assertNotNull(exception.getCause()); + // 3 compentation activities, 1 succeed, 2 failed + assertEquals(1, MockCompentationActivity.compensateOrder.size()); + } + + @Test + public void testCompensateInParallel_exception_3failed() { + MockCompentationActivity.compensateOrder.clear(); + + SagaOption config = SagaOption.newBuilder() + .setParallelCompensation(true).build(); + Saga saga = new Saga(config); + MockActivityInput input1 = new MockActivityInput(); + input1.setOrder(1); + input1.setThrowException(true); + saga.registerCompensation(MockCompentationActivity.class.getName(), input1); + MockActivityInput input2 = new MockActivityInput(); + input2.setOrder(2); + input2.setThrowException(true); + saga.registerCompensation(MockCompentationActivity.class.getName(), input2); + MockActivityInput input3 = new MockActivityInput(); + input3.setOrder(3); + input3.setThrowException(true); + saga.registerCompensation(MockCompentationActivity.class.getName(), input3); + + SagaCompensationException exception = assertThrows(SagaCompensationException.class, () -> { + saga.compensate(createMockContext()); + }); + assertNotNull(exception.getCause()); + // 3 compentation activities, 0 succeed, 3 failed + assertEquals(0, MockCompentationActivity.compensateOrder.size()); + } + + @Test + public void testCompensateSequentially() { + MockCompentationActivity.compensateOrder.clear(); + + SagaOption config = SagaOption.newBuilder() + .setParallelCompensation(false).build(); + Saga saga = new Saga(config); + MockActivityInput input1 = new MockActivityInput(); + input1.setOrder(1); + saga.registerCompensation(MockCompentationActivity.class.getName(), input1); + MockActivityInput input2 = new MockActivityInput(); + input2.setOrder(2); + saga.registerCompensation(MockCompentationActivity.class.getName(), input2); + MockActivityInput input3 = new MockActivityInput(); + input3.setOrder(3); + saga.registerCompensation(MockCompentationActivity.class.getName(), input3); + + saga.compensate(createMockContext()); + + assertEquals(3, MockCompentationActivity.compensateOrder.size()); + + // the order should be 3 / 2 / 1 + assertEquals(Integer.valueOf(3), MockCompentationActivity.compensateOrder.get(0)); + assertEquals(Integer.valueOf(2), MockCompentationActivity.compensateOrder.get(1)); + assertEquals(Integer.valueOf(1), MockCompentationActivity.compensateOrder.get(2)); + } + + @Test + public void testCompensateSequentially_continueWithError() { + MockCompentationActivity.compensateOrder.clear(); + + SagaOption config = SagaOption.newBuilder() + .setParallelCompensation(false) + .setContinueWithError(true) + .build(); + Saga saga = new Saga(config); + MockActivityInput input1 = new MockActivityInput(); + input1.setOrder(1); + saga.registerCompensation(MockCompentationActivity.class.getName(), input1); + MockActivityInput input2 = new MockActivityInput(); + input2.setOrder(2); + input2.setThrowException(true); + saga.registerCompensation(MockCompentationActivity.class.getName(), input2); + MockActivityInput input3 = new MockActivityInput(); + input3.setOrder(3); + saga.registerCompensation(MockCompentationActivity.class.getName(), input3); + + SagaCompensationException exception = assertThrows(SagaCompensationException.class, () -> { + saga.compensate(createMockContext()); + }); + assertNotNull(exception.getCause()); + assertEquals(0, exception.getSuppressed().length); + + // 3 compentation activities, 2 succeed, 1 failed + assertEquals(2, MockCompentationActivity.compensateOrder.size()); + // the order should be 3 / 1 + assertEquals(Integer.valueOf(3), MockCompentationActivity.compensateOrder.get(0)); + assertEquals(Integer.valueOf(1), MockCompentationActivity.compensateOrder.get(1)); + } + + @Test + public void testCompensateSequentially_continueWithError_suppressed() { + MockCompentationActivity.compensateOrder.clear(); + + SagaOption config = SagaOption.newBuilder() + .setParallelCompensation(false) + .setContinueWithError(true) + .build(); + Saga saga = new Saga(config); + MockActivityInput input1 = new MockActivityInput(); + input1.setOrder(1); + saga.registerCompensation(MockCompentationActivity.class.getName(), input1); + MockActivityInput input2 = new MockActivityInput(); + input2.setOrder(2); + input2.setThrowException(true); + saga.registerCompensation(MockCompentationActivity.class.getName(), input2); + MockActivityInput input3 = new MockActivityInput(); + input3.setOrder(3); + input3.setThrowException(true); + saga.registerCompensation(MockCompentationActivity.class.getName(), input3); + + SagaCompensationException exception = assertThrows(SagaCompensationException.class, () -> { + saga.compensate(createMockContext()); + }); + assertNotNull(exception.getCause()); + assertEquals(1, exception.getSuppressed().length); + + // 3 compentation activities, 1 succeed, 2 failed + assertEquals(1, MockCompentationActivity.compensateOrder.size()); + // the order should be 3 / 1 + assertEquals(Integer.valueOf(1), MockCompentationActivity.compensateOrder.get(0)); + } + + @Test + public void testCompensateSequentially_notContinueWithError() { + MockCompentationActivity.compensateOrder.clear(); + + SagaOption config = SagaOption.newBuilder() + .setParallelCompensation(false) + .setContinueWithError(false) + .build(); + Saga saga = new Saga(config); + MockActivityInput input1 = new MockActivityInput(); + input1.setOrder(1); + saga.registerCompensation(MockCompentationActivity.class.getName(), input1); + MockActivityInput input2 = new MockActivityInput(); + input2.setOrder(2); + input2.setThrowException(true); + saga.registerCompensation(MockCompentationActivity.class.getName(), input2); + MockActivityInput input3 = new MockActivityInput(); + input3.setOrder(3); + saga.registerCompensation(MockCompentationActivity.class.getName(), input3); + + SagaCompensationException exception = assertThrows(SagaCompensationException.class, () -> { + saga.compensate(createMockContext()); + }); + assertNotNull(exception.getCause()); + assertEquals(0, exception.getSuppressed().length); + + // 3 compentation activities, 1 succeed, 1 failed and not continue + assertEquals(1, MockCompentationActivity.compensateOrder.size()); + // the order should be 3 / 1 + assertEquals(Integer.valueOf(3), MockCompentationActivity.compensateOrder.get(0)); + } + + public static class MockActivity implements WorkflowActivity { + + @Override + public Object run(WorkflowActivityContext ctx) { + MockActivityOutput output = new MockActivityOutput(); + output.setSucceed(true); + return output; + } + } + + public static class MockCompentationActivity implements WorkflowActivity { + + private static List compensateOrder = Collections.synchronizedList(new ArrayList<>()); + + @Override + public Object run(WorkflowActivityContext ctx) { + MockActivityInput input = ctx.getInput(MockActivityInput.class); + + if (input.isThrowException()) { + throw new RuntimeException("compensate failed: order=" + input.getOrder()); + } + + compensateOrder.add(input.getOrder()); + return null; + } + } + + public static class MockActivityInput { + private int order = 0; + private boolean throwException; + + public int getOrder() { + return order; + } + + public void setOrder(int order) { + this.order = order; + } + + public boolean isThrowException() { + return throwException; + } + + public void setThrowException(boolean throwException) { + this.throwException = throwException; + } + } + + public static class MockActivityOutput { + private boolean succeed; + + public boolean isSucceed() { + return succeed; + } + + public void setSucceed(boolean succeed) { + this.succeed = succeed; + } + } + + public static class ActivityAnswer implements Answer> { + + @Override + public Task answer(InvocationOnMock invocation) throws Throwable { + Object[] args = invocation.getArguments(); + String name = (String) args[0]; + Object input = args[1]; + + WorkflowActivity activity; + WorkflowActivityContext activityContext = Mockito.mock(WorkflowActivityContext.class); + try { + activity = (WorkflowActivity) Class.forName(name).getDeclaredConstructor().newInstance(); + } catch (Exception e) { + fail(e); + return null; + } + + Task task = mock(Task.class); + when(task.await()).thenAnswer(invocation1 -> { + Mockito.doReturn(input).when(activityContext).getInput(Mockito.any()); + activity.run(activityContext); + return null; + }); + return task; + } + + } + + public static class AllActivityAnswer implements Answer> { + @Override + public Task answer(InvocationOnMock invocation) throws Throwable { + Object[] args = invocation.getArguments(); + List> tasks = (List>) args[0]; + + ExecutorService executor = Executors.newFixedThreadPool(5); + List> compensationTasks = new ArrayList<>(); + for (Task task : tasks) { + Callable compensationTask = new Callable() { + @Override + public Void call() { + return task.await(); + } + }; + compensationTasks.add(compensationTask); + } + + List> resultFutures; + try { + resultFutures = executor.invokeAll(compensationTasks, 2, TimeUnit.SECONDS); + } catch (InterruptedException e) { + fail(e); + return null; + } + + Task task = mock(Task.class); + when(task.await()).thenAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + Exception exception = null; + for (Future resultFuture : resultFutures) { + try { + resultFuture.get(); + } catch (Exception e) { + exception = e; + } + } + if (exception != null) { + throw exception; + } + return null; + } + }); + return task; + } + } + +}