Skip to content

Commit 470b3cd

Browse files
setting dependencies correctly between task group and other (#47947)
1 parent df67f84 commit 470b3cd

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

task-sdk/src/airflow/sdk/definitions/taskgroup.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,14 +343,28 @@ def _set_relatives(
343343
if not isinstance(task_or_task_list, Sequence):
344344
task_or_task_list = [task_or_task_list]
345345

346+
# Helper function to find leaves from a task list or task group
347+
def find_leaves(group_or_task) -> list:
348+
while group_or_task:
349+
if list(group_or_task.get_leaves()):
350+
return list(group_or_task.get_leaves())
351+
if group_or_task.upstream_task_ids:
352+
upstream_task_ids_list = list(group_or_task.upstream_task_ids)
353+
return [self.dag.get_task(task_id) for task_id in upstream_task_ids_list]
354+
group_or_task = group_or_task.parent_group
355+
return []
356+
357+
# Check if the current TaskGroup is empty
358+
leaves = find_leaves(self)
359+
346360
for task_like in task_or_task_list:
347361
self.update_relative(task_like, upstream, edge_modifier=edge_modifier)
348362

349363
if upstream:
350364
for task in self.get_roots():
351365
task.set_upstream(task_or_task_list)
352366
else:
353-
for task in self.get_leaves():
367+
for task in leaves: # Use the fetched leaves
354368
task.set_downstream(task_or_task_list)
355369

356370
def __enter__(self) -> TaskGroup:

0 commit comments

Comments
 (0)