Skip to content

Commit

Permalink
Use Task SDK's Context dict in models/taskinstance.py (apache#45834)
Browse files Browse the repository at this point in the history
This PR ports re-uses the Context dict from the Task SDK in `models/taskinstance.py`.

Once, CeleryExecutor & KubernetesExecutor are ported over to Task SDK, we can remove all of this code. This PR unifies some of that code.
  • Loading branch information
kaxil authored Jan 22, 2025
1 parent 537ca7b commit 0e7c639
Show file tree
Hide file tree
Showing 10 changed files with 174 additions and 122 deletions.
4 changes: 3 additions & 1 deletion airflow/cli/commands/remote_commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,11 @@ def _get_dag_run(
dag_run = DagRun(
dag_id=dag.dag_id,
run_id=logical_date_or_run_id,
run_type=DagRunType.MANUAL,
logical_date=dag_run_logical_date,
data_interval=dag.timetable.infer_manual_data_interval(run_after=dag_run_logical_date),
triggered_by=DagRunTriggeredByType.CLI,
state=DagRunState.RUNNING,
)
return dag_run, True
elif create_if_necessary == "db":
Expand All @@ -186,7 +188,7 @@ def _get_dag_run(
run_type=DagRunType.MANUAL,
triggered_by=DagRunTriggeredByType.CLI,
dag_version=None,
state=DagRunState.QUEUED,
state=DagRunState.RUNNING,
session=session,
)
return dag_run, True
Expand Down
2 changes: 2 additions & 0 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
from airflow.utils.operator_helpers import ExecutionCallableRunner
from airflow.utils.operator_resources import Resources
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunTriggeredByType
from airflow.utils.xcom import XCOM_RETURN_KEY

Expand Down Expand Up @@ -632,6 +633,7 @@ def run(
logical_date=info.logical_date,
data_interval=info.data_interval,
triggered_by=DagRunTriggeredByType.TEST,
state=DagRunState.RUNNING,
)
ti = TaskInstance(self, run_id=dr.run_id)
ti.dag_run = dr
Expand Down
133 changes: 62 additions & 71 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from jinja2 import TemplateAssertionError, UndefinedError
from sqlalchemy import (
Column,
DateTime,
Float,
ForeignKey,
ForeignKeyConstraint,
Expand Down Expand Up @@ -162,7 +161,6 @@
from airflow.sdk.definitions._internal.abstractoperator import Operator
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.types import OutletEventAccessorsProtocol, RuntimeTaskInstanceProtocol
from airflow.timetables.base import DataInterval
from airflow.typing_compat import Literal, TypeGuard
from airflow.utils.task_group import TaskGroup

Expand Down Expand Up @@ -928,6 +926,11 @@ def _get_template_context(
from airflow import macros
from airflow.models.abstractoperator import NotMapped
from airflow.models.baseoperator import BaseOperator
from airflow.sdk.api.datamodels._generated import (
DagRun as DagRunSDK,
PrevSuccessfulDagRunResponse,
TIRunContext,
)

integrate_macros_plugins()

Expand All @@ -938,50 +941,34 @@ def _get_template_context(
assert task.dag

dag_run = task_instance.get_dagrun(session)
data_interval = dag.get_run_data_interval(dag_run)

validated_params = process_params(dag, task, dag_run.conf, suppress_exception=ignore_param_exceptions)

logical_date: DateTime = timezone.coerce_datetime(task_instance.logical_date)
ds = logical_date.strftime("%Y-%m-%d")
ds_nodash = ds.replace("-", "")
ts = logical_date.isoformat()
ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S")
ts_nodash_with_tz = ts.replace("-", "").replace(":", "")
ti_context_from_server = TIRunContext(
dag_run=DagRunSDK.model_validate(dag_run, from_attributes=True),
max_tries=task_instance.max_tries,
)
runtime_ti = task_instance.to_runtime_ti(context_from_server=ti_context_from_server)

context: Context = runtime_ti.get_template_context()

@cache # Prevent multiple database access.
def _get_previous_dagrun_success() -> DagRun | None:
return task_instance.get_previous_dagrun(state=DagRunState.SUCCESS, session=session)

def _get_previous_dagrun_data_interval_success() -> DataInterval | None:
dagrun = _get_previous_dagrun_success()
if dagrun is None:
return None
return dag.get_run_data_interval(dagrun)
def _get_previous_dagrun_success() -> PrevSuccessfulDagRunResponse:
dr_from_db = task_instance.get_previous_dagrun(state=DagRunState.SUCCESS, session=session)
if dr_from_db:
return PrevSuccessfulDagRunResponse.model_validate(dr_from_db, from_attributes=True)
return PrevSuccessfulDagRunResponse()

def get_prev_data_interval_start_success() -> pendulum.DateTime | None:
data_interval = _get_previous_dagrun_data_interval_success()
if data_interval is None:
return None
return data_interval.start
return timezone.coerce_datetime(_get_previous_dagrun_success().data_interval_start)

def get_prev_data_interval_end_success() -> pendulum.DateTime | None:
data_interval = _get_previous_dagrun_data_interval_success()
if data_interval is None:
return None
return data_interval.end
return timezone.coerce_datetime(_get_previous_dagrun_success().data_interval_end)

def get_prev_start_date_success() -> pendulum.DateTime | None:
dagrun = _get_previous_dagrun_success()
if dagrun is None:
return None
return timezone.coerce_datetime(dagrun.start_date)
return timezone.coerce_datetime(_get_previous_dagrun_success().start_date)

def get_prev_end_date_success() -> pendulum.DateTime | None:
dagrun = _get_previous_dagrun_success()
if dagrun is None:
return None
return timezone.coerce_datetime(dagrun.end_date)
return timezone.coerce_datetime(_get_previous_dagrun_success().end_date)

def get_triggering_events() -> dict[str, list[AssetEvent]]:
if TYPE_CHECKING:
Expand All @@ -1005,41 +992,29 @@ def get_triggering_events() -> dict[str, list[AssetEvent]]:
# * Context in task_sdk/src/airflow/sdk/definitions/context.py
# * KNOWN_CONTEXT_KEYS in airflow/utils/context.py
# * Table in docs/apache-airflow/templates-ref.rst
context: Context = {
"dag": dag,
"dag_run": dag_run,
"data_interval_end": timezone.coerce_datetime(data_interval.end),
"data_interval_start": timezone.coerce_datetime(data_interval.start),
"outlet_events": OutletEventAccessors(),
"ds": ds,
"ds_nodash": ds_nodash,
"inlets": task.inlets,
"inlet_events": InletEventsAccessors(task.inlets, session=session),
"logical_date": logical_date,
"macros": macros,
"map_index_template": task.map_index_template,
"outlets": task.outlets,
"params": validated_params,
"prev_data_interval_start_success": get_prev_data_interval_start_success(),
"prev_data_interval_end_success": get_prev_data_interval_end_success(),
"prev_start_date_success": get_prev_start_date_success(),
"prev_end_date_success": get_prev_end_date_success(),
"run_id": task_instance.run_id,
"task": task, # type: ignore[typeddict-item]
"task_instance": task_instance,
"task_instance_key_str": f"{task.dag_id}__{task.task_id}__{ds_nodash}",
"test_mode": task_instance.test_mode,
"ti": task_instance,
"triggering_asset_events": lazy_object_proxy.Proxy(get_triggering_events),
"ts": ts,
"ts_nodash": ts_nodash,
"ts_nodash_with_tz": ts_nodash_with_tz,
"var": {
"json": VariableAccessor(deserialize_json=True),
"value": VariableAccessor(deserialize_json=False),
},
"conn": ConnectionAccessor(),
}

context.update(
{
"outlet_events": OutletEventAccessors(),
"inlet_events": InletEventsAccessors(task.inlets, session=session),
"macros": macros,
"params": validated_params,
"prev_data_interval_start_success": get_prev_data_interval_start_success(),
"prev_data_interval_end_success": get_prev_data_interval_end_success(),
"prev_start_date_success": get_prev_start_date_success(),
"prev_end_date_success": get_prev_end_date_success(),
"test_mode": task_instance.test_mode,
# ti/task_instance are added here for ti.xcom_{push,pull}
"task_instance": task_instance,
"ti": task_instance,
"triggering_asset_events": lazy_object_proxy.Proxy(get_triggering_events),
"var": {
"json": VariableAccessor(deserialize_json=True),
"value": VariableAccessor(deserialize_json=False),
},
"conn": ConnectionAccessor(),
}
)

try:
expanded_ti_count: int | None = BaseOperator.get_mapped_ti_count(
Expand All @@ -1058,8 +1033,6 @@ def get_triggering_events() -> dict[str, list[AssetEvent]]:
except NotMapped:
pass

# Mypy doesn't like turning existing dicts in to a TypeDict -- and we "lie" in the type stub to say it
# is one, but in practice it isn't. See https://github.com/python/mypy/issues/8890
return context


Expand Down Expand Up @@ -1902,6 +1875,24 @@ def from_runtime_ti(cls, runtime_ti: RuntimeTaskInstanceProtocol) -> TaskInstanc
assert isinstance(ti, TaskInstance)
return ti

def to_runtime_ti(self, context_from_server) -> RuntimeTaskInstanceProtocol:
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance

runtime_ti = RuntimeTaskInstance.model_construct(
id=self.id,
task_id=self.task_id,
dag_id=self.dag_id,
run_id=self.run_id,
try_numer=self.try_number,
map_index=self.map_index,
task=self.task,
max_tries=self.max_tries,
hostname=self.hostname,
_ti_context_from_server=context_from_server,
)

return runtime_ti

@staticmethod
def _command_as_list(
ti: TaskInstance,
Expand Down
35 changes: 12 additions & 23 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@
AssetUriRef,
)
from airflow.sdk.definitions.context import Context
from airflow.sdk.execution_time.context import OutletEventAccessors as OutletEventAccessorsSDK
from airflow.sdk.execution_time.context import (
ConnectionAccessor as ConnectionAccessorSDK,
OutletEventAccessors as OutletEventAccessorsSDK,
VariableAccessor as VariableAccessorSDK,
)
from airflow.utils.db import LazySelectSequence
from airflow.utils.session import create_session
from airflow.utils.types import NOTSET
Expand Down Expand Up @@ -107,21 +111,13 @@
}


class VariableAccessor:
class VariableAccessor(VariableAccessorSDK):
"""Wrapper to access Variable values in template."""

def __init__(self, *, deserialize_json: bool) -> None:
self._deserialize_json = deserialize_json
self.var: Any = None

def __getattr__(self, key: str) -> Any:
from airflow.models.variable import Variable

self.var = Variable.get(key, deserialize_json=self._deserialize_json)
return self.var

def __repr__(self) -> str:
return str(self.var)
return Variable.get(key, deserialize_json=self._deserialize_json)

def get(self, key, default: Any = NOTSET) -> Any:
from airflow.models.variable import Variable
Expand All @@ -131,27 +127,20 @@ def get(self, key, default: Any = NOTSET) -> Any:
return Variable.get(key, default, deserialize_json=self._deserialize_json)


class ConnectionAccessor:
class ConnectionAccessor(ConnectionAccessorSDK):
"""Wrapper to access Connection entries in template."""

def __init__(self) -> None:
self.var: Any = None

def __getattr__(self, key: str) -> Any:
def __getattr__(self, conn_id: str) -> Any:
from airflow.models.connection import Connection

self.var = Connection.get_connection_from_secrets(key)
return self.var

def __repr__(self) -> str:
return str(self.var)
return Connection.get_connection_from_secrets(conn_id)

def get(self, key: str, default_conn: Any = None) -> Any:
def get(self, conn_id: str, default_conn: Any = None) -> Any:
from airflow.exceptions import AirflowNotFoundException
from airflow.models.connection import Connection

try:
return Connection.get_connection_from_secrets(key)
return Connection.get_connection_from_secrets(conn_id)
except AirflowNotFoundException:
return default_conn

Expand Down
63 changes: 51 additions & 12 deletions scripts/ci/pre_commit/template_context_key_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,32 +27,65 @@

ROOT_DIR = pathlib.Path(__file__).resolve().parents[3]

TASKINSTANCE_PY = ROOT_DIR.joinpath("airflow", "models", "taskinstance.py")
TASKRUNNER_PY = ROOT_DIR.joinpath("task_sdk", "src", "airflow", "sdk", "execution_time", "task_runner.py")
CONTEXT_PY = ROOT_DIR.joinpath("airflow", "utils", "context.py")
CONTEXT_HINT = ROOT_DIR.joinpath("task_sdk", "src", "airflow", "sdk", "definitions", "context.py")
TEMPLATES_REF_RST = ROOT_DIR.joinpath("docs", "apache-airflow", "templates-ref.rst")


def _iter_template_context_keys_from_original_return() -> typing.Iterator[str]:
ti_mod = ast.parse(TASKINSTANCE_PY.read_text("utf-8"), str(TASKINSTANCE_PY))
fn_get_template_context = next(
ti_mod = ast.parse(TASKRUNNER_PY.read_text("utf-8"), str(TASKRUNNER_PY))

# Locate the RuntimeTaskInstance class definition
runtime_task_instance_class = next(
node
for node in ast.iter_child_nodes(ti_mod)
if isinstance(node, ast.FunctionDef) and node.name == "_get_template_context"
if isinstance(node, ast.ClassDef) and node.name == "RuntimeTaskInstance"
)
st_context_value = next(
stmt.value

# Locate the get_template_context method in RuntimeTaskInstance
fn_get_template_context = next(
node
for node in ast.iter_child_nodes(runtime_task_instance_class)
if isinstance(node, ast.FunctionDef) and node.name == "get_template_context"
)

# Helper function to extract keys from a dictionary node
def extract_keys_from_dict(node: ast.Dict) -> typing.Iterator[str]:
for key in node.keys:
if not isinstance(key, ast.Constant) or not isinstance(key.value, str):
raise ValueError("Key in dictionary is not a string literal")
yield key.value

# Extract keys from the main `context` dictionary assignment
context_assignment = next(
stmt
for stmt in fn_get_template_context.body
if isinstance(stmt, ast.AnnAssign)
and isinstance(stmt.target, ast.Name)
and stmt.target.id == "context"
)
if not isinstance(st_context_value, ast.Dict):
raise ValueError("'context' is not assigned a dict literal")
for expr in st_context_value.keys:
if not isinstance(expr, ast.Constant) or not isinstance(expr.value, str):
raise ValueError("key in 'context' dict is not a str literal")
yield expr.value

if not isinstance(context_assignment.value, ast.Dict):
raise ValueError("'context' is not assigned a dictionary literal")
yield from extract_keys_from_dict(context_assignment.value)

# Handle keys added conditionally in `if self._ti_context_from_server`
for stmt in fn_get_template_context.body:
if (
isinstance(stmt, ast.If)
and isinstance(stmt.test, ast.Attribute)
and stmt.test.attr == "_ti_context_from_server"
):
for sub_stmt in stmt.body:
# Get keys from `context_from_server` assignment
if (
isinstance(sub_stmt, ast.AnnAssign)
and isinstance(sub_stmt.target, ast.Name)
and isinstance(sub_stmt.value, ast.Dict)
and sub_stmt.target.id == "context_from_server"
):
yield from extract_keys_from_dict(sub_stmt.value)


def _iter_template_context_keys_from_declaration() -> typing.Iterator[str]:
Expand Down Expand Up @@ -105,6 +138,12 @@ def _compare_keys(retn_keys: set[str], decl_keys: set[str], hint_keys: set[str],
# Compat shim for task-sdk, not actually designed for user use
retn_keys.add("expanded_ti_count")

# TODO: These are the keys that are yet to be ported over to the Task SDK.
retn_keys.add("inlet_events")
retn_keys.add("params")
retn_keys.add("test_mode")
retn_keys.add("triggering_asset_events")

# Only present in callbacks. Not listed in templates-ref (that doc is for task execution).
retn_keys.update(("exception", "reason", "try_number"))
docs_keys.update(("exception", "reason", "try_number"))
Expand Down
Loading

0 comments on commit 0e7c639

Please sign in to comment.