Skip to content

Commit

Permalink
recognizes ARRAY data type and converts to nested, fixes minimal nest…
Browse files Browse the repository at this point in the history
…ed types handling for sqlalchemy backend
  • Loading branch information
rudolfix committed Dec 22, 2024
1 parent 5a1cb69 commit 6ae82ba
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 47 deletions.
2 changes: 2 additions & 0 deletions dlt/extract/incremental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,8 @@ def __call__(self, rows: TDataItems, meta: Any = None) -> Optional[TDataItems]:
self._cached_state["last_value"] = transformer.last_value
if not transformer.deduplication_disabled:
# compute hashes for new last rows
# NOTE: object transform uses last_rows to pass rows to dedup, arrow computes
# hashes directly
unique_hashes = set(
transformer.compute_unique_value(row, self.primary_key)
for row in transformer.last_rows
Expand Down
2 changes: 1 addition & 1 deletion dlt/sources/sql_database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ def sql_table(
engine = engine_adapter_callback(engine)
metadata = metadata or MetaData(schema=schema)

skip_nested_on_minimal = backend == "sqlalchemy"
# Table object is only created when reflecting, we don't want empty tables in metadata
# as it breaks foreign key resolution
table_obj = metadata.tables.get(table)
Expand All @@ -222,6 +221,7 @@ def sql_table(
if table_obj is not None:
if not defer_table_reflect:
table_obj = _execute_table_adapter(table_obj, table_adapter_callback, included_columns)
skip_nested_on_minimal = backend == "sqlalchemy"
hints = table_to_resource_hints(
table_obj,
reflection_level,
Expand Down
2 changes: 2 additions & 0 deletions dlt/sources/sql_database/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def table_rows(
table,
reflection_level,
type_adapter_callback,
backend == "sqlalchemy", # skip nested types
resolve_foreign_keys=resolve_foreign_keys,
)

Expand All @@ -285,6 +286,7 @@ def table_rows(
table,
reflection_level,
type_adapter_callback,
backend == "sqlalchemy", # skip nested types
resolve_foreign_keys=resolve_foreign_keys,
)

Expand Down
7 changes: 6 additions & 1 deletion dlt/sources/sql_database/schema_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ def sqla_col_to_column_schema(

if reflection_level == "minimal":
# normalized into subtables
if isinstance(sql_col.type, sqltypes.JSON) and skip_nested_columns_on_minimal:
if (
isinstance(sql_col.type, (sqltypes.JSON, sqltypes.ARRAY))
and skip_nested_columns_on_minimal
):
return None
return col

Expand Down Expand Up @@ -139,6 +142,8 @@ def sqla_col_to_column_schema(
col["data_type"] = "time"
elif isinstance(sql_t, sqltypes.JSON):
col["data_type"] = "json"
elif isinstance(sql_t, sqltypes.ARRAY):
col["data_type"] = "json"
elif isinstance(sql_t, sqltypes.Boolean):
col["data_type"] = "bool"
else:
Expand Down
13 changes: 9 additions & 4 deletions tests/load/sources/sql_database/sql_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
schema as sqla_schema,
)

from sqlalchemy.dialects.postgresql import DATERANGE, JSONB
from sqlalchemy.dialects.postgresql import JSONB

from dlt.common.pendulum import pendulum, timedelta
from dlt.common.utils import chunks, uniq_id
Expand Down Expand Up @@ -182,6 +182,8 @@ def _make_precision_table(table_name: str, nullable: bool) -> None:
Column("float_col", Float, nullable=nullable),
Column("json_col", JSONB, nullable=nullable),
Column("bool_col", Boolean, nullable=nullable),
Column("uuid_col", Uuid, nullable=nullable),
Column("array_col", ARRAY(Integer), nullable=nullable),
)

_make_precision_table("has_precision", False)
Expand All @@ -194,7 +196,7 @@ def _make_precision_table(table_name: str, nullable: bool) -> None:
# Column("unsupported_daterange_1", DATERANGE, nullable=False),
Column("supported_text", Text, nullable=False),
Column("supported_int", Integer, nullable=False),
Column("unsupported_array_1", ARRAY(Integer), nullable=False),
# Column("unsupported_array_1", ARRAY(Integer), nullable=False),
# Column("supported_datetime", DateTime(timezone=True), nullable=False),
)

Expand Down Expand Up @@ -325,8 +327,11 @@ def _fake_precision_data(self, table_name: str, n: int = 100, null_n: int = 0) -
date_col=mimesis.Datetime().date(),
time_col=mimesis.Datetime().time(),
float_col=random.random(),
json_col='{"data": [1, 2, 3]}', # NOTE: can we do this?
# NOTE: do not use strings. pandas mangles them (or spend time adding ifs to tests)
json_col=[1, 2.1, -1.1],
bool_col=random.randint(0, 1) == 1,
uuid_col=str(uuid4()) if Uuid is str else uuid4(),
array_col=[1, 2, 3],
)
for _ in range(n + null_n)
]
Expand All @@ -350,7 +355,7 @@ def _fake_unsupported_data(self, n: int = 100) -> None:
# unsupported_daterange_1="[2020-01-01, 2020-09-01]",
supported_text=mimesis.Text().word(),
supported_int=random.randint(0, 100),
unsupported_array_1=[1, 2, 3],
# unsupported_array_1=[1, 2, 3],
# supported_datetime="2015-08-12T01:25:22.468126+0100",
)
for _ in range(n)
Expand Down
18 changes: 13 additions & 5 deletions tests/load/sources/sql_database/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Any, TYPE_CHECKING
from functools import partial
from typing import Callable, Any, Literal
from dataclasses import dataclass

