Skip to content

Commit

Permalink
uses the same logic for naive datetimes in arrow and object incrementals
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Jan 1, 2025
1 parent 95fedad commit 9ef9b9e
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 18 deletions.
1 change: 0 additions & 1 deletion dlt/extract/incremental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
coerce_value,
py_type_to_sc_type,
)
from dlt.common.utils import without_none

from dlt.extract.exceptions import IncrementalUnboundError
from dlt.extract.incremental.exceptions import (
Expand Down
36 changes: 20 additions & 16 deletions dlt/extract/incremental/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,19 @@ def __call__(
row: TDataItem,
) -> Tuple[bool, bool, bool]: ...

@staticmethod
def _adapt_if_datetime(row_value: Any, last_value: Any) -> Any:
# For datetime cursor, ensure the value is a timezone aware datetime.
# The object saved in state will always be a tz aware pendulum datetime so this ensures values are comparable
if (
isinstance(row_value, datetime)
and row_value.tzinfo is None
and isinstance(last_value, datetime)
and last_value.tzinfo is not None
):
row_value = pendulum.instance(row_value).in_tz("UTC")
return row_value

@property
def deduplication_disabled(self) -> bool:
"""Skip deduplication when length of the key is 0 or if lag is applied."""
Expand Down Expand Up @@ -185,19 +198,9 @@ def __call__(
return None, False, False
else:
return row, False, False

last_value = self.last_value
last_value_func = self.last_value_func

# For datetime cursor, ensure the value is a timezone aware datetime.
# The object saved in state will always be a tz aware pendulum datetime so this ensures values are comparable
if (
isinstance(row_value, datetime)
and row_value.tzinfo is None
and isinstance(last_value, datetime)
and last_value.tzinfo is not None
):
row_value = pendulum.instance(row_value).in_tz("UTC")
row_value = self._adapt_if_datetime(row_value, last_value)

# Check whether end_value has been reached
# Filter end value ranges exclusively, so in case of "max" function we remove values >= end_value
Expand Down Expand Up @@ -354,13 +357,8 @@ def __call__(

# TODO: Json path support. For now assume the cursor_path is a column name
cursor_path = self.cursor_path

# The new max/min value
try:
# NOTE: datetimes are always pendulum in UTC
row_value = from_arrow_scalar(self.compute(tbl[cursor_path]))
cursor_data_type = tbl.schema.field(cursor_path).type
row_value_scalar = to_arrow_scalar(row_value, cursor_data_type)
except KeyError as e:
raise IncrementalCursorPathMissing(
self.resource_name,
Expand All @@ -371,6 +369,12 @@ def __call__(
" must be a column name.",
) from e

# The new max/min value
row_value_scalar = self.compute(
tbl[cursor_path]
) # to_arrow_scalar(row_value, cursor_data_type)
row_value = self._adapt_if_datetime(from_arrow_scalar(row_value_scalar), self.last_value)

if tbl.schema.field(cursor_path).nullable:
tbl_without_null, tbl_with_null = self._process_null_at_cursor_path(tbl)
tbl = tbl_without_null
Expand Down
1 change: 0 additions & 1 deletion tests/libs/test_parquet_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from dlt.common.schema.utils import new_column
from dlt.common.configuration.specs.config_section_context import ConfigSectionContext
from dlt.common.time import ensure_pendulum_datetime
from dlt.common.libs.pyarrow import from_arrow_scalar

from tests.common.data_writers.utils import get_writer
from tests.cases import (
Expand Down

0 comments on commit 9ef9b9e

Please sign in to comment.