Skip to content

Commit

Permalink
persist sink load state in pipeline state
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Jan 25, 2024
1 parent 872c75a commit 8b3da5b
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 75 deletions.
2 changes: 2 additions & 0 deletions dlt/common/configuration/specs/config_section_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class ConfigSectionContext(ContainerInjectableContext):
sections: Tuple[str, ...] = ()
merge_style: TMergeFunc = None
source_state_key: str = None
destination_state_key: str = None

def merge(self, existing: "ConfigSectionContext") -> None:
"""Merges existing context into incoming using a merge style function"""
Expand Down Expand Up @@ -79,4 +80,5 @@ def __init__(
sections: Tuple[str, ...] = (),
merge_style: TMergeFunc = None,
source_state_key: str = None,
destination_state_key: str = None,
) -> None: ...
51 changes: 50 additions & 1 deletion dlt/common/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ class TPipelineState(TypedDict, total=False):
"""A section of state that is not synchronized with the destination and does not participate in change merging and version control"""

sources: NotRequired[Dict[str, Dict[str, Any]]]
destinations: NotRequired[Dict[str, Dict[str, Any]]]


class TSourceState(TPipelineState):
Expand Down Expand Up @@ -594,9 +595,13 @@ class StateInjectableContext(ContainerInjectableContext):

can_create_default: ClassVar[bool] = False

commit: Optional[Callable[[], None]] = None

if TYPE_CHECKING:

def __init__(self, state: TPipelineState = None) -> None: ...
def __init__(
self, state: TPipelineState = None, commit: Optional[Callable[[], None]] = None
) -> None: ...


