From 66efab89623c7446116fca0b351367ffbd2685d2 Mon Sep 17 00:00:00 2001 From: Honnix Date: Mon, 2 Oct 2023 10:35:53 +0200 Subject: [PATCH] Fetch task template in dynamic workflow task (#254) * Fetch task template Signed-off-by: Hongxin Liang * Update integration test Signed-off-by: Hongxin Liang * Deploy scala examples to staging Signed-off-by: Hongxin Liang * Expose it Signed-off-by: Hongxin Liang --------- Signed-off-by: Hongxin Liang --- .../org.flyte.flytekit.SdkRunnableTask | 1 + .../DynamicFibonacciWorkflowTask.java | 19 ++++- .../org/flyte/examples/FlyteEnvironment.java | 3 +- integration-tests/pom.xml | 5 ++ .../test/java/org/flyte/JavaExamplesIT.java | 7 +- .../org/flyte/utils/FlyteSandboxClient.java | 16 +++-- .../flyte/jflyte/utils/FlyteAdminClient.java | 24 ++++--- .../flyte/jflyte/utils/ProjectClosure.java | 15 ++-- .../java/org/flyte/jflyte/utils/Fixtures.java | 62 ++++++++++++++++ .../jflyte/utils/FlyteAdminClientTest.java | 72 ++++++++++--------- .../jflyte/utils/ProjectClosureTest.java | 69 ++++++++++++++++++ .../flyte/jflyte/ExecuteDynamicWorkflow.java | 25 ++++++- 12 files changed, 254 insertions(+), 64 deletions(-) create mode 100644 jflyte-utils/src/test/java/org/flyte/jflyte/utils/Fixtures.java diff --git a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask index 56bd0e960..0fc19c133 100644 --- a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask +++ b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask @@ -2,3 +2,4 @@ org.flyte.examples.flytekitscala.HelloWorldTask org.flyte.examples.flytekitscala.SumTask org.flyte.examples.flytekitscala.GreetTask org.flyte.examples.flytekitscala.AddQuestionTask +org.flyte.examples.flytekitscala.NoInputsTask diff --git a/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java b/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java index 37c04d5e7..abc2bb121 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java @@ -16,8 +16,9 @@ */ package org.flyte.examples; -import static org.flyte.examples.FlyteEnvironment.DOMAIN; +import static org.flyte.examples.FlyteEnvironment.DEVELOPMENT_DOMAIN; import static org.flyte.examples.FlyteEnvironment.PROJECT; +import static org.flyte.examples.FlyteEnvironment.STAGING_DOMAIN; import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; @@ -65,22 +66,34 @@ public Output run(SdkWorkflowBuilder builder, Input input) { } else if (input.n().get() == 0) { return Output.create(SdkBindingDataFactory.of(0)); } else { + // remote task that is discoverable in current classpath SdkNode hello = builder.apply( "hello", SdkRemoteTask.create( - DOMAIN, + DEVELOPMENT_DOMAIN, PROJECT, HelloWorldTask.class.getName(), SdkTypes.nulls(), SdkTypes.nulls())); + // a fully remote task + SdkNode world = + builder.apply( + "world", + SdkRemoteTask.create( + STAGING_DOMAIN, + PROJECT, + "org.flyte.examples.flytekitscala.NoInputsTask", + SdkTypes.nulls(), + SdkTypes.nulls()) + .withUpstreamNode(hello)); @Var SdkBindingData prev = SdkBindingDataFactory.of(0); @Var SdkBindingData value = SdkBindingDataFactory.of(1); for (int i = 2; i <= input.n().get(); i++) { SdkBindingData next = builder .apply( - "fib-" + i, new SumTask().withUpstreamNode(hello), SumInput.create(value, prev)) + "fib-" + i, new SumTask().withUpstreamNode(world), SumInput.create(value, prev)) .getOutputs(); prev = value; value = next; diff --git a/flytekit-examples/src/main/java/org/flyte/examples/FlyteEnvironment.java b/flytekit-examples/src/main/java/org/flyte/examples/FlyteEnvironment.java index 1d6575a45..3654e015c 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/FlyteEnvironment.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/FlyteEnvironment.java @@ -18,7 +18,8 @@ public final class FlyteEnvironment { - public static final String DOMAIN = "development"; + public static final String DEVELOPMENT_DOMAIN = "development"; + public static final String STAGING_DOMAIN = "staging"; public static final String PROJECT = "flytesnacks"; private FlyteEnvironment() { diff --git a/integration-tests/pom.xml b/integration-tests/pom.xml index 19288bb95..c70e58377 100644 --- a/integration-tests/pom.xml +++ b/integration-tests/pom.xml @@ -58,6 +58,11 @@ flytekit-examples test + + org.flyte + flytekit-examples-scala_2.13 + test + org.flyte jflyte diff --git a/integration-tests/src/test/java/org/flyte/JavaExamplesIT.java b/integration-tests/src/test/java/org/flyte/JavaExamplesIT.java index f3bf204a1..cbce85dee 100644 --- a/integration-tests/src/test/java/org/flyte/JavaExamplesIT.java +++ b/integration-tests/src/test/java/org/flyte/JavaExamplesIT.java @@ -17,6 +17,7 @@ package org.flyte; import static org.flyte.FlyteContainer.CLIENT; +import static org.flyte.examples.FlyteEnvironment.STAGING_DOMAIN; import static org.flyte.utils.Literal.ofIntegerMap; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; @@ -29,11 +30,13 @@ @TestInstance(TestInstance.Lifecycle.PER_CLASS) public class JavaExamplesIT { - private static final String CLASSPATH = "flytekit-examples/target/lib"; + private static final String CLASSPATH_EXAMPLES = "flytekit-examples/target/lib"; + private static final String CLASSPATH_EXAMPLES_SCALA = "flytekit-examples-scala/target/lib"; @BeforeAll public static void beforeAll() { - CLIENT.registerWorkflows(CLASSPATH); + CLIENT.registerWorkflows(CLASSPATH_EXAMPLES); + CLIENT.registerWorkflows(CLASSPATH_EXAMPLES_SCALA, STAGING_DOMAIN); } @Test diff --git a/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxClient.java b/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxClient.java index 0279ec7e4..f614ef596 100644 --- a/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxClient.java +++ b/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxClient.java @@ -16,7 +16,7 @@ */ package org.flyte.utils; -import static org.flyte.examples.FlyteEnvironment.DOMAIN; +import static org.flyte.examples.FlyteEnvironment.DEVELOPMENT_DOMAIN; import static org.flyte.examples.FlyteEnvironment.PROJECT; import flyteidl.admin.ExecutionOuterClass; @@ -59,7 +59,7 @@ public Literals.LiteralMap createTaskExecution(String name, Literals.LiteralMap return createExecution( IdentifierOuterClass.Identifier.newBuilder() .setResourceType(IdentifierOuterClass.ResourceType.TASK) - .setDomain(DOMAIN) + .setDomain(DEVELOPMENT_DOMAIN) .setProject(PROJECT) .setName(name) .setVersion(version) @@ -71,7 +71,7 @@ public Literals.LiteralMap createExecution(String name, Literals.LiteralMap inpu return createExecution( IdentifierOuterClass.Identifier.newBuilder() .setResourceType(IdentifierOuterClass.ResourceType.LAUNCH_PLAN) - .setDomain(DOMAIN) + .setDomain(DEVELOPMENT_DOMAIN) .setProject(PROJECT) .setName(name) .setVersion(version) @@ -84,7 +84,7 @@ private Literals.LiteralMap createExecution( ExecutionOuterClass.ExecutionCreateResponse response = stub.createExecution( ExecutionOuterClass.ExecutionCreateRequest.newBuilder() - .setDomain(DOMAIN) + .setDomain(DEVELOPMENT_DOMAIN) .setProject(PROJECT) .setInputs(inputs) .setSpec(ExecutionOuterClass.ExecutionSpec.newBuilder().setLaunchPlan(id).build()) @@ -148,14 +148,14 @@ private boolean isRunning(Execution.WorkflowExecution.Phase phase) { return false; } - public void registerWorkflows(String classpath) { + public void registerWorkflows(String classpath, String domain) { try { jflyte( "jflyte", "register", "workflows", "-p=" + PROJECT, - "-d=" + DOMAIN, + "-d=" + domain, "-v=" + version, "-cp=" + classpath); } catch (Exception e) { @@ -163,6 +163,10 @@ public void registerWorkflows(String classpath) { } } + public void registerWorkflows(String classpath) { + registerWorkflows(classpath, DEVELOPMENT_DOMAIN); + } + public void serializeWorkflows(String classpath, String folder) { jflyte("jflyte", "serialize", "workflows", "-cp=" + classpath, "-f=" + folder); } diff --git a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/FlyteAdminClient.java b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/FlyteAdminClient.java index 3fc6bf812..bc687f4f1 100644 --- a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/FlyteAdminClient.java +++ b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/FlyteAdminClient.java @@ -25,7 +25,6 @@ import flyteidl.admin.LaunchPlanOuterClass; import flyteidl.admin.TaskOuterClass; import flyteidl.admin.WorkflowOuterClass; -import flyteidl.core.IdentifierOuterClass; import flyteidl.service.AdminServiceGrpc; import io.grpc.Channel; import io.grpc.ClientInterceptor; @@ -185,8 +184,15 @@ public TaskIdentifier fetchLatestTaskId(NamedEntityIdentifier taskId) { return fetchLatestResource( taskId, request -> stub.listTasks(request).getTasksList(), - TaskOuterClass.Task::getId, - ProtoUtil::deserializeTaskId); + task -> ProtoUtil.deserializeTaskId(task.getId())); + } + + @Nullable + public TaskTemplate fetchLatestTaskTemplate(NamedEntityIdentifier taskId) { + return fetchLatestResource( + taskId, + request -> stub.listTasks(request).getTasksList(), + task -> ProtoUtil.deserialize(task.getClosure().getCompiledTask().getTemplate())); } @Nullable @@ -194,8 +200,7 @@ public WorkflowIdentifier fetchLatestWorkflowId(NamedEntityIdentifier workflowId return fetchLatestResource( workflowId, request -> stub.listWorkflows(request).getWorkflowsList(), - WorkflowOuterClass.Workflow::getId, - ProtoUtil::deserializeWorkflowId); + workflow -> ProtoUtil.deserializeWorkflowId(workflow.getId())); } @Nullable @@ -203,16 +208,14 @@ public LaunchPlanIdentifier fetchLatestLaunchPlanId(NamedEntityIdentifier launch return fetchLatestResource( launchPlanId, request -> stub.listLaunchPlans(request).getLaunchPlansList(), - LaunchPlanOuterClass.LaunchPlan::getId, - ProtoUtil::deserializeLaunchPlanId); + launchPlan -> ProtoUtil.deserializeLaunchPlanId(launchPlan.getId())); } @Nullable private T fetchLatestResource( NamedEntityIdentifier nameId, Function> performRequestFn, - Function extractIdFn, - Function deserializeFn) { + Function deserializeFn) { ResourceListRequest request = ResourceListRequest.newBuilder() .setLimit(1) @@ -230,8 +233,7 @@ private T fetchLatestResource( return null; } - IdentifierOuterClass.Identifier id = extractIdFn.apply(list.get(0)); - return deserializeFn.apply(id); + return deserializeFn.apply(list.get(0)); } private void idempotentCreate(String label, Object id, GrpcRetries.Retryable retryable) { diff --git a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java index 54445a403..7ef4b9de8 100644 --- a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java +++ b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java @@ -47,6 +47,7 @@ import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiConsumer; +import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Stream; import org.flyte.api.v1.Container; @@ -295,8 +296,8 @@ static void checkCycles(Map allWorkflows) checkCycles( workflowId, allWorkflows, - /*beingVisited=*/ new HashSet<>(), - /*visited=*/ new HashSet<>())) + /* beingVisited= */ new HashSet<>(), + /* visited= */ new HashSet<>())) .findFirst(); if (cycle.isPresent()) { throw new IllegalArgumentException( @@ -374,8 +375,10 @@ public static Map collectSubWorkflows( .collect(toUnmodifiableMap()); } - public static Map collectTasks( - List rewrittenNodes, Map allTasks) { + public static Map collectDynamicWorkflowTasks( + List rewrittenNodes, + Map allTasks, + Function remoteTaskTemplateFetcher) { return collectTaskIds(rewrittenNodes).stream() // all identifiers should be rewritten at this point .map( @@ -389,7 +392,9 @@ public static Map collectTasks( .distinct() .map( taskId -> { - TaskTemplate taskTemplate = allTasks.get(taskId); + TaskTemplate taskTemplate = + Optional.ofNullable(allTasks.get(taskId)) + .orElseGet(() -> remoteTaskTemplateFetcher.apply(taskId)); if (taskTemplate == null) { throw new NoSuchElementException("Can't find referenced task " + taskId); diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/Fixtures.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/Fixtures.java new file mode 100644 index 000000000..37c58a958 --- /dev/null +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/Fixtures.java @@ -0,0 +1,62 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.jflyte.utils; + +import static java.util.Collections.emptyMap; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.flyte.api.v1.Container; +import org.flyte.api.v1.KeyValuePair; +import org.flyte.api.v1.RetryStrategy; +import org.flyte.api.v1.SimpleType; +import org.flyte.api.v1.Struct; +import org.flyte.api.v1.TaskTemplate; +import org.flyte.api.v1.TypedInterface; + +final class Fixtures { + static final String IMAGE_NAME = "alpine:latest"; + static final String COMMAND = "date"; + + static final Container CONTAINER = + Container.builder() + .command(ImmutableList.of(COMMAND)) + .args(ImmutableList.of()) + .image(IMAGE_NAME) + .env(ImmutableList.of(KeyValuePair.of("key", "value"))) + .build(); + static final TypedInterface INTERFACE_ = + TypedInterface.builder() + .inputs(ImmutableMap.of("x", ApiUtils.createVar(SimpleType.STRING))) + .outputs(ImmutableMap.of("y", ApiUtils.createVar(SimpleType.INTEGER))) + .build(); + static final RetryStrategy RETRIES = RetryStrategy.builder().retries(4).build(); + static final TaskTemplate TASK_TEMPLATE = + TaskTemplate.builder() + .container(CONTAINER) + .type("custom-task") + .interface_(INTERFACE_) + .custom(Struct.of(emptyMap())) + .retries(RETRIES) + .discoverable(false) + .cacheSerializable(false) + .build(); + + private Fixtures() { + throw new UnsupportedOperationException(); + } +} diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/FlyteAdminClientTest.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/FlyteAdminClientTest.java index 0c4658b74..93344b321 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/FlyteAdminClientTest.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/FlyteAdminClientTest.java @@ -17,7 +17,9 @@ package org.flyte.jflyte.utils; import static java.util.Collections.emptyList; -import static java.util.Collections.emptyMap; +import static org.flyte.jflyte.utils.Fixtures.COMMAND; +import static org.flyte.jflyte.utils.Fixtures.IMAGE_NAME; +import static org.flyte.jflyte.utils.Fixtures.TASK_TEMPLATE; import static org.flyte.jflyte.utils.FlyteAdminClient.TRIGGERING_PRINCIPAL; import static org.flyte.jflyte.utils.FlyteAdminClient.USER_TRIGGERED_EXECUTION_NESTING; import static org.hamcrest.MatcherAssert.assertThat; @@ -32,7 +34,10 @@ import flyteidl.admin.LaunchPlanOuterClass; import flyteidl.admin.ScheduleOuterClass; import flyteidl.admin.TaskOuterClass; +import flyteidl.admin.TaskOuterClass.Task; +import flyteidl.admin.TaskOuterClass.TaskClosure; import flyteidl.admin.WorkflowOuterClass; +import flyteidl.core.Compiler.CompiledTask; import flyteidl.core.IdentifierOuterClass; import flyteidl.core.IdentifierOuterClass.ResourceType; import flyteidl.core.Interface; @@ -51,9 +56,7 @@ import java.util.Collections; import org.flyte.api.v1.Binding; import org.flyte.api.v1.BindingData; -import org.flyte.api.v1.Container; import org.flyte.api.v1.CronSchedule; -import org.flyte.api.v1.KeyValuePair; import org.flyte.api.v1.LaunchPlan; import org.flyte.api.v1.LaunchPlanIdentifier; import org.flyte.api.v1.Literal; @@ -64,10 +67,8 @@ import org.flyte.api.v1.PartialTaskIdentifier; import org.flyte.api.v1.PartialWorkflowIdentifier; import org.flyte.api.v1.Primitive; -import org.flyte.api.v1.RetryStrategy; import org.flyte.api.v1.Scalar; import org.flyte.api.v1.SimpleType; -import org.flyte.api.v1.Struct; import org.flyte.api.v1.TaskIdentifier; import org.flyte.api.v1.TaskNode; import org.flyte.api.v1.TaskTemplate; @@ -95,8 +96,6 @@ public class FlyteAdminClientTest { private static final String WF_NAME = "workflow-foo"; private static final String WF_VERSION = "version-wf-foo"; private static final String WF_OLD_VERSION = "version-wf-bar"; - private static final String IMAGE_NAME = "alpine:latest"; - private static final String COMMAND = "date"; private FlyteAdminClient client; private TestAdminService stubService; @@ -138,33 +137,7 @@ public void shouldPropagateCreateTaskToStub() { .version(TASK_VERSION) .build(); - TypedInterface interface_ = - TypedInterface.builder() - .inputs(ImmutableMap.of("x", ApiUtils.createVar(SimpleType.STRING))) - .outputs(ImmutableMap.of("y", ApiUtils.createVar(SimpleType.INTEGER))) - .build(); - - Container container = - Container.builder() - .command(ImmutableList.of(COMMAND)) - .args(ImmutableList.of()) - .image(IMAGE_NAME) - .env(ImmutableList.of(KeyValuePair.of("key", "value"))) - .build(); - - RetryStrategy retries = RetryStrategy.builder().retries(4).build(); - TaskTemplate template = - TaskTemplate.builder() - .container(container) - .type("custom-task") - .interface_(interface_) - .custom(Struct.of(emptyMap())) - .retries(retries) - .discoverable(false) - .cacheSerializable(false) - .build(); - - client.createTask(identifier, template); + client.createTask(identifier, TASK_TEMPLATE); assertThat( stubService.createTaskRequest, @@ -397,6 +370,35 @@ public void fetchLatestTaskIdShouldReturnFirstTaskFromList() { .build())); } + @Test + public void fetchLatestTaskShouldReturnFirstTaskFromList() { + stubService.taskLists = + Arrays.asList( + Task.newBuilder() + .setId(newIdentifier(ResourceType.TASK, TASK_NAME, TASK_VERSION)) + .setClosure( + TaskClosure.newBuilder() + .setCompiledTask( + CompiledTask.newBuilder() + .setTemplate(ProtoUtil.serialize(TASK_TEMPLATE)) + .build()) + .build()) + .build(), + TaskOuterClass.Task.newBuilder() + .setId(newIdentifier(ResourceType.TASK, TASK_NAME, TASK_OLD_VERSION)) + .build()); + + TaskTemplate fetchLatestTaskTemplate = + client.fetchLatestTaskTemplate( + NamedEntityIdentifier.builder() + .project(PROJECT) + .domain(DOMAIN) + .name(TASK_NAME) + .build()); + + assertThat(fetchLatestTaskTemplate, equalTo(TASK_TEMPLATE)); + } + @Test public void fetchLatestTaskIdShouldReturnNullWhenEmptyList() { stubService.taskLists = Collections.emptyList(); @@ -517,7 +519,7 @@ private TaskOuterClass.TaskSpec newTaskSpec() { Tasks.TaskTemplate.newBuilder() .setContainer( Tasks.Container.newBuilder() - .setImage(FlyteAdminClientTest.IMAGE_NAME) + .setImage(IMAGE_NAME) .addCommand(COMMAND) .addEnv( Literals.KeyValuePair.newBuilder() diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java index 924107a60..cd1f698e7 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java @@ -18,8 +18,11 @@ import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonList; +import static java.util.Objects.requireNonNull; import static org.flyte.api.v1.Resources.ResourceName.CPU; import static org.flyte.api.v1.Resources.ResourceName.MEMORY; +import static org.flyte.jflyte.utils.Fixtures.TASK_TEMPLATE; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; @@ -48,6 +51,7 @@ import org.flyte.api.v1.Literal; import org.flyte.api.v1.Node; import org.flyte.api.v1.Operand; +import org.flyte.api.v1.PartialTaskIdentifier; import org.flyte.api.v1.PartialWorkflowIdentifier; import org.flyte.api.v1.Primitive; import org.flyte.api.v1.Resources; @@ -55,6 +59,8 @@ import org.flyte.api.v1.RunnableTask; import org.flyte.api.v1.Struct; import org.flyte.api.v1.Task; +import org.flyte.api.v1.TaskIdentifier; +import org.flyte.api.v1.TaskNode; import org.flyte.api.v1.TaskTemplate; import org.flyte.api.v1.TypedInterface; import org.flyte.api.v1.WorkflowIdentifier; @@ -665,6 +671,69 @@ public void testCreateTaskTemplateForTasksWithCacheDisabled() { } } + @Test + void testCollectDynamicWorkflowTasks() { + Node node1 = + Node.builder() + .id("foo") + .upstreamNodeIds(singletonList("upstream-1")) + .inputs(List.of()) + .taskNode( + TaskNode.builder() + .referenceId( + PartialTaskIdentifier.builder() + .domain("domain1") + .project("project1") + .name("name1") + .version("version1") + .build()) + .build()) + .build(); + Node node2 = + Node.builder() + .id("bar") + .upstreamNodeIds(singletonList("upstream-1")) + .inputs(List.of()) + .taskNode( + TaskNode.builder() + .referenceId( + PartialTaskIdentifier.builder() + .domain("domain2") + .project("project2") + .name("nam2") + .version("versio2") + .build()) + .build()) + .build(); + + TaskIdentifier taskIdentifier1 = + TaskIdentifier.builder() + .domain(requireNonNull(node1.taskNode()).referenceId().domain()) + .project(requireNonNull(node1.taskNode()).referenceId().project()) + .name(requireNonNull(node1.taskNode()).referenceId().name()) + .version(requireNonNull(node1.taskNode()).referenceId().version()) + .build(); + TaskIdentifier taskIdentifier2 = + TaskIdentifier.builder() + .domain(requireNonNull(node2.taskNode()).referenceId().domain()) + .project(requireNonNull(node2.taskNode()).referenceId().project()) + .name(requireNonNull(node2.taskNode()).referenceId().name()) + .version(requireNonNull(node2.taskNode()).referenceId().version()) + .build(); + + Map taskTemplateMap = Map.of(taskIdentifier1, TASK_TEMPLATE); + + TaskTemplate taskTemplate = TASK_TEMPLATE.toBuilder().type("foo-bar-task").build(); + + Map collectedTaskTemplateMap = + ProjectClosure.collectDynamicWorkflowTasks( + List.of(node1, node2), taskTemplateMap, taskIdentifier -> taskTemplate); + + assertThat( + collectedTaskTemplateMap, + equalTo(Map.of(taskIdentifier1, TASK_TEMPLATE, taskIdentifier2, taskTemplate))); + } + private RunnableTask createRunnableTask( Resources expectedResources, List customJavaToolOptions) { return new RunnableTask() { diff --git a/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java b/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java index 21cb3338e..c5de2cc88 100644 --- a/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java +++ b/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java @@ -40,6 +40,7 @@ import org.flyte.api.v1.DynamicWorkflowTask; import org.flyte.api.v1.DynamicWorkflowTaskRegistrar; import org.flyte.api.v1.Literal; +import org.flyte.api.v1.NamedEntityIdentifier; import org.flyte.api.v1.Node; import org.flyte.api.v1.RunnableTask; import org.flyte.api.v1.RunnableTaskRegistrar; @@ -222,7 +223,8 @@ static DynamicJobSpec rewrite( ProjectClosure.collectSubWorkflows(rewrittenNodes, workflowTemplates); Map usedTaskTemplates = - ProjectClosure.collectTasks(rewrittenNodes, taskTemplates); + ProjectClosure.collectDynamicWorkflowTasks( + rewrittenNodes, taskTemplates, id -> fetchTaskTemplate(flyteAdminClient, id)); // FIXME one sub-workflow can use more sub-workflows, we should recursively collect used tasks // and workflows @@ -246,6 +248,27 @@ static DynamicJobSpec rewrite( } } + // note that there are cases we are making an unnecessary network call because we might have + // already got the task template when resolving the latest task version, but since it is also + // possible that user has provided a version for a remote task, and in that case we would not need + // to resolve the latest version, so we need to make this call; + // we accept the additional cost because it should be rare to have remote tasks in a dynamic + // workflow + private static TaskTemplate fetchTaskTemplate( + FlyteAdminClient flyteAdminClient, TaskIdentifier id) { + LOG.info("fetching task template remotely for {}", id); + + TaskTemplate taskTemplate = + flyteAdminClient.fetchLatestTaskTemplate( + NamedEntityIdentifier.builder() + .domain(id.domain()) + .project(id.project()) + .name(id.name()) + .build()); + + return taskTemplate; + } + private static DynamicWorkflowTask getDynamicWorkflowTask(String name) { // be careful not to pass extra Map env = getEnv();