Skip to content

Commit

Permalink
Add open/closed range arguments for incremental (#1991)
Browse files Browse the repository at this point in the history
* Add open/closed range arguments for incremental

* Docs for incremental range args

* Docstring

* Typo

* Ensure deduplication is disabled when range_start=='open'

* Cache transformer settings
  • Loading branch information
steinitzu authored and donotpush committed Dec 11, 2024
1 parent 8a650b1 commit 2a749a5
Show file tree
Hide file tree
Showing 9 changed files with 434 additions and 136 deletions.
4 changes: 4 additions & 0 deletions dlt/common/incremental/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
LastValueFunc = Callable[[Sequence[TCursorValue]], Any]
OnCursorValueMissing = Literal["raise", "include", "exclude"]

TIncrementalRange = Literal["open", "closed"]


class IncrementalColumnState(TypedDict):
initial_value: Optional[Any]
Expand All @@ -26,3 +28,5 @@ class IncrementalArgs(TypedDict, total=False):
allow_external_schedulers: Optional[bool]
lag: Optional[Union[float, int]]
on_cursor_value_missing: Optional[OnCursorValueMissing]
range_start: Optional[TIncrementalRange]
range_end: Optional[TIncrementalRange]
60 changes: 38 additions & 22 deletions dlt/extract/incremental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
LastValueFunc,
OnCursorValueMissing,
IncrementalArgs,
TIncrementalRange,
)
from dlt.extract.items import SupportsPipe, TTableHintTemplate, ItemTransform
from dlt.extract.incremental.transform import (
Expand Down Expand Up @@ -104,6 +105,11 @@ class Incremental(ItemTransform[TDataItem], BaseConfiguration, Generic[TCursorVa
Note that if logical "end date" is present then also "end_value" will be set which means that resource state is not used and exactly this range of date will be loaded
on_cursor_value_missing: Specify what happens when the cursor_path does not exist in a record or a record has `None` at the cursor_path: raise, include, exclude
lag: Optional value used to define a lag or attribution window. For datetime cursors, this is interpreted as seconds. For other types, it uses the + or - operator depending on the last_value_func.
range_start: Decide whether the incremental filtering range is `open` or `closed` on the start value side. Default is `closed`.
Setting this to `open` means that items with the same cursor value as the last value from the previous run (or `initial_value`) are excluded from the result.
The `open` range disables deduplication logic so it can serve as an optimization when you know cursors don't overlap between pipeline runs.
range_end: Decide whether the incremental filtering range is `open` or `closed` on the end value side. Default is `open` (exact `end_value` is excluded).
Setting this to `closed` means that items with the exact same cursor value as the `end_value` are included in the result.
"""

# this is config/dataclass so declare members
Expand All @@ -116,6 +122,8 @@ class Incremental(ItemTransform[TDataItem], BaseConfiguration, Generic[TCursorVa
on_cursor_value_missing: OnCursorValueMissing = "raise"
lag: Optional[float] = None
duplicate_cursor_warning_threshold: ClassVar[int] = 200
range_start: TIncrementalRange = "closed"
range_end: TIncrementalRange = "open"

# incremental acting as empty
EMPTY: ClassVar["Incremental[Any]"] = None
Expand All @@ -132,6 +140,8 @@ def __init__(
allow_external_schedulers: bool = False,
on_cursor_value_missing: OnCursorValueMissing = "raise",
lag: Optional[float] = None,
range_start: TIncrementalRange = "closed",
range_end: TIncrementalRange = "open",
) -> None:
# make sure that path is valid
if cursor_path:
Expand Down Expand Up @@ -174,9 +184,11 @@ def __init__(
self.start_out_of_range: bool = False
"""Becomes true on the first item that is out of range of `start_value`. I.e. when using `max` this is a value that is lower than `start_value`"""

self._transformers: Dict[str, IncrementalTransform] = {}
self._transformers: Dict[Type[IncrementalTransform], IncrementalTransform] = {}
self._bound_pipe: SupportsPipe = None
"""Bound pipe"""
self.range_start = range_start
self.range_end = range_end

@property
def primary_key(self) -> Optional[TTableHintTemplate[TColumnNames]]:
Expand All @@ -190,22 +202,6 @@ def primary_key(self, value: str) -> None:
for transform in self._transformers.values():
transform.primary_key = value

def _make_transforms(self) -> None:
types = [("arrow", ArrowIncremental), ("json", JsonIncremental)]
for dt, kls in types:
self._transformers[dt] = kls(
self.resource_name,
self.cursor_path,
self.initial_value,
self.start_value,
self.end_value,
self.last_value_func,
self._primary_key,
set(self._cached_state["unique_hashes"]),
self.on_cursor_value_missing,
self.lag,
)

@classmethod
def from_existing_state(
cls, resource_name: str, cursor_path: str
Expand Down Expand Up @@ -489,7 +485,8 @@ def bind(self, pipe: SupportsPipe) -> "Incremental[TCursorValue]":
)
# cache state
self._cached_state = self.get_state()
self._make_transforms()
# Clear transforms so we get new instances
self._transformers.clear()
return self

def can_close(self) -> bool:
Expand Down Expand Up @@ -520,15 +517,34 @@ def __str__(self) -> str:
f" {self.last_value_func}"
)

def _make_or_get_transformer(self, cls: Type[IncrementalTransform]) -> IncrementalTransform:
if transformer := self._transformers.get(cls):
return transformer
transformer = self._transformers[cls] = cls(
self.resource_name,
self.cursor_path,
self.initial_value,
self.start_value,
self.end_value,
self.last_value_func,
self._primary_key,
set(self._cached_state["unique_hashes"]),
self.on_cursor_value_missing,
self.lag,
self.range_start,
self.range_end,
)
return transformer

def _get_transformer(self, items: TDataItems) -> IncrementalTransform:
# Assume list is all of the same type
for item in items if isinstance(items, list) else [items]:
if is_arrow_item(item):
return self._transformers["arrow"]
return self._make_or_get_transformer(ArrowIncremental)
elif pandas is not None and isinstance(item, pandas.DataFrame):
return self._transformers["arrow"]
return self._transformers["json"]
return self._transformers["json"]
return self._make_or_get_transformer(ArrowIncremental)
return self._make_or_get_transformer(JsonIncremental)
return self._make_or_get_transformer(JsonIncremental)

def __call__(self, rows: TDataItems, meta: Any = None) -> Optional[TDataItems]:
if rows is None:
Expand Down
75 changes: 50 additions & 25 deletions dlt/extract/incremental/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
IncrementalPrimaryKeyMissing,
IncrementalCursorPathHasValueNone,
)
from dlt.common.incremental.typing import TCursorValue, LastValueFunc, OnCursorValueMissing
from dlt.common.incremental.typing import (
TCursorValue,
LastValueFunc,
OnCursorValueMissing,
TIncrementalRange,
)
from dlt.extract.utils import resolve_column_value
from dlt.extract.items import TTableHintTemplate

Expand Down Expand Up @@ -57,6 +62,8 @@ def __init__(
unique_hashes: Set[str],
on_cursor_value_missing: OnCursorValueMissing = "raise",
lag: Optional[float] = None,
range_start: TIncrementalRange = "closed",
range_end: TIncrementalRange = "open",
) -> None:
self.resource_name = resource_name
self.cursor_path = cursor_path
Expand All @@ -71,6 +78,9 @@ def __init__(
self.start_unique_hashes = set(unique_hashes)
self.on_cursor_value_missing = on_cursor_value_missing
self.lag = lag
self.range_start = range_start
self.range_end = range_end

# compile jsonpath
self._compiled_cursor_path = compile_path(cursor_path)
# for simple column name we'll fallback to search in dict
Expand Down Expand Up @@ -107,6 +117,8 @@ def __call__(
def deduplication_disabled(self) -> bool:
"""Skip deduplication when length of the key is 0 or if lag is applied."""
# disable deduplication if end value is set - state is not saved
if self.range_start == "open":
return True
if self.end_value is not None:
return True
# disable deduplication if lag is applied - destination must deduplicate ranges
Expand Down Expand Up @@ -191,10 +203,10 @@ def __call__(
# Filter end value ranges exclusively, so in case of "max" function we remove values >= end_value
if self.end_value is not None:
try:
if (
last_value_func((row_value, self.end_value)) != self.end_value
or last_value_func((row_value,)) == self.end_value
):
if last_value_func((row_value, self.end_value)) != self.end_value:
return None, False, True

if self.range_end == "open" and last_value_func((row_value,)) == self.end_value:
return None, False, True
except Exception as ex:
raise IncrementalCursorInvalidCoercion(
Expand All @@ -221,6 +233,9 @@ def __call__(
) from ex
# new_value is "less" or equal to last_value (the actual max)
if last_value == new_value:
if self.range_start == "open":
# We only want greater than last_value
return None, False, False
# use func to compute row_value into last_value compatible
processed_row_value = last_value_func((row_value,))
# skip the record that is not a start_value or new_value: that record was already processed
Expand Down Expand Up @@ -258,6 +273,31 @@ def __call__(
class ArrowIncremental(IncrementalTransform):
_dlt_index = "_dlt_index"

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
if self.last_value_func is max:
self.compute = pa.compute.max
self.end_compare = (
pa.compute.less if self.range_end == "open" else pa.compute.less_equal
)
self.last_value_compare = (
pa.compute.greater_equal if self.range_start == "closed" else pa.compute.greater
)
self.new_value_compare = pa.compute.greater
elif self.last_value_func is min:
self.compute = pa.compute.min
self.end_compare = (
pa.compute.greater if self.range_end == "open" else pa.compute.greater_equal
)
self.last_value_compare = (
pa.compute.less_equal if self.range_start == "closed" else pa.compute.less
)
self.new_value_compare = pa.compute.less
else:
raise NotImplementedError(
"Only min or max last_value_func is supported for arrow tables"
)

def compute_unique_values(self, item: "TAnyArrowItem", unique_columns: List[str]) -> List[str]:
if not unique_columns:
return []
Expand Down Expand Up @@ -312,28 +352,13 @@ def __call__(
if not tbl: # row is None or empty arrow table
return tbl, start_out_of_range, end_out_of_range

if self.last_value_func is max:
compute = pa.compute.max
end_compare = pa.compute.less
last_value_compare = pa.compute.greater_equal
new_value_compare = pa.compute.greater
elif self.last_value_func is min:
compute = pa.compute.min
end_compare = pa.compute.greater
last_value_compare = pa.compute.less_equal
new_value_compare = pa.compute.less
else:
raise NotImplementedError(
"Only min or max last_value_func is supported for arrow tables"
)

# TODO: Json path support. For now assume the cursor_path is a column name
cursor_path = self.cursor_path

# The new max/min value
try:
# NOTE: datetimes are always pendulum in UTC
row_value = from_arrow_scalar(compute(tbl[cursor_path]))
row_value = from_arrow_scalar(self.compute(tbl[cursor_path]))
cursor_data_type = tbl.schema.field(cursor_path).type
row_value_scalar = to_arrow_scalar(row_value, cursor_data_type)
except KeyError as e:
Expand Down Expand Up @@ -364,10 +389,10 @@ def __call__(
cursor_data_type,
str(ex),
) from ex
tbl = tbl.filter(end_compare(tbl[cursor_path], end_value_scalar))
tbl = tbl.filter(self.end_compare(tbl[cursor_path], end_value_scalar))
# Is max row value higher than end value?
# NOTE: pyarrow bool *always* evaluates to python True. `as_py()` is necessary
end_out_of_range = not end_compare(row_value_scalar, end_value_scalar).as_py()
end_out_of_range = not self.end_compare(row_value_scalar, end_value_scalar).as_py()

if self.start_value is not None:
try:
Expand All @@ -383,7 +408,7 @@ def __call__(
str(ex),
) from ex
# Remove rows lower or equal than the last start value
keep_filter = last_value_compare(tbl[cursor_path], start_value_scalar)
keep_filter = self.last_value_compare(tbl[cursor_path], start_value_scalar)
start_out_of_range = bool(pa.compute.any(pa.compute.invert(keep_filter)).as_py())
tbl = tbl.filter(keep_filter)
if not self.deduplication_disabled:
Expand All @@ -407,7 +432,7 @@ def __call__(

if (
self.last_value is None
or new_value_compare(
or self.new_value_compare(
row_value_scalar, to_arrow_scalar(self.last_value, cursor_data_type)
).as_py()
): # Last value has changed
Expand Down
12 changes: 8 additions & 4 deletions dlt/sources/sql_database/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,16 @@ def __init__(
self.end_value = incremental.end_value
self.row_order: TSortOrder = self.incremental.row_order
self.on_cursor_value_missing = self.incremental.on_cursor_value_missing
self.range_start = self.incremental.range_start
self.range_end = self.incremental.range_end
else:
self.cursor_column = None
self.last_value = None
self.end_value = None
self.row_order = None
self.on_cursor_value_missing = None
self.range_start = None
self.range_end = None

def _make_query(self) -> SelectAny:
table = self.table
Expand All @@ -110,11 +114,11 @@ def _make_query(self) -> SelectAny:

# generate where
if last_value_func is max: # Query ordered and filtered according to last_value function
filter_op = operator.ge
filter_op_end = operator.lt
filter_op = operator.ge if self.range_start == "closed" else operator.gt
filter_op_end = operator.lt if self.range_end == "open" else operator.le
elif last_value_func is min:
filter_op = operator.le
filter_op_end = operator.gt
filter_op = operator.le if self.range_start == "closed" else operator.lt
filter_op_end = operator.gt if self.range_end == "open" else operator.ge
else: # Custom last_value, load everything and let incremental handle filtering
return query # type: ignore[no-any-return]

Expand Down
Loading

0 comments on commit 2a749a5

Please sign in to comment.