Skip to content

Commit 3ed428f

Browse files
authored
Allow pulling XCom from inlet events (#49054)
1 parent 8324dae commit 3ed428f

File tree

6 files changed

+199
-28
lines changed

6 files changed

+199
-28
lines changed

task-sdk/src/airflow/sdk/bases/xcom.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from __future__ import annotations
1919

20-
from typing import Any
20+
from typing import Any, Protocol
2121

2222
import structlog
2323

@@ -26,6 +26,13 @@
2626
log = structlog.get_logger(logger_name="task")
2727

2828

29+
class TIKeyProtocol(Protocol):
30+
dag_id: str
31+
task_id: str
32+
run_id: str
33+
map_index: int
34+
35+
2936
class BaseXCom:
3037
"""BaseXcom is an interface now to interact with XCom backends."""
3138

@@ -116,7 +123,7 @@ def _set_xcom_in_db(
116123
def get_value(
117124
cls,
118125
*,
119-
ti_key: Any,
126+
ti_key: TIKeyProtocol,
120127
key: str,
121128
) -> Any:
122129
"""

task-sdk/src/airflow/sdk/execution_time/comms.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,19 @@
4343

4444
from __future__ import annotations
4545

46+
from collections.abc import Iterator
4647
from datetime import datetime
48+
from functools import cached_property
4749
from typing import Annotated, Any, Literal, Union
4850
from uuid import UUID
4951

52+
import attrs
5053
from fastapi import Body
5154
from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, field_serializer
5255

5356
from airflow.sdk.api.datamodels._generated import (
57+
AssetEventDagRunReference,
58+
AssetEventResponse,
5459
AssetEventsResponse,
5560
AssetResponse,
5661
BundleInfo,
@@ -108,6 +113,50 @@ def from_asset_response(cls, asset_response: AssetResponse) -> AssetResult:
108113
return cls(**asset_response.model_dump(exclude_defaults=True), type="AssetResult")
109114

110115

116+
@attrs.define(kw_only=True)
117+
class AssetEventSourceTaskInstance:
118+
"""Used in AssetEventResult."""
119+
120+
dag_id: str
121+
task_id: str
122+
run_id: str
123+
map_index: int
124+
125+
def xcom_pull(
126+
self,
127+
*,
128+
key: str = "return_value", # TODO: Make this a constant; see RuntimeTaskInstance.
129+
default: Any = None,
130+
) -> Any:
131+
from airflow.sdk.execution_time.xcom import XCom
132+
133+
if (value := XCom.get_value(ti_key=self, key=key)) is None:
134+
return default
135+
return value
136+
137+
138+
class AssetEventResult(AssetEventResponse):
139+
"""Used in AssetEventsResult."""
140+
141+
@classmethod
142+
def from_asset_event_response(cls, asset_event_response: AssetEventResponse) -> AssetEventResult:
143+
return cls(**asset_event_response.model_dump(exclude_defaults=True))
144+
145+
@cached_property
146+
def source_task_instance(self) -> AssetEventSourceTaskInstance | None:
147+
if not (self.source_task_id and self.source_dag_id and self.source_run_id):
148+
return None
149+
if self.source_map_index is None:
150+
return None
151+
152+
return AssetEventSourceTaskInstance(
153+
dag_id=self.source_dag_id,
154+
task_id=self.source_task_id,
155+
run_id=self.source_run_id,
156+
map_index=self.source_map_index,
157+
)
158+
159+
111160
class AssetEventsResult(AssetEventsResponse):
112161
"""Response to GetAssetEvent request."""
113162

@@ -129,6 +178,32 @@ def from_asset_events_response(cls, asset_events_response: AssetEventsResponse)
129178
type="AssetEventsResult",
130179
)
131180

181+
def iter_asset_event_results(self) -> Iterator[AssetEventResult]:
182+
return (AssetEventResult.from_asset_event_response(event) for event in self.asset_events)
183+
184+
185+
class AssetEventDagRunReferenceResult(AssetEventDagRunReference):
186+
@classmethod
187+
def from_asset_event_dag_run_reference(
188+
cls,
189+
asset_event_dag_run_reference: AssetEventDagRunReference,
190+
) -> AssetEventDagRunReferenceResult:
191+
return cls(**asset_event_dag_run_reference.model_dump(exclude_defaults=True))
192+
193+
@cached_property
194+
def source_task_instance(self) -> AssetEventSourceTaskInstance | None:
195+
if not (self.source_task_id and self.source_dag_id and self.source_run_id):
196+
return None
197+
if self.source_map_index is None:
198+
return None
199+
200+
return AssetEventSourceTaskInstance(
201+
dag_id=self.source_dag_id,
202+
task_id=self.source_task_id,
203+
run_id=self.source_run_id,
204+
map_index=self.source_map_index,
205+
)
206+
132207

133208
class XComResult(XComResponse):
134209
"""Response to ReadXCom request."""

task-sdk/src/airflow/sdk/execution_time/context.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,12 @@
4444
from uuid import UUID
4545

4646
from airflow.sdk import Variable
47-
from airflow.sdk.api.datamodels._generated import AssetEventDagRunReference, AssetEventResponse
4847
from airflow.sdk.bases.operator import BaseOperator
4948
from airflow.sdk.definitions.connection import Connection
5049
from airflow.sdk.definitions.context import Context
5150
from airflow.sdk.execution_time.comms import (
51+
AssetEventDagRunReferenceResult,
52+
AssetEventResult,
5253
AssetEventsResult,
5354
AssetResult,
5455
ConnectionResult,
@@ -377,20 +378,20 @@ def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset
377378
@attrs.define
378379
class TriggeringAssetEventsAccessor(
379380
_AssetRefResolutionMixin,
380-
Mapping[Union[Asset, AssetAlias, AssetRef], Sequence["AssetEventDagRunReference"]],
381+
Mapping[Union[Asset, AssetAlias, AssetRef], Sequence["AssetEventDagRunReferenceResult"]],
381382
):
382383
"""Lazy mapping of triggering asset events."""
383384

384-
_events: Mapping[BaseAssetUniqueKey, Sequence[AssetEventDagRunReference]]
385+
_events: Mapping[BaseAssetUniqueKey, Sequence[AssetEventDagRunReferenceResult]]
385386

386387
@classmethod
387-
def build(cls, events: Iterable[AssetEventDagRunReference]) -> TriggeringAssetEventsAccessor:
388-
collected: dict[BaseAssetUniqueKey, list[AssetEventDagRunReference]] = collections.defaultdict(list)
388+
def build(cls, events: Iterable[AssetEventDagRunReferenceResult]) -> TriggeringAssetEventsAccessor:
389+
coll: dict[BaseAssetUniqueKey, list[AssetEventDagRunReferenceResult]] = collections.defaultdict(list)
389390
for event in events:
390-
collected[AssetUniqueKey(name=event.asset.name, uri=event.asset.uri)].append(event)
391+
coll[AssetUniqueKey(name=event.asset.name, uri=event.asset.uri)].append(event)
391392
for alias in event.source_aliases:
392-
collected[AssetAliasUniqueKey(name=alias.name)].append(event)
393-
return cls(collected)
393+
coll[AssetAliasUniqueKey(name=alias.name)].append(event)
394+
return cls(coll)
394395

395396
def __str__(self) -> str:
396397
return f"TriggeringAssetEventAccessor(_events={self._events})"
@@ -404,7 +405,7 @@ def __iter__(self) -> Iterator[Asset | AssetAlias]:
404405
def __len__(self) -> int:
405406
return len(self._events)
406407

407-
def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> Sequence[AssetEventDagRunReference]:
408+
def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> Sequence[AssetEventDagRunReferenceResult]:
408409
hashable_key: BaseAssetUniqueKey
409410
if isinstance(key, Asset):
410411
hashable_key = AssetUniqueKey.from_asset(key)
@@ -531,7 +532,7 @@ def __iter__(self) -> Iterator[Asset | AssetAlias]:
531532
def __len__(self) -> int:
532533
return len(self._inlets)
533534

534-
def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> list[AssetEventResponse]:
535+
def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> list[AssetEventResult]:
535536
from airflow.sdk.definitions.asset import Asset
536537
from airflow.sdk.execution_time.comms import (
537538
ErrorResponse,
@@ -573,7 +574,7 @@ def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> list[AssetEve
573574
if TYPE_CHECKING:
574575
assert isinstance(msg, AssetEventsResult)
575576

576-
return msg.asset_events
577+
return list(msg.iter_asset_event_results())
577578

578579

579580
@cache # Prevent multiple API access.

task-sdk/src/airflow/sdk/execution_time/task_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
5757
from airflow.sdk.execution_time.callback_runner import create_executable_runner
5858
from airflow.sdk.execution_time.comms import (
59+
AssetEventDagRunReferenceResult,
5960
DagRunStateResult,
6061
DeferTask,
6162
DRCount,
@@ -170,7 +171,6 @@ def get_template_context(self) -> Context:
170171
"params": validated_params,
171172
# TODO: Make this go through Public API longer term.
172173
# "test_mode": task_instance.test_mode,
173-
# "triggering_asset_events": lazy_object_proxy.Proxy(get_triggering_events),
174174
"var": {
175175
"json": VariableAccessor(deserialize_json=True),
176176
"value": VariableAccessor(deserialize_json=False),
@@ -182,7 +182,10 @@ def get_template_context(self) -> Context:
182182
context_from_server: Context = {
183183
# TODO: Assess if we need to pass these through timezone.coerce_datetime
184184
"dag_run": dag_run, # type: ignore[typeddict-item] # Removable after #46522
185-
"triggering_asset_events": TriggeringAssetEventsAccessor.build(dag_run.consumed_asset_events),
185+
"triggering_asset_events": TriggeringAssetEventsAccessor.build(
186+
AssetEventDagRunReferenceResult.from_asset_event_dag_run_reference(event)
187+
for event in dag_run.consumed_asset_events
188+
),
186189
"task_instance_key_str": f"{self.task.dag_id}__{self.task.task_id}__{dag_run.run_id}",
187190
"task_reschedule_count": from_server.task_reschedule_count or 0,
188191
"prev_start_date_success": lazy_object_proxy.Proxy(

0 commit comments

Comments
 (0)