From 1f50534caa6077a8da1377db2232c6f018588324 Mon Sep 17 00:00:00 2001 From: Hongxin Liang Date: Mon, 9 Oct 2023 16:05:29 +0200 Subject: [PATCH] Fix subworkflow collecting Signed-off-by: Hongxin Liang --- .../flyte/jflyte/utils/ProjectClosure.java | 10 ++- .../jflyte/utils/ProjectClosureTest.java | 3 +- .../flyte/jflyte/ExecuteDynamicWorkflow.java | 85 ++++++++++--------- 3 files changed, 54 insertions(+), 44 deletions(-) diff --git a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java index 7ef4b9de8..fe1c2f80c 100644 --- a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java +++ b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java @@ -270,7 +270,8 @@ static ProjectClosure load( rewrittenWorkflowTemplates, workflowTemplate -> { Map subWorkflows = - collectSubWorkflows(workflowTemplate.nodes(), rewrittenWorkflowTemplates); + collectSubWorkflows( + workflowTemplate.nodes(), rewrittenWorkflowTemplates, Function.identity()); return WorkflowSpec.builder() .workflowTemplate(workflowTemplate) @@ -344,7 +345,10 @@ && checkCycles(subWorkflowId, allWorkflows, beingVisited, visited))) { @VisibleForTesting public static Map collectSubWorkflows( - List rewrittenNodes, Map allWorkflows) { + List nodes, + Map allWorkflows, + Function, List> nodesRewriter) { + List rewrittenNodes = nodesRewriter.apply(nodes); return collectSubWorkflowIds(rewrittenNodes).stream() // all identifiers should be rewritten at this point .map( @@ -366,7 +370,7 @@ public static Map collectSubWorkflows( } Map nestedSubWorkflows = - collectSubWorkflows(subWorkflow.nodes(), allWorkflows); + collectSubWorkflows(subWorkflow.nodes(), allWorkflows, nodesRewriter); return Stream.concat( Stream.of(Maps.immutableEntry(workflowId, subWorkflow)), diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java index cd1f698e7..36e6da6e0 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java @@ -38,6 +38,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Function; import org.flyte.api.v1.BooleanExpression; import org.flyte.api.v1.BranchNode; import org.flyte.api.v1.ComparisonExpression; @@ -258,7 +259,7 @@ public void testCollectSubWorkflows() { nestedOtherSubWorkflowRef, emptyWorkflowTemplate); Map collectedSubWorkflows = - ProjectClosure.collectSubWorkflows(nodes, allWorkflows); + ProjectClosure.collectSubWorkflows(nodes, allWorkflows, Function.identity()); assertThat( collectedSubWorkflows, diff --git a/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java b/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java index 5b72b18f5..5e6ec8e2e 100644 --- a/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java +++ b/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java @@ -31,6 +31,7 @@ import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.ForkJoinPool; +import java.util.function.Function; import java.util.stream.Collectors; import org.flyte.api.v1.Binding; import org.flyte.api.v1.BindingData; @@ -202,8 +203,8 @@ private static DynamicJobSpec rewrite( Config config, ExecutionConfig executionConfig, DynamicJobSpec spec, - Map taskTemplates, - Map workflowTemplates) { + Map allTaskTemplates, + Map allWorkflowTemplates) { try (FlyteAdminClient flyteAdminClient = FlyteAdminClient.create(config.platformUrl(), config.platformInsecure(), null)) { @@ -216,24 +217,41 @@ private static DynamicJobSpec rewrite( .adminClient(flyteAdminClient) .build() .visitor(); + Function, List> nodesRewriter = + nodes -> nodes.stream().map(workflowNodeVisitor::visitNode).collect(toUnmodifiableList()); + + Map allUsedSubWorkflows = + collectAllUsedSubWorkflows( + spec.nodes(), allWorkflowTemplates, workflowNodeVisitor, nodesRewriter); - Map allUsedTaskTemplates = new HashMap<>(); - Map allUsedSubWorkflows = new HashMap<>(); Map cache = new HashMap<>(); + Map allUsedTaskTemplates = new HashMap<>(); - List nodes = - recursivelyCollect( + // collect directly used task templates + List rewrittenNodes = + collectTaskTemplates( spec.nodes(), + nodesRewriter, allUsedTaskTemplates, - allUsedSubWorkflows, - taskTemplates, - workflowTemplates, - workflowNodeVisitor, + allTaskTemplates, flyteAdminClient, cache); + // collect task templates used by subworkflows + allUsedSubWorkflows + .values() + .forEach( + workflowTemplate -> + collectTaskTemplates( + workflowTemplate.nodes(), + nodesRewriter, + allUsedTaskTemplates, + allTaskTemplates, + flyteAdminClient, + cache)); + return spec.toBuilder() - .nodes(nodes) + .nodes(rewrittenNodes) .subWorkflows( ImmutableMap.builder() .putAll(spec.subWorkflows()) @@ -248,45 +266,32 @@ private static DynamicJobSpec rewrite( } } - private static List recursivelyCollect( - List startPoint, + private static Map collectAllUsedSubWorkflows( + List nodes, + Map workflowTemplates, + WorkflowNodeVisitor workflowNodeVisitor, + Function, List> nodesRewriter) { + + Map allUsedSubWorkflows = + ProjectClosure.collectSubWorkflows(nodes, workflowTemplates, nodesRewriter); + return mapValues(allUsedSubWorkflows, workflowNodeVisitor::visitWorkflowTemplate); + } + + private static List collectTaskTemplates( + List nodes, + Function, List> nodesRewriter, Map allUsedTaskTemplates, - Map allUsedSubWorkflows, Map allTaskTemplates, - Map allWorkflowTemplates, - WorkflowNodeVisitor workflowNodeVisitor, FlyteAdminClient flyteAdminClient, Map cache) { - List rewrittenNodes = - startPoint.stream().map(workflowNodeVisitor::visitNode).collect(toUnmodifiableList()); + List rewrittenNodes = nodesRewriter.apply(nodes); Map usedTaskTemplates = ProjectClosure.collectDynamicWorkflowTasks( - rewrittenNodes, allTaskTemplates, id -> fetchTaskTemplate(flyteAdminClient, id, cache)); + nodes, allTaskTemplates, id -> fetchTaskTemplate(flyteAdminClient, id, cache)); allUsedTaskTemplates.putAll(usedTaskTemplates); - Map usedSubWorkflows = - ProjectClosure.collectSubWorkflows(rewrittenNodes, allWorkflowTemplates); - Map rewrittenUsedSubWorkflows = - mapValues(usedSubWorkflows, workflowNodeVisitor::visitWorkflowTemplate); - - rewrittenUsedSubWorkflows.forEach( - (key, value) -> { - if (!allUsedSubWorkflows.containsKey(key)) { - allUsedSubWorkflows.put(key, value); - recursivelyCollect( - value.nodes(), - allUsedTaskTemplates, - allUsedSubWorkflows, - allTaskTemplates, - allWorkflowTemplates, - workflowNodeVisitor, - flyteAdminClient, - cache); - } - }); - return rewrittenNodes; }