Skip to content

Commit

Permalink
Fix subworkflow collecting
Browse files Browse the repository at this point in the history
Signed-off-by: Hongxin Liang <[email protected]>
  • Loading branch information
honnix committed Oct 9, 2023
1 parent 27be6d4 commit 1f50534
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ static ProjectClosure load(
rewrittenWorkflowTemplates,
workflowTemplate -> {
Map<WorkflowIdentifier, WorkflowTemplate> subWorkflows =
collectSubWorkflows(workflowTemplate.nodes(), rewrittenWorkflowTemplates);
collectSubWorkflows(
workflowTemplate.nodes(), rewrittenWorkflowTemplates, Function.identity());

return WorkflowSpec.builder()
.workflowTemplate(workflowTemplate)
Expand Down Expand Up @@ -344,7 +345,10 @@ && checkCycles(subWorkflowId, allWorkflows, beingVisited, visited))) {

@VisibleForTesting
public static Map<WorkflowIdentifier, WorkflowTemplate> collectSubWorkflows(
List<Node> rewrittenNodes, Map<WorkflowIdentifier, WorkflowTemplate> allWorkflows) {
List<Node> nodes,
Map<WorkflowIdentifier, WorkflowTemplate> allWorkflows,
Function<List<Node>, List<Node>> nodesRewriter) {
List<Node> rewrittenNodes = nodesRewriter.apply(nodes);
return collectSubWorkflowIds(rewrittenNodes).stream()
// all identifiers should be rewritten at this point
.map(
Expand All @@ -366,7 +370,7 @@ public static Map<WorkflowIdentifier, WorkflowTemplate> collectSubWorkflows(
}

Map<WorkflowIdentifier, WorkflowTemplate> nestedSubWorkflows =
collectSubWorkflows(subWorkflow.nodes(), allWorkflows);
collectSubWorkflows(subWorkflow.nodes(), allWorkflows, nodesRewriter);

return Stream.concat(
Stream.of(Maps.immutableEntry(workflowId, subWorkflow)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.flyte.api.v1.BooleanExpression;
import org.flyte.api.v1.BranchNode;
import org.flyte.api.v1.ComparisonExpression;
Expand Down Expand Up @@ -258,7 +259,7 @@ public void testCollectSubWorkflows() {
nestedOtherSubWorkflowRef, emptyWorkflowTemplate);

Map<WorkflowIdentifier, WorkflowTemplate> collectedSubWorkflows =
ProjectClosure.collectSubWorkflows(nodes, allWorkflows);
ProjectClosure.collectSubWorkflows(nodes, allWorkflows, Function.identity());

assertThat(
collectedSubWorkflows,
Expand Down
85 changes: 45 additions & 40 deletions jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ForkJoinPool;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.flyte.api.v1.Binding;
import org.flyte.api.v1.BindingData;
Expand Down Expand Up @@ -202,8 +203,8 @@ private static DynamicJobSpec rewrite(
Config config,
ExecutionConfig executionConfig,
DynamicJobSpec spec,
Map<TaskIdentifier, TaskTemplate> taskTemplates,
Map<WorkflowIdentifier, WorkflowTemplate> workflowTemplates) {
Map<TaskIdentifier, TaskTemplate> allTaskTemplates,
Map<WorkflowIdentifier, WorkflowTemplate> allWorkflowTemplates) {

try (FlyteAdminClient flyteAdminClient =
FlyteAdminClient.create(config.platformUrl(), config.platformInsecure(), null)) {
Expand All @@ -216,24 +217,41 @@ private static DynamicJobSpec rewrite(
.adminClient(flyteAdminClient)
.build()
.visitor();
Function<List<Node>, List<Node>> nodesRewriter =
nodes -> nodes.stream().map(workflowNodeVisitor::visitNode).collect(toUnmodifiableList());

Map<WorkflowIdentifier, WorkflowTemplate> allUsedSubWorkflows =
collectAllUsedSubWorkflows(
spec.nodes(), allWorkflowTemplates, workflowNodeVisitor, nodesRewriter);

Map<TaskIdentifier, TaskTemplate> allUsedTaskTemplates = new HashMap<>();
Map<WorkflowIdentifier, WorkflowTemplate> allUsedSubWorkflows = new HashMap<>();
Map<TaskIdentifier, TaskTemplate> cache = new HashMap<>();
Map<TaskIdentifier, TaskTemplate> allUsedTaskTemplates = new HashMap<>();

List<Node> nodes =
recursivelyCollect(
// collect directly used task templates
List<Node> rewrittenNodes =
collectTaskTemplates(
spec.nodes(),
nodesRewriter,
allUsedTaskTemplates,
allUsedSubWorkflows,
taskTemplates,
workflowTemplates,
workflowNodeVisitor,
allTaskTemplates,
flyteAdminClient,
cache);

// collect task templates used by subworkflows
allUsedSubWorkflows
.values()
.forEach(
workflowTemplate ->
collectTaskTemplates(
workflowTemplate.nodes(),
nodesRewriter,
allUsedTaskTemplates,
allTaskTemplates,
flyteAdminClient,
cache));

return spec.toBuilder()
.nodes(nodes)
.nodes(rewrittenNodes)
.subWorkflows(
ImmutableMap.<WorkflowIdentifier, WorkflowTemplate>builder()
.putAll(spec.subWorkflows())
Expand All @@ -248,45 +266,32 @@ private static DynamicJobSpec rewrite(
}
}

private static List<Node> recursivelyCollect(
List<Node> startPoint,
private static Map<WorkflowIdentifier, WorkflowTemplate> collectAllUsedSubWorkflows(
List<Node> nodes,
Map<WorkflowIdentifier, WorkflowTemplate> workflowTemplates,
WorkflowNodeVisitor workflowNodeVisitor,
Function<List<Node>, List<Node>> nodesRewriter) {

Map<WorkflowIdentifier, WorkflowTemplate> allUsedSubWorkflows =
ProjectClosure.collectSubWorkflows(nodes, workflowTemplates, nodesRewriter);
return mapValues(allUsedSubWorkflows, workflowNodeVisitor::visitWorkflowTemplate);
}

private static List<Node> collectTaskTemplates(
List<Node> nodes,
Function<List<Node>, List<Node>> nodesRewriter,
Map<TaskIdentifier, TaskTemplate> allUsedTaskTemplates,
Map<WorkflowIdentifier, WorkflowTemplate> allUsedSubWorkflows,
Map<TaskIdentifier, TaskTemplate> allTaskTemplates,
Map<WorkflowIdentifier, WorkflowTemplate> allWorkflowTemplates,
WorkflowNodeVisitor workflowNodeVisitor,
FlyteAdminClient flyteAdminClient,
Map<TaskIdentifier, TaskTemplate> cache) {

List<Node> rewrittenNodes =
startPoint.stream().map(workflowNodeVisitor::visitNode).collect(toUnmodifiableList());
List<Node> rewrittenNodes = nodesRewriter.apply(nodes);

Map<TaskIdentifier, TaskTemplate> usedTaskTemplates =
ProjectClosure.collectDynamicWorkflowTasks(
rewrittenNodes, allTaskTemplates, id -> fetchTaskTemplate(flyteAdminClient, id, cache));
nodes, allTaskTemplates, id -> fetchTaskTemplate(flyteAdminClient, id, cache));
allUsedTaskTemplates.putAll(usedTaskTemplates);

Map<WorkflowIdentifier, WorkflowTemplate> usedSubWorkflows =
ProjectClosure.collectSubWorkflows(rewrittenNodes, allWorkflowTemplates);
Map<WorkflowIdentifier, WorkflowTemplate> rewrittenUsedSubWorkflows =
mapValues(usedSubWorkflows, workflowNodeVisitor::visitWorkflowTemplate);

rewrittenUsedSubWorkflows.forEach(
(key, value) -> {
if (!allUsedSubWorkflows.containsKey(key)) {
allUsedSubWorkflows.put(key, value);
recursivelyCollect(
value.nodes(),
allUsedTaskTemplates,
allUsedSubWorkflows,
allTaskTemplates,
allWorkflowTemplates,
workflowNodeVisitor,
flyteAdminClient,
cache);
}
});

return rewrittenNodes;
}

Expand Down

0 comments on commit 1f50534

Please sign in to comment.