diff --git a/dlt/helpers/airflow_helper.py b/dlt/helpers/airflow_helper.py index 91721c457f..437602d3a4 100644 --- a/dlt/helpers/airflow_helper.py +++ b/dlt/helpers/airflow_helper.py @@ -1,6 +1,7 @@ import os from tempfile import gettempdir from typing import Any, Callable, List, Literal, Optional, Sequence, Tuple + from tenacity import ( retry_if_exception, wait_exponential, @@ -9,9 +10,7 @@ RetryCallState, ) -from dlt.common import pendulum from dlt.common.exceptions import MissingDependencyException -from dlt.common.runtime.telemetry import with_telemetry try: from airflow.configuration import conf @@ -20,14 +19,17 @@ from airflow.operators.dummy import DummyOperator # type: ignore from airflow.operators.python import PythonOperator, get_current_context except ModuleNotFoundError: - raise MissingDependencyException("Airflow", ["airflow>=2.0.0"]) + raise MissingDependencyException("Airflow", ["apache-airflow>=2.5"]) import dlt +from dlt.common import pendulum from dlt.common import logger +from dlt.common.runtime.telemetry import with_telemetry from dlt.common.data_writers import TLoaderFileFormat from dlt.common.schema.typing import TWriteDisposition, TSchemaContract from dlt.common.utils import uniq_id +from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext from dlt.common.runtime.collector import NULL_COLLECTOR @@ -135,7 +137,7 @@ def add_run( pipeline: Pipeline, data: Any, *, - decompose: Literal["none", "serialize", "parallel"] = "none", + decompose: Literal["none", "serialize", "parallel", "parallel-isolated"] = "none", table_name: str = None, write_disposition: TWriteDisposition = None, loader_file_format: TLoaderFileFormat = None, @@ -158,11 +160,22 @@ def add_run( A source decomposition strategy into Airflow tasks: none - no decomposition, default value. serialize - decompose the source into a sequence of Airflow tasks. - parallel - decompose the source into a parallel Airflow task group. - NOTE: In case the SequentialExecutor is used by Airflow, the tasks - will remain sequential. Use another executor, e.g. CeleryExecutor)! - NOTE: The first component of the source is done first, after that - the rest are executed in parallel to each other. + parallel - decompose the source into a parallel Airflow task group, + except the first resource must be completed first. + All tasks that are run in parallel share the same pipeline state. + If two of them modify the state, part of state may be lost + parallel-isolated - decompose the source into a parallel Airflow task group. + with the same exception as above. All task have separate pipeline + state (via separate pipeline name) but share the same dataset, + schemas and tables. + NOTE: The first component of the source in both parallel models is done first, + after that the rest are executed in parallel to each other. + NOTE: In case the SequentialExecutor is used by Airflow, the tasks + will remain sequential despite 'parallel' or 'parallel-isolated' mode. + Use another executor (e.g. CeleryExecutor) to make tasks parallel! + + Parallel tasks are executed in different pipelines, all derived from the original + one, but with the state isolated from each other. table_name: (str): The name of the table to which the data should be loaded within the `dataset` write_disposition (TWriteDisposition, optional): Same as in `run` command. Defaults to None. loader_file_format (Literal["jsonl", "insert_values", "parquet"], optional): The file format the loader will use to create the load package. @@ -194,12 +207,12 @@ def task_name(pipeline: Pipeline, data: Any) -> str: # use factory function to make a task, in order to parametrize it # passing arguments to task function (_run) is serializing # them and running template engine on them - def make_task(pipeline: Pipeline, data: Any) -> PythonOperator: + def make_task(pipeline: Pipeline, data: Any, name: str = None) -> PythonOperator: def _run() -> None: # activate pipeline pipeline.activate() # drop local data - task_pipeline = pipeline.drop() + task_pipeline = pipeline.drop(pipeline_name=name) # use task logger if self.use_task_logger: @@ -308,15 +321,60 @@ def log_after_attempt(retry_state: RetryCallState) -> None: if pipeline.full_refresh: raise ValueError("Cannot decompose pipelines with full_refresh set") - # parallel tasks tasks = [] sources = data.decompose("scc") + t_name = task_name(pipeline, data) start = make_task(pipeline, sources[0]) + # parallel tasks for source in sources[1:]: + for resource in source.resources.values(): + if resource.incremental: + logger.warn( + f"The resource {resource.name} in task {t_name} " + "is using incremental loading and may modify the " + "state. Resources that modify the state should not " + "run in parallel within the single pipeline as the " + "state will not be correctly merged. Please use " + "'serialize' or 'parallel-isolated' modes instead." + ) + break + tasks.append(make_task(pipeline, source)) - end = DummyOperator(task_id=f"{task_name(pipeline, data)}_end") + end = DummyOperator(task_id=f"{t_name}_end") + + if tasks: + start >> tasks >> end + return [start] + tasks + [end] + + start >> end + return [start, end] + elif decompose == "parallel-isolated": + if not isinstance(data, DltSource): + raise ValueError("Can only decompose dlt sources") + + if pipeline.full_refresh: + raise ValueError("Cannot decompose pipelines with full_refresh set") + + # parallel tasks + tasks = [] + naming = SnakeCaseNamingConvention() + sources = data.decompose("scc") + start = make_task( + pipeline, + sources[0], + naming.normalize_identifier(task_name(pipeline, sources[0])), + ) + + # parallel tasks + for source in sources[1:]: + # name pipeline the same as task + new_pipeline_name = naming.normalize_identifier(task_name(pipeline, source)) + tasks.append(make_task(pipeline, source, new_pipeline_name)) + + t_name = task_name(pipeline, data) + end = DummyOperator(task_id=f"{t_name}_end") if tasks: start >> tasks >> end @@ -325,7 +383,10 @@ def log_after_attempt(retry_state: RetryCallState) -> None: start >> end return [start, end] else: - raise ValueError("decompose value must be one of ['none', 'serialize', 'parallel']") + raise ValueError( + "decompose value must be one of ['none', 'serialize', 'parallel'," + " 'parallel-isolated']" + ) def add_fun(self, f: Callable[..., Any], **kwargs: Any) -> Any: """Will execute a function `f` inside an Airflow task. It is up to the function to create pipeline and source(s)""" diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 73c8f076d1..f75548a390 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -324,13 +324,17 @@ def __init__( self.credentials = credentials self._configure(import_schema_path, export_schema_path, must_attach_to_local_pipeline) - def drop(self) -> "Pipeline": - """Deletes local pipeline state, schemas and any working files""" + def drop(self, pipeline_name: str = None) -> "Pipeline": + """Deletes local pipeline state, schemas and any working files. + + Args: + pipeline_name (str): Optional. New pipeline name. + """ # reset the pipeline working dir self._create_pipeline() # clone the pipeline return Pipeline( - self.pipeline_name, + pipeline_name or self.pipeline_name, self.pipelines_dir, self.pipeline_salt, self.destination, diff --git a/tests/helpers/airflow_tests/test_airflow_wrapper.py b/tests/helpers/airflow_tests/test_airflow_wrapper.py index f6c0320635..0399e3875d 100644 --- a/tests/helpers/airflow_tests/test_airflow_wrapper.py +++ b/tests/helpers/airflow_tests/test_airflow_wrapper.py @@ -1,5 +1,6 @@ import os import pytest +from unittest import mock from typing import List from airflow import DAG from airflow.decorators import dag @@ -11,6 +12,8 @@ import dlt from dlt.common import pendulum from dlt.common.utils import uniq_id +from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention + from dlt.helpers.airflow_helper import PipelineTasksGroup, DEFAULT_RETRY_BACKOFF from dlt.pipeline.exceptions import CannotRestorePipelineException, PipelineStepFailed @@ -75,6 +78,23 @@ def resource(): return resource +@dlt.source +def mock_data_incremental_source(): + @dlt.resource + def resource1(a: str = None, b=None, c=None): + yield ["s", "a"] + + @dlt.resource + def resource2( + updated_at: dlt.sources.incremental[str] = dlt.sources.incremental( + "updated_at", initial_value="1970-01-01T00:00:00Z" + ) + ): + yield [{"updated_at": "1970-02-01T00:00:00Z"}] + + return resource1, resource2 + + @dlt.source(section="mock_data_source_state") def mock_data_source_state(): @dlt.resource(selected=True) @@ -263,13 +283,123 @@ def dag_parallel(): dag_def = dag_parallel() assert len(tasks_list) == 4 dag_def.test() + pipeline_dag_parallel = dlt.attach(pipeline_name="pipeline_dag_parallel") - pipeline_dag_decomposed_counts = load_table_counts( + results = load_table_counts( pipeline_dag_parallel, *[t["name"] for t in pipeline_dag_parallel.default_schema.data_tables()], ) - assert pipeline_dag_decomposed_counts == pipeline_standalone_counts + assert results == pipeline_standalone_counts + + # verify tasks 1-2 in between tasks 0 and 3 + for task in dag_def.tasks[1:3]: + assert task.downstream_task_ids == set([dag_def.tasks[-1].task_id]) + assert task.upstream_task_ids == set([dag_def.tasks[0].task_id]) + + +def test_parallel_incremental(): + pipeline_standalone = dlt.pipeline( + pipeline_name="pipeline_parallel", + dataset_name="mock_data_" + uniq_id(), + destination="duckdb", + credentials=":pipeline:", + ) + pipeline_standalone.run(mock_data_incremental_source()) + + tasks_list: List[PythonOperator] = None + + quackdb_path = os.path.join(TEST_STORAGE_ROOT, "pipeline_dag_parallel.duckdb") + + @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) + def dag_parallel(): + nonlocal tasks_list + tasks = PipelineTasksGroup( + "pipeline_dag_parallel", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False + ) + + # set duckdb to be outside of pipeline folder which is dropped on each task + pipeline_dag_parallel = dlt.pipeline( + pipeline_name="pipeline_dag_parallel", + dataset_name="mock_data_" + uniq_id(), + destination="duckdb", + credentials=quackdb_path, + ) + tasks.add_run( + pipeline_dag_parallel, + mock_data_incremental_source(), + decompose="parallel", + trigger_rule="all_done", + retries=0, + provide_context=True, + ) + + with mock.patch("dlt.helpers.airflow_helper.logger.warn") as warn_mock: + dag_def = dag_parallel() + dag_def.test() + warn_mock.assert_called_once() + + +def test_parallel_isolated_run(): + pipeline_standalone = dlt.pipeline( + pipeline_name="pipeline_parallel", + dataset_name="mock_data_" + uniq_id(), + destination="duckdb", + credentials=":pipeline:", + ) + pipeline_standalone.run(mock_data_source()) + pipeline_standalone_counts = load_table_counts( + pipeline_standalone, *[t["name"] for t in pipeline_standalone.default_schema.data_tables()] + ) + + tasks_list: List[PythonOperator] = None + + quackdb_path = os.path.join(TEST_STORAGE_ROOT, "pipeline_dag_parallel.duckdb") + + @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) + def dag_parallel(): + nonlocal tasks_list + tasks = PipelineTasksGroup( + "pipeline_dag_parallel", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False + ) + + # set duckdb to be outside of pipeline folder which is dropped on each task + pipeline_dag_parallel = dlt.pipeline( + pipeline_name="pipeline_dag_parallel", + dataset_name="mock_data_" + uniq_id(), + destination="duckdb", + credentials=quackdb_path, + ) + tasks_list = tasks.add_run( + pipeline_dag_parallel, + mock_data_source(), + decompose="parallel-isolated", + trigger_rule="all_done", + retries=0, + provide_context=True, + ) + + dag_def = dag_parallel() + assert len(tasks_list) == 4 + dag_def.test() + + results = {} + snake_case = SnakeCaseNamingConvention() + for i in range(0, 3): + pipeline_dag_parallel = dlt.attach( + pipeline_name=snake_case.normalize_identifier( + dag_def.tasks[i].task_id.replace("pipeline_dag_parallel.", "") + ) + ) + pipeline_dag_decomposed_counts = load_table_counts( + pipeline_dag_parallel, + *[t["name"] for t in pipeline_dag_parallel.default_schema.data_tables()], + ) + results.update(pipeline_dag_decomposed_counts) + + assert results == pipeline_standalone_counts + + # verify tasks 1-2 in between tasks 0 and 3 for task in dag_def.tasks[1:3]: assert task.downstream_task_ids == set([dag_def.tasks[-1].task_id]) assert task.upstream_task_ids == set([dag_def.tasks[0].task_id]) diff --git a/tests/load/filesystem/test_filesystem_common.py b/tests/load/filesystem/test_filesystem_common.py index 4d370fc786..4c94766097 100644 --- a/tests/load/filesystem/test_filesystem_common.py +++ b/tests/load/filesystem/test_filesystem_common.py @@ -52,7 +52,7 @@ def check_file_exists(): def check_file_changed(): details = filesystem.info(file_url) assert details["size"] == 11 - assert (MTIME_DISPATCH[config.protocol](details) - now).seconds < 60 + assert (MTIME_DISPATCH[config.protocol](details) - now).seconds < 120 bucket_url = os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] config = get_config() diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index cdc4a02509..e1f7397ef9 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -1258,6 +1258,16 @@ def generic(start): assert pipeline.default_schema.get_table("single_table")["resource"] == "state1" +def test_drop_with_new_name() -> None: + old_test_name = "old_pipeline_name" + new_test_name = "new_pipeline_name" + + pipeline = dlt.pipeline(pipeline_name=old_test_name, destination="duckdb") + new_pipeline = pipeline.drop(pipeline_name=new_test_name) + + assert new_pipeline.pipeline_name == new_test_name + + def test_remove_autodetect() -> None: now = pendulum.now()