Skip to content

Commit

Permalink
AIP-72: Get Previous Successful Dag Run in Task Context (apache#45813)
Browse files Browse the repository at this point in the history
closes apache#45814

Adds following keys to the Task Context:

- prev_data_interval_start_success
- prev_data_interval_end_success
- prev_start_date_success
- prev_end_date_success
  • Loading branch information
kaxil authored Jan 21, 2025
1 parent 916ca46 commit 41b151e
Show file tree
Hide file tree
Showing 11 changed files with 278 additions and 28 deletions.
9 changes: 9 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,12 @@ class TIRunContext(BaseModel):

connections: Annotated[list[ConnectionResponse], Field(default_factory=list)]
"""Connections that can be accessed by the task instance."""


class PrevSuccessfulDagRunResponse(BaseModel):
"""Schema for response with previous successful DagRun information for Task Template Context."""

data_interval_start: UtcDateTime | None = None
data_interval_end: UtcDateTime | None = None
start_date: UtcDateTime | None = None
end_date: UtcDateTime | None = None
39 changes: 38 additions & 1 deletion airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
DagRun,
PrevSuccessfulDagRunResponse,
TIDeferredStatePayload,
TIEnterRunningPayload,
TIHeartbeatInfo,
Expand All @@ -45,7 +46,7 @@
from airflow.models.trigger import Trigger
from airflow.models.xcom import XCom
from airflow.utils import timezone
from airflow.utils.state import State, TerminalTIState
from airflow.utils.state import DagRunState, State, TerminalTIState

# TODO: Add dependency on JWT token
router = AirflowRouter()
Expand Down Expand Up @@ -393,6 +394,42 @@ def ti_put_rtif(
return {"message": "Rendered task instance fields successfully set"}


@router.get(
"/{task_instance_id}/previous-successful-dagrun",
status_code=status.HTTP_200_OK,
responses={
status.HTTP_404_NOT_FOUND: {"description": "Task Instance or Dag Run not found"},
},
)
def get_previous_successful_dagrun(
task_instance_id: UUID, session: SessionDep
) -> PrevSuccessfulDagRunResponse:
"""
Get the previous successful DagRun for a TaskInstance.
The data from this endpoint is used to get values for Task Context.
"""
ti_id_str = str(task_instance_id)
task_instance = session.scalar(select(TI).where(TI.id == ti_id_str))
if not task_instance:
return PrevSuccessfulDagRunResponse()

dag_run = session.scalar(
select(DR)
.where(
DR.dag_id == task_instance.dag_id,
DR.logical_date < task_instance.logical_date,
DR.state == DagRunState.SUCCESS,
)
.order_by(DR.logical_date.desc())
.limit(1)
)
if not dag_run:
return PrevSuccessfulDagRunResponse()

return PrevSuccessfulDagRunResponse.model_validate(dag_run)


def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool:
"""Is task instance is eligible for retry."""
if state == State.RESTARTING:
Expand Down
11 changes: 11 additions & 0 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
AssetResponse,
ConnectionResponse,
DagRunType,
PrevSuccessfulDagRunResponse,
TerminalTIState,
TIDeferredStatePayload,
TIEnterRunningPayload,
Expand Down Expand Up @@ -161,6 +162,15 @@ def set_rtif(self, id: uuid.UUID, body: dict[str, str]) -> dict[str, bool]:
# decouple from the server response string
return {"ok": True}

def get_previous_successful_dagrun(self, id: uuid.UUID) -> PrevSuccessfulDagRunResponse:
"""
Get the previous successful dag run for a given task instance.
The data from it is used to get values for Task Context.
"""
resp = self.client.get(f"task-instances/{id}/previous-successful-dagrun")
return PrevSuccessfulDagRunResponse.model_validate_json(resp.read())


class ConnectionOperations:
__slots__ = ("client",)
Expand All @@ -181,6 +191,7 @@ def get(self, conn_id: str) -> ConnectionResponse | ErrorResponse:
status_code=e.response.status_code,
)
return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND, detail={"conn_id": conn_id})
raise
return ConnectionResponse.model_validate_json(resp.read())


Expand Down
28 changes: 15 additions & 13 deletions task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@
from pydantic import BaseModel, ConfigDict, Field


class AssetAliasResponse(BaseModel):
class AssetResponse(BaseModel):
"""
Asset alias schema with fields that are needed for Runtime.
Asset schema for responses with fields that are needed for Runtime.
"""

name: Annotated[str, Field(title="Name")]
uri: Annotated[str, Field(title="Uri")]
group: Annotated[str, Field(title="Group")]
extra: Annotated[dict[str, Any] | None, Field(title="Extra")] = None


class ConnectionResponse(BaseModel):
Expand Down Expand Up @@ -78,6 +80,17 @@ class IntermediateTIState(str, Enum):
DEFERRED = "deferred"


class PrevSuccessfulDagRunResponse(BaseModel):
"""
Schema for response with previous successful DagRun information for Task Template Context.
"""

data_interval_start: Annotated[datetime | None, Field(title="Data Interval Start")] = None
data_interval_end: Annotated[datetime | None, Field(title="Data Interval End")] = None
start_date: Annotated[datetime | None, Field(title="Start Date")] = None
end_date: Annotated[datetime | None, Field(title="End Date")] = None


class TIDeferredStatePayload(BaseModel):
"""
Schema for updating TaskInstance to a deferred state.
Expand Down Expand Up @@ -196,17 +209,6 @@ class TaskInstance(BaseModel):
hostname: Annotated[str | None, Field(title="Hostname")] = None


class AssetResponse(BaseModel):
"""
Asset schema for responses with fields that are needed for Runtime.
"""

name: Annotated[str, Field(title="Name")]
uri: Annotated[str, Field(title="Uri")]
group: Annotated[str, Field(title="Group")]
extra: Annotated[dict[str, Any] | None, Field(title="Extra")] = None


class DagRun(BaseModel):
"""
Schema for DagRun model with minimal required fields needed for Runtime.
Expand Down
46 changes: 38 additions & 8 deletions task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

from datetime import datetime
from typing import Annotated, Literal, Union
from uuid import UUID

from fastapi import Body
from pydantic import BaseModel, ConfigDict, Field, JsonValue
Expand All @@ -53,6 +54,7 @@
AssetResponse,
BundleInfo,
ConnectionResponse,
PrevSuccessfulDagRunResponse,
TaskInstance,
TerminalTIState,
TIDeferredStatePayload,
Expand Down Expand Up @@ -146,14 +148,36 @@ def from_variable_response(cls, variable_response: VariableResponse) -> Variable
return cls(**variable_response.model_dump(exclude_defaults=True), type="VariableResult")


class PrevSuccessfulDagRunResult(PrevSuccessfulDagRunResponse):
type: Literal["PrevSuccessfulDagRunResult"] = "PrevSuccessfulDagRunResult"

@classmethod
def from_dagrun_response(cls, prev_dag_run: PrevSuccessfulDagRunResponse) -> PrevSuccessfulDagRunResult:
"""
Get a result object from response object.
PrevSuccessfulDagRunResponse is autogenerated from the API schema, so we need to convert it to
PrevSuccessfulDagRunResult for communication between the Supervisor and the task process.
"""
return cls(**prev_dag_run.model_dump(exclude_defaults=True), type="PrevSuccessfulDagRunResult")


class ErrorResponse(BaseModel):
error: ErrorType = ErrorType.GENERIC_ERROR
detail: dict | None = None
type: Literal["ErrorResponse"] = "ErrorResponse"


ToTask = Annotated[
Union[StartupDetails, XComResult, ConnectionResult, VariableResult, ErrorResponse, AssetResult],
Union[
AssetResult,
ConnectionResult,
ErrorResponse,
PrevSuccessfulDagRunResult,
StartupDetails,
VariableResult,
XComResult,
],
Field(discriminator="type"),
]

Expand Down Expand Up @@ -261,19 +285,25 @@ class GetAssetByUri(BaseModel):
type: Literal["GetAssetByUri"] = "GetAssetByUri"


class GetPrevSuccessfulDagRun(BaseModel):
ti_id: UUID
type: Literal["GetPrevSuccessfulDagRun"] = "GetPrevSuccessfulDagRun"


ToSupervisor = Annotated[
Union[
TaskState,
GetXCom,
GetConnection,
GetVariable,
DeferTask,
GetAssetByName,
GetAssetByUri,
DeferTask,
GetConnection,
GetPrevSuccessfulDagRun,
GetVariable,
GetXCom,
PutVariable,
SetXCom,
SetRenderedFields,
RescheduleTask,
SetRenderedFields,
SetXCom,
TaskState,
],
Field(discriminator="type"),
]
27 changes: 26 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import contextlib
from collections.abc import Generator, Iterator, Mapping
from functools import cache
from typing import TYPE_CHECKING, Any, Union

