Skip to content

Commit

Permalink
Cache transformer settings
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Dec 6, 2024
1 parent 674736c commit 10e0770
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 50 deletions.
50 changes: 26 additions & 24 deletions dlt/extract/incremental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def __init__(
self.start_out_of_range: bool = False
"""Becomes true on the first item that is out of range of `start_value`. I.e. when using `max` this is a value that is lower than `start_value`"""

self._transformers: Dict[str, IncrementalTransform] = {}
self._transformers: Dict[Type[IncrementalTransform], IncrementalTransform] = {}
self._bound_pipe: SupportsPipe = None
"""Bound pipe"""
self.range_start = range_start
Expand All @@ -202,24 +202,6 @@ def primary_key(self, value: str) -> None:
for transform in self._transformers.values():
transform.primary_key = value

def _make_transforms(self) -> None:
types = [("arrow", ArrowIncremental), ("json", JsonIncremental)]
for dt, kls in types:
self._transformers[dt] = kls(
self.resource_name,
self.cursor_path,
self.initial_value,
self.start_value,
self.end_value,
self.last_value_func,
self._primary_key,
set(self._cached_state["unique_hashes"]),
self.on_cursor_value_missing,
self.lag,
self.range_start,
self.range_end,
)

@classmethod
def from_existing_state(
cls, resource_name: str, cursor_path: str
Expand Down Expand Up @@ -503,7 +485,8 @@ def bind(self, pipe: SupportsPipe) -> "Incremental[TCursorValue]":
)
# cache state
self._cached_state = self.get_state()
self._make_transforms()
# Clear transforms so we get new instances
self._transformers.clear()
return self

def can_close(self) -> bool:
Expand Down Expand Up @@ -534,15 +517,34 @@ def __str__(self) -> str:
f" {self.last_value_func}"
)

def _make_transformer(self, cls: Type[IncrementalTransform]) -> IncrementalTransform:
if transformer := self._transformers.get(cls):
return transformer
transformer = self._transformers[cls] = cls(
self.resource_name,
self.cursor_path,
self.initial_value,
self.start_value,
self.end_value,
self.last_value_func,
self._primary_key,
set(self._cached_state["unique_hashes"]),
self.on_cursor_value_missing,
self.lag,
self.range_start,
self.range_end,
)
return transformer

def _get_transformer(self, items: TDataItems) -> IncrementalTransform:
# Assume list is all of the same type
for item in items if isinstance(items, list) else [items]:
if is_arrow_item(item):
return self._transformers["arrow"]
return self._make_transformer(ArrowIncremental)
elif pandas is not None and isinstance(item, pandas.DataFrame):
return self._transformers["arrow"]
return self._transformers["json"]
return self._transformers["json"]
return self._make_transformer(ArrowIncremental)
return self._make_transformer(JsonIncremental)
return self._make_transformer(JsonIncremental)

def __call__(self, rows: TDataItems, meta: Any = None) -> Optional[TDataItems]:
if rows is None:
Expand Down
56 changes: 30 additions & 26 deletions dlt/extract/incremental/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,31 @@ def __call__(
class ArrowIncremental(IncrementalTransform):
_dlt_index = "_dlt_index"

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
if self.last_value_func is max:
self.compute = pa.compute.max
self.end_compare = (
pa.compute.less if self.range_end == "open" else pa.compute.less_equal
)
self.last_value_compare = (
pa.compute.greater_equal if self.range_start == "closed" else pa.compute.greater
)
self.new_value_compare = pa.compute.greater
elif self.last_value_func is min:
self.compute = pa.compute.min
self.end_compare = (
pa.compute.greater if self.range_end == "open" else pa.compute.greater_equal
)
self.last_value_compare = (
pa.compute.less_equal if self.range_start == "closed" else pa.compute.less
)
self.new_value_compare = pa.compute.less
else:
raise NotImplementedError(
"Only min or max last_value_func is supported for arrow tables"
)

def compute_unique_values(self, item: "TAnyArrowItem", unique_columns: List[str]) -> List[str]:
if not unique_columns:
return []
Expand Down Expand Up @@ -327,34 +352,13 @@ def __call__(
if not tbl: # row is None or empty arrow table
return tbl, start_out_of_range, end_out_of_range

if self.last_value_func is max:
compute = pa.compute.max
end_compare = pa.compute.less if self.range_end == "open" else pa.compute.less_equal
last_value_compare = (
pa.compute.greater_equal if self.range_start == "closed" else pa.compute.greater
)
new_value_compare = pa.compute.greater
elif self.last_value_func is min:
compute = pa.compute.min
end_compare = (
pa.compute.greater if self.range_end == "open" else pa.compute.greater_equal
)
last_value_compare = (
pa.compute.less_equal if self.range_start == "closed" else pa.compute.less
)
new_value_compare = pa.compute.less
else:
raise NotImplementedError(
"Only min or max last_value_func is supported for arrow tables"
)

# TODO: Json path support. For now assume the cursor_path is a column name
cursor_path = self.cursor_path

# The new max/min value
try:
# NOTE: datetimes are always pendulum in UTC
row_value = from_arrow_scalar(compute(tbl[cursor_path]))
row_value = from_arrow_scalar(self.compute(tbl[cursor_path]))
cursor_data_type = tbl.schema.field(cursor_path).type
row_value_scalar = to_arrow_scalar(row_value, cursor_data_type)
except KeyError as e:
Expand Down Expand Up @@ -385,10 +389,10 @@ def __call__(
cursor_data_type,
str(ex),
) from ex
tbl = tbl.filter(end_compare(tbl[cursor_path], end_value_scalar))
tbl = tbl.filter(self.end_compare(tbl[cursor_path], end_value_scalar))
# Is max row value higher than end value?
# NOTE: pyarrow bool *always* evaluates to python True. `as_py()` is necessary
end_out_of_range = not end_compare(row_value_scalar, end_value_scalar).as_py()
end_out_of_range = not self.end_compare(row_value_scalar, end_value_scalar).as_py()

if self.start_value is not None:
try:
Expand All @@ -404,7 +408,7 @@ def __call__(
str(ex),
) from ex
# Remove rows lower or equal than the last start value
keep_filter = last_value_compare(tbl[cursor_path], start_value_scalar)
keep_filter = self.last_value_compare(tbl[cursor_path], start_value_scalar)
start_out_of_range = bool(pa.compute.any(pa.compute.invert(keep_filter)).as_py())
tbl = tbl.filter(keep_filter)
if not self.deduplication_disabled:
Expand All @@ -428,7 +432,7 @@ def __call__(

if (
self.last_value is None
or new_value_compare(
or self.new_value_compare(
row_value_scalar, to_arrow_scalar(self.last_value, cursor_data_type)
).as_py()
): # Last value has changed
Expand Down

0 comments on commit 10e0770

Please sign in to comment.