diff --git a/dlt/extract/incremental.py b/dlt/extract/incremental/__init__.py similarity index 84% rename from dlt/extract/incremental.py rename to dlt/extract/incremental/__init__.py index ebc54530cc..c310697bda 100644 --- a/dlt/extract/incremental.py +++ b/dlt/extract/incremental/__init__.py @@ -17,39 +17,17 @@ from dlt.common.data_types.type_helpers import coerce_from_date_types, coerce_value, py_type_to_sc_type from dlt.extract.exceptions import IncrementalUnboundError, PipeException +from dlt.extract.incremental.exceptions import IncrementalCursorPathMissing, IncrementalPrimaryKeyMissing +from dlt.extract.incremental.typing import IncrementalColumnState, TCursorValue, LastValueFunc from dlt.extract.pipe import Pipe from dlt.extract.utils import resolve_column_value -from dlt.extract.typing import FilterItem, SupportsPipe, TTableHintTemplate +from dlt.extract.typing import SupportsPipe, TTableHintTemplate, MapItem, YieldMapItem, FilterItem +from dlt.extract.incremental.transform import get_transformer -TCursorValue = TypeVar("TCursorValue", bound=Any) -LastValueFunc = Callable[[Sequence[TCursorValue]], Any] - - -class IncrementalColumnState(TypedDict): - initial_value: Optional[Any] - last_value: Optional[Any] - unique_hashes: List[str] - - -class IncrementalCursorPathMissing(PipeException): - def __init__(self, pipe_name: str, json_path: str, item: TDataItem) -> None: - self.json_path = json_path - self.item = item - msg = f"Cursor element with JSON path {json_path} was not found in extracted data item. All data items must contain this path. Use the same names of fields as in your JSON document - if those are different from the names you see in database." - super().__init__(pipe_name, msg) - - -class IncrementalPrimaryKeyMissing(PipeException): - def __init__(self, pipe_name: str, primary_key_column: str, item: TDataItem) -> None: - self.primary_key_column = primary_key_column - self.item = item - msg = f"Primary key column {primary_key_column} was not found in extracted data item. All data items must contain this column. Use the same names of fields as in your JSON document." - super().__init__(pipe_name, msg) - @configspec -class Incremental(FilterItem, BaseConfiguration, Generic[TCursorValue]): +class Incremental(YieldMapItem, BaseConfiguration, Generic[TCursorValue]): """Adds incremental extraction for a resource by storing a cursor value in persistent state. The cursor could for example be a timestamp for when the record was created and you can use this to load only @@ -244,62 +222,20 @@ def unique_value(self, row: TDataItem) -> str: except KeyError as k_err: raise IncrementalPrimaryKeyMissing(self.resource_name, k_err.args[0], row) - def transform(self, row: TDataItem) -> bool: + def transform(self, row: TDataItem) -> TDataItem: if row is None: - return True - - row_values = find_values(self.cursor_path_p, row) - if not row_values: - raise IncrementalCursorPathMissing(self.resource_name, self.cursor_path, row) - row_value = row_values[0] - - # For datetime cursor, ensure the value is a timezone aware datetime. - # The object saved in state will always be a tz aware pendulum datetime so this ensures values are comparable - if isinstance(row_value, datetime): - row_value = pendulum.instance(row_value) - - incremental_state = self._cached_state - last_value = incremental_state['last_value'] - last_value_func = self.last_value_func - - # Check whether end_value has been reached - # Filter end value ranges exclusively, so in case of "max" function we remove values >= end_value - if self.end_value is not None and ( - last_value_func((row_value, self.end_value)) != self.end_value or last_value_func((row_value, )) == self.end_value - ): - self.end_out_of_range = True - return False - - check_values = (row_value,) + ((last_value, ) if last_value is not None else ()) - new_value = last_value_func(check_values) - if last_value == new_value: - processed_row_value = last_value_func((row_value, )) - # we store row id for all records with the current "last_value" in state and use it to deduplicate - if processed_row_value == last_value: - unique_value = self.unique_value(row) - # if unique value exists then use it to deduplicate - if unique_value: - if unique_value in incremental_state['unique_hashes']: - return False - # add new hash only if the record row id is same as current last value - incremental_state['unique_hashes'].append(unique_value) - return True - # skip the record that is not a last_value or new_value: that record was already processed - check_values = (row_value,) + ((self.start_value,) if self.start_value is not None else ()) - new_value = last_value_func(check_values) - # Include rows == start_value but exclude "lower" - if new_value == self.start_value and processed_row_value != self.start_value: - self.start_out_of_range = True - return False - else: - return True - else: - incremental_state["last_value"] = new_value - unique_value = self.unique_value(row) - if unique_value: - incremental_state["unique_hashes"] = [unique_value] + yield row + return - return True + transformer = get_transformer(row) + + row, start_out_of_range, end_out_of_range = transformer( + row, self.resource_name, self.cursor_path_p, self.start_value, self.end_value, self._cached_state, self.last_value_func, self.primary_key + ) + self.start_out_of_range = start_out_of_range + self.end_out_of_range = end_out_of_range + if row is not None: + yield row def get_incremental_value_type(self) -> Type[Any]: """Infers the type of incremental value from a class of an instance if those preserve the Generic arguments information.""" @@ -377,7 +313,7 @@ def __str__(self) -> str: return f"Incremental at {id(self)} for resource {self.resource_name} with cursor path: {self.cursor_path} initial {self.initial_value} lv_func {self.last_value_func}" -class IncrementalResourceWrapper(FilterItem): +class IncrementalResourceWrapper(YieldMapItem): _incremental: Optional[Incremental[Any]] = None """Keeps the injectable incremental""" _resource_name: str = None @@ -485,7 +421,11 @@ def bind(self, pipe: SupportsPipe) -> "IncrementalResourceWrapper": def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]: if not self._incremental: - return item + yield item + return if self._incremental.primary_key is None: self._incremental.primary_key = self.primary_key - return self._incremental(item, meta) + if isinstance(item, list): + yield list(self._incremental(item, meta)) + else: + yield self._incremental(item, meta) diff --git a/dlt/extract/incremental/exceptions.py b/dlt/extract/incremental/exceptions.py new file mode 100644 index 0000000000..1bf635e981 --- /dev/null +++ b/dlt/extract/incremental/exceptions.py @@ -0,0 +1,18 @@ +from dlt.extract.exceptions import PipeException +from dlt.common.typing import TDataItem + + +class IncrementalCursorPathMissing(PipeException): + def __init__(self, pipe_name: str, json_path: str, item: TDataItem) -> None: + self.json_path = json_path + self.item = item + msg = f"Cursor element with JSON path {json_path} was not found in extracted data item. All data items must contain this path. Use the same names of fields as in your JSON document - if those are different from the names you see in database." + super().__init__(pipe_name, msg) + + +class IncrementalPrimaryKeyMissing(PipeException): + def __init__(self, pipe_name: str, primary_key_column: str, item: TDataItem) -> None: + self.primary_key_column = primary_key_column + self.item = item + msg = f"Primary key column {primary_key_column} was not found in extracted data item. All data items must contain this column. Use the same names of fields as in your JSON document." + super().__init__(pipe_name, msg) diff --git a/dlt/extract/incremental/transform.py b/dlt/extract/incremental/transform.py new file mode 100644 index 0000000000..e48d4232d2 --- /dev/null +++ b/dlt/extract/incremental/transform.py @@ -0,0 +1,251 @@ +from datetime import datetime # noqa: I251 +from typing import Optional, Tuple, Protocol, Mapping, Union, List + +try: + import pandas as pd +except ModuleNotFoundError: + pd = None + +from dlt.common.exceptions import MissingDependencyException +from dlt.common.utils import digest128 +from dlt.common.json import json +from dlt.common import pendulum +from dlt.common.typing import TDataItem +from dlt.common.jsonpath import TJsonPath, find_values +from dlt.extract.incremental.exceptions import IncrementalCursorPathMissing, IncrementalPrimaryKeyMissing +from dlt.extract.incremental.typing import IncrementalColumnState, TCursorValue, LastValueFunc +from dlt.extract.utils import resolve_column_value +from dlt.extract.typing import TTableHintTemplate +from dlt.common.schema.typing import TColumnNames +try: + from dlt.common.libs.pyarrow import is_arrow_item, pyarrow as pa, TAnyArrowItem +except MissingDependencyException: + is_arrow_item = lambda x: False + + + +class IncrementalTransformer(Protocol): + def __call__( + self, + row: TDataItem, + resource_name: str, + cursor_path: TJsonPath, + start_value: Optional[TCursorValue], + end_value: Optional[TCursorValue], + incremental_state: IncrementalColumnState, + last_value_func: LastValueFunc[TCursorValue], + primary_key: Optional[TTableHintTemplate[TColumnNames]] + ) -> Tuple[bool, bool, bool]: + ... + + +class JsonIncremental(IncrementalTransformer): + def unique_value( + self, + row: TDataItem, + primary_key: Optional[TTableHintTemplate[TColumnNames]], + resource_name: str + ) -> str: + try: + if primary_key: + return digest128(json.dumps(resolve_column_value(primary_key, row), sort_keys=True)) + elif primary_key is None: + return digest128(json.dumps(row, sort_keys=True)) + else: + return None + except KeyError as k_err: + raise IncrementalPrimaryKeyMissing(resource_name, k_err.args[0], row) + + def __call__( + self, + row: TDataItem, + resource_name: str, + cursor_path: TJsonPath, + start_value: Optional[TCursorValue], + end_value: Optional[TCursorValue], + incremental_state: IncrementalColumnState, + last_value_func: LastValueFunc[TCursorValue], + primary_key: Optional[TTableHintTemplate[TColumnNames]] + ) -> Tuple[Optional[TDataItem], bool, bool]: + """ + Returns: + Tuple (row, start_out_of_range, end_out_of_range) where row is either the data item or `None` if it is completely filtered out + """ + start_out_of_range = end_out_of_range = False + if row is None: + return row, start_out_of_range, end_out_of_range + + row_values = find_values(cursor_path, row) + if not row_values: + raise IncrementalCursorPathMissing(resource_name, str(cursor_path), row) + row_value = row_values[0] + + # For datetime cursor, ensure the value is a timezone aware datetime. + # The object saved in state will always be a tz aware pendulum datetime so this ensures values are comparable + if isinstance(row_value, datetime): + row_value = pendulum.instance(row_value) + + last_value = incremental_state['last_value'] + + # Check whether end_value has been reached + # Filter end value ranges exclusively, so in case of "max" function we remove values >= end_value + if end_value is not None and ( + last_value_func((row_value, end_value)) != end_value or last_value_func((row_value, )) == end_value + ): + end_out_of_range = True + return None, start_out_of_range, end_out_of_range + + check_values = (row_value,) + ((last_value, ) if last_value is not None else ()) + new_value = last_value_func(check_values) + if last_value == new_value: + processed_row_value = last_value_func((row_value, )) + # we store row id for all records with the current "last_value" in state and use it to deduplicate + if processed_row_value == last_value: + unique_value = self.unique_value(row, primary_key, resource_name) + # if unique value exists then use it to deduplicate + if unique_value: + if unique_value in incremental_state['unique_hashes']: + return None, start_out_of_range, end_out_of_range + # add new hash only if the record row id is same as current last value + incremental_state['unique_hashes'].append(unique_value) + return row, start_out_of_range, end_out_of_range + # skip the record that is not a last_value or new_value: that record was already processed + check_values = (row_value,) + ((start_value,) if start_value is not None else ()) + new_value = last_value_func(check_values) + # Include rows == start_value but exclude "lower" + if new_value == start_value and processed_row_value != start_value: + start_out_of_range = True + return None, start_out_of_range, end_out_of_range + else: + return row, start_out_of_range, end_out_of_range + else: + incremental_state["last_value"] = new_value + unique_value = self.unique_value(row, primary_key, resource_name) + if unique_value: + incremental_state["unique_hashes"] = [unique_value] + + return row, start_out_of_range, end_out_of_range + + + +class ArrowIncremental(IncrementalTransformer): + def unique_values( + self, + item: "TAnyArrowItem", + primary_key: Optional[TTableHintTemplate[TColumnNames]], + resource_name: str + ) -> List[Tuple[int, str]]: + item = item + indices = item["_dlt_index"] + item = item.drop(["_dlt_index"]) # Don't include the index in unique hash + if primary_key: + columns = primary_key(item) if callable(primary_key) else primary_key + if isinstance(columns, str): + item = item[columns] + else: + item = item.select(columns) + rows = item.to_pylist() + return [ + (index, digest128(json.dumps(row, sort_keys=True))) for index, row in zip(indices, rows) + ] + + def __call__( + self, + tbl: "TAnyArrowItem", + resource_name: str, + cursor_path: TJsonPath, + start_value: Optional[TCursorValue], + end_value: Optional[TCursorValue], + incremental_state: IncrementalColumnState, + last_value_func: LastValueFunc[TCursorValue], + primary_key: Optional[TTableHintTemplate[TColumnNames]] + ) -> Tuple[TDataItem, bool, bool]: + is_pandas = pd is not None and isinstance(tbl, pd.DataFrame) + if is_pandas: + tbl = pa.Table.from_pandas(tbl) + + start_out_of_range = end_out_of_range = False + if not tbl: # row is None or empty arrow table + return tbl, start_out_of_range, end_out_of_range + + last_value = incremental_state['last_value'] + + if last_value_func is max: + compute = pa.compute.max + end_compare = pa.compute.less + start_compare = pa.compute.greater_equal + last_value_compare = pa.compute.greater_equal + new_value_compare = pa.compute.greater + elif last_value_func is min: + compute = pa.compute.min + end_compare = pa.compute.greater + start_compare = pa.compute.less_equal + last_value_compare = pa.compute.less_equal + new_value_compare = pa.compute.less + else: + raise NotImplementedError("Only min or max last_value_func is supported for arrow tables") + + + # TODO: Json path support. For now assume the cursor_path is a column name + cursor_path = str(cursor_path) + # The new max/min value + row_value = compute(tbl[cursor_path]).as_py() + + # If end_value is provided, filter to include table rows that are "less" than end_value + if end_value is not None: + tbl = tbl.filter(end_compare(tbl[cursor_path], end_value)) + # Is max row value higher than end value? + end_out_of_range = not end_compare(row_value, end_value) + if end_out_of_range: + if is_pandas: + tbl = tbl.to_pandas() + return tbl, start_out_of_range, end_out_of_range + + # Filter out all rows which have cursor value equal to last value + # and unique id exists in state + tbl = tbl.append_column("_dlt_index", pa.array(range(tbl.num_rows))) + if last_value is not None: + tbl = tbl.filter(last_value_compare(tbl[cursor_path], last_value)) + # Exclude rows from the table which have unique hashes already seen before + + # Rows with same cursor as stored last value + eq_rows = tbl.filter(pa.compute.equal(tbl[cursor_path], last_value)) + # compute index, unique hash mapping + unique_values = self.unique_values(eq_rows, primary_key, resource_name) + unique_values = [(i, uq_val) for i, uq_val in unique_values if uq_val not in incremental_state['unique_hashes']] + keep_idx = pa.array(i for i, _ in unique_values) + # Filter the table + tbl = tbl.filter(pa.compute.is_in(tbl["_dlt_index"], keep_idx)) + + if new_value_compare(row_value, last_value): # Last value has changed + incremental_state['last_value'] = row_value + # Compute unique hashes for all rows equal to row value + incremental_state['unique_hashes'] = [uq_val for _, uq_val in self.unique_values( + tbl.filter(pa.compute.equal(tbl[cursor_path], row_value)), primary_key, resource_name + )] + else: + # last value is unchanged, add the hashes + incremental_state['unique_hashes'].extend(uq_val for _, uq_val in unique_values) + else: + incremental_state['last_value'] = row_value + incremental_state['unique_hashes'] = [uq_val for _, uq_val in self.unique_values( + tbl.filter(pa.compute.equal(tbl[cursor_path], row_value)), primary_key, resource_name + )] + + if start_value is not None: + # Is any value lower than start value + start_out_of_range = pa.compute.any(end_compare(tbl[cursor_path], start_value)).as_py() + # Include rows >= start_value + tbl = tbl.filter(start_compare(tbl[cursor_path], start_value)) + + if is_pandas: + return tbl.drop(["_dlt_index"]).to_pandas(), start_out_of_range, end_out_of_range + return tbl.drop(["_dlt_index"]), start_out_of_range, end_out_of_range + + +def get_transformer(item: TDataItem) -> IncrementalTransformer: + if is_arrow_item(item): + return ArrowIncremental() + elif pd is not None and isinstance(item, pd.DataFrame): + return ArrowIncremental() + return JsonIncremental() diff --git a/dlt/extract/incremental/typing.py b/dlt/extract/incremental/typing.py new file mode 100644 index 0000000000..03f36121be --- /dev/null +++ b/dlt/extract/incremental/typing.py @@ -0,0 +1,10 @@ +from typing import TypedDict, Optional, Any, List, TypeVar, Callable, Sequence + + +TCursorValue = TypeVar("TCursorValue", bound=Any) +LastValueFunc = Callable[[Sequence[TCursorValue]], Any] + +class IncrementalColumnState(TypedDict): + initial_value: Optional[Any] + last_value: Optional[Any] + unique_hashes: List[str] diff --git a/tests/pipeline/test_arrow_loading.py b/tests/pipeline/test_arrow_loading.py index c3e064443c..3ddb35e8ee 100644 --- a/tests/pipeline/test_arrow_loading.py +++ b/tests/pipeline/test_arrow_loading.py @@ -71,3 +71,17 @@ def some_data(): pipeline.extract(some_data()) pipeline.normalize() + + +@pytest.mark.parametrize("item_type", ["pandas", "table"]) +def test_extract_with_incremental(item_type: str): + item = make_data_item(item_type) + + pipeline = dlt.pipeline("arrow_" + uniq_id(), destination="filesystem") + + @dlt.resource + def some_data(incremental = dlt.sources.incremental("datetime")): + yield item + + pipeline.extract(some_data()) + pipeline.normalize()