diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b0c33dfc6..baa535d4b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,7 +2,7 @@ name: test on: push: # Run on pushes to the default branch - branches: [main] + branches: [main,poc-dbt-compile-task] pull_request_target: # Also run on pull requests originated from forks branches: [main] @@ -176,6 +176,8 @@ jobs: POSTGRES_DB: postgres POSTGRES_SCHEMA: public POSTGRES_PORT: 5432 + AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/" + AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn - name: Upload coverage to Github uses: actions/upload-artifact@v4 @@ -248,6 +250,8 @@ jobs: POSTGRES_DB: postgres POSTGRES_SCHEMA: public POSTGRES_PORT: 5432 + AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/" + AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn - name: Upload coverage to Github uses: actions/upload-artifact@v4 @@ -316,6 +320,8 @@ jobs: POSTGRES_DB: postgres POSTGRES_SCHEMA: public POSTGRES_PORT: 5432 + AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/" + AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn - name: Upload coverage to Github uses: actions/upload-artifact@v4 @@ -393,6 +399,8 @@ jobs: POSTGRES_DB: postgres POSTGRES_SCHEMA: public POSTGRES_PORT: 5432 + AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/" + AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn - name: Upload coverage to Github uses: actions/upload-artifact@v4 @@ -537,6 +545,8 @@ jobs: POSTGRES_DB: postgres POSTGRES_SCHEMA: public POSTGRES_PORT: 5432 + AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/" + AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn - name: Upload coverage to Github uses: actions/upload-artifact@v4 diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 17ee22c95..9de21292e 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -8,6 +8,7 @@ from cosmos.config import RenderConfig from cosmos.constants import ( + DBT_COMPILE_TASK_ID, DEFAULT_DBT_RESOURCES, TESTABLE_DBT_RESOURCES, DbtResourceType, @@ -252,6 +253,31 @@ def generate_task_or_group( return task_or_group +def _add_dbt_compile_task( + nodes: dict[str, DbtNode], + dag: DAG, + execution_mode: ExecutionMode, + task_args: dict[str, Any], + tasks_map: dict[str, Any], + task_group: TaskGroup | None, +) -> None: + if execution_mode != ExecutionMode.AIRFLOW_ASYNC: + return + + compile_task_metadata = TaskMetadata( + id=DBT_COMPILE_TASK_ID, + operator_class="cosmos.operators.airflow_async.DbtCompileAirflowAsyncOperator", + arguments=task_args, + extra_context={}, + ) + compile_airflow_task = create_airflow_task(compile_task_metadata, dag, task_group=task_group) + tasks_map[DBT_COMPILE_TASK_ID] = compile_airflow_task + + for node_id, node in nodes.items(): + if not node.depends_on and node_id in tasks_map: + tasks_map[DBT_COMPILE_TASK_ID] >> tasks_map[node_id] + + def build_airflow_graph( nodes: dict[str, DbtNode], dag: DAG, # Airflow-specific - parent DAG where to associate tasks and (optional) task groups @@ -332,11 +358,14 @@ def build_airflow_graph( for leaf_node_id in leaves_ids: tasks_map[leaf_node_id] >> test_task + _add_dbt_compile_task(nodes, dag, execution_mode, task_args, tasks_map, task_group) + create_airflow_task_dependencies(nodes, tasks_map) def create_airflow_task_dependencies( - nodes: dict[str, DbtNode], tasks_map: dict[str, Union[TaskGroup, BaseOperator]] + nodes: dict[str, DbtNode], + tasks_map: dict[str, Union[TaskGroup, BaseOperator]], ) -> None: """ Create the Airflow task dependencies between non-test nodes. diff --git a/cosmos/constants.py b/cosmos/constants.py index e9d1aaa6b..f42cfc4fc 100644 --- a/cosmos/constants.py +++ b/cosmos/constants.py @@ -86,6 +86,7 @@ class ExecutionMode(Enum): """ LOCAL = "local" + AIRFLOW_ASYNC = "airflow_async" DOCKER = "docker" KUBERNETES = "kubernetes" AWS_EKS = "aws_eks" @@ -147,3 +148,5 @@ def _missing_value_(cls, value): # type: ignore # It expects that you have already created those resources through the appropriate commands. # https://docs.getdbt.com/reference/commands/test TESTABLE_DBT_RESOURCES = {DbtResourceType.MODEL, DbtResourceType.SOURCE, DbtResourceType.SNAPSHOT, DbtResourceType.SEED} + +DBT_COMPILE_TASK_ID = "dbt_compile" diff --git a/cosmos/operators/airflow_async.py b/cosmos/operators/airflow_async.py new file mode 100644 index 000000000..05f762702 --- /dev/null +++ b/cosmos/operators/airflow_async.py @@ -0,0 +1,67 @@ +from cosmos.operators.local import ( + DbtBuildLocalOperator, + DbtCompileLocalOperator, + DbtDocsAzureStorageLocalOperator, + DbtDocsGCSLocalOperator, + DbtDocsLocalOperator, + DbtDocsS3LocalOperator, + DbtLSLocalOperator, + DbtRunLocalOperator, + DbtRunOperationLocalOperator, + DbtSeedLocalOperator, + DbtSnapshotLocalOperator, + DbtSourceLocalOperator, + DbtTestLocalOperator, +) + + +class DbtBuildAirflowAsyncOperator(DbtBuildLocalOperator): + pass + + +class DbtLSAirflowAsyncOperator(DbtLSLocalOperator): + pass + + +class DbtSeedAirflowAsyncOperator(DbtSeedLocalOperator): + pass + + +class DbtSnapshotAirflowAsyncOperator(DbtSnapshotLocalOperator): + pass + + +class DbtSourceAirflowAsyncOperator(DbtSourceLocalOperator): + pass + + +class DbtRunAirflowAsyncOperator(DbtRunLocalOperator): + pass + + +class DbtTestAirflowAsyncOperator(DbtTestLocalOperator): + pass + + +class DbtRunOperationAirflowAsyncOperator(DbtRunOperationLocalOperator): + pass + + +class DbtDocsAirflowAsyncOperator(DbtDocsLocalOperator): + pass + + +class DbtDocsS3AirflowAsyncOperator(DbtDocsS3LocalOperator): + pass + + +class DbtDocsAzureStorageAirflowAsyncOperator(DbtDocsAzureStorageLocalOperator): + pass + + +class DbtDocsGCSAirflowAsyncOperator(DbtDocsGCSLocalOperator): + pass + + +class DbtCompileAirflowAsyncOperator(DbtCompileLocalOperator): + pass diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index d82083a23..ed7969ebd 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -429,3 +429,12 @@ def add_cmd_flags(self) -> list[str]: flags.append("--args") flags.append(yaml.dump(self.args)) return flags + + +class DbtCompileMixin: + """ + Mixin for dbt compile command. + """ + + base_cmd = ["compile"] + ui_color = "#877c7c" diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 557bfe500..db5993609 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -9,6 +9,7 @@ from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence +from urllib.parse import urlparse import airflow import jinja2 @@ -17,6 +18,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.utils.context import Context from airflow.utils.session import NEW_SESSION, create_session, provide_session +from airflow.version import version as airflow_version from attr import define from packaging.version import Version @@ -26,10 +28,11 @@ _get_latest_cached_package_lockfile, is_cache_package_lockfile_enabled, ) -from cosmos.constants import InvocationMode +from cosmos.constants import FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP, InvocationMode from cosmos.dataset import get_dataset_alias_name from cosmos.dbt.project import get_partial_parse_path, has_non_empty_dependencies_file -from cosmos.exceptions import AirflowCompatibilityError +from cosmos.exceptions import AirflowCompatibilityError, CosmosValueError +from cosmos.settings import AIRFLOW_IO_AVAILABLE, remote_target_path, remote_target_path_conn_id try: from airflow.datasets import Dataset @@ -67,6 +70,7 @@ from cosmos.operators.base import ( AbstractDbtBaseOperator, DbtBuildMixin, + DbtCompileMixin, DbtLSMixin, DbtRunMixin, DbtRunOperationMixin, @@ -137,6 +141,7 @@ def __init__( install_deps: bool = False, callback: Callable[[str], None] | None = None, should_store_compiled_sql: bool = True, + should_upload_compiled_sql: bool = False, append_env: bool = True, **kwargs: Any, ) -> None: @@ -146,6 +151,7 @@ def __init__( self.compiled_sql = "" self.freshness = "" self.should_store_compiled_sql = should_store_compiled_sql + self.should_upload_compiled_sql = should_upload_compiled_sql self.openlineage_events_completes: list[RunEvent] = [] self.invocation_mode = invocation_mode self.invoke_dbt: Callable[..., FullOutputSubprocessResult | dbtRunnerResult] @@ -271,6 +277,84 @@ def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Se else: self.log.info("Warning: ti is of type TaskInstancePydantic. Cannot update template_fields.") + @staticmethod + def _configure_remote_target_path() -> tuple[Path, str] | tuple[None, None]: + """Configure the remote target path if it is provided.""" + if not remote_target_path: + return None, None + + _configured_target_path = None + + target_path_str = str(remote_target_path) + + remote_conn_id = remote_target_path_conn_id + if not remote_conn_id: + target_path_schema = urlparse(target_path_str).scheme + remote_conn_id = FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP.get(target_path_schema, None) # type: ignore[assignment] + if remote_conn_id is None: + return None, None + + if not AIRFLOW_IO_AVAILABLE: + raise CosmosValueError( + f"You're trying to specify remote target path {target_path_str}, but the required " + f"Object Storage feature is unavailable in Airflow version {airflow_version}. Please upgrade to " + "Airflow 2.8 or later." + ) + + from airflow.io.path import ObjectStoragePath + + _configured_target_path = ObjectStoragePath(target_path_str, conn_id=remote_conn_id) + + if not _configured_target_path.exists(): # type: ignore[no-untyped-call] + _configured_target_path.mkdir(parents=True, exist_ok=True) + + return _configured_target_path, remote_conn_id + + def _construct_dest_file_path( + self, dest_target_dir: Path, file_path: str, source_compiled_dir: Path, context: Context + ) -> str: + """ + Construct the destination path for the compiled SQL files to be uploaded to the remote store. + """ + dest_target_dir_str = str(dest_target_dir).rstrip("/") + + task = context["task"] + dag_id = task.dag_id + task_group_id = task.task_group.group_id if task.task_group else None + identifiers_list = [] + if dag_id: + identifiers_list.append(dag_id) + if task_group_id: + identifiers_list.append(task_group_id) + dag_task_group_identifier = "__".join(identifiers_list) + + rel_path = os.path.relpath(file_path, source_compiled_dir).lstrip("/") + + return f"{dest_target_dir_str}/{dag_task_group_identifier}/compiled/{rel_path}" + + def upload_compiled_sql(self, tmp_project_dir: str, context: Context) -> None: + """ + Uploads the compiled SQL files from the dbt compile output to the remote store. + """ + if not self.should_upload_compiled_sql: + return + + dest_target_dir, dest_conn_id = self._configure_remote_target_path() + if not dest_target_dir: + raise CosmosValueError( + "You're trying to upload compiled SQL files, but the remote target path is not configured. " + ) + + from airflow.io.path import ObjectStoragePath + + source_compiled_dir = Path(tmp_project_dir) / "target" / "compiled" + files = [str(file) for file in source_compiled_dir.rglob("*") if file.is_file()] + for file_path in files: + dest_file_path = self._construct_dest_file_path(dest_target_dir, file_path, source_compiled_dir, context) + dest_object_storage_path = ObjectStoragePath(dest_file_path, conn_id=dest_conn_id) + ObjectStoragePath(file_path).copy(dest_object_storage_path) + self.log.debug("Copied %s to %s", file_path, dest_object_storage_path) + @provide_session def store_freshness_json(self, tmp_project_dir: str, context: Context, session: Session = NEW_SESSION) -> None: """ @@ -416,6 +500,7 @@ def run_command( self.store_freshness_json(tmp_project_dir, context) self.store_compiled_sql(tmp_project_dir, context) + self.upload_compiled_sql(tmp_project_dir, context) self.handle_exception(result) if self.callback: self.callback(tmp_project_dir) @@ -920,3 +1005,9 @@ def __init__(self, **kwargs: str) -> None: raise DeprecationWarning( "The DbtDepsOperator has been deprecated. " "Please use the `install_deps` flag in dbt_args instead." ) + + +class DbtCompileLocalOperator(DbtCompileMixin, DbtLocalBaseOperator): + def __init__(self, *args: Any, **kwargs: Any) -> None: + kwargs["should_upload_compiled_sql"] = True + super().__init__(*args, **kwargs) diff --git a/cosmos/settings.py b/cosmos/settings.py index 6449630ae..2cae79968 100644 --- a/cosmos/settings.py +++ b/cosmos/settings.py @@ -35,6 +35,9 @@ remote_cache_dir = conf.get("cosmos", "remote_cache_dir", fallback=None) remote_cache_dir_conn_id = conf.get("cosmos", "remote_cache_dir_conn_id", fallback=None) +remote_target_path = conf.get("cosmos", "remote_target_path", fallback=None) +remote_target_path_conn_id = conf.get("cosmos", "remote_target_path_conn_id", fallback=None) + try: LINEAGE_NAMESPACE = conf.get("openlineage", "namespace") except airflow.exceptions.AirflowConfigException: diff --git a/dev/dags/simple_dag_async.py b/dev/dags/simple_dag_async.py new file mode 100644 index 000000000..787461236 --- /dev/null +++ b/dev/dags/simple_dag_async.py @@ -0,0 +1,39 @@ +import os +from datetime import datetime +from pathlib import Path + +from cosmos import DbtDag, ExecutionConfig, ExecutionMode, ProfileConfig, ProjectConfig +from cosmos.profiles import PostgresUserPasswordProfileMapping + +DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt" +DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) + +profile_config = ProfileConfig( + profile_name="default", + target_name="dev", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="example_conn", + profile_args={"schema": "public"}, + disable_event_tracking=True, + ), +) + +# [START airflow_async_execution_mode_example] +simple_dag_async = DbtDag( + # dbt/cosmos-specific parameters + project_config=ProjectConfig( + DBT_ROOT_PATH / "jaffle_shop", + ), + profile_config=profile_config, + execution_config=ExecutionConfig( + execution_mode=ExecutionMode.AIRFLOW_ASYNC, + ), + # normal dag parameters + schedule_interval=None, + start_date=datetime(2023, 1, 1), + catchup=False, + dag_id="simple_dag_async", + tags=["simple"], + operator_args={"install_deps": True}, +) +# [END airflow_async_execution_mode_example] diff --git a/docs/configuration/cosmos-conf.rst b/docs/configuration/cosmos-conf.rst index 95a4adcad..3035cfd7a 100644 --- a/docs/configuration/cosmos-conf.rst +++ b/docs/configuration/cosmos-conf.rst @@ -126,6 +126,27 @@ This page lists all available Airflow configurations that affect ``astronomer-co - Default: ``None`` - Environment Variable: ``AIRFLOW__COSMOS__REMOTE_CACHE_DIR_CONN_ID`` +.. _remote_target_path: + +`remote_target_path`_: + (Introduced since Cosmos 1.7.0) The path to the remote target directory. This is the directory designated to + remotely copy & store in the files generated and stored by dbt in the dbt project's target directory. The value + for the remote target path can be any of the schemes that are supported by the + `Airflow Object Store `_ + feature introduced in Airflow 2.8.0 (e.g. ``s3://your_s3_bucket/target_dir/``, ``gs://your_gs_bucket/target_dir/``, + ``abfs://your_azure_container/cache_dir``, etc.) + + - Default: ``None`` + - Environment Variable: ``AIRFLOW__COSMOS__REMOTE_TARGET_PATH`` + +.. _remote_target_path_conn_id: + +`remote_target_path_conn_id`_: + (Introduced since Cosmos 1.7.0) The connection ID for the remote target path. If this is not set, the default + Airflow connection ID identified for the scheme will be used. + + - Default: ``None`` + - Environment Variable: ``AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID`` [openlineage] ~~~~~~~~~~~~~ diff --git a/docs/getting_started/execution-modes.rst b/docs/getting_started/execution-modes.rst index f80c3da9d..ec150992d 100644 --- a/docs/getting_started/execution-modes.rst +++ b/docs/getting_started/execution-modes.rst @@ -12,12 +12,13 @@ Cosmos can run ``dbt`` commands using five different approaches, called ``execut 5. **aws_eks**: Run ``dbt`` commands from AWS EKS Pods managed by Cosmos (requires a pre-existing Docker image) 6. **azure_container_instance**: Run ``dbt`` commands from Azure Container Instances managed by Cosmos (requires a pre-existing Docker image) 7. **gcp_cloud_run_job**: Run ``dbt`` commands from GCP Cloud Run Job instances managed by Cosmos (requires a pre-existing Docker image) +8. **airflow_async**: (Experimental and introduced since Cosmos 1.7.0) Run the dbt resources from your dbt project asynchronously, by submitting the corresponding compiled SQLs to Apache Airflow's `Deferrable operators `__ The choice of the ``execution mode`` can vary based on each user's needs and concerns. For more details, check each execution mode described below. .. list-table:: Execution Modes Comparison - :widths: 20 20 20 20 20 + :widths: 25 25 25 25 :header-rows: 1 * - Execution Mode @@ -52,6 +53,10 @@ The choice of the ``execution mode`` can vary based on each user's needs and con - Slow - High - No + * - Airflow Async + - Medium + - None + - Yes Local ----- @@ -238,6 +243,42 @@ Each task will create a new Cloud Run Job execution, giving full isolation. The }, ) +Airflow Async (experimental) +---------------------------- + +.. versionadded:: 1.7.0 + + +(**Experimental**) The ``airflow_async`` execution mode is a way to run the dbt resources from your dbt project using Apache Airflow's +`Deferrable operators `__. +This execution mode could be preferred when you've long running resources and you want to run them asynchronously by +leveraging Airflow's deferrable operators. With that, you would be able to potentially observe higher throughput of tasks +as more dbt nodes will be run in parallel since they won't be blocking Airflow's worker slots. + +In this mode, Cosmos adds a new operator, ``DbtCompileAirflowAsyncOperator``, as a root task in the DbtDag or DbtTaskGroup. The task runs +the ``dbt compile`` command on your dbt project which then outputs compiled SQLs in the project's target directory. +As part of the same task run, these compiled SQLs are then stored remotely to a remote path set using the +:ref:`remote_target_path` configuration. The remote path is then used by the subsequent tasks in the DAG to +fetch (from the remote path) and run the compiled SQLs asynchronously using e.g. the ``DbtRunAirflowAsyncOperator``. +You may observe that the compile task takes a bit longer to run due to the latency of storing the compiled SQLs +remotely (e.g. for the classic ``jaffle_shop`` dbt project, upon compiling it produces about 31 files measuring about 124KB in total, but on a local +machine it took approximately 25 seconds for the task to compile & upload the compiled SQLs to the remote path)., +however, it is still a win as it is one-time overhead and the subsequent tasks run asynchronously utilising the Airflow's +deferrable operators and supplying to them those compiled SQLs. + +Note that currently, the ``airflow_async`` execution mode has the following limitations and is released as Experimental: + +1. Only supports the ``dbt resource type`` models to be run asynchronously using Airflow deferrable operators. All other resources are executed synchronously using dbt commands as they are in the ``local`` execution mode. +2. Only supports BigQuery as the target database. If a profile target other than BigQuery is specified, Cosmos will error out saying that the target database is not supported with this execution mode. +3. Only works for ``full_refresh`` models. There is pending work to support other modes. + +Example DAG: + +.. literalinclude:: ../../dev/dags/simple_dag_async.py + :language: python + :start-after: [START airflow_async_execution_mode_example] + :end-before: [END airflow_async_execution_mode_example] + .. _invocation_modes: Invocation Modes ================ diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index 72a09a5e5..6fc7cdc0a 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -21,6 +21,7 @@ ) from cosmos.config import ProfileConfig, RenderConfig from cosmos.constants import ( + DBT_COMPILE_TASK_ID, DbtResourceType, ExecutionMode, SourceRenderingBehavior, @@ -226,6 +227,41 @@ def test_build_airflow_graph_with_after_all(): assert dag.leaves[0].select == ["tag:some"] +@pytest.mark.integration +def test_build_airflow_graph_with_dbt_compile_task(): + with DAG("test-id-dbt-compile", start_date=datetime(2022, 1, 1)) as dag: + task_args = { + "project_dir": SAMPLE_PROJ_PATH, + "conn_id": "fake_conn", + "profile_config": ProfileConfig( + profile_name="default", + target_name="default", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="fake_conn", + profile_args={"schema": "public"}, + ), + ), + } + render_config = RenderConfig( + select=["tag:some"], + test_behavior=TestBehavior.AFTER_ALL, + source_rendering_behavior=SOURCE_RENDERING_BEHAVIOR, + ) + build_airflow_graph( + nodes=sample_nodes, + dag=dag, + execution_mode=ExecutionMode.AIRFLOW_ASYNC, + test_indirect_selection=TestIndirectSelection.EAGER, + task_args=task_args, + dbt_project_name="astro_shop", + render_config=render_config, + ) + + task_ids = [task.task_id for task in dag.tasks] + assert DBT_COMPILE_TASK_ID in task_ids + assert DBT_COMPILE_TASK_ID in dag.tasks[0].upstream_task_ids + + def test_calculate_operator_class(): class_module_import_path = calculate_operator_class(execution_mode=ExecutionMode.KUBERNETES, dbt_class="DbtSeed") assert class_module_import_path == "cosmos.operators.kubernetes.DbtSeedKubernetesOperator" diff --git a/tests/operators/test_airflow_async.py b/tests/operators/test_airflow_async.py new file mode 100644 index 000000000..fc085c7d0 --- /dev/null +++ b/tests/operators/test_airflow_async.py @@ -0,0 +1,82 @@ +from cosmos.operators.airflow_async import ( + DbtBuildAirflowAsyncOperator, + DbtCompileAirflowAsyncOperator, + DbtDocsAirflowAsyncOperator, + DbtDocsAzureStorageAirflowAsyncOperator, + DbtDocsGCSAirflowAsyncOperator, + DbtDocsS3AirflowAsyncOperator, + DbtLSAirflowAsyncOperator, + DbtRunAirflowAsyncOperator, + DbtRunOperationAirflowAsyncOperator, + DbtSeedAirflowAsyncOperator, + DbtSnapshotAirflowAsyncOperator, + DbtSourceAirflowAsyncOperator, + DbtTestAirflowAsyncOperator, +) +from cosmos.operators.local import ( + DbtBuildLocalOperator, + DbtCompileLocalOperator, + DbtDocsAzureStorageLocalOperator, + DbtDocsGCSLocalOperator, + DbtDocsLocalOperator, + DbtDocsS3LocalOperator, + DbtLSLocalOperator, + DbtRunLocalOperator, + DbtRunOperationLocalOperator, + DbtSeedLocalOperator, + DbtSnapshotLocalOperator, + DbtSourceLocalOperator, + DbtTestLocalOperator, +) + + +def test_dbt_build_airflow_async_operator_inheritance(): + assert issubclass(DbtBuildAirflowAsyncOperator, DbtBuildLocalOperator) + + +def test_dbt_ls_airflow_async_operator_inheritance(): + assert issubclass(DbtLSAirflowAsyncOperator, DbtLSLocalOperator) + + +def test_dbt_seed_airflow_async_operator_inheritance(): + assert issubclass(DbtSeedAirflowAsyncOperator, DbtSeedLocalOperator) + + +def test_dbt_snapshot_airflow_async_operator_inheritance(): + assert issubclass(DbtSnapshotAirflowAsyncOperator, DbtSnapshotLocalOperator) + + +def test_dbt_source_airflow_async_operator_inheritance(): + assert issubclass(DbtSourceAirflowAsyncOperator, DbtSourceLocalOperator) + + +def test_dbt_run_airflow_async_operator_inheritance(): + assert issubclass(DbtRunAirflowAsyncOperator, DbtRunLocalOperator) + + +def test_dbt_test_airflow_async_operator_inheritance(): + assert issubclass(DbtTestAirflowAsyncOperator, DbtTestLocalOperator) + + +def test_dbt_run_operation_airflow_async_operator_inheritance(): + assert issubclass(DbtRunOperationAirflowAsyncOperator, DbtRunOperationLocalOperator) + + +def test_dbt_docs_airflow_async_operator_inheritance(): + assert issubclass(DbtDocsAirflowAsyncOperator, DbtDocsLocalOperator) + + +def test_dbt_docs_s3_airflow_async_operator_inheritance(): + assert issubclass(DbtDocsS3AirflowAsyncOperator, DbtDocsS3LocalOperator) + + +def test_dbt_docs_azure_storage_airflow_async_operator_inheritance(): + assert issubclass(DbtDocsAzureStorageAirflowAsyncOperator, DbtDocsAzureStorageLocalOperator) + + +def test_dbt_docs_gcs_airflow_async_operator_inheritance(): + assert issubclass(DbtDocsGCSAirflowAsyncOperator, DbtDocsGCSLocalOperator) + + +def test_dbt_compile_airflow_async_operator_inheritance(): + assert issubclass(DbtCompileAirflowAsyncOperator, DbtCompileLocalOperator) diff --git a/tests/operators/test_base.py b/tests/operators/test_base.py index 6f4425282..e97c2d396 100644 --- a/tests/operators/test_base.py +++ b/tests/operators/test_base.py @@ -8,6 +8,7 @@ from cosmos.operators.base import ( AbstractDbtBaseOperator, DbtBuildMixin, + DbtCompileMixin, DbtLSMixin, DbtRunMixin, DbtRunOperationMixin, @@ -143,6 +144,7 @@ def test_dbt_base_operator_context_merge( ("seed", DbtSeedMixin), ("run", DbtRunMixin), ("build", DbtBuildMixin), + ("compile", DbtCompileMixin), ], ) def test_dbt_mixin_base_cmd(dbt_command, dbt_operator_class): diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 04001ca75..c7615225f 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -25,9 +25,11 @@ parse_number_of_warnings_dbt_runner, parse_number_of_warnings_subprocess, ) +from cosmos.exceptions import CosmosValueError from cosmos.hooks.subprocess import FullOutputSubprocessResult from cosmos.operators.local import ( DbtBuildLocalOperator, + DbtCompileLocalOperator, DbtDocsAzureStorageLocalOperator, DbtDocsGCSLocalOperator, DbtDocsLocalOperator, @@ -42,6 +44,7 @@ DbtTestLocalOperator, ) from cosmos.profiles import PostgresUserPasswordProfileMapping +from cosmos.settings import AIRFLOW_IO_AVAILABLE from tests.utils import test_dag as run_test_dag DBT_PROJ_DIR = Path(__file__).parent.parent.parent / "dev/dags/dbt/jaffle_shop" @@ -1131,3 +1134,136 @@ def test_store_freshness_not_store_compiled_sql(mock_context, mock_session): # Verify the freshness attribute is set correctly assert instance.freshness == "" + + +def test_dbt_compile_local_operator_initialisation(): + operator = DbtCompileLocalOperator( + task_id="fake-task", + profile_config=profile_config, + project_dir="fake-dir", + ) + assert operator.should_upload_compiled_sql is True + assert "compile" in operator.base_cmd + + +@patch("cosmos.operators.local.remote_target_path", new="s3://some-bucket/target") +@patch("cosmos.operators.local.AIRFLOW_IO_AVAILABLE", new=False) +def test_configure_remote_target_path_object_storage_unavailable_on_earlier_airflow_versions(): + operator = DbtCompileLocalOperator( + task_id="fake-task", + profile_config=profile_config, + project_dir="fake-dir", + ) + with pytest.raises(CosmosValueError, match="Object Storage feature is unavailable"): + operator._configure_remote_target_path() + + +@pytest.mark.parametrize( + "rem_target_path, rem_target_path_conn_id", + [ + (None, "aws_s3_conn"), + ("unknown://some-bucket/cache", None), + ], +) +def test_config_remote_target_path_unset_settings(rem_target_path, rem_target_path_conn_id): + with patch("cosmos.operators.local.remote_target_path", new=rem_target_path): + with patch("cosmos.operators.local.remote_target_path_conn_id", new=rem_target_path_conn_id): + operator = DbtCompileLocalOperator( + task_id="fake-task", + profile_config=profile_config, + project_dir="fake-dir", + ) + target_path, target_conn = operator._configure_remote_target_path() + assert target_path is None + assert target_conn is None + + +@pytest.mark.skipif(not AIRFLOW_IO_AVAILABLE, reason="Airflow did not have Object Storage until the 2.8 release") +@patch("cosmos.operators.local.remote_target_path", new="s3://some-bucket/target") +@patch("cosmos.operators.local.remote_target_path_conn_id", new="aws_s3_conn") +@patch("airflow.io.path.ObjectStoragePath") +def test_configure_remote_target_path(mock_object_storage_path): + operator = DbtCompileLocalOperator( + task_id="fake-task", + profile_config=profile_config, + project_dir="fake-dir", + ) + mock_remote_path = MagicMock() + mock_object_storage_path.return_value.exists.return_value = True + mock_object_storage_path.return_value = mock_remote_path + target_path, target_conn = operator._configure_remote_target_path() + assert target_path == mock_remote_path + assert target_conn == "aws_s3_conn" + mock_object_storage_path.assert_called_with("s3://some-bucket/target", conn_id="aws_s3_conn") + + mock_object_storage_path.return_value.exists.return_value = False + mock_object_storage_path.return_value.mkdir.return_value = MagicMock() + _, _ = operator._configure_remote_target_path() + mock_object_storage_path.return_value.mkdir.assert_called_with(parents=True, exist_ok=True) + + +@patch.object(DbtLocalBaseOperator, "_configure_remote_target_path") +def test_no_compiled_sql_upload_for_other_operators(mock_configure_remote_target_path): + operator = DbtSeedLocalOperator( + task_id="fake-task", + profile_config=profile_config, + project_dir="fake-dir", + ) + assert operator.should_upload_compiled_sql is False + operator.upload_compiled_sql("fake-dir", MagicMock()) + mock_configure_remote_target_path.assert_not_called() + + +@patch("cosmos.operators.local.DbtCompileLocalOperator._configure_remote_target_path") +def test_upload_compiled_sql_no_remote_path_raises_error(mock_configure_remote): + operator = DbtCompileLocalOperator( + task_id="fake-task", + profile_config=profile_config, + project_dir="fake-dir", + ) + + mock_configure_remote.return_value = (None, None) + + tmp_project_dir = "/fake/tmp/project" + context = {"dag": MagicMock(dag_id="test_dag")} + + with pytest.raises(CosmosValueError, match="remote target path is not configured"): + operator.upload_compiled_sql(tmp_project_dir, context) + + +@pytest.mark.skipif(not AIRFLOW_IO_AVAILABLE, reason="Airflow did not have Object Storage until the 2.8 release") +@patch("airflow.io.path.ObjectStoragePath.copy") +@patch("airflow.io.path.ObjectStoragePath") +@patch("cosmos.operators.local.DbtCompileLocalOperator._configure_remote_target_path") +def test_upload_compiled_sql_should_upload(mock_configure_remote, mock_object_storage_path, mock_copy): + """Test upload_compiled_sql when should_upload_compiled_sql is True and uploads files.""" + operator = DbtCompileLocalOperator( + task_id="fake-task", + profile_config=profile_config, + project_dir="fake-dir", + dag=DAG("test_dag", start_date=datetime(2024, 4, 16)), + ) + + mock_configure_remote.return_value = ("mock_remote_path", "mock_conn_id") + + tmp_project_dir = "/fake/tmp/project" + source_compiled_dir = Path(tmp_project_dir) / "target" / "compiled" + + file1 = MagicMock(spec=Path) + file1.is_file.return_value = True + file1.__str__.return_value = str(source_compiled_dir / "file1.sql") + + file2 = MagicMock(spec=Path) + file2.is_file.return_value = True + file2.__str__.return_value = str(source_compiled_dir / "file2.sql") + + files = [file1, file2] + + with patch.object(Path, "rglob", return_value=files): + operator.upload_compiled_sql(tmp_project_dir, context={"task": operator}) + + for file_path in files: + rel_path = os.path.relpath(str(file_path), str(source_compiled_dir)) + expected_dest_path = f"mock_remote_path/test_dag/compiled/{rel_path.lstrip('/')}" + mock_object_storage_path.assert_any_call(expected_dest_path, conn_id="mock_conn_id") + mock_object_storage_path.return_value.copy.assert_any_call(mock_object_storage_path.return_value) diff --git a/tests/test_example_dags.py b/tests/test_example_dags.py index 9f8601156..9aa66432d 100644 --- a/tests/test_example_dags.py +++ b/tests/test_example_dags.py @@ -28,7 +28,7 @@ MIN_VER_DAG_FILE: dict[str, list[str]] = { "2.4": ["cosmos_seed_dag.py"], - "2.8": ["cosmos_manifest_example.py"], + "2.8": ["cosmos_manifest_example.py", "simple_dag_async.py"], } IGNORED_DAG_FILES = ["performance_dag.py", "jaffle_shop_kubernetes.py"]