diff --git a/dlt/extract/incremental/__init__.py b/dlt/extract/incremental/__init__.py index f0c6803b51..5e7bae49c6 100644 --- a/dlt/extract/incremental/__init__.py +++ b/dlt/extract/incremental/__init__.py @@ -517,7 +517,7 @@ def __str__(self) -> str: f" {self.last_value_func}" ) - def _make_transformer(self, cls: Type[IncrementalTransform]) -> IncrementalTransform: + def _make_or_get_transformer(self, cls: Type[IncrementalTransform]) -> IncrementalTransform: if transformer := self._transformers.get(cls): return transformer transformer = self._transformers[cls] = cls( @@ -540,11 +540,11 @@ def _get_transformer(self, items: TDataItems) -> IncrementalTransform: # Assume list is all of the same type for item in items if isinstance(items, list) else [items]: if is_arrow_item(item): - return self._make_transformer(ArrowIncremental) + return self._make_or_get_transformer(ArrowIncremental) elif pandas is not None and isinstance(item, pandas.DataFrame): - return self._make_transformer(ArrowIncremental) - return self._make_transformer(JsonIncremental) - return self._make_transformer(JsonIncremental) + return self._make_or_get_transformer(ArrowIncremental) + return self._make_or_get_transformer(JsonIncremental) + return self._make_or_get_transformer(JsonIncremental) def __call__(self, rows: TDataItems, meta: Any = None) -> Optional[TDataItems]: if rows is None: diff --git a/tests/extract/test_incremental.py b/tests/extract/test_incremental.py index 5c96098343..3ebc9d1201 100644 --- a/tests/extract/test_incremental.py +++ b/tests/extract/test_incremental.py @@ -3864,7 +3864,7 @@ def test_start_range_open(item_type: TestDataItemFormat, last_value_func: Any) - expected_items = list(range(6, 12)) order_dir = "ASC" elif last_value_func == min: - data_range = reversed(data_range) + data_range = reversed(data_range) # type: ignore[call-overload] initial_value = 5 # Only items lower than inital extracted expected_items = list(reversed(range(1, 5))) diff --git a/tests/load/sources/sql_database/test_sql_database_source.py b/tests/load/sources/sql_database/test_sql_database_source.py index 9079638586..b5d4e000ae 100644 --- a/tests/load/sources/sql_database/test_sql_database_source.py +++ b/tests/load/sources/sql_database/test_sql_database_source.py @@ -13,6 +13,7 @@ from dlt.common.utils import uniq_id from dlt.extract.exceptions import ResourceExtractionError +from dlt.extract.incremental.transform import JsonIncremental, ArrowIncremental from dlt.sources import DltResource from tests.pipeline.utils import ( @@ -831,8 +832,8 @@ def _assert_incremental(item): else: assert _r.incremental.primary_key == ["id"] assert _r.incremental._incremental.primary_key == ["id"] - assert _r.incremental._incremental._transformers["json"].primary_key == ["id"] - assert _r.incremental._incremental._transformers["arrow"].primary_key == ["id"] + assert _r.incremental._incremental._make_or_get_transformer(JsonIncremental).primary_key == ["id"] + assert _r.incremental._incremental._make_or_get_transformer(ArrowIncremental).primary_key == ["id"] return item pipeline = make_pipeline("duckdb") @@ -841,8 +842,8 @@ def _assert_incremental(item): assert resource.incremental.primary_key == ["id"] assert resource.incremental._incremental.primary_key == ["id"] - assert resource.incremental._incremental._transformers["json"].primary_key == ["id"] - assert resource.incremental._incremental._transformers["arrow"].primary_key == ["id"] + assert resource.incremental._incremental._make_or_get_transformer(JsonIncremental).primary_key == ["id"] + assert resource.incremental._incremental._make_or_get_transformer(ArrowIncremental).primary_key == ["id"] @pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"])