Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions task-sdk/src/airflow/sdk/bases/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from __future__ import annotations

from typing import Any
from typing import Any, Protocol

import structlog

Expand All @@ -26,6 +26,13 @@
log = structlog.get_logger(logger_name="task")


class TIKeyProtocol(Protocol):
dag_id: str
task_id: str
run_id: str
map_index: int


class BaseXCom:
"""BaseXcom is an interface now to interact with XCom backends."""

Expand Down Expand Up @@ -116,7 +123,7 @@ def _set_xcom_in_db(
def get_value(
cls,
*,
ti_key: Any,
ti_key: TIKeyProtocol,
key: str,
) -> Any:
"""
Expand Down
75 changes: 75 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,19 @@

from __future__ import annotations

from collections.abc import Iterator
from datetime import datetime
from functools import cached_property
from typing import Annotated, Any, Literal, Union
from uuid import UUID

import attrs
from fastapi import Body
from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, field_serializer

from airflow.sdk.api.datamodels._generated import (
AssetEventDagRunReference,
AssetEventResponse,
AssetEventsResponse,
AssetResponse,
BundleInfo,
Expand Down Expand Up @@ -108,6 +113,50 @@ def from_asset_response(cls, asset_response: AssetResponse) -> AssetResult:
return cls(**asset_response.model_dump(exclude_defaults=True), type="AssetResult")


@attrs.define(kw_only=True)
class AssetEventSourceTaskInstance:
"""Used in AssetEventResult."""

dag_id: str
task_id: str
run_id: str
map_index: int

def xcom_pull(
self,
*,
key: str = "return_value", # TODO: Make this a constant; see RuntimeTaskInstance.
default: Any = None,
) -> Any:
from airflow.sdk.execution_time.xcom import XCom

if (value := XCom.get_value(ti_key=self, key=key)) is None:
return default
return value


class AssetEventResult(AssetEventResponse):
"""Used in AssetEventsResult."""

@classmethod
def from_asset_event_response(cls, asset_event_response: AssetEventResponse) -> AssetEventResult:
return cls(**asset_event_response.model_dump(exclude_defaults=True))

@cached_property
def source_task_instance(self) -> AssetEventSourceTaskInstance | None:
if not (self.source_task_id and self.source_dag_id and self.source_run_id):
return None
if self.source_map_index is None:
return None

return AssetEventSourceTaskInstance(
dag_id=self.source_dag_id,
task_id=self.source_task_id,
run_id=self.source_run_id,
map_index=self.source_map_index,
)


class AssetEventsResult(AssetEventsResponse):
"""Response to GetAssetEvent request."""

Expand All @@ -129,6 +178,32 @@ def from_asset_events_response(cls, asset_events_response: AssetEventsResponse)
type="AssetEventsResult",
)

def iter_asset_event_results(self) -> Iterator[AssetEventResult]:
return (AssetEventResult.from_asset_event_response(event) for event in self.asset_events)


class AssetEventDagRunReferenceResult(AssetEventDagRunReference):
@classmethod
def from_asset_event_dag_run_reference(
cls,
asset_event_dag_run_reference: AssetEventDagRunReference,
) -> AssetEventDagRunReferenceResult:
return cls(**asset_event_dag_run_reference.model_dump(exclude_defaults=True))

@cached_property
def source_task_instance(self) -> AssetEventSourceTaskInstance | None:
if not (self.source_task_id and self.source_dag_id and self.source_run_id):
return None
if self.source_map_index is None:
return None

return AssetEventSourceTaskInstance(
dag_id=self.source_dag_id,
task_id=self.source_task_id,
run_id=self.source_run_id,
map_index=self.source_map_index,
)


class XComResult(XComResponse):
"""Response to ReadXCom request."""
Expand Down
23 changes: 12 additions & 11 deletions task-sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@
from uuid import UUID

from airflow.sdk import Variable
from airflow.sdk.api.datamodels._generated import AssetEventDagRunReference, AssetEventResponse
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.context import Context
from airflow.sdk.execution_time.comms import (
AssetEventDagRunReferenceResult,
AssetEventResult,
AssetEventsResult,
AssetResult,
ConnectionResult,
Expand Down Expand Up @@ -331,20 +332,20 @@ def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset
@attrs.define
class TriggeringAssetEventsAccessor(
_AssetRefResolutionMixin,
Mapping[Union[Asset, AssetAlias, AssetRef], Sequence["AssetEventDagRunReference"]],
Mapping[Union[Asset, AssetAlias, AssetRef], Sequence["AssetEventDagRunReferenceResult"]],
):
"""Lazy mapping of triggering asset events."""

_events: Mapping[BaseAssetUniqueKey, Sequence[AssetEventDagRunReference]]
_events: Mapping[BaseAssetUniqueKey, Sequence[AssetEventDagRunReferenceResult]]

@classmethod
def build(cls, events: Iterable[AssetEventDagRunReference]) -> TriggeringAssetEventsAccessor:
collected: dict[BaseAssetUniqueKey, list[AssetEventDagRunReference]] = collections.defaultdict(list)
def build(cls, events: Iterable[AssetEventDagRunReferenceResult]) -> TriggeringAssetEventsAccessor:
coll: dict[BaseAssetUniqueKey, list[AssetEventDagRunReferenceResult]] = collections.defaultdict(list)
for event in events:
collected[AssetUniqueKey(name=event.asset.name, uri=event.asset.uri)].append(event)
coll[AssetUniqueKey(name=event.asset.name, uri=event.asset.uri)].append(event)
for alias in event.source_aliases:
collected[AssetAliasUniqueKey(name=alias.name)].append(event)
return cls(collected)
coll[AssetAliasUniqueKey(name=alias.name)].append(event)
return cls(coll)

def __str__(self) -> str:
return f"TriggeringAssetEventAccessor(_events={self._events})"
Expand All @@ -358,7 +359,7 @@ def __iter__(self) -> Iterator[Asset | AssetAlias]:
def __len__(self) -> int:
return len(self._events)

def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> Sequence[AssetEventDagRunReference]:
def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> Sequence[AssetEventDagRunReferenceResult]:
hashable_key: BaseAssetUniqueKey
if isinstance(key, Asset):
hashable_key = AssetUniqueKey.from_asset(key)
Expand Down Expand Up @@ -485,7 +486,7 @@ def __iter__(self) -> Iterator[Asset | AssetAlias]:
def __len__(self) -> int:
return len(self._inlets)

def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> list[AssetEventResponse]:
def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> list[AssetEventResult]:
from airflow.sdk.definitions.asset import Asset
from airflow.sdk.execution_time.comms import (
ErrorResponse,
Expand Down Expand Up @@ -527,7 +528,7 @@ def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> list[AssetEve
if TYPE_CHECKING:
assert isinstance(msg, AssetEventsResult)

return msg.asset_events
return list(msg.iter_asset_event_results())


@cache # Prevent multiple API access.
Expand Down
7 changes: 5 additions & 2 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
from airflow.sdk.execution_time.callback_runner import create_executable_runner
from airflow.sdk.execution_time.comms import (
AssetEventDagRunReferenceResult,
DagRunStateResult,
DeferTask,
DRCount,
Expand Down Expand Up @@ -170,7 +171,6 @@ def get_template_context(self) -> Context:
"params": validated_params,
# TODO: Make this go through Public API longer term.
# "test_mode": task_instance.test_mode,
# "triggering_asset_events": lazy_object_proxy.Proxy(get_triggering_events),
"var": {
"json": VariableAccessor(deserialize_json=True),
"value": VariableAccessor(deserialize_json=False),
Expand All @@ -182,7 +182,10 @@ def get_template_context(self) -> Context:
context_from_server: Context = {
# TODO: Assess if we need to pass these through timezone.coerce_datetime
"dag_run": dag_run, # type: ignore[typeddict-item] # Removable after #46522
"triggering_asset_events": TriggeringAssetEventsAccessor.build(dag_run.consumed_asset_events),
"triggering_asset_events": TriggeringAssetEventsAccessor.build(
AssetEventDagRunReferenceResult.from_asset_event_dag_run_reference(event)
for event in dag_run.consumed_asset_events
),
"task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{dag_run.run_id}",
"task_reschedule_count": from_server.task_reschedule_count or 0,
"prev_start_date_success": lazy_object_proxy.Proxy(
Expand Down
Loading
Loading