From 1db25b89f432622b81eab3b3df6721998a41b95c Mon Sep 17 00:00:00 2001 From: Hongxin Liang Date: Thu, 28 Sep 2023 14:44:18 +0200 Subject: [PATCH] Support testing dynamic workflow task Signed-off-by: Hongxin Liang --- .../java/org/flyte/examples/UberWorkflow.java | 10 +++++- .../java/org/flyte/examples/WorkflowTest.java | 11 +++++-- .../flytekit/testing/SdkTestingExecutor.java | 32 ++++++++++++++++--- pom.xml | 12 +++---- 4 files changed, 52 insertions(+), 13 deletions(-) diff --git a/flytekit-examples/src/main/java/org/flyte/examples/UberWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/UberWorkflow.java index 737a01d5a..951f5a7e8 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/UberWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/UberWorkflow.java @@ -68,6 +68,14 @@ public SumWorkflow.Output expand(SdkWorkflowBuilder builder, Input input) { .result(); SdkBindingData abcd = builder.apply("post-sum", new SumTask(), SumTask.SumInput.create(abc, d)).getOutputs(); - return SumWorkflow.Output.create(abcd); + SdkBindingData result = + builder + .apply( + "fibonacci", + new DynamicFibonacciWorkflowTask(), + DynamicFibonacciWorkflowTask.Input.create(abcd)) + .getOutputs() + .output(); + return SumWorkflow.Output.create(result); } } diff --git a/flytekit-examples/src/test/java/org/flyte/examples/WorkflowTest.java b/flytekit-examples/src/test/java/org/flyte/examples/WorkflowTest.java index b79a68ef9..d298d60cd 100644 --- a/flytekit-examples/src/test/java/org/flyte/examples/WorkflowTest.java +++ b/flytekit-examples/src/test/java/org/flyte/examples/WorkflowTest.java @@ -54,9 +54,13 @@ public void testMockTasks() { new SumTask(), SumTask.SumInput.create(SdkBindingDataFactory.of(0L), SdkBindingDataFactory.of(4L)), SdkBindingDataFactory.of(42L)) + .withTaskOutput( + new DynamicFibonacciWorkflowTask(), + DynamicFibonacciWorkflowTask.Input.create(SdkBindingDataFactory.of(42L)), + DynamicFibonacciWorkflowTask.Output.create(SdkBindingDataFactory.of(123L))) .execute(); - assertEquals(42L, result.getIntegerOutput("result")); + assertEquals(123L, result.getIntegerOutput("result")); } @Test @@ -87,9 +91,12 @@ public void testMockSubWorkflow() { new SumTask(), SumInput.create(SdkBindingDataFactory.of(10L), SdkBindingDataFactory.of(4L)), SdkBindingDataFactory.of(15L)) + .withTask( + new DynamicFibonacciWorkflowTask(), + input -> DynamicFibonacciWorkflowTask.Output.create(SdkBindingDataFactory.of(42L))) .execute(); - assertEquals(15L, result.getIntegerOutput("result")); + assertEquals(42L, result.getIntegerOutput("result")); } @Test diff --git a/flytekit-testing/src/main/java/org/flyte/flytekit/testing/SdkTestingExecutor.java b/flytekit-testing/src/main/java/org/flyte/flytekit/testing/SdkTestingExecutor.java index b187ff77f..1f1996fac 100644 --- a/flytekit-testing/src/main/java/org/flyte/flytekit/testing/SdkTestingExecutor.java +++ b/flytekit-testing/src/main/java/org/flyte/flytekit/testing/SdkTestingExecutor.java @@ -42,9 +42,11 @@ import org.flyte.api.v1.WorkflowNode; import org.flyte.api.v1.WorkflowNode.Reference; import org.flyte.api.v1.WorkflowTemplate; +import org.flyte.flytekit.SdkDynamicWorkflowTask; import org.flyte.flytekit.SdkRemoteLaunchPlan; import org.flyte.flytekit.SdkRemoteTask; import org.flyte.flytekit.SdkRunnableTask; +import org.flyte.flytekit.SdkTransform; import org.flyte.flytekit.SdkType; import org.flyte.flytekit.SdkWorkflow; import org.flyte.localengine.ExecutionContext; @@ -321,8 +323,13 @@ public SdkTestingExecutor withFixedInputs(SdkType type, T value) { public SdkTestingExecutor withTaskOutput( SdkRunnableTask task, InputT input, OutputT output) { + return withTaskOutput0(task, input, output); + } + + public SdkTestingExecutor withTaskOutput( + SdkRemoteTask task, InputT input, OutputT output) { TestingRunnableTask fixedTask = - getFixedTaskOrDefault(task.getName(), task.getInputType(), task.getOutputType()); + getFixedTaskOrDefault(task.name(), task.inputs(), task.outputs()); return toBuilder() .putFixedTask(task.getName(), fixedTask.withFixedOutput(input, output)) @@ -330,11 +337,18 @@ public SdkTestingExecutor withTaskOutput( } public SdkTestingExecutor withTaskOutput( - SdkRemoteTask task, InputT input, OutputT output) { + SdkDynamicWorkflowTask task, InputT input, OutputT output) { + return withTaskOutput0(task, input, output); + } + + private SdkTestingExecutor withTaskOutput0( + SdkTransform task, InputT input, OutputT output) { TestingRunnableTask fixedTask = - getFixedTaskOrDefault(task.name(), task.inputs(), task.outputs()); + getFixedTaskOrDefault(task.getName(), task.getInputType(), task.getOutputType()); - return toBuilder().putFixedTask(task.name(), fixedTask.withFixedOutput(input, output)).build(); + return toBuilder() + .putFixedTask(task.getName(), fixedTask.withFixedOutput(input, output)) + .build(); } public SdkTestingExecutor withLaunchPlanOutput( @@ -361,6 +375,16 @@ public SdkTestingExecutor withLaunchPlan( public SdkTestingExecutor withTask( SdkRunnableTask task, Function runFn) { + return withTask0(task, runFn); + } + + public SdkTestingExecutor withTask( + SdkDynamicWorkflowTask task, Function runFn) { + return withTask0(task, runFn); + } + + private SdkTestingExecutor withTask0( + SdkTransform task, Function runFn) { TestingRunnableTask fixedTask = getFixedTaskOrDefault(task.getName(), task.getInputType(), task.getOutputType()); diff --git a/pom.xml b/pom.xml index c4ae8aca3..78692c3f5 100644 --- a/pom.xml +++ b/pom.xml @@ -49,12 +49,12 @@ flytekit-api flytekit-jackson flytekit-java - flytekit-scala_2.12 - flytekit-scala_2.13 - flytekit-scala-tests + + + flytekit-testing flytekit-examples - flytekit-examples-scala + flytekit-local-engine flyteidl-protos jflyte-api @@ -102,8 +102,8 @@ 5.6.2 - - -Xep:AutoValueImmutableFields:OFF -Xep:Var:ERROR + + 11 ${maven.compiler.release}