Skip to content

Commit

Permalink
handles no deduplication case explicitly, more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Mar 7, 2024
1 parent 31c035c commit e18b971
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 49 deletions.
3 changes: 1 addition & 2 deletions dlt/extract/incremental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,15 +465,14 @@ def __call__(self, rows: TDataItems, meta: Any = None) -> Optional[TDataItems]:

# write back state
self._cached_state["last_value"] = transformer.last_value
if self.primary_key != ():
if not transformer.deduplication_disabled:
# compute hashes for new last rows
unique_hashes = set(
transformer.compute_unique_value(row, self.primary_key)
for row in transformer.last_rows
)
# add directly computed hashes
unique_hashes.update(transformer.unique_hashes)
print(unique_hashes)
self._cached_state["unique_hashes"] = list(unique_hashes)

return rows
Expand Down
77 changes: 42 additions & 35 deletions dlt/extract/incremental/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def compute_unique_value(
primary_key: Optional[TTableHintTemplate[TColumnNames]],
) -> str:
try:
assert not self.deduplication_disabled, (
f"{self.resource_name}: Attempt to compute unique values when deduplication is"
" disabled"
)

if primary_key:
return digest128(json.dumps(resolve_column_value(primary_key, row), sort_keys=True))
elif primary_key is None:
Expand All @@ -84,6 +89,11 @@ def __call__(
row: TDataItem,
) -> Tuple[bool, bool, bool]: ...

@property
def deduplication_disabled(self) -> bool:
"""Skip deduplication when length of the key is 0"""
return isinstance(self.primary_key, (list, tuple)) and len(self.primary_key) == 0


class JsonIncremental(IncrementalTransform):
def find_cursor_value(self, row: TDataItem) -> Any:
Expand Down Expand Up @@ -154,13 +164,13 @@ def __call__(
if new_value == self.start_value:
# if equal there's still a chance that item gets in
if processed_row_value == self.start_value:
unique_value = self.compute_unique_value(row, self.primary_key)
# if unique value exists then use it to deduplicate
if unique_value:
if not self.deduplication_disabled:
unique_value = self.compute_unique_value(row, self.primary_key)
# if unique value exists then use it to deduplicate
if unique_value in self.start_unique_hashes:
return None, True, False
else:
# smaller than start value gets out
# smaller than start value: gets out
return None, True, False

# we store row id for all records with the current "last_value" in state and use it to deduplicate
Expand All @@ -169,8 +179,7 @@ def __call__(
self.last_rows.append(row)
else:
self.last_value = new_value
# store rows with "max" values to compute hashes
# only when needed
# store rows with "max" values to compute hashes after processing full batch
self.last_rows = [row]
self.unique_hashes = set()

Expand Down Expand Up @@ -198,9 +207,7 @@ def compute_unique_values_with_index(
for index, row in zip(indices, rows)
]

def _add_unique_index(
self, tbl: "pa.Table", unique_columns: Optional[List[str]], aggregate: str, cursor_path: str
) -> "pa.Table":
def _add_unique_index(self, tbl: "pa.Table") -> "pa.Table":
"""Creates unique index if necessary."""
# create unique index if necessary
if self._dlt_index not in tbl.schema.names:
Expand Down Expand Up @@ -231,22 +238,18 @@ def __call__(
self._dlt_index = primary_key
elif primary_key is None:
unique_columns = tbl.schema.names
else: # deduplicating is disabled
unique_columns = None

start_out_of_range = end_out_of_range = False
if not tbl: # row is None or empty arrow table
return tbl, start_out_of_range, end_out_of_range

if self.last_value_func is max:
compute = pa.compute.max
aggregate = "max"
end_compare = pa.compute.less
last_value_compare = pa.compute.greater_equal
new_value_compare = pa.compute.greater
elif self.last_value_func is min:
compute = pa.compute.min
aggregate = "min"
end_compare = pa.compute.greater
last_value_compare = pa.compute.less_equal
new_value_compare = pa.compute.less
Expand Down Expand Up @@ -287,21 +290,24 @@ def __call__(
keep_filter = last_value_compare(tbl[cursor_path], start_value_scalar)
start_out_of_range = bool(pa.compute.any(pa.compute.invert(keep_filter)).as_py())
tbl = tbl.filter(keep_filter)
# Deduplicate after filtering old values
tbl = self._add_unique_index(tbl, unique_columns, aggregate, cursor_path)
# Remove already processed rows where the cursor is equal to the start value
eq_rows = tbl.filter(pa.compute.equal(tbl[cursor_path], start_value_scalar))
# compute index, unique hash mapping
unique_values_index = self.compute_unique_values_with_index(eq_rows, unique_columns)
unique_values_index = [
(i, uq_val)
for i, uq_val in unique_values_index
if uq_val in self.start_unique_hashes
]
# find rows with unique ids that were stored from previous run
remove_idx = pa.array(i for i, _ in unique_values_index)
# Filter the table
tbl = tbl.filter(pa.compute.invert(pa.compute.is_in(tbl[self._dlt_index], remove_idx)))
if not self.deduplication_disabled:
# Deduplicate after filtering old values
tbl = self._add_unique_index(tbl)
# Remove already processed rows where the cursor is equal to the start value
eq_rows = tbl.filter(pa.compute.equal(tbl[cursor_path], start_value_scalar))
# compute index, unique hash mapping
unique_values_index = self.compute_unique_values_with_index(eq_rows, unique_columns)
unique_values_index = [
(i, uq_val)
for i, uq_val in unique_values_index
if uq_val in self.start_unique_hashes
]
# find rows with unique ids that were stored from previous run
remove_idx = pa.array(i for i, _ in unique_values_index)
# Filter the table
tbl = tbl.filter(
pa.compute.invert(pa.compute.is_in(tbl[self._dlt_index], remove_idx))
)

if (
self.last_value is None
Expand All @@ -310,14 +316,15 @@ def __call__(
).as_py()
): # Last value has changed
self.last_value = row_value
# Compute unique hashes for all rows equal to row value
self.unique_hashes = set(
self.compute_unique_values(
tbl.filter(pa.compute.equal(tbl[cursor_path], row_value_scalar)),
unique_columns,
if not self.deduplication_disabled:
# Compute unique hashes for all rows equal to row value
self.unique_hashes = set(
self.compute_unique_values(
tbl.filter(pa.compute.equal(tbl[cursor_path], row_value_scalar)),
unique_columns,
)
)
)
elif self.last_value == row_value:
elif self.last_value == row_value and not self.deduplication_disabled:
# last value is unchanged, add the hashes
self.unique_hashes.update(
set(
Expand Down
131 changes: 119 additions & 12 deletions tests/extract/test_incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
from dlt.common.configuration.specs.base_configuration import configspec, BaseConfiguration
from dlt.common.configuration import ConfigurationValueError
from dlt.common.pendulum import pendulum, timedelta
from dlt.common.pipeline import StateInjectableContext, resource_state
from dlt.common.pipeline import NormalizeInfo, StateInjectableContext, resource_state
from dlt.common.schema.schema import Schema
from dlt.common.utils import uniq_id, digest128, chunks
from dlt.common.json import json

from dlt.extract import DltSource
from dlt.extract.exceptions import InvalidStepFunctionArguments
from dlt.extract.resource import DltResource
from dlt.sources.helpers.transform import take_first
from dlt.extract.incremental.exceptions import (
IncrementalCursorPathMissing,
Expand Down Expand Up @@ -1456,35 +1457,141 @@ def ascending_desc(


@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS)
def test_unique_values_unordered_rows(item_type: TDataItemFormat) -> None:
@dlt.resource
@pytest.mark.parametrize("order", ["random", "desc", "asc"])
@pytest.mark.parametrize("primary_key", [[], None, "updated_at"])
@pytest.mark.parametrize(
"deterministic", (True, False), ids=("deterministic-record", "non-deterministic-record")
)
def test_unique_values_unordered_rows(
item_type: TDataItemFormat, order: str, primary_key: Any, deterministic: bool
) -> None:
@dlt.resource(primary_key=primary_key)
def random_ascending_chunks(
order: str,
updated_at: dlt.sources.incremental[int] = dlt.sources.incremental(
"updated_at",
initial_value=10,
)
),
) -> Any:
range_ = list(range(updated_at.start_value, updated_at.start_value + 121))
random.shuffle(range_)
if order == "random":
random.shuffle(range_)
if order == "desc":
range_ = reversed(range_) # type: ignore[assignment]

for chunk in chunks(range_, 30):
# make sure that overlapping element is the last one
data = [{"updated_at": i} for i in chunk]
data = [
{"updated_at": i, "rand": random.random() if not deterministic else 0}
for i in chunk
]
# random.shuffle(data)
print(data)
yield data_to_item_format(item_type, data)

os.environ["COMPLETED_PROB"] = "1.0" # make it complete immediately
pipeline = dlt.pipeline("test_unique_values_unordered_rows", destination="dummy")
# print(list(random_ascending_chunks()))
pipeline.run(random_ascending_chunks())
pipeline.run(random_ascending_chunks(order))
assert pipeline.last_trace.last_normalize_info.row_counts["random_ascending_chunks"] == 121

# 120 rows (one overlap - incremental reacquires and deduplicates)
# print(list(random_ascending_chunks()))
pipeline.run(random_ascending_chunks())
assert pipeline.last_trace.last_normalize_info.row_counts["random_ascending_chunks"] == 120
pipeline.run(random_ascending_chunks(order))
# overlapping element must be deduped when:
# 1. we have primary key on just updated at
# OR we have a key on full record but the record is deterministic so duplicate may be found
rows = 120 if primary_key == "updated_at" or (deterministic and primary_key != []) else 121
assert pipeline.last_trace.last_normalize_info.row_counts["random_ascending_chunks"] == rows


# test next batch adding to unique
@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS)
@pytest.mark.parametrize("primary_key", [[], None, "updated_at"]) # [], None,
@pytest.mark.parametrize(
"deterministic", (True, False), ids=("deterministic-record", "non-deterministic-record")
)
def test_carry_unique_hashes(
item_type: TDataItemFormat, primary_key: Any, deterministic: bool
) -> None:
# each day extends list of hashes and removes duplicates until the last day

@dlt.resource(primary_key=primary_key)
def random_ascending_chunks(
# order: str,
day: int,
updated_at: dlt.sources.incremental[int] = dlt.sources.incremental(
"updated_at",
initial_value=10,
),
) -> Any:
range_ = random.sample(
range(updated_at.initial_value, updated_at.initial_value + 10), k=10
) # list(range(updated_at.initial_value, updated_at.initial_value + 10))
range_ += [100]
if day == 4:
# on day 4 add an element that will reset all others
range_ += [1000]

for chunk in chunks(range_, 3):
# make sure that overlapping element is the last one
data = [
{"updated_at": i, "rand": random.random() if not deterministic else 0}
for i in chunk
]
# random.shuffle(data)
print(data)
yield data_to_item_format(item_type, data)

os.environ["COMPLETED_PROB"] = "1.0" # make it complete immediately
pipeline = dlt.pipeline("test_unique_values_unordered_rows", destination="dummy")

def _assert_state(r_: DltResource, day: int, info: NormalizeInfo) -> None:
uniq_hashes = r_.state["incremental"]["updated_at"]["unique_hashes"]
row_count = info.row_counts.get("random_ascending_chunks", 0)
if primary_key == "updated_at":
# we keep only newest version of the record
assert len(uniq_hashes) == 1
if day == 1:
# all records loaded
assert row_count == 11
elif day == 4:
# new biggest item loaded
assert row_count == 1
else:
# all deduplicated
assert row_count == 0
elif primary_key is None:
# we deduplicate over full content
if day == 4:
assert len(uniq_hashes) == 1
# both the 100 or 1000 are in if non deterministic content
assert row_count == (2 if not deterministic else 1)
else:
# each day adds new hash if content non deterministic
assert len(uniq_hashes) == (day if not deterministic else 1)
if day == 1:
assert row_count == 11
else:
assert row_count == (1 if not deterministic else 0)
elif primary_key == []:
# no deduplication
assert len(uniq_hashes) == 0
if day == 4:
assert row_count == 2
else:
if day == 1:
assert row_count == 11
else:
assert row_count == 1

r_ = random_ascending_chunks(1)
pipeline.run(r_)
_assert_state(r_, 1, pipeline.last_trace.last_normalize_info)
r_ = random_ascending_chunks(2)
pipeline.run(r_)
_assert_state(r_, 2, pipeline.last_trace.last_normalize_info)
r_ = random_ascending_chunks(3)
_assert_state(r_, 3, pipeline.last_trace.last_normalize_info)
r_ = random_ascending_chunks(4)
_assert_state(r_, 4, pipeline.last_trace.last_normalize_info)


@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS)
Expand Down

0 comments on commit e18b971

Please sign in to comment.