From 6c4226d24e8d21ba6b663fd639a4b69eea65bc02 Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 16 Jan 2024 11:59:34 +0100 Subject: [PATCH] add pydantic contracts implementation tests --- tests/pipeline/test_schema_contracts.py | 114 +++++++++++++++++++++--- 1 file changed, 104 insertions(+), 10 deletions(-) diff --git a/tests/pipeline/test_schema_contracts.py b/tests/pipeline/test_schema_contracts.py index 2f2e6b6932..a4d4707593 100644 --- a/tests/pipeline/test_schema_contracts.py +++ b/tests/pipeline/test_schema_contracts.py @@ -1,6 +1,6 @@ import dlt, os, pytest import contextlib -from typing import Any, Callable, Iterator, Union, Optional +from typing import Any, Callable, Iterator, Union, Optional, Type from dlt.common.schema.typing import TSchemaContract from dlt.common.utils import uniq_id @@ -9,6 +9,7 @@ from dlt.extract import DltResource from dlt.pipeline.pipeline import Pipeline from dlt.pipeline.exceptions import PipelineStepFailed +from dlt.extract.exceptions import ResourceExtractionError from tests.load.pipeline.utils import load_table_counts from tests.utils import ( @@ -26,17 +27,19 @@ @contextlib.contextmanager -def raises_frozen_exception(check_raise: bool = True) -> Any: +def raises_step_exception(check_raise: bool = True, expected_nested_error: Type[Any] = None) -> Any: + expected_nested_error = expected_nested_error or DataValidationError if not check_raise: yield return with pytest.raises(PipelineStepFailed) as py_exc: yield if py_exc.value.step == "extract": - assert isinstance(py_exc.value.__context__, DataValidationError) + print(type(py_exc.value.__context__)) + assert isinstance(py_exc.value.__context__, expected_nested_error) else: # normalize - assert isinstance(py_exc.value.__context__.__context__, DataValidationError) + assert isinstance(py_exc.value.__context__.__context__, expected_nested_error) def items(settings: TSchemaContract) -> Any: @@ -94,6 +97,7 @@ def load_items(): VARIANT_COLUMN_NAME = "some_int__v_text" SUBITEMS_TABLE = "items__sub_items" NEW_ITEMS_TABLE = "new_items" +ITEMS_TABLE = "items" def run_resource( @@ -171,7 +175,7 @@ def test_new_tables( assert NEW_COLUMN_NAME in pipeline.default_schema.tables["items"]["columns"] # test adding new table - with raises_frozen_exception(contract_setting == "freeze"): + with raises_step_exception(contract_setting == "freeze"): run_resource(pipeline, new_items, full_settings, item_format) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] @@ -191,7 +195,7 @@ def test_new_tables( assert VARIANT_COLUMN_NAME in pipeline.default_schema.tables["items"]["columns"] # test adding new subtable - with raises_frozen_exception(contract_setting == "freeze"): + with raises_step_exception(contract_setting == "freeze"): run_resource(pipeline, items_with_subtable, full_settings) table_counts = load_table_counts( @@ -227,7 +231,7 @@ def test_new_columns( assert table_counts[NEW_ITEMS_TABLE] == 10 # test adding new column twice: filter will try to catch it before it is added for the second time - with raises_frozen_exception(contract_setting == "freeze"): + with raises_step_exception(contract_setting == "freeze"): run_resource(pipeline, items_with_new_column, full_settings, item_format, duplicates=2) # delete extracted files if left after exception pipeline.drop_pending_packages() @@ -301,7 +305,7 @@ def test_freeze_variants(contract_setting: str, setting_location: str) -> None: assert NEW_COLUMN_NAME in pipeline.default_schema.tables["items"]["columns"] # test adding variant column - with raises_frozen_exception(contract_setting == "freeze"): + with raises_step_exception(contract_setting == "freeze"): run_resource(pipeline, items_with_variant, full_settings) if contract_setting == "evolve": @@ -485,7 +489,7 @@ def get_items_subtable(): # loading once with pydantic will freeze the cols pipeline = get_pipeline() pipeline.run([get_items_with_model()]) - with raises_frozen_exception(True): + with raises_step_exception(True): pipeline.run([get_items_new_col()]) # it is possible to override contract when there are new columns @@ -524,7 +528,7 @@ def get_items(): } yield {"id": 2, "tables": "two", "new_column": "some val"} - with raises_frozen_exception(table_mode == "freeze"): + with raises_step_exception(table_mode == "freeze"): pipeline.run([get_items()], schema_contract={"tables": table_mode}) if table_mode != "freeze": @@ -622,3 +626,93 @@ def get_items(): # apply hints apply to `items` not the original resource, so doing get_items() below removed them completely pipeline.run(items) assert pipeline.last_trace.last_normalize_info.row_counts.get("items", 0) == 2 + + +@pytest.mark.parametrize("contract_setting", schema_contract) +@pytest.mark.parametrize("as_list", [True, False]) +def test_pydantic_contract_implementation(contract_setting: str, as_list: bool) -> None: + from pydantic import BaseModel + + class Items(BaseModel): + id: int # noqa: A003 + name: str + + def get_items(as_list: bool = False): + items = [ + { + "id": 5, + "name": "dave", + } + ] + if as_list: + yield items + else: + yield from items + + def get_items_extra_attribute(as_list: bool = False): + items = [{"id": 5, "name": "dave", "blah": "blubb"}] + if as_list: + yield items + else: + yield from items + + def get_items_extra_variant(as_list: bool = False): + items = [ + { + "id": "five", + "name": "dave", + } + ] + if as_list: + yield items + else: + yield from items + + # test columns complying to model + pipeline = get_pipeline() + pipeline.run( + [get_items(as_list)], + schema_contract={"columns": contract_setting}, + columns=Items, + table_name="items", + ) + table_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] + ) + assert table_counts[ITEMS_TABLE] == 1 + + # test columns extra attribute + with raises_step_exception( + contract_setting in ["freeze"], + expected_nested_error=( + ResourceExtractionError if contract_setting == "freeze" else NotImplementedError + ), + ): + pipeline.run( + [get_items_extra_attribute(as_list)], + schema_contract={"columns": contract_setting}, + columns=Items, + table_name="items", + ) + table_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] + ) + assert table_counts[ITEMS_TABLE] == 1 if (contract_setting in ["freeze", "discard_row"]) else 2 + + # test columns with variant + with raises_step_exception( + contract_setting in ["freeze", "discard_value"], + expected_nested_error=( + ResourceExtractionError if contract_setting == "freeze" else NotImplementedError + ), + ): + pipeline.run( + [get_items_extra_variant(as_list)], + schema_contract={"data_type": contract_setting}, + columns=Items, + table_name="items", + ) + table_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] + ) + assert table_counts[ITEMS_TABLE] == 1 if (contract_setting in ["freeze", "discard_row"]) else 3