Skip to content

Commit

Permalink
add pydantic contracts implementation tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Jan 16, 2024
1 parent 9b05798 commit 6c4226d
Showing 1 changed file with 104 additions and 10 deletions.
114 changes: 104 additions & 10 deletions tests/pipeline/test_schema_contracts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dlt, os, pytest
import contextlib
from typing import Any, Callable, Iterator, Union, Optional
from typing import Any, Callable, Iterator, Union, Optional, Type

from dlt.common.schema.typing import TSchemaContract
from dlt.common.utils import uniq_id
Expand All @@ -9,6 +9,7 @@
from dlt.extract import DltResource
from dlt.pipeline.pipeline import Pipeline
from dlt.pipeline.exceptions import PipelineStepFailed
from dlt.extract.exceptions import ResourceExtractionError

from tests.load.pipeline.utils import load_table_counts
from tests.utils import (
Expand All @@ -26,17 +27,19 @@


@contextlib.contextmanager
def raises_frozen_exception(check_raise: bool = True) -> Any:
def raises_step_exception(check_raise: bool = True, expected_nested_error: Type[Any] = None) -> Any:
expected_nested_error = expected_nested_error or DataValidationError
if not check_raise:
yield
return
with pytest.raises(PipelineStepFailed) as py_exc:
yield
if py_exc.value.step == "extract":
assert isinstance(py_exc.value.__context__, DataValidationError)
print(type(py_exc.value.__context__))
assert isinstance(py_exc.value.__context__, expected_nested_error)
else:
# normalize
assert isinstance(py_exc.value.__context__.__context__, DataValidationError)
assert isinstance(py_exc.value.__context__.__context__, expected_nested_error)


def items(settings: TSchemaContract) -> Any:
Expand Down Expand Up @@ -94,6 +97,7 @@ def load_items():
VARIANT_COLUMN_NAME = "some_int__v_text"
SUBITEMS_TABLE = "items__sub_items"
NEW_ITEMS_TABLE = "new_items"
ITEMS_TABLE = "items"


def run_resource(
Expand Down Expand Up @@ -171,7 +175,7 @@ def test_new_tables(
assert NEW_COLUMN_NAME in pipeline.default_schema.tables["items"]["columns"]

# test adding new table
with raises_frozen_exception(contract_setting == "freeze"):
with raises_step_exception(contract_setting == "freeze"):
run_resource(pipeline, new_items, full_settings, item_format)
table_counts = load_table_counts(
pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]
Expand All @@ -191,7 +195,7 @@ def test_new_tables(
assert VARIANT_COLUMN_NAME in pipeline.default_schema.tables["items"]["columns"]

# test adding new subtable
with raises_frozen_exception(contract_setting == "freeze"):
with raises_step_exception(contract_setting == "freeze"):
run_resource(pipeline, items_with_subtable, full_settings)

table_counts = load_table_counts(
Expand Down Expand Up @@ -227,7 +231,7 @@ def test_new_columns(
assert table_counts[NEW_ITEMS_TABLE] == 10

# test adding new column twice: filter will try to catch it before it is added for the second time
with raises_frozen_exception(contract_setting == "freeze"):
with raises_step_exception(contract_setting == "freeze"):
run_resource(pipeline, items_with_new_column, full_settings, item_format, duplicates=2)
# delete extracted files if left after exception
pipeline.drop_pending_packages()
Expand Down Expand Up @@ -301,7 +305,7 @@ def test_freeze_variants(contract_setting: str, setting_location: str) -> None:
assert NEW_COLUMN_NAME in pipeline.default_schema.tables["items"]["columns"]

# test adding variant column
with raises_frozen_exception(contract_setting == "freeze"):
with raises_step_exception(contract_setting == "freeze"):
run_resource(pipeline, items_with_variant, full_settings)

if contract_setting == "evolve":
Expand Down Expand Up @@ -485,7 +489,7 @@ def get_items_subtable():
# loading once with pydantic will freeze the cols
pipeline = get_pipeline()
pipeline.run([get_items_with_model()])
with raises_frozen_exception(True):
with raises_step_exception(True):
pipeline.run([get_items_new_col()])

# it is possible to override contract when there are new columns
Expand Down Expand Up @@ -524,7 +528,7 @@ def get_items():
}
yield {"id": 2, "tables": "two", "new_column": "some val"}

with raises_frozen_exception(table_mode == "freeze"):
with raises_step_exception(table_mode == "freeze"):
pipeline.run([get_items()], schema_contract={"tables": table_mode})

if table_mode != "freeze":
Expand Down Expand Up @@ -622,3 +626,93 @@ def get_items():
# apply hints apply to `items` not the original resource, so doing get_items() below removed them completely
pipeline.run(items)
assert pipeline.last_trace.last_normalize_info.row_counts.get("items", 0) == 2


@pytest.mark.parametrize("contract_setting", schema_contract)
@pytest.mark.parametrize("as_list", [True, False])
def test_pydantic_contract_implementation(contract_setting: str, as_list: bool) -> None:
from pydantic import BaseModel

class Items(BaseModel):
id: int # noqa: A003
name: str

def get_items(as_list: bool = False):
items = [
{
"id": 5,
"name": "dave",
}
]
if as_list:
yield items
else:
yield from items

def get_items_extra_attribute(as_list: bool = False):
items = [{"id": 5, "name": "dave", "blah": "blubb"}]
if as_list:
yield items
else:
yield from items

def get_items_extra_variant(as_list: bool = False):
items = [
{
"id": "five",
"name": "dave",
}
]
if as_list:
yield items
else:
yield from items

# test columns complying to model
pipeline = get_pipeline()
pipeline.run(
[get_items(as_list)],
schema_contract={"columns": contract_setting},
columns=Items,
table_name="items",
)
table_counts = load_table_counts(
pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]
)
assert table_counts[ITEMS_TABLE] == 1

# test columns extra attribute
with raises_step_exception(
contract_setting in ["freeze"],
expected_nested_error=(
ResourceExtractionError if contract_setting == "freeze" else NotImplementedError
),
):
pipeline.run(
[get_items_extra_attribute(as_list)],
schema_contract={"columns": contract_setting},
columns=Items,
table_name="items",
)
table_counts = load_table_counts(
pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]
)
assert table_counts[ITEMS_TABLE] == 1 if (contract_setting in ["freeze", "discard_row"]) else 2

# test columns with variant
with raises_step_exception(
contract_setting in ["freeze", "discard_value"],
expected_nested_error=(
ResourceExtractionError if contract_setting == "freeze" else NotImplementedError
),
):
pipeline.run(
[get_items_extra_variant(as_list)],
schema_contract={"data_type": contract_setting},
columns=Items,
table_name="items",
)
table_counts = load_table_counts(
pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]
)
assert table_counts[ITEMS_TABLE] == 1 if (contract_setting in ["freeze", "discard_row"]) else 3

0 comments on commit 6c4226d

Please sign in to comment.