Skip to content

Commit

Permalink
WIP: support dataset aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana committed Sep 23, 2024
1 parent 11de5ba commit 0cc9774
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 11 deletions.
11 changes: 11 additions & 0 deletions cosmos/core/airflow.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import importlib

import airflow
from airflow.models import BaseOperator
from airflow.models.dag import DAG
from airflow.utils.task_group import TaskGroup
from packaging.version import Version

from cosmos.core.graph.entities import Task
from cosmos.dataset import get_dataset_alias_name
from cosmos.log import get_logger

logger = get_logger(__name__)
AIRFLOW_VERSION = Version(airflow.__version__)


def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None) -> BaseOperator:
Expand All @@ -29,6 +33,13 @@ def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None
if task.owner != "":
task_kwargs["owner"] = task.owner

if module_name == "local" and AIRFLOW_VERSION >= Version("2.10"):
from airflow.datasets import DatasetAlias

# ignoring the type because older versions of Airflow raise the follow error in MyPU
# error: Incompatible types in assignment (expression has type "list[DatasetAlias]", target has type "str") [assignment] Found 1 error in 1 file (checked 3 source files)
task_kwargs["outlets"] = [DatasetAlias(name=get_dataset_alias_name(dag, task_group, task.id))] # type: ignore

airflow_task = Operator(
task_id=task.id,
dag=dag,
Expand Down
30 changes: 30 additions & 0 deletions cosmos/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from airflow import DAG
from airflow.utils.task_group import TaskGroup


def get_dataset_alias_name(dag: DAG | None, task_group: TaskGroup | None, task_id: str) -> str:
"""
Given the Airflow DAG, Airflow TaskGroup and the Airflow Task ID, return the name of the
Airflow DatasetAlias associated to that task.
"""
dag_id = None
task_group_id = None

if task_group:
if task_group.dag_id is not None:
dag_id = task_group.dag_id
if task_group.group_id is not None:
task_group_id = task_group.group_id
elif dag:
dag_id = dag.dag_id

identifiers_list = []
dag_id = dag_id
task_group_id = task_group_id

if dag_id:
identifiers_list.append(dag_id)
if task_group_id:
identifiers_list.append(task_group_id.replace(".", "__"))

return "__".join(identifiers_list)
34 changes: 23 additions & 11 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence

import airflow
import jinja2
from airflow import DAG
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.models.taskinstance import TaskInstance
from airflow.utils.context import Context
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from attr import define
from packaging.version import Version

from cosmos import cache
from cosmos.cache import (
Expand All @@ -24,6 +26,7 @@
is_cache_package_lockfile_enabled,
)
from cosmos.constants import InvocationMode
from cosmos.dataset import get_dataset_alias_name
from cosmos.dbt.project import get_partial_parse_path, has_non_empty_dependencies_file
from cosmos.exceptions import AirflowCompatibilityError
from cosmos.settings import LINEAGE_NAMESPACE
Expand All @@ -42,6 +45,7 @@
from dbt.cli.main import dbtRunner, dbtRunnerResult
from openlineage.client.run import RunEvent


from sqlalchemy.orm import Session

from cosmos.config import ProfileConfig
Expand Down Expand Up @@ -72,6 +76,8 @@
DbtTestMixin,
)

AIRFLOW_VERSION = Version(airflow.__version__)

logger = get_logger(__name__)

try:
Expand Down Expand Up @@ -387,7 +393,7 @@ def run_command(
outlets = self.get_datasets("outputs")
self.log.info("Inlets: %s", inlets)
self.log.info("Outlets: %s", outlets)
self.register_dataset(inlets, outlets)
self.register_dataset(inlets, outlets, context)

if self.partial_parse and self.cache_dir:
partial_parse_file = get_partial_parse_path(tmp_dir_path)
Expand Down Expand Up @@ -468,20 +474,26 @@ def get_datasets(self, source: Literal["inputs", "outputs"]) -> list[Dataset]:
)
return datasets

def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset]) -> None:
def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset], context: Context) -> None:
"""
Register a list of datasets as outlets of the current task.
Until Airflow 2.7, there was not a better interface to associate outlets to a task during execution.
"""
with create_session() as session:
self.outlets.extend(new_outlets)
self.inlets.extend(new_inlets)
for task in self.dag.tasks:
if task.task_id == self.task_id:
task.outlets.extend(new_outlets)
task.inlets.extend(new_inlets)
DAG.bulk_write_to_db([self.dag], session=session)
session.commit()
if AIRFLOW_VERSION < Version("2.10"):
with create_session() as session:
self.outlets.extend(new_outlets)
self.inlets.extend(new_inlets)
for task in self.dag.tasks:
if task.task_id == self.task_id:
task.outlets.extend(new_outlets)
task.inlets.extend(new_inlets)
DAG.bulk_write_to_db([self.dag], session=session)
session.commit()
else:
dataset_alias_name = get_dataset_alias_name(self.dag, self.task_group, self.task_id)
for outlet in new_outlets:
context["outlet_events"][dataset_alias_name].add(outlet)
# TODO: check equivalent to inlets

def get_openlineage_facets_on_complete(self, task_instance: TaskInstance) -> OperatorLineage:
"""
Expand Down

0 comments on commit 0cc9774

Please sign in to comment.