diff --git a/flytekit-examples/src/main/java/org/flyte/examples/UberWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/UberWorkflow.java index 737a01d5a..951f5a7e8 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/UberWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/UberWorkflow.java @@ -68,6 +68,14 @@ public SumWorkflow.Output expand(SdkWorkflowBuilder builder, Input input) { .result(); SdkBindingData abcd = builder.apply("post-sum", new SumTask(), SumTask.SumInput.create(abc, d)).getOutputs(); - return SumWorkflow.Output.create(abcd); + SdkBindingData result = + builder + .apply( + "fibonacci", + new DynamicFibonacciWorkflowTask(), + DynamicFibonacciWorkflowTask.Input.create(abcd)) + .getOutputs() + .output(); + return SumWorkflow.Output.create(result); } } diff --git a/flytekit-examples/src/test/java/org/flyte/examples/WorkflowTest.java b/flytekit-examples/src/test/java/org/flyte/examples/WorkflowTest.java index b79a68ef9..d298d60cd 100644 --- a/flytekit-examples/src/test/java/org/flyte/examples/WorkflowTest.java +++ b/flytekit-examples/src/test/java/org/flyte/examples/WorkflowTest.java @@ -54,9 +54,13 @@ public void testMockTasks() { new SumTask(), SumTask.SumInput.create(SdkBindingDataFactory.of(0L), SdkBindingDataFactory.of(4L)), SdkBindingDataFactory.of(42L)) + .withTaskOutput( + new DynamicFibonacciWorkflowTask(), + DynamicFibonacciWorkflowTask.Input.create(SdkBindingDataFactory.of(42L)), + DynamicFibonacciWorkflowTask.Output.create(SdkBindingDataFactory.of(123L))) .execute(); - assertEquals(42L, result.getIntegerOutput("result")); + assertEquals(123L, result.getIntegerOutput("result")); } @Test @@ -87,9 +91,12 @@ public void testMockSubWorkflow() { new SumTask(), SumInput.create(SdkBindingDataFactory.of(10L), SdkBindingDataFactory.of(4L)), SdkBindingDataFactory.of(15L)) + .withTask( + new DynamicFibonacciWorkflowTask(), + input -> DynamicFibonacciWorkflowTask.Output.create(SdkBindingDataFactory.of(42L))) .execute(); - assertEquals(15L, result.getIntegerOutput("result")); + assertEquals(42L, result.getIntegerOutput("result")); } @Test diff --git a/flytekit-testing/src/main/java/org/flyte/flytekit/testing/SdkTestingExecutor.java b/flytekit-testing/src/main/java/org/flyte/flytekit/testing/SdkTestingExecutor.java index b187ff77f..ec5528f12 100644 --- a/flytekit-testing/src/main/java/org/flyte/flytekit/testing/SdkTestingExecutor.java +++ b/flytekit-testing/src/main/java/org/flyte/flytekit/testing/SdkTestingExecutor.java @@ -42,9 +42,11 @@ import org.flyte.api.v1.WorkflowNode; import org.flyte.api.v1.WorkflowNode.Reference; import org.flyte.api.v1.WorkflowTemplate; +import org.flyte.flytekit.SdkDynamicWorkflowTask; import org.flyte.flytekit.SdkRemoteLaunchPlan; import org.flyte.flytekit.SdkRemoteTask; import org.flyte.flytekit.SdkRunnableTask; +import org.flyte.flytekit.SdkTransform; import org.flyte.flytekit.SdkType; import org.flyte.flytekit.SdkWorkflow; import org.flyte.localengine.ExecutionContext; @@ -321,20 +323,27 @@ public SdkTestingExecutor withFixedInputs(SdkType type, T value) { public SdkTestingExecutor withTaskOutput( SdkRunnableTask task, InputT input, OutputT output) { - TestingRunnableTask fixedTask = - getFixedTaskOrDefault(task.getName(), task.getInputType(), task.getOutputType()); - - return toBuilder() - .putFixedTask(task.getName(), fixedTask.withFixedOutput(input, output)) - .build(); + return withTaskOutput0(task, input, output); } public SdkTestingExecutor withTaskOutput( SdkRemoteTask task, InputT input, OutputT output) { + return withTaskOutput0(task, input, output); + } + + public SdkTestingExecutor withTaskOutput( + SdkDynamicWorkflowTask task, InputT input, OutputT output) { + return withTaskOutput0(task, input, output); + } + + private SdkTestingExecutor withTaskOutput0( + SdkTransform task, InputT input, OutputT output) { TestingRunnableTask fixedTask = - getFixedTaskOrDefault(task.name(), task.inputs(), task.outputs()); + getFixedTaskOrDefault(task.getName(), task.getInputType(), task.getOutputType()); - return toBuilder().putFixedTask(task.name(), fixedTask.withFixedOutput(input, output)).build(); + return toBuilder() + .putFixedTask(task.getName(), fixedTask.withFixedOutput(input, output)) + .build(); } public SdkTestingExecutor withLaunchPlanOutput( @@ -361,6 +370,16 @@ public SdkTestingExecutor withLaunchPlan( public SdkTestingExecutor withTask( SdkRunnableTask task, Function runFn) { + return withTask0(task, runFn); + } + + public SdkTestingExecutor withTask( + SdkDynamicWorkflowTask task, Function runFn) { + return withTask0(task, runFn); + } + + private SdkTestingExecutor withTask0( + SdkTransform task, Function runFn) { TestingRunnableTask fixedTask = getFixedTaskOrDefault(task.getName(), task.getInputType(), task.getOutputType());