diff --git a/cosmos/core/airflow.py b/cosmos/core/airflow.py index 9e1d08ac1..ca0d3cd6f 100644 --- a/cosmos/core/airflow.py +++ b/cosmos/core/airflow.py @@ -1,13 +1,17 @@ import importlib +import airflow from airflow.models import BaseOperator from airflow.models.dag import DAG from airflow.utils.task_group import TaskGroup +from packaging.version import Version from cosmos.core.graph.entities import Task +from cosmos.dataset import get_dataset_alias_name from cosmos.log import get_logger logger = get_logger(__name__) +AIRFLOW_VERSION = Version(airflow.__version__) def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None) -> BaseOperator: @@ -29,6 +33,13 @@ def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None if task.owner != "": task_kwargs["owner"] = task.owner + if module_name == "local" and AIRFLOW_VERSION >= Version("2.10"): + from airflow.datasets import DatasetAlias + + # ignoring the type because older versions of Airflow raise the follow error in MyPU + # error: Incompatible types in assignment (expression has type "list[DatasetAlias]", target has type "str") [assignment] Found 1 error in 1 file (checked 3 source files) + task_kwargs["outlets"] = [DatasetAlias(name=get_dataset_alias_name(dag, task_group, task.id))] # type: ignore + airflow_task = Operator( task_id=task.id, dag=dag, diff --git a/cosmos/dataset.py b/cosmos/dataset.py new file mode 100644 index 000000000..3c43eb1bb --- /dev/null +++ b/cosmos/dataset.py @@ -0,0 +1,30 @@ +from airflow import DAG +from airflow.utils.task_group import TaskGroup + + +def get_dataset_alias_name(dag: DAG | None, task_group: TaskGroup | None, task_id: str) -> str: + """ + Given the Airflow DAG, Airflow TaskGroup and the Airflow Task ID, return the name of the + Airflow DatasetAlias associated to that task. + """ + dag_id = None + task_group_id = None + + if task_group: + if task_group.dag_id is not None: + dag_id = task_group.dag_id + if task_group.group_id is not None: + task_group_id = task_group.group_id + elif dag: + dag_id = dag.dag_id + + identifiers_list = [] + dag_id = dag_id + task_group_id = task_group_id + + if dag_id: + identifiers_list.append(dag_id) + if task_group_id: + identifiers_list.append(task_group_id.replace(".", "__")) + + return "__".join(identifiers_list) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 2ba2b18ff..e95907042 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -9,6 +9,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence +import airflow import jinja2 from airflow import DAG from airflow.exceptions import AirflowException, AirflowSkipException @@ -16,6 +17,7 @@ from airflow.utils.context import Context from airflow.utils.session import NEW_SESSION, create_session, provide_session from attr import define +from packaging.version import Version from cosmos import cache from cosmos.cache import ( @@ -24,6 +26,7 @@ is_cache_package_lockfile_enabled, ) from cosmos.constants import InvocationMode +from cosmos.dataset import get_dataset_alias_name from cosmos.dbt.project import get_partial_parse_path, has_non_empty_dependencies_file from cosmos.exceptions import AirflowCompatibilityError from cosmos.settings import LINEAGE_NAMESPACE @@ -42,6 +45,7 @@ from dbt.cli.main import dbtRunner, dbtRunnerResult from openlineage.client.run import RunEvent + from sqlalchemy.orm import Session from cosmos.config import ProfileConfig @@ -72,6 +76,8 @@ DbtTestMixin, ) +AIRFLOW_VERSION = Version(airflow.__version__) + logger = get_logger(__name__) try: @@ -387,7 +393,7 @@ def run_command( outlets = self.get_datasets("outputs") self.log.info("Inlets: %s", inlets) self.log.info("Outlets: %s", outlets) - self.register_dataset(inlets, outlets) + self.register_dataset(inlets, outlets, context) if self.partial_parse and self.cache_dir: partial_parse_file = get_partial_parse_path(tmp_dir_path) @@ -468,20 +474,26 @@ def get_datasets(self, source: Literal["inputs", "outputs"]) -> list[Dataset]: ) return datasets - def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset]) -> None: + def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset], context: Context) -> None: """ Register a list of datasets as outlets of the current task. Until Airflow 2.7, there was not a better interface to associate outlets to a task during execution. """ - with create_session() as session: - self.outlets.extend(new_outlets) - self.inlets.extend(new_inlets) - for task in self.dag.tasks: - if task.task_id == self.task_id: - task.outlets.extend(new_outlets) - task.inlets.extend(new_inlets) - DAG.bulk_write_to_db([self.dag], session=session) - session.commit() + if AIRFLOW_VERSION < Version("2.10"): + with create_session() as session: + self.outlets.extend(new_outlets) + self.inlets.extend(new_inlets) + for task in self.dag.tasks: + if task.task_id == self.task_id: + task.outlets.extend(new_outlets) + task.inlets.extend(new_inlets) + DAG.bulk_write_to_db([self.dag], session=session) + session.commit() + else: + dataset_alias_name = get_dataset_alias_name(self.dag, self.task_group, self.task_id) + for outlet in new_outlets: + context["outlet_events"][dataset_alias_name].add(outlet) + # TODO: check equivalent to inlets def get_openlineage_facets_on_complete(self, task_instance: TaskInstance) -> OperatorLineage: """