Skip to content

Commit

Permalink
pipeline.run/extract refresh argument
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed May 22, 2024
1 parent 35299fa commit 999982e
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 25 deletions.
19 changes: 10 additions & 9 deletions dlt/pipeline/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
StateInjectableContext,
Container,
pipeline_state as current_pipeline_state,
TRefreshMode,
)
from dlt.common.destination.reference import WithStagingDataset

Expand Down Expand Up @@ -179,29 +180,29 @@ def drop(
return DropCommand(pipeline, resources, schema_name, state_paths, drop_all, state_only)()


def refresh_source(pipeline: "Pipeline", source: DltSource) -> Dict[str, Any]:
def refresh_source(
pipeline: "Pipeline", source: DltSource, refresh: TRefreshMode
) -> Dict[str, Any]:
"""Run the pipeline's refresh mode on the given source, updating the source's schema and state.
Returns:
The new load package state containing tables that need to be dropped/truncated.
"""
pipeline_state, _ = current_pipeline_state(pipeline._container)
if pipeline.refresh is None or pipeline.first_run:
if pipeline.first_run:
return {}
_resources_to_drop = (
list(source.resources.extracted) if pipeline.refresh != "drop_sources" else []
)
pipeline_state, _ = current_pipeline_state(pipeline._container)
_resources_to_drop = list(source.resources.extracted) if refresh != "drop_sources" else []
drop_result = drop_resources(
source.schema,
pipeline_state,
resources=_resources_to_drop,
drop_all=pipeline.refresh == "drop_sources",
state_paths="*" if pipeline.refresh == "drop_sources" else [],
drop_all=refresh == "drop_sources",
state_paths="*" if refresh == "drop_sources" else [],
sources=source.name,
)
load_package_state = {}
if drop_result.dropped_tables:
key = "dropped_tables" if pipeline.refresh != "drop_data" else "truncated_tables"
key = "dropped_tables" if refresh != "drop_data" else "truncated_tables"
load_package_state[key] = drop_result.dropped_tables
source.schema = drop_result.schema
if "sources" in drop_result.state:
Expand Down
15 changes: 11 additions & 4 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ def extract(
max_parallel_items: int = None,
workers: int = None,
schema_contract: TSchemaContract = None,
refresh: Optional[TRefreshMode] = None,
) -> ExtractInfo:
"""Extracts the `data` and prepare it for the normalization. Does not require destination or credentials to be configured. See `run` method for the arguments' description."""

Expand Down Expand Up @@ -440,7 +441,11 @@ def extract(
raise SourceExhausted(source.name)

self._extract_source(
extract_step, source, max_parallel_items, workers, with_refresh=True
extract_step,
source,
max_parallel_items,
workers,
refresh=refresh or self.refresh,
)
# extract state
state: TPipelineStateDoc = None
Expand Down Expand Up @@ -593,6 +598,7 @@ def run(
schema: Schema = None,
loader_file_format: TLoaderFileFormat = None,
schema_contract: TSchemaContract = None,
refresh: Optional[TRefreshMode] = None,
) -> LoadInfo:
"""Loads the data from `data` argument into the destination specified in `destination` and dataset specified in `dataset_name`.
Expand Down Expand Up @@ -697,6 +703,7 @@ def run(
primary_key=primary_key,
schema=schema,
schema_contract=schema_contract,
refresh=refresh or self.refresh,
)
self.normalize(loader_file_format=loader_file_format)
return self.load(destination, dataset_name, credentials=credentials)
Expand Down Expand Up @@ -1106,7 +1113,7 @@ def _extract_source(
source: DltSource,
max_parallel_items: int,
workers: int,
with_refresh: bool = False,
refresh: Optional[TRefreshMode] = None,
load_package_state_update: Optional[Dict[str, Any]] = None,
) -> str:
# discover the existing pipeline schema
Expand All @@ -1127,8 +1134,8 @@ def _extract_source(
pass

load_package_state_update = dict(load_package_state_update or {})
if with_refresh:
load_package_state_update.update(refresh_source(self, source))
if refresh:
load_package_state_update.update(refresh_source(self, source, refresh))

# extract into pipeline schema
load_id = extract.extract(
Expand Down
79 changes: 67 additions & 12 deletions tests/pipeline/test_refresh_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def column_values(cursor: DBApiCursor, column_name: str) -> List[Any]:


@dlt.source
def refresh_source(first_run: bool = True, drop_dataset: bool = False):
def refresh_source(first_run: bool = True, drop_sources: bool = False):
@dlt.resource
def some_data_1():
if first_run:
Expand All @@ -49,7 +49,7 @@ def some_data_1():
else:
# Check state is cleared for this resource
assert not resource_state("some_data_1")
if drop_dataset:
if drop_sources:
assert_source_state_is_wiped(dlt.state())
# Second dataset without name column to test tables are re-created
yield {"id": 3}
Expand All @@ -65,7 +65,7 @@ def some_data_2():
yield {"id": 6, "name": "Jill"}
else:
assert not resource_state("some_data_2")
if drop_dataset:
if drop_sources:
assert_source_state_is_wiped(dlt.state())
yield {"id": 7}
yield {"id": 8}
Expand All @@ -79,7 +79,7 @@ def some_data_3():
yield {"id": 10, "name": "Jill"}
else:
assert not resource_state("some_data_3")
if drop_dataset:
if drop_sources:
assert_source_state_is_wiped(dlt.state())
yield {"id": 11}
yield {"id": 12}
Expand All @@ -94,7 +94,7 @@ def some_data_4():
yield some_data_4


def test_refresh_drop_dataset():
def test_refresh_drop_sources():

# First run pipeline with load to destination so tables are created
pipeline = dlt.pipeline(
Expand All @@ -104,12 +104,12 @@ def test_refresh_drop_dataset():
dataset_name="refresh_full_test",
)

info = pipeline.run(refresh_source(first_run=True, drop_dataset=True))
info = pipeline.run(refresh_source(first_run=True, drop_sources=True))
assert_load_info(info)

# Second run of pipeline with only selected resources
info = pipeline.run(
refresh_source(first_run=False, drop_dataset=True).with_resources(
refresh_source(first_run=False, drop_sources=True).with_resources(
"some_data_1", "some_data_2"
)
)
Expand Down Expand Up @@ -154,13 +154,13 @@ def test_existing_schema_hash():
dataset_name="refresh_full_test",
)

info = pipeline.run(refresh_source(first_run=True, drop_dataset=True))
info = pipeline.run(refresh_source(first_run=True, drop_sources=True))
assert_load_info(info)
first_schema_hash = pipeline.default_schema.version_hash

# Second run with all tables dropped and only some tables re-created
info = pipeline.run(
refresh_source(first_run=False, drop_dataset=True).with_resources(
refresh_source(first_run=False, drop_sources=True).with_resources(
"some_data_1", "some_data_2"
)
)
Expand All @@ -173,7 +173,7 @@ def test_existing_schema_hash():

# Run again with all tables to ensure they are re-created
# The new schema in this case should match the schema of the first run exactly
info = pipeline.run(refresh_source(first_run=True, drop_dataset=True))
info = pipeline.run(refresh_source(first_run=True, drop_sources=True))
# Check table 3 was re-created
with pipeline.sql_client() as client:
result = client.execute_sql("SELECT id, name FROM some_data_3 ORDER BY id")
Expand Down Expand Up @@ -293,7 +293,7 @@ def test_refresh_drop_data_only():
assert source_state["resources"]["some_data_3"] == {"run1_1": "value1_1"}


def test_refresh_drop_dataset_multiple_sources():
def test_refresh_drop_sources_multiple_sources():
"""
Ensure only state and tables for currently selected source is dropped
"""
Expand Down Expand Up @@ -346,7 +346,7 @@ def source_2_data_2():

# Run both sources
info = pipeline.run(
[refresh_source(first_run=True, drop_dataset=True), refresh_source_2(first_run=True)]
[refresh_source(first_run=True, drop_sources=True), refresh_source_2(first_run=True)]
)
assert_load_info(info, 2)
# breakpoint()
Expand Down Expand Up @@ -381,3 +381,58 @@ def source_2_data_2():
with pytest.raises(DatabaseUndefinedRelation):
with pipeline.sql_client() as client:
result = client.execute_sql("SELECT * FROM source_2_data_2")


def test_refresh_argument_to_run():
pipeline = dlt.pipeline(
"refresh_full_test",
destination="duckdb",
dataset_name="refresh_full_test",
)

info = pipeline.run(refresh_source(first_run=True))
assert_load_info(info)

info = pipeline.run(
refresh_source(first_run=False).with_resources("some_data_3"),
refresh="drop_sources",
)
assert_load_info(info)

# Check local schema to confirm refresh was at all applied
tables = set(t["name"] for t in pipeline.default_schema.data_tables())
assert tables == {"some_data_3"}

# Run again without refresh to confirm refresh option doesn't persist on pipeline
info = pipeline.run(refresh_source(first_run=False).with_resources("some_data_2"))
assert_load_info(info)

# Nothing is dropped
tables = set(t["name"] for t in pipeline.default_schema.data_tables())
assert tables == {"some_data_2", "some_data_3"}


def test_refresh_argument_to_extract():
pipeline = dlt.pipeline(
"refresh_full_test",
destination="duckdb",
dataset_name="refresh_full_test",
)

info = pipeline.run(refresh_source(first_run=True))
assert_load_info(info)

pipeline.extract(
refresh_source(first_run=False).with_resources("some_data_3"),
refresh="drop_sources",
)

tables = set(t["name"] for t in pipeline.default_schema.data_tables(include_incomplete=True))
# All other data tables removed
assert tables == {"some_data_3", "some_data_4"}

# Run again without refresh to confirm refresh option doesn't persist on pipeline
pipeline.extract(refresh_source(first_run=False).with_resources("some_data_2"))

tables = set(t["name"] for t in pipeline.default_schema.data_tables(include_incomplete=True))
assert tables == {"some_data_2", "some_data_3", "some_data_4"}

0 comments on commit 999982e

Please sign in to comment.