Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(airflow): allow data to be a callable #1318

Merged
merged 5 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions dlt/helpers/airflow_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def run(
loader_file_format: TLoaderFileFormat = None,
schema_contract: TSchemaContract = None,
pipeline_name: str = None,
on_before_run: Callable[[], None] = None,
**kwargs: Any,
) -> PythonOperator:
"""
Expand All @@ -179,7 +180,12 @@ def run(

Args:
pipeline (Pipeline): The pipeline to run
data (Any): The data to run the pipeline with
data (Any):
The data to run the pipeline with. If a non-resource
callable given, it's evaluated during the DAG execution,
right before the actual pipeline run.
NOTE: If `on_before_run` is provided, first `on_before_run`
is evaluated, and then callable `data`.
table_name (str, optional): The name of the table to
which the data should be loaded within the `dataset`.
write_disposition (TWriteDispositionConfig, optional): Same as
Expand All @@ -191,6 +197,8 @@ def run(
for the schema contract settings, this will replace
the schema contract settings for all tables in the schema.
pipeline_name (str, optional): The name of the derived pipeline.
on_before_run (Callable, optional): A callable to be
executed right before the actual pipeline run.

Returns:
PythonOperator: Airflow task instance.
Expand All @@ -204,6 +212,7 @@ def run(
loader_file_format=loader_file_format,
schema_contract=schema_contract,
pipeline_name=pipeline_name,
on_before_run=on_before_run,
)
return PythonOperator(task_id=self._task_name(pipeline, data), python_callable=f, **kwargs)

Expand All @@ -216,12 +225,18 @@ def _run(
loader_file_format: TLoaderFileFormat = None,
schema_contract: TSchemaContract = None,
pipeline_name: str = None,
on_before_run: Callable[[], None] = None,
) -> None:
"""Run the given pipeline with the given data.

Args:
pipeline (Pipeline): The pipeline to run
data (Any): The data to run the pipeline with
data (Any):
The data to run the pipeline with. If a non-resource
callable given, it's evaluated during the DAG execution,
right before the actual pipeline run.
NOTE: If `on_before_run` is provided, first `on_before_run`
is evaluated, and then callable `data`.
table_name (str, optional): The name of the
table to which the data should be loaded
within the `dataset`.
Expand All @@ -236,6 +251,8 @@ def _run(
for all tables in the schema.
pipeline_name (str, optional): The name of the
derived pipeline.
on_before_run (Callable, optional): A callable
to be executed right before the actual pipeline run.
"""
# activate pipeline
pipeline.activate()
Expand Down Expand Up @@ -271,6 +288,12 @@ def log_after_attempt(retry_state: RetryCallState) -> None:
)

try:
if on_before_run is not None:
on_before_run()

if callable(data):
data = data()
rudolfix marked this conversation as resolved.
Show resolved Hide resolved

# retry with given policy on selected pipeline steps
for attempt in self.retry_policy.copy(
retry=retry_if_exception(
Expand Down Expand Up @@ -325,6 +348,7 @@ def add_run(
write_disposition: TWriteDispositionConfig = None,
loader_file_format: TLoaderFileFormat = None,
schema_contract: TSchemaContract = None,
on_before_run: Callable[[], None] = None,
**kwargs: Any,
) -> List[PythonOperator]:
"""Creates a task or a group of tasks to run `data` with `pipeline`
Expand All @@ -338,7 +362,10 @@ def add_run(

Args:
pipeline (Pipeline): An instance of pipeline used to run the source
data (Any): Any data supported by `run` method of the pipeline
data (Any):
Any data supported by `run` method of the pipeline.
If a non-resource callable given, it's called before
the load to get the data.
decompose (Literal["none", "serialize", "parallel"], optional):
A source decomposition strategy into Airflow tasks:
none - no decomposition, default value.
Expand All @@ -365,6 +392,8 @@ def add_run(
Not all file_formats are compatible with all destinations. Defaults to the preferred file format of the selected destination.
schema_contract (TSchemaContract, optional): On override for the schema contract settings,
this will replace the schema contract settings for all tables in the schema. Defaults to None.
on_before_run (Callable, optional):
A callable to be executed right before the actual pipeline run.

Returns:
Any: Airflow tasks created in order of creation.
Expand All @@ -391,6 +420,7 @@ def make_task(pipeline: Pipeline, data: Any, name: str = None) -> PythonOperator
loader_file_format=loader_file_format,
schema_contract=schema_contract,
pipeline_name=name,
on_before_run=on_before_run,
)
return PythonOperator(
task_id=self._task_name(pipeline, data), python_callable=f, **kwargs
Expand Down
82 changes: 80 additions & 2 deletions tests/helpers/airflow_tests/test_airflow_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from typing import List
from airflow import DAG
from airflow.decorators import dag
from airflow.operators.python import PythonOperator
from airflow.operators.python import PythonOperator, get_current_context
from airflow.models import TaskInstance
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType

import dlt
from dlt.common import pendulum
from dlt.common import logger, pendulum
from dlt.common.utils import uniq_id
from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention

Expand Down Expand Up @@ -917,3 +917,81 @@ def dag_parallel():
dag_def = dag_parallel()
assert len(tasks_list) == 1
dag_def.test()


def callable_source():
@dlt.resource
def test_res():
context = get_current_context()
yield [
{"id": 1, "tomorrow": context["tomorrow_ds"]},
{"id": 2, "tomorrow": context["tomorrow_ds"]},
{"id": 3, "tomorrow": context["tomorrow_ds"]},
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
]

return test_res


def test_run_callable() -> None:
quackdb_path = os.path.join(TEST_STORAGE_ROOT, "callable_dag.duckdb")

@dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args)
def dag_regular():
tasks = PipelineTasksGroup(
"callable_dag_group", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False
)

call_dag = dlt.pipeline(
pipeline_name="callable_dag",
dataset_name="mock_data_" + uniq_id(),
destination="duckdb",
credentials=quackdb_path,
)
tasks.run(call_dag, callable_source)

dag_def: DAG = dag_regular()
dag_def.test()

pipeline_dag = dlt.attach(pipeline_name="callable_dag")

with pipeline_dag.sql_client() as client:
with client.execute_query("SELECT * FROM test_res") as result:
results = result.fetchall()

assert len(results) == 3

for row in results:
assert row[1] == pendulum.tomorrow().format("YYYY-MM-DD")


def on_before_run():
context = get_current_context()
logger.info(f'on_before_run test: {context["tomorrow_ds"]}')


def test_on_before_run() -> None:
quackdb_path = os.path.join(TEST_STORAGE_ROOT, "callable_dag.duckdb")

@dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args)
def dag_regular():
tasks = PipelineTasksGroup(
"callable_dag_group", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False
)

call_dag = dlt.pipeline(
pipeline_name="callable_dag",
dataset_name="mock_data_" + uniq_id(),
destination="duckdb",
credentials=quackdb_path,
)
tasks.run(call_dag, mock_data_source, on_before_run=on_before_run)

dag_def: DAG = dag_regular()

with mock.patch("dlt.helpers.airflow_helper.logger.info") as logger_mock:
dag_def.test()
logger_mock.assert_has_calls(
[
mock.call(f'on_before_run test: {pendulum.tomorrow().format("YYYY-MM-DD")}'),
]
)
Loading