def pipeline_state(
Expand Down Expand Up @@ -679,6 +684,50 @@ def source_state() -> DictStrAny:
_last_full_state: TPipelineState = None


def destination_state() -> DictStrAny:
container = Container()

# get the destination name from the section context
destination_state_key: str = None
with contextlib.suppress(ContextDefaultCannotBeCreated):
sections_context = container[ConfigSectionContext]
destination_state_key = sections_context.destination_state_key

if not destination_state_key:
raise SourceSectionNotAvailable()

state, _ = pipeline_state(Container())

destination_state: DictStrAny = state.setdefault("destinations", {}).setdefault(
destination_state_key, {}
)
return destination_state


def reset_destination_state() -> None:
container = Container()

# get the destination name from the section context
destination_state_key: str = None
with contextlib.suppress(ContextDefaultCannotBeCreated):
sections_context = container[ConfigSectionContext]
destination_state_key = sections_context.destination_state_key

if not destination_state_key:
raise SourceSectionNotAvailable()

state, _ = pipeline_state(Container())

state.setdefault("destinations", {}).pop(destination_state_key)


def commit_pipeline_state() -> None:
container = Container()
# get injected state if present. injected state is typically "managed" so changes will be persisted
state_ctx = container[StateInjectableContext]
state_ctx.commit()


def _delete_source_state_keys(
key: TAnyJsonPath, source_state_: Optional[DictStrAny] = None, /
) -> None:
Expand Down
83 changes: 16 additions & 67 deletions dlt/destinations/impl/sink/sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from dlt.destinations.job_impl import EmptyLoadJob
from dlt.common.typing import TDataItems
from dlt.common import json
from dlt.common.configuration.container import Container
from dlt.common.pipeline import StateInjectableContext
from dlt.common.pipeline import destination_state, reset_destination_state, commit_pipeline_state

from dlt.common.storages.load_storage import ParsedLoadJobFileName
from dlt.common.schema import Schema, TTableSchema, TSchemaTables
from dlt.common.schema.typing import TTableSchema
from dlt.common.storages import FileStorage
Expand All @@ -21,38 +23,35 @@
from dlt.destinations.impl.sink.configuration import SinkClientConfiguration, TSinkCallable


# TODO: implement proper state storage somewhere, can this somehow go into the loadpackage?
job_execution_storage: Dict[str, int] = {}


class SinkLoadJob(LoadJob, ABC):
def __init__(
self,
table: TTableSchema,
file_path: str,
config: SinkClientConfiguration,
schema: Schema,
job_execution_storage: Dict[str, int],
load_state: Dict[str, int],
) -> None:
super().__init__(FileStorage.get_file_name_from_file_path(file_path))
self._file_path = file_path
self._config = config
self._table = table
self._schema = schema
self._job_execution_storage = job_execution_storage

# TODO: is this the correct way to tell dlt to retry this job in the next attempt?
self._state: TLoadJobState = "running"
try:
current_index = self._job_execution_storage.get(self._parsed_file_name.file_id, 0)
current_index = load_state.get(self._parsed_file_name.file_id, 0)
for batch in self.run(current_index):
self.call_callable_with_items(batch)
current_index += len(batch)
self._job_execution_storage[self._parsed_file_name.file_id] = current_index
load_state[self._parsed_file_name.file_id] = current_index
self._state = "completed"
except Exception as e:
self._state = "retry"
raise e
finally:
# save progress
commit_pipeline_state()

@abstractmethod
def run(self, start_index: int) -> Iterable[TDataItems]:
Expand Down Expand Up @@ -121,52 +120,6 @@ def run(self, start_index: int) -> Iterable[TDataItems]:
yield current_batch


# class SinkInsertValueslLoadJob(SinkLoadJob):
# def run(self) -> None:
# from dlt.common import json

# # stream items
# with FileStorage.open_zipsafe_ro(self._file_path) as f:
# header = f.readline().strip()
# values_mark = f.readline()

# # properly formatted file has a values marker at the beginning
# assert values_mark == "VALUES\n"

# # extract column names
# assert header.startswith("INSERT INTO") and header.endswith(")")
# header = header[15:-1]
# column_names = header.split(",")

# # build batches
# current_batch: TDataItems = []
# current_row: str = ""
# for line in f:
# current_row += line
# if line.endswith(");"):
# current_row = current_row[1:-2]
# elif line.endswith("),\n"):
# current_row = current_row[1:-3]
# else:
# continue

# values = current_row.split(",")
# values = [None if v == "NULL" else v for v in values]
# current_row = ""
# print(values)
# print(current_row)

# # zip and send to callable
# current_batch.append(dict(zip(column_names, values)))
# d = dict(zip(column_names, values))
# print(json.dumps(d, pretty=True))
# if len(current_batch) == self._config.batch_size:
# self.call_callable_with_items(current_batch)
# current_batch = []

# self.call_callable_with_items(current_batch)


class SinkClient(JobClientBase):
"""Sink Client"""

Expand All @@ -175,8 +128,6 @@ class SinkClient(JobClientBase):
def __init__(self, schema: Schema, config: SinkClientConfiguration) -> None:
super().__init__(schema, config)
self.config: SinkClientConfiguration = config
global job_execution_storage
self.job_execution_storage = job_execution_storage

def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None:
pass
Expand All @@ -193,23 +144,21 @@ def update_stored_schema(
return super().update_stored_schema(only_tables, expected_update)

def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob:
load_state = destination_state().setdefault(load_id, {})
if file_path.endswith("parquet"):
return SinkParquetLoadJob(
table, file_path, self.config, self.schema, job_execution_storage
)
return SinkParquetLoadJob(table, file_path, self.config, self.schema, load_state)
if file_path.endswith("jsonl"):
return SinkJsonlLoadJob(
table, file_path, self.config, self.schema, job_execution_storage
)
# if file_path.endswith("insert_values"):
# return SinkInsertValueslLoadJob(table, file_path, self.config, self.schema)
return SinkJsonlLoadJob(table, file_path, self.config, self.schema, load_state)
return None

def restore_file_load(self, file_path: str) -> LoadJob:
return EmptyLoadJob.from_file_path(file_path, "completed")

def complete_load(self, load_id: str) -> None:
pass
# pop all state for this load on success
state = destination_state()
state.pop(load_id, None)
commit_pipeline_state()

def __enter__(self) -> "SinkClient":
return self
Expand Down
16 changes: 12 additions & 4 deletions dlt/load/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from dlt.common import sleep, logger
from dlt.common.configuration import with_config, known_sections
from dlt.common.configuration.resolve import inject_section
from dlt.common.configuration.accessors import config
from dlt.common.pipeline import LoadInfo, LoadMetrics, SupportsPipeline, WithStepInfo
from dlt.common.schema.utils import get_child_tables, get_top_level_table
Expand Down Expand Up @@ -35,6 +36,7 @@
SupportsStagingDestination,
TDestination,
)
from dlt.common.configuration.specs.config_section_context import ConfigSectionContext

from dlt.destinations.job_impl import EmptyLoadJob

Expand Down Expand Up @@ -558,10 +560,16 @@ def run(self, pool: Optional[Executor]) -> TRunMetrics:

# get top load id and mark as being processed
with self.collector(f"Load {schema.name} in {load_id}"):
# the same load id may be processed across multiple runs
if not self.current_load_id:
self._step_info_start_load_id(load_id)
self.load_single_package(load_id, schema)
with inject_section(
ConfigSectionContext(
sections=(known_sections.LOAD,),
destination_state_key=self.destination.destination_name,
)
):
# the same load id may be processed across multiple runs
if not self.current_load_id:
self._step_info_start_load_id(load_id)
self.load_single_package(load_id, schema)

return TRunMetrics(False, len(self.load_storage.list_normalized_packages()))

Expand Down
20 changes: 18 additions & 2 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,20 @@ def decorator(f: TFun) -> TFun:
def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any:
# activate pipeline so right state is always provided
self.activate()

# backup and restore state
should_extract_state = may_extract_state and self.config.restore_from_destination
with self.managed_state(extract_state=should_extract_state) as state:
# commit hook
def commit_state() -> None:
# save the state
bump_version_if_modified(state)
self._save_state(state)

# add the state to container as a context
with self._container.injectable_context(StateInjectableContext(state=state)):
with self._container.injectable_context(
StateInjectableContext(state=state, commit=commit_state)
):
return f(self, *args, **kwargs)

return _wrap # type: ignore
Expand Down Expand Up @@ -246,7 +255,14 @@ class Pipeline(SupportsPipeline):
STATE_FILE: ClassVar[str] = "state.json"
STATE_PROPS: ClassVar[List[str]] = list(
set(get_type_hints(TPipelineState).keys())
- {"sources", "destination_type", "destination_name", "staging_type", "staging_name"}
- {
"sources",
"destination_type",
"destination_name",
"staging_type",
"staging_name",
"destinations",
}
)
LOCAL_STATE_PROPS: ClassVar[List[str]] = list(get_type_hints(TPipelineLocalState).keys())
DEFAULT_DATASET_SUFFIX: ClassVar[str] = "_dataset"
Expand Down
15 changes: 14 additions & 1 deletion tests/load/sink/test_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None:
p.run([items(), items2()])
assert_items_in_range(calls["items"], 0, 100)
assert_items_in_range(calls["items2"], 0, 100)
# destination state should be cleared after load
assert p.state["destinations"]["sink"] == {}

# provoke errors
calls = {}
Expand All @@ -242,16 +244,25 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None:
with pytest.raises(PipelineStepFailed):
p.run([items(), items2()])

# partly loaded
# we should have data for one load id saved here
assert len(p.state["destinations"]["sink"]) == 1
# get saved indexes
values = list(list(p.state["destinations"]["sink"].values())[0].values())

# partly loaded, pointers in state should be right
if batch_size == 1:
assert_items_in_range(calls["items"], 0, 25)
assert_items_in_range(calls["items2"], 0, 45)
# one pointer for state, one for items, one for items2...
assert values == [1, 25, 45]
elif batch_size == 10:
assert_items_in_range(calls["items"], 0, 20)
assert_items_in_range(calls["items2"], 0, 40)
assert values == [1, 20, 40]
elif batch_size == 23:
assert_items_in_range(calls["items"], 0, 23)
assert_items_in_range(calls["items2"], 0, 23)
assert values == [1, 23, 23]
else:
raise AssertionError("Unknown batch size")

Expand All @@ -260,6 +271,8 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None:
provoke_error = {}
calls = {}
p.load()
# state should be cleared again
assert p.state["destinations"]["sink"] == {}

# both calls combined should have every item called just once
assert_items_in_range(calls["items"] + first_calls["items"], 0, 100)
Expand Down

0 comments on commit 8b3da5b

Please sign in to comment.