import attrs
Expand All @@ -39,10 +40,17 @@
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType

if TYPE_CHECKING:
from uuid import UUID

from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.context import Context
from airflow.sdk.definitions.variable import Variable
from airflow.sdk.execution_time.comms import AssetResult, ConnectionResult, VariableResult
from airflow.sdk.execution_time.comms import (
AssetResult,
ConnectionResult,
PrevSuccessfulDagRunResponse,
VariableResult,
)

log = structlog.get_logger(logger_name="task")

Expand Down Expand Up @@ -272,6 +280,23 @@ def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset
return Asset(**msg.model_dump(exclude={"type"}))


@cache # Prevent multiple API access.
def get_previous_dagrun_success(ti_id: UUID) -> PrevSuccessfulDagRunResponse:
from airflow.sdk.execution_time.comms import (
GetPrevSuccessfulDagRun,
PrevSuccessfulDagRunResponse,
PrevSuccessfulDagRunResult,
)
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

SUPERVISOR_COMMS.send_request(log=log, msg=GetPrevSuccessfulDagRun(ti_id=ti_id))
msg = SUPERVISOR_COMMS.get_message()

if TYPE_CHECKING:
assert isinstance(msg, PrevSuccessfulDagRunResult)
return PrevSuccessfulDagRunResponse(**msg.model_dump(exclude={"type"}))


@contextlib.contextmanager
def set_current_context(context: Context) -> Generator[Context, None, None]:
"""
Expand Down
6 changes: 6 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@
GetAssetByName,
GetAssetByUri,
GetConnection,
GetPrevSuccessfulDagRun,
GetVariable,
GetXCom,
PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
SetRenderedFields,
Expand Down Expand Up @@ -798,6 +800,10 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
asset_resp = self.client.assets.get(uri=msg.uri)
asset_result = AssetResult.from_asset_response(asset_resp)
resp = asset_result.model_dump_json(exclude_unset=True).encode()
elif isinstance(msg, GetPrevSuccessfulDagRun):
dagrun_resp = self.client.task_instances.get_previous_successful_dagrun(self.id)
dagrun_result = PrevSuccessfulDagRunResult.from_dagrun_response(dagrun_resp)
resp = dagrun_result.model_dump_json(exclude_unset=True).encode()
else:
log.error("Unhandled request", msg=msg)
return
Expand Down
18 changes: 14 additions & 4 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from typing import TYPE_CHECKING, Annotated, Any, Generic, TextIO, TypeVar

import attrs
import lazy_object_proxy
import structlog
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter

Expand All @@ -52,6 +53,7 @@
MacrosAccessor,
OutletEventAccessors,
VariableAccessor,
get_previous_dagrun_success,
set_current_context,
)
from airflow.utils.net import get_hostname
Expand Down Expand Up @@ -100,10 +102,6 @@ def get_template_context(self) -> Context:
"macros": MacrosAccessor(),
# "params": validated_params,
# TODO: Make this go through Public API longer term.
# "prev_data_interval_start_success": get_prev_data_interval_start_success(),
# "prev_data_interval_end_success": get_prev_data_interval_end_success(),
# "prev_start_date_success": get_prev_start_date_success(),
# "prev_end_date_success": get_prev_end_date_success(),
# "test_mode": task_instance.test_mode,
# "triggering_asset_events": lazy_object_proxy.Proxy(get_triggering_events),
"var": {
Expand Down Expand Up @@ -134,6 +132,18 @@ def get_template_context(self) -> Context:
"ts": ts,
"ts_nodash": ts_nodash,
"ts_nodash_with_tz": ts_nodash_with_tz,
"prev_data_interval_start_success": lazy_object_proxy.Proxy(
lambda: get_previous_dagrun_success(self.id).data_interval_start
),
"prev_data_interval_end_success": lazy_object_proxy.Proxy(
lambda: get_previous_dagrun_success(self.id).data_interval_end
),
"prev_start_date_success": lazy_object_proxy.Proxy(
lambda: get_previous_dagrun_success(self.id).start_date
),
"prev_end_date_success": lazy_object_proxy.Proxy(
lambda: get_previous_dagrun_success(self.id).end_date
),
}
context.update(context_from_server)

Expand Down
Loading

0 comments on commit 41b151e

Please sign in to comment.