diff --git a/dlt/extract/incremental/__init__.py b/dlt/extract/incremental/__init__.py index bfc9542a1d..9ad174fd63 100644 --- a/dlt/extract/incremental/__init__.py +++ b/dlt/extract/incremental/__init__.py @@ -465,7 +465,7 @@ def __call__(self, rows: TDataItems, meta: Any = None) -> Optional[TDataItems]: # write back state self._cached_state["last_value"] = transformer.last_value - if self.primary_key != (): + if not transformer.deduplication_disabled: # compute hashes for new last rows unique_hashes = set( transformer.compute_unique_value(row, self.primary_key) @@ -473,7 +473,6 @@ def __call__(self, rows: TDataItems, meta: Any = None) -> Optional[TDataItems]: ) # add directly computed hashes unique_hashes.update(transformer.unique_hashes) - print(unique_hashes) self._cached_state["unique_hashes"] = list(unique_hashes) return rows diff --git a/dlt/extract/incremental/transform.py b/dlt/extract/incremental/transform.py index 74b0966cee..2ad827b755 100644 --- a/dlt/extract/incremental/transform.py +++ b/dlt/extract/incremental/transform.py @@ -70,6 +70,11 @@ def compute_unique_value( primary_key: Optional[TTableHintTemplate[TColumnNames]], ) -> str: try: + assert not self.deduplication_disabled, ( + f"{self.resource_name}: Attempt to compute unique values when deduplication is" + " disabled" + ) + if primary_key: return digest128(json.dumps(resolve_column_value(primary_key, row), sort_keys=True)) elif primary_key is None: @@ -84,6 +89,11 @@ def __call__( row: TDataItem, ) -> Tuple[bool, bool, bool]: ... + @property + def deduplication_disabled(self) -> bool: + """Skip deduplication when length of the key is 0""" + return isinstance(self.primary_key, (list, tuple)) and len(self.primary_key) == 0 + class JsonIncremental(IncrementalTransform): def find_cursor_value(self, row: TDataItem) -> Any: @@ -154,13 +164,13 @@ def __call__( if new_value == self.start_value: # if equal there's still a chance that item gets in if processed_row_value == self.start_value: - unique_value = self.compute_unique_value(row, self.primary_key) - # if unique value exists then use it to deduplicate - if unique_value: + if not self.deduplication_disabled: + unique_value = self.compute_unique_value(row, self.primary_key) + # if unique value exists then use it to deduplicate if unique_value in self.start_unique_hashes: return None, True, False else: - # smaller than start value gets out + # smaller than start value: gets out return None, True, False # we store row id for all records with the current "last_value" in state and use it to deduplicate @@ -169,8 +179,7 @@ def __call__( self.last_rows.append(row) else: self.last_value = new_value - # store rows with "max" values to compute hashes - # only when needed + # store rows with "max" values to compute hashes after processing full batch self.last_rows = [row] self.unique_hashes = set() @@ -198,9 +207,7 @@ def compute_unique_values_with_index( for index, row in zip(indices, rows) ] - def _add_unique_index( - self, tbl: "pa.Table", unique_columns: Optional[List[str]], aggregate: str, cursor_path: str - ) -> "pa.Table": + def _add_unique_index(self, tbl: "pa.Table") -> "pa.Table": """Creates unique index if necessary.""" # create unique index if necessary if self._dlt_index not in tbl.schema.names: @@ -231,8 +238,6 @@ def __call__( self._dlt_index = primary_key elif primary_key is None: unique_columns = tbl.schema.names - else: # deduplicating is disabled - unique_columns = None start_out_of_range = end_out_of_range = False if not tbl: # row is None or empty arrow table @@ -240,13 +245,11 @@ def __call__( if self.last_value_func is max: compute = pa.compute.max - aggregate = "max" end_compare = pa.compute.less last_value_compare = pa.compute.greater_equal new_value_compare = pa.compute.greater elif self.last_value_func is min: compute = pa.compute.min - aggregate = "min" end_compare = pa.compute.greater last_value_compare = pa.compute.less_equal new_value_compare = pa.compute.less @@ -287,21 +290,24 @@ def __call__( keep_filter = last_value_compare(tbl[cursor_path], start_value_scalar) start_out_of_range = bool(pa.compute.any(pa.compute.invert(keep_filter)).as_py()) tbl = tbl.filter(keep_filter) - # Deduplicate after filtering old values - tbl = self._add_unique_index(tbl, unique_columns, aggregate, cursor_path) - # Remove already processed rows where the cursor is equal to the start value - eq_rows = tbl.filter(pa.compute.equal(tbl[cursor_path], start_value_scalar)) - # compute index, unique hash mapping - unique_values_index = self.compute_unique_values_with_index(eq_rows, unique_columns) - unique_values_index = [ - (i, uq_val) - for i, uq_val in unique_values_index - if uq_val in self.start_unique_hashes - ] - # find rows with unique ids that were stored from previous run - remove_idx = pa.array(i for i, _ in unique_values_index) - # Filter the table - tbl = tbl.filter(pa.compute.invert(pa.compute.is_in(tbl[self._dlt_index], remove_idx))) + if not self.deduplication_disabled: + # Deduplicate after filtering old values + tbl = self._add_unique_index(tbl) + # Remove already processed rows where the cursor is equal to the start value + eq_rows = tbl.filter(pa.compute.equal(tbl[cursor_path], start_value_scalar)) + # compute index, unique hash mapping + unique_values_index = self.compute_unique_values_with_index(eq_rows, unique_columns) + unique_values_index = [ + (i, uq_val) + for i, uq_val in unique_values_index + if uq_val in self.start_unique_hashes + ] + # find rows with unique ids that were stored from previous run + remove_idx = pa.array(i for i, _ in unique_values_index) + # Filter the table + tbl = tbl.filter( + pa.compute.invert(pa.compute.is_in(tbl[self._dlt_index], remove_idx)) + ) if ( self.last_value is None @@ -310,14 +316,15 @@ def __call__( ).as_py() ): # Last value has changed self.last_value = row_value - # Compute unique hashes for all rows equal to row value - self.unique_hashes = set( - self.compute_unique_values( - tbl.filter(pa.compute.equal(tbl[cursor_path], row_value_scalar)), - unique_columns, + if not self.deduplication_disabled: + # Compute unique hashes for all rows equal to row value + self.unique_hashes = set( + self.compute_unique_values( + tbl.filter(pa.compute.equal(tbl[cursor_path], row_value_scalar)), + unique_columns, + ) ) - ) - elif self.last_value == row_value: + elif self.last_value == row_value and not self.deduplication_disabled: # last value is unchanged, add the hashes self.unique_hashes.update( set( diff --git a/tests/extract/test_incremental.py b/tests/extract/test_incremental.py index b3db3b852b..895b8b1d4a 100644 --- a/tests/extract/test_incremental.py +++ b/tests/extract/test_incremental.py @@ -15,13 +15,14 @@ from dlt.common.configuration.specs.base_configuration import configspec, BaseConfiguration from dlt.common.configuration import ConfigurationValueError from dlt.common.pendulum import pendulum, timedelta -from dlt.common.pipeline import StateInjectableContext, resource_state +from dlt.common.pipeline import NormalizeInfo, StateInjectableContext, resource_state from dlt.common.schema.schema import Schema from dlt.common.utils import uniq_id, digest128, chunks from dlt.common.json import json from dlt.extract import DltSource from dlt.extract.exceptions import InvalidStepFunctionArguments +from dlt.extract.resource import DltResource from dlt.sources.helpers.transform import take_first from dlt.extract.incremental.exceptions import ( IncrementalCursorPathMissing, @@ -1456,35 +1457,141 @@ def ascending_desc( @pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_unique_values_unordered_rows(item_type: TDataItemFormat) -> None: - @dlt.resource +@pytest.mark.parametrize("order", ["random", "desc", "asc"]) +@pytest.mark.parametrize("primary_key", [[], None, "updated_at"]) +@pytest.mark.parametrize( + "deterministic", (True, False), ids=("deterministic-record", "non-deterministic-record") +) +def test_unique_values_unordered_rows( + item_type: TDataItemFormat, order: str, primary_key: Any, deterministic: bool +) -> None: + @dlt.resource(primary_key=primary_key) def random_ascending_chunks( + order: str, updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( "updated_at", initial_value=10, - ) + ), ) -> Any: range_ = list(range(updated_at.start_value, updated_at.start_value + 121)) - random.shuffle(range_) + if order == "random": + random.shuffle(range_) + if order == "desc": + range_ = reversed(range_) # type: ignore[assignment] + for chunk in chunks(range_, 30): # make sure that overlapping element is the last one - data = [{"updated_at": i} for i in chunk] + data = [ + {"updated_at": i, "rand": random.random() if not deterministic else 0} + for i in chunk + ] # random.shuffle(data) print(data) yield data_to_item_format(item_type, data) os.environ["COMPLETED_PROB"] = "1.0" # make it complete immediately pipeline = dlt.pipeline("test_unique_values_unordered_rows", destination="dummy") - # print(list(random_ascending_chunks())) - pipeline.run(random_ascending_chunks()) + pipeline.run(random_ascending_chunks(order)) assert pipeline.last_trace.last_normalize_info.row_counts["random_ascending_chunks"] == 121 # 120 rows (one overlap - incremental reacquires and deduplicates) - # print(list(random_ascending_chunks())) - pipeline.run(random_ascending_chunks()) - assert pipeline.last_trace.last_normalize_info.row_counts["random_ascending_chunks"] == 120 + pipeline.run(random_ascending_chunks(order)) + # overlapping element must be deduped when: + # 1. we have primary key on just updated at + # OR we have a key on full record but the record is deterministic so duplicate may be found + rows = 120 if primary_key == "updated_at" or (deterministic and primary_key != []) else 121 + assert pipeline.last_trace.last_normalize_info.row_counts["random_ascending_chunks"] == rows + - # test next batch adding to unique +@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) +@pytest.mark.parametrize("primary_key", [[], None, "updated_at"]) # [], None, +@pytest.mark.parametrize( + "deterministic", (True, False), ids=("deterministic-record", "non-deterministic-record") +) +def test_carry_unique_hashes( + item_type: TDataItemFormat, primary_key: Any, deterministic: bool +) -> None: + # each day extends list of hashes and removes duplicates until the last day + + @dlt.resource(primary_key=primary_key) + def random_ascending_chunks( + # order: str, + day: int, + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( + "updated_at", + initial_value=10, + ), + ) -> Any: + range_ = random.sample( + range(updated_at.initial_value, updated_at.initial_value + 10), k=10 + ) # list(range(updated_at.initial_value, updated_at.initial_value + 10)) + range_ += [100] + if day == 4: + # on day 4 add an element that will reset all others + range_ += [1000] + + for chunk in chunks(range_, 3): + # make sure that overlapping element is the last one + data = [ + {"updated_at": i, "rand": random.random() if not deterministic else 0} + for i in chunk + ] + # random.shuffle(data) + print(data) + yield data_to_item_format(item_type, data) + + os.environ["COMPLETED_PROB"] = "1.0" # make it complete immediately + pipeline = dlt.pipeline("test_unique_values_unordered_rows", destination="dummy") + + def _assert_state(r_: DltResource, day: int, info: NormalizeInfo) -> None: + uniq_hashes = r_.state["incremental"]["updated_at"]["unique_hashes"] + row_count = info.row_counts.get("random_ascending_chunks", 0) + if primary_key == "updated_at": + # we keep only newest version of the record + assert len(uniq_hashes) == 1 + if day == 1: + # all records loaded + assert row_count == 11 + elif day == 4: + # new biggest item loaded + assert row_count == 1 + else: + # all deduplicated + assert row_count == 0 + elif primary_key is None: + # we deduplicate over full content + if day == 4: + assert len(uniq_hashes) == 1 + # both the 100 or 1000 are in if non deterministic content + assert row_count == (2 if not deterministic else 1) + else: + # each day adds new hash if content non deterministic + assert len(uniq_hashes) == (day if not deterministic else 1) + if day == 1: + assert row_count == 11 + else: + assert row_count == (1 if not deterministic else 0) + elif primary_key == []: + # no deduplication + assert len(uniq_hashes) == 0 + if day == 4: + assert row_count == 2 + else: + if day == 1: + assert row_count == 11 + else: + assert row_count == 1 + + r_ = random_ascending_chunks(1) + pipeline.run(r_) + _assert_state(r_, 1, pipeline.last_trace.last_normalize_info) + r_ = random_ascending_chunks(2) + pipeline.run(r_) + _assert_state(r_, 2, pipeline.last_trace.last_normalize_info) + r_ = random_ascending_chunks(3) + _assert_state(r_, 3, pipeline.last_trace.last_normalize_info) + r_ = random_ascending_chunks(4) + _assert_state(r_, 4, pipeline.last_trace.last_normalize_info) @pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS)