Skip to content

Commit

Permalink
Use Protocol for OutletEventAccessor (apache#45762)
Browse files Browse the repository at this point in the history
Follow-up of apache#45727 to use Protocol to allow auto-completion on IDE while not introducing runtime dep
  • Loading branch information
kaxil authored Jan 20, 2025
1 parent c68083e commit 08d0273
Show file tree
Hide file tree
Showing 13 changed files with 63 additions and 28 deletions.
6 changes: 3 additions & 3 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.protocols import RuntimeTaskInstanceProtocol
from airflow.sdk.types import OutletEventAccessorsProtocol, RuntimeTaskInstanceProtocol
from airflow.timetables.base import DataInterval
from airflow.typing_compat import Literal, TypeGuard
from airflow.utils.task_group import TaskGroup
Expand Down Expand Up @@ -2730,7 +2730,7 @@ def _run_raw_task(
)

def _register_asset_changes(
self, *, events: OutletEventAccessors, session: Session | None = None
self, *, events: OutletEventAccessorsProtocol, session: Session | None = None
) -> None:
if session:
TaskInstance._register_asset_changes_int(ti=self, events=events, session=session)
Expand All @@ -2740,7 +2740,7 @@ def _register_asset_changes(
@staticmethod
@provide_session
def _register_asset_changes_int(
ti: TaskInstance, *, events: OutletEventAccessors, session: Session = NEW_SESSION
ti: TaskInstance, *, events: OutletEventAccessorsProtocol, session: Session = NEW_SESSION
) -> None:
if TYPE_CHECKING:
assert ti.task
Expand Down
4 changes: 2 additions & 2 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetAliasEvent,
AssetAliasUniqueKey,
AssetAll,
AssetAny,
Expand All @@ -64,7 +65,7 @@
BaseAsset,
)
from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator
from airflow.sdk.execution_time.context import AssetAliasEvent, OutletEventAccessor
from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors
from airflow.serialization.dag_dependency import DagDependency
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.helpers import serialize_template_field
Expand All @@ -80,7 +81,6 @@
from airflow.utils.context import (
ConnectionAccessor,
Context,
OutletEventAccessors,
VariableAccessor,
)
from airflow.utils.db import LazySelectSequence
Expand Down
3 changes: 2 additions & 1 deletion airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from sqlalchemy.sql.expression import Select, TextClause

from airflow.models.baseoperator import BaseOperator
from airflow.sdk.types import OutletEventAccessorsProtocol

# NOTE: Please keep this in sync with the following:
# * Context in task_sdk/src/airflow/sdk/definitions/context.py
Expand Down Expand Up @@ -331,7 +332,7 @@ def context_copy_partial(source: Context, keys: Container[str]) -> Context:
return cast(Context, new)


def context_get_outlet_events(context: Context) -> OutletEventAccessors:
def context_get_outlet_events(context: Context) -> OutletEventAccessorsProtocol:
try:
return context["outlet_events"]
except KeyError:
Expand Down
4 changes: 2 additions & 2 deletions airflow/utils/operator_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
from airflow.utils.context import OutletEventAccessors
from airflow.sdk.types import OutletEventAccessorsProtocol

P = ParamSpec("P")
R = TypeVar("R")
Expand Down Expand Up @@ -230,7 +230,7 @@ def run(*args, **kwargs): ...

def ExecutionCallableRunner(
func: Callable[P, R],
outlet_events: OutletEventAccessors,
outlet_events: OutletEventAccessorsProtocol,
*,
logger: logging.Logger,
) -> _ExecutionCallableRunner:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

if TYPE_CHECKING:
try:
from airflow.sdk.definitions.protocols import RuntimeTaskInstanceProtocol as TaskInstance
from airflow.sdk.types import RuntimeTaskInstanceProtocol as TaskInstance
except ImportError:
from airflow.models import TaskInstance # type: ignore[assignment]
from airflow.utils.context import Context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

if TYPE_CHECKING:
try:
from airflow.sdk.definitions.protocols import RuntimeTaskInstanceProtocol
from airflow.sdk.types import RuntimeTaskInstanceProtocol
except ImportError:
from airflow.models import TaskInstance as RuntimeTaskInstanceProtocol # type: ignore[assignment]
from airflow.utils.context import Context
Expand Down
9 changes: 9 additions & 0 deletions task_sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,3 +660,12 @@ def as_expression(self) -> Any:
:meta private:
"""
return {"all": [o.as_expression() for o in self.objects]}


@attrs.define
class AssetAliasEvent:
"""Representation of asset event to be triggered by an asset alias."""

source_alias_name: str
dest_asset_key: AssetUniqueKey
extra: dict[str, Any]
9 changes: 6 additions & 3 deletions task_sdk/src/airflow/sdk/definitions/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
from airflow.models.operator import Operator
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.protocols import DagRunProtocol, RuntimeTaskInstanceProtocol
from airflow.sdk.types import (
DagRunProtocol,
OutletEventAccessorsProtocol,
RuntimeTaskInstanceProtocol,
)


class Context(TypedDict, total=False):
Expand All @@ -38,8 +42,7 @@ class Context(TypedDict, total=False):
dag_run: DagRunProtocol
data_interval_end: datetime | None
data_interval_start: datetime | None
# outlet_events: OutletEventAccessors
outlet_events: Any
outlet_events: OutletEventAccessorsProtocol
ds: str
ds_nodash: str
expanded_ti_count: int | None
Expand Down
10 changes: 1 addition & 9 deletions task_sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetAliasEvent,
AssetAliasUniqueKey,
AssetNameRef,
AssetRef,
Expand Down Expand Up @@ -174,15 +175,6 @@ def __eq__(self, other: object) -> bool:
return True


@attrs.define
class AssetAliasEvent:
"""Representation of asset event to be triggered by an asset alias."""

source_alias_name: str
dest_asset_key: AssetUniqueKey
extra: dict[str, Any]


@attrs.define
class OutletEventAccessor:
"""Wrapper to access an outlet asset event in template."""
Expand Down
1 change: 0 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def get_template_context(self) -> Context:
}
context.update(context_from_server)

# TODO: We should use/move TypeDict from airflow.utils.context.Context
return context

def render_templates(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from typing import TYPE_CHECKING, Any, Protocol

if TYPE_CHECKING:
from collections.abc import Iterator
from datetime import datetime

from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, BaseAssetUniqueKey
from airflow.sdk.definitions.baseoperator import BaseOperator


Expand Down Expand Up @@ -65,3 +67,28 @@ def xcom_pull(
) -> Any: ...

def xcom_push(self, key: str, value: Any) -> None: ...


class OutletEventAccessorProtocol(Protocol):
"""Protocol for managing access to a specific outlet event accessor."""

key: BaseAssetUniqueKey
extra: dict[str, Any]
asset_alias_events: list[AssetAliasEvent]

def __init__(
self,
*,
key: BaseAssetUniqueKey,
extra: dict[str, Any],
asset_alias_events: list[AssetAliasEvent],
) -> None: ...
def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: ...


class OutletEventAccessorsProtocol(Protocol):
"""Protocol for managing access to outlet event accessors."""

def __iter__(self) -> Iterator[Asset | AssetAlias]: ...
def __len__(self) -> int: ...
def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessorProtocol: ...
9 changes: 7 additions & 2 deletions task_sdk/tests/execution_time/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,18 @@
import pytest

from airflow.sdk import get_current_context
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasUniqueKey, AssetUniqueKey
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetAliasEvent,
AssetAliasUniqueKey,
AssetUniqueKey,
)
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.variable import Variable
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import AssetResult, ConnectionResult, ErrorResponse, VariableResult
from airflow.sdk.execution_time.context import (
AssetAliasEvent,
ConnectionAccessor,
OutletEventAccessor,
OutletEventAccessors,
Expand Down
5 changes: 2 additions & 3 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,12 @@
from airflow.models.xcom_arg import XComArg
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey
from airflow.sdk.execution_time.context import AssetAliasEvent, OutletEventAccessor
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetUniqueKey
from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.triggers.base import BaseTrigger
from airflow.utils import timezone
from airflow.utils.context import OutletEventAccessors
from airflow.utils.db import LazySelectSequence
from airflow.utils.operator_resources import Resources
from airflow.utils.state import DagRunState, State
Expand Down

0 comments on commit 08d0273

Please sign in to comment.