Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic improvements #901

Merged
merged 36 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
0afb273
add example tests
sh-rp Jan 18, 2024
001aa94
Add sub-model hints for pydantic
Jan 19, 2024
89dcd20
Handle list of pydantic models field type
Jan 22, 2024
3cbd490
Add container_type field to TColumnSchema
Jan 22, 2024
3b7fd4e
Revert handling for the list of pydantic models
Jan 29, 2024
ba7814c
Use snake_case naming convention to generate field names
Jan 29, 2024
9acab62
Remove test
Jan 29, 2024
980f1f2
Add comment
Jan 29, 2024
aed23bb
Adjust logic
Jan 30, 2024
3a79838
Add more assertions
Jan 30, 2024
5a0a4cb
Fix mypy linting issue
Jan 30, 2024
38959ac
Fix mypy linting issue
Jan 30, 2024
17a2821
Remove unused code
Jan 30, 2024
0fd9f5d
Add duckdb to extra installs
Jan 30, 2024
ed85e47
Remove trailing spaces
Jan 30, 2024
fb44e9e
Add duckdb extras
Jan 31, 2024
cf5a279
Add duckdb extra to common tests workflow
Jan 31, 2024
650d0c6
Detect pydantic model in try..catch block
Feb 6, 2024
4f59c19
Swap if branch conditions
Feb 6, 2024
7780f5a
Revert old changes
Feb 6, 2024
18f8858
Revert old changes
Feb 6, 2024
41f93ca
Remove rednundant if branches and adjust tests
Feb 6, 2024
091db6d
Revert some changes
Feb 6, 2024
dd8e766
Remove duplicate test
Feb 6, 2024
7bd766c
Enable test only with duckdb available
Feb 6, 2024
8fbe6f8
Return case when explicit complex types for field is expected
Feb 6, 2024
ea969ac
Revert
Feb 6, 2024
d73ff81
Mark test for duckdb only
Feb 6, 2024
91e5553
Remove duplicate import
Feb 6, 2024
ebe6ced
Move tests to pipeline extra
Feb 7, 2024
b7e89f4
Simplify tests
Feb 7, 2024
5acc282
Add unit tests
Feb 7, 2024
879d338
Fix typing issues
Feb 7, 2024
841aafa
Check if type=complex for lists
Feb 7, 2024
2be9de4
Add one more test case
Feb 7, 2024
bac3fac
Add one more test case
Feb 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
defaults:
run:
shell: bash
runs-on: ${{ matrix.os }}
runs-on: ${{ matrix.os }}

steps:

Expand All @@ -42,7 +42,7 @@ jobs:
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
installer-parallel: true

- name: Load cached venv
id: cached-poetry-dependencies
Expand All @@ -57,7 +57,7 @@ jobs:

- name: Run make lint
run: |
export PATH=$PATH:"/c/Program Files/usr/bin" # needed for Windows
export PATH=$PATH:"/c/Program Files/usr/bin" # needed for Windows
make lint

# - name: print envs
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_destination_athena_iceberg.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:

- name: Install dependencies
# if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction -E --with sentry-sdk --with pipeline
run: poetry install --no-interaction -E --with sentry-sdk --with pipeline

- name: create secrets.toml
run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml
Expand Down
56 changes: 42 additions & 14 deletions dlt/common/libs/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
)
from typing_extensions import Annotated, get_args, get_origin

from dlt.common.data_types import py_type_to_sc_type
from dlt.common.exceptions import MissingDependencyException
from dlt.common.schema import DataValidationError
from dlt.common.schema.typing import TSchemaEvolutionMode, TTableSchemaColumns
from dlt.common.data_types import py_type_to_sc_type
from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention
from dlt.common.typing import (
TDataItem,
TDataItems,
Expand Down Expand Up @@ -52,6 +53,9 @@
_TPydanticModel = TypeVar("_TPydanticModel", bound=BaseModel)


snake_case_naming_convention = SnakeCaseNamingConvention()


class ListModel(BaseModel, Generic[_TPydanticModel]):
items: List[_TPydanticModel]

Expand All @@ -71,7 +75,7 @@ class DltConfig(TypedDict, total=False):


def pydantic_to_table_schema_columns(
model: Union[BaseModel, Type[BaseModel]]
model: Union[BaseModel, Type[BaseModel]],
) -> TTableSchemaColumns:
"""Convert a pydantic model to a table schema columns dict

Expand Down Expand Up @@ -111,24 +115,47 @@ def pydantic_to_table_schema_columns(

if is_list_generic_type(inner_type):
inner_type = list
elif is_dict_generic_type(inner_type) or issubclass(inner_type, BaseModel):
elif is_dict_generic_type(inner_type):
inner_type = dict

is_inner_type_pydantic_model = False
name = field.alias or field_name
try:
data_type = py_type_to_sc_type(inner_type)
except TypeError:
# try to coerce unknown type to text
data_type = "text"

if data_type == "complex" and skip_complex_types:
sultaniman marked this conversation as resolved.
Show resolved Hide resolved
if issubclass(inner_type, BaseModel):
data_type = "complex"
is_inner_type_pydantic_model = True
else:
# try to coerce unknown type to text
data_type = "text"

if is_inner_type_pydantic_model and not skip_complex_types:
result[name] = {
"name": name,
"data_type": "complex",
"nullable": nullable,
}
elif is_inner_type_pydantic_model:
# This case is for a single field schema/model
# we need to generate snake_case field names
# and return flattened field schemas
schema_hints = pydantic_to_table_schema_columns(field.annotation)

for field_name, hints in schema_hints.items():
schema_key = snake_case_naming_convention.make_path(name, field_name)
result[schema_key] = {
**hints,
"name": snake_case_naming_convention.make_path(name, hints["name"]),
}
elif data_type == "complex" and skip_complex_types:
continue

result[name] = {
"name": name,
"data_type": data_type,
"nullable": nullable,
}
else:
result[name] = {
"name": name,
"data_type": data_type,
"nullable": nullable,
}

return result

Expand Down Expand Up @@ -261,7 +288,8 @@ def create_list_model(
# TODO: use LenientList to create list model that automatically discards invalid items
# https://github.com/pydantic/pydantic/issues/2274 and https://gist.github.com/dmontagu/7f0cef76e5e0e04198dd608ad7219573
return create_model(
"List" + __name__, items=(List[model], ...) # type: ignore[return-value,valid-type]
"List" + __name__,
items=(List[model], ...), # type: ignore[return-value,valid-type]
)


Expand Down
1 change: 1 addition & 0 deletions dlt/common/schema/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Optional,
Sequence,
Set,
Tuple,
Type,
TypedDict,
NewType,
Expand Down
6 changes: 5 additions & 1 deletion dlt/extract/hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ def compute_table_schema(self, item: TDataItem = None) -> TTableSchema:
if self._table_name_hint_fun and item is None:
raise DataItemRequiredForDynamicTableHints(self.name)
# resolve
resolved_template: TResourceHints = {k: self._resolve_hint(item, v) for k, v in table_template.items() if k not in ["incremental", "validator", "original_columns"]} # type: ignore
resolved_template: TResourceHints = {
k: self._resolve_hint(item, v)
for k, v in table_template.items()
if k not in ["incremental", "validator", "original_columns"]
} # type: ignore
table_schema = self._merge_keys(resolved_template)
table_schema["resource"] = self.name
validate_dict_ignoring_xkeys(
Expand Down
1 change: 0 additions & 1 deletion tests/libs/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
Union,
Optional,
List,
Dict,
Any,
)
from typing_extensions import Annotated, get_args, get_origin
Expand Down
154 changes: 154 additions & 0 deletions tests/pipeline/test_pipeline_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,157 @@ def generic(start=8):

pipeline = dlt.pipeline(destination="duckdb")
pipeline.run(generic(), loader_file_format=file_format)


class Child(BaseModel):
child_attribute: str
optional_child_attribute: Optional[str] = None


@pytest.mark.parametrize(
sultaniman marked this conversation as resolved.
Show resolved Hide resolved
"destination_config",
destinations_configs(default_sql_configs=True, subset=["duckdb"]),
ids=lambda x: x.name,
)
def test_flattens_model_when_skip_complex_types_is_set(
destination_config: DestinationTestConfiguration,
) -> None:
class Parent(BaseModel):
child: Child
optional_parent_attribute: Optional[str] = None
dlt_config: ClassVar[DltConfig] = {"skip_complex_types": True}

example_data = {
"optional_parent_attribute": None,
"child": {
"child_attribute": "any string",
"optional_child_attribute": None,
},
}

@dlt.resource
def res():
yield [example_data]

@dlt.source(max_table_nesting=1)
def src():
yield res()

p = destination_config.setup_pipeline("example", full_refresh=True)
sultaniman marked this conversation as resolved.
Show resolved Hide resolved
p.run(src(), table_name="items", columns=Parent)

with p.sql_client() as client:
with client.execute_query("SELECT * FROM items") as cursor:
loaded_values = {
col[0]: val
for val, col in zip(cursor.fetchall()[0], cursor.description)
if col[0] not in ("_dlt_id", "_dlt_load_id")
}
assert loaded_values == {
"child__child_attribute": "any string",
"child__optional_child_attribute": None,
"optional_parent_attribute": None,
}

keys = p.default_schema.tables["items"]["columns"].keys()
columns = p.default_schema.tables["items"]["columns"]

assert keys == {
"child__child_attribute",
"child__optional_child_attribute",
"optional_parent_attribute",
"_dlt_load_id",
"_dlt_id",
}

assert columns["child__child_attribute"] == {
"name": "child__child_attribute",
"data_type": "text",
"nullable": False,
}

assert columns["child__optional_child_attribute"] == {
"name": "child__optional_child_attribute",
"data_type": "text",
"nullable": True,
}

assert columns["optional_parent_attribute"] == {
"name": "optional_parent_attribute",
"data_type": "text",
"nullable": True,
}


@pytest.mark.parametrize(
"destination_config",
destinations_configs(default_sql_configs=True, subset=["duckdb"]),
ids=lambda x: x.name,
)
def test_flattens_model_when_skip_complex_types_is_not_set(
destination_config: DestinationTestConfiguration,
):
class Parent(BaseModel):
child: Child
optional_parent_attribute: Optional[str] = None
data_dictionary: Dict[str, Any] = None
dlt_config: ClassVar[DltConfig] = {"skip_complex_types": False}

example_data = {
"optional_parent_attribute": None,
"data_dictionary": {
"child_attribute": "any string",
},
"child": {
"child_attribute": "any string",
"optional_child_attribute": None,
},
}

@dlt.resource
def res():
yield [example_data]

@dlt.source(max_table_nesting=1)
sultaniman marked this conversation as resolved.
Show resolved Hide resolved
def src():
yield res()

p = destination_config.setup_pipeline("example", full_refresh=True)
p.run(src(), table_name="items", columns=Parent)

with p.sql_client() as client:
with client.execute_query("SELECT * FROM items") as cursor:
loaded_values = {
col[0]: val
for val, col in zip(cursor.fetchall()[0], cursor.description)
if col[0] not in ("_dlt_id", "_dlt_load_id")
}

assert loaded_values == {
sultaniman marked this conversation as resolved.
Show resolved Hide resolved
"child": '{"child_attribute":"any string","optional_child_attribute":null}',
"optional_parent_attribute": None,
"data_dictionary": '{"child_attribute":"any string"}',
}

keys = p.default_schema.tables["items"]["columns"].keys()
assert keys == {
"child",
"optional_parent_attribute",
"data_dictionary",
"_dlt_load_id",
"_dlt_id",
}

columns = p.default_schema.tables["items"]["columns"]

assert columns["optional_parent_attribute"] == {
"name": "optional_parent_attribute",
"data_type": "text",
"nullable": True,
}

assert columns["data_dictionary"] == {
"name": "data_dictionary",
"data_type": "complex",
"nullable": False,
}
Loading