Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Support dataset aliases #1217

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion cosmos/core/airflow.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
from __future__ import annotations

import importlib

import airflow
from airflow.models import BaseOperator
from airflow.models.dag import DAG
from airflow.utils.task_group import TaskGroup
from packaging.version import Version

from cosmos.core.graph.entities import Task

# from cosmos.dataset import get_dataset_alias_name
from cosmos.log import get_logger

logger = get_logger(__name__)
AIRFLOW_VERSION = Version(airflow.__version__)


def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None) -> BaseOperator:
def get_airflow_task(task: Task, dag: DAG, task_group: TaskGroup | None = None) -> BaseOperator:
"""
Get the Airflow Operator class for a Task.

Expand Down
33 changes: 33 additions & 0 deletions cosmos/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotations

from airflow import DAG
from airflow.utils.task_group import TaskGroup


def get_dataset_alias_name(dag: DAG | None, task_group: TaskGroup | None, task_id: str) -> str:
"""
Given the Airflow DAG, Airflow TaskGroup and the Airflow Task ID, return the name of the
Airflow DatasetAlias associated to that task.
"""
dag_id = None
task_group_id = None

if task_group:
if task_group.dag_id is not None:
dag_id = task_group.dag_id
if task_group.group_id is not None:
task_group_id = task_group.group_id
task_group_id = task_group_id.replace(".", "__")
elif dag:
dag_id = dag.dag_id

identifiers_list = []

if dag_id:
identifiers_list.append(dag_id)
if task_group_id:
identifiers_list.append(task_group_id)

identifiers_list.append(task_id)

return "__".join(identifiers_list)
54 changes: 42 additions & 12 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence

import airflow
import jinja2
from airflow import DAG
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.models.taskinstance import TaskInstance
from airflow.utils.context import Context
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from attr import define
from packaging.version import Version

from cosmos import cache
from cosmos.cache import (
Expand All @@ -24,6 +26,7 @@
is_cache_package_lockfile_enabled,
)
from cosmos.constants import InvocationMode
from cosmos.dataset import get_dataset_alias_name
from cosmos.dbt.project import get_partial_parse_path, has_non_empty_dependencies_file
from cosmos.exceptions import AirflowCompatibilityError
from cosmos.settings import LINEAGE_NAMESPACE
Expand All @@ -42,6 +45,7 @@
from dbt.cli.main import dbtRunner, dbtRunnerResult
from openlineage.client.run import RunEvent


from sqlalchemy.orm import Session

from cosmos.config import ProfileConfig
Expand Down Expand Up @@ -72,6 +76,8 @@
DbtTestMixin,
)

AIRFLOW_VERSION = Version(airflow.__version__)

logger = get_logger(__name__)

try:
Expand Down Expand Up @@ -125,6 +131,7 @@ class DbtLocalBaseOperator(AbstractDbtBaseOperator):

def __init__(
self,
task_id: str,
profile_config: ProfileConfig,
invocation_mode: InvocationMode | None = None,
install_deps: bool = False,
Expand All @@ -133,6 +140,7 @@ def __init__(
append_env: bool = True,
**kwargs: Any,
) -> None:
self.task_id = task_id
self.profile_config = profile_config
self.callback = callback
self.compiled_sql = ""
Expand All @@ -145,7 +153,19 @@ def __init__(
self._dbt_runner: dbtRunner | None = None
if self.invocation_mode:
self._set_invocation_methods()
super().__init__(**kwargs)

if AIRFLOW_VERSION >= Version("2.10"):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only run this if emit_datasets is set to true

from airflow.datasets import DatasetAlias

# ignoring the type because older versions of Airflow raise the follow error in mypy
# error: Incompatible types in assignment (expression has type "list[DatasetAlias]", target has type "str")
dag_id = kwargs.get("dag")
task_group_id = kwargs.get("task_group")
kwargs["outlets"] = [
DatasetAlias(name=get_dataset_alias_name(dag_id, task_group_id, task_id))
] # type: ignore

super().__init__(task_id=task_id, **kwargs)

# For local execution mode, we're consistent with the LoadMode.DBT_LS command in forwarding the environment
# variables to the subprocess by default. Although this behavior is designed for ExecuteMode.LOCAL and
Expand Down Expand Up @@ -387,7 +407,7 @@ def run_command(
outlets = self.get_datasets("outputs")
self.log.info("Inlets: %s", inlets)
self.log.info("Outlets: %s", outlets)
self.register_dataset(inlets, outlets)
self.register_dataset(inlets, outlets, context)

if self.partial_parse and self.cache_dir:
partial_parse_file = get_partial_parse_path(tmp_dir_path)
Expand Down Expand Up @@ -468,20 +488,30 @@ def get_datasets(self, source: Literal["inputs", "outputs"]) -> list[Dataset]:
)
return datasets

def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset]) -> None:
def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset], context: Context) -> None:
"""
Register a list of datasets as outlets of the current task.
Until Airflow 2.7, there was not a better interface to associate outlets to a task during execution.
This works before Airflow 2.10 with a few limitations, as described in the ticket:
TODO: add the link to the GH issue related to orphaned nodes
"""
with create_session() as session:
self.outlets.extend(new_outlets)
self.inlets.extend(new_inlets)
for task in self.dag.tasks:
if task.task_id == self.task_id:
task.outlets.extend(new_outlets)
task.inlets.extend(new_inlets)
DAG.bulk_write_to_db([self.dag], session=session)
session.commit()
if AIRFLOW_VERSION < Version("2.10"):
logger.info("Assigning inlets/outlets without DatasetAlias")
with create_session() as session:
self.outlets.extend(new_outlets)
self.inlets.extend(new_inlets)
for task in self.dag.tasks:
if task.task_id == self.task_id:
task.outlets.extend(new_outlets)
task.inlets.extend(new_inlets)
DAG.bulk_write_to_db([self.dag], session=session)
session.commit()
else:
logger.info("Assigning inlets/outlets with DatasetAlias")
dataset_alias_name = get_dataset_alias_name(self.dag, self.task_group, self.task_id)
for outlet in new_outlets:
context["outlet_events"][dataset_alias_name].add(outlet)
# TODO: check equivalent to inlets

def get_openlineage_facets_on_complete(self, task_instance: TaskInstance) -> OperatorLineage:
"""
Expand Down
57 changes: 56 additions & 1 deletion tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,10 @@ def test_dbt_test_local_operator_invocation_mode_methods(mock_extract_log_issues

@pytest.mark.skipif(
version.parse(airflow_version) < version.parse("2.4")
or version.parse(airflow_version) >= version.parse("2.10")
or version.parse(airflow_version) in PARTIALLY_SUPPORTED_AIRFLOW_VERSIONS,
reason="Airflow DAG did not have datasets until the 2.4 release, inlets and outlets do not work by default in Airflow 2.9.0 and 2.9.1",
reason="Airflow DAG did not have datasets until the 2.4 release, inlets and outlets do not work by default in Airflow 2.9.0 and 2.9.1. \n"
"From Airflow 2.10 onwards, we started using DatasetAlias, which changed this behaviour.",
)
@pytest.mark.integration
def test_run_operator_dataset_inlets_and_outlets(caplog):
Expand Down Expand Up @@ -453,6 +455,59 @@ def test_run_operator_dataset_inlets_and_outlets(caplog):
assert test_operator.outlets == []


@pytest.mark.skipif(
version.parse(airflow_version) < version.parse("2.10"),
reason="From Airflow 2.10 onwards, we started using DatasetAlias, which changed this behaviour.",
)
@pytest.mark.integration
def test_run_operator_dataset_inlets_and_outlets_airflow_210_onwards(caplog):
from airflow.models.dataset import DatasetAliasModel, DatasetModel
from sqlalchemy import select

with DAG("test-id-1", start_date=datetime(2022, 1, 1)) as dag:
seed_operator = DbtSeedLocalOperator(
profile_config=real_profile_config,
project_dir=DBT_PROJ_DIR,
task_id="seed",
dag=dag,
dbt_cmd_flags=["--select", "raw_customers"],
install_deps=True,
append_env=True,
)
run_operator = DbtRunLocalOperator(
profile_config=real_profile_config,
project_dir=DBT_PROJ_DIR,
task_id="run",
dag=dag,
dbt_cmd_flags=["--models", "stg_customers"],
install_deps=True,
append_env=True,
)
test_operator = DbtTestLocalOperator(
profile_config=real_profile_config,
project_dir=DBT_PROJ_DIR,
task_id="test",
dag=dag,
dbt_cmd_flags=["--models", "stg_customers"],
install_deps=True,
append_env=True,
)
seed_operator >> run_operator >> test_operator

dag_run, session = run_test_dag(dag)

assert session.scalars(select(DatasetModel)).all()
assert session.scalars(select(DatasetAliasModel)).all()
assert False
# assert session == session
# dataset_model = session.scalars(select(DatasetModel).where(DatasetModel.uri == "<something>"))
# assert dataset_model == 1
# dataset_alias_models = dataset_model.aliases # Aliases associated to the URI.


# session.query(Dataset).filter_by


@pytest.mark.skipif(
version.parse(airflow_version) not in PARTIALLY_SUPPORTED_AIRFLOW_VERSIONS,
reason="Airflow 2.9.0 and 2.9.1 have a breaking change in Dataset URIs",
Expand Down
Loading