From 6ae82ba82b6ba6d17cd9e3bf53a6446339360311 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sun, 22 Dec 2024 23:25:32 +0100 Subject: [PATCH] recognizes ARRAY data type and converts to nested, fixes minimal nested types handling for sqlalchemy backend --- dlt/extract/incremental/__init__.py | 2 + dlt/sources/sql_database/__init__.py | 2 +- dlt/sources/sql_database/helpers.py | 2 + dlt/sources/sql_database/schema_types.py | 7 ++- tests/load/sources/sql_database/sql_source.py | 13 +++-- .../load/sources/sql_database/test_helpers.py | 18 +++++-- .../sql_database/test_sql_database_source.py | 51 ++++++------------- ...st_sql_database_source_all_destinations.py | 8 ++- 8 files changed, 56 insertions(+), 47 deletions(-) diff --git a/dlt/extract/incremental/__init__.py b/dlt/extract/incremental/__init__.py index ce06292864..86ba34e6c7 100644 --- a/dlt/extract/incremental/__init__.py +++ b/dlt/extract/incremental/__init__.py @@ -564,6 +564,8 @@ def __call__(self, rows: TDataItems, meta: Any = None) -> Optional[TDataItems]: self._cached_state["last_value"] = transformer.last_value if not transformer.deduplication_disabled: # compute hashes for new last rows + # NOTE: object transform uses last_rows to pass rows to dedup, arrow computes + # hashes directly unique_hashes = set( transformer.compute_unique_value(row, self.primary_key) for row in transformer.last_rows diff --git a/dlt/sources/sql_database/__init__.py b/dlt/sources/sql_database/__init__.py index f7d238641e..c37b956c20 100644 --- a/dlt/sources/sql_database/__init__.py +++ b/dlt/sources/sql_database/__init__.py @@ -212,7 +212,6 @@ def sql_table( engine = engine_adapter_callback(engine) metadata = metadata or MetaData(schema=schema) - skip_nested_on_minimal = backend == "sqlalchemy" # Table object is only created when reflecting, we don't want empty tables in metadata # as it breaks foreign key resolution table_obj = metadata.tables.get(table) @@ -222,6 +221,7 @@ def sql_table( if table_obj is not None: if not defer_table_reflect: table_obj = _execute_table_adapter(table_obj, table_adapter_callback, included_columns) + skip_nested_on_minimal = backend == "sqlalchemy" hints = table_to_resource_hints( table_obj, reflection_level, diff --git a/dlt/sources/sql_database/helpers.py b/dlt/sources/sql_database/helpers.py index ee38c7dd98..b6b14c93bb 100644 --- a/dlt/sources/sql_database/helpers.py +++ b/dlt/sources/sql_database/helpers.py @@ -263,6 +263,7 @@ def table_rows( table, reflection_level, type_adapter_callback, + backend == "sqlalchemy", # skip nested types resolve_foreign_keys=resolve_foreign_keys, ) @@ -285,6 +286,7 @@ def table_rows( table, reflection_level, type_adapter_callback, + backend == "sqlalchemy", # skip nested types resolve_foreign_keys=resolve_foreign_keys, ) diff --git a/dlt/sources/sql_database/schema_types.py b/dlt/sources/sql_database/schema_types.py index e75cbdc5d5..1cca0b81bb 100644 --- a/dlt/sources/sql_database/schema_types.py +++ b/dlt/sources/sql_database/schema_types.py @@ -78,7 +78,10 @@ def sqla_col_to_column_schema( if reflection_level == "minimal": # normalized into subtables - if isinstance(sql_col.type, sqltypes.JSON) and skip_nested_columns_on_minimal: + if ( + isinstance(sql_col.type, (sqltypes.JSON, sqltypes.ARRAY)) + and skip_nested_columns_on_minimal + ): return None return col @@ -139,6 +142,8 @@ def sqla_col_to_column_schema( col["data_type"] = "time" elif isinstance(sql_t, sqltypes.JSON): col["data_type"] = "json" + elif isinstance(sql_t, sqltypes.ARRAY): + col["data_type"] = "json" elif isinstance(sql_t, sqltypes.Boolean): col["data_type"] = "bool" else: diff --git a/tests/load/sources/sql_database/sql_source.py b/tests/load/sources/sql_database/sql_source.py index 7f2deaf13c..3f8b89a2b5 100644 --- a/tests/load/sources/sql_database/sql_source.py +++ b/tests/load/sources/sql_database/sql_source.py @@ -39,7 +39,7 @@ schema as sqla_schema, ) -from sqlalchemy.dialects.postgresql import DATERANGE, JSONB +from sqlalchemy.dialects.postgresql import JSONB from dlt.common.pendulum import pendulum, timedelta from dlt.common.utils import chunks, uniq_id @@ -182,6 +182,8 @@ def _make_precision_table(table_name: str, nullable: bool) -> None: Column("float_col", Float, nullable=nullable), Column("json_col", JSONB, nullable=nullable), Column("bool_col", Boolean, nullable=nullable), + Column("uuid_col", Uuid, nullable=nullable), + Column("array_col", ARRAY(Integer), nullable=nullable), ) _make_precision_table("has_precision", False) @@ -194,7 +196,7 @@ def _make_precision_table(table_name: str, nullable: bool) -> None: # Column("unsupported_daterange_1", DATERANGE, nullable=False), Column("supported_text", Text, nullable=False), Column("supported_int", Integer, nullable=False), - Column("unsupported_array_1", ARRAY(Integer), nullable=False), + # Column("unsupported_array_1", ARRAY(Integer), nullable=False), # Column("supported_datetime", DateTime(timezone=True), nullable=False), ) @@ -325,8 +327,11 @@ def _fake_precision_data(self, table_name: str, n: int = 100, null_n: int = 0) - date_col=mimesis.Datetime().date(), time_col=mimesis.Datetime().time(), float_col=random.random(), - json_col='{"data": [1, 2, 3]}', # NOTE: can we do this? + # NOTE: do not use strings. pandas mangles them (or spend time adding ifs to tests) + json_col=[1, 2.1, -1.1], bool_col=random.randint(0, 1) == 1, + uuid_col=str(uuid4()) if Uuid is str else uuid4(), + array_col=[1, 2, 3], ) for _ in range(n + null_n) ] @@ -350,7 +355,7 @@ def _fake_unsupported_data(self, n: int = 100) -> None: # unsupported_daterange_1="[2020-01-01, 2020-09-01]", supported_text=mimesis.Text().word(), supported_int=random.randint(0, 100), - unsupported_array_1=[1, 2, 3], + # unsupported_array_1=[1, 2, 3], # supported_datetime="2015-08-12T01:25:22.468126+0100", ) for _ in range(n) diff --git a/tests/load/sources/sql_database/test_helpers.py b/tests/load/sources/sql_database/test_helpers.py index 43da9c955f..81c2b3c2e0 100644 --- a/tests/load/sources/sql_database/test_helpers.py +++ b/tests/load/sources/sql_database/test_helpers.py @@ -1,4 +1,5 @@ -from typing import Callable, Any, TYPE_CHECKING +from functools import partial +from typing import Callable, Any, Literal from dataclasses import dataclass import pytest @@ -409,20 +410,23 @@ def test_make_query_incremental_range_end_closed( assert query.compare(expected) -def mock_json_column(field: str) -> TDataItem: +def mock_column(field: str, mock_type: Literal["json", "array"] = "json") -> TDataItem: """""" import pyarrow as pa import pandas as pd - json_mock_str = '{"data": [1, 2, 3]}' + if mock_type == "json": + mock_str = '{"data": [1, 2, 3]}' + elif mock_type == "array": + mock_str = "[1, 2, 3]" def _unwrap(table: TDataItem) -> TDataItem: if isinstance(table, pd.DataFrame): - table[field] = [None if s is None else json_mock_str for s in table[field]] + table[field] = [None if s is None else mock_str for s in table[field]] return table else: col_index = table.column_names.index(field) - json_str_array = pa.array([None if s is None else json_mock_str for s in table[field]]) + json_str_array = pa.array([None if s is None else mock_str for s in table[field]]) return table.set_column( col_index, pa.field(field, pa.string(), nullable=table.schema.field(field).nullable), @@ -430,3 +434,7 @@ def _unwrap(table: TDataItem) -> TDataItem: ) return _unwrap + + +mock_json_column = partial(mock_column, mock_type="json") +mock_array_column = partial(mock_column, mock_type="array") diff --git a/tests/load/sources/sql_database/test_sql_database_source.py b/tests/load/sources/sql_database/test_sql_database_source.py index 2de923fe38..d709ace086 100644 --- a/tests/load/sources/sql_database/test_sql_database_source.py +++ b/tests/load/sources/sql_database/test_sql_database_source.py @@ -21,7 +21,7 @@ assert_schema_on_data, load_tables_to_dicts, ) -from tests.load.sources.sql_database.test_helpers import mock_json_column +from tests.load.sources.sql_database.test_helpers import mock_json_column, mock_array_column from tests.utils import data_item_length, load_table_counts @@ -535,8 +535,11 @@ def dummy_source(): expected_col_names = [col["name"] for col in PRECISION_COLUMNS] # on sqlalchemy json col is not written to schema if no types are discovered - if backend == "sqlalchemy" and reflection_level == "minimal" and not with_defer: - expected_col_names = [col for col in expected_col_names if col != "json_col"] + # nested types are converted into nested tables, not columns + if backend == "sqlalchemy" and reflection_level == "minimal": + expected_col_names = [ + col for col in expected_col_names if col not in ("json_col", "array_col") + ] assert col_names == expected_col_names @@ -869,6 +872,7 @@ def test_deferred_reflect_in_source( # mock the right json values for backends not supporting it if backend in ("connectorx", "pandas"): source.resources["has_precision"].add_map(mock_json_column("json_col")) + source.resources["has_precision"].add_map(mock_array_column("array_col")) # no columns in both tables assert source.has_precision.columns == {} @@ -926,6 +930,7 @@ def test_deferred_reflect_in_resource( # mock the right json values for backends not supporting it if backend in ("connectorx", "pandas"): table.add_map(mock_json_column("json_col")) + table.add_map(mock_array_column("array_col")) # no columns in both tables assert table.columns == {} @@ -1041,28 +1046,17 @@ def test_sql_database_include_view_in_table_names( @pytest.mark.parametrize("backend", ["pyarrow", "pandas", "sqlalchemy"]) @pytest.mark.parametrize("standalone_resource", [True, False]) @pytest.mark.parametrize("reflection_level", ["minimal", "full", "full_with_precision"]) -@pytest.mark.parametrize("type_adapter", [True, False]) def test_infer_unsupported_types( sql_source_db_unsupported_types: SQLAlchemySourceDB, backend: TableBackend, reflection_level: ReflectionLevel, standalone_resource: bool, - type_adapter: bool, ) -> None: - def type_adapter_callback(t): - if isinstance(t, sa.ARRAY): - return sa.JSON - return t - - if backend == "pyarrow" and type_adapter: - pytest.skip("Arrow does not support type adapter for arrays") - common_kwargs = dict( credentials=sql_source_db_unsupported_types.credentials, schema=sql_source_db_unsupported_types.schema, reflection_level=reflection_level, backend=backend, - type_adapter_callback=type_adapter_callback if type_adapter else None, ) if standalone_resource: @@ -1084,9 +1078,6 @@ def dummy_source(): pipeline = make_pipeline("duckdb") pipeline.extract(source) - - columns = pipeline.default_schema.tables["has_unsupported_types"]["columns"] - pipeline.normalize() pipeline.load() @@ -1094,30 +1085,12 @@ def dummy_source(): schema = pipeline.default_schema assert "has_unsupported_types" in schema.tables - columns = schema.tables["has_unsupported_types"]["columns"] rows = load_tables_to_dicts(pipeline, "has_unsupported_types")["has_unsupported_types"] if backend == "pyarrow": - # TODO: duckdb writes structs as strings (not json encoded) to json columns - # Just check that it has a value - - assert isinstance(json.loads(rows[0]["unsupported_array_1"]), list) - assert columns["unsupported_array_1"]["data_type"] == "json" - # Other columns are loaded assert isinstance(rows[0]["supported_text"], str) assert isinstance(rows[0]["supported_int"], int) - elif backend == "sqlalchemy": - # sqla value is a dataclass and is inferred as json - - assert columns["unsupported_array_1"]["data_type"] == "json" - - elif backend == "pandas": - # pandas parses it as string - if type_adapter and reflection_level != "minimal": - assert columns["unsupported_array_1"]["data_type"] == "json" - - assert isinstance(json.loads(rows[0]["unsupported_array_1"]), list) @pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) @@ -1438,6 +1411,14 @@ def add_default_decimal_precision(columns: List[TColumnSchema]) -> List[TColumnS "data_type": "bool", "name": "bool_col", }, + { + "data_type": "text", + "name": "uuid_col", + }, + { + "data_type": "json", + "name": "array_col", + }, ] NOT_NULL_PRECISION_COLUMNS = [{"nullable": False, **column} for column in PRECISION_COLUMNS] diff --git a/tests/load/sources/sql_database/test_sql_database_source_all_destinations.py b/tests/load/sources/sql_database/test_sql_database_source_all_destinations.py index 4f4e876fb6..004366b145 100644 --- a/tests/load/sources/sql_database/test_sql_database_source_all_destinations.py +++ b/tests/load/sources/sql_database/test_sql_database_source_all_destinations.py @@ -20,7 +20,7 @@ try: from dlt.sources.sql_database import TableBackend, sql_database, sql_table - from tests.load.sources.sql_database.test_helpers import mock_json_column + from tests.load.sources.sql_database.test_helpers import mock_json_column, mock_array_column from tests.load.sources.sql_database.test_sql_database_source import ( assert_row_counts, convert_time_to_us, @@ -63,6 +63,9 @@ def test_load_sql_schema_loads_all_tables( # always use mock json source.has_precision.add_map(mock_json_column("json_col")) source.has_precision_nullable.add_map(mock_json_column("json_col")) + # always use mock array + source.has_precision.add_map(mock_array_column("array_col")) + source.has_precision_nullable.add_map(mock_array_column("array_col")) assert "chat_message_view" not in source.resources # Views are not reflected by default @@ -103,6 +106,9 @@ def test_load_sql_schema_loads_all_tables_parallel( # always use mock json source.has_precision.add_map(mock_json_column("json_col")) source.has_precision_nullable.add_map(mock_json_column("json_col")) + # always use mock array + source.has_precision.add_map(mock_array_column("array_col")) + source.has_precision_nullable.add_map(mock_array_column("array_col")) load_info = pipeline.run(source) print(humanize.precisedelta(pipeline.last_trace.finished_at - pipeline.last_trace.started_at))