import pytest
Expand Down Expand Up @@ -409,24 +410,31 @@ def test_make_query_incremental_range_end_closed(
assert query.compare(expected)


def mock_json_column(field: str) -> TDataItem:
def mock_column(field: str, mock_type: Literal["json", "array"] = "json") -> TDataItem:
""""""
import pyarrow as pa
import pandas as pd

json_mock_str = '{"data": [1, 2, 3]}'
if mock_type == "json":
mock_str = '{"data": [1, 2, 3]}'
elif mock_type == "array":
mock_str = "[1, 2, 3]"

def _unwrap(table: TDataItem) -> TDataItem:
if isinstance(table, pd.DataFrame):
table[field] = [None if s is None else json_mock_str for s in table[field]]
table[field] = [None if s is None else mock_str for s in table[field]]
return table
else:
col_index = table.column_names.index(field)
json_str_array = pa.array([None if s is None else json_mock_str for s in table[field]])
json_str_array = pa.array([None if s is None else mock_str for s in table[field]])
return table.set_column(
col_index,
pa.field(field, pa.string(), nullable=table.schema.field(field).nullable),
json_str_array,
)

return _unwrap


mock_json_column = partial(mock_column, mock_type="json")
mock_array_column = partial(mock_column, mock_type="array")
51 changes: 16 additions & 35 deletions tests/load/sources/sql_database/test_sql_database_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
assert_schema_on_data,
load_tables_to_dicts,
)
from tests.load.sources.sql_database.test_helpers import mock_json_column
from tests.load.sources.sql_database.test_helpers import mock_json_column, mock_array_column
from tests.utils import data_item_length, load_table_counts


Expand Down Expand Up @@ -535,8 +535,11 @@ def dummy_source():
expected_col_names = [col["name"] for col in PRECISION_COLUMNS]

# on sqlalchemy json col is not written to schema if no types are discovered
if backend == "sqlalchemy" and reflection_level == "minimal" and not with_defer:
expected_col_names = [col for col in expected_col_names if col != "json_col"]
# nested types are converted into nested tables, not columns
if backend == "sqlalchemy" and reflection_level == "minimal":
expected_col_names = [
col for col in expected_col_names if col not in ("json_col", "array_col")
]

assert col_names == expected_col_names

