From 1b17ded7ece109d7f738bbe82af83199b3a87e10 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Fri, 10 May 2024 17:48:01 -0400 Subject: [PATCH] Cleanup --- dlt/pipeline/helpers.py | 31 +++++++++++++++++++++++++++++++ dlt/pipeline/pipeline.py | 28 ++-------------------------- 2 files changed, 33 insertions(+), 26 deletions(-) diff --git a/dlt/pipeline/helpers.py b/dlt/pipeline/helpers.py index 2970c6f6e1..5d4a3f3ba3 100644 --- a/dlt/pipeline/helpers.py +++ b/dlt/pipeline/helpers.py @@ -30,6 +30,7 @@ _get_matching_resources, StateInjectableContext, Container, + pipeline_state as current_pipeline_state, ) from dlt.common.destination.reference import WithStagingDataset @@ -43,6 +44,7 @@ from dlt.pipeline.typing import TPipelineStep from dlt.pipeline.drop import drop_resources from dlt.common.configuration.exceptions import ContextDefaultCannotBeCreated +from dlt.extract import DltSource if TYPE_CHECKING: from dlt.pipeline import Pipeline @@ -175,3 +177,32 @@ def drop( state_only: bool = False, ) -> None: return DropCommand(pipeline, resources, schema_name, state_paths, drop_all, state_only)() + + +def refresh_source(pipeline: "Pipeline", source: DltSource) -> Dict[str, Any]: + """Run the pipeline's refresh mode on the given source, updating the source's schema and state. + + Returns: + The new load package state containing tables that need to be dropped/truncated. + """ + pipeline_state, _ = current_pipeline_state(pipeline._container) + if pipeline.refresh is None or pipeline.first_run: + return {} + _resources_to_drop = ( + list(source.resources.extracted) if pipeline.refresh != "drop_dataset" else [] + ) + drop_result = drop_resources( + source.schema, + pipeline_state, + resources=_resources_to_drop, + drop_all=pipeline.refresh == "drop_dataset", + state_paths="*" if pipeline.refresh == "drop_dataset" else [], + ) + load_package_state = {} + if drop_result.dropped_tables: + key = "dropped_tables" if pipeline.refresh != "drop_data" else "truncated_tables" + load_package_state[key] = drop_result.dropped_tables + source.schema = drop_result.schema + if "sources" in drop_result.state: + pipeline_state["sources"] = drop_result.state["sources"] + return load_package_state diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index fa301a1261..58e30fbdf6 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -150,7 +150,7 @@ ) from dlt.pipeline.warnings import credentials_argument_deprecated from dlt.common.storages.load_package import TLoadPackageState -from dlt.pipeline.drop import drop_resources +from dlt.pipeline.helpers import refresh_source def with_state_sync(may_extract_state: bool = False) -> Callable[[TFun], TFun]: @@ -1095,25 +1095,6 @@ def _wipe_working_folder(self) -> None: def _attach_pipeline(self) -> None: pass - def _refresh_source(self, source: DltSource) -> Tuple[Schema, TPipelineState, Dict[str, Any]]: - if self.refresh is None or self.first_run: - return source.schema, self.state, {} - _resources_to_drop = ( - list(source.resources.extracted) if self.refresh != "drop_dataset" else [] - ) - drop_result = drop_resources( - source.schema, - self.state, - resources=_resources_to_drop, - drop_all=self.refresh == "drop_dataset", - state_paths="*" if self.refresh == "drop_dataset" else [], - ) - load_package_state = {} - if drop_result.dropped_tables: - key = "dropped_tables" if self.refresh != "drop_data" else "truncated_tables" - load_package_state[key] = drop_result.dropped_tables - return drop_result.schema, drop_result.state, load_package_state - def _extract_source( self, extract: Extract, @@ -1142,12 +1123,7 @@ def _extract_source( load_package_state_update = dict(load_package_state_update or {}) if with_refresh: - new_schema, new_state, load_package_state = self._refresh_source(source) - load_package_state_update.update(load_package_state) - source.schema = new_schema - state, _ = current_pipeline_state(self._container) - if "sources" in new_state: - state["sources"] = new_state["sources"] + load_package_state_update.update(refresh_source(self, source)) # extract into pipeline schema load_id = extract.extract(