Skip to content

Commit

Permalink
fix: task flow dynamic mapping with default_args (apache#41592)
Browse files Browse the repository at this point in the history
  • Loading branch information
phi-friday authored and joaopamaral committed Oct 21, 2024
1 parent 71112f0 commit bdee2d3
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 7 deletions.
25 changes: 18 additions & 7 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,18 +431,29 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag()
task_group = task_kwargs.pop("task_group", None) or TaskGroupContext.get_current_task_group(dag)

partial_kwargs, partial_params = get_merged_defaults(
default_args, partial_params = get_merged_defaults(
dag=dag,
task_group=task_group,
task_params=task_kwargs.pop("params", None),
task_default_args=task_kwargs.pop("default_args", None),
)
partial_kwargs.update(
task_kwargs,
is_setup=self.is_setup,
is_teardown=self.is_teardown,
on_failure_fail_dagrun=self.on_failure_fail_dagrun,
)
partial_kwargs: dict[str, Any] = {
"is_setup": self.is_setup,
"is_teardown": self.is_teardown,
"on_failure_fail_dagrun": self.on_failure_fail_dagrun,
}
base_signature = inspect.signature(BaseOperator)
ignore = {
"default_args", # This is target we are working on now.
"kwargs", # A common name for a keyword argument.
"do_xcom_push", # In the same boat as `multiple_outputs`
"multiple_outputs", # We will use `self.multiple_outputs` instead.
"params", # Already handled above `partial_params`.
"task_concurrency", # Deprecated(replaced by `max_active_tis_per_dag`).
}
partial_keys = set(base_signature.parameters) - ignore
partial_kwargs.update({key: value for key, value in default_args.items() if key in partial_keys})
partial_kwargs.update(task_kwargs)

task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group)
if task_group:
Expand Down
24 changes: 24 additions & 0 deletions tests/decorators/test_mapped.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
# under the License.
from __future__ import annotations

import pytest

from airflow.decorators import task
from airflow.models.dag import DAG
from airflow.utils.task_group import TaskGroup
from tests.models import DEFAULT_DATE
Expand All @@ -36,3 +39,24 @@ def f(z):

dag.get_task("t1") == x1.operator
dag.get_task("g.t2") == x2.operator


@pytest.mark.db_test
def test_mapped_task_with_arbitrary_default_args(dag_maker, session):
default_args = {"some": "value", "not": "in", "the": "task", "or": "dag"}
with dag_maker(session=session, default_args=default_args):

@task.python(do_xcom_push=True)
def f(x: int, y: int) -> int:
return x + y

f.partial(y=10).expand(x=[1, 2, 3])

dag_run = dag_maker.create_dagrun(session=session)
decision = dag_run.task_instance_scheduling_decisions(session=session)
xcoms = set()
for ti in decision.schedulable_tis:
ti.run(session=session)
xcoms.add(ti.xcom_pull(session=session, task_ids=ti.task_id, map_indexes=ti.map_index))

assert xcoms == {11, 12, 13}

0 comments on commit bdee2d3

Please sign in to comment.