From 999982ec1308b273daaf81f5b117497645d12248 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 22 May 2024 14:47:11 -0400 Subject: [PATCH] pipeline.run/extract refresh argument --- dlt/pipeline/helpers.py | 19 +++---- dlt/pipeline/pipeline.py | 15 ++++-- tests/pipeline/test_refresh_modes.py | 79 +++++++++++++++++++++++----- 3 files changed, 88 insertions(+), 25 deletions(-) diff --git a/dlt/pipeline/helpers.py b/dlt/pipeline/helpers.py index 0359c53ca2..811fe31733 100644 --- a/dlt/pipeline/helpers.py +++ b/dlt/pipeline/helpers.py @@ -31,6 +31,7 @@ StateInjectableContext, Container, pipeline_state as current_pipeline_state, + TRefreshMode, ) from dlt.common.destination.reference import WithStagingDataset @@ -179,29 +180,29 @@ def drop( return DropCommand(pipeline, resources, schema_name, state_paths, drop_all, state_only)() -def refresh_source(pipeline: "Pipeline", source: DltSource) -> Dict[str, Any]: +def refresh_source( + pipeline: "Pipeline", source: DltSource, refresh: TRefreshMode +) -> Dict[str, Any]: """Run the pipeline's refresh mode on the given source, updating the source's schema and state. Returns: The new load package state containing tables that need to be dropped/truncated. """ - pipeline_state, _ = current_pipeline_state(pipeline._container) - if pipeline.refresh is None or pipeline.first_run: + if pipeline.first_run: return {} - _resources_to_drop = ( - list(source.resources.extracted) if pipeline.refresh != "drop_sources" else [] - ) + pipeline_state, _ = current_pipeline_state(pipeline._container) + _resources_to_drop = list(source.resources.extracted) if refresh != "drop_sources" else [] drop_result = drop_resources( source.schema, pipeline_state, resources=_resources_to_drop, - drop_all=pipeline.refresh == "drop_sources", - state_paths="*" if pipeline.refresh == "drop_sources" else [], + drop_all=refresh == "drop_sources", + state_paths="*" if refresh == "drop_sources" else [], sources=source.name, ) load_package_state = {} if drop_result.dropped_tables: - key = "dropped_tables" if pipeline.refresh != "drop_data" else "truncated_tables" + key = "dropped_tables" if refresh != "drop_data" else "truncated_tables" load_package_state[key] = drop_result.dropped_tables source.schema = drop_result.schema if "sources" in drop_result.state: diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 5ab65df423..81b50a8326 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -412,6 +412,7 @@ def extract( max_parallel_items: int = None, workers: int = None, schema_contract: TSchemaContract = None, + refresh: Optional[TRefreshMode] = None, ) -> ExtractInfo: """Extracts the `data` and prepare it for the normalization. Does not require destination or credentials to be configured. See `run` method for the arguments' description.""" @@ -440,7 +441,11 @@ def extract( raise SourceExhausted(source.name) self._extract_source( - extract_step, source, max_parallel_items, workers, with_refresh=True + extract_step, + source, + max_parallel_items, + workers, + refresh=refresh or self.refresh, ) # extract state state: TPipelineStateDoc = None @@ -593,6 +598,7 @@ def run( schema: Schema = None, loader_file_format: TLoaderFileFormat = None, schema_contract: TSchemaContract = None, + refresh: Optional[TRefreshMode] = None, ) -> LoadInfo: """Loads the data from `data` argument into the destination specified in `destination` and dataset specified in `dataset_name`. @@ -697,6 +703,7 @@ def run( primary_key=primary_key, schema=schema, schema_contract=schema_contract, + refresh=refresh or self.refresh, ) self.normalize(loader_file_format=loader_file_format) return self.load(destination, dataset_name, credentials=credentials) @@ -1106,7 +1113,7 @@ def _extract_source( source: DltSource, max_parallel_items: int, workers: int, - with_refresh: bool = False, + refresh: Optional[TRefreshMode] = None, load_package_state_update: Optional[Dict[str, Any]] = None, ) -> str: # discover the existing pipeline schema @@ -1127,8 +1134,8 @@ def _extract_source( pass load_package_state_update = dict(load_package_state_update or {}) - if with_refresh: - load_package_state_update.update(refresh_source(self, source)) + if refresh: + load_package_state_update.update(refresh_source(self, source, refresh)) # extract into pipeline schema load_id = extract.extract( diff --git a/tests/pipeline/test_refresh_modes.py b/tests/pipeline/test_refresh_modes.py index 49c8a9c3b0..a967a36877 100644 --- a/tests/pipeline/test_refresh_modes.py +++ b/tests/pipeline/test_refresh_modes.py @@ -36,7 +36,7 @@ def column_values(cursor: DBApiCursor, column_name: str) -> List[Any]: @dlt.source -def refresh_source(first_run: bool = True, drop_dataset: bool = False): +def refresh_source(first_run: bool = True, drop_sources: bool = False): @dlt.resource def some_data_1(): if first_run: @@ -49,7 +49,7 @@ def some_data_1(): else: # Check state is cleared for this resource assert not resource_state("some_data_1") - if drop_dataset: + if drop_sources: assert_source_state_is_wiped(dlt.state()) # Second dataset without name column to test tables are re-created yield {"id": 3} @@ -65,7 +65,7 @@ def some_data_2(): yield {"id": 6, "name": "Jill"} else: assert not resource_state("some_data_2") - if drop_dataset: + if drop_sources: assert_source_state_is_wiped(dlt.state()) yield {"id": 7} yield {"id": 8} @@ -79,7 +79,7 @@ def some_data_3(): yield {"id": 10, "name": "Jill"} else: assert not resource_state("some_data_3") - if drop_dataset: + if drop_sources: assert_source_state_is_wiped(dlt.state()) yield {"id": 11} yield {"id": 12} @@ -94,7 +94,7 @@ def some_data_4(): yield some_data_4 -def test_refresh_drop_dataset(): +def test_refresh_drop_sources(): # First run pipeline with load to destination so tables are created pipeline = dlt.pipeline( @@ -104,12 +104,12 @@ def test_refresh_drop_dataset(): dataset_name="refresh_full_test", ) - info = pipeline.run(refresh_source(first_run=True, drop_dataset=True)) + info = pipeline.run(refresh_source(first_run=True, drop_sources=True)) assert_load_info(info) # Second run of pipeline with only selected resources info = pipeline.run( - refresh_source(first_run=False, drop_dataset=True).with_resources( + refresh_source(first_run=False, drop_sources=True).with_resources( "some_data_1", "some_data_2" ) ) @@ -154,13 +154,13 @@ def test_existing_schema_hash(): dataset_name="refresh_full_test", ) - info = pipeline.run(refresh_source(first_run=True, drop_dataset=True)) + info = pipeline.run(refresh_source(first_run=True, drop_sources=True)) assert_load_info(info) first_schema_hash = pipeline.default_schema.version_hash # Second run with all tables dropped and only some tables re-created info = pipeline.run( - refresh_source(first_run=False, drop_dataset=True).with_resources( + refresh_source(first_run=False, drop_sources=True).with_resources( "some_data_1", "some_data_2" ) ) @@ -173,7 +173,7 @@ def test_existing_schema_hash(): # Run again with all tables to ensure they are re-created # The new schema in this case should match the schema of the first run exactly - info = pipeline.run(refresh_source(first_run=True, drop_dataset=True)) + info = pipeline.run(refresh_source(first_run=True, drop_sources=True)) # Check table 3 was re-created with pipeline.sql_client() as client: result = client.execute_sql("SELECT id, name FROM some_data_3 ORDER BY id") @@ -293,7 +293,7 @@ def test_refresh_drop_data_only(): assert source_state["resources"]["some_data_3"] == {"run1_1": "value1_1"} -def test_refresh_drop_dataset_multiple_sources(): +def test_refresh_drop_sources_multiple_sources(): """ Ensure only state and tables for currently selected source is dropped """ @@ -346,7 +346,7 @@ def source_2_data_2(): # Run both sources info = pipeline.run( - [refresh_source(first_run=True, drop_dataset=True), refresh_source_2(first_run=True)] + [refresh_source(first_run=True, drop_sources=True), refresh_source_2(first_run=True)] ) assert_load_info(info, 2) # breakpoint() @@ -381,3 +381,58 @@ def source_2_data_2(): with pytest.raises(DatabaseUndefinedRelation): with pipeline.sql_client() as client: result = client.execute_sql("SELECT * FROM source_2_data_2") + + +def test_refresh_argument_to_run(): + pipeline = dlt.pipeline( + "refresh_full_test", + destination="duckdb", + dataset_name="refresh_full_test", + ) + + info = pipeline.run(refresh_source(first_run=True)) + assert_load_info(info) + + info = pipeline.run( + refresh_source(first_run=False).with_resources("some_data_3"), + refresh="drop_sources", + ) + assert_load_info(info) + + # Check local schema to confirm refresh was at all applied + tables = set(t["name"] for t in pipeline.default_schema.data_tables()) + assert tables == {"some_data_3"} + + # Run again without refresh to confirm refresh option doesn't persist on pipeline + info = pipeline.run(refresh_source(first_run=False).with_resources("some_data_2")) + assert_load_info(info) + + # Nothing is dropped + tables = set(t["name"] for t in pipeline.default_schema.data_tables()) + assert tables == {"some_data_2", "some_data_3"} + + +def test_refresh_argument_to_extract(): + pipeline = dlt.pipeline( + "refresh_full_test", + destination="duckdb", + dataset_name="refresh_full_test", + ) + + info = pipeline.run(refresh_source(first_run=True)) + assert_load_info(info) + + pipeline.extract( + refresh_source(first_run=False).with_resources("some_data_3"), + refresh="drop_sources", + ) + + tables = set(t["name"] for t in pipeline.default_schema.data_tables(include_incomplete=True)) + # All other data tables removed + assert tables == {"some_data_3", "some_data_4"} + + # Run again without refresh to confirm refresh option doesn't persist on pipeline + pipeline.extract(refresh_source(first_run=False).with_resources("some_data_2")) + + tables = set(t["name"] for t in pipeline.default_schema.data_tables(include_incomplete=True)) + assert tables == {"some_data_2", "some_data_3", "some_data_4"}