Expand Down Expand Up @@ -869,6 +872,7 @@ def test_deferred_reflect_in_source(
# mock the right json values for backends not supporting it
if backend in ("connectorx", "pandas"):
source.resources["has_precision"].add_map(mock_json_column("json_col"))
source.resources["has_precision"].add_map(mock_array_column("array_col"))

# no columns in both tables
assert source.has_precision.columns == {}
Expand Down Expand Up @@ -926,6 +930,7 @@ def test_deferred_reflect_in_resource(
# mock the right json values for backends not supporting it
if backend in ("connectorx", "pandas"):
table.add_map(mock_json_column("json_col"))
table.add_map(mock_array_column("array_col"))

# no columns in both tables
assert table.columns == {}
Expand Down Expand Up @@ -1041,28 +1046,17 @@ def test_sql_database_include_view_in_table_names(
@pytest.mark.parametrize("backend", ["pyarrow", "pandas", "sqlalchemy"])
@pytest.mark.parametrize("standalone_resource", [True, False])
@pytest.mark.parametrize("reflection_level", ["minimal", "full", "full_with_precision"])
@pytest.mark.parametrize("type_adapter", [True, False])
def test_infer_unsupported_types(
sql_source_db_unsupported_types: SQLAlchemySourceDB,
backend: TableBackend,
reflection_level: ReflectionLevel,
standalone_resource: bool,
type_adapter: bool,
) -> None:
def type_adapter_callback(t):
if isinstance(t, sa.ARRAY):
return sa.JSON
return t

if backend == "pyarrow" and type_adapter:
pytest.skip("Arrow does not support type adapter for arrays")

common_kwargs = dict(
credentials=sql_source_db_unsupported_types.credentials,
schema=sql_source_db_unsupported_types.schema,
reflection_level=reflection_level,
backend=backend,
type_adapter_callback=type_adapter_callback if type_adapter else None,
)
if standalone_resource:

Expand All @@ -1084,40 +1078,19 @@ def dummy_source():

pipeline = make_pipeline("duckdb")
pipeline.extract(source)

columns = pipeline.default_schema.tables["has_unsupported_types"]["columns"]

pipeline.normalize()
pipeline.load()

assert_row_counts(pipeline, sql_source_db_unsupported_types, ["has_unsupported_types"])

schema = pipeline.default_schema
assert "has_unsupported_types" in schema.tables
columns = schema.tables["has_unsupported_types"]["columns"]

rows = load_tables_to_dicts(pipeline, "has_unsupported_types")["has_unsupported_types"]

if backend == "pyarrow":
# TODO: duckdb writes structs as strings (not json encoded) to json columns
# Just check that it has a value

assert isinstance(json.loads(rows[0]["unsupported_array_1"]), list)
assert columns["unsupported_array_1"]["data_type"] == "json"
# Other columns are loaded
assert isinstance(rows[0]["supported_text"], str)
assert isinstance(rows[0]["supported_int"], int)
elif backend == "sqlalchemy":
# sqla value is a dataclass and is inferred as json

assert columns["unsupported_array_1"]["data_type"] == "json"

elif backend == "pandas":
# pandas parses it as string
if type_adapter and reflection_level != "minimal":
assert columns["unsupported_array_1"]["data_type"] == "json"

assert isinstance(json.loads(rows[0]["unsupported_array_1"]), list)


@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"])
Expand Down Expand Up @@ -1438,6 +1411,14 @@ def add_default_decimal_precision(columns: List[TColumnSchema]) -> List[TColumnS
"data_type": "bool",
"name": "bool_col",
},
{
"data_type": "text",
"name": "uuid_col",
},
{
"data_type": "json",
"name": "array_col",
},
]

NOT_NULL_PRECISION_COLUMNS = [{"nullable": False, **column} for column in PRECISION_COLUMNS]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

try:
from dlt.sources.sql_database import TableBackend, sql_database, sql_table
from tests.load.sources.sql_database.test_helpers import mock_json_column
from tests.load.sources.sql_database.test_helpers import mock_json_column, mock_array_column
from tests.load.sources.sql_database.test_sql_database_source import (
assert_row_counts,
convert_time_to_us,
Expand Down Expand Up @@ -63,6 +63,9 @@ def test_load_sql_schema_loads_all_tables(
# always use mock json
source.has_precision.add_map(mock_json_column("json_col"))
source.has_precision_nullable.add_map(mock_json_column("json_col"))
# always use mock array
source.has_precision.add_map(mock_array_column("array_col"))
source.has_precision_nullable.add_map(mock_array_column("array_col"))

assert "chat_message_view" not in source.resources # Views are not reflected by default

Expand Down Expand Up @@ -103,6 +106,9 @@ def test_load_sql_schema_loads_all_tables_parallel(
# always use mock json
source.has_precision.add_map(mock_json_column("json_col"))
source.has_precision_nullable.add_map(mock_json_column("json_col"))
# always use mock array
source.has_precision.add_map(mock_array_column("array_col"))
source.has_precision_nullable.add_map(mock_array_column("array_col"))

load_info = pipeline.run(source)
print(humanize.precisedelta(pipeline.last_trace.finished_at - pipeline.last_trace.started_at))
Expand Down

0 comments on commit 6ae82ba

Please sign in to comment.