diff --git a/dlt/helpers/airflow_helper.py b/dlt/helpers/airflow_helper.py index 89fe06349b..e68c330765 100644 --- a/dlt/helpers/airflow_helper.py +++ b/dlt/helpers/airflow_helper.py @@ -171,6 +171,7 @@ def run( loader_file_format: TLoaderFileFormat = None, schema_contract: TSchemaContract = None, pipeline_name: str = None, + on_before_run: Callable[[], None] = None, **kwargs: Any, ) -> PythonOperator: """ @@ -179,7 +180,12 @@ def run( Args: pipeline (Pipeline): The pipeline to run - data (Any): The data to run the pipeline with + data (Any): + The data to run the pipeline with. If a non-resource + callable given, it's evaluated during the DAG execution, + right before the actual pipeline run. + NOTE: If `on_before_run` is provided, first `on_before_run` + is evaluated, and then callable `data`. table_name (str, optional): The name of the table to which the data should be loaded within the `dataset`. write_disposition (TWriteDispositionConfig, optional): Same as @@ -191,6 +197,8 @@ def run( for the schema contract settings, this will replace the schema contract settings for all tables in the schema. pipeline_name (str, optional): The name of the derived pipeline. + on_before_run (Callable, optional): A callable to be + executed right before the actual pipeline run. Returns: PythonOperator: Airflow task instance. @@ -204,6 +212,7 @@ def run( loader_file_format=loader_file_format, schema_contract=schema_contract, pipeline_name=pipeline_name, + on_before_run=on_before_run, ) return PythonOperator(task_id=self._task_name(pipeline, data), python_callable=f, **kwargs) @@ -216,12 +225,18 @@ def _run( loader_file_format: TLoaderFileFormat = None, schema_contract: TSchemaContract = None, pipeline_name: str = None, + on_before_run: Callable[[], None] = None, ) -> None: """Run the given pipeline with the given data. Args: pipeline (Pipeline): The pipeline to run - data (Any): The data to run the pipeline with + data (Any): + The data to run the pipeline with. If a non-resource + callable given, it's evaluated during the DAG execution, + right before the actual pipeline run. + NOTE: If `on_before_run` is provided, first `on_before_run` + is evaluated, and then callable `data`. table_name (str, optional): The name of the table to which the data should be loaded within the `dataset`. @@ -236,6 +251,8 @@ def _run( for all tables in the schema. pipeline_name (str, optional): The name of the derived pipeline. + on_before_run (Callable, optional): A callable + to be executed right before the actual pipeline run. """ # activate pipeline pipeline.activate() @@ -271,6 +288,12 @@ def log_after_attempt(retry_state: RetryCallState) -> None: ) try: + if on_before_run is not None: + on_before_run() + + if callable(data): + data = data() + # retry with given policy on selected pipeline steps for attempt in self.retry_policy.copy( retry=retry_if_exception( @@ -325,6 +348,7 @@ def add_run( write_disposition: TWriteDispositionConfig = None, loader_file_format: TLoaderFileFormat = None, schema_contract: TSchemaContract = None, + on_before_run: Callable[[], None] = None, **kwargs: Any, ) -> List[PythonOperator]: """Creates a task or a group of tasks to run `data` with `pipeline` @@ -338,7 +362,10 @@ def add_run( Args: pipeline (Pipeline): An instance of pipeline used to run the source - data (Any): Any data supported by `run` method of the pipeline + data (Any): + Any data supported by `run` method of the pipeline. + If a non-resource callable given, it's called before + the load to get the data. decompose (Literal["none", "serialize", "parallel"], optional): A source decomposition strategy into Airflow tasks: none - no decomposition, default value. @@ -365,6 +392,8 @@ def add_run( Not all file_formats are compatible with all destinations. Defaults to the preferred file format of the selected destination. schema_contract (TSchemaContract, optional): On override for the schema contract settings, this will replace the schema contract settings for all tables in the schema. Defaults to None. + on_before_run (Callable, optional): + A callable to be executed right before the actual pipeline run. Returns: Any: Airflow tasks created in order of creation. @@ -391,6 +420,7 @@ def make_task(pipeline: Pipeline, data: Any, name: str = None) -> PythonOperator loader_file_format=loader_file_format, schema_contract=schema_contract, pipeline_name=name, + on_before_run=on_before_run, ) return PythonOperator( task_id=self._task_name(pipeline, data), python_callable=f, **kwargs diff --git a/tests/helpers/airflow_tests/test_airflow_wrapper.py b/tests/helpers/airflow_tests/test_airflow_wrapper.py index a328403ba0..845800e47f 100644 --- a/tests/helpers/airflow_tests/test_airflow_wrapper.py +++ b/tests/helpers/airflow_tests/test_airflow_wrapper.py @@ -4,13 +4,13 @@ from typing import List from airflow import DAG from airflow.decorators import dag -from airflow.operators.python import PythonOperator +from airflow.operators.python import PythonOperator, get_current_context from airflow.models import TaskInstance from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType import dlt -from dlt.common import pendulum +from dlt.common import logger, pendulum from dlt.common.utils import uniq_id from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention @@ -917,3 +917,81 @@ def dag_parallel(): dag_def = dag_parallel() assert len(tasks_list) == 1 dag_def.test() + + +def callable_source(): + @dlt.resource + def test_res(): + context = get_current_context() + yield [ + {"id": 1, "tomorrow": context["tomorrow_ds"]}, + {"id": 2, "tomorrow": context["tomorrow_ds"]}, + {"id": 3, "tomorrow": context["tomorrow_ds"]}, + ] + + return test_res + + +def test_run_callable() -> None: + quackdb_path = os.path.join(TEST_STORAGE_ROOT, "callable_dag.duckdb") + + @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) + def dag_regular(): + tasks = PipelineTasksGroup( + "callable_dag_group", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False + ) + + call_dag = dlt.pipeline( + pipeline_name="callable_dag", + dataset_name="mock_data_" + uniq_id(), + destination="duckdb", + credentials=quackdb_path, + ) + tasks.run(call_dag, callable_source) + + dag_def: DAG = dag_regular() + dag_def.test() + + pipeline_dag = dlt.attach(pipeline_name="callable_dag") + + with pipeline_dag.sql_client() as client: + with client.execute_query("SELECT * FROM test_res") as result: + results = result.fetchall() + + assert len(results) == 3 + + for row in results: + assert row[1] == pendulum.tomorrow().format("YYYY-MM-DD") + + +def on_before_run(): + context = get_current_context() + logger.info(f'on_before_run test: {context["tomorrow_ds"]}') + + +def test_on_before_run() -> None: + quackdb_path = os.path.join(TEST_STORAGE_ROOT, "callable_dag.duckdb") + + @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) + def dag_regular(): + tasks = PipelineTasksGroup( + "callable_dag_group", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False + ) + + call_dag = dlt.pipeline( + pipeline_name="callable_dag", + dataset_name="mock_data_" + uniq_id(), + destination="duckdb", + credentials=quackdb_path, + ) + tasks.run(call_dag, mock_data_source, on_before_run=on_before_run) + + dag_def: DAG = dag_regular() + + with mock.patch("dlt.helpers.airflow_helper.logger.info") as logger_mock: + dag_def.test() + logger_mock.assert_has_calls( + [ + mock.call(f'on_before_run test: {pendulum.tomorrow().format("YYYY-MM-DD")}'), + ] + )