Skip to content

Commit

Permalink
cleanup and fix non-nested foreach input params
Browse files Browse the repository at this point in the history
  • Loading branch information
saikonen committed Jan 27, 2024
1 parent 72a01c7 commit 9ad36fe
Showing 1 changed file with 34 additions and 20 deletions.
54 changes: 34 additions & 20 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,11 +999,7 @@ def _visit(
.inputs(
Inputs().parameters(
[Parameter("input-paths"), Parameter("split-index")]
+ (
[Parameter("root-input-path")]
if node.is_inside_foreach
else []
)
+ ([Parameter("root-input-path")] if parent_foreach else [])
)
)
.outputs(
Expand Down Expand Up @@ -1126,11 +1122,22 @@ def _container_templates(self):
task_idx = "{{inputs.parameters.split-index}}"
if node.is_inside_foreach and self.graph[node.out_funcs[0]].type == "join":
join_node = self.graph[node.out_funcs[0]]
if any(
join_is_foreach = any(
self.graph[parent].matching_join == join_node.name
for parent in join_node.split_parents
if self.graph[parent].type == "foreach"
):
)
node_in_nested_foreach = (
len(
[
parent
for parent in node.split_parents
if self.graph[parent].type == "foreach"
]
)
> 1
)
if join_is_foreach and node_in_nested_foreach:
# we need to use the split index in case this is the last step in a nested foreach
task_idx = "{{inputs.parameters.split-index}}"
root_input = "{{inputs.parameters.root-input-path}}"
Expand Down Expand Up @@ -1464,11 +1471,13 @@ def _container_templates(self):
# to thank the designers of Argo Workflows for making this so
# straightforward!
inputs = []
has_split_index = False
if node.name != "start":
inputs.append(Parameter("input-paths"))
if any(self.graph[n].type == "foreach" for n in node.in_funcs):
# Fetch split-index from parent
inputs.append(Parameter("split-index"))
has_split_index = True
if (
node.type == "join"
and self.graph[node.split_parents[-1]].type == "foreach"
Expand All @@ -1477,16 +1486,28 @@ def _container_templates(self):
inputs.append(Parameter("max-split"))
if node.is_inside_foreach and self.graph[node.out_funcs[0]].type == "join":
join_node = self.graph[node.out_funcs[0]]
if any(
join_is_foreach = any(
self.graph[parent].matching_join == join_node.name
for parent in join_node.split_parents
if self.graph[parent].type == "foreach"
):
# we need to carry the split-index info for the last step inside a foreach
# for correctly joining nested foreaches
inputs.extend(
[Parameter("split-index"), Parameter("root-input-path")]
)
node_in_nested_foreach = (
len(
[
parent
for parent in node.split_parents
if self.graph[parent].type == "foreach"
]
)
> 1
)
if join_is_foreach and node_in_nested_foreach:
# we need to carry the split-index and root-input-path info for the last step inside a foreach
# for correctly joining nested foreaches
if not has_split_index:
# Don't add duplicate split index parameters.
inputs.append(Parameter("split-index"))
inputs.append(Parameter("root-input-path"))

outputs = []
if node.name != "end":
Expand All @@ -1501,13 +1522,6 @@ def _container_templates(self):
outputs.append(
Parameter("max-split").valueFrom({"path": "/mnt/out/max_split"})
)
# if node.is_inside_foreach:
# # outer foreach index
# outputs.append(
# Parameter("root-index").valueFrom(
# {"path": "/mnt/out/root_index"}
# )
# )

# It makes no sense to set env vars to None (shows up as "None" string)
# Also we skip some env vars (e.g. in case we want to pull them from KUBERNETES_SECRETS)
Expand Down

0 comments on commit 9ad36fe

Please sign in to comment.