diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index 1c76fd39cd..8870024de9 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -61,6 +61,8 @@ class TLoadPackageState(TVersionedState, total=False): """A section of state that does not participate in change merging and version control""" destination_state: NotRequired[Dict[str, Any]] """private space for destinations to store state relevant only to the load package""" + source_state: NotRequired[Dict[str, Any]] + """private space for source to store state relevant only to the load package, currently used for storing pipeline state""" class TLoadPackage(TypedDict, total=False): @@ -689,6 +691,12 @@ def destination_state() -> DictStrAny: return lp["state"].setdefault("destination_state", {}) +def load_package_source_state() -> DictStrAny: + """Get segment of load package state that is specific to the current destination.""" + lp = load_package() + return lp["state"].setdefault("source_state", {}) + + def clear_destination_state(commit: bool = True) -> None: """Clear segment of load package state that is specific to the current destination. Optionally commit to load package.""" lp = load_package() diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index 0f8d85b26d..a8cd3e9422 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -194,12 +194,12 @@ def update_stored_schema( self.fs_client.touch(posixpath.join(directory, INIT_FILE_NAME)) # write schema to destination - self.store_current_schema(load_id or "1") + self._store_current_schema(load_id or "1") return expected_update def _get_table_dirs(self, table_names: Iterable[str]) -> List[str]: - """Gets unique directories where table data is stored.""" + """Gets directories where table data is stored.""" table_dirs: List[str] = [] for table_name in table_names: # dlt tables do not respect layout (for now) @@ -268,10 +268,7 @@ def _list_dlt_dir(self, dirname: str) -> Iterator[Tuple[str, List[str]]]: continue yield filepath, fileparts - def complete_load(self, load_id: str) -> None: - # store current state - self.store_current_state(load_id) - + def _store_load(self, load_id: str) -> None: # write entry to load "table" # TODO: this is also duplicate across all destinations. DRY this. load_data = { @@ -282,9 +279,13 @@ def complete_load(self, load_id: str) -> None: "schema_version_hash": self.schema.version_hash, } filepath = f"{self.dataset_path}/{self.schema.loads_table_name}/{self.schema.name}__{load_id}.jsonl" - self._write_to_json_file(filepath, load_data) + def complete_load(self, load_id: str) -> None: + # store current state + self._store_current_state(load_id) + self._store_load(load_id) + # # state read/write # @@ -293,19 +294,20 @@ def _get_state_file_name(self, pipeline_name: str, version_hash: str, load_id: s """gets full path for schema file for a given hash""" return f"{self.dataset_path}/{self.schema.state_table_name}/{pipeline_name}__{load_id}__{self._to_path_safe_string(version_hash)}.jsonl" - def store_current_state(self, load_id: str) -> None: + def _store_current_state(self, load_id: str) -> None: # get state doc from current pipeline - from dlt.common.configuration.container import Container - from dlt.common.pipeline import PipelineContext - from dlt.pipeline.state_sync import state_doc + from dlt.pipeline.current import load_package_source_state + from dlt.pipeline.state_sync import LOAD_PACKAGE_STATE_KEY + + doc = load_package_source_state().get(LOAD_PACKAGE_STATE_KEY, {}) - pipeline = Container()[PipelineContext].pipeline() - state = pipeline.state - doc = state_doc(state) + if not doc: + return # get paths + pipeline_name = doc["pipeline_name"] hash_path = self._get_state_file_name( - pipeline.pipeline_name, self.schema.stored_version_hash, load_id + pipeline_name, self.schema.stored_version_hash, load_id ) # write @@ -323,7 +325,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: newest_load_id = fileparts[1] selected_path = filepath - """Loads compressed state from destination storage""" + # Load compressed state from destination if selected_path: state_json = json.loads(self.fs_client.read_text(selected_path)) state_json.pop("version_hash") @@ -339,13 +341,6 @@ def _get_schema_file_name(self, version_hash: str, load_id: str) -> str: """gets full path for schema file for a given hash""" return f"{self.dataset_path}/{self.schema.version_table_name}/{self.schema.name}__{load_id}__{self._to_path_safe_string(version_hash)}.jsonl" - def get_stored_schema(self) -> Optional[StorageSchemaInfo]: - """Retrieves newest schema from destination storage""" - return self._get_stored_schema_by_hash_or_newest() - - def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]: - return self._get_stored_schema_by_hash_or_newest(version_hash) - def _get_stored_schema_by_hash_or_newest( self, version_hash: str = None ) -> Optional[StorageSchemaInfo]: @@ -372,7 +367,7 @@ def _get_stored_schema_by_hash_or_newest( return None - def store_current_schema(self, load_id: str) -> None: + def _store_current_schema(self, load_id: str) -> None: # get paths hash_path = self._get_schema_file_name(self.schema.stored_version_hash, load_id) @@ -388,3 +383,10 @@ def store_current_schema(self, load_id: str) -> None: # we always keep tabs on what the current schema is self._write_to_json_file(hash_path, version_info) + + def get_stored_schema(self) -> Optional[StorageSchemaInfo]: + """Retrieves newest schema from destination storage""" + return self._get_stored_schema_by_hash_or_newest() + + def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]: + return self._get_stored_schema_by_hash_or_newest(version_hash) diff --git a/dlt/extract/__init__.py b/dlt/extract/__init__.py index 03b2e59539..7e0dd3d0fc 100644 --- a/dlt/extract/__init__.py +++ b/dlt/extract/__init__.py @@ -1,4 +1,4 @@ -from dlt.extract.resource import DltResource, with_table_name, with_hints +from dlt.extract.resource import DltResource, with_table_name, with_hints, with_package_state from dlt.extract.hints import make_hints from dlt.extract.source import DltSource from dlt.extract.decorators import source, resource, transformer, defer diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index cc2b03c50b..f07c4ccbc6 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -27,7 +27,12 @@ TWriteDispositionConfig, ) from dlt.common.storages import NormalizeStorageConfiguration, LoadPackageInfo, SchemaStorage -from dlt.common.storages.load_package import ParsedLoadJobFileName +from dlt.common.storages.load_package import ( + ParsedLoadJobFileName, + LoadPackageStateInjectableContext, +) + + from dlt.common.utils import get_callable_name, get_full_class_name from dlt.extract.decorators import SourceInjectableContext, SourceSchemaInjectableContext @@ -367,7 +372,13 @@ def extract( load_id = self.extract_storage.create_load_package(source.discover_schema()) with Container().injectable_context( SourceSchemaInjectableContext(source.schema) - ), Container().injectable_context(SourceInjectableContext(source)): + ), Container().injectable_context( + SourceInjectableContext(source) + ), Container().injectable_context( + LoadPackageStateInjectableContext( + storage=self.extract_storage.new_packages, load_id=load_id + ) + ): # inject the config section with the current source name with inject_section( ConfigSectionContext( diff --git a/dlt/extract/extractors.py b/dlt/extract/extractors.py index b4afc5b1f8..86888fed0a 100644 --- a/dlt/extract/extractors.py +++ b/dlt/extract/extractors.py @@ -19,7 +19,7 @@ TPartialTableSchema, ) from dlt.extract.hints import HintsMeta -from dlt.extract.resource import DltResource +from dlt.extract.resource import DltResource, LoadPackageStateMeta from dlt.extract.items import TableNameMeta from dlt.extract.storage import ExtractorItemStorage @@ -88,6 +88,13 @@ def write_items(self, resource: DltResource, items: TDataItems, meta: Any) -> No meta = TableNameMeta(meta.hints["name"]) # type: ignore[arg-type] self._reset_contracts_cache() + # if we have a load package state meta, store to load package + if isinstance(meta, LoadPackageStateMeta): + from dlt.pipeline.current import load_package_source_state, commit_load_package_state + + load_package_source_state()[meta.state_key_name] = items + commit_load_package_state() + if table_name := self._get_static_table_name(resource, meta): # write item belonging to table with static name self._write_to_static_table(resource, table_name, items, meta) diff --git a/dlt/extract/resource.py b/dlt/extract/resource.py index 4776158bbb..87e4fd7f76 100644 --- a/dlt/extract/resource.py +++ b/dlt/extract/resource.py @@ -76,6 +76,22 @@ def with_hints( return DataItemWithMeta(HintsMeta(hints, create_table_variant), item) +class LoadPackageStateMeta: + __slots__ = "state_key_name" + + def __init__(self, state_key_name: str) -> None: + self.state_key_name = state_key_name + + +def with_package_state(item: TDataItems, state_key_name: str) -> DataItemWithMeta: + """Marks `item` to also be inserted into the package state. + + Will create a separate variant of hints for a table if `name` is provided in `hints` and `create_table_variant` is set. + + """ + return DataItemWithMeta(LoadPackageStateMeta(state_key_name), item) + + class DltResource(Iterable[TDataItem], DltResourceHints): """Implements dlt resource. Contains a data pipe that wraps a generating item and table schema that can be adjusted""" diff --git a/dlt/load/load.py b/dlt/load/load.py index 4da64e472a..7bedb3dfa6 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -341,7 +341,13 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) # do not commit load id for aborted packages if not aborted: with self.get_destination_client(schema) as job_client: - job_client.complete_load(load_id) + with Container().injectable_context( + LoadPackageStateInjectableContext( + storage=self.load_storage.normalized_packages, + load_id=load_id, + ) + ): + job_client.complete_load(load_id) self.load_storage.complete_load_package(load_id, aborted) # collect package info self._loaded_packages.append(self.load_storage.get_load_package_info(load_id)) @@ -471,10 +477,9 @@ def run(self, pool: Optional[Executor]) -> TRunMetrics: schema = self.load_storage.normalized_packages.load_schema(load_id) logger.info(f"Loaded schema name {schema.name} and version {schema.stored_version}") - container = Container() # get top load id and mark as being processed with self.collector(f"Load {schema.name} in {load_id}"): - with container.injectable_context( + with Container().injectable_context( LoadPackageStateInjectableContext( storage=self.load_storage.normalized_packages, load_id=load_id, diff --git a/dlt/pipeline/current.py b/dlt/pipeline/current.py index 25fd398623..4bbe74a123 100644 --- a/dlt/pipeline/current.py +++ b/dlt/pipeline/current.py @@ -7,6 +7,7 @@ load_package, commit_load_package_state, destination_state, + load_package_source_state, clear_destination_state, ) from dlt.extract.decorators import get_source_schema, get_source diff --git a/dlt/pipeline/mark.py b/dlt/pipeline/mark.py index 3956d9bbe2..0b753539be 100644 --- a/dlt/pipeline/mark.py +++ b/dlt/pipeline/mark.py @@ -3,5 +3,6 @@ with_table_name, with_hints, make_hints, + with_package_state, materialize_schema_item as materialize_table_schema, ) diff --git a/dlt/pipeline/state_sync.py b/dlt/pipeline/state_sync.py index bc9e35bafe..70a58d1f98 100644 --- a/dlt/pipeline/state_sync.py +++ b/dlt/pipeline/state_sync.py @@ -22,6 +22,7 @@ ) PIPELINE_STATE_ENGINE_VERSION = 4 +LOAD_PACKAGE_STATE_KEY = "pipeline_state" # state table columns STATE_TABLE_COLUMNS: TTableSchemaColumns = { @@ -109,7 +110,7 @@ def state_doc(state: TPipelineState) -> DictStrAny: def state_resource(state: TPipelineState) -> DltResource: - doc = state_doc(state) + doc = dlt.mark.with_package_state(state_doc(state), LOAD_PACKAGE_STATE_KEY) return dlt.resource( [doc], name=STATE_TABLE_NAME, write_disposition="append", columns=STATE_TABLE_COLUMNS ) diff --git a/tests/load/utils.py b/tests/load/utils.py index 110c2b433d..487fd588be 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -64,6 +64,8 @@ "r2", ] +ALL_FILESYSTEM_DRIVERS = ["memory", "file"] + # Filter out buckets not in all filesystem drivers WITH_GDRIVE_BUCKETS = [GCS_BUCKET, AWS_BUCKET, FILE_BUCKET, MEMORY_BUCKET, AZ_BUCKET, GDRIVE_BUCKET] WITH_GDRIVE_BUCKETS = [