Skip to content

Commit

Permalink
feat(airflow): implement parallel-isolated mode (#979)
Browse files Browse the repository at this point in the history
* feat(airflow): implement parallel-isolated mode

* emit a warning for incremental args in resources

* in parallel-isolated mode use pipeline names


* runs isolated-parallel in separate pipelines, parallel in single

---------

Co-authored-by: Marcin Rudolf <[email protected]>
  • Loading branch information
IlyaFaer and rudolfix authored Feb 26, 2024
1 parent e035e38 commit 14b0a66
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 20 deletions.
89 changes: 75 additions & 14 deletions dlt/helpers/airflow_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from tempfile import gettempdir
from typing import Any, Callable, List, Literal, Optional, Sequence, Tuple

from tenacity import (
retry_if_exception,
wait_exponential,
Expand All @@ -9,9 +10,7 @@
RetryCallState,
)

from dlt.common import pendulum
from dlt.common.exceptions import MissingDependencyException
from dlt.common.runtime.telemetry import with_telemetry

try:
from airflow.configuration import conf
Expand All @@ -20,14 +19,17 @@
from airflow.operators.dummy import DummyOperator # type: ignore
from airflow.operators.python import PythonOperator, get_current_context
except ModuleNotFoundError:
raise MissingDependencyException("Airflow", ["airflow>=2.0.0"])
raise MissingDependencyException("Airflow", ["apache-airflow>=2.5"])


import dlt
from dlt.common import pendulum
from dlt.common import logger
from dlt.common.runtime.telemetry import with_telemetry
from dlt.common.data_writers import TLoaderFileFormat
from dlt.common.schema.typing import TWriteDisposition, TSchemaContract
from dlt.common.utils import uniq_id
from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention
from dlt.common.configuration.container import Container
from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext
from dlt.common.runtime.collector import NULL_COLLECTOR
Expand Down Expand Up @@ -135,7 +137,7 @@ def add_run(
pipeline: Pipeline,
data: Any,
*,
decompose: Literal["none", "serialize", "parallel"] = "none",
decompose: Literal["none", "serialize", "parallel", "parallel-isolated"] = "none",
table_name: str = None,
write_disposition: TWriteDisposition = None,
loader_file_format: TLoaderFileFormat = None,
Expand All @@ -158,11 +160,22 @@ def add_run(
A source decomposition strategy into Airflow tasks:
none - no decomposition, default value.
serialize - decompose the source into a sequence of Airflow tasks.
parallel - decompose the source into a parallel Airflow task group.
NOTE: In case the SequentialExecutor is used by Airflow, the tasks
will remain sequential. Use another executor, e.g. CeleryExecutor)!
NOTE: The first component of the source is done first, after that
the rest are executed in parallel to each other.
parallel - decompose the source into a parallel Airflow task group,
except the first resource must be completed first.
All tasks that are run in parallel share the same pipeline state.
If two of them modify the state, part of state may be lost
parallel-isolated - decompose the source into a parallel Airflow task group.
with the same exception as above. All task have separate pipeline
state (via separate pipeline name) but share the same dataset,
schemas and tables.
NOTE: The first component of the source in both parallel models is done first,
after that the rest are executed in parallel to each other.
NOTE: In case the SequentialExecutor is used by Airflow, the tasks
will remain sequential despite 'parallel' or 'parallel-isolated' mode.
Use another executor (e.g. CeleryExecutor) to make tasks parallel!
Parallel tasks are executed in different pipelines, all derived from the original
one, but with the state isolated from each other.
table_name: (str): The name of the table to which the data should be loaded within the `dataset`
write_disposition (TWriteDisposition, optional): Same as in `run` command. Defaults to None.
loader_file_format (Literal["jsonl", "insert_values", "parquet"], optional): The file format the loader will use to create the load package.
Expand Down Expand Up @@ -194,12 +207,12 @@ def task_name(pipeline: Pipeline, data: Any) -> str:
# use factory function to make a task, in order to parametrize it
# passing arguments to task function (_run) is serializing
# them and running template engine on them
def make_task(pipeline: Pipeline, data: Any) -> PythonOperator:
def make_task(pipeline: Pipeline, data: Any, name: str = None) -> PythonOperator:
def _run() -> None:
# activate pipeline
pipeline.activate()
# drop local data
task_pipeline = pipeline.drop()
task_pipeline = pipeline.drop(pipeline_name=name)

# use task logger
if self.use_task_logger:
Expand Down Expand Up @@ -308,15 +321,60 @@ def log_after_attempt(retry_state: RetryCallState) -> None:
if pipeline.full_refresh:
raise ValueError("Cannot decompose pipelines with full_refresh set")

# parallel tasks
tasks = []
sources = data.decompose("scc")
t_name = task_name(pipeline, data)
start = make_task(pipeline, sources[0])

# parallel tasks
for source in sources[1:]:
for resource in source.resources.values():
if resource.incremental:
logger.warn(
f"The resource {resource.name} in task {t_name} "
"is using incremental loading and may modify the "
"state. Resources that modify the state should not "
"run in parallel within the single pipeline as the "
"state will not be correctly merged. Please use "
"'serialize' or 'parallel-isolated' modes instead."
)
break

tasks.append(make_task(pipeline, source))

end = DummyOperator(task_id=f"{task_name(pipeline, data)}_end")
end = DummyOperator(task_id=f"{t_name}_end")

if tasks:
start >> tasks >> end
return [start] + tasks + [end]

start >> end
return [start, end]
elif decompose == "parallel-isolated":
if not isinstance(data, DltSource):
raise ValueError("Can only decompose dlt sources")

if pipeline.full_refresh:
raise ValueError("Cannot decompose pipelines with full_refresh set")

# parallel tasks
tasks = []
naming = SnakeCaseNamingConvention()
sources = data.decompose("scc")
start = make_task(
pipeline,
sources[0],
naming.normalize_identifier(task_name(pipeline, sources[0])),
)

# parallel tasks
for source in sources[1:]:
# name pipeline the same as task
new_pipeline_name = naming.normalize_identifier(task_name(pipeline, source))
tasks.append(make_task(pipeline, source, new_pipeline_name))

t_name = task_name(pipeline, data)
end = DummyOperator(task_id=f"{t_name}_end")

if tasks:
start >> tasks >> end
Expand All @@ -325,7 +383,10 @@ def log_after_attempt(retry_state: RetryCallState) -> None:
start >> end
return [start, end]
else:
raise ValueError("decompose value must be one of ['none', 'serialize', 'parallel']")
raise ValueError(
"decompose value must be one of ['none', 'serialize', 'parallel',"
" 'parallel-isolated']"
)

def add_fun(self, f: Callable[..., Any], **kwargs: Any) -> Any:
"""Will execute a function `f` inside an Airflow task. It is up to the function to create pipeline and source(s)"""
Expand Down
10 changes: 7 additions & 3 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,13 +324,17 @@ def __init__(
self.credentials = credentials
self._configure(import_schema_path, export_schema_path, must_attach_to_local_pipeline)

def drop(self) -> "Pipeline":
"""Deletes local pipeline state, schemas and any working files"""
def drop(self, pipeline_name: str = None) -> "Pipeline":
"""Deletes local pipeline state, schemas and any working files.
Args:
pipeline_name (str): Optional. New pipeline name.
"""
# reset the pipeline working dir
self._create_pipeline()
# clone the pipeline
return Pipeline(
self.pipeline_name,
pipeline_name or self.pipeline_name,
self.pipelines_dir,
self.pipeline_salt,
self.destination,
Expand Down
134 changes: 132 additions & 2 deletions tests/helpers/airflow_tests/test_airflow_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pytest
from unittest import mock
from typing import List
from airflow import DAG
from airflow.decorators import dag
Expand All @@ -11,6 +12,8 @@
import dlt
from dlt.common import pendulum
from dlt.common.utils import uniq_id
from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention

from dlt.helpers.airflow_helper import PipelineTasksGroup, DEFAULT_RETRY_BACKOFF
from dlt.pipeline.exceptions import CannotRestorePipelineException, PipelineStepFailed

Expand Down Expand Up @@ -75,6 +78,23 @@ def resource():
return resource


@dlt.source
def mock_data_incremental_source():
@dlt.resource
def resource1(a: str = None, b=None, c=None):
yield ["s", "a"]

@dlt.resource
def resource2(
updated_at: dlt.sources.incremental[str] = dlt.sources.incremental(
"updated_at", initial_value="1970-01-01T00:00:00Z"
)
):
yield [{"updated_at": "1970-02-01T00:00:00Z"}]

return resource1, resource2


@dlt.source(section="mock_data_source_state")
def mock_data_source_state():
@dlt.resource(selected=True)
Expand Down Expand Up @@ -263,13 +283,123 @@ def dag_parallel():
dag_def = dag_parallel()
assert len(tasks_list) == 4
dag_def.test()

pipeline_dag_parallel = dlt.attach(pipeline_name="pipeline_dag_parallel")
pipeline_dag_decomposed_counts = load_table_counts(
results = load_table_counts(
pipeline_dag_parallel,
*[t["name"] for t in pipeline_dag_parallel.default_schema.data_tables()],
)
assert pipeline_dag_decomposed_counts == pipeline_standalone_counts

assert results == pipeline_standalone_counts

# verify tasks 1-2 in between tasks 0 and 3
for task in dag_def.tasks[1:3]:
assert task.downstream_task_ids == set([dag_def.tasks[-1].task_id])
assert task.upstream_task_ids == set([dag_def.tasks[0].task_id])


def test_parallel_incremental():
pipeline_standalone = dlt.pipeline(
pipeline_name="pipeline_parallel",
dataset_name="mock_data_" + uniq_id(),
destination="duckdb",
credentials=":pipeline:",
)
pipeline_standalone.run(mock_data_incremental_source())

tasks_list: List[PythonOperator] = None

quackdb_path = os.path.join(TEST_STORAGE_ROOT, "pipeline_dag_parallel.duckdb")

@dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args)
def dag_parallel():
nonlocal tasks_list
tasks = PipelineTasksGroup(
"pipeline_dag_parallel", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False
)

# set duckdb to be outside of pipeline folder which is dropped on each task
pipeline_dag_parallel = dlt.pipeline(
pipeline_name="pipeline_dag_parallel",
dataset_name="mock_data_" + uniq_id(),
destination="duckdb",
credentials=quackdb_path,
)
tasks.add_run(
pipeline_dag_parallel,
mock_data_incremental_source(),
decompose="parallel",
trigger_rule="all_done",
retries=0,
provide_context=True,
)

with mock.patch("dlt.helpers.airflow_helper.logger.warn") as warn_mock:
dag_def = dag_parallel()
dag_def.test()
warn_mock.assert_called_once()


def test_parallel_isolated_run():
pipeline_standalone = dlt.pipeline(
pipeline_name="pipeline_parallel",
dataset_name="mock_data_" + uniq_id(),
destination="duckdb",
credentials=":pipeline:",
)
pipeline_standalone.run(mock_data_source())
pipeline_standalone_counts = load_table_counts(
pipeline_standalone, *[t["name"] for t in pipeline_standalone.default_schema.data_tables()]
)

tasks_list: List[PythonOperator] = None

quackdb_path = os.path.join(TEST_STORAGE_ROOT, "pipeline_dag_parallel.duckdb")

@dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args)
def dag_parallel():
nonlocal tasks_list
tasks = PipelineTasksGroup(
"pipeline_dag_parallel", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False
)

# set duckdb to be outside of pipeline folder which is dropped on each task
pipeline_dag_parallel = dlt.pipeline(
pipeline_name="pipeline_dag_parallel",
dataset_name="mock_data_" + uniq_id(),
destination="duckdb",
credentials=quackdb_path,
)
tasks_list = tasks.add_run(
pipeline_dag_parallel,
mock_data_source(),
decompose="parallel-isolated",
trigger_rule="all_done",
retries=0,
provide_context=True,
)

dag_def = dag_parallel()
assert len(tasks_list) == 4
dag_def.test()

results = {}
snake_case = SnakeCaseNamingConvention()
for i in range(0, 3):
pipeline_dag_parallel = dlt.attach(
pipeline_name=snake_case.normalize_identifier(
dag_def.tasks[i].task_id.replace("pipeline_dag_parallel.", "")
)
)
pipeline_dag_decomposed_counts = load_table_counts(
pipeline_dag_parallel,
*[t["name"] for t in pipeline_dag_parallel.default_schema.data_tables()],
)
results.update(pipeline_dag_decomposed_counts)

assert results == pipeline_standalone_counts

# verify tasks 1-2 in between tasks 0 and 3
for task in dag_def.tasks[1:3]:
assert task.downstream_task_ids == set([dag_def.tasks[-1].task_id])
assert task.upstream_task_ids == set([dag_def.tasks[0].task_id])
Expand Down
2 changes: 1 addition & 1 deletion tests/load/filesystem/test_filesystem_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def check_file_exists():
def check_file_changed():
details = filesystem.info(file_url)
assert details["size"] == 11
assert (MTIME_DISPATCH[config.protocol](details) - now).seconds < 60
assert (MTIME_DISPATCH[config.protocol](details) - now).seconds < 120

bucket_url = os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"]
config = get_config()
Expand Down
10 changes: 10 additions & 0 deletions tests/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,6 +1258,16 @@ def generic(start):
assert pipeline.default_schema.get_table("single_table")["resource"] == "state1"


def test_drop_with_new_name() -> None:
old_test_name = "old_pipeline_name"
new_test_name = "new_pipeline_name"

pipeline = dlt.pipeline(pipeline_name=old_test_name, destination="duckdb")
new_pipeline = pipeline.drop(pipeline_name=new_test_name)

assert new_pipeline.pipeline_name == new_test_name


def test_remove_autodetect() -> None:
now = pendulum.now()

Expand Down

0 comments on commit 14b0a66

Please sign in to comment.