diff --git a/dlt/common/configuration/specs/config_section_context.py b/dlt/common/configuration/specs/config_section_context.py index a656a2b0fe..b4d0bc7731 100644 --- a/dlt/common/configuration/specs/config_section_context.py +++ b/dlt/common/configuration/specs/config_section_context.py @@ -12,6 +12,7 @@ class ConfigSectionContext(ContainerInjectableContext): sections: Tuple[str, ...] = () merge_style: TMergeFunc = None source_state_key: str = None + destination_state_key: str = None def merge(self, existing: "ConfigSectionContext") -> None: """Merges existing context into incoming using a merge style function""" @@ -79,4 +80,5 @@ def __init__( sections: Tuple[str, ...] = (), merge_style: TMergeFunc = None, source_state_key: str = None, + destination_state_key: str = None, ) -> None: ... diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index 6b7b308b44..1971b83410 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -470,6 +470,7 @@ class TPipelineState(TypedDict, total=False): """A section of state that is not synchronized with the destination and does not participate in change merging and version control""" sources: NotRequired[Dict[str, Dict[str, Any]]] + destinations: NotRequired[Dict[str, Dict[str, Any]]] class TSourceState(TPipelineState): @@ -594,9 +595,13 @@ class StateInjectableContext(ContainerInjectableContext): can_create_default: ClassVar[bool] = False + commit: Optional[Callable[[], None]] = None + if TYPE_CHECKING: - def __init__(self, state: TPipelineState = None) -> None: ... + def __init__( + self, state: TPipelineState = None, commit: Optional[Callable[[], None]] = None + ) -> None: ... def pipeline_state( @@ -679,6 +684,50 @@ def source_state() -> DictStrAny: _last_full_state: TPipelineState = None +def destination_state() -> DictStrAny: + container = Container() + + # get the destination name from the section context + destination_state_key: str = None + with contextlib.suppress(ContextDefaultCannotBeCreated): + sections_context = container[ConfigSectionContext] + destination_state_key = sections_context.destination_state_key + + if not destination_state_key: + raise SourceSectionNotAvailable() + + state, _ = pipeline_state(Container()) + + destination_state: DictStrAny = state.setdefault("destinations", {}).setdefault( + destination_state_key, {} + ) + return destination_state + + +def reset_destination_state() -> None: + container = Container() + + # get the destination name from the section context + destination_state_key: str = None + with contextlib.suppress(ContextDefaultCannotBeCreated): + sections_context = container[ConfigSectionContext] + destination_state_key = sections_context.destination_state_key + + if not destination_state_key: + raise SourceSectionNotAvailable() + + state, _ = pipeline_state(Container()) + + state.setdefault("destinations", {}).pop(destination_state_key) + + +def commit_pipeline_state() -> None: + container = Container() + # get injected state if present. injected state is typically "managed" so changes will be persisted + state_ctx = container[StateInjectableContext] + state_ctx.commit() + + def _delete_source_state_keys( key: TAnyJsonPath, source_state_: Optional[DictStrAny] = None, / ) -> None: diff --git a/dlt/destinations/impl/sink/sink.py b/dlt/destinations/impl/sink/sink.py index a89e14f2cf..b9e0904e80 100644 --- a/dlt/destinations/impl/sink/sink.py +++ b/dlt/destinations/impl/sink/sink.py @@ -5,8 +5,10 @@ from dlt.destinations.job_impl import EmptyLoadJob from dlt.common.typing import TDataItems from dlt.common import json +from dlt.common.configuration.container import Container +from dlt.common.pipeline import StateInjectableContext +from dlt.common.pipeline import destination_state, reset_destination_state, commit_pipeline_state -from dlt.common.storages.load_storage import ParsedLoadJobFileName from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.schema.typing import TTableSchema from dlt.common.storages import FileStorage @@ -21,10 +23,6 @@ from dlt.destinations.impl.sink.configuration import SinkClientConfiguration, TSinkCallable -# TODO: implement proper state storage somewhere, can this somehow go into the loadpackage? -job_execution_storage: Dict[str, int] = {} - - class SinkLoadJob(LoadJob, ABC): def __init__( self, @@ -32,27 +30,28 @@ def __init__( file_path: str, config: SinkClientConfiguration, schema: Schema, - job_execution_storage: Dict[str, int], + load_state: Dict[str, int], ) -> None: super().__init__(FileStorage.get_file_name_from_file_path(file_path)) self._file_path = file_path self._config = config self._table = table self._schema = schema - self._job_execution_storage = job_execution_storage - # TODO: is this the correct way to tell dlt to retry this job in the next attempt? self._state: TLoadJobState = "running" try: - current_index = self._job_execution_storage.get(self._parsed_file_name.file_id, 0) + current_index = load_state.get(self._parsed_file_name.file_id, 0) for batch in self.run(current_index): self.call_callable_with_items(batch) current_index += len(batch) - self._job_execution_storage[self._parsed_file_name.file_id] = current_index + load_state[self._parsed_file_name.file_id] = current_index self._state = "completed" except Exception as e: self._state = "retry" raise e + finally: + # save progress + commit_pipeline_state() @abstractmethod def run(self, start_index: int) -> Iterable[TDataItems]: @@ -121,52 +120,6 @@ def run(self, start_index: int) -> Iterable[TDataItems]: yield current_batch -# class SinkInsertValueslLoadJob(SinkLoadJob): -# def run(self) -> None: -# from dlt.common import json - -# # stream items -# with FileStorage.open_zipsafe_ro(self._file_path) as f: -# header = f.readline().strip() -# values_mark = f.readline() - -# # properly formatted file has a values marker at the beginning -# assert values_mark == "VALUES\n" - -# # extract column names -# assert header.startswith("INSERT INTO") and header.endswith(")") -# header = header[15:-1] -# column_names = header.split(",") - -# # build batches -# current_batch: TDataItems = [] -# current_row: str = "" -# for line in f: -# current_row += line -# if line.endswith(");"): -# current_row = current_row[1:-2] -# elif line.endswith("),\n"): -# current_row = current_row[1:-3] -# else: -# continue - -# values = current_row.split(",") -# values = [None if v == "NULL" else v for v in values] -# current_row = "" -# print(values) -# print(current_row) - -# # zip and send to callable -# current_batch.append(dict(zip(column_names, values))) -# d = dict(zip(column_names, values)) -# print(json.dumps(d, pretty=True)) -# if len(current_batch) == self._config.batch_size: -# self.call_callable_with_items(current_batch) -# current_batch = [] - -# self.call_callable_with_items(current_batch) - - class SinkClient(JobClientBase): """Sink Client""" @@ -175,8 +128,6 @@ class SinkClient(JobClientBase): def __init__(self, schema: Schema, config: SinkClientConfiguration) -> None: super().__init__(schema, config) self.config: SinkClientConfiguration = config - global job_execution_storage - self.job_execution_storage = job_execution_storage def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: pass @@ -193,23 +144,21 @@ def update_stored_schema( return super().update_stored_schema(only_tables, expected_update) def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + load_state = destination_state().setdefault(load_id, {}) if file_path.endswith("parquet"): - return SinkParquetLoadJob( - table, file_path, self.config, self.schema, job_execution_storage - ) + return SinkParquetLoadJob(table, file_path, self.config, self.schema, load_state) if file_path.endswith("jsonl"): - return SinkJsonlLoadJob( - table, file_path, self.config, self.schema, job_execution_storage - ) - # if file_path.endswith("insert_values"): - # return SinkInsertValueslLoadJob(table, file_path, self.config, self.schema) + return SinkJsonlLoadJob(table, file_path, self.config, self.schema, load_state) return None def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") def complete_load(self, load_id: str) -> None: - pass + # pop all state for this load on success + state = destination_state() + state.pop(load_id, None) + commit_pipeline_state() def __enter__(self) -> "SinkClient": return self diff --git a/dlt/load/load.py b/dlt/load/load.py index b0b52d61d6..beb52cc296 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -7,6 +7,7 @@ from dlt.common import sleep, logger from dlt.common.configuration import with_config, known_sections +from dlt.common.configuration.resolve import inject_section from dlt.common.configuration.accessors import config from dlt.common.pipeline import LoadInfo, LoadMetrics, SupportsPipeline, WithStepInfo from dlt.common.schema.utils import get_child_tables, get_top_level_table @@ -35,6 +36,7 @@ SupportsStagingDestination, TDestination, ) +from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.destinations.job_impl import EmptyLoadJob @@ -558,10 +560,16 @@ def run(self, pool: Optional[Executor]) -> TRunMetrics: # get top load id and mark as being processed with self.collector(f"Load {schema.name} in {load_id}"): - # the same load id may be processed across multiple runs - if not self.current_load_id: - self._step_info_start_load_id(load_id) - self.load_single_package(load_id, schema) + with inject_section( + ConfigSectionContext( + sections=(known_sections.LOAD,), + destination_state_key=self.destination.destination_name, + ) + ): + # the same load id may be processed across multiple runs + if not self.current_load_id: + self._step_info_start_load_id(load_id) + self.load_single_package(load_id, schema) return TRunMetrics(False, len(self.load_storage.list_normalized_packages())) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 73c8f076d1..613e4794ce 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -143,11 +143,20 @@ def decorator(f: TFun) -> TFun: def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: # activate pipeline so right state is always provided self.activate() + # backup and restore state should_extract_state = may_extract_state and self.config.restore_from_destination with self.managed_state(extract_state=should_extract_state) as state: + # commit hook + def commit_state() -> None: + # save the state + bump_version_if_modified(state) + self._save_state(state) + # add the state to container as a context - with self._container.injectable_context(StateInjectableContext(state=state)): + with self._container.injectable_context( + StateInjectableContext(state=state, commit=commit_state) + ): return f(self, *args, **kwargs) return _wrap # type: ignore @@ -246,7 +255,14 @@ class Pipeline(SupportsPipeline): STATE_FILE: ClassVar[str] = "state.json" STATE_PROPS: ClassVar[List[str]] = list( set(get_type_hints(TPipelineState).keys()) - - {"sources", "destination_type", "destination_name", "staging_type", "staging_name"} + - { + "sources", + "destination_type", + "destination_name", + "staging_type", + "staging_name", + "destinations", + } ) LOCAL_STATE_PROPS: ClassVar[List[str]] = list(get_type_hints(TPipelineLocalState).keys()) DEFAULT_DATASET_SUFFIX: ClassVar[str] = "_dataset" diff --git a/tests/load/sink/test_sink.py b/tests/load/sink/test_sink.py index 2be7af5513..d96fb2d9eb 100644 --- a/tests/load/sink/test_sink.py +++ b/tests/load/sink/test_sink.py @@ -234,6 +234,8 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: p.run([items(), items2()]) assert_items_in_range(calls["items"], 0, 100) assert_items_in_range(calls["items2"], 0, 100) + # destination state should be cleared after load + assert p.state["destinations"]["sink"] == {} # provoke errors calls = {} @@ -242,16 +244,25 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: with pytest.raises(PipelineStepFailed): p.run([items(), items2()]) - # partly loaded + # we should have data for one load id saved here + assert len(p.state["destinations"]["sink"]) == 1 + # get saved indexes + values = list(list(p.state["destinations"]["sink"].values())[0].values()) + + # partly loaded, pointers in state should be right if batch_size == 1: assert_items_in_range(calls["items"], 0, 25) assert_items_in_range(calls["items2"], 0, 45) + # one pointer for state, one for items, one for items2... + assert values == [1, 25, 45] elif batch_size == 10: assert_items_in_range(calls["items"], 0, 20) assert_items_in_range(calls["items2"], 0, 40) + assert values == [1, 20, 40] elif batch_size == 23: assert_items_in_range(calls["items"], 0, 23) assert_items_in_range(calls["items2"], 0, 23) + assert values == [1, 23, 23] else: raise AssertionError("Unknown batch size") @@ -260,6 +271,8 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: provoke_error = {} calls = {} p.load() + # state should be cleared again + assert p.state["destinations"]["sink"] == {} # both calls combined should have every item called just once assert_items_in_range(calls["items"] + first_calls["items"], 0